comparison src/LazyTensors/lazy_tensor_operations.jl @ 2057:8a2a0d678d6f feature/lazy_tensors/pretty_printing

Merge default
author Jonatan Werpers <jonatan@werpers.com>
date Tue, 10 Feb 2026 22:41:19 +0100
parents 1e8270c18edb ed50eec18365
children
comparison
equal deleted inserted replaced
1110:c0bff9f6e0fb 2057:8a2a0d678d6f
3 3
4 Struct for lazy application of a LazyTensor. Created using `*`. 4 Struct for lazy application of a LazyTensor. Created using `*`.
5 5
6 Allows the result of a `LazyTensor` applied to a vector to be treated as an `AbstractArray`. 6 Allows the result of a `LazyTensor` applied to a vector to be treated as an `AbstractArray`.
7 With a mapping `m` and a vector `v` the TensorApplication object can be created by `m*v`. 7 With a mapping `m` and a vector `v` the TensorApplication object can be created by `m*v`.
8 The actual result will be calcualted when indexing into `m*v`. 8 The actual result will be calculated when indexing into `m*v`.
9 """ 9 """
10 struct TensorApplication{T,R,D, TM<:LazyTensor{<:Any,R,D}, AA<:AbstractArray{<:Any,D}} <: LazyArray{T,R} 10 struct TensorApplication{T,R,D, TM<:LazyTensor{<:Any,R,D}, AA<:AbstractArray{<:Any,D}} <: LazyArray{T,R}
11 t::TM 11 t::TM
12 o::AA 12 o::AA
13 13
50 50
51 range_size(tmt::TensorTranspose) = domain_size(tmt.tm) 51 range_size(tmt::TensorTranspose) = domain_size(tmt.tm)
52 domain_size(tmt::TensorTranspose) = range_size(tmt.tm) 52 domain_size(tmt::TensorTranspose) = range_size(tmt.tm)
53 53
54 54
55 struct ElementwiseTensorOperation{Op,T,R,D,T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}} <: LazyTensor{T,D,R} 55 """
56 tm1::T1 56 TensorNegation{T,R,D,...} <: LazyTensor{T,R,D}
57 tm2::T2 57
58 58 The negation of a LazyTensor.
59 function ElementwiseTensorOperation{Op,T,R,D}(tm1::T1,tm2::T2) where {Op,T,R,D, T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}} 59 """
60 @boundscheck check_domain_size(tm2, domain_size(tm1)) 60 struct TensorNegation{T,R,D,TM<:LazyTensor{T,R,D}} <: LazyTensor{T,R,D}
61 @boundscheck check_range_size(tm2, range_size(tm1)) 61 tm::TM
62 return new{Op,T,R,D,T1,T2}(tm1,tm2) 62 end
63 end 63
64 end 64 apply(tm::TensorNegation, v, I...) = -apply(tm.tm, v, I...)
65 65 apply_transpose(tm::TensorNegation, v, I...) = -apply_transpose(tm.tm, v, I...)
66 ElementwiseTensorOperation{Op}(s,t) where Op = ElementwiseTensorOperation{Op,eltype(s), range_dim(s), domain_dim(s)}(s,t) 66
67 67 range_size(tm::TensorNegation) = range_size(tm.tm)
68 apply(tmBinOp::ElementwiseTensorOperation{:+,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) + apply(tmBinOp.tm2, v, I...) 68 domain_size(tm::TensorNegation) = domain_size(tm.tm)
69 apply(tmBinOp::ElementwiseTensorOperation{:-,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) - apply(tmBinOp.tm2, v, I...) 69
70 70
71 range_size(tmBinOp::ElementwiseTensorOperation) = range_size(tmBinOp.tm1) 71 """
72 domain_size(tmBinOp::ElementwiseTensorOperation) = domain_size(tmBinOp.tm1) 72 TensorSum{T,R,D,...} <: LazyTensor{T,R,D}
73
74 The lazy sum of 2 or more lazy tensors.
75 """
76 struct TensorSum{T,R,D,TT<:NTuple{N, LazyTensor{T,R,D}} where N} <: LazyTensor{T,R,D}
77 tms::TT
78
79 function TensorSum{T,R,D}(tms::TT) where {T,R,D, TT<:NTuple{N, LazyTensor{T,R,D}} where N}
80 @boundscheck map(tms) do tm
81 check_domain_size(tm, domain_size(tms[1]))
82 check_range_size(tm, range_size(tms[1]))
83 end
84
85 return new{T,R,D,TT}(tms)
86 end
87 end
88
89 """
90 TensorSum(ts::Vararg{LazyTensor})
91
92 The lazy sum of the tensors `ts`.
93 """
94 function TensorSum(ts::Vararg{LazyTensor})
95 T = eltype(ts[1])
96 R = range_dim(ts[1])
97 D = domain_dim(ts[1])
98 return TensorSum{T,R,D}(ts)
99 end
100
101 function apply(tmBinOp::TensorSum{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D}
102 return sum(tmBinOp.tms) do tm
103 apply(tm,v,I...)
104 end
105 end
106
107 function apply_transpose(tmBinOp::TensorSum{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D}
108 return sum(tmBinOp.tms) do tm
109 apply_transpose(tm,v,I...)
110 end
111 end
112
113 range_size(tmBinOp::TensorSum) = range_size(tmBinOp.tms[1])
114 domain_size(tmBinOp::TensorSum) = domain_size(tmBinOp.tms[1])
73 115
74 116
75 """ 117 """
76 TensorComposition{T,R,K,D} 118 TensorComposition{T,R,K,D}
77 119
101 143
102 """ 144 """
103 TensorComposition(tm, tmi::IdentityTensor) 145 TensorComposition(tm, tmi::IdentityTensor)
104 TensorComposition(tmi::IdentityTensor, tm) 146 TensorComposition(tmi::IdentityTensor, tm)
105 147
106 Composes a `Tensormapping` `tm` with an `IdentityTensor` `tmi`, by returning `tm` 148 Composes a `LazyTensor` `tm` with an `IdentityTensor` `tmi`, by returning `tm`
107 """ 149 """
108 function TensorComposition(tm::LazyTensor{T,R,D}, tmi::IdentityTensor{T,D}) where {T,R,D} 150 function TensorComposition(tm::LazyTensor{T,R,D}, tmi::IdentityTensor{T,D}) where {T,R,D}
109 @boundscheck check_domain_size(tm, range_size(tmi)) 151 @boundscheck check_domain_size(tm, range_size(tmi))
110 return tm 152 return tm
111 end 153 end
124 Base.:*(tm::LazyTensor{T}, a::T) where T = a*tm 166 Base.:*(tm::LazyTensor{T}, a::T) where T = a*tm
125 167
126 """ 168 """
127 InflatedTensor{T,R,D} <: LazyTensor{T,R,D} 169 InflatedTensor{T,R,D} <: LazyTensor{T,R,D}
128 170
129 An inflated `LazyTensor` with dimensions added before and afer its actual dimensions. 171 An inflated `LazyTensor` with dimensions added before and after its actual dimensions.
130 """ 172 """
131 struct InflatedTensor{T,R,D,D_before,R_middle,D_middle,D_after, TM<:LazyTensor{T,R_middle,D_middle}} <: LazyTensor{T,R,D} 173 struct InflatedTensor{T,R,D,D_before,R_middle,D_middle,D_after, TM<:LazyTensor{T,R_middle,D_middle}} <: LazyTensor{T,R,D}
132 before::IdentityTensor{T,D_before} 174 before::IdentityTensor{T,D_before}
133 tm::TM 175 tm::TM
134 after::IdentityTensor{T,D_after} 176 after::IdentityTensor{T,D_after}
167 itm.tm, 209 itm.tm,
168 IdentityTensor(itm.after.size..., after.size...), 210 IdentityTensor(itm.after.size..., after.size...),
169 ) 211 )
170 end 212 end
171 213
172 InflatedTensor(before::IdentityTensor, tm::LazyTensor{T}) where T = InflatedTensor(before,tm,IdentityTensor{T}()) 214 InflatedTensor(before::IdentityTensor, tm::LazyTensor) = InflatedTensor(before,tm,IdentityTensor{eltype(tm)}())
173 InflatedTensor(tm::LazyTensor{T}, after::IdentityTensor) where T = InflatedTensor(IdentityTensor{T}(),tm,after) 215 InflatedTensor(tm::LazyTensor, after::IdentityTensor) = InflatedTensor(IdentityTensor{eltype(tm)}(),tm,after)
174 # Resolve ambiguity between the two previous methods 216 # Resolve ambiguity between the two previous methods
175 InflatedTensor(I1::IdentityTensor{T}, I2::IdentityTensor{T}) where T = InflatedTensor(I1,I2,IdentityTensor{T}()) 217 InflatedTensor(I1::IdentityTensor, I2::IdentityTensor) = InflatedTensor(I1,I2,IdentityTensor{promote_type(eltype(I1), eltype(I2))}())
176 218
177 # TODO: Implement some pretty printing in terms of ⊗. E.g InflatedTensor(I(3),B,I(2)) -> I(3)⊗B⊗I(2) 219 # TODO: Implement some pretty printing in terms of ⊗. E.g InflatedTensor(I(3),B,I(2)) -> I(3)⊗B⊗I(2)
178 220
179 function range_size(itm::InflatedTensor) 221 function range_size(itm::InflatedTensor)
180 return flatten_tuple( 222 return concatenate_tuples(
181 range_size(itm.before), 223 range_size(itm.before),
182 range_size(itm.tm), 224 range_size(itm.tm),
183 range_size(itm.after), 225 range_size(itm.after),
184 ) 226 )
185 end 227 end
186 228
187 function domain_size(itm::InflatedTensor) 229 function domain_size(itm::InflatedTensor)
188 return flatten_tuple( 230 return concatenate_tuples(
189 domain_size(itm.before), 231 domain_size(itm.before),
190 domain_size(itm.tm), 232 domain_size(itm.tm),
191 domain_size(itm.after), 233 domain_size(itm.after),
192 ) 234 )
193 end 235 end
196 dim_before = range_dim(itm.before) 238 dim_before = range_dim(itm.before)
197 dim_domain = domain_dim(itm.tm) 239 dim_domain = domain_dim(itm.tm)
198 dim_range = range_dim(itm.tm) 240 dim_range = range_dim(itm.tm)
199 dim_after = range_dim(itm.after) 241 dim_after = range_dim(itm.after)
200 242
201 view_index, inner_index = split_index(Val(dim_before), Val(dim_domain), Val(dim_range), Val(dim_after), I...) 243 view_index, inner_index = split_index(dim_before, dim_domain, dim_range, dim_after, I...)
202 244
203 v_inner = view(v, view_index...) 245 v_inner = view(v, view_index...)
204 return apply(itm.tm, v_inner, inner_index...) 246 return apply(itm.tm, v_inner, inner_index...)
205 end 247 end
206 248
208 dim_before = range_dim(itm.before) 250 dim_before = range_dim(itm.before)
209 dim_domain = domain_dim(itm.tm) 251 dim_domain = domain_dim(itm.tm)
210 dim_range = range_dim(itm.tm) 252 dim_range = range_dim(itm.tm)
211 dim_after = range_dim(itm.after) 253 dim_after = range_dim(itm.after)
212 254
213 view_index, inner_index = split_index(Val(dim_before), Val(dim_range), Val(dim_domain), Val(dim_after), I...) 255 view_index, inner_index = split_index(dim_before, dim_range, dim_domain, dim_after, I...)
214 256
215 v_inner = view(v, view_index...) 257 v_inner = view(v, view_index...)
216 return apply_transpose(itm.tm, v_inner, inner_index...) 258 return apply_transpose(itm.tm, v_inner, inner_index...)
217 end 259 end
218 260
227 269
228 270
229 @doc raw""" 271 @doc raw"""
230 LazyOuterProduct(tms...) 272 LazyOuterProduct(tms...)
231 273
232 Creates a `TensorComposition` for the outerproduct of `tms...`. 274 Creates a `TensorComposition` for the outer product of `tms...`.
233 This is done by separating the outer product into regular products of outer products involving only identity mappings and one non-identity mapping. 275 This is done by separating the outer product into regular products of outer products involving only identity mappings and one non-identity mapping.
234 276
235 First let 277 First let
236 ```math 278 ```math
237 \begin{aligned} 279 \begin{aligned}
270 itm2 = InflatedTensor(IdentityTensor{T}(domain_size(tm1)),tm2) 312 itm2 = InflatedTensor(IdentityTensor{T}(domain_size(tm1)),tm2)
271 313
272 return itm1∘itm2 314 return itm1∘itm2
273 end 315 end
274 316
275 LazyOuterProduct(t1::IdentityTensor{T}, t2::IdentityTensor{T}) where T = IdentityTensor{T}(t1.size...,t2.size...) 317 LazyOuterProduct(t1::IdentityTensor, t2::IdentityTensor) = IdentityTensor{promote_type(eltype(t1),eltype(t2))}(t1.size...,t2.size...)
276 LazyOuterProduct(t1::LazyTensor, t2::IdentityTensor) = InflatedTensor(t1, t2) 318 LazyOuterProduct(t1::LazyTensor, t2::IdentityTensor) = InflatedTensor(t1, t2)
277 LazyOuterProduct(t1::IdentityTensor, t2::LazyTensor) = InflatedTensor(t1, t2) 319 LazyOuterProduct(t1::IdentityTensor, t2::LazyTensor) = InflatedTensor(t1, t2)
278 320
279 LazyOuterProduct(tms::Vararg{LazyTensor}) = foldl(LazyOuterProduct, tms) 321 LazyOuterProduct(tms::Vararg{LazyTensor}) = foldl(LazyOuterProduct, tms)
280 322
323
324
325 """
326 inflate(tm::LazyTensor, sz, dir)
327
328 Inflate `tm` such that it gets the size `sz` in all directions except `dir`.
329 Here `sz[dir]` is ignored and replaced with the range and domains size of
330 `tm`.
331
332 An example of when this operation is useful is when extending a one
333 dimensional difference operator `D` to a 2D grid of a certain size. In that
334 case we could have
335
336 ```julia
337 Dx = inflate(D, (10,10), 1)
338 Dy = inflate(D, (10,10), 2)
339 ```
340 """
341 function inflate(tm::LazyTensor, sz, dir)
342 Is = IdentityTensor{eltype(tm)}.(sz)
343 parts = Base.setindex(Is, tm, dir)
344 return foldl(⊗, parts)
345 end
281 346
282 function check_domain_size(tm::LazyTensor, sz) 347 function check_domain_size(tm::LazyTensor, sz)
283 if domain_size(tm) != sz 348 if domain_size(tm) != sz
284 throw(DomainSizeMismatch(tm,sz)) 349 throw(DomainSizeMismatch(tm,sz))
285 end 350 end