changeset 493:df566372bb4f feature/avoid_nested_inflated_tensormappings

Implement constructors to avoid creating nested InflatedTensorMappings
author Jonatan Werpers <jonatan@werpers.com>
date Thu, 05 Nov 2020 13:18:24 +0100
parents 7e698030c170
children f906f207571c
files src/LazyTensors/lazy_tensor_operations.jl test/testLazyTensors.jl
diffstat 2 files changed, 22 insertions(+), 0 deletions(-) [+]
line wrap: on
line diff
--- a/src/LazyTensors/lazy_tensor_operations.jl	Thu Nov 05 11:46:03 2020 +0100
+++ b/src/LazyTensors/lazy_tensor_operations.jl	Thu Nov 05 13:18:24 2020 +0100
@@ -221,8 +221,20 @@
 The outer product of `before`, `tm` and `after`, where `before` and `after` are `IdentityMapping`s.
 
 If one of `before` or `after` is left out, a 0-dimensional `IdentityMapping` is used as the default value.
+
+If `tm` already is an `InflatedTensorMapping`, `before` and `after` will be extended instead of
+creating a nested `InflatedTensorMapping`.
 """
 InflatedTensorMapping(::IdentityMapping, ::TensorMapping, ::IdentityMapping)
+
+function InflatedTensorMapping(before, itm::InflatedTensorMapping, after)
+    return InflatedTensorMapping(
+        IdentityMapping(before.size...,  itm.before.size...),
+        itm.tm,
+        IdentityMapping(itm.after.size..., after.size...),
+    )
+end
+
 InflatedTensorMapping(before::IdentityMapping, tm::TensorMapping{T}) where T = InflatedTensorMapping(before,tm,IdentityMapping{T}())
 InflatedTensorMapping(tm::TensorMapping{T}, after::IdentityMapping) where T = InflatedTensorMapping(IdentityMapping{T}(),tm,after)
 # Resolve ambiguity between the two previous methods
--- a/test/testLazyTensors.jl	Thu Nov 05 11:46:03 2020 +0100
+++ b/test/testLazyTensors.jl	Thu Nov 05 13:18:24 2020 +0100
@@ -394,6 +394,16 @@
     @inferred apply(tm,v,Index{Unknown}.((1,2,3,2,2,4))...)
     @inferred (tm*v)[1,2,3,2,2,4]
 
+    @testset "InflatedTensorMapping of InflatedTensorMapping" begin
+        A = ScalingOperator(2.0,(2,3))
+        itm = InflatedTensorMapping(I(3,2), A, I(4))
+        @test  InflatedTensorMapping(I(4), itm, I(2)) == InflatedTensorMapping(I(4,3,2), A, I(4,2))
+        @test  InflatedTensorMapping(itm, I(2)) == InflatedTensorMapping(I(3,2), A, I(4,2))
+        @test  InflatedTensorMapping(I(4), itm) == InflatedTensorMapping(I(4,3,2), A, I(4))
+
+        @test InflatedTensorMapping(I(2), I(2), I(2)) isa InflatedTensorMapping # The constructor should always return its type.
+    end
+
 end
 
 @testset "slice_tuple" begin