changeset 1007:f7a718bcb4da refactor/lazy_tensors

Add checking of sizes to LazyTensorBinaryOperation
author Jonatan Werpers <jonatan@werpers.com>
date Sun, 20 Mar 2022 22:41:28 +0100
parents d9476fede83d
children 8920f237caf7
files src/LazyTensors/lazy_tensor_operations.jl test/LazyTensors/lazy_tensor_operations_test.jl
diffstat 2 files changed, 28 insertions(+), 19 deletions(-) [+]
line wrap: on
line diff
--- a/src/LazyTensors/lazy_tensor_operations.jl	Sun Mar 20 22:22:32 2022 +0100
+++ b/src/LazyTensors/lazy_tensor_operations.jl	Sun Mar 20 22:41:28 2022 +0100
@@ -1,5 +1,3 @@
-# TODO: Go over type parameters
-
 """
     LazyTensorApplication{T,R,D} <: LazyArray{T,R}
 
@@ -59,10 +57,11 @@
     tm2::T2
 
     function LazyTensorBinaryOperation{Op,T,R,D}(tm1::T1,tm2::T2) where {Op,T,R,D, T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}}
+        @boundscheck check_domain_size(tm2, domain_size(tm1))
+        @boundscheck check_range_size(tm2, range_size(tm1))
         return new{Op,T,R,D,T1,T2}(tm1,tm2)
     end
 end
-# TODO: Boundschecking in constructor.
 
 LazyTensorBinaryOperation{Op}(s,t) where Op = LazyTensorBinaryOperation{Op,eltype(s), range_dim(s), domain_dim(s)}(s,t)
 
--- a/test/LazyTensors/lazy_tensor_operations_test.jl	Sun Mar 20 22:22:32 2022 +0100
+++ b/test/LazyTensors/lazy_tensor_operations_test.jl	Sun Mar 20 22:41:28 2022 +0100
@@ -4,15 +4,26 @@
 
 using Tullio
 
-@testset "Mapping transpose" begin
-    struct DummyMapping{T,R,D} <: LazyTensor{T,R,D} end
+struct DummyMapping{T,R,D} <: LazyTensor{T,R,D} end
+
+LazyTensors.apply(m::DummyMapping{T,R}, v, I::Vararg{Any,R}) where {T,R} = :apply
+LazyTensors.apply_transpose(m::DummyMapping{T,R,D}, v, I::Vararg{Any,D}) where {T,R,D} = :apply_transpose
+
+LazyTensors.range_size(m::DummyMapping) = :range_size
+LazyTensors.domain_size(m::DummyMapping) = :domain_size
+
 
-    LazyTensors.apply(m::DummyMapping{T,R}, v, I::Vararg{Any,R}) where {T,R} = :apply
-    LazyTensors.apply_transpose(m::DummyMapping{T,R,D}, v, I::Vararg{Any,D}) where {T,R,D} = :apply_transpose
+struct SizeDoublingMapping{T,R,D} <: LazyTensor{T,R,D}
+    domain_size::NTuple{D,Int}
+end
 
-    LazyTensors.range_size(m::DummyMapping) = :range_size
-    LazyTensors.domain_size(m::DummyMapping) = :domain_size
+LazyTensors.apply(m::SizeDoublingMapping{T,R}, v, i::Vararg{Any,R}) where {T,R} = (:apply,v,i)
+LazyTensors.range_size(m::SizeDoublingMapping) = 2 .* m.domain_size
+LazyTensors.domain_size(m::SizeDoublingMapping) = m.domain_size
 
+
+
+@testset "Mapping transpose" begin
     m = DummyMapping{Float64,2,3}()
     @test m' isa LazyTensor{Float64, 3,2}
     @test m'' == m
@@ -26,14 +37,6 @@
 
 
 @testset "LazyTensorApplication" begin
-    struct SizeDoublingMapping{T,R,D} <: LazyTensor{T,R,D}
-        domain_size::NTuple{D,Int}
-    end
-
-    LazyTensors.apply(m::SizeDoublingMapping{T,R}, v, i::Vararg{Any,R}) where {T,R} = (:apply,v,i)
-    LazyTensors.range_size(m::SizeDoublingMapping) = 2 .* m.domain_size
-    LazyTensors.domain_size(m::SizeDoublingMapping) = m.domain_size
-
     m = SizeDoublingMapping{Int, 1, 1}((3,))
     mm = SizeDoublingMapping{Int, 1, 1}((6,))
     v = [0,1,2]
@@ -139,13 +142,20 @@
         @test ((A-B)*v)[i] == 2*v[i] - 3*v[i]
     end
 
-    # TODO: Test with size changing tm
-    # TODO: Test for mismatch in dimensions (DomainSizeMismatch?)
 
     @test range_size(A+B) == range_size(A) == range_size(B)
     @test domain_size(A+B) == domain_size(A) == domain_size(B)
 
     @test ((A+B)*ComplexF64[1.1,1.2,1.3])[3] isa ComplexF64
+
+    @testset "Error on unmatched sizes" begin
+        @test_throws Union{DomainSizeMismatch, RangeSizeMismatch} ScalingTensor(2.0, (3,)) + ScalingTensor(2.0, (4,))
+
+        @test_throws DomainSizeMismatch ScalingTensor(2.0, (4,)) + SizeDoublingMapping{Float64,1,1}((2,))
+        @test_throws DomainSizeMismatch SizeDoublingMapping{Float64,1,1}((2,)) + ScalingTensor(2.0, (4,))
+        @test_throws RangeSizeMismatch ScalingTensor(2.0, (2,)) + SizeDoublingMapping{Float64,1,1}((2,))
+        @test_throws RangeSizeMismatch SizeDoublingMapping{Float64,1,1}((2,)) + ScalingTensor(2.0, (2,))
+    end
 end