comparison test/LazyTensors/lazy_tensor_operations_test.jl @ 2057:8a2a0d678d6f feature/lazy_tensors/pretty_printing

Merge default
author Jonatan Werpers <jonatan@werpers.com>
date Tue, 10 Feb 2026 22:41:19 +0100
parents 1e8270c18edb b12e28a03b2e
children
comparison
equal deleted inserted replaced
1110:c0bff9f6e0fb 2057:8a2a0d678d6f
1 using Test 1 using Test
2 using Sbplib.LazyTensors 2 using Diffinitive.LazyTensors
3 using Sbplib.RegionIndices 3 using Diffinitive.RegionIndices
4 4
5 using Tullio 5 using Tullio
6 6
7 struct DummyMapping{T,R,D} <: LazyTensor{T,R,D} end 7 struct TransposableDummyMapping{T,R,D} <: LazyTensor{T,R,D} end
8 8
9 LazyTensors.apply(m::DummyMapping{T,R}, v, I::Vararg{Any,R}) where {T,R} = :apply 9 LazyTensors.apply(m::TransposableDummyMapping{T,R}, v, I::Vararg{Any,R}) where {T,R} = :apply
10 LazyTensors.apply_transpose(m::DummyMapping{T,R,D}, v, I::Vararg{Any,D}) where {T,R,D} = :apply_transpose 10 LazyTensors.apply_transpose(m::TransposableDummyMapping{T,R,D}, v, I::Vararg{Any,D}) where {T,R,D} = :apply_transpose
11 11
12 LazyTensors.range_size(m::DummyMapping) = :range_size 12 LazyTensors.range_size(m::TransposableDummyMapping) = :range_size
13 LazyTensors.domain_size(m::DummyMapping) = :domain_size 13 LazyTensors.domain_size(m::TransposableDummyMapping) = :domain_size
14 14
15 15
16 struct SizeDoublingMapping{T,R,D} <: LazyTensor{T,R,D} 16 struct SizeDoublingMapping{T,R,D} <: LazyTensor{T,R,D}
17 domain_size::NTuple{D,Int} 17 domain_size::NTuple{D,Int}
18 end 18 end
20 LazyTensors.apply(m::SizeDoublingMapping{T,R}, v, i::Vararg{Any,R}) where {T,R} = (:apply,v,i) 20 LazyTensors.apply(m::SizeDoublingMapping{T,R}, v, i::Vararg{Any,R}) where {T,R} = (:apply,v,i)
21 LazyTensors.range_size(m::SizeDoublingMapping) = 2 .* m.domain_size 21 LazyTensors.range_size(m::SizeDoublingMapping) = 2 .* m.domain_size
22 LazyTensors.domain_size(m::SizeDoublingMapping) = m.domain_size 22 LazyTensors.domain_size(m::SizeDoublingMapping) = m.domain_size
23 23
24 24
25
26 @testset "Mapping transpose" begin 25 @testset "Mapping transpose" begin
27 m = DummyMapping{Float64,2,3}() 26 m = TransposableDummyMapping{Float64,2,3}()
28 @test m' isa LazyTensor{Float64, 3,2} 27 @test m' isa LazyTensor{Float64, 3,2}
29 @test m'' == m 28 @test m'' == m
30 @test apply(m',zeros(Float64,(0,0)), 0, 0, 0) == :apply_transpose 29 @test apply(m',zeros(Float64,(0,0)), 0, 0, 0) == :apply_transpose
31 @test apply(m'',zeros(Float64,(0,0,0)), 0, 0) == :apply 30 @test apply(m'',zeros(Float64,(0,0,0)), 0, 0) == :apply
32 @test apply_transpose(m', zeros(Float64,(0,0,0)), 0, 0) == :apply 31 @test apply_transpose(m', zeros(Float64,(0,0,0)), 0, 0) == :apply
126 @inferred m*v 125 @inferred m*v
127 @inferred (m*v)[1] 126 @inferred (m*v)[1]
128 end 127 end
129 end 128 end
130 129
131 130 @testset "TensorNegation" begin
132 @testset "LazyTensor binary operations" begin 131 A = rand(2,3)
132 B = rand(3,4)
133
134 Ã = DenseTensor(A, (1,), (2,))
135 B̃ = DenseTensor(B, (1,), (2,))
136
137 @test -Ã isa TensorNegation
138
139 v = rand(3)
140 @test (-Ã)*v == -(Ã*v)
141
142 v = rand(4)
143 @test (-B̃)*v == -(B̃*v)
144
145 v = rand(2)
146 @test (-Ã)'*v == -(Ã'*v)
147
148 v = rand(3)
149 @test (-B̃)'*v == -(B̃'*v)
150
151 @test domain_size(-Ã) == (3,)
152 @test domain_size(-B̃) == (4,)
153
154 @test range_size(-Ã) == (2,)
155 @test range_size(-B̃) == (3,)
156 end
157
158 @testset "TensorSum" begin
133 A = ScalingTensor(2.0, (3,)) 159 A = ScalingTensor(2.0, (3,))
134 B = ScalingTensor(3.0, (3,)) 160 B = ScalingTensor(3.0, (3,))
135 161
136 v = [1.1,1.2,1.3] 162 v = [1.1,1.2,1.3]
137 for i ∈ eachindex(v) 163 for i ∈ eachindex(v)
138 @test ((A+B)*v)[i] == 2*v[i] + 3*v[i] 164 @test ((A+B)*v)[i] == 2*v[i] + 3*v[i]
139 end 165 end
140 166
141 for i ∈ eachindex(v) 167 for i ∈ eachindex(v)
142 @test ((A-B)*v)[i] == 2*v[i] - 3*v[i] 168 @test ((A-B)*v)[i] == 2*v[i] - 3*v[i]
169 end
170
171 for i ∈ eachindex(v)
172 @test ((A+B)'*v)[i] == 2*v[i] + 3*v[i]
143 end 173 end
144 174
145 175
146 @test range_size(A+B) == range_size(A) == range_size(B) 176 @test range_size(A+B) == range_size(A) == range_size(B)
147 @test domain_size(A+B) == domain_size(A) == domain_size(B) 177 @test domain_size(A+B) == domain_size(A) == domain_size(B)
153 183
154 @test_throws DomainSizeMismatch ScalingTensor(2.0, (4,)) + SizeDoublingMapping{Float64,1,1}((2,)) 184 @test_throws DomainSizeMismatch ScalingTensor(2.0, (4,)) + SizeDoublingMapping{Float64,1,1}((2,))
155 @test_throws DomainSizeMismatch SizeDoublingMapping{Float64,1,1}((2,)) + ScalingTensor(2.0, (4,)) 185 @test_throws DomainSizeMismatch SizeDoublingMapping{Float64,1,1}((2,)) + ScalingTensor(2.0, (4,))
156 @test_throws RangeSizeMismatch ScalingTensor(2.0, (2,)) + SizeDoublingMapping{Float64,1,1}((2,)) 186 @test_throws RangeSizeMismatch ScalingTensor(2.0, (2,)) + SizeDoublingMapping{Float64,1,1}((2,))
157 @test_throws RangeSizeMismatch SizeDoublingMapping{Float64,1,1}((2,)) + ScalingTensor(2.0, (2,)) 187 @test_throws RangeSizeMismatch SizeDoublingMapping{Float64,1,1}((2,)) + ScalingTensor(2.0, (2,))
188 end
189
190 @testset "Chained operators" begin
191 A = ScalingTensor(1.0, (3,))
192 B = ScalingTensor(2.0, (3,))
193 C = ScalingTensor(3.0, (3,))
194 D = ScalingTensor(4.0, (3,))
195
196 @test A+B+C+D isa TensorSum
197 @test length((A+B+C+D).tms) == 4
198
199
200 @test A+B-C+D isa TensorSum
201 @test length((A+B-C+D).tms) == 4
202
203 v = rand(3)
204 @test (A+B-C+D)*v == 1v + 2v - 3v + 4v
205
206
207 @test -A-B-C-D isa TensorSum
208 @test length((-A-B-C-D).tms) == 4
209
210 v = rand(3)
211 @test (-A-B-C-D)*v == -1v - 2v - 3v - 4v
158 end 212 end
159 end 213 end
160 214
161 215
162 @testset "TensorComposition" begin 216 @testset "TensorComposition" begin
186 v = rand(3) 240 v = rand(3)
187 @test a*Ã isa TensorComposition 241 @test a*Ã isa TensorComposition
188 @test a*Ã == Ã*a 242 @test a*Ã == Ã*a
189 @test range_size(a*Ã) == range_size(Ã) 243 @test range_size(a*Ã) == range_size(Ã)
190 @test domain_size(a*Ã) == domain_size(Ã) 244 @test domain_size(a*Ã) == domain_size(Ã)
191 @test a*Ã*v == a.*A*v 245 @test a*Ã*v ≈ a.*A*v rtol=1e-14
192 end 246 end
193 247
194 248
195 @testset "InflatedTensor" begin 249 @testset "InflatedTensor" begin
196 I(sz...) = IdentityTensor(sz...) 250 I(sz...) = IdentityTensor(sz...)
378 I1 = IdentityTensor(3,2) 432 I1 = IdentityTensor(3,2)
379 I2 = IdentityTensor(4) 433 I2 = IdentityTensor(4)
380 @test I1⊗Ã⊗I2 == InflatedTensor(I1, Ã, I2) 434 @test I1⊗Ã⊗I2 == InflatedTensor(I1, Ã, I2)
381 end 435 end
382 end 436 end
437
438 @testset "inflate" begin
439 I = LazyTensors.inflate(IdentityTensor(),(3,4,5,6), 2)
440 @test I isa LazyTensor{Float64, 3,3}
441 @test range_size(I) == (3,5,6)
442 @test domain_size(I) == (3,5,6)
443
444 @test LazyTensors.inflate(ScalingTensor(1., (4,)),(3,4,5,6), 1) == InflatedTensor(IdentityTensor{Float64}(),ScalingTensor(1., (4,)),IdentityTensor(4,5,6))
445 @test LazyTensors.inflate(ScalingTensor(2., (1,)),(3,4,5,6), 2) == InflatedTensor(IdentityTensor(3),ScalingTensor(2., (1,)),IdentityTensor(5,6))
446 @test LazyTensors.inflate(ScalingTensor(3., (6,)),(3,4,5,6), 4) == InflatedTensor(IdentityTensor(3,4,5),ScalingTensor(3., (6,)),IdentityTensor{Float64}())
447
448 @test_throws BoundsError LazyTensors.inflate(ScalingTensor(1., (4,)),(3,4,5,6), 0)
449 @test_throws BoundsError LazyTensors.inflate(ScalingTensor(1., (4,)),(3,4,5,6), 5)
450 end