changeset 1839:e1077273eda5 refactor/lazy_tensors/elementwise_ops

Separate the logic of how TensorSum works from the semantics of + and -
author Jonatan Werpers <jonatan@werpers.com>
date Thu, 09 Jan 2025 22:12:26 +0100
parents 4bd998069053
children cb3a8450ed44
files src/LazyTensors/LazyTensors.jl src/LazyTensors/lazy_tensor_operations.jl
diffstat 2 files changed, 5 insertions(+), 14 deletions(-) [+]
line wrap: on
line diff
--- a/src/LazyTensors/LazyTensors.jl	Thu Jan 09 21:54:45 2025 +0100
+++ b/src/LazyTensors/LazyTensors.jl	Thu Jan 09 22:12:26 2025 +0100
@@ -39,7 +39,11 @@
 # Addition and subtraction of lazy tensors
 Base.:+(ts::LazyTensor...) = TensorSum(ts...)
 Base.:-(t::LazyTensor) = TensorNegation(t)
-Base.:-(s::LazyTensor, t::LazyTensor) = TensorSum(s,-t)
+Base.:-(s::LazyTensor, t::LazyTensor) = s + (-t)
+## Specializations to flatten the nesting of tensors. This helps Julia during inference.
+Base.:+(t::TensorSum, s::TensorSum) = TensorSum(t.tms..., s.tms...)
+Base.:+(t::TensorSum, s::LazyTensor) = TensorSum(t.tms..., s)
+Base.:+(t::LazyTensor, s::TensorSum) = TensorSum(t, s.tms...)
 
 # Composing lazy tensors
 Base.:∘(s::LazyTensor, t::LazyTensor) = TensorComposition(s,t)
--- a/src/LazyTensors/lazy_tensor_operations.jl	Thu Jan 09 21:54:45 2025 +0100
+++ b/src/LazyTensors/lazy_tensor_operations.jl	Thu Jan 09 22:12:26 2025 +0100
@@ -82,19 +82,6 @@
     return TensorSum{T,R,D}(ts)
 end
 
-# The following methods for :+ are intended to reduce the depth of the tree of operations in some caes
-function TensorSum(t1::TensorSum, t2::TensorSum)
-    TensorSum(t1.tms..., t2.tms...)
-end
-
-function TensorSum(t1::TensorSum, t2::LazyTensor)
-    TensorSum(t1.tms..., t2)
-end
-
-function TensorSum(t1::LazyTensor, t2::TensorSum)
-    TensorSum(t1, t2.tms...)
-end
-
 function apply(tmBinOp::TensorSum{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D}
     vs = map(tmBinOp.tms) do tm
         apply(tm,v,I...)