Mercurial > repos > public > sbplib_julia
diff src/LazyTensors/lazy_tensor_operations.jl @ 1395:bdcdbd4ea9cd feature/boundary_conditions
Merge with default. Comment out broken tests for boundary_conditions at sat
author | Vidar Stiernström <vidar.stiernstrom@it.uu.se> |
---|---|
date | Wed, 26 Jul 2023 21:35:50 +0200 |
parents | b4ee47f2aafb 4684c7f1c4cb |
children | d68d02dd882f |
line wrap: on
line diff
--- a/src/LazyTensors/lazy_tensor_operations.jl Tue Feb 07 21:55:07 2023 +0100 +++ b/src/LazyTensors/lazy_tensor_operations.jl Wed Jul 26 21:35:50 2023 +0200 @@ -52,7 +52,7 @@ domain_size(tmt::TensorTranspose) = range_size(tmt.tm) -struct ElementwiseTensorOperation{Op,T,R,D,T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}} <: LazyTensor{T,D,R} +struct ElementwiseTensorOperation{Op,T,R,D,T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}} <: LazyTensor{T,R,D} tm1::T1 tm2::T2 @@ -177,7 +177,7 @@ # TODO: Implement some pretty printing in terms of ⊗. E.g InflatedTensor(I(3),B,I(2)) -> I(3)⊗B⊗I(2) function range_size(itm::InflatedTensor) - return flatten_tuple( + return concatenate_tuples( range_size(itm.before), range_size(itm.tm), range_size(itm.after), @@ -185,7 +185,7 @@ end function domain_size(itm::InflatedTensor) - return flatten_tuple( + return concatenate_tuples( domain_size(itm.before), domain_size(itm.tm), domain_size(itm.after), @@ -198,7 +198,7 @@ dim_range = range_dim(itm.tm) dim_after = range_dim(itm.after) - view_index, inner_index = split_index(Val(dim_before), Val(dim_domain), Val(dim_range), Val(dim_after), I...) + view_index, inner_index = split_index(dim_before, dim_domain, dim_range, dim_after, I...) v_inner = view(v, view_index...) return apply(itm.tm, v_inner, inner_index...) @@ -210,7 +210,7 @@ dim_range = range_dim(itm.tm) dim_after = range_dim(itm.after) - view_index, inner_index = split_index(Val(dim_before), Val(dim_range), Val(dim_domain), Val(dim_after), I...) + view_index, inner_index = split_index(dim_before, dim_range, dim_domain, dim_after, I...) v_inner = view(v, view_index...) return apply_transpose(itm.tm, v_inner, inner_index...) @@ -270,6 +270,29 @@ LazyOuterProduct(tms::Vararg{LazyTensor}) = foldl(LazyOuterProduct, tms) + +""" + inflate(tm::LazyTensor, sz, dir) + +Inflate `tm` such that it gets the size `sz` in all directions except `dir`. +Here `sz[dir]` is ignored and replaced with the range and domains size of +`tm`. + +An example of when this operation is useful is when extending a one +dimensional difference operator `D` to a 2D grid of a ceratin size. In that +case we could have + +```julia +Dx = inflate(D, (10,10), 1) +Dy = inflate(D, (10,10), 2) +``` +""" +function inflate(tm::LazyTensor, sz, dir) + Is = IdentityTensor{eltype(tm)}.(sz) + parts = Base.setindex(Is, tm, dir) + return foldl(⊗, parts) +end + function check_domain_size(tm::LazyTensor, sz) if domain_size(tm) != sz throw(DomainSizeMismatch(tm,sz))