changeset 1900:418566cdd689

Merge refactor/lazy_tensors/elementwise_ops
author Jonatan Werpers <jonatan@werpers.com>
date Fri, 31 Jan 2025 20:35:28 +0100
parents 9d708f3300d5 (current diff) 0f2b33d60f49 (diff)
children edee7d677efb f93ba5832146 e68669552ed8 e4500727f435
files
diffstat 3 files changed, 119 insertions(+), 17 deletions(-) [+]
line wrap: on
line diff
--- a/src/LazyTensors/LazyTensors.jl	Fri Jan 31 15:52:49 2025 +0100
+++ b/src/LazyTensors/LazyTensors.jl	Fri Jan 31 20:35:28 2025 +0100
@@ -9,6 +9,8 @@
 export TensorApplication
 export TensorTranspose
 export TensorComposition
+export TensorNegation
+export TensorSum
 export IdentityTensor
 export ScalingTensor
 export DiagonalTensor
@@ -35,8 +37,13 @@
 Base.:*(a::LazyTensor, args::Union{LazyTensor, AbstractArray}...) = foldr(*,(a,args...))
 
 # Addition and subtraction of lazy tensors
-Base.:+(s::LazyTensor, t::LazyTensor) = ElementwiseTensorOperation{:+}(s,t)
-Base.:-(s::LazyTensor, t::LazyTensor) = ElementwiseTensorOperation{:-}(s,t)
+Base.:+(ts::LazyTensor...) = TensorSum(ts...)
+Base.:-(t::LazyTensor) = TensorNegation(t)
+Base.:-(s::LazyTensor, t::LazyTensor) = s + (-t)
+## Specializations to flatten the nesting of tensors. This helps Julia during inference.
+Base.:+(t::TensorSum, s::TensorSum) = TensorSum(t.tms..., s.tms...)
+Base.:+(t::TensorSum, s::LazyTensor) = TensorSum(t.tms..., s)
+Base.:+(t::LazyTensor, s::TensorSum) = TensorSum(t, s.tms...)
 
 # Composing lazy tensors
 Base.:∘(s::LazyTensor, t::LazyTensor) = TensorComposition(s,t)
--- a/src/LazyTensors/lazy_tensor_operations.jl	Fri Jan 31 15:52:49 2025 +0100
+++ b/src/LazyTensors/lazy_tensor_operations.jl	Fri Jan 31 20:35:28 2025 +0100
@@ -52,24 +52,66 @@
 domain_size(tmt::TensorTranspose) = range_size(tmt.tm)
 
 
-struct ElementwiseTensorOperation{Op,T,R,D,T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}} <: LazyTensor{T,R,D}
-    tm1::T1
-    tm2::T2
+"""
+    TensorNegation{T,R,D,...} <: LazyTensor{T,R,D}
+
+The negation of a LazyTensor.
+"""
+struct TensorNegation{T,R,D,TM<:LazyTensor{T,R,D}} <: LazyTensor{T,R,D}
+    tm::TM
+end
+
+apply(tm::TensorNegation, v, I...) = -apply(tm.tm, v, I...)
+apply_transpose(tm::TensorNegation, v, I...) = -apply_transpose(tm.tm, v, I...)
+
+range_size(tm::TensorNegation) = range_size(tm.tm)
+domain_size(tm::TensorNegation) = domain_size(tm.tm)
+
 
-    function ElementwiseTensorOperation{Op,T,R,D}(tm1::T1,tm2::T2) where {Op,T,R,D, T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}}
-        @boundscheck check_domain_size(tm2, domain_size(tm1))
-        @boundscheck check_range_size(tm2, range_size(tm1))
-        return new{Op,T,R,D,T1,T2}(tm1,tm2)
+"""
+    TensorSum{T,R,D,...} <: LazyTensor{T,R,D}
+
+The lazy sum of 2 or more lazy tensors.
+"""
+struct TensorSum{T,R,D,TT<:NTuple{N, LazyTensor{T,R,D}} where N} <: LazyTensor{T,R,D}
+    tms::TT
+
+    function TensorSum{T,R,D}(tms::TT) where {T,R,D, TT<:NTuple{N, LazyTensor{T,R,D}} where N}
+        @boundscheck map(tms) do tm
+            check_domain_size(tm, domain_size(tms[1]))
+            check_range_size(tm, range_size(tms[1]))
+        end
+
+        return new{T,R,D,TT}(tms)
     end
 end
 
-ElementwiseTensorOperation{Op}(s,t) where Op = ElementwiseTensorOperation{Op,eltype(s), range_dim(s), domain_dim(s)}(s,t)
+"""
+    TensorSum(ts::Vararg{LazyTensor})
+
+The lazy sum of the tensors `ts`.
+"""
+function TensorSum(ts::Vararg{LazyTensor})
+    T = eltype(ts[1])
+    R = range_dim(ts[1])
+    D = domain_dim(ts[1])
+    return TensorSum{T,R,D}(ts)
+end
 
-apply(tmBinOp::ElementwiseTensorOperation{:+,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) + apply(tmBinOp.tm2, v, I...)
-apply(tmBinOp::ElementwiseTensorOperation{:-,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) - apply(tmBinOp.tm2, v, I...)
+function apply(tmBinOp::TensorSum{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D}
+    return sum(tmBinOp.tms) do tm
+        apply(tm,v,I...)
+    end
+end
 
-range_size(tmBinOp::ElementwiseTensorOperation) = range_size(tmBinOp.tm1)
-domain_size(tmBinOp::ElementwiseTensorOperation) = domain_size(tmBinOp.tm1)
+function apply_transpose(tmBinOp::TensorSum{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D}
+    return sum(tmBinOp.tms) do tm
+        apply_transpose(tm,v,I...)
+    end
+end
+
+range_size(tmBinOp::TensorSum) = range_size(tmBinOp.tms[1])
+domain_size(tmBinOp::TensorSum) = domain_size(tmBinOp.tms[1])
 
 
 """
@@ -121,7 +163,6 @@
 
 Base.:*(a::T, tm::LazyTensor{T}) where T = TensorComposition(ScalingTensor{T,range_dim(tm)}(a,range_size(tm)), tm)
 Base.:*(tm::LazyTensor{T}, a::T) where T = a*tm
-Base.:-(tm::LazyTensor) = (-one(eltype(tm)))*tm
 
 """
     InflatedTensor{T,R,D} <: LazyTensor{T,R,D}
--- a/test/LazyTensors/lazy_tensor_operations_test.jl	Fri Jan 31 15:52:49 2025 +0100
+++ b/test/LazyTensors/lazy_tensor_operations_test.jl	Fri Jan 31 20:35:28 2025 +0100
@@ -22,7 +22,6 @@
 LazyTensors.domain_size(m::SizeDoublingMapping) = m.domain_size
 
 
-
 @testset "Mapping transpose" begin
     m = TransposableDummyMapping{Float64,2,3}()
     @test m' isa LazyTensor{Float64, 3,2}
@@ -128,8 +127,35 @@
     end
 end
 
+@testset "TensorNegation" begin
+    A = rand(2,3)
+    B = rand(3,4)
 
-@testset "LazyTensor binary operations" begin
+    Ã = DenseTensor(A, (1,), (2,))
+    B̃ = DenseTensor(B, (1,), (2,))
+
+    @test -Ã isa TensorNegation
+
+    v = rand(3)
+    @test (-Ã)*v == -(Ã*v)
+
+    v = rand(4)
+    @test (-B̃)*v == -(B̃*v)
+
+    v = rand(2)
+    @test (-Ã)'*v == -(Ã'*v)
+
+    v = rand(3)
+    @test (-B̃)'*v == -(B̃'*v)
+
+    @test domain_size(-Ã) == (3,)
+    @test domain_size(-B̃) == (4,)
+
+    @test range_size(-Ã) == (2,)
+    @test range_size(-B̃) == (3,)
+end
+
+@testset "TensorSum" begin
     A = ScalingTensor(2.0, (3,))
     B = ScalingTensor(3.0, (3,))
 
@@ -142,6 +168,10 @@
         @test ((A-B)*v)[i] == 2*v[i] - 3*v[i]
     end
 
+    for i ∈ eachindex(v)
+        @test ((A+B)'*v)[i] == 2*v[i] + 3*v[i]
+    end
+
 
     @test range_size(A+B) == range_size(A) == range_size(B)
     @test domain_size(A+B) == domain_size(A) == domain_size(B)
@@ -156,6 +186,30 @@
         @test_throws RangeSizeMismatch ScalingTensor(2.0, (2,)) + SizeDoublingMapping{Float64,1,1}((2,))
         @test_throws RangeSizeMismatch SizeDoublingMapping{Float64,1,1}((2,)) + ScalingTensor(2.0, (2,))
     end
+
+    @testset "Chained operators" begin
+        A = ScalingTensor(1.0, (3,))
+        B = ScalingTensor(2.0, (3,))
+        C = ScalingTensor(3.0, (3,))
+        D = ScalingTensor(4.0, (3,))
+
+        @test A+B+C+D isa TensorSum
+        @test length((A+B+C+D).tms) == 4
+
+
+        @test A+B-C+D isa TensorSum
+        @test length((A+B-C+D).tms) == 4
+
+        v = rand(3)
+        @test (A+B-C+D)*v == 1v + 2v - 3v + 4v
+
+
+        @test -A-B-C-D isa TensorSum
+        @test length((-A-B-C-D).tms) == 4
+
+        v = rand(3)
+        @test (-A-B-C-D)*v == -1v - 2v - 3v - 4v
+    end
 end