diff src/LazyTensors/lazy_tensor_operations.jl @ 489:1a7d6da3cc45

Merge feature/compose_identity_mappings
author Vidar Stiernström <vidar.stiernstrom@it.uu.se>
date Thu, 05 Nov 2020 11:32:49 +0100
parents 6a6b7eaf9edf
children df566372bb4f
line wrap: on
line diff
--- a/src/LazyTensors/lazy_tensor_operations.jl	Mon Nov 02 21:33:35 2020 +0100
+++ b/src/LazyTensors/lazy_tensor_operations.jl	Thu Nov 05 11:32:49 2020 +0100
@@ -86,12 +86,9 @@
     t2::TM2
 
     @inline function TensorMappingComposition(t1::TensorMapping{T,R,K}, t2::TensorMapping{T,K,D}) where {T,R,K,D}
-        @boundscheck if domain_size(t1) != range_size(t2)
-            throw(DimensionMismatch("the first argument has domain size $(domain_size(t1)) while the second has range size $(range_size(t2)) "))
-        end
+        @boundscheck check_domain_size(t1, range_size(t2))
         return new{T,R,K,D, typeof(t1), typeof(t2)}(t1,t2)
     end
-    # Add check for matching sizes as a boundscheck
 end
 export TensorMappingComposition
 
@@ -170,6 +167,27 @@
 apply(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...]
 apply_transpose(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...]
 
+"""
+Base.:∘(tm, tmi)
+Base.:∘(tmi, tm)
+
+Composes a `Tensormapping` `tm` with an `IdentityMapping` `tmi`, by returning `tm`
+"""
+@inline function Base.:∘(tm::TensorMapping{T,R,D}, tmi::IdentityMapping{T,D}) where {T,R,D}
+    @boundscheck check_domain_size(tm, range_size(tmi))
+    return tm
+end
+
+@inline function Base.:∘(tmi::IdentityMapping{T,R}, tm::TensorMapping{T,R,D}) where {T,R,D}
+    @boundscheck check_domain_size(tmi, range_size(tm))
+    return tm
+end
+# Specialization for the case where tm is an IdentityMapping. Required to resolve ambiguity.
+@inline function Base.:∘(tm::IdentityMapping{T,D}, tmi::IdentityMapping{T,D}) where {T,D}
+    @boundscheck check_domain_size(tm, range_size(tmi))
+    return tmi
+end
+
 
 """
     InflatedTensorMapping{T,R,D} <: TensorMapping{T,R,D}
@@ -279,3 +297,20 @@
 flatten_tuple(t::NTuple{N, Number} where N) = t
 flatten_tuple(t::Tuple) = ((flatten_tuple.(t)...)...,) # simplify?
 flatten_tuple(ts::Vararg) = flatten_tuple(ts)
+
+function check_domain_size(tm::TensorMapping, sz)
+    if domain_size(tm) != sz
+        throw(SizeMismatch(tm,sz))
+    end
+end
+
+struct SizeMismatch <: Exception
+    tm::TensorMapping
+    sz
+end
+export SizeMismatch
+
+function Base.showerror(io::IO, err::SizeMismatch)
+    print(io, "SizeMismatch: ")
+    print(io, "domain size $(domain_size(err.tm)) of TensorMapping not matching size $(err.sz)")
+end