Mercurial > repos > public > sbplib_julia
comparison src/LazyTensors/lazy_tensor_operations.jl @ 1395:bdcdbd4ea9cd feature/boundary_conditions
Merge with default. Comment out broken tests for boundary_conditions at sat
author | Vidar Stiernström <vidar.stiernstrom@it.uu.se> |
---|---|
date | Wed, 26 Jul 2023 21:35:50 +0200 |
parents | b4ee47f2aafb 4684c7f1c4cb |
children | d68d02dd882f |
comparison
equal
deleted
inserted
replaced
1217:ea2e8254820a | 1395:bdcdbd4ea9cd |
---|---|
50 | 50 |
51 range_size(tmt::TensorTranspose) = domain_size(tmt.tm) | 51 range_size(tmt::TensorTranspose) = domain_size(tmt.tm) |
52 domain_size(tmt::TensorTranspose) = range_size(tmt.tm) | 52 domain_size(tmt::TensorTranspose) = range_size(tmt.tm) |
53 | 53 |
54 | 54 |
55 struct ElementwiseTensorOperation{Op,T,R,D,T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}} <: LazyTensor{T,D,R} | 55 struct ElementwiseTensorOperation{Op,T,R,D,T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}} <: LazyTensor{T,R,D} |
56 tm1::T1 | 56 tm1::T1 |
57 tm2::T2 | 57 tm2::T2 |
58 | 58 |
59 function ElementwiseTensorOperation{Op,T,R,D}(tm1::T1,tm2::T2) where {Op,T,R,D, T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}} | 59 function ElementwiseTensorOperation{Op,T,R,D}(tm1::T1,tm2::T2) where {Op,T,R,D, T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}} |
60 @boundscheck check_domain_size(tm2, domain_size(tm1)) | 60 @boundscheck check_domain_size(tm2, domain_size(tm1)) |
175 InflatedTensor(I1::IdentityTensor{T}, I2::IdentityTensor{T}) where T = InflatedTensor(I1,I2,IdentityTensor{T}()) | 175 InflatedTensor(I1::IdentityTensor{T}, I2::IdentityTensor{T}) where T = InflatedTensor(I1,I2,IdentityTensor{T}()) |
176 | 176 |
177 # TODO: Implement some pretty printing in terms of ⊗. E.g InflatedTensor(I(3),B,I(2)) -> I(3)⊗B⊗I(2) | 177 # TODO: Implement some pretty printing in terms of ⊗. E.g InflatedTensor(I(3),B,I(2)) -> I(3)⊗B⊗I(2) |
178 | 178 |
179 function range_size(itm::InflatedTensor) | 179 function range_size(itm::InflatedTensor) |
180 return flatten_tuple( | 180 return concatenate_tuples( |
181 range_size(itm.before), | 181 range_size(itm.before), |
182 range_size(itm.tm), | 182 range_size(itm.tm), |
183 range_size(itm.after), | 183 range_size(itm.after), |
184 ) | 184 ) |
185 end | 185 end |
186 | 186 |
187 function domain_size(itm::InflatedTensor) | 187 function domain_size(itm::InflatedTensor) |
188 return flatten_tuple( | 188 return concatenate_tuples( |
189 domain_size(itm.before), | 189 domain_size(itm.before), |
190 domain_size(itm.tm), | 190 domain_size(itm.tm), |
191 domain_size(itm.after), | 191 domain_size(itm.after), |
192 ) | 192 ) |
193 end | 193 end |
196 dim_before = range_dim(itm.before) | 196 dim_before = range_dim(itm.before) |
197 dim_domain = domain_dim(itm.tm) | 197 dim_domain = domain_dim(itm.tm) |
198 dim_range = range_dim(itm.tm) | 198 dim_range = range_dim(itm.tm) |
199 dim_after = range_dim(itm.after) | 199 dim_after = range_dim(itm.after) |
200 | 200 |
201 view_index, inner_index = split_index(Val(dim_before), Val(dim_domain), Val(dim_range), Val(dim_after), I...) | 201 view_index, inner_index = split_index(dim_before, dim_domain, dim_range, dim_after, I...) |
202 | 202 |
203 v_inner = view(v, view_index...) | 203 v_inner = view(v, view_index...) |
204 return apply(itm.tm, v_inner, inner_index...) | 204 return apply(itm.tm, v_inner, inner_index...) |
205 end | 205 end |
206 | 206 |
208 dim_before = range_dim(itm.before) | 208 dim_before = range_dim(itm.before) |
209 dim_domain = domain_dim(itm.tm) | 209 dim_domain = domain_dim(itm.tm) |
210 dim_range = range_dim(itm.tm) | 210 dim_range = range_dim(itm.tm) |
211 dim_after = range_dim(itm.after) | 211 dim_after = range_dim(itm.after) |
212 | 212 |
213 view_index, inner_index = split_index(Val(dim_before), Val(dim_range), Val(dim_domain), Val(dim_after), I...) | 213 view_index, inner_index = split_index(dim_before, dim_range, dim_domain, dim_after, I...) |
214 | 214 |
215 v_inner = view(v, view_index...) | 215 v_inner = view(v, view_index...) |
216 return apply_transpose(itm.tm, v_inner, inner_index...) | 216 return apply_transpose(itm.tm, v_inner, inner_index...) |
217 end | 217 end |
218 | 218 |
268 LazyOuterProduct(t1::IdentityTensor, t2::LazyTensor) = InflatedTensor(t1, t2) | 268 LazyOuterProduct(t1::IdentityTensor, t2::LazyTensor) = InflatedTensor(t1, t2) |
269 | 269 |
270 LazyOuterProduct(tms::Vararg{LazyTensor}) = foldl(LazyOuterProduct, tms) | 270 LazyOuterProduct(tms::Vararg{LazyTensor}) = foldl(LazyOuterProduct, tms) |
271 | 271 |
272 | 272 |
273 | |
274 """ | |
275 inflate(tm::LazyTensor, sz, dir) | |
276 | |
277 Inflate `tm` such that it gets the size `sz` in all directions except `dir`. | |
278 Here `sz[dir]` is ignored and replaced with the range and domains size of | |
279 `tm`. | |
280 | |
281 An example of when this operation is useful is when extending a one | |
282 dimensional difference operator `D` to a 2D grid of a ceratin size. In that | |
283 case we could have | |
284 | |
285 ```julia | |
286 Dx = inflate(D, (10,10), 1) | |
287 Dy = inflate(D, (10,10), 2) | |
288 ``` | |
289 """ | |
290 function inflate(tm::LazyTensor, sz, dir) | |
291 Is = IdentityTensor{eltype(tm)}.(sz) | |
292 parts = Base.setindex(Is, tm, dir) | |
293 return foldl(⊗, parts) | |
294 end | |
295 | |
273 function check_domain_size(tm::LazyTensor, sz) | 296 function check_domain_size(tm::LazyTensor, sz) |
274 if domain_size(tm) != sz | 297 if domain_size(tm) != sz |
275 throw(DomainSizeMismatch(tm,sz)) | 298 throw(DomainSizeMismatch(tm,sz)) |
276 end | 299 end |
277 end | 300 end |