Mercurial > repos > public > sbplib_julia
comparison test/LazyTensors/lazy_tensor_operations_test.jl @ 1007:f7a718bcb4da refactor/lazy_tensors
Add checking of sizes to LazyTensorBinaryOperation
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Sun, 20 Mar 2022 22:41:28 +0100 |
parents | d9476fede83d |
children | 2c1a0722ddb9 56fe037641ef |
comparison
equal
deleted
inserted
replaced
1006:d9476fede83d | 1007:f7a718bcb4da |
---|---|
2 using Sbplib.LazyTensors | 2 using Sbplib.LazyTensors |
3 using Sbplib.RegionIndices | 3 using Sbplib.RegionIndices |
4 | 4 |
5 using Tullio | 5 using Tullio |
6 | 6 |
7 struct DummyMapping{T,R,D} <: LazyTensor{T,R,D} end | |
8 | |
9 LazyTensors.apply(m::DummyMapping{T,R}, v, I::Vararg{Any,R}) where {T,R} = :apply | |
10 LazyTensors.apply_transpose(m::DummyMapping{T,R,D}, v, I::Vararg{Any,D}) where {T,R,D} = :apply_transpose | |
11 | |
12 LazyTensors.range_size(m::DummyMapping) = :range_size | |
13 LazyTensors.domain_size(m::DummyMapping) = :domain_size | |
14 | |
15 | |
16 struct SizeDoublingMapping{T,R,D} <: LazyTensor{T,R,D} | |
17 domain_size::NTuple{D,Int} | |
18 end | |
19 | |
20 LazyTensors.apply(m::SizeDoublingMapping{T,R}, v, i::Vararg{Any,R}) where {T,R} = (:apply,v,i) | |
21 LazyTensors.range_size(m::SizeDoublingMapping) = 2 .* m.domain_size | |
22 LazyTensors.domain_size(m::SizeDoublingMapping) = m.domain_size | |
23 | |
24 | |
25 | |
7 @testset "Mapping transpose" begin | 26 @testset "Mapping transpose" begin |
8 struct DummyMapping{T,R,D} <: LazyTensor{T,R,D} end | |
9 | |
10 LazyTensors.apply(m::DummyMapping{T,R}, v, I::Vararg{Any,R}) where {T,R} = :apply | |
11 LazyTensors.apply_transpose(m::DummyMapping{T,R,D}, v, I::Vararg{Any,D}) where {T,R,D} = :apply_transpose | |
12 | |
13 LazyTensors.range_size(m::DummyMapping) = :range_size | |
14 LazyTensors.domain_size(m::DummyMapping) = :domain_size | |
15 | |
16 m = DummyMapping{Float64,2,3}() | 27 m = DummyMapping{Float64,2,3}() |
17 @test m' isa LazyTensor{Float64, 3,2} | 28 @test m' isa LazyTensor{Float64, 3,2} |
18 @test m'' == m | 29 @test m'' == m |
19 @test apply(m',zeros(Float64,(0,0)), 0, 0, 0) == :apply_transpose | 30 @test apply(m',zeros(Float64,(0,0)), 0, 0, 0) == :apply_transpose |
20 @test apply(m'',zeros(Float64,(0,0,0)), 0, 0) == :apply | 31 @test apply(m'',zeros(Float64,(0,0,0)), 0, 0) == :apply |
24 @test domain_size(m') == :range_size | 35 @test domain_size(m') == :range_size |
25 end | 36 end |
26 | 37 |
27 | 38 |
28 @testset "LazyTensorApplication" begin | 39 @testset "LazyTensorApplication" begin |
29 struct SizeDoublingMapping{T,R,D} <: LazyTensor{T,R,D} | |
30 domain_size::NTuple{D,Int} | |
31 end | |
32 | |
33 LazyTensors.apply(m::SizeDoublingMapping{T,R}, v, i::Vararg{Any,R}) where {T,R} = (:apply,v,i) | |
34 LazyTensors.range_size(m::SizeDoublingMapping) = 2 .* m.domain_size | |
35 LazyTensors.domain_size(m::SizeDoublingMapping) = m.domain_size | |
36 | |
37 m = SizeDoublingMapping{Int, 1, 1}((3,)) | 40 m = SizeDoublingMapping{Int, 1, 1}((3,)) |
38 mm = SizeDoublingMapping{Int, 1, 1}((6,)) | 41 mm = SizeDoublingMapping{Int, 1, 1}((6,)) |
39 v = [0,1,2] | 42 v = [0,1,2] |
40 @test size(m*v) == 2 .*size(v) | 43 @test size(m*v) == 2 .*size(v) |
41 @test (m*v)[1] == (:apply,v,(1,)) | 44 @test (m*v)[1] == (:apply,v,(1,)) |
137 | 140 |
138 for i ∈ eachindex(v) | 141 for i ∈ eachindex(v) |
139 @test ((A-B)*v)[i] == 2*v[i] - 3*v[i] | 142 @test ((A-B)*v)[i] == 2*v[i] - 3*v[i] |
140 end | 143 end |
141 | 144 |
142 # TODO: Test with size changing tm | |
143 # TODO: Test for mismatch in dimensions (DomainSizeMismatch?) | |
144 | 145 |
145 @test range_size(A+B) == range_size(A) == range_size(B) | 146 @test range_size(A+B) == range_size(A) == range_size(B) |
146 @test domain_size(A+B) == domain_size(A) == domain_size(B) | 147 @test domain_size(A+B) == domain_size(A) == domain_size(B) |
147 | 148 |
148 @test ((A+B)*ComplexF64[1.1,1.2,1.3])[3] isa ComplexF64 | 149 @test ((A+B)*ComplexF64[1.1,1.2,1.3])[3] isa ComplexF64 |
150 | |
151 @testset "Error on unmatched sizes" begin | |
152 @test_throws Union{DomainSizeMismatch, RangeSizeMismatch} ScalingTensor(2.0, (3,)) + ScalingTensor(2.0, (4,)) | |
153 | |
154 @test_throws DomainSizeMismatch ScalingTensor(2.0, (4,)) + SizeDoublingMapping{Float64,1,1}((2,)) | |
155 @test_throws DomainSizeMismatch SizeDoublingMapping{Float64,1,1}((2,)) + ScalingTensor(2.0, (4,)) | |
156 @test_throws RangeSizeMismatch ScalingTensor(2.0, (2,)) + SizeDoublingMapping{Float64,1,1}((2,)) | |
157 @test_throws RangeSizeMismatch SizeDoublingMapping{Float64,1,1}((2,)) + ScalingTensor(2.0, (2,)) | |
158 end | |
149 end | 159 end |
150 | 160 |
151 | 161 |
152 @testset "LazyTensorComposition" begin | 162 @testset "LazyTensorComposition" begin |
153 A = rand(2,3) | 163 A = rand(2,3) |