diff src/LazyTensors/lazy_tensor_operations.jl @ 2057:8a2a0d678d6f feature/lazy_tensors/pretty_printing

Merge default
author Jonatan Werpers <jonatan@werpers.com>
date Tue, 10 Feb 2026 22:41:19 +0100
parents 1e8270c18edb ed50eec18365
children
line wrap: on
line diff
--- a/src/LazyTensors/lazy_tensor_operations.jl	Mon May 23 07:20:27 2022 +0200
+++ b/src/LazyTensors/lazy_tensor_operations.jl	Tue Feb 10 22:41:19 2026 +0100
@@ -5,7 +5,7 @@
 
 Allows the result of a `LazyTensor` applied to a vector to be treated as an `AbstractArray`.
 With a mapping `m` and a vector `v` the TensorApplication object can be created by `m*v`.
-The actual result will be calcualted when indexing into `m*v`.
+The actual result will be calculated when indexing into `m*v`.
 """
 struct TensorApplication{T,R,D, TM<:LazyTensor{<:Any,R,D}, AA<:AbstractArray{<:Any,D}} <: LazyArray{T,R}
     t::TM
@@ -52,24 +52,66 @@
 domain_size(tmt::TensorTranspose) = range_size(tmt.tm)
 
 
-struct ElementwiseTensorOperation{Op,T,R,D,T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}} <: LazyTensor{T,D,R}
-    tm1::T1
-    tm2::T2
+"""
+    TensorNegation{T,R,D,...} <: LazyTensor{T,R,D}
+
+The negation of a LazyTensor.
+"""
+struct TensorNegation{T,R,D,TM<:LazyTensor{T,R,D}} <: LazyTensor{T,R,D}
+    tm::TM
+end
+
+apply(tm::TensorNegation, v, I...) = -apply(tm.tm, v, I...)
+apply_transpose(tm::TensorNegation, v, I...) = -apply_transpose(tm.tm, v, I...)
+
+range_size(tm::TensorNegation) = range_size(tm.tm)
+domain_size(tm::TensorNegation) = domain_size(tm.tm)
+
 
-    function ElementwiseTensorOperation{Op,T,R,D}(tm1::T1,tm2::T2) where {Op,T,R,D, T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}}
-        @boundscheck check_domain_size(tm2, domain_size(tm1))
-        @boundscheck check_range_size(tm2, range_size(tm1))
-        return new{Op,T,R,D,T1,T2}(tm1,tm2)
+"""
+    TensorSum{T,R,D,...} <: LazyTensor{T,R,D}
+
+The lazy sum of 2 or more lazy tensors.
+"""
+struct TensorSum{T,R,D,TT<:NTuple{N, LazyTensor{T,R,D}} where N} <: LazyTensor{T,R,D}
+    tms::TT
+
+    function TensorSum{T,R,D}(tms::TT) where {T,R,D, TT<:NTuple{N, LazyTensor{T,R,D}} where N}
+        @boundscheck map(tms) do tm
+            check_domain_size(tm, domain_size(tms[1]))
+            check_range_size(tm, range_size(tms[1]))
+        end
+
+        return new{T,R,D,TT}(tms)
     end
 end
 
-ElementwiseTensorOperation{Op}(s,t) where Op = ElementwiseTensorOperation{Op,eltype(s), range_dim(s), domain_dim(s)}(s,t)
+"""
+    TensorSum(ts::Vararg{LazyTensor})
+
+The lazy sum of the tensors `ts`.
+"""
+function TensorSum(ts::Vararg{LazyTensor})
+    T = eltype(ts[1])
+    R = range_dim(ts[1])
+    D = domain_dim(ts[1])
+    return TensorSum{T,R,D}(ts)
+end
 
-apply(tmBinOp::ElementwiseTensorOperation{:+,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) + apply(tmBinOp.tm2, v, I...)
-apply(tmBinOp::ElementwiseTensorOperation{:-,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) - apply(tmBinOp.tm2, v, I...)
+function apply(tmBinOp::TensorSum{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D}
+    return sum(tmBinOp.tms) do tm
+        apply(tm,v,I...)
+    end
+end
 
-range_size(tmBinOp::ElementwiseTensorOperation) = range_size(tmBinOp.tm1)
-domain_size(tmBinOp::ElementwiseTensorOperation) = domain_size(tmBinOp.tm1)
+function apply_transpose(tmBinOp::TensorSum{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D}
+    return sum(tmBinOp.tms) do tm
+        apply_transpose(tm,v,I...)
+    end
+end
+
+range_size(tmBinOp::TensorSum) = range_size(tmBinOp.tms[1])
+domain_size(tmBinOp::TensorSum) = domain_size(tmBinOp.tms[1])
 
 
 """
@@ -103,7 +145,7 @@
     TensorComposition(tm, tmi::IdentityTensor)
     TensorComposition(tmi::IdentityTensor, tm)
 
-Composes a `Tensormapping` `tm` with an `IdentityTensor` `tmi`, by returning `tm`
+Composes a `LazyTensor` `tm` with an `IdentityTensor` `tmi`, by returning `tm`
 """
 function TensorComposition(tm::LazyTensor{T,R,D}, tmi::IdentityTensor{T,D}) where {T,R,D}
     @boundscheck check_domain_size(tm, range_size(tmi))
@@ -126,7 +168,7 @@
 """
     InflatedTensor{T,R,D} <: LazyTensor{T,R,D}
 
