Mercurial > repos > public > sbplib_julia
diff src/LazyTensors/LazyTensors.jl @ 1900:418566cdd689
Merge refactor/lazy_tensors/elementwise_ops
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Fri, 31 Jan 2025 20:35:28 +0100 |
parents | e1077273eda5 |
children |
line wrap: on
line diff
--- a/src/LazyTensors/LazyTensors.jl Fri Jan 31 15:52:49 2025 +0100 +++ b/src/LazyTensors/LazyTensors.jl Fri Jan 31 20:35:28 2025 +0100 @@ -9,6 +9,8 @@ export TensorApplication export TensorTranspose export TensorComposition +export TensorNegation +export TensorSum export IdentityTensor export ScalingTensor export DiagonalTensor @@ -35,8 +37,13 @@ Base.:*(a::LazyTensor, args::Union{LazyTensor, AbstractArray}...) = foldr(*,(a,args...)) # Addition and subtraction of lazy tensors -Base.:+(s::LazyTensor, t::LazyTensor) = ElementwiseTensorOperation{:+}(s,t) -Base.:-(s::LazyTensor, t::LazyTensor) = ElementwiseTensorOperation{:-}(s,t) +Base.:+(ts::LazyTensor...) = TensorSum(ts...) +Base.:-(t::LazyTensor) = TensorNegation(t) +Base.:-(s::LazyTensor, t::LazyTensor) = s + (-t) +## Specializations to flatten the nesting of tensors. This helps Julia during inference. +Base.:+(t::TensorSum, s::TensorSum) = TensorSum(t.tms..., s.tms...) +Base.:+(t::TensorSum, s::LazyTensor) = TensorSum(t.tms..., s) +Base.:+(t::LazyTensor, s::TensorSum) = TensorSum(t, s.tms...) # Composing lazy tensors Base.:∘(s::LazyTensor, t::LazyTensor) = TensorComposition(s,t)