comparison src/LazyTensors/LazyTensors.jl @ 1837:200971c71657 refactor/lazy_tensors/elementwise_ops

Refactor ElementwiseTensorOperation into TensorSum and use TensorNegation for handling subtraction
author Jonatan Werpers <jonatan@werpers.com>
date Thu, 09 Jan 2025 21:46:01 +0100
parents 368999a2e243
children e1077273eda5
comparison
equal deleted inserted replaced
1836:368999a2e243 1837:200971c71657
8 8
9 export TensorApplication 9 export TensorApplication
10 export TensorTranspose 10 export TensorTranspose
11 export TensorComposition 11 export TensorComposition
12 export TensorNegation 12 export TensorNegation
13 export TensorSum
13 export IdentityTensor 14 export IdentityTensor
14 export ScalingTensor 15 export ScalingTensor
15 export DiagonalTensor 16 export DiagonalTensor
16 export DenseTensor 17 export DenseTensor
17 export InflatedTensor 18 export InflatedTensor
34 Base.:*(a::LazyTensor, v::AbstractArray) = TensorApplication(a,v) 35 Base.:*(a::LazyTensor, v::AbstractArray) = TensorApplication(a,v)
35 Base.:*(a::LazyTensor, b::LazyTensor) = throw(MethodError(Base.:*,(a,b))) 36 Base.:*(a::LazyTensor, b::LazyTensor) = throw(MethodError(Base.:*,(a,b)))
36 Base.:*(a::LazyTensor, args::Union{LazyTensor, AbstractArray}...) = foldr(*,(a,args...)) 37 Base.:*(a::LazyTensor, args::Union{LazyTensor, AbstractArray}...) = foldr(*,(a,args...))
37 38
38 # Addition and subtraction of lazy tensors 39 # Addition and subtraction of lazy tensors
39 Base.:+(ts::LazyTensor...) = ElementwiseTensorOperation{:+}(ts...) 40 Base.:+(ts::LazyTensor...) = TensorSum(ts...)
40 Base.:-(t::LazyTensor) = TensorNegation(t) 41 Base.:-(t::LazyTensor) = TensorNegation(t)
41 Base.:-(s::LazyTensor, t::LazyTensor) = ElementwiseTensorOperation{:-}(s,t) 42 Base.:-(s::LazyTensor, t::LazyTensor) = TensorSum(s,-t)
42 43
43 # Composing lazy tensors 44 # Composing lazy tensors
44 Base.:∘(s::LazyTensor, t::LazyTensor) = TensorComposition(s,t) 45 Base.:∘(s::LazyTensor, t::LazyTensor) = TensorComposition(s,t)
45 Base.:∘(s::TensorComposition, t::LazyTensor) = s.t1∘(s.t2∘t) 46 Base.:∘(s::TensorComposition, t::LazyTensor) = s.t1∘(s.t2∘t)
46 47