Mercurial > repos > public > sbplib_julia
diff src/LazyTensors/lazy_tensor_operations.jl @ 1888:eb70a0941cd6 allocation_testing
Merge default
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Fri, 03 Feb 2023 23:02:46 +0100 |
parents | 2278730f9cee |
children | 6f51160c7ca7 f1c2a4fa0ee1 |
line wrap: on
line diff
--- a/src/LazyTensors/lazy_tensor_operations.jl Thu Apr 07 17:02:51 2022 +0200 +++ b/src/LazyTensors/lazy_tensor_operations.jl Fri Feb 03 23:02:46 2023 +0100 @@ -98,7 +98,6 @@ apply_transpose(c.t2, c.t1'*v, I...) end - """ TensorComposition(tm, tmi::IdentityTensor) TensorComposition(tmi::IdentityTensor, tm) @@ -120,6 +119,8 @@ return tmi end +Base.:*(a::T, tm::LazyTensor{T}) where T = TensorComposition(ScalingTensor{T,range_dim(tm)}(a,range_size(tm)), tm) +Base.:*(tm::LazyTensor{T}, a::T) where T = a*tm """ InflatedTensor{T,R,D} <: LazyTensor{T,R,D} @@ -268,6 +269,29 @@ LazyOuterProduct(tms::Vararg{LazyTensor}) = foldl(LazyOuterProduct, tms) + +""" + inflate(tm::LazyTensor, sz, dir) + +Inflate `tm` such that it gets the size `sz` in all directions except `dir`. +Here `sz[dir]` is ignored and replaced with the range and domains size of +`tm`. + +An example of when this operation is useful is when extending a one +dimensional difference operator `D` to a 2D grid of a ceratin size. In that +case we could have + +```julia +Dx = inflate(D, (10,10), 1) +Dy = inflate(D, (10,10), 2) +``` +""" +function inflate(tm::LazyTensor, sz, dir) + Is = IdentityTensor{eltype(tm)}.(sz) + parts = Base.setindex(Is, tm, dir) + return foldl(⊗, parts) +end + function check_domain_size(tm::LazyTensor, sz) if domain_size(tm) != sz throw(DomainSizeMismatch(tm,sz))