comparison src/LazyTensors/LazyTensors.jl @ 1788:8b64df6cadba refactor/lazy_tensors/elementwise_ops

Refactor ElementWiseOperation to give a flatter structure of tensor compositions improving type inference
author Jonatan Werpers <jonatan@werpers.com>
date Wed, 25 Sep 2024 10:25:30 +0200
parents a922aa69eb83
children b8cb38fd67ff
comparison
equal deleted inserted replaced
1770:fbbadc6df706 1788:8b64df6cadba
23 Base.:*(a::LazyTensor, v::AbstractArray) = TensorApplication(a,v) 23 Base.:*(a::LazyTensor, v::AbstractArray) = TensorApplication(a,v)
24 Base.:*(a::LazyTensor, b::LazyTensor) = throw(MethodError(Base.:*,(a,b))) 24 Base.:*(a::LazyTensor, b::LazyTensor) = throw(MethodError(Base.:*,(a,b)))
25 Base.:*(a::LazyTensor, args::Union{LazyTensor, AbstractArray}...) = foldr(*,(a,args...)) 25 Base.:*(a::LazyTensor, args::Union{LazyTensor, AbstractArray}...) = foldr(*,(a,args...))
26 26
27 # Addition and subtraction of lazy tensors 27 # Addition and subtraction of lazy tensors
28 Base.:+(s::LazyTensor, t::LazyTensor) = ElementwiseTensorOperation{:+}(s,t) 28 Base.:+(ts::LazyTensor...) = ElementwiseTensorOperation{:+}(ts...)
29 Base.:-(s::LazyTensor, t::LazyTensor) = ElementwiseTensorOperation{:-}(s,t) 29 Base.:-(s::LazyTensor, t::LazyTensor) = ElementwiseTensorOperation{:-}(s,t)
30 30
31 # Composing lazy tensors 31 # Composing lazy tensors
32 Base.:∘(s::LazyTensor, t::LazyTensor) = TensorComposition(s,t) 32 Base.:∘(s::LazyTensor, t::LazyTensor) = TensorComposition(s,t)
33 Base.:∘(s::TensorComposition, t::LazyTensor) = s.t1∘(s.t2∘t) 33 Base.:∘(s::TensorComposition, t::LazyTensor) = s.t1∘(s.t2∘t)