diff src/LazyTensors/lazy_tensor_operations.jl @ 1837:200971c71657 refactor/lazy_tensors/elementwise_ops

Refactor ElementwiseTensorOperation into TensorSum and use TensorNegation for handling subtraction
author Jonatan Werpers <jonatan@werpers.com>
date Thu, 09 Jan 2025 21:46:01 +0100
parents 368999a2e243
children e1077273eda5
line wrap: on
line diff
--- a/src/LazyTensors/lazy_tensor_operations.jl	Thu Jan 09 15:32:47 2025 +0100
+++ b/src/LazyTensors/lazy_tensor_operations.jl	Thu Jan 09 21:46:01 2025 +0100
@@ -62,60 +62,49 @@
 domain_size(tm::TensorNegation) = domain_size(tm.tm)
 
 
-struct ElementwiseTensorOperation{Op,T,R,D,TT<:NTuple{N, LazyTensor{T,R,D}} where N} <: LazyTensor{T,R,D}
+struct TensorSum{T,R,D,TT<:NTuple{N, LazyTensor{T,R,D}} where N} <: LazyTensor{T,R,D}
     tms::TT
 
-    function ElementwiseTensorOperation{Op,T,R,D}(tms::TT) where {Op,T,R,D, TT<:NTuple{N, LazyTensor{T,R,D}} where N}
+    function TensorSum{T,R,D}(tms::TT) where {T,R,D, TT<:NTuple{N, LazyTensor{T,R,D}} where N}
         @boundscheck map(tms) do tm
             check_domain_size(tm, domain_size(tms[1]))
             check_range_size(tm, range_size(tms[1]))
         end
 
-        return new{Op,T,R,D,TT}(tms)
+        return new{T,R,D,TT}(tms)
     end
 end
-# TBD: Can we introduce negation of LazyTensors? It could be done generically
-# with a ScalingTensor but also using specializations for specific tensor
-# types. This would allow simplification of ElementwiseTensorOperation to
-# TensorSum. The implementation of `-` can be done using negation and the
-# TensorSum type. We should make sure this doesn't impact the efficiency of
-# for example SATs.
 
-
-function ElementwiseTensorOperation{:+}(ts::Vararg{LazyTensor})
-    return ElementwiseTensorOperation{:+,eltype(ts[1]), range_dim(ts[1]), domain_dim(ts[1])}(ts)
+function TensorSum(ts::Vararg{LazyTensor})
+    T = eltype(ts[1])
+    R = range_dim(ts[1])
+    D = domain_dim(ts[1])
+    return TensorSum{T,R,D}(ts)
 end
 
 # The following methods for :+ are intended to reduce the depth of the tree of operations in some caes
-function ElementwiseTensorOperation{:+}(t1::ElementwiseTensorOperation{:+}, t2::ElementwiseTensorOperation{:+})
-    ElementwiseTensorOperation{:+}(t1.tms..., t2.tms...)
-end
-
-function ElementwiseTensorOperation{:+}(t1::ElementwiseTensorOperation{:+}, t2::LazyTensor)
-    ElementwiseTensorOperation{:+}(t1.tms..., t2)
+function TensorSum(t1::TensorSum, t2::TensorSum)
+    TensorSum(t1.tms..., t2.tms...)
 end
 
-function ElementwiseTensorOperation{:+}(t1::LazyTensor, t2::ElementwiseTensorOperation{:+})
-    ElementwiseTensorOperation{:+}(t1, t2.tms...)
+function TensorSum(t1::TensorSum, t2::LazyTensor)
+    TensorSum(t1.tms..., t2)
 end
 
-function ElementwiseTensorOperation{:-}(t1::LazyTensor, t2::LazyTensor)
-    return ElementwiseTensorOperation{:-,eltype(t1), range_dim(t1), domain_dim(t1)}((t1,t2))
+function TensorSum(t1::LazyTensor, t2::TensorSum)
+    TensorSum(t1, t2.tms...)
 end
 
-function apply(tmBinOp::ElementwiseTensorOperation{:+,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D}
+function apply(tmBinOp::TensorSum{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D}
     vs = map(tmBinOp.tms) do tm
         apply(tm,v,I...)
     end
 
     return +(vs...)
 end
-function apply(tmBinOp::ElementwiseTensorOperation{:-,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D}
-    apply(tmBinOp.tms[1], v, I...) - apply(tmBinOp.tms[2], v, I...)
-end
 
-range_size(tmBinOp::ElementwiseTensorOperation) = range_size(tmBinOp.tms[1])
-domain_size(tmBinOp::ElementwiseTensorOperation) = domain_size(tmBinOp.tms[1])
+range_size(tmBinOp::TensorSum) = range_size(tmBinOp.tms[1])
+domain_size(tmBinOp::TensorSum) = domain_size(tmBinOp.tms[1])
 
 
 """