-An inflated `LazyTensor` with dimensions added before and afer its actual dimensions.
+An inflated `LazyTensor` with dimensions added before and after its actual dimensions.
 """
 struct InflatedTensor{T,R,D,D_before,R_middle,D_middle,D_after, TM<:LazyTensor{T,R_middle,D_middle}} <: LazyTensor{T,R,D}
     before::IdentityTensor{T,D_before}
@@ -169,15 +211,15 @@
     )
 end
 
-InflatedTensor(before::IdentityTensor, tm::LazyTensor{T}) where T = InflatedTensor(before,tm,IdentityTensor{T}())
-InflatedTensor(tm::LazyTensor{T}, after::IdentityTensor) where T = InflatedTensor(IdentityTensor{T}(),tm,after)
+InflatedTensor(before::IdentityTensor, tm::LazyTensor) = InflatedTensor(before,tm,IdentityTensor{eltype(tm)}())
+InflatedTensor(tm::LazyTensor, after::IdentityTensor) = InflatedTensor(IdentityTensor{eltype(tm)}(),tm,after)
 # Resolve ambiguity between the two previous methods
-InflatedTensor(I1::IdentityTensor{T}, I2::IdentityTensor{T}) where T = InflatedTensor(I1,I2,IdentityTensor{T}())
+InflatedTensor(I1::IdentityTensor, I2::IdentityTensor) = InflatedTensor(I1,I2,IdentityTensor{promote_type(eltype(I1), eltype(I2))}())
 
 # TODO: Implement some pretty printing in terms of ⊗. E.g InflatedTensor(I(3),B,I(2)) -> I(3)⊗B⊗I(2)
 
 function range_size(itm::InflatedTensor)
-    return flatten_tuple(
+    return concatenate_tuples(
         range_size(itm.before),
         range_size(itm.tm),
         range_size(itm.after),
@@ -185,7 +227,7 @@
 end
 
 function domain_size(itm::InflatedTensor)
-    return flatten_tuple(
+    return concatenate_tuples(
         domain_size(itm.before),
         domain_size(itm.tm),
         domain_size(itm.after),
@@ -198,7 +240,7 @@
     dim_range = range_dim(itm.tm)
     dim_after = range_dim(itm.after)
 
-    view_index, inner_index = split_index(Val(dim_before), Val(dim_domain), Val(dim_range), Val(dim_after), I...)
+    view_index, inner_index = split_index(dim_before, dim_domain, dim_range, dim_after, I...)
 
     v_inner = view(v, view_index...)
     return apply(itm.tm, v_inner, inner_index...)
@@ -210,7 +252,7 @@
     dim_range = range_dim(itm.tm)
     dim_after = range_dim(itm.after)
 
-    view_index, inner_index = split_index(Val(dim_before), Val(dim_range), Val(dim_domain), Val(dim_after), I...)
+    view_index, inner_index = split_index(dim_before, dim_range, dim_domain, dim_after, I...)
 
     v_inner = view(v, view_index...)
     return apply_transpose(itm.tm, v_inner, inner_index...)
@@ -229,7 +271,7 @@
 @doc raw"""
     LazyOuterProduct(tms...)
 
-Creates a `TensorComposition` for the outerproduct of `tms...`.
+Creates a `TensorComposition` for the outer product of `tms...`.
 This is done by separating the outer product into regular products of outer products involving only identity mappings and one non-identity mapping.
 
 First let
@@ -272,13 +314,36 @@
     return itm1∘itm2
 end
 
-LazyOuterProduct(t1::IdentityTensor{T}, t2::IdentityTensor{T}) where T = IdentityTensor{T}(t1.size...,t2.size...)
+LazyOuterProduct(t1::IdentityTensor, t2::IdentityTensor) = IdentityTensor{promote_type(eltype(t1),eltype(t2))}(t1.size...,t2.size...)
 LazyOuterProduct(t1::LazyTensor, t2::IdentityTensor) = InflatedTensor(t1, t2)
 LazyOuterProduct(t1::IdentityTensor, t2::LazyTensor) = InflatedTensor(t1, t2)
 
 LazyOuterProduct(tms::Vararg{LazyTensor}) = foldl(LazyOuterProduct, tms)
 
 
+
+"""
+    inflate(tm::LazyTensor, sz, dir)
+
+Inflate `tm` such that it gets the size `sz` in all directions except `dir`.
+Here `sz[dir]` is ignored and replaced with the range and domains size of
+`tm`.
+
+An example of when this operation is useful is when extending a one
+dimensional difference operator `D` to a 2D grid of a certain size. In that
+case we could have
+
+```julia
+Dx = inflate(D, (10,10), 1)
+Dy = inflate(D, (10,10), 2)
+```
+"""
+function inflate(tm::LazyTensor, sz, dir)
+    Is = IdentityTensor{eltype(tm)}.(sz)
+    parts = Base.setindex(Is, tm, dir)
+    return foldl(⊗, parts)
+end
+
 function check_domain_size(tm::LazyTensor, sz)
     if domain_size(tm) != sz
         throw(DomainSizeMismatch(tm,sz))