diff src/LazyTensors/lazy_tensor_operations.jl @ 1006:d9476fede83d refactor/lazy_tensors

Add check methods for range size
author Jonatan Werpers <jonatan@werpers.com>
date Sun, 20 Mar 2022 22:22:32 +0100
parents becd95ba0fce
children f7a718bcb4da
line wrap: on
line diff
--- a/src/LazyTensors/lazy_tensor_operations.jl	Sun Mar 20 22:15:29 2022 +0100
+++ b/src/LazyTensors/lazy_tensor_operations.jl	Sun Mar 20 22:22:32 2022 +0100
@@ -271,16 +271,33 @@
 
 function check_domain_size(tm::LazyTensor, sz)
     if domain_size(tm) != sz
-        throw(SizeMismatch(tm,sz))
+        throw(DomainSizeMismatch(tm,sz))
     end
 end
 
-struct SizeMismatch <: Exception
+function check_range_size(tm::LazyTensor, sz)
+    if range_size(tm) != sz
+        throw(RangeSizeMismatch(tm,sz))
+    end
+end
+
+struct DomainSizeMismatch <: Exception
     tm::LazyTensor
     sz
 end
 
-function Base.showerror(io::IO, err::SizeMismatch)
-    print(io, "SizeMismatch: ")
+function Base.showerror(io::IO, err::DomainSizeMismatch)
+    print(io, "DomainSizeMismatch: ")
     print(io, "domain size $(domain_size(err.tm)) of LazyTensor not matching size $(err.sz)")
 end
+
+
+struct RangeSizeMismatch <: Exception
+    tm::LazyTensor
+    sz
+end
+
+function Base.showerror(io::IO, err::RangeSizeMismatch)
+    print(io, "RangeSizeMismatch: ")
+    print(io, "range size $(range_size(err.tm)) of LazyTensor not matching size $(err.sz)")
+end