comparison test/LazyTensors/lazy_tensor_operations_test.jl @ 1900:418566cdd689

Merge refactor/lazy_tensors/elementwise_ops
author Jonatan Werpers <jonatan@werpers.com>
date Fri, 31 Jan 2025 20:35:28 +0100
parents b12e28a03b2e
children
comparison
equal deleted inserted replaced
1896:9d708f3300d5 1900:418566cdd689
18 end 18 end
19 19
20 LazyTensors.apply(m::SizeDoublingMapping{T,R}, v, i::Vararg{Any,R}) where {T,R} = (:apply,v,i) 20 LazyTensors.apply(m::SizeDoublingMapping{T,R}, v, i::Vararg{Any,R}) where {T,R} = (:apply,v,i)
21 LazyTensors.range_size(m::SizeDoublingMapping) = 2 .* m.domain_size 21 LazyTensors.range_size(m::SizeDoublingMapping) = 2 .* m.domain_size
22 LazyTensors.domain_size(m::SizeDoublingMapping) = m.domain_size 22 LazyTensors.domain_size(m::SizeDoublingMapping) = m.domain_size
23
24 23
25 24
26 @testset "Mapping transpose" begin 25 @testset "Mapping transpose" begin
27 m = TransposableDummyMapping{Float64,2,3}() 26 m = TransposableDummyMapping{Float64,2,3}()
28 @test m' isa LazyTensor{Float64, 3,2} 27 @test m' isa LazyTensor{Float64, 3,2}
126 @inferred m*v 125 @inferred m*v
127 @inferred (m*v)[1] 126 @inferred (m*v)[1]
128 end 127 end
129 end 128 end
130 129
131 130 @testset "TensorNegation" begin
132 @testset "LazyTensor binary operations" begin 131 A = rand(2,3)
132 B = rand(3,4)
133
134 Ã = DenseTensor(A, (1,), (2,))
135 B̃ = DenseTensor(B, (1,), (2,))
136
137 @test -Ã isa TensorNegation
138
139 v = rand(3)
140 @test (-Ã)*v == -(Ã*v)
141
142 v = rand(4)
143 @test (-B̃)*v == -(B̃*v)
144
145 v = rand(2)
146 @test (-Ã)'*v == -(Ã'*v)
147
148 v = rand(3)
149 @test (-B̃)'*v == -(B̃'*v)
150
151 @test domain_size(-Ã) == (3,)
152 @test domain_size(-B̃) == (4,)
153
154 @test range_size(-Ã) == (2,)
155 @test range_size(-B̃) == (3,)
156 end
157
158 @testset "TensorSum" begin
133 A = ScalingTensor(2.0, (3,)) 159 A = ScalingTensor(2.0, (3,))
134 B = ScalingTensor(3.0, (3,)) 160 B = ScalingTensor(3.0, (3,))
135 161
136 v = [1.1,1.2,1.3] 162 v = [1.1,1.2,1.3]
137 for i ∈ eachindex(v) 163 for i ∈ eachindex(v)
138 @test ((A+B)*v)[i] == 2*v[i] + 3*v[i] 164 @test ((A+B)*v)[i] == 2*v[i] + 3*v[i]
139 end 165 end
140 166
141 for i ∈ eachindex(v) 167 for i ∈ eachindex(v)
142 @test ((A-B)*v)[i] == 2*v[i] - 3*v[i] 168 @test ((A-B)*v)[i] == 2*v[i] - 3*v[i]
169 end
170
171 for i ∈ eachindex(v)
172 @test ((A+B)'*v)[i] == 2*v[i] + 3*v[i]
143 end 173 end
144 174
145 175
146 @test range_size(A+B) == range_size(A) == range_size(B) 176 @test range_size(A+B) == range_size(A) == range_size(B)
147 @test domain_size(A+B) == domain_size(A) == domain_size(B) 177 @test domain_size(A+B) == domain_size(A) == domain_size(B)
153 183
154 @test_throws DomainSizeMismatch ScalingTensor(2.0, (4,)) + SizeDoublingMapping{Float64,1,1}((2,)) 184 @test_throws DomainSizeMismatch ScalingTensor(2.0, (4,)) + SizeDoublingMapping{Float64,1,1}((2,))
155 @test_throws DomainSizeMismatch SizeDoublingMapping{Float64,1,1}((2,)) + ScalingTensor(2.0, (4,)) 185 @test_throws DomainSizeMismatch SizeDoublingMapping{Float64,1,1}((2,)) + ScalingTensor(2.0, (4,))
156 @test_throws RangeSizeMismatch ScalingTensor(2.0, (2,)) + SizeDoublingMapping{Float64,1,1}((2,)) 186 @test_throws RangeSizeMismatch ScalingTensor(2.0, (2,)) + SizeDoublingMapping{Float64,1,1}((2,))
157 @test_throws RangeSizeMismatch SizeDoublingMapping{Float64,1,1}((2,)) + ScalingTensor(2.0, (2,)) 187 @test_throws RangeSizeMismatch SizeDoublingMapping{Float64,1,1}((2,)) + ScalingTensor(2.0, (2,))
188 end
189
190 @testset "Chained operators" begin
191 A = ScalingTensor(1.0, (3,))
192 B = ScalingTensor(2.0, (3,))
193 C = ScalingTensor(3.0, (3,))
194 D = ScalingTensor(4.0, (3,))
195
196 @test A+B+C+D isa TensorSum
197 @test length((A+B+C+D).tms) == 4
198
199
200 @test A+B-C+D isa TensorSum
201 @test length((A+B-C+D).tms) == 4
202
203 v = rand(3)
204 @test (A+B-C+D)*v == 1v + 2v - 3v + 4v
205
206
207 @test -A-B-C-D isa TensorSum
208 @test length((-A-B-C-D).tms) == 4
209
210 v = rand(3)
211 @test (-A-B-C-D)*v == -1v - 2v - 3v - 4v
158 end 212 end
159 end 213 end
160 214
161 215
162 @testset "TensorComposition" begin 216 @testset "TensorComposition" begin