changeset 479:95f3b9036801 feature/compose_identity_mappings

Specialize composition operator for composing a tensormapping with an identitymapping.
author Vidar Stiernström <vidar.stiernstrom@it.uu.se>
date Wed, 04 Nov 2020 20:03:37 +0100
parents 3041f8578bba
children c1a366331e75
files src/LazyTensors/lazy_tensor_operations.jl test/testLazyTensors.jl
diffstat 2 files changed, 39 insertions(+), 4 deletions(-) [+]
line wrap: on
line diff
--- a/src/LazyTensors/lazy_tensor_operations.jl	Mon Nov 02 21:33:35 2020 +0100
+++ b/src/LazyTensors/lazy_tensor_operations.jl	Wed Nov 04 20:03:37 2020 +0100
@@ -86,15 +86,18 @@
     t2::TM2
 
     @inline function TensorMappingComposition(t1::TensorMapping{T,R,K}, t2::TensorMapping{T,K,D}) where {T,R,K,D}
-        @boundscheck if domain_size(t1) != range_size(t2)
-            throw(DimensionMismatch("the first argument has domain size $(domain_size(t1)) while the second has range size $(range_size(t2)) "))
-        end
+        @boundscheck check_matching_size(t1::TensorMapping, t2::TensorMapping)
         return new{T,R,K,D, typeof(t1), typeof(t2)}(t1,t2)
     end
-    # Add check for matching sizes as a boundscheck
 end
 export TensorMappingComposition
 
+function check_matching_size(t1::TensorMapping, t2::TensorMapping)
+    if domain_size(t1) != range_size(t2)
+        throw(DimensionMismatch("the first argument has domain size $(domain_size(t1)) while the second has range size $(range_size(t2)) "))
+    end
+end
+
 range_size(tm::TensorMappingComposition) = range_size(tm.t1)
 domain_size(tm::TensorMappingComposition) = domain_size(tm.t2)
 
@@ -170,6 +173,27 @@
 apply(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...]
 apply_transpose(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...]
 
+"""
+Base.:∘(tm, tmi)
+Base.:∘(tmi, tm)
+
+Composes a `Tensormapping`s `tm` with an `IdentityMapping`s `tmi`, by returning `tm`
+"""
+@inline function Base.:∘(tm::TensorMapping{T,R,D}, tmi::IdentityMapping{T,D}) where {T,R,D}
+    @boundscheck check_matching_size(tm::TensorMapping, tmi::TensorMapping)
+    return tm
+end
+
+@inline function Base.:∘(tmi::IdentityMapping{T,R}, tm::TensorMapping{T,R,D}) where {T,R,D}
+    @boundscheck check_matching_size(tmi::TensorMapping, tm::TensorMapping)
+    return tm
+end
+# Specialization for the case where tm is an IdentityMapping. Required to resolve ambiguity.
+@inline function Base.:∘(tm::IdentityMapping{T,D}, tmi::IdentityMapping{T,D}) where {T,D}
+    @boundscheck check_matching_size(tm::TensorMapping, tmi::TensorMapping)
+    return tmi
+end
+
 
 """
     InflatedTensorMapping{T,R,D} <: TensorMapping{T,R,D}
--- a/test/testLazyTensors.jl	Mon Nov 02 21:33:35 2020 +0100
+++ b/test/testLazyTensors.jl	Wed Nov 04 20:03:37 2020 +0100
@@ -312,6 +312,17 @@
 
     @inferred range_dim(I)
     @inferred domain_dim(I)
+
+    Ã = rand(4,2)
+    A = LazyLinearMap(Ã,(1,),(2,))
+    I1 = IdentityMapping{Float64}(2)
+    I2 = IdentityMapping{Float64}(4)
+    @test A∘I1 == A
+    @test I2∘A == A
+    @test I1∘I1 == I1
+    @test_throws DimensionMismatch I1∘A
+    @test_throws DimensionMismatch A∘I2
+    @test_throws DimensionMismatch I1∘I2
 end
 
 @testset "InflatedTensorMapping" begin