Mercurial > repos > public > sbplib_julia
comparison test/LazyTensors/lazy_tensor_operations_test.jl @ 1005:becd95ba0fce refactor/lazy_tensors
Add bounds checking for lazy tensor application and clea up tests a bit
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Sun, 20 Mar 2022 22:15:29 +0100 |
parents | 7fd37aab84fe |
children | d9476fede83d |
comparison
equal
deleted
inserted
replaced
1004:7fd37aab84fe | 1005:becd95ba0fce |
---|---|
32 | 32 |
33 LazyTensors.apply(m::SizeDoublingMapping{T,R}, v, i::Vararg{Any,R}) where {T,R} = (:apply,v,i) | 33 LazyTensors.apply(m::SizeDoublingMapping{T,R}, v, i::Vararg{Any,R}) where {T,R} = (:apply,v,i) |
34 LazyTensors.range_size(m::SizeDoublingMapping) = 2 .* m.domain_size | 34 LazyTensors.range_size(m::SizeDoublingMapping) = 2 .* m.domain_size |
35 LazyTensors.domain_size(m::SizeDoublingMapping) = m.domain_size | 35 LazyTensors.domain_size(m::SizeDoublingMapping) = m.domain_size |
36 | 36 |
37 | |
38 m = SizeDoublingMapping{Int, 1, 1}((3,)) | 37 m = SizeDoublingMapping{Int, 1, 1}((3,)) |
38 mm = SizeDoublingMapping{Int, 1, 1}((6,)) | |
39 v = [0,1,2] | 39 v = [0,1,2] |
40 @test size(m*v) == 2 .*size(v) | 40 @test size(m*v) == 2 .*size(v) |
41 @test (m*v)[0] == (:apply,v,(0,)) | 41 @test (m*v)[1] == (:apply,v,(1,)) |
42 @test (m*m*v)[1] == (:apply,m*v,(1,)) | 42 @test (mm*m*v)[1] == (:apply,m*v,(1,)) |
43 @test (m*m*v)[3] == (:apply,m*v,(3,)) | 43 @test (mm*m*v)[3] == (:apply,m*v,(3,)) |
44 @test (m*m*v)[6] == (:apply,m*v,(6,)) | 44 @test (mm*m*v)[6] == (:apply,m*v,(6,)) |
45 @test_broken BoundsError == (m*m*v)[0] | |
46 @test_broken BoundsError == (m*m*v)[7] | |
47 @test_throws MethodError m*m | 45 @test_throws MethodError m*m |
48 | 46 |
49 @test (m*v)[CartesianIndex(2)] == (:apply,v,(2,)) | 47 @test (m*v)[CartesianIndex(2)] == (:apply,v,(2,)) |
50 @test (m*m*v)[CartesianIndex(2)] == (:apply,m*v,(2,)) | 48 @test (mm*m*v)[CartesianIndex(2)] == (:apply,m*v,(2,)) |
51 | |
52 m = SizeDoublingMapping{Int, 2, 1}((3,)) | |
53 @test_throws MethodError m*ones(Int,2,2) | |
54 @test_throws MethodError m*m*v | |
55 | 49 |
56 m = SizeDoublingMapping{Float64, 2, 2}((3,3)) | 50 m = SizeDoublingMapping{Float64, 2, 2}((3,3)) |
51 mm = SizeDoublingMapping{Float64, 2, 2}((6,6)) | |
57 v = ones(3,3) | 52 v = ones(3,3) |
58 @test size(m*v) == 2 .*size(v) | 53 @test size(m*v) == 2 .*size(v) |
59 @test (m*v)[1,2] == (:apply,v,(1,2)) | 54 @test (m*v)[1,2] == (:apply,v,(1,2)) |
60 | 55 |
61 @test (m*v)[CartesianIndex(2,3)] == (:apply,v,(2,3)) | 56 @test (m*v)[CartesianIndex(2,3)] == (:apply,v,(2,3)) |
62 @test (m*m*v)[CartesianIndex(4,3)] == (:apply,m*v,(4,3)) | 57 @test (mm*m*v)[CartesianIndex(4,3)] == (:apply,m*v,(4,3)) |
63 | 58 |
64 m = ScalingTensor(2,(3,)) | 59 m = ScalingTensor(2,(3,)) |
65 v = [1,2,3] | 60 v = [1,2,3] |
66 @test m*v isa AbstractVector | 61 @test m*v isa AbstractVector |
67 @test m*v == [2,4,6] | 62 @test m*v == [2,4,6] |
68 | 63 |
69 m = ScalingTensor(2,(2,2)) | 64 m = ScalingTensor(2,(2,2)) |
70 v = [[1 2];[3 4]] | 65 v = [[1 2];[3 4]] |
71 @test m*v == [[2 4];[6 8]] | 66 @test m*v == [[2 4];[6 8]] |
72 @test (m*v)[2,1] == 6 | 67 @test (m*v)[2,1] == 6 |
68 | |
69 @testset "Error on index out of bounds" begin | |
70 m = SizeDoublingMapping{Int, 1, 1}((3,)) | |
71 v = [0,1,2] | |
72 | |
73 @test_throws BoundsError (m*v)[0] | |
74 @test_throws BoundsError (m*v)[7] | |
75 end | |
76 | |
77 @testset "Error on unmatched dimensions" begin | |
78 v = [0,1,2] | |
79 m = SizeDoublingMapping{Int, 2, 1}((3,)) | |
80 @test_throws MethodError m*ones(Int,2,2) | |
81 @test_throws MethodError m*m*v | |
82 end | |
83 | |
84 @testset "Error on unmatched sizes" begin | |
85 @test_throws SizeMismatch ScalingTensor(2,(2,))*ones(3) | |
86 @test_throws SizeMismatch ScalingTensor(2,(2,))*ScalingTensor(2,(3,))*ones(3) | |
87 end | |
88 | |
73 | 89 |
74 @testset "Type calculation" begin | 90 @testset "Type calculation" begin |
75 m = ScalingTensor(2,(3,)) | 91 m = ScalingTensor(2,(3,)) |
76 v = [1.,2.,3.] | 92 v = [1.,2.,3.] |
77 @test m*v isa AbstractVector{Float64} | 93 @test m*v isa AbstractVector{Float64} |