comparison src/LazyTensors/lazy_tensor_operations.jl @ 1839:e1077273eda5 refactor/lazy_tensors/elementwise_ops

Separate the logic of how TensorSum works from the semantics of + and -
author Jonatan Werpers <jonatan@werpers.com>
date Thu, 09 Jan 2025 22:12:26 +0100
parents 200971c71657
children cb3a8450ed44
comparison
equal deleted inserted replaced
1838:4bd998069053 1839:e1077273eda5
78 function TensorSum(ts::Vararg{LazyTensor}) 78 function TensorSum(ts::Vararg{LazyTensor})
79 T = eltype(ts[1]) 79 T = eltype(ts[1])
80 R = range_dim(ts[1]) 80 R = range_dim(ts[1])
81 D = domain_dim(ts[1]) 81 D = domain_dim(ts[1])
82 return TensorSum{T,R,D}(ts) 82 return TensorSum{T,R,D}(ts)
83 end
84
85 # The following methods for :+ are intended to reduce the depth of the tree of operations in some caes
86 function TensorSum(t1::TensorSum, t2::TensorSum)
87 TensorSum(t1.tms..., t2.tms...)
88 end
89
90 function TensorSum(t1::TensorSum, t2::LazyTensor)
91 TensorSum(t1.tms..., t2)
92 end
93
94 function TensorSum(t1::LazyTensor, t2::TensorSum)
95 TensorSum(t1, t2.tms...)
96 end 83 end
97 84
98 function apply(tmBinOp::TensorSum{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} 85 function apply(tmBinOp::TensorSum{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D}
99 vs = map(tmBinOp.tms) do tm 86 vs = map(tmBinOp.tms) do tm
100 apply(tm,v,I...) 87 apply(tm,v,I...)