comparison src/LazyTensors/lazy_tensor_operations.jl @ 1395:bdcdbd4ea9cd feature/boundary_conditions

Merge with default. Comment out broken tests for boundary_conditions at sat
author Vidar Stiernström <vidar.stiernstrom@it.uu.se>
date Wed, 26 Jul 2023 21:35:50 +0200
parents b4ee47f2aafb 4684c7f1c4cb
children d68d02dd882f
comparison
equal deleted inserted replaced
1217:ea2e8254820a 1395:bdcdbd4ea9cd
50 50
51 range_size(tmt::TensorTranspose) = domain_size(tmt.tm) 51 range_size(tmt::TensorTranspose) = domain_size(tmt.tm)
52 domain_size(tmt::TensorTranspose) = range_size(tmt.tm) 52 domain_size(tmt::TensorTranspose) = range_size(tmt.tm)
53 53
54 54
55 struct ElementwiseTensorOperation{Op,T,R,D,T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}} <: LazyTensor{T,D,R} 55 struct ElementwiseTensorOperation{Op,T,R,D,T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}} <: LazyTensor{T,R,D}
56 tm1::T1 56 tm1::T1
57 tm2::T2 57 tm2::T2
58 58
59 function ElementwiseTensorOperation{Op,T,R,D}(tm1::T1,tm2::T2) where {Op,T,R,D, T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}} 59 function ElementwiseTensorOperation{Op,T,R,D}(tm1::T1,tm2::T2) where {Op,T,R,D, T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}}
60 @boundscheck check_domain_size(tm2, domain_size(tm1)) 60 @boundscheck check_domain_size(tm2, domain_size(tm1))
175 InflatedTensor(I1::IdentityTensor{T}, I2::IdentityTensor{T}) where T = InflatedTensor(I1,I2,IdentityTensor{T}()) 175 InflatedTensor(I1::IdentityTensor{T}, I2::IdentityTensor{T}) where T = InflatedTensor(I1,I2,IdentityTensor{T}())
176 176
177 # TODO: Implement some pretty printing in terms of ⊗. E.g InflatedTensor(I(3),B,I(2)) -> I(3)⊗B⊗I(2) 177 # TODO: Implement some pretty printing in terms of ⊗. E.g InflatedTensor(I(3),B,I(2)) -> I(3)⊗B⊗I(2)
178 178
179 function range_size(itm::InflatedTensor) 179 function range_size(itm::InflatedTensor)
180 return flatten_tuple( 180 return concatenate_tuples(
181 range_size(itm.before), 181 range_size(itm.before),
182 range_size(itm.tm), 182 range_size(itm.tm),
183 range_size(itm.after), 183 range_size(itm.after),
184 ) 184 )
185 end 185 end
186 186
187 function domain_size(itm::InflatedTensor) 187 function domain_size(itm::InflatedTensor)
188 return flatten_tuple( 188 return concatenate_tuples(
189 domain_size(itm.before), 189 domain_size(itm.before),
190 domain_size(itm.tm), 190 domain_size(itm.tm),
191 domain_size(itm.after), 191 domain_size(itm.after),
192 ) 192 )
193 end 193 end
196 dim_before = range_dim(itm.before) 196 dim_before = range_dim(itm.before)
197 dim_domain = domain_dim(itm.tm) 197 dim_domain = domain_dim(itm.tm)
198 dim_range = range_dim(itm.tm) 198 dim_range = range_dim(itm.tm)
199 dim_after = range_dim(itm.after) 199 dim_after = range_dim(itm.after)
200 200
201 view_index, inner_index = split_index(Val(dim_before), Val(dim_domain), Val(dim_range), Val(dim_after), I...) 201 view_index, inner_index = split_index(dim_before, dim_domain, dim_range, dim_after, I...)
202 202
203 v_inner = view(v, view_index...) 203 v_inner = view(v, view_index...)
204 return apply(itm.tm, v_inner, inner_index...) 204 return apply(itm.tm, v_inner, inner_index...)
205 end 205 end
206 206
208 dim_before = range_dim(itm.before) 208 dim_before = range_dim(itm.before)
209 dim_domain = domain_dim(itm.tm) 209 dim_domain = domain_dim(itm.tm)
210 dim_range = range_dim(itm.tm) 210 dim_range = range_dim(itm.tm)
211 dim_after = range_dim(itm.after) 211 dim_after = range_dim(itm.after)
212 212
213 view_index, inner_index = split_index(Val(dim_before), Val(dim_range), Val(dim_domain), Val(dim_after), I...) 213 view_index, inner_index = split_index(dim_before, dim_range, dim_domain, dim_after, I...)
214 214
215 v_inner = view(v, view_index...) 215 v_inner = view(v, view_index...)
216 return apply_transpose(itm.tm, v_inner, inner_index...) 216 return apply_transpose(itm.tm, v_inner, inner_index...)
217 end 217 end
218 218
268 LazyOuterProduct(t1::IdentityTensor, t2::LazyTensor) = InflatedTensor(t1, t2) 268 LazyOuterProduct(t1::IdentityTensor, t2::LazyTensor) = InflatedTensor(t1, t2)
269 269
270 LazyOuterProduct(tms::Vararg{LazyTensor}) = foldl(LazyOuterProduct, tms) 270 LazyOuterProduct(tms::Vararg{LazyTensor}) = foldl(LazyOuterProduct, tms)
271 271
272 272
273
274 """
275 inflate(tm::LazyTensor, sz, dir)
276
277 Inflate `tm` such that it gets the size `sz` in all directions except `dir`.
278 Here `sz[dir]` is ignored and replaced with the range and domains size of
279 `tm`.
280
281 An example of when this operation is useful is when extending a one
282 dimensional difference operator `D` to a 2D grid of a ceratin size. In that
283 case we could have
284
285 ```julia
286 Dx = inflate(D, (10,10), 1)
287 Dy = inflate(D, (10,10), 2)
288 ```
289 """
290 function inflate(tm::LazyTensor, sz, dir)
291 Is = IdentityTensor{eltype(tm)}.(sz)
292 parts = Base.setindex(Is, tm, dir)
293 return foldl(⊗, parts)
294 end
295
273 function check_domain_size(tm::LazyTensor, sz) 296 function check_domain_size(tm::LazyTensor, sz)
274 if domain_size(tm) != sz 297 if domain_size(tm) != sz
275 throw(DomainSizeMismatch(tm,sz)) 298 throw(DomainSizeMismatch(tm,sz))
276 end 299 end
277 end 300 end