comparison src/LazyTensors/LazyTensors.jl @ 2057:8a2a0d678d6f feature/lazy_tensors/pretty_printing

Merge default
author Jonatan Werpers <jonatan@werpers.com>
date Tue, 10 Feb 2026 22:41:19 +0100
parents e1077273eda5
children
comparison
equal deleted inserted replaced
1110:c0bff9f6e0fb 2057:8a2a0d678d6f
1 module LazyTensors 1 module LazyTensors
2
3 export LazyTensor
4 export apply
5 export apply_transpose
6 export range_dim, domain_dim
7 export range_size, domain_size
2 8
3 export TensorApplication 9 export TensorApplication
4 export TensorTranspose 10 export TensorTranspose
5 export TensorComposition 11 export TensorComposition
12 export TensorNegation
13 export TensorSum
6 export IdentityTensor 14 export IdentityTensor
7 export ScalingTensor 15 export ScalingTensor
8 export DiagonalTensor 16 export DiagonalTensor
9 export DenseTensor 17 export DenseTensor
10 export InflatedTensor 18 export InflatedTensor
11 export LazyOuterProduct 19 export LazyOuterProduct
12 export ⊗ 20 export ⊗
13 export DomainSizeMismatch 21 export DomainSizeMismatch
14 export RangeSizeMismatch 22 export RangeSizeMismatch
23
24 export LazyArray
25 export LazyFunctionArray
26 export +̃, -̃, *̃, /̃
15 27
16 include("lazy_tensor.jl") 28 include("lazy_tensor.jl")
17 include("tensor_types.jl") 29 include("tensor_types.jl")
18 include("lazy_array.jl") 30 include("lazy_array.jl")
19 include("lazy_tensor_operations.jl") 31 include("lazy_tensor_operations.jl")
23 Base.:*(a::LazyTensor, v::AbstractArray) = TensorApplication(a,v) 35 Base.:*(a::LazyTensor, v::AbstractArray) = TensorApplication(a,v)
24 Base.:*(a::LazyTensor, b::LazyTensor) = throw(MethodError(Base.:*,(a,b))) 36 Base.:*(a::LazyTensor, b::LazyTensor) = throw(MethodError(Base.:*,(a,b)))
25 Base.:*(a::LazyTensor, args::Union{LazyTensor, AbstractArray}...) = foldr(*,(a,args...)) 37 Base.:*(a::LazyTensor, args::Union{LazyTensor, AbstractArray}...) = foldr(*,(a,args...))
26 38
27 # Addition and subtraction of lazy tensors 39 # Addition and subtraction of lazy tensors
28 Base.:+(s::LazyTensor, t::LazyTensor) = ElementwiseTensorOperation{:+}(s,t) 40 Base.:+(ts::LazyTensor...) = TensorSum(ts...)
29 Base.:-(s::LazyTensor, t::LazyTensor) = ElementwiseTensorOperation{:-}(s,t) 41 Base.:-(t::LazyTensor) = TensorNegation(t)
42 Base.:-(s::LazyTensor, t::LazyTensor) = s + (-t)
43 ## Specializations to flatten the nesting of tensors. This helps Julia during inference.
44 Base.:+(t::TensorSum, s::TensorSum) = TensorSum(t.tms..., s.tms...)
45 Base.:+(t::TensorSum, s::LazyTensor) = TensorSum(t.tms..., s)
46 Base.:+(t::LazyTensor, s::TensorSum) = TensorSum(t, s.tms...)
30 47
31 # Composing lazy tensors 48 # Composing lazy tensors
32 Base.:∘(s::LazyTensor, t::LazyTensor) = TensorComposition(s,t) 49 Base.:∘(s::LazyTensor, t::LazyTensor) = TensorComposition(s,t)
50 Base.:∘(s::TensorComposition, t::LazyTensor) = s.t1∘(s.t2∘t)
33 51
34 # Outer products of tensors 52 # Outer products of tensors
35 ⊗(a::LazyTensor, b::LazyTensor) = LazyOuterProduct(a,b) 53 ⊗(a::LazyTensor, b::LazyTensor) = LazyOuterProduct(a,b)
36 54
37 end # module 55 end # module