changeset 1006:d9476fede83d refactor/lazy_tensors

Add check methods for range size
author Jonatan Werpers <jonatan@werpers.com>
date Sun, 20 Mar 2022 22:22:32 +0100
parents becd95ba0fce
children f7a718bcb4da
files src/LazyTensors/LazyTensors.jl src/LazyTensors/lazy_tensor_operations.jl test/LazyTensors/lazy_tensor_operations_test.jl
diffstat 3 files changed, 27 insertions(+), 9 deletions(-) [+]
line wrap: on
line diff
--- a/src/LazyTensors/LazyTensors.jl	Sun Mar 20 22:15:29 2022 +0100
+++ b/src/LazyTensors/LazyTensors.jl	Sun Mar 20 22:22:32 2022 +0100
@@ -9,7 +9,8 @@
 export InflatedLazyTensor
 export LazyOuterProduct
 export ⊗
-export SizeMismatch
+export DomainSizeMismatch
+export RangeSizeMismatch
 
 include("lazy_tensor.jl")
 include("tensor_types.jl")
--- a/src/LazyTensors/lazy_tensor_operations.jl	Sun Mar 20 22:15:29 2022 +0100
+++ b/src/LazyTensors/lazy_tensor_operations.jl	Sun Mar 20 22:22:32 2022 +0100
@@ -271,16 +271,33 @@
 
 function check_domain_size(tm::LazyTensor, sz)
     if domain_size(tm) != sz
-        throw(SizeMismatch(tm,sz))
+        throw(DomainSizeMismatch(tm,sz))
     end
 end
 
-struct SizeMismatch <: Exception
+function check_range_size(tm::LazyTensor, sz)
+    if range_size(tm) != sz
+        throw(RangeSizeMismatch(tm,sz))
+    end
+end
+
+struct DomainSizeMismatch <: Exception
     tm::LazyTensor
     sz
 end
 
-function Base.showerror(io::IO, err::SizeMismatch)
-    print(io, "SizeMismatch: ")
+function Base.showerror(io::IO, err::DomainSizeMismatch)
+    print(io, "DomainSizeMismatch: ")
     print(io, "domain size $(domain_size(err.tm)) of LazyTensor not matching size $(err.sz)")
 end
+
+
+struct RangeSizeMismatch <: Exception
+    tm::LazyTensor
+    sz
+end
+
+function Base.showerror(io::IO, err::RangeSizeMismatch)
+    print(io, "RangeSizeMismatch: ")
+    print(io, "range size $(range_size(err.tm)) of LazyTensor not matching size $(err.sz)")
+end
--- a/test/LazyTensors/lazy_tensor_operations_test.jl	Sun Mar 20 22:15:29 2022 +0100
+++ b/test/LazyTensors/lazy_tensor_operations_test.jl	Sun Mar 20 22:22:32 2022 +0100
@@ -82,8 +82,8 @@
     end
 
     @testset "Error on unmatched sizes" begin
-        @test_throws SizeMismatch ScalingTensor(2,(2,))*ones(3)
-        @test_throws SizeMismatch ScalingTensor(2,(2,))*ScalingTensor(2,(3,))*ones(3)
+        @test_throws DomainSizeMismatch ScalingTensor(2,(2,))*ones(3)
+        @test_throws DomainSizeMismatch ScalingTensor(2,(2,))*ScalingTensor(2,(3,))*ones(3)
     end
 
 
@@ -140,7 +140,7 @@
     end
 
     # TODO: Test with size changing tm
-    # TODO: Test for mismatch in dimensions (SizeMismatch?)
+    # TODO: Test for mismatch in dimensions (DomainSizeMismatch?)
 
     @test range_size(A+B) == range_size(A) == range_size(B)
     @test domain_size(A+B) == domain_size(A) == domain_size(B)
@@ -159,7 +159,7 @@
     @test Ã∘B̃ isa LazyTensorComposition
     @test range_size(Ã∘B̃) == (2,)
     @test domain_size(Ã∘B̃) == (4,)
-    @test_throws SizeMismatch B̃∘Ã
+    @test_throws DomainSizeMismatch B̃∘Ã
 
     # @test @inbounds B̃∘Ã # Should not error even though dimensions don't match. (Since ]test runs with forced boundschecking this is currently not testable 2020-10-16)