Mercurial > repos > public > sbplib_julia
comparison test/LazyTensors/lazy_tensor_operations_test.jl @ 2057:8a2a0d678d6f feature/lazy_tensors/pretty_printing
Merge default
| author | Jonatan Werpers <jonatan@werpers.com> |
|---|---|
| date | Tue, 10 Feb 2026 22:41:19 +0100 |
| parents | 1e8270c18edb b12e28a03b2e |
| children |
comparison
equal
deleted
inserted
replaced
| 1110:c0bff9f6e0fb | 2057:8a2a0d678d6f |
|---|---|
| 1 using Test | 1 using Test |
| 2 using Sbplib.LazyTensors | 2 using Diffinitive.LazyTensors |
| 3 using Sbplib.RegionIndices | 3 using Diffinitive.RegionIndices |
| 4 | 4 |
| 5 using Tullio | 5 using Tullio |
| 6 | 6 |
| 7 struct DummyMapping{T,R,D} <: LazyTensor{T,R,D} end | 7 struct TransposableDummyMapping{T,R,D} <: LazyTensor{T,R,D} end |
| 8 | 8 |
| 9 LazyTensors.apply(m::DummyMapping{T,R}, v, I::Vararg{Any,R}) where {T,R} = :apply | 9 LazyTensors.apply(m::TransposableDummyMapping{T,R}, v, I::Vararg{Any,R}) where {T,R} = :apply |
| 10 LazyTensors.apply_transpose(m::DummyMapping{T,R,D}, v, I::Vararg{Any,D}) where {T,R,D} = :apply_transpose | 10 LazyTensors.apply_transpose(m::TransposableDummyMapping{T,R,D}, v, I::Vararg{Any,D}) where {T,R,D} = :apply_transpose |
| 11 | 11 |
| 12 LazyTensors.range_size(m::DummyMapping) = :range_size | 12 LazyTensors.range_size(m::TransposableDummyMapping) = :range_size |
| 13 LazyTensors.domain_size(m::DummyMapping) = :domain_size | 13 LazyTensors.domain_size(m::TransposableDummyMapping) = :domain_size |
| 14 | 14 |
| 15 | 15 |
| 16 struct SizeDoublingMapping{T,R,D} <: LazyTensor{T,R,D} | 16 struct SizeDoublingMapping{T,R,D} <: LazyTensor{T,R,D} |
| 17 domain_size::NTuple{D,Int} | 17 domain_size::NTuple{D,Int} |
| 18 end | 18 end |
| 20 LazyTensors.apply(m::SizeDoublingMapping{T,R}, v, i::Vararg{Any,R}) where {T,R} = (:apply,v,i) | 20 LazyTensors.apply(m::SizeDoublingMapping{T,R}, v, i::Vararg{Any,R}) where {T,R} = (:apply,v,i) |
| 21 LazyTensors.range_size(m::SizeDoublingMapping) = 2 .* m.domain_size | 21 LazyTensors.range_size(m::SizeDoublingMapping) = 2 .* m.domain_size |
| 22 LazyTensors.domain_size(m::SizeDoublingMapping) = m.domain_size | 22 LazyTensors.domain_size(m::SizeDoublingMapping) = m.domain_size |
| 23 | 23 |
| 24 | 24 |
| 25 | |
| 26 @testset "Mapping transpose" begin | 25 @testset "Mapping transpose" begin |
| 27 m = DummyMapping{Float64,2,3}() | 26 m = TransposableDummyMapping{Float64,2,3}() |
| 28 @test m' isa LazyTensor{Float64, 3,2} | 27 @test m' isa LazyTensor{Float64, 3,2} |
| 29 @test m'' == m | 28 @test m'' == m |
| 30 @test apply(m',zeros(Float64,(0,0)), 0, 0, 0) == :apply_transpose | 29 @test apply(m',zeros(Float64,(0,0)), 0, 0, 0) == :apply_transpose |
| 31 @test apply(m'',zeros(Float64,(0,0,0)), 0, 0) == :apply | 30 @test apply(m'',zeros(Float64,(0,0,0)), 0, 0) == :apply |
| 32 @test apply_transpose(m', zeros(Float64,(0,0,0)), 0, 0) == :apply | 31 @test apply_transpose(m', zeros(Float64,(0,0,0)), 0, 0) == :apply |
| 126 @inferred m*v | 125 @inferred m*v |
| 127 @inferred (m*v)[1] | 126 @inferred (m*v)[1] |
| 128 end | 127 end |
| 129 end | 128 end |
| 130 | 129 |
| 131 | 130 @testset "TensorNegation" begin |
| 132 @testset "LazyTensor binary operations" begin | 131 A = rand(2,3) |
| 132 B = rand(3,4) | |
| 133 | |
| 134 Ã = DenseTensor(A, (1,), (2,)) | |
| 135 B̃ = DenseTensor(B, (1,), (2,)) | |
| 136 | |
| 137 @test -Ã isa TensorNegation | |
| 138 | |
| 139 v = rand(3) | |
| 140 @test (-Ã)*v == -(Ã*v) | |
| 141 | |
| 142 v = rand(4) | |
| 143 @test (-B̃)*v == -(B̃*v) | |
| 144 | |
| 145 v = rand(2) | |
| 146 @test (-Ã)'*v == -(Ã'*v) | |
| 147 | |
| 148 v = rand(3) | |
| 149 @test (-B̃)'*v == -(B̃'*v) | |
| 150 | |
| 151 @test domain_size(-Ã) == (3,) | |
| 152 @test domain_size(-B̃) == (4,) | |
| 153 | |
| 154 @test range_size(-Ã) == (2,) | |
| 155 @test range_size(-B̃) == (3,) | |
| 156 end | |
| 157 | |
| 158 @testset "TensorSum" begin | |
| 133 A = ScalingTensor(2.0, (3,)) | 159 A = ScalingTensor(2.0, (3,)) |
| 134 B = ScalingTensor(3.0, (3,)) | 160 B = ScalingTensor(3.0, (3,)) |
| 135 | 161 |
| 136 v = [1.1,1.2,1.3] | 162 v = [1.1,1.2,1.3] |
| 137 for i ∈ eachindex(v) | 163 for i ∈ eachindex(v) |
| 138 @test ((A+B)*v)[i] == 2*v[i] + 3*v[i] | 164 @test ((A+B)*v)[i] == 2*v[i] + 3*v[i] |
| 139 end | 165 end |
| 140 | 166 |
| 141 for i ∈ eachindex(v) | 167 for i ∈ eachindex(v) |
| 142 @test ((A-B)*v)[i] == 2*v[i] - 3*v[i] | 168 @test ((A-B)*v)[i] == 2*v[i] - 3*v[i] |
| 169 end | |
| 170 | |
| 171 for i ∈ eachindex(v) | |
| 172 @test ((A+B)'*v)[i] == 2*v[i] + 3*v[i] | |
| 143 end | 173 end |
| 144 | 174 |
| 145 | 175 |
| 146 @test range_size(A+B) == range_size(A) == range_size(B) | 176 @test range_size(A+B) == range_size(A) == range_size(B) |
| 147 @test domain_size(A+B) == domain_size(A) == domain_size(B) | 177 @test domain_size(A+B) == domain_size(A) == domain_size(B) |
| 153 | 183 |
| 154 @test_throws DomainSizeMismatch ScalingTensor(2.0, (4,)) + SizeDoublingMapping{Float64,1,1}((2,)) | 184 @test_throws DomainSizeMismatch ScalingTensor(2.0, (4,)) + SizeDoublingMapping{Float64,1,1}((2,)) |
| 155 @test_throws DomainSizeMismatch SizeDoublingMapping{Float64,1,1}((2,)) + ScalingTensor(2.0, (4,)) | 185 @test_throws DomainSizeMismatch SizeDoublingMapping{Float64,1,1}((2,)) + ScalingTensor(2.0, (4,)) |
| 156 @test_throws RangeSizeMismatch ScalingTensor(2.0, (2,)) + SizeDoublingMapping{Float64,1,1}((2,)) | 186 @test_throws RangeSizeMismatch ScalingTensor(2.0, (2,)) + SizeDoublingMapping{Float64,1,1}((2,)) |
| 157 @test_throws RangeSizeMismatch SizeDoublingMapping{Float64,1,1}((2,)) + ScalingTensor(2.0, (2,)) | 187 @test_throws RangeSizeMismatch SizeDoublingMapping{Float64,1,1}((2,)) + ScalingTensor(2.0, (2,)) |
| 188 end | |
| 189 | |
| 190 @testset "Chained operators" begin | |
| 191 A = ScalingTensor(1.0, (3,)) | |
| 192 B = ScalingTensor(2.0, (3,)) | |
| 193 C = ScalingTensor(3.0, (3,)) | |
| 194 D = ScalingTensor(4.0, (3,)) | |
| 195 | |
| 196 @test A+B+C+D isa TensorSum | |
| 197 @test length((A+B+C+D).tms) == 4 | |
| 198 | |
| 199 | |
| 200 @test A+B-C+D isa TensorSum | |
| 201 @test length((A+B-C+D).tms) == 4 | |
| 202 | |
| 203 v = rand(3) | |
| 204 @test (A+B-C+D)*v == 1v + 2v - 3v + 4v | |
| 205 | |
| 206 | |
| 207 @test -A-B-C-D isa TensorSum | |
| 208 @test length((-A-B-C-D).tms) == 4 | |
| 209 | |
| 210 v = rand(3) | |
| 211 @test (-A-B-C-D)*v == -1v - 2v - 3v - 4v | |
| 158 end | 212 end |
| 159 end | 213 end |
| 160 | 214 |
| 161 | 215 |
| 162 @testset "TensorComposition" begin | 216 @testset "TensorComposition" begin |
| 186 v = rand(3) | 240 v = rand(3) |
| 187 @test a*Ã isa TensorComposition | 241 @test a*Ã isa TensorComposition |
| 188 @test a*Ã == Ã*a | 242 @test a*Ã == Ã*a |
| 189 @test range_size(a*Ã) == range_size(Ã) | 243 @test range_size(a*Ã) == range_size(Ã) |
| 190 @test domain_size(a*Ã) == domain_size(Ã) | 244 @test domain_size(a*Ã) == domain_size(Ã) |
| 191 @test a*Ã*v == a.*A*v | 245 @test a*Ã*v ≈ a.*A*v rtol=1e-14 |
| 192 end | 246 end |
| 193 | 247 |
| 194 | 248 |
| 195 @testset "InflatedTensor" begin | 249 @testset "InflatedTensor" begin |
| 196 I(sz...) = IdentityTensor(sz...) | 250 I(sz...) = IdentityTensor(sz...) |
| 378 I1 = IdentityTensor(3,2) | 432 I1 = IdentityTensor(3,2) |
| 379 I2 = IdentityTensor(4) | 433 I2 = IdentityTensor(4) |
| 380 @test I1⊗Ã⊗I2 == InflatedTensor(I1, Ã, I2) | 434 @test I1⊗Ã⊗I2 == InflatedTensor(I1, Ã, I2) |
| 381 end | 435 end |
| 382 end | 436 end |
| 437 | |
| 438 @testset "inflate" begin | |
| 439 I = LazyTensors.inflate(IdentityTensor(),(3,4,5,6), 2) | |
| 440 @test I isa LazyTensor{Float64, 3,3} | |
| 441 @test range_size(I) == (3,5,6) | |
| 442 @test domain_size(I) == (3,5,6) | |
| 443 | |
| 444 @test LazyTensors.inflate(ScalingTensor(1., (4,)),(3,4,5,6), 1) == InflatedTensor(IdentityTensor{Float64}(),ScalingTensor(1., (4,)),IdentityTensor(4,5,6)) | |
| 445 @test LazyTensors.inflate(ScalingTensor(2., (1,)),(3,4,5,6), 2) == InflatedTensor(IdentityTensor(3),ScalingTensor(2., (1,)),IdentityTensor(5,6)) | |
| 446 @test LazyTensors.inflate(ScalingTensor(3., (6,)),(3,4,5,6), 4) == InflatedTensor(IdentityTensor(3,4,5),ScalingTensor(3., (6,)),IdentityTensor{Float64}()) | |
| 447 | |
| 448 @test_throws BoundsError LazyTensors.inflate(ScalingTensor(1., (4,)),(3,4,5,6), 0) | |
| 449 @test_throws BoundsError LazyTensors.inflate(ScalingTensor(1., (4,)),(3,4,5,6), 5) | |
| 450 end |
