Mercurial > repos > public > sbplib_julia
comparison test/LazyTensors/lazy_tensor_operations_test.jl @ 1900:418566cdd689
Merge refactor/lazy_tensors/elementwise_ops
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Fri, 31 Jan 2025 20:35:28 +0100 |
parents | b12e28a03b2e |
children |
comparison
equal
deleted
inserted
replaced
1896:9d708f3300d5 | 1900:418566cdd689 |
---|---|
18 end | 18 end |
19 | 19 |
20 LazyTensors.apply(m::SizeDoublingMapping{T,R}, v, i::Vararg{Any,R}) where {T,R} = (:apply,v,i) | 20 LazyTensors.apply(m::SizeDoublingMapping{T,R}, v, i::Vararg{Any,R}) where {T,R} = (:apply,v,i) |
21 LazyTensors.range_size(m::SizeDoublingMapping) = 2 .* m.domain_size | 21 LazyTensors.range_size(m::SizeDoublingMapping) = 2 .* m.domain_size |
22 LazyTensors.domain_size(m::SizeDoublingMapping) = m.domain_size | 22 LazyTensors.domain_size(m::SizeDoublingMapping) = m.domain_size |
23 | |
24 | 23 |
25 | 24 |
26 @testset "Mapping transpose" begin | 25 @testset "Mapping transpose" begin |
27 m = TransposableDummyMapping{Float64,2,3}() | 26 m = TransposableDummyMapping{Float64,2,3}() |
28 @test m' isa LazyTensor{Float64, 3,2} | 27 @test m' isa LazyTensor{Float64, 3,2} |
126 @inferred m*v | 125 @inferred m*v |
127 @inferred (m*v)[1] | 126 @inferred (m*v)[1] |
128 end | 127 end |
129 end | 128 end |
130 | 129 |
131 | 130 @testset "TensorNegation" begin |
132 @testset "LazyTensor binary operations" begin | 131 A = rand(2,3) |
132 B = rand(3,4) | |
133 | |
134 Ã = DenseTensor(A, (1,), (2,)) | |
135 B̃ = DenseTensor(B, (1,), (2,)) | |
136 | |
137 @test -Ã isa TensorNegation | |
138 | |
139 v = rand(3) | |
140 @test (-Ã)*v == -(Ã*v) | |
141 | |
142 v = rand(4) | |
143 @test (-B̃)*v == -(B̃*v) | |
144 | |
145 v = rand(2) | |
146 @test (-Ã)'*v == -(Ã'*v) | |
147 | |
148 v = rand(3) | |
149 @test (-B̃)'*v == -(B̃'*v) | |
150 | |
151 @test domain_size(-Ã) == (3,) | |
152 @test domain_size(-B̃) == (4,) | |
153 | |
154 @test range_size(-Ã) == (2,) | |
155 @test range_size(-B̃) == (3,) | |
156 end | |
157 | |
158 @testset "TensorSum" begin | |
133 A = ScalingTensor(2.0, (3,)) | 159 A = ScalingTensor(2.0, (3,)) |
134 B = ScalingTensor(3.0, (3,)) | 160 B = ScalingTensor(3.0, (3,)) |
135 | 161 |
136 v = [1.1,1.2,1.3] | 162 v = [1.1,1.2,1.3] |
137 for i ∈ eachindex(v) | 163 for i ∈ eachindex(v) |
138 @test ((A+B)*v)[i] == 2*v[i] + 3*v[i] | 164 @test ((A+B)*v)[i] == 2*v[i] + 3*v[i] |
139 end | 165 end |
140 | 166 |
141 for i ∈ eachindex(v) | 167 for i ∈ eachindex(v) |
142 @test ((A-B)*v)[i] == 2*v[i] - 3*v[i] | 168 @test ((A-B)*v)[i] == 2*v[i] - 3*v[i] |
169 end | |
170 | |
171 for i ∈ eachindex(v) | |
172 @test ((A+B)'*v)[i] == 2*v[i] + 3*v[i] | |
143 end | 173 end |
144 | 174 |
145 | 175 |
146 @test range_size(A+B) == range_size(A) == range_size(B) | 176 @test range_size(A+B) == range_size(A) == range_size(B) |
147 @test domain_size(A+B) == domain_size(A) == domain_size(B) | 177 @test domain_size(A+B) == domain_size(A) == domain_size(B) |
153 | 183 |
154 @test_throws DomainSizeMismatch ScalingTensor(2.0, (4,)) + SizeDoublingMapping{Float64,1,1}((2,)) | 184 @test_throws DomainSizeMismatch ScalingTensor(2.0, (4,)) + SizeDoublingMapping{Float64,1,1}((2,)) |
155 @test_throws DomainSizeMismatch SizeDoublingMapping{Float64,1,1}((2,)) + ScalingTensor(2.0, (4,)) | 185 @test_throws DomainSizeMismatch SizeDoublingMapping{Float64,1,1}((2,)) + ScalingTensor(2.0, (4,)) |
156 @test_throws RangeSizeMismatch ScalingTensor(2.0, (2,)) + SizeDoublingMapping{Float64,1,1}((2,)) | 186 @test_throws RangeSizeMismatch ScalingTensor(2.0, (2,)) + SizeDoublingMapping{Float64,1,1}((2,)) |
157 @test_throws RangeSizeMismatch SizeDoublingMapping{Float64,1,1}((2,)) + ScalingTensor(2.0, (2,)) | 187 @test_throws RangeSizeMismatch SizeDoublingMapping{Float64,1,1}((2,)) + ScalingTensor(2.0, (2,)) |
188 end | |
189 | |
190 @testset "Chained operators" begin | |
191 A = ScalingTensor(1.0, (3,)) | |
192 B = ScalingTensor(2.0, (3,)) | |
193 C = ScalingTensor(3.0, (3,)) | |
194 D = ScalingTensor(4.0, (3,)) | |
195 | |
196 @test A+B+C+D isa TensorSum | |
197 @test length((A+B+C+D).tms) == 4 | |
198 | |
199 | |
200 @test A+B-C+D isa TensorSum | |
201 @test length((A+B-C+D).tms) == 4 | |
202 | |
203 v = rand(3) | |
204 @test (A+B-C+D)*v == 1v + 2v - 3v + 4v | |
205 | |
206 | |
207 @test -A-B-C-D isa TensorSum | |
208 @test length((-A-B-C-D).tms) == 4 | |
209 | |
210 v = rand(3) | |
211 @test (-A-B-C-D)*v == -1v - 2v - 3v - 4v | |
158 end | 212 end |
159 end | 213 end |
160 | 214 |
161 | 215 |
162 @testset "TensorComposition" begin | 216 @testset "TensorComposition" begin |