Mercurial > repos > public > sbplib_julia
changeset 1900:418566cdd689
Merge refactor/lazy_tensors/elementwise_ops
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Fri, 31 Jan 2025 20:35:28 +0100 |
parents | 9d708f3300d5 (current diff) 0f2b33d60f49 (diff) |
children | edee7d677efb f93ba5832146 e68669552ed8 e4500727f435 |
files | |
diffstat | 3 files changed, 119 insertions(+), 17 deletions(-) [+] |
line wrap: on
line diff
--- a/src/LazyTensors/LazyTensors.jl Fri Jan 31 15:52:49 2025 +0100 +++ b/src/LazyTensors/LazyTensors.jl Fri Jan 31 20:35:28 2025 +0100 @@ -9,6 +9,8 @@ export TensorApplication export TensorTranspose export TensorComposition +export TensorNegation +export TensorSum export IdentityTensor export ScalingTensor export DiagonalTensor @@ -35,8 +37,13 @@ Base.:*(a::LazyTensor, args::Union{LazyTensor, AbstractArray}...) = foldr(*,(a,args...)) # Addition and subtraction of lazy tensors -Base.:+(s::LazyTensor, t::LazyTensor) = ElementwiseTensorOperation{:+}(s,t) -Base.:-(s::LazyTensor, t::LazyTensor) = ElementwiseTensorOperation{:-}(s,t) +Base.:+(ts::LazyTensor...) = TensorSum(ts...) +Base.:-(t::LazyTensor) = TensorNegation(t) +Base.:-(s::LazyTensor, t::LazyTensor) = s + (-t) +## Specializations to flatten the nesting of tensors. This helps Julia during inference. +Base.:+(t::TensorSum, s::TensorSum) = TensorSum(t.tms..., s.tms...) +Base.:+(t::TensorSum, s::LazyTensor) = TensorSum(t.tms..., s) +Base.:+(t::LazyTensor, s::TensorSum) = TensorSum(t, s.tms...) # Composing lazy tensors Base.:∘(s::LazyTensor, t::LazyTensor) = TensorComposition(s,t)
--- a/src/LazyTensors/lazy_tensor_operations.jl Fri Jan 31 15:52:49 2025 +0100 +++ b/src/LazyTensors/lazy_tensor_operations.jl Fri Jan 31 20:35:28 2025 +0100 @@ -52,24 +52,66 @@ domain_size(tmt::TensorTranspose) = range_size(tmt.tm) -struct ElementwiseTensorOperation{Op,T,R,D,T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}} <: LazyTensor{T,R,D} - tm1::T1 - tm2::T2 +""" + TensorNegation{T,R,D,...} <: LazyTensor{T,R,D} + +The negation of a LazyTensor. +""" +struct TensorNegation{T,R,D,TM<:LazyTensor{T,R,D}} <: LazyTensor{T,R,D} + tm::TM +end + +apply(tm::TensorNegation, v, I...) = -apply(tm.tm, v, I...) +apply_transpose(tm::TensorNegation, v, I...) = -apply_transpose(tm.tm, v, I...) + +range_size(tm::TensorNegation) = range_size(tm.tm) +domain_size(tm::TensorNegation) = domain_size(tm.tm) + - function ElementwiseTensorOperation{Op,T,R,D}(tm1::T1,tm2::T2) where {Op,T,R,D, T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}} - @boundscheck check_domain_size(tm2, domain_size(tm1)) - @boundscheck check_range_size(tm2, range_size(tm1)) - return new{Op,T,R,D,T1,T2}(tm1,tm2) +""" + TensorSum{T,R,D,...} <: LazyTensor{T,R,D} + +The lazy sum of 2 or more lazy tensors. +""" +struct TensorSum{T,R,D,TT<:NTuple{N, LazyTensor{T,R,D}} where N} <: LazyTensor{T,R,D} + tms::TT + + function TensorSum{T,R,D}(tms::TT) where {T,R,D, TT<:NTuple{N, LazyTensor{T,R,D}} where N} + @boundscheck map(tms) do tm + check_domain_size(tm, domain_size(tms[1])) + check_range_size(tm, range_size(tms[1])) + end + + return new{T,R,D,TT}(tms) end end -ElementwiseTensorOperation{Op}(s,t) where Op = ElementwiseTensorOperation{Op,eltype(s), range_dim(s), domain_dim(s)}(s,t) +""" + TensorSum(ts::Vararg{LazyTensor}) + +The lazy sum of the tensors `ts`. +""" +function TensorSum(ts::Vararg{LazyTensor}) + T = eltype(ts[1]) + R = range_dim(ts[1]) + D = domain_dim(ts[1]) + return TensorSum{T,R,D}(ts) +end -apply(tmBinOp::ElementwiseTensorOperation{:+,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) + apply(tmBinOp.tm2, v, I...) -apply(tmBinOp::ElementwiseTensorOperation{:-,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) - apply(tmBinOp.tm2, v, I...) +function apply(tmBinOp::TensorSum{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} + return sum(tmBinOp.tms) do tm + apply(tm,v,I...) + end +end -range_size(tmBinOp::ElementwiseTensorOperation) = range_size(tmBinOp.tm1) -domain_size(tmBinOp::ElementwiseTensorOperation) = domain_size(tmBinOp.tm1) +function apply_transpose(tmBinOp::TensorSum{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} + return sum(tmBinOp.tms) do tm + apply_transpose(tm,v,I...) + end +end + +range_size(tmBinOp::TensorSum) = range_size(tmBinOp.tms[1]) +domain_size(tmBinOp::TensorSum) = domain_size(tmBinOp.tms[1]) """ @@ -121,7 +163,6 @@ Base.:*(a::T, tm::LazyTensor{T}) where T = TensorComposition(ScalingTensor{T,range_dim(tm)}(a,range_size(tm)), tm) Base.:*(tm::LazyTensor{T}, a::T) where T = a*tm -Base.:-(tm::LazyTensor) = (-one(eltype(tm)))*tm """ InflatedTensor{T,R,D} <: LazyTensor{T,R,D}
--- a/test/LazyTensors/lazy_tensor_operations_test.jl Fri Jan 31 15:52:49 2025 +0100 +++ b/test/LazyTensors/lazy_tensor_operations_test.jl Fri Jan 31 20:35:28 2025 +0100 @@ -22,7 +22,6 @@ LazyTensors.domain_size(m::SizeDoublingMapping) = m.domain_size - @testset "Mapping transpose" begin m = TransposableDummyMapping{Float64,2,3}() @test m' isa LazyTensor{Float64, 3,2} @@ -128,8 +127,35 @@ end end +@testset "TensorNegation" begin + A = rand(2,3) + B = rand(3,4) -@testset "LazyTensor binary operations" begin + Ã = DenseTensor(A, (1,), (2,)) + B̃ = DenseTensor(B, (1,), (2,)) + + @test -Ã isa TensorNegation + + v = rand(3) + @test (-Ã)*v == -(Ã*v) + + v = rand(4) + @test (-B̃)*v == -(B̃*v) + + v = rand(2) + @test (-Ã)'*v == -(Ã'*v) + + v = rand(3) + @test (-B̃)'*v == -(B̃'*v) + + @test domain_size(-Ã) == (3,) + @test domain_size(-B̃) == (4,) + + @test range_size(-Ã) == (2,) + @test range_size(-B̃) == (3,) +end + +@testset "TensorSum" begin A = ScalingTensor(2.0, (3,)) B = ScalingTensor(3.0, (3,)) @@ -142,6 +168,10 @@ @test ((A-B)*v)[i] == 2*v[i] - 3*v[i] end + for i ∈ eachindex(v) + @test ((A+B)'*v)[i] == 2*v[i] + 3*v[i] + end + @test range_size(A+B) == range_size(A) == range_size(B) @test domain_size(A+B) == domain_size(A) == domain_size(B) @@ -156,6 +186,30 @@ @test_throws RangeSizeMismatch ScalingTensor(2.0, (2,)) + SizeDoublingMapping{Float64,1,1}((2,)) @test_throws RangeSizeMismatch SizeDoublingMapping{Float64,1,1}((2,)) + ScalingTensor(2.0, (2,)) end + + @testset "Chained operators" begin + A = ScalingTensor(1.0, (3,)) + B = ScalingTensor(2.0, (3,)) + C = ScalingTensor(3.0, (3,)) + D = ScalingTensor(4.0, (3,)) + + @test A+B+C+D isa TensorSum + @test length((A+B+C+D).tms) == 4 + + + @test A+B-C+D isa TensorSum + @test length((A+B-C+D).tms) == 4 + + v = rand(3) + @test (A+B-C+D)*v == 1v + 2v - 3v + 4v + + + @test -A-B-C-D isa TensorSum + @test length((-A-B-C-D).tms) == 4 + + v = rand(3) + @test (-A-B-C-D)*v == -1v - 2v - 3v - 4v + end end