Mercurial > repos > public > sbplib_julia
changeset 1837:200971c71657 refactor/lazy_tensors/elementwise_ops
Refactor ElementwiseTensorOperation into TensorSum and use TensorNegation for handling subtraction
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Thu, 09 Jan 2025 21:46:01 +0100 |
parents | 368999a2e243 |
children | 4bd998069053 |
files | src/LazyTensors/LazyTensors.jl src/LazyTensors/lazy_tensor_operations.jl test/LazyTensors/lazy_tensor_operations_test.jl |
diffstat | 3 files changed, 21 insertions(+), 32 deletions(-) [+] |
line wrap: on
line diff
--- a/src/LazyTensors/LazyTensors.jl Thu Jan 09 15:32:47 2025 +0100 +++ b/src/LazyTensors/LazyTensors.jl Thu Jan 09 21:46:01 2025 +0100 @@ -10,6 +10,7 @@ export TensorTranspose export TensorComposition export TensorNegation +export TensorSum export IdentityTensor export ScalingTensor export DiagonalTensor @@ -36,9 +37,9 @@ Base.:*(a::LazyTensor, args::Union{LazyTensor, AbstractArray}...) = foldr(*,(a,args...)) # Addition and subtraction of lazy tensors -Base.:+(ts::LazyTensor...) = ElementwiseTensorOperation{:+}(ts...) +Base.:+(ts::LazyTensor...) = TensorSum(ts...) Base.:-(t::LazyTensor) = TensorNegation(t) -Base.:-(s::LazyTensor, t::LazyTensor) = ElementwiseTensorOperation{:-}(s,t) +Base.:-(s::LazyTensor, t::LazyTensor) = TensorSum(s,-t) # Composing lazy tensors Base.:∘(s::LazyTensor, t::LazyTensor) = TensorComposition(s,t)
--- a/src/LazyTensors/lazy_tensor_operations.jl Thu Jan 09 15:32:47 2025 +0100 +++ b/src/LazyTensors/lazy_tensor_operations.jl Thu Jan 09 21:46:01 2025 +0100 @@ -62,60 +62,49 @@ domain_size(tm::TensorNegation) = domain_size(tm.tm) -struct ElementwiseTensorOperation{Op,T,R,D,TT<:NTuple{N, LazyTensor{T,R,D}} where N} <: LazyTensor{T,R,D} +struct TensorSum{T,R,D,TT<:NTuple{N, LazyTensor{T,R,D}} where N} <: LazyTensor{T,R,D} tms::TT - function ElementwiseTensorOperation{Op,T,R,D}(tms::TT) where {Op,T,R,D, TT<:NTuple{N, LazyTensor{T,R,D}} where N} + function TensorSum{T,R,D}(tms::TT) where {T,R,D, TT<:NTuple{N, LazyTensor{T,R,D}} where N} @boundscheck map(tms) do tm check_domain_size(tm, domain_size(tms[1])) check_range_size(tm, range_size(tms[1])) end - return new{Op,T,R,D,TT}(tms) + return new{T,R,D,TT}(tms) end end -# TBD: Can we introduce negation of LazyTensors? It could be done generically -# with a ScalingTensor but also using specializations for specific tensor -# types. This would allow simplification of ElementwiseTensorOperation to -# TensorSum. The implementation of `-` can be done using negation and the -# TensorSum type. We should make sure this doesn't impact the efficiency of -# for example SATs. - -function ElementwiseTensorOperation{:+}(ts::Vararg{LazyTensor}) - return ElementwiseTensorOperation{:+,eltype(ts[1]), range_dim(ts[1]), domain_dim(ts[1])}(ts) +function TensorSum(ts::Vararg{LazyTensor}) + T = eltype(ts[1]) + R = range_dim(ts[1]) + D = domain_dim(ts[1]) + return TensorSum{T,R,D}(ts) end # The following methods for :+ are intended to reduce the depth of the tree of operations in some caes -function ElementwiseTensorOperation{:+}(t1::ElementwiseTensorOperation{:+}, t2::ElementwiseTensorOperation{:+}) - ElementwiseTensorOperation{:+}(t1.tms..., t2.tms...) -end - -function ElementwiseTensorOperation{:+}(t1::ElementwiseTensorOperation{:+}, t2::LazyTensor) - ElementwiseTensorOperation{:+}(t1.tms..., t2) +function TensorSum(t1::TensorSum, t2::TensorSum) + TensorSum(t1.tms..., t2.tms...) end -function ElementwiseTensorOperation{:+}(t1::LazyTensor, t2::ElementwiseTensorOperation{:+}) - ElementwiseTensorOperation{:+}(t1, t2.tms...) +function TensorSum(t1::TensorSum, t2::LazyTensor) + TensorSum(t1.tms..., t2) end -function ElementwiseTensorOperation{:-}(t1::LazyTensor, t2::LazyTensor) - return ElementwiseTensorOperation{:-,eltype(t1), range_dim(t1), domain_dim(t1)}((t1,t2)) +function TensorSum(t1::LazyTensor, t2::TensorSum) + TensorSum(t1, t2.tms...) end -function apply(tmBinOp::ElementwiseTensorOperation{:+,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} +function apply(tmBinOp::TensorSum{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} vs = map(tmBinOp.tms) do tm apply(tm,v,I...) end return +(vs...) end -function apply(tmBinOp::ElementwiseTensorOperation{:-,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} - apply(tmBinOp.tms[1], v, I...) - apply(tmBinOp.tms[2], v, I...) -end -range_size(tmBinOp::ElementwiseTensorOperation) = range_size(tmBinOp.tms[1]) -domain_size(tmBinOp::ElementwiseTensorOperation) = domain_size(tmBinOp.tms[1]) +range_size(tmBinOp::TensorSum) = range_size(tmBinOp.tms[1]) +domain_size(tmBinOp::TensorSum) = domain_size(tmBinOp.tms[1]) """
--- a/test/LazyTensors/lazy_tensor_operations_test.jl Thu Jan 09 15:32:47 2025 +0100 +++ b/test/LazyTensors/lazy_tensor_operations_test.jl Thu Jan 09 21:46:01 2025 +0100 @@ -22,7 +22,6 @@ LazyTensors.domain_size(m::SizeDoublingMapping) = m.domain_size - @testset "Mapping transpose" begin m = TransposableDummyMapping{Float64,2,3}() @test m' isa LazyTensor{Float64, 3,2} @@ -156,7 +155,7 @@ @test range_size(-B̃) == (3,) end -@testset "LazyTensor binary operations" begin +@testset "TensorSum" begin A = ScalingTensor(2.0, (3,)) B = ScalingTensor(3.0, (3,))