changeset 1005:becd95ba0fce refactor/lazy_tensors

Add bounds checking for lazy tensor application and clea up tests a bit
author Jonatan Werpers <jonatan@werpers.com>
date Sun, 20 Mar 2022 22:15:29 +0100
parents 7fd37aab84fe
children d9476fede83d
files src/LazyTensors/lazy_tensor_operations.jl test/LazyTensors/lazy_tensor_operations_test.jl
diffstat 2 files changed, 35 insertions(+), 19 deletions(-) [+]
line wrap: on
line diff
--- a/src/LazyTensors/lazy_tensor_operations.jl	Sun Mar 20 21:35:20 2022 +0100
+++ b/src/LazyTensors/lazy_tensor_operations.jl	Sun Mar 20 22:15:29 2022 +0100
@@ -1,7 +1,5 @@
-# TODO: We need to be really careful about good error messages.
 # TODO: Go over type parameters
 
-
 """
     LazyTensorApplication{T,R,D} <: LazyArray{T,R}
 
@@ -16,17 +14,19 @@
     o::AA
 
     function LazyTensorApplication(t::LazyTensor{<:Any,R,D}, o::AbstractArray{<:Any,D}) where {R,D}
+        @boundscheck check_domain_size(t, size(o))
         I = ntuple(i->1, range_dim(t))
         T = typeof(apply(t,o,I...))
         return new{T,R,D,typeof(t), typeof(o)}(t,o)
     end
 end
-# TODO: Do boundschecking on creation!
 
-Base.getindex(ta::LazyTensorApplication{T,R}, I::Vararg{Any,R}) where {T,R} = apply(ta.t, ta.o, I...)
-Base.getindex(ta::LazyTensorApplication{T,1}, I::CartesianIndex{1}) where {T} = apply(ta.t, ta.o, I.I...) # Would otherwise be caught in the previous method.
+function Base.getindex(ta::LazyTensorApplication{T,R}, I::Vararg{Any,R}) where {T,R}
+    @boundscheck checkbounds(ta, Int.(I)...)
+    return apply(ta.t, ta.o, I...)
+end
+Base.getindex(ta::LazyTensorApplication{T,1} where T, I::CartesianIndex{1}) = ta[Tuple(I)...] # Would otherwise be caught in the previous method.
 Base.size(ta::LazyTensorApplication) = range_size(ta.t)
-# TODO: What else is needed to implement the AbstractArray interface?
 
 
 """
--- a/test/LazyTensors/lazy_tensor_operations_test.jl	Sun Mar 20 21:35:20 2022 +0100
+++ b/test/LazyTensors/lazy_tensor_operations_test.jl	Sun Mar 20 22:15:29 2022 +0100
@@ -34,32 +34,27 @@
     LazyTensors.range_size(m::SizeDoublingMapping) = 2 .* m.domain_size
     LazyTensors.domain_size(m::SizeDoublingMapping) = m.domain_size
 
-
     m = SizeDoublingMapping{Int, 1, 1}((3,))
+    mm = SizeDoublingMapping{Int, 1, 1}((6,))
     v = [0,1,2]
     @test size(m*v) == 2 .*size(v)
-    @test (m*v)[0] == (:apply,v,(0,))
-    @test (m*m*v)[1] == (:apply,m*v,(1,))
-    @test (m*m*v)[3] == (:apply,m*v,(3,))
-    @test (m*m*v)[6] == (:apply,m*v,(6,))
-    @test_broken BoundsError == (m*m*v)[0]
-    @test_broken BoundsError == (m*m*v)[7]
+    @test (m*v)[1] == (:apply,v,(1,))
+    @test (mm*m*v)[1] == (:apply,m*v,(1,))
+    @test (mm*m*v)[3] == (:apply,m*v,(3,))
+    @test (mm*m*v)[6] == (:apply,m*v,(6,))
     @test_throws MethodError m*m
 
     @test (m*v)[CartesianIndex(2)] == (:apply,v,(2,))
-    @test (m*m*v)[CartesianIndex(2)] == (:apply,m*v,(2,))
-
-    m = SizeDoublingMapping{Int, 2, 1}((3,))
-    @test_throws MethodError m*ones(Int,2,2)
-    @test_throws MethodError m*m*v
+    @test (mm*m*v)[CartesianIndex(2)] == (:apply,m*v,(2,))
 
     m = SizeDoublingMapping{Float64, 2, 2}((3,3))
+    mm = SizeDoublingMapping{Float64, 2, 2}((6,6))
     v = ones(3,3)
     @test size(m*v) == 2 .*size(v)
     @test (m*v)[1,2] == (:apply,v,(1,2))
 
     @test (m*v)[CartesianIndex(2,3)] == (:apply,v,(2,3))
-    @test (m*m*v)[CartesianIndex(4,3)] == (:apply,m*v,(4,3))
+    @test (mm*m*v)[CartesianIndex(4,3)] == (:apply,m*v,(4,3))
 
     m = ScalingTensor(2,(3,))
     v = [1,2,3]
@@ -71,6 +66,27 @@
     @test m*v == [[2 4];[6 8]]
     @test (m*v)[2,1] == 6
 
+    @testset "Error on index out of bounds" begin
+        m = SizeDoublingMapping{Int, 1, 1}((3,))
+        v = [0,1,2]
+
+        @test_throws BoundsError (m*v)[0]
+        @test_throws BoundsError (m*v)[7]
+    end
+
+    @testset "Error on unmatched dimensions" begin
+        v = [0,1,2]
+        m = SizeDoublingMapping{Int, 2, 1}((3,))
+        @test_throws MethodError m*ones(Int,2,2)
+        @test_throws MethodError m*m*v
+    end
+
+    @testset "Error on unmatched sizes" begin
+        @test_throws SizeMismatch ScalingTensor(2,(2,))*ones(3)
+        @test_throws SizeMismatch ScalingTensor(2,(2,))*ScalingTensor(2,(3,))*ones(3)
+    end
+
+
     @testset "Type calculation" begin
         m = ScalingTensor(2,(3,))
         v = [1.,2.,3.]