Mercurial > repos > public > sbplib_julia
diff src/LazyTensors/LazyTensors.jl @ 1839:e1077273eda5 refactor/lazy_tensors/elementwise_ops
Separate the logic of how TensorSum works from the semantics of + and -
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Thu, 09 Jan 2025 22:12:26 +0100 |
parents | 200971c71657 |
children |
line wrap: on
line diff
--- a/src/LazyTensors/LazyTensors.jl Thu Jan 09 21:54:45 2025 +0100 +++ b/src/LazyTensors/LazyTensors.jl Thu Jan 09 22:12:26 2025 +0100 @@ -39,7 +39,11 @@ # Addition and subtraction of lazy tensors Base.:+(ts::LazyTensor...) = TensorSum(ts...) Base.:-(t::LazyTensor) = TensorNegation(t) -Base.:-(s::LazyTensor, t::LazyTensor) = TensorSum(s,-t) +Base.:-(s::LazyTensor, t::LazyTensor) = s + (-t) +## Specializations to flatten the nesting of tensors. This helps Julia during inference. +Base.:+(t::TensorSum, s::TensorSum) = TensorSum(t.tms..., s.tms...) +Base.:+(t::TensorSum, s::LazyTensor) = TensorSum(t.tms..., s) +Base.:+(t::LazyTensor, s::TensorSum) = TensorSum(t, s.tms...) # Composing lazy tensors Base.:∘(s::LazyTensor, t::LazyTensor) = TensorComposition(s,t)