Mercurial > repos > public > sbplib_julia
comparison src/LazyTensors/lazy_tensor_operations.jl @ 1005:becd95ba0fce refactor/lazy_tensors
Add bounds checking for lazy tensor application and clea up tests a bit
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Sun, 20 Mar 2022 22:15:29 +0100 |
parents | 271aa6ae1055 |
children | d9476fede83d |
comparison
equal
deleted
inserted
replaced
1004:7fd37aab84fe | 1005:becd95ba0fce |
---|---|
1 # TODO: We need to be really careful about good error messages. | |
2 # TODO: Go over type parameters | 1 # TODO: Go over type parameters |
3 | |
4 | 2 |
5 """ | 3 """ |
6 LazyTensorApplication{T,R,D} <: LazyArray{T,R} | 4 LazyTensorApplication{T,R,D} <: LazyArray{T,R} |
7 | 5 |
8 Struct for lazy application of a LazyTensor. Created using `*`. | 6 Struct for lazy application of a LazyTensor. Created using `*`. |
14 struct LazyTensorApplication{T,R,D, TM<:LazyTensor{<:Any,R,D}, AA<:AbstractArray{<:Any,D}} <: LazyArray{T,R} | 12 struct LazyTensorApplication{T,R,D, TM<:LazyTensor{<:Any,R,D}, AA<:AbstractArray{<:Any,D}} <: LazyArray{T,R} |
15 t::TM | 13 t::TM |
16 o::AA | 14 o::AA |
17 | 15 |
18 function LazyTensorApplication(t::LazyTensor{<:Any,R,D}, o::AbstractArray{<:Any,D}) where {R,D} | 16 function LazyTensorApplication(t::LazyTensor{<:Any,R,D}, o::AbstractArray{<:Any,D}) where {R,D} |
17 @boundscheck check_domain_size(t, size(o)) | |
19 I = ntuple(i->1, range_dim(t)) | 18 I = ntuple(i->1, range_dim(t)) |
20 T = typeof(apply(t,o,I...)) | 19 T = typeof(apply(t,o,I...)) |
21 return new{T,R,D,typeof(t), typeof(o)}(t,o) | 20 return new{T,R,D,typeof(t), typeof(o)}(t,o) |
22 end | 21 end |
23 end | 22 end |
24 # TODO: Do boundschecking on creation! | 23 |
25 | 24 function Base.getindex(ta::LazyTensorApplication{T,R}, I::Vararg{Any,R}) where {T,R} |
26 Base.getindex(ta::LazyTensorApplication{T,R}, I::Vararg{Any,R}) where {T,R} = apply(ta.t, ta.o, I...) | 25 @boundscheck checkbounds(ta, Int.(I)...) |
27 Base.getindex(ta::LazyTensorApplication{T,1}, I::CartesianIndex{1}) where {T} = apply(ta.t, ta.o, I.I...) # Would otherwise be caught in the previous method. | 26 return apply(ta.t, ta.o, I...) |
27 end | |
28 Base.getindex(ta::LazyTensorApplication{T,1} where T, I::CartesianIndex{1}) = ta[Tuple(I)...] # Would otherwise be caught in the previous method. | |
28 Base.size(ta::LazyTensorApplication) = range_size(ta.t) | 29 Base.size(ta::LazyTensorApplication) = range_size(ta.t) |
29 # TODO: What else is needed to implement the AbstractArray interface? | |
30 | 30 |
31 | 31 |
32 """ | 32 """ |
33 LazyTensorTranspose{T,R,D} <: LazyTensor{T,D,R} | 33 LazyTensorTranspose{T,R,D} <: LazyTensor{T,D,R} |
34 | 34 |