Mercurial > repos > public > sbplib_julia
changeset 1077:bdb49f82a571 feature/diagonal_tensor
Merge default
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Thu, 07 Apr 2022 20:28:59 +0200 |
parents | 93f87d5d9fbb (diff) 03f65ef8adb9 (current diff) |
children | 0784b8d2f621 |
files | test/Manifest.toml |
diffstat | 5 files changed, 85 insertions(+), 1 deletions(-) [+] |
line wrap: on
line diff
--- a/src/LazyTensors/LazyTensors.jl Thu Apr 07 07:30:00 2022 +0200 +++ b/src/LazyTensors/LazyTensors.jl Thu Apr 07 20:28:59 2022 +0200 @@ -3,9 +3,10 @@ export TensorApplication export TensorTranspose export TensorComposition -export DenseTensor export IdentityTensor export ScalingTensor +export DiagonalTensor +export DenseTensor export InflatedTensor export LazyOuterProduct export ⊗
--- a/src/LazyTensors/tensor_types.jl Thu Apr 07 07:30:00 2022 +0200 +++ b/src/LazyTensors/tensor_types.jl Thu Apr 07 20:28:59 2022 +0200 @@ -37,6 +37,24 @@ """ + DiagonalTensor{T,D,...} <: LazyTensor{T,D,D} + DiagonalTensor(a::AbstractArray) + +A lazy tensor with diagonal `a`. +""" +struct DiagonalTensor{T,D,AT<:AbstractArray{T,D}} <: LazyTensor{T,D,D} + diagonal::AT +end + +range_size(tm::DiagonalTensor) = size(tm.diagonal) +domain_size(tm::DiagonalTensor) = size(tm.diagonal) + + +LazyTensors.apply(tm::DiagonalTensor{T,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,D}) where {T,D} = tm.diagonal[I...]*v[I...] +LazyTensors.apply_transpose(tm::DiagonalTensor{T,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,D}) where {T,D} = tm.diagonal[I...]*v[I...] + + +""" DenseTensor{T,R,D,...}(A, range_indicies, domain_indicies) LazyTensor defined by the AbstractArray A. `range_indicies` and `domain_indicies` define which indicies of A should
--- a/test/LazyTensors/tensor_types_test.jl Thu Apr 07 07:30:00 2022 +0200 +++ b/test/LazyTensors/tensor_types_test.jl Thu Apr 07 20:28:59 2022 +0200 @@ -1,5 +1,6 @@ using Test using Sbplib.LazyTensors +using BenchmarkTools @testset "IdentityTensor" begin @test IdentityTensor{Float64}((4,5)) isa IdentityTensor{T,2} where T @@ -58,6 +59,47 @@ @inferred (st'*v)[2,2] end +@testset "DiagonalTensor" begin + @test DiagonalTensor([1,2,3,4]) isa LazyTensor{Int,1,1} + @test DiagonalTensor([1 2 3; 4 5 6]) isa LazyTensor{Int,2,2} + @test DiagonalTensor([1. 2. 3.; 4. 5. 6.]) isa LazyTensor{Float64,2,2} + + @test range_size(DiagonalTensor([1,2,3,4])) == (4,) + @test domain_size(DiagonalTensor([1,2,3,4])) == (4,) + + @test range_size(DiagonalTensor([1 2 3; 4 5 6])) == (2,3) + @test domain_size(DiagonalTensor([1 2 3; 4 5 6])) == (2,3) + + @testset "apply size=$sz" for sz ∈ [(4,),(3,2),(3,4,2)] + diag = rand(sz...) + tm = DiagonalTensor(diag) + + v = rand(sz...) + + @test tm*v == diag.*v + @test tm'*v == diag.*v + end + + + @testset "allocations size=$sz" for sz ∈ [(4,),(3,2),(3,4,2)] + diag = rand(sz...) + tm = DiagonalTensor(diag) + + v = rand(sz...) + + @test tm*v == diag.*v + @test tm'*v == diag.*v + end + + sz = (3,2) + diag = rand(sz...) + tm = DiagonalTensor(diag) + + v = rand(sz...) + LazyTensors.apply(tm,v, 2,1) + @test (@ballocated LazyTensors.apply($tm,$v, 2,1)) == 0 +end + @testset "DenseTensor" begin # Test a standard matrix-vector product
--- a/test/Manifest.toml Thu Apr 07 07:30:00 2022 +0200 +++ b/test/Manifest.toml Thu Apr 07 20:28:59 2022 +0200 @@ -12,6 +12,12 @@ [[deps.Base64]] uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" +[[deps.BenchmarkTools]] +deps = ["JSON", "Logging", "Printf", "Profile", "Statistics", "UUIDs"] +git-tree-sha1 = "4c10eee4af024676200bc7752e536f858c6b8f93" +uuid = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +version = "1.3.1" + [[deps.ChainRulesCore]] deps = ["Compat", "LinearAlgebra", "SparseArrays"] git-tree-sha1 = "9950387274246d08af38f6eef8cb5480862a435f" @@ -93,6 +99,12 @@ uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" version = "1.4.1" +[[deps.JSON]] +deps = ["Dates", "Mmap", "Parsers", "Unicode"] +git-tree-sha1 = "3c837543ddb02250ef42f4738347454f95079d4e" +uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" +version = "0.21.3" + [[deps.LibCURL]] deps = ["LibCURL_jll", "MozillaCACerts_jll"] uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" @@ -161,6 +173,12 @@ uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" version = "0.5.5+0" +[[deps.Parsers]] +deps = ["Dates"] +git-tree-sha1 = "85b5da0fa43588c75bb1ff986493443f821c70b7" +uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" +version = "2.2.3" + [[deps.Pkg]] deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" @@ -175,6 +193,10 @@ deps = ["Unicode"] uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" +[[deps.Profile]] +deps = ["Printf"] +uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79" + [[deps.REPL]] deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
--- a/test/Project.toml Thu Apr 07 07:30:00 2022 +0200 +++ b/test/Project.toml Thu Apr 07 20:28:59 2022 +0200 @@ -1,4 +1,5 @@ [deps] +BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" Glob = "c27321d9-0574-5035-807b-f59d2c89b15c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"