comparison src/LazyTensors/lazy_tensor_operations.jl @ 1005:becd95ba0fce refactor/lazy_tensors

Add bounds checking for lazy tensor application and clea up tests a bit
author Jonatan Werpers <jonatan@werpers.com>
date Sun, 20 Mar 2022 22:15:29 +0100
parents 271aa6ae1055
children d9476fede83d
comparison
equal deleted inserted replaced
1004:7fd37aab84fe 1005:becd95ba0fce
1 # TODO: We need to be really careful about good error messages.
2 # TODO: Go over type parameters 1 # TODO: Go over type parameters
3
4 2
5 """ 3 """
6 LazyTensorApplication{T,R,D} <: LazyArray{T,R} 4 LazyTensorApplication{T,R,D} <: LazyArray{T,R}
7 5
8 Struct for lazy application of a LazyTensor. Created using `*`. 6 Struct for lazy application of a LazyTensor. Created using `*`.
14 struct LazyTensorApplication{T,R,D, TM<:LazyTensor{<:Any,R,D}, AA<:AbstractArray{<:Any,D}} <: LazyArray{T,R} 12 struct LazyTensorApplication{T,R,D, TM<:LazyTensor{<:Any,R,D}, AA<:AbstractArray{<:Any,D}} <: LazyArray{T,R}
15 t::TM 13 t::TM
16 o::AA 14 o::AA
17 15
18 function LazyTensorApplication(t::LazyTensor{<:Any,R,D}, o::AbstractArray{<:Any,D}) where {R,D} 16 function LazyTensorApplication(t::LazyTensor{<:Any,R,D}, o::AbstractArray{<:Any,D}) where {R,D}
17 @boundscheck check_domain_size(t, size(o))
19 I = ntuple(i->1, range_dim(t)) 18 I = ntuple(i->1, range_dim(t))
20 T = typeof(apply(t,o,I...)) 19 T = typeof(apply(t,o,I...))
21 return new{T,R,D,typeof(t), typeof(o)}(t,o) 20 return new{T,R,D,typeof(t), typeof(o)}(t,o)
22 end 21 end
23 end 22 end
24 # TODO: Do boundschecking on creation! 23
25 24 function Base.getindex(ta::LazyTensorApplication{T,R}, I::Vararg{Any,R}) where {T,R}
26 Base.getindex(ta::LazyTensorApplication{T,R}, I::Vararg{Any,R}) where {T,R} = apply(ta.t, ta.o, I...) 25 @boundscheck checkbounds(ta, Int.(I)...)
27 Base.getindex(ta::LazyTensorApplication{T,1}, I::CartesianIndex{1}) where {T} = apply(ta.t, ta.o, I.I...) # Would otherwise be caught in the previous method. 26 return apply(ta.t, ta.o, I...)
27 end
28 Base.getindex(ta::LazyTensorApplication{T,1} where T, I::CartesianIndex{1}) = ta[Tuple(I)...] # Would otherwise be caught in the previous method.
28 Base.size(ta::LazyTensorApplication) = range_size(ta.t) 29 Base.size(ta::LazyTensorApplication) = range_size(ta.t)
29 # TODO: What else is needed to implement the AbstractArray interface?
30 30
31 31
32 """ 32 """
33 LazyTensorTranspose{T,R,D} <: LazyTensor{T,D,R} 33 LazyTensorTranspose{T,R,D} <: LazyTensor{T,D,R}
34 34