changeset 1836:368999a2e243 refactor/lazy_tensors/elementwise_ops

Add TensorNegation
author Jonatan Werpers <jonatan@werpers.com>
date Thu, 09 Jan 2025 15:32:47 +0100
parents a6f28a8b8f3f
children 200971c71657
files src/LazyTensors/LazyTensors.jl src/LazyTensors/lazy_tensor_operations.jl test/LazyTensors/lazy_tensor_operations_test.jl
diffstat 3 files changed, 39 insertions(+), 1 deletions(-) [+]
line wrap: on
line diff
--- a/src/LazyTensors/LazyTensors.jl	Thu Jan 09 12:40:49 2025 +0100
+++ b/src/LazyTensors/LazyTensors.jl	Thu Jan 09 15:32:47 2025 +0100
@@ -9,6 +9,7 @@
 export TensorApplication
 export TensorTranspose
 export TensorComposition
+export TensorNegation
 export IdentityTensor
 export ScalingTensor
 export DiagonalTensor
@@ -36,6 +37,7 @@
 
 # Addition and subtraction of lazy tensors
 Base.:+(ts::LazyTensor...) = ElementwiseTensorOperation{:+}(ts...)
+Base.:-(t::LazyTensor) = TensorNegation(t)
 Base.:-(s::LazyTensor, t::LazyTensor) = ElementwiseTensorOperation{:-}(s,t)
 
 # Composing lazy tensors
--- a/src/LazyTensors/lazy_tensor_operations.jl	Thu Jan 09 12:40:49 2025 +0100
+++ b/src/LazyTensors/lazy_tensor_operations.jl	Thu Jan 09 15:32:47 2025 +0100
@@ -51,6 +51,16 @@
 range_size(tmt::TensorTranspose) = domain_size(tmt.tm)
 domain_size(tmt::TensorTranspose) = range_size(tmt.tm)
 
+struct TensorNegation{T,R,D, TM<:LazyTensor{T,R,D}} <: LazyTensor{T,R,D}
+    tm::TM
+end
+
+apply(tm::TensorNegation, v, I...) = -apply(tm.tm, v, I...)
+apply_transpose(tm::TensorNegation, v, I...) = -apply_transpose(tm.tm, v, I...)
+
+range_size(tm::TensorNegation) = range_size(tm.tm)
+domain_size(tm::TensorNegation) = domain_size(tm.tm)
+
 
 struct ElementwiseTensorOperation{Op,T,R,D,TT<:NTuple{N, LazyTensor{T,R,D}} where N} <: LazyTensor{T,R,D}
     tms::TT
@@ -157,7 +167,6 @@
 
 Base.:*(a::T, tm::LazyTensor{T}) where T = TensorComposition(ScalingTensor{T,range_dim(tm)}(a,range_size(tm)), tm)
 Base.:*(tm::LazyTensor{T}, a::T) where T = a*tm
-Base.:-(tm::LazyTensor) = (-one(eltype(tm)))*tm
 
 """
     InflatedTensor{T,R,D} <: LazyTensor{T,R,D}
--- a/test/LazyTensors/lazy_tensor_operations_test.jl	Thu Jan 09 12:40:49 2025 +0100
+++ b/test/LazyTensors/lazy_tensor_operations_test.jl	Thu Jan 09 15:32:47 2025 +0100
@@ -128,6 +128,33 @@
     end
 end
 
+@testset "TensorNegation" begin
+    A = rand(2,3)
+    B = rand(3,4)
+
+    Ã = DenseTensor(A, (1,), (2,))
+    B̃ = DenseTensor(B, (1,), (2,))
+
+    @test -Ã isa TensorNegation
+
+    v = rand(3)
+    @test (-Ã)*v == -(Ã*v)
+
+    v = rand(4)
+    @test (-B̃)*v == -(B̃*v)
+
+    v = rand(2)
+    @test (-Ã)'*v == -(Ã'*v)
+
+    v = rand(3)
+    @test (-B̃)'*v == -(B̃'*v)
+
+    @test domain_size(-Ã) == (3,)
+    @test domain_size(-B̃) == (4,)
+
+    @test range_size(-Ã) == (2,)
+    @test range_size(-B̃) == (3,)
+end
 
 @testset "LazyTensor binary operations" begin
     A = ScalingTensor(2.0, (3,))