diff src/LazyTensors/LazyTensors.jl @ 1954:b0915f43b122 feature/sbp_operators/laplace_curvilinear

Merge feature/grids/geometry_functions
author Jonatan Werpers <jonatan@werpers.com>
date Sat, 08 Feb 2025 09:38:58 +0100
parents e1077273eda5
children
line wrap: on
line diff
--- a/src/LazyTensors/LazyTensors.jl	Sat Feb 08 09:35:13 2025 +0100
+++ b/src/LazyTensors/LazyTensors.jl	Sat Feb 08 09:38:58 2025 +0100
@@ -9,6 +9,8 @@
 export TensorApplication
 export TensorTranspose
 export TensorComposition
+export TensorNegation
+export TensorSum
 export IdentityTensor
 export ScalingTensor
 export DiagonalTensor
@@ -35,8 +37,13 @@
 Base.:*(a::LazyTensor, args::Union{LazyTensor, AbstractArray}...) = foldr(*,(a,args...))
 
 # Addition and subtraction of lazy tensors
-Base.:+(ts::LazyTensor...) = ElementwiseTensorOperation{:+}(ts...)
-Base.:-(s::LazyTensor, t::LazyTensor) = ElementwiseTensorOperation{:-}(s,t)
+Base.:+(ts::LazyTensor...) = TensorSum(ts...)
+Base.:-(t::LazyTensor) = TensorNegation(t)
+Base.:-(s::LazyTensor, t::LazyTensor) = s + (-t)
+## Specializations to flatten the nesting of tensors. This helps Julia during inference.
+Base.:+(t::TensorSum, s::TensorSum) = TensorSum(t.tms..., s.tms...)
+Base.:+(t::TensorSum, s::LazyTensor) = TensorSum(t.tms..., s)
+Base.:+(t::LazyTensor, s::TensorSum) = TensorSum(t, s.tms...)
 
 # Composing lazy tensors
 Base.:∘(s::LazyTensor, t::LazyTensor) = TensorComposition(s,t)