diff src/LazyTensors/LazyTensors.jl @ 1900:418566cdd689

Merge refactor/lazy_tensors/elementwise_ops
author Jonatan Werpers <jonatan@werpers.com>
date Fri, 31 Jan 2025 20:35:28 +0100
parents e1077273eda5
children
line wrap: on
line diff
--- a/src/LazyTensors/LazyTensors.jl	Fri Jan 31 15:52:49 2025 +0100
+++ b/src/LazyTensors/LazyTensors.jl	Fri Jan 31 20:35:28 2025 +0100
@@ -9,6 +9,8 @@
 export TensorApplication
 export TensorTranspose
 export TensorComposition
+export TensorNegation
+export TensorSum
 export IdentityTensor
 export ScalingTensor
 export DiagonalTensor
@@ -35,8 +37,13 @@
 Base.:*(a::LazyTensor, args::Union{LazyTensor, AbstractArray}...) = foldr(*,(a,args...))
 
 # Addition and subtraction of lazy tensors
-Base.:+(s::LazyTensor, t::LazyTensor) = ElementwiseTensorOperation{:+}(s,t)
-Base.:-(s::LazyTensor, t::LazyTensor) = ElementwiseTensorOperation{:-}(s,t)
+Base.:+(ts::LazyTensor...) = TensorSum(ts...)
+Base.:-(t::LazyTensor) = TensorNegation(t)
+Base.:-(s::LazyTensor, t::LazyTensor) = s + (-t)
+## Specializations to flatten the nesting of tensors. This helps Julia during inference.
+Base.:+(t::TensorSum, s::TensorSum) = TensorSum(t.tms..., s.tms...)
+Base.:+(t::TensorSum, s::LazyTensor) = TensorSum(t.tms..., s)
+Base.:+(t::LazyTensor, s::TensorSum) = TensorSum(t, s.tms...)
 
 # Composing lazy tensors
 Base.:∘(s::LazyTensor, t::LazyTensor) = TensorComposition(s,t)