diff src/LazyTensors/lazy_tensor_operations.jl @ 1888:eb70a0941cd6 allocation_testing

Merge default
author Jonatan Werpers <jonatan@werpers.com>
date Fri, 03 Feb 2023 23:02:46 +0100
parents 2278730f9cee
children 6f51160c7ca7 f1c2a4fa0ee1
line wrap: on
line diff
--- a/src/LazyTensors/lazy_tensor_operations.jl	Thu Apr 07 17:02:51 2022 +0200
+++ b/src/LazyTensors/lazy_tensor_operations.jl	Fri Feb 03 23:02:46 2023 +0100
@@ -98,7 +98,6 @@
     apply_transpose(c.t2, c.t1'*v, I...)
 end
 
-
 """
     TensorComposition(tm, tmi::IdentityTensor)
     TensorComposition(tmi::IdentityTensor, tm)
@@ -120,6 +119,8 @@
     return tmi
 end
 
+Base.:*(a::T, tm::LazyTensor{T}) where T = TensorComposition(ScalingTensor{T,range_dim(tm)}(a,range_size(tm)), tm)
+Base.:*(tm::LazyTensor{T}, a::T) where T = a*tm
 
 """
     InflatedTensor{T,R,D} <: LazyTensor{T,R,D}
@@ -268,6 +269,29 @@
 LazyOuterProduct(tms::Vararg{LazyTensor}) = foldl(LazyOuterProduct, tms)
 
 
+
+"""
+    inflate(tm::LazyTensor, sz, dir)
+
+Inflate `tm` such that it gets the size `sz` in all directions except `dir`.
+Here `sz[dir]` is ignored and replaced with the range and domains size of
+`tm`.
+
+An example of when this operation is useful is when extending a one
+dimensional difference operator `D` to a 2D grid of a ceratin size. In that
+case we could have
+
+```julia
+Dx = inflate(D, (10,10), 1)
+Dy = inflate(D, (10,10), 2)
+```
+"""
+function inflate(tm::LazyTensor, sz, dir)
+    Is = IdentityTensor{eltype(tm)}.(sz)
+    parts = Base.setindex(Is, tm, dir)
+    return foldl(⊗, parts)
+end
+
 function check_domain_size(tm::LazyTensor, sz)
     if domain_size(tm) != sz
         throw(DomainSizeMismatch(tm,sz))