Mercurial > repos > public > sbplib_julia
changeset 1788:8b64df6cadba refactor/lazy_tensors/elementwise_ops
Refactor ElementWiseOperation to give a flatter structure of tensor compositions improving type inference
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Wed, 25 Sep 2024 10:25:30 +0200 |
parents | fbbadc6df706 |
children | 48eaa973159a |
files | src/LazyTensors/LazyTensors.jl src/LazyTensors/lazy_tensor_operations.jl |
diffstat | 2 files changed, 43 insertions(+), 13 deletions(-) [+] |
line wrap: on
line diff
--- a/src/LazyTensors/LazyTensors.jl Fri Sep 13 22:41:27 2024 +0200 +++ b/src/LazyTensors/LazyTensors.jl Wed Sep 25 10:25:30 2024 +0200 @@ -25,7 +25,7 @@ Base.:*(a::LazyTensor, args::Union{LazyTensor, AbstractArray}...) = foldr(*,(a,args...)) # Addition and subtraction of lazy tensors -Base.:+(s::LazyTensor, t::LazyTensor) = ElementwiseTensorOperation{:+}(s,t) +Base.:+(ts::LazyTensor...) = ElementwiseTensorOperation{:+}(ts...) Base.:-(s::LazyTensor, t::LazyTensor) = ElementwiseTensorOperation{:-}(s,t) # Composing lazy tensors
--- a/src/LazyTensors/lazy_tensor_operations.jl Fri Sep 13 22:41:27 2024 +0200 +++ b/src/LazyTensors/lazy_tensor_operations.jl Wed Sep 25 10:25:30 2024 +0200 @@ -52,24 +52,54 @@ domain_size(tmt::TensorTranspose) = range_size(tmt.tm) -struct ElementwiseTensorOperation{Op,T,R,D,T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}} <: LazyTensor{T,R,D} - tm1::T1 - tm2::T2 +struct ElementwiseTensorOperation{Op,T,R,D,TT<:NTuple{N, LazyTensor{T,R,D}} where N} <: LazyTensor{T,R,D} + tms::TT - function ElementwiseTensorOperation{Op,T,R,D}(tm1::T1,tm2::T2) where {Op,T,R,D, T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}} - @boundscheck check_domain_size(tm2, domain_size(tm1)) - @boundscheck check_range_size(tm2, range_size(tm1)) - return new{Op,T,R,D,T1,T2}(tm1,tm2) + function ElementwiseTensorOperation{Op,T,R,D}(tms::TT) where {Op,T,R,D, TT<:NTuple{N, LazyTensor{T,R,D}} where N} + @boundscheck map(tms) do tm + check_domain_size(tm, domain_size(tms[1])) + check_range_size(tm, range_size(tms[1])) + end + + return new{Op,T,R,D,TT}(tms) end end -ElementwiseTensorOperation{Op}(s,t) where Op = ElementwiseTensorOperation{Op,eltype(s), range_dim(s), domain_dim(s)}(s,t) + +function ElementwiseTensorOperation{:+}(ts::Vararg{LazyTensor}) + return ElementwiseTensorOperation{:+,eltype(ts[1]), range_dim(ts[1]), domain_dim(ts[1])}(ts) +end + +# The following methods for :+ are intended to reduce the depth of the tree of operations in some caes +function ElementwiseTensorOperation{:+}(t1::ElementwiseTensorOperation{:+}, t2::ElementwiseTensorOperation{:+}) + ElementwiseTensorOperation{:+}(t1.tms..., t2.tms...) +end + +function ElementwiseTensorOperation{:+}(t1::ElementwiseTensorOperation{:+}, t2::LazyTensor) + ElementwiseTensorOperation{:+}(t1.tms..., t2) +end + +function ElementwiseTensorOperation{:+}(t1::LazyTensor, t2::ElementwiseTensorOperation{:+}) + ElementwiseTensorOperation{:+}(t1, t2.tms...) +end -apply(tmBinOp::ElementwiseTensorOperation{:+,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) + apply(tmBinOp.tm2, v, I...) -apply(tmBinOp::ElementwiseTensorOperation{:-,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) - apply(tmBinOp.tm2, v, I...) +function ElementwiseTensorOperation{:-}(t1::LazyTensor, t2::LazyTensor) + return ElementwiseTensorOperation{:-,eltype(t1), range_dim(t1), domain_dim(t1)}((t1,t2)) +end + +function apply(tmBinOp::ElementwiseTensorOperation{:+,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} + vs = map(tmBinOp.tms) do tm + apply(tm,v,I...) + end -range_size(tmBinOp::ElementwiseTensorOperation) = range_size(tmBinOp.tm1) -domain_size(tmBinOp::ElementwiseTensorOperation) = domain_size(tmBinOp.tm1) + return +(vs...) +end +function apply(tmBinOp::ElementwiseTensorOperation{:-,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} + apply(tmBinOp.tms[1], v, I...) - apply(tmBinOp.tms[2], v, I...) +end + +range_size(tmBinOp::ElementwiseTensorOperation) = range_size(tmBinOp.tms[1]) +domain_size(tmBinOp::ElementwiseTensorOperation) = domain_size(tmBinOp.tms[1]) """