diff src/LazyTensors/LazyTensors.jl @ 2057:8a2a0d678d6f feature/lazy_tensors/pretty_printing

Merge default
author Jonatan Werpers <jonatan@werpers.com>
date Tue, 10 Feb 2026 22:41:19 +0100
parents e1077273eda5
children
line wrap: on
line diff
--- a/src/LazyTensors/LazyTensors.jl	Mon May 23 07:20:27 2022 +0200
+++ b/src/LazyTensors/LazyTensors.jl	Tue Feb 10 22:41:19 2026 +0100
@@ -1,8 +1,16 @@
 module LazyTensors
 
+export LazyTensor
+export apply
+export apply_transpose
+export range_dim, domain_dim
+export range_size, domain_size
+
 export TensorApplication
 export TensorTranspose
 export TensorComposition
+export TensorNegation
+export TensorSum
 export IdentityTensor
 export ScalingTensor
 export DiagonalTensor
@@ -13,6 +21,10 @@
 export DomainSizeMismatch
 export RangeSizeMismatch
 
+export LazyArray
+export LazyFunctionArray
+export +̃, -̃, *̃, /̃
+
 include("lazy_tensor.jl")
 include("tensor_types.jl")
 include("lazy_array.jl")
@@ -25,11 +37,17 @@
 Base.:*(a::LazyTensor, args::Union{LazyTensor, AbstractArray}...) = foldr(*,(a,args...))
 
 # Addition and subtraction of lazy tensors
-Base.:+(s::LazyTensor, t::LazyTensor) = ElementwiseTensorOperation{:+}(s,t)
-Base.:-(s::LazyTensor, t::LazyTensor) = ElementwiseTensorOperation{:-}(s,t)
+Base.:+(ts::LazyTensor...) = TensorSum(ts...)
+Base.:-(t::LazyTensor) = TensorNegation(t)
+Base.:-(s::LazyTensor, t::LazyTensor) = s + (-t)
+## Specializations to flatten the nesting of tensors. This helps Julia during inference.
+Base.:+(t::TensorSum, s::TensorSum) = TensorSum(t.tms..., s.tms...)
+Base.:+(t::TensorSum, s::LazyTensor) = TensorSum(t.tms..., s)
+Base.:+(t::LazyTensor, s::TensorSum) = TensorSum(t, s.tms...)
 
 # Composing lazy tensors
 Base.:∘(s::LazyTensor, t::LazyTensor) = TensorComposition(s,t)
+Base.:∘(s::TensorComposition, t::LazyTensor) = s.t1∘(s.t2∘t)
 
 # Outer products of tensors
 ⊗(a::LazyTensor, b::LazyTensor) = LazyOuterProduct(a,b)