comparison test/LazyTensors/lazy_tensor_operations_test.jl @ 1005:becd95ba0fce refactor/lazy_tensors

Add bounds checking for lazy tensor application and clea up tests a bit
author Jonatan Werpers <jonatan@werpers.com>
date Sun, 20 Mar 2022 22:15:29 +0100
parents 7fd37aab84fe
children d9476fede83d
comparison
equal deleted inserted replaced
1004:7fd37aab84fe 1005:becd95ba0fce
32 32
33 LazyTensors.apply(m::SizeDoublingMapping{T,R}, v, i::Vararg{Any,R}) where {T,R} = (:apply,v,i) 33 LazyTensors.apply(m::SizeDoublingMapping{T,R}, v, i::Vararg{Any,R}) where {T,R} = (:apply,v,i)
34 LazyTensors.range_size(m::SizeDoublingMapping) = 2 .* m.domain_size 34 LazyTensors.range_size(m::SizeDoublingMapping) = 2 .* m.domain_size
35 LazyTensors.domain_size(m::SizeDoublingMapping) = m.domain_size 35 LazyTensors.domain_size(m::SizeDoublingMapping) = m.domain_size
36 36
37
38 m = SizeDoublingMapping{Int, 1, 1}((3,)) 37 m = SizeDoublingMapping{Int, 1, 1}((3,))
38 mm = SizeDoublingMapping{Int, 1, 1}((6,))
39 v = [0,1,2] 39 v = [0,1,2]
40 @test size(m*v) == 2 .*size(v) 40 @test size(m*v) == 2 .*size(v)
41 @test (m*v)[0] == (:apply,v,(0,)) 41 @test (m*v)[1] == (:apply,v,(1,))
42 @test (m*m*v)[1] == (:apply,m*v,(1,)) 42 @test (mm*m*v)[1] == (:apply,m*v,(1,))
43 @test (m*m*v)[3] == (:apply,m*v,(3,)) 43 @test (mm*m*v)[3] == (:apply,m*v,(3,))
44 @test (m*m*v)[6] == (:apply,m*v,(6,)) 44 @test (mm*m*v)[6] == (:apply,m*v,(6,))
45 @test_broken BoundsError == (m*m*v)[0]
46 @test_broken BoundsError == (m*m*v)[7]
47 @test_throws MethodError m*m 45 @test_throws MethodError m*m
48 46
49 @test (m*v)[CartesianIndex(2)] == (:apply,v,(2,)) 47 @test (m*v)[CartesianIndex(2)] == (:apply,v,(2,))
50 @test (m*m*v)[CartesianIndex(2)] == (:apply,m*v,(2,)) 48 @test (mm*m*v)[CartesianIndex(2)] == (:apply,m*v,(2,))
51
52 m = SizeDoublingMapping{Int, 2, 1}((3,))
53 @test_throws MethodError m*ones(Int,2,2)
54 @test_throws MethodError m*m*v
55 49
56 m = SizeDoublingMapping{Float64, 2, 2}((3,3)) 50 m = SizeDoublingMapping{Float64, 2, 2}((3,3))
51 mm = SizeDoublingMapping{Float64, 2, 2}((6,6))
57 v = ones(3,3) 52 v = ones(3,3)
58 @test size(m*v) == 2 .*size(v) 53 @test size(m*v) == 2 .*size(v)
59 @test (m*v)[1,2] == (:apply,v,(1,2)) 54 @test (m*v)[1,2] == (:apply,v,(1,2))
60 55
61 @test (m*v)[CartesianIndex(2,3)] == (:apply,v,(2,3)) 56 @test (m*v)[CartesianIndex(2,3)] == (:apply,v,(2,3))
62 @test (m*m*v)[CartesianIndex(4,3)] == (:apply,m*v,(4,3)) 57 @test (mm*m*v)[CartesianIndex(4,3)] == (:apply,m*v,(4,3))
63 58
64 m = ScalingTensor(2,(3,)) 59 m = ScalingTensor(2,(3,))
65 v = [1,2,3] 60 v = [1,2,3]
66 @test m*v isa AbstractVector 61 @test m*v isa AbstractVector
67 @test m*v == [2,4,6] 62 @test m*v == [2,4,6]
68 63
69 m = ScalingTensor(2,(2,2)) 64 m = ScalingTensor(2,(2,2))
70 v = [[1 2];[3 4]] 65 v = [[1 2];[3 4]]
71 @test m*v == [[2 4];[6 8]] 66 @test m*v == [[2 4];[6 8]]
72 @test (m*v)[2,1] == 6 67 @test (m*v)[2,1] == 6
68
69 @testset "Error on index out of bounds" begin
70 m = SizeDoublingMapping{Int, 1, 1}((3,))
71 v = [0,1,2]
72
73 @test_throws BoundsError (m*v)[0]
74 @test_throws BoundsError (m*v)[7]
75 end
76
77 @testset "Error on unmatched dimensions" begin
78 v = [0,1,2]
79 m = SizeDoublingMapping{Int, 2, 1}((3,))
80 @test_throws MethodError m*ones(Int,2,2)
81 @test_throws MethodError m*m*v
82 end
83
84 @testset "Error on unmatched sizes" begin
85 @test_throws SizeMismatch ScalingTensor(2,(2,))*ones(3)
86 @test_throws SizeMismatch ScalingTensor(2,(2,))*ScalingTensor(2,(3,))*ones(3)
87 end
88
73 89
74 @testset "Type calculation" begin 90 @testset "Type calculation" begin
75 m = ScalingTensor(2,(3,)) 91 m = ScalingTensor(2,(3,))
76 v = [1.,2.,3.] 92 v = [1.,2.,3.]
77 @test m*v isa AbstractVector{Float64} 93 @test m*v isa AbstractVector{Float64}