comparison src/LazyTensors/lazy_tensor_operations.jl @ 1007:f7a718bcb4da refactor/lazy_tensors

Add checking of sizes to LazyTensorBinaryOperation
author Jonatan Werpers <jonatan@werpers.com>
date Sun, 20 Mar 2022 22:41:28 +0100
parents d9476fede83d
children 2c1a0722ddb9 4dd3c2312d9e 52f07c77299d
comparison
equal deleted inserted replaced
1006:d9476fede83d 1007:f7a718bcb4da
1 # TODO: Go over type parameters
2
3 """ 1 """
4 LazyTensorApplication{T,R,D} <: LazyArray{T,R} 2 LazyTensorApplication{T,R,D} <: LazyArray{T,R}
5 3
6 Struct for lazy application of a LazyTensor. Created using `*`. 4 Struct for lazy application of a LazyTensor. Created using `*`.
7 5
57 struct LazyTensorBinaryOperation{Op,T,R,D,T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}} <: LazyTensor{T,D,R} 55 struct LazyTensorBinaryOperation{Op,T,R,D,T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}} <: LazyTensor{T,D,R}
58 tm1::T1 56 tm1::T1
59 tm2::T2 57 tm2::T2
60 58
61 function LazyTensorBinaryOperation{Op,T,R,D}(tm1::T1,tm2::T2) where {Op,T,R,D, T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}} 59 function LazyTensorBinaryOperation{Op,T,R,D}(tm1::T1,tm2::T2) where {Op,T,R,D, T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}}
60 @boundscheck check_domain_size(tm2, domain_size(tm1))
61 @boundscheck check_range_size(tm2, range_size(tm1))
62 return new{Op,T,R,D,T1,T2}(tm1,tm2) 62 return new{Op,T,R,D,T1,T2}(tm1,tm2)
63 end 63 end
64 end 64 end
65 # TODO: Boundschecking in constructor.
66 65
67 LazyTensorBinaryOperation{Op}(s,t) where Op = LazyTensorBinaryOperation{Op,eltype(s), range_dim(s), domain_dim(s)}(s,t) 66 LazyTensorBinaryOperation{Op}(s,t) where Op = LazyTensorBinaryOperation{Op,eltype(s), range_dim(s), domain_dim(s)}(s,t)
68 67
69 apply(tmBinOp::LazyTensorBinaryOperation{:+,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) + apply(tmBinOp.tm2, v, I...) 68 apply(tmBinOp::LazyTensorBinaryOperation{:+,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) + apply(tmBinOp.tm2, v, I...)
70 apply(tmBinOp::LazyTensorBinaryOperation{:-,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) - apply(tmBinOp.tm2, v, I...) 69 apply(tmBinOp::LazyTensorBinaryOperation{:-,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) - apply(tmBinOp.tm2, v, I...)