comparison src/LazyTensors/LazyTensors.jl @ 1900:418566cdd689

Merge refactor/lazy_tensors/elementwise_ops
author Jonatan Werpers <jonatan@werpers.com>
date Fri, 31 Jan 2025 20:35:28 +0100
parents e1077273eda5
children
comparison
equal deleted inserted replaced
1896:9d708f3300d5 1900:418566cdd689
7 export range_size, domain_size 7 export range_size, domain_size
8 8
9 export TensorApplication 9 export TensorApplication
10 export TensorTranspose 10 export TensorTranspose
11 export TensorComposition 11 export TensorComposition
12 export TensorNegation
13 export TensorSum
12 export IdentityTensor 14 export IdentityTensor
13 export ScalingTensor 15 export ScalingTensor
14 export DiagonalTensor 16 export DiagonalTensor
15 export DenseTensor 17 export DenseTensor
16 export InflatedTensor 18 export InflatedTensor
33 Base.:*(a::LazyTensor, v::AbstractArray) = TensorApplication(a,v) 35 Base.:*(a::LazyTensor, v::AbstractArray) = TensorApplication(a,v)
34 Base.:*(a::LazyTensor, b::LazyTensor) = throw(MethodError(Base.:*,(a,b))) 36 Base.:*(a::LazyTensor, b::LazyTensor) = throw(MethodError(Base.:*,(a,b)))
35 Base.:*(a::LazyTensor, args::Union{LazyTensor, AbstractArray}...) = foldr(*,(a,args...)) 37 Base.:*(a::LazyTensor, args::Union{LazyTensor, AbstractArray}...) = foldr(*,(a,args...))
36 38
37 # Addition and subtraction of lazy tensors 39 # Addition and subtraction of lazy tensors
38 Base.:+(s::LazyTensor, t::LazyTensor) = ElementwiseTensorOperation{:+}(s,t) 40 Base.:+(ts::LazyTensor...) = TensorSum(ts...)
39 Base.:-(s::LazyTensor, t::LazyTensor) = ElementwiseTensorOperation{:-}(s,t) 41 Base.:-(t::LazyTensor) = TensorNegation(t)
42 Base.:-(s::LazyTensor, t::LazyTensor) = s + (-t)
43 ## Specializations to flatten the nesting of tensors. This helps Julia during inference.
44 Base.:+(t::TensorSum, s::TensorSum) = TensorSum(t.tms..., s.tms...)
45 Base.:+(t::TensorSum, s::LazyTensor) = TensorSum(t.tms..., s)
46 Base.:+(t::LazyTensor, s::TensorSum) = TensorSum(t, s.tms...)
40 47
41 # Composing lazy tensors 48 # Composing lazy tensors
42 Base.:∘(s::LazyTensor, t::LazyTensor) = TensorComposition(s,t) 49 Base.:∘(s::LazyTensor, t::LazyTensor) = TensorComposition(s,t)
43 Base.:∘(s::TensorComposition, t::LazyTensor) = s.t1∘(s.t2∘t) 50 Base.:∘(s::TensorComposition, t::LazyTensor) = s.t1∘(s.t2∘t)
44 51