Mercurial > repos > public > sbplib_julia
diff src/LazyTensors/lazy_tensor_operations.jl @ 1900:418566cdd689
Merge refactor/lazy_tensors/elementwise_ops
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Fri, 31 Jan 2025 20:35:28 +0100 |
parents | ed50eec18365 |
children |
line wrap: on
line diff
--- a/src/LazyTensors/lazy_tensor_operations.jl Fri Jan 31 15:52:49 2025 +0100 +++ b/src/LazyTensors/lazy_tensor_operations.jl Fri Jan 31 20:35:28 2025 +0100 @@ -52,24 +52,66 @@ domain_size(tmt::TensorTranspose) = range_size(tmt.tm) -struct ElementwiseTensorOperation{Op,T,R,D,T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}} <: LazyTensor{T,R,D} - tm1::T1 - tm2::T2 +""" + TensorNegation{T,R,D,...} <: LazyTensor{T,R,D} + +The negation of a LazyTensor. +""" +struct TensorNegation{T,R,D,TM<:LazyTensor{T,R,D}} <: LazyTensor{T,R,D} + tm::TM +end + +apply(tm::TensorNegation, v, I...) = -apply(tm.tm, v, I...) +apply_transpose(tm::TensorNegation, v, I...) = -apply_transpose(tm.tm, v, I...) + +range_size(tm::TensorNegation) = range_size(tm.tm) +domain_size(tm::TensorNegation) = domain_size(tm.tm) + - function ElementwiseTensorOperation{Op,T,R,D}(tm1::T1,tm2::T2) where {Op,T,R,D, T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}} - @boundscheck check_domain_size(tm2, domain_size(tm1)) - @boundscheck check_range_size(tm2, range_size(tm1)) - return new{Op,T,R,D,T1,T2}(tm1,tm2) +""" + TensorSum{T,R,D,...} <: LazyTensor{T,R,D} + +The lazy sum of 2 or more lazy tensors. +""" +struct TensorSum{T,R,D,TT<:NTuple{N, LazyTensor{T,R,D}} where N} <: LazyTensor{T,R,D} + tms::TT + + function TensorSum{T,R,D}(tms::TT) where {T,R,D, TT<:NTuple{N, LazyTensor{T,R,D}} where N} + @boundscheck map(tms) do tm + check_domain_size(tm, domain_size(tms[1])) + check_range_size(tm, range_size(tms[1])) + end + + return new{T,R,D,TT}(tms) end end -ElementwiseTensorOperation{Op}(s,t) where Op = ElementwiseTensorOperation{Op,eltype(s), range_dim(s), domain_dim(s)}(s,t) +""" + TensorSum(ts::Vararg{LazyTensor}) + +The lazy sum of the tensors `ts`. +""" +function TensorSum(ts::Vararg{LazyTensor}) + T = eltype(ts[1]) + R = range_dim(ts[1]) + D = domain_dim(ts[1]) + return TensorSum{T,R,D}(ts) +end -apply(tmBinOp::ElementwiseTensorOperation{:+,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) + apply(tmBinOp.tm2, v, I...) -apply(tmBinOp::ElementwiseTensorOperation{:-,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) - apply(tmBinOp.tm2, v, I...) +function apply(tmBinOp::TensorSum{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} + return sum(tmBinOp.tms) do tm + apply(tm,v,I...) + end +end -range_size(tmBinOp::ElementwiseTensorOperation) = range_size(tmBinOp.tm1) -domain_size(tmBinOp::ElementwiseTensorOperation) = domain_size(tmBinOp.tm1) +function apply_transpose(tmBinOp::TensorSum{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} + return sum(tmBinOp.tms) do tm + apply_transpose(tm,v,I...) + end +end + +range_size(tmBinOp::TensorSum) = range_size(tmBinOp.tms[1]) +domain_size(tmBinOp::TensorSum) = domain_size(tmBinOp.tms[1]) """ @@ -121,7 +163,6 @@ Base.:*(a::T, tm::LazyTensor{T}) where T = TensorComposition(ScalingTensor{T,range_dim(tm)}(a,range_size(tm)), tm) Base.:*(tm::LazyTensor{T}, a::T) where T = a*tm -Base.:-(tm::LazyTensor) = (-one(eltype(tm)))*tm """ InflatedTensor{T,R,D} <: LazyTensor{T,R,D}