Mercurial > repos > public > sbplib_julia
comparison src/LazyTensors/LazyTensors.jl @ 2057:8a2a0d678d6f feature/lazy_tensors/pretty_printing
Merge default
| author | Jonatan Werpers <jonatan@werpers.com> |
|---|---|
| date | Tue, 10 Feb 2026 22:41:19 +0100 |
| parents | e1077273eda5 |
| children |
comparison
equal
deleted
inserted
replaced
| 1110:c0bff9f6e0fb | 2057:8a2a0d678d6f |
|---|---|
| 1 module LazyTensors | 1 module LazyTensors |
| 2 | |
| 3 export LazyTensor | |
| 4 export apply | |
| 5 export apply_transpose | |
| 6 export range_dim, domain_dim | |
| 7 export range_size, domain_size | |
| 2 | 8 |
| 3 export TensorApplication | 9 export TensorApplication |
| 4 export TensorTranspose | 10 export TensorTranspose |
| 5 export TensorComposition | 11 export TensorComposition |
| 12 export TensorNegation | |
| 13 export TensorSum | |
| 6 export IdentityTensor | 14 export IdentityTensor |
| 7 export ScalingTensor | 15 export ScalingTensor |
| 8 export DiagonalTensor | 16 export DiagonalTensor |
| 9 export DenseTensor | 17 export DenseTensor |
| 10 export InflatedTensor | 18 export InflatedTensor |
| 11 export LazyOuterProduct | 19 export LazyOuterProduct |
| 12 export ⊗ | 20 export ⊗ |
| 13 export DomainSizeMismatch | 21 export DomainSizeMismatch |
| 14 export RangeSizeMismatch | 22 export RangeSizeMismatch |
| 23 | |
| 24 export LazyArray | |
| 25 export LazyFunctionArray | |
| 26 export +̃, -̃, *̃, /̃ | |
| 15 | 27 |
| 16 include("lazy_tensor.jl") | 28 include("lazy_tensor.jl") |
| 17 include("tensor_types.jl") | 29 include("tensor_types.jl") |
| 18 include("lazy_array.jl") | 30 include("lazy_array.jl") |
| 19 include("lazy_tensor_operations.jl") | 31 include("lazy_tensor_operations.jl") |
| 23 Base.:*(a::LazyTensor, v::AbstractArray) = TensorApplication(a,v) | 35 Base.:*(a::LazyTensor, v::AbstractArray) = TensorApplication(a,v) |
| 24 Base.:*(a::LazyTensor, b::LazyTensor) = throw(MethodError(Base.:*,(a,b))) | 36 Base.:*(a::LazyTensor, b::LazyTensor) = throw(MethodError(Base.:*,(a,b))) |
| 25 Base.:*(a::LazyTensor, args::Union{LazyTensor, AbstractArray}...) = foldr(*,(a,args...)) | 37 Base.:*(a::LazyTensor, args::Union{LazyTensor, AbstractArray}...) = foldr(*,(a,args...)) |
| 26 | 38 |
| 27 # Addition and subtraction of lazy tensors | 39 # Addition and subtraction of lazy tensors |
| 28 Base.:+(s::LazyTensor, t::LazyTensor) = ElementwiseTensorOperation{:+}(s,t) | 40 Base.:+(ts::LazyTensor...) = TensorSum(ts...) |
| 29 Base.:-(s::LazyTensor, t::LazyTensor) = ElementwiseTensorOperation{:-}(s,t) | 41 Base.:-(t::LazyTensor) = TensorNegation(t) |
| 42 Base.:-(s::LazyTensor, t::LazyTensor) = s + (-t) | |
| 43 ## Specializations to flatten the nesting of tensors. This helps Julia during inference. | |
| 44 Base.:+(t::TensorSum, s::TensorSum) = TensorSum(t.tms..., s.tms...) | |
| 45 Base.:+(t::TensorSum, s::LazyTensor) = TensorSum(t.tms..., s) | |
| 46 Base.:+(t::LazyTensor, s::TensorSum) = TensorSum(t, s.tms...) | |
| 30 | 47 |
| 31 # Composing lazy tensors | 48 # Composing lazy tensors |
| 32 Base.:∘(s::LazyTensor, t::LazyTensor) = TensorComposition(s,t) | 49 Base.:∘(s::LazyTensor, t::LazyTensor) = TensorComposition(s,t) |
| 50 Base.:∘(s::TensorComposition, t::LazyTensor) = s.t1∘(s.t2∘t) | |
| 33 | 51 |
| 34 # Outer products of tensors | 52 # Outer products of tensors |
| 35 ⊗(a::LazyTensor, b::LazyTensor) = LazyOuterProduct(a,b) | 53 ⊗(a::LazyTensor, b::LazyTensor) = LazyOuterProduct(a,b) |
| 36 | 54 |
| 37 end # module | 55 end # module |
