Mercurial > repos > public > sbplib_julia
changeset 1005:becd95ba0fce refactor/lazy_tensors
Add bounds checking for lazy tensor application and clea up tests a bit
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Sun, 20 Mar 2022 22:15:29 +0100 |
parents | 7fd37aab84fe |
children | d9476fede83d |
files | src/LazyTensors/lazy_tensor_operations.jl test/LazyTensors/lazy_tensor_operations_test.jl |
diffstat | 2 files changed, 35 insertions(+), 19 deletions(-) [+] |
line wrap: on
line diff
--- a/src/LazyTensors/lazy_tensor_operations.jl Sun Mar 20 21:35:20 2022 +0100 +++ b/src/LazyTensors/lazy_tensor_operations.jl Sun Mar 20 22:15:29 2022 +0100 @@ -1,7 +1,5 @@ -# TODO: We need to be really careful about good error messages. # TODO: Go over type parameters - """ LazyTensorApplication{T,R,D} <: LazyArray{T,R} @@ -16,17 +14,19 @@ o::AA function LazyTensorApplication(t::LazyTensor{<:Any,R,D}, o::AbstractArray{<:Any,D}) where {R,D} + @boundscheck check_domain_size(t, size(o)) I = ntuple(i->1, range_dim(t)) T = typeof(apply(t,o,I...)) return new{T,R,D,typeof(t), typeof(o)}(t,o) end end -# TODO: Do boundschecking on creation! -Base.getindex(ta::LazyTensorApplication{T,R}, I::Vararg{Any,R}) where {T,R} = apply(ta.t, ta.o, I...) -Base.getindex(ta::LazyTensorApplication{T,1}, I::CartesianIndex{1}) where {T} = apply(ta.t, ta.o, I.I...) # Would otherwise be caught in the previous method. +function Base.getindex(ta::LazyTensorApplication{T,R}, I::Vararg{Any,R}) where {T,R} + @boundscheck checkbounds(ta, Int.(I)...) + return apply(ta.t, ta.o, I...) +end +Base.getindex(ta::LazyTensorApplication{T,1} where T, I::CartesianIndex{1}) = ta[Tuple(I)...] # Would otherwise be caught in the previous method. Base.size(ta::LazyTensorApplication) = range_size(ta.t) -# TODO: What else is needed to implement the AbstractArray interface? """
--- a/test/LazyTensors/lazy_tensor_operations_test.jl Sun Mar 20 21:35:20 2022 +0100 +++ b/test/LazyTensors/lazy_tensor_operations_test.jl Sun Mar 20 22:15:29 2022 +0100 @@ -34,32 +34,27 @@ LazyTensors.range_size(m::SizeDoublingMapping) = 2 .* m.domain_size LazyTensors.domain_size(m::SizeDoublingMapping) = m.domain_size - m = SizeDoublingMapping{Int, 1, 1}((3,)) + mm = SizeDoublingMapping{Int, 1, 1}((6,)) v = [0,1,2] @test size(m*v) == 2 .*size(v) - @test (m*v)[0] == (:apply,v,(0,)) - @test (m*m*v)[1] == (:apply,m*v,(1,)) - @test (m*m*v)[3] == (:apply,m*v,(3,)) - @test (m*m*v)[6] == (:apply,m*v,(6,)) - @test_broken BoundsError == (m*m*v)[0] - @test_broken BoundsError == (m*m*v)[7] + @test (m*v)[1] == (:apply,v,(1,)) + @test (mm*m*v)[1] == (:apply,m*v,(1,)) + @test (mm*m*v)[3] == (:apply,m*v,(3,)) + @test (mm*m*v)[6] == (:apply,m*v,(6,)) @test_throws MethodError m*m @test (m*v)[CartesianIndex(2)] == (:apply,v,(2,)) - @test (m*m*v)[CartesianIndex(2)] == (:apply,m*v,(2,)) - - m = SizeDoublingMapping{Int, 2, 1}((3,)) - @test_throws MethodError m*ones(Int,2,2) - @test_throws MethodError m*m*v + @test (mm*m*v)[CartesianIndex(2)] == (:apply,m*v,(2,)) m = SizeDoublingMapping{Float64, 2, 2}((3,3)) + mm = SizeDoublingMapping{Float64, 2, 2}((6,6)) v = ones(3,3) @test size(m*v) == 2 .*size(v) @test (m*v)[1,2] == (:apply,v,(1,2)) @test (m*v)[CartesianIndex(2,3)] == (:apply,v,(2,3)) - @test (m*m*v)[CartesianIndex(4,3)] == (:apply,m*v,(4,3)) + @test (mm*m*v)[CartesianIndex(4,3)] == (:apply,m*v,(4,3)) m = ScalingTensor(2,(3,)) v = [1,2,3] @@ -71,6 +66,27 @@ @test m*v == [[2 4];[6 8]] @test (m*v)[2,1] == 6 + @testset "Error on index out of bounds" begin + m = SizeDoublingMapping{Int, 1, 1}((3,)) + v = [0,1,2] + + @test_throws BoundsError (m*v)[0] + @test_throws BoundsError (m*v)[7] + end + + @testset "Error on unmatched dimensions" begin + v = [0,1,2] + m = SizeDoublingMapping{Int, 2, 1}((3,)) + @test_throws MethodError m*ones(Int,2,2) + @test_throws MethodError m*m*v + end + + @testset "Error on unmatched sizes" begin + @test_throws SizeMismatch ScalingTensor(2,(2,))*ones(3) + @test_throws SizeMismatch ScalingTensor(2,(2,))*ScalingTensor(2,(3,))*ones(3) + end + + @testset "Type calculation" begin m = ScalingTensor(2,(3,)) v = [1.,2.,3.]