changeset 1790:602104ac0e41 feature/sbp_operators/laplace_curvilinear

Merge refactor/lazy_tensors/elementwise_ops
author Jonatan Werpers <jonatan@werpers.com>
date Wed, 25 Sep 2024 10:33:48 +0200
parents 1f42944d4a72 (current diff) 48eaa973159a (diff)
children b8cb38fd67ff
files
diffstat 2 files changed, 49 insertions(+), 13 deletions(-) [+]
line wrap: on
line diff
--- a/src/LazyTensors/LazyTensors.jl	Mon Sep 16 11:03:37 2024 +0200
+++ b/src/LazyTensors/LazyTensors.jl	Wed Sep 25 10:33:48 2024 +0200
@@ -25,7 +25,7 @@
 Base.:*(a::LazyTensor, args::Union{LazyTensor, AbstractArray}...) = foldr(*,(a,args...))
 
 # Addition and subtraction of lazy tensors
-Base.:+(s::LazyTensor, t::LazyTensor) = ElementwiseTensorOperation{:+}(s,t)
+Base.:+(ts::LazyTensor...) = ElementwiseTensorOperation{:+}(ts...)
 Base.:-(s::LazyTensor, t::LazyTensor) = ElementwiseTensorOperation{:-}(s,t)
 
 # Composing lazy tensors
--- a/src/LazyTensors/lazy_tensor_operations.jl	Mon Sep 16 11:03:37 2024 +0200
+++ b/src/LazyTensors/lazy_tensor_operations.jl	Wed Sep 25 10:33:48 2024 +0200
@@ -52,24 +52,60 @@
 domain_size(tmt::TensorTranspose) = range_size(tmt.tm)
 
 
-struct ElementwiseTensorOperation{Op,T,R,D,T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}} <: LazyTensor{T,R,D}
-    tm1::T1
-    tm2::T2
+struct ElementwiseTensorOperation{Op,T,R,D,TT<:NTuple{N, LazyTensor{T,R,D}} where N} <: LazyTensor{T,R,D}
+    tms::TT
 
-    function ElementwiseTensorOperation{Op,T,R,D}(tm1::T1,tm2::T2) where {Op,T,R,D, T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}}
-        @boundscheck check_domain_size(tm2, domain_size(tm1))
-        @boundscheck check_range_size(tm2, range_size(tm1))
-        return new{Op,T,R,D,T1,T2}(tm1,tm2)
+    function ElementwiseTensorOperation{Op,T,R,D}(tms::TT) where {Op,T,R,D, TT<:NTuple{N, LazyTensor{T,R,D}} where N}
+        @boundscheck map(tms) do tm
+            check_domain_size(tm, domain_size(tms[1]))
+            check_range_size(tm, range_size(tms[1]))
+        end
+
+        return new{Op,T,R,D,TT}(tms)
     end
 end
+# TBD: Can we introduce negation of LazyTensors? It could be done generically
+# with a ScalingTensor but also using specializations for specific tensor
+# types. This would allow simplification of ElementwiseTensorOperation to
+# TensorSum. The implementation of `-` can be done using negation and the
+# TensorSum type. We should make sure this doesn't impact the efficiency of
+# for example SATs.
 
-ElementwiseTensorOperation{Op}(s,t) where Op = ElementwiseTensorOperation{Op,eltype(s), range_dim(s), domain_dim(s)}(s,t)
+
+function ElementwiseTensorOperation{:+}(ts::Vararg{LazyTensor})
+    return ElementwiseTensorOperation{:+,eltype(ts[1]), range_dim(ts[1]), domain_dim(ts[1])}(ts)
+end
+
+# The following methods for :+ are intended to reduce the depth of the tree of operations in some caes
+function ElementwiseTensorOperation{:+}(t1::ElementwiseTensorOperation{:+}, t2::ElementwiseTensorOperation{:+})
+    ElementwiseTensorOperation{:+}(t1.tms..., t2.tms...)
+end
+
+function ElementwiseTensorOperation{:+}(t1::ElementwiseTensorOperation{:+}, t2::LazyTensor)
+    ElementwiseTensorOperation{:+}(t1.tms..., t2)
+end
 
-apply(tmBinOp::ElementwiseTensorOperation{:+,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) + apply(tmBinOp.tm2, v, I...)
-apply(tmBinOp::ElementwiseTensorOperation{:-,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) - apply(tmBinOp.tm2, v, I...)
+function ElementwiseTensorOperation{:+}(t1::LazyTensor, t2::ElementwiseTensorOperation{:+})
+    ElementwiseTensorOperation{:+}(t1, t2.tms...)
+end
+
+function ElementwiseTensorOperation{:-}(t1::LazyTensor, t2::LazyTensor)
+    return ElementwiseTensorOperation{:-,eltype(t1), range_dim(t1), domain_dim(t1)}((t1,t2))
+end
 
-range_size(tmBinOp::ElementwiseTensorOperation) = range_size(tmBinOp.tm1)
-domain_size(tmBinOp::ElementwiseTensorOperation) = domain_size(tmBinOp.tm1)
+function apply(tmBinOp::ElementwiseTensorOperation{:+,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D}
+    vs = map(tmBinOp.tms) do tm
+        apply(tm,v,I...)
+    end
+
+    return +(vs...)
+end
+function apply(tmBinOp::ElementwiseTensorOperation{:-,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D}
+    apply(tmBinOp.tms[1], v, I...) - apply(tmBinOp.tms[2], v, I...)
+end
+
+range_size(tmBinOp::ElementwiseTensorOperation) = range_size(tmBinOp.tms[1])
+domain_size(tmBinOp::ElementwiseTensorOperation) = domain_size(tmBinOp.tms[1])
 
 
 """