changeset 1071:93f87d5d9fbb feature/diagonal_tensor

Add a lazy diagonal tensor
author Jonatan Werpers <jonatan@werpers.com>
date Thu, 24 Mar 2022 09:29:57 +0100
parents 2b6298905692
children 21c209cd95c8 bdb49f82a571
files src/LazyTensors/LazyTensors.jl src/LazyTensors/tensor_types.jl test/LazyTensors/tensor_types_test.jl test/Manifest.toml test/Project.toml
diffstat 5 files changed, 86 insertions(+), 2 deletions(-) [+]
line wrap: on
line diff
--- a/src/LazyTensors/LazyTensors.jl	Wed Mar 23 21:59:32 2022 +0100
+++ b/src/LazyTensors/LazyTensors.jl	Thu Mar 24 09:29:57 2022 +0100
@@ -3,9 +3,10 @@
 export TensorApplication
 export TensorTranspose
 export TensorComposition
-export DenseTensor
 export IdentityTensor
 export ScalingTensor
+export DiagonalTensor
+export DenseTensor
 export InflatedTensor
 export LazyOuterProduct
 export ⊗
--- a/src/LazyTensors/tensor_types.jl	Wed Mar 23 21:59:32 2022 +0100
+++ b/src/LazyTensors/tensor_types.jl	Thu Mar 24 09:29:57 2022 +0100
@@ -37,6 +37,24 @@
 
 
 """
+    DiagonalTensor{T,D,...} <: LazyTensor{T,D,D}
+    DiagonalTensor(a::AbstractArray)
+
+A lazy tensor with diagonal `a`.
+"""
+struct DiagonalTensor{T,D,AT<:AbstractArray{T,D}} <: LazyTensor{T,D,D}
+    diagonal::AT
+end
+
+range_size(tm::DiagonalTensor) = size(tm.diagonal)
+domain_size(tm::DiagonalTensor) = size(tm.diagonal)
+
+
+LazyTensors.apply(tm::DiagonalTensor{T,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,D}) where {T,D} = tm.diagonal[I...]*v[I...]
+LazyTensors.apply_transpose(tm::DiagonalTensor{T,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,D}) where {T,D} = tm.diagonal[I...]*v[I...]
+
+
+"""
     DenseTensor{T,R,D,...}(A, range_indicies, domain_indicies)
 
 LazyTensor defined by the AbstractArray A. `range_indicies` and `domain_indicies` define which indicies of A should
--- a/test/LazyTensors/tensor_types_test.jl	Wed Mar 23 21:59:32 2022 +0100
+++ b/test/LazyTensors/tensor_types_test.jl	Thu Mar 24 09:29:57 2022 +0100
@@ -1,5 +1,6 @@
 using Test
 using Sbplib.LazyTensors
+using BenchmarkTools
 
 @testset "IdentityTensor" begin
     @test IdentityTensor{Float64}((4,5)) isa IdentityTensor{T,2} where T
@@ -58,6 +59,47 @@
     @inferred (st'*v)[2,2]
 end
 
+@testset "DiagonalTensor" begin
+    @test DiagonalTensor([1,2,3,4]) isa LazyTensor{Int,1,1}
+    @test DiagonalTensor([1 2 3; 4 5 6]) isa LazyTensor{Int,2,2}
+    @test DiagonalTensor([1. 2. 3.; 4. 5. 6.]) isa LazyTensor{Float64,2,2}
+
+    @test range_size(DiagonalTensor([1,2,3,4])) == (4,)
+    @test domain_size(DiagonalTensor([1,2,3,4])) == (4,)
+
+    @test range_size(DiagonalTensor([1 2 3; 4 5 6])) == (2,3)
+    @test domain_size(DiagonalTensor([1 2 3; 4 5 6])) == (2,3)
+
+    @testset "apply size=$sz" for sz ∈ [(4,),(3,2),(3,4,2)]
+        diag = rand(sz...)
+        tm = DiagonalTensor(diag)
+
+        v = rand(sz...)
+
+        @test tm*v == diag.*v
+        @test tm'*v == diag.*v
+    end
+
+
+    @testset "allocations size=$sz" for sz ∈ [(4,),(3,2),(3,4,2)]
+        diag = rand(sz...)
+        tm = DiagonalTensor(diag)
+
+        v = rand(sz...)
+
+        @test tm*v == diag.*v
+        @test tm'*v == diag.*v
+    end
+
+    sz = (3,2)
+    diag = rand(sz...)
+    tm = DiagonalTensor(diag)
+
+    v = rand(sz...)
+    LazyTensors.apply(tm,v, 2,1)
+    @test (@ballocated LazyTensors.apply($tm,$v, 2,1)) == 0
+end
+
 
 @testset "DenseTensor" begin
     # Test a standard matrix-vector product
--- a/test/Manifest.toml	Wed Mar 23 21:59:32 2022 +0100
+++ b/test/Manifest.toml	Thu Mar 24 09:29:57 2022 +0100
@@ -1,6 +1,6 @@
 # This file is machine-generated - editing it directly is not advised
 
-julia_version = "1.7.0"
+julia_version = "1.7.1"
 manifest_format = "2.0"
 
 [[deps.ArgTools]]
@@ -12,6 +12,12 @@
 [[deps.Base64]]
 uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
 
+[[deps.BenchmarkTools]]
+deps = ["JSON", "Logging", "Printf", "Profile", "Statistics", "UUIDs"]
+git-tree-sha1 = "4c10eee4af024676200bc7752e536f858c6b8f93"
+uuid = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
+version = "1.3.1"
+
 [[deps.ChainRulesCore]]
 deps = ["Compat", "LinearAlgebra", "SparseArrays"]
 git-tree-sha1 = "4c26b4e9e91ca528ea212927326ece5918a04b47"
@@ -93,6 +99,12 @@
 uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210"
 version = "1.3.0"
 
+[[deps.JSON]]
+deps = ["Dates", "Mmap", "Parsers", "Unicode"]
+git-tree-sha1 = "3c837543ddb02250ef42f4738347454f95079d4e"
+uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
+version = "0.21.3"
+
 [[deps.LibCURL]]
 deps = ["LibCURL_jll", "MozillaCACerts_jll"]
 uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21"
@@ -161,6 +173,12 @@
 uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e"
 version = "0.5.5+0"
 
+[[deps.Parsers]]
+deps = ["Dates"]
+git-tree-sha1 = "85b5da0fa43588c75bb1ff986493443f821c70b7"
+uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
+version = "2.2.3"
+
 [[deps.Pkg]]
 deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
 uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
@@ -175,6 +193,10 @@
 deps = ["Unicode"]
 uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
 
+[[deps.Profile]]
+deps = ["Printf"]
+uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79"
+
 [[deps.REPL]]
 deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"]
 uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
--- a/test/Project.toml	Wed Mar 23 21:59:32 2022 +0100
+++ b/test/Project.toml	Thu Mar 24 09:29:57 2022 +0100
@@ -1,4 +1,5 @@
 [deps]
+BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
 Glob = "c27321d9-0574-5035-807b-f59d2c89b15c"
 LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
 TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"