comparison src/LazyTensors/lazy_tensor_operations.jl @ 485:4b49f03bdb98 feature/compose_identity_mappings

Switch from DimensionMismatch to SizeMismatch for boundschecks on compositions
author Jonatan Werpers <jonatan@werpers.com>
date Thu, 05 Nov 2020 10:49:27 +0100
parents b995f056ad1f
children 6a6b7eaf9edf
comparison
equal deleted inserted replaced
484:b995f056ad1f 485:4b49f03bdb98
84 struct TensorMappingComposition{T,R,K,D, TM1<:TensorMapping{T,R,K}, TM2<:TensorMapping{T,K,D}} <: TensorMapping{T,R,D} 84 struct TensorMappingComposition{T,R,K,D, TM1<:TensorMapping{T,R,K}, TM2<:TensorMapping{T,K,D}} <: TensorMapping{T,R,D}
85 t1::TM1 85 t1::TM1
86 t2::TM2 86 t2::TM2
87 87
88 @inline function TensorMappingComposition(t1::TensorMapping{T,R,K}, t2::TensorMapping{T,K,D}) where {T,R,K,D} 88 @inline function TensorMappingComposition(t1::TensorMapping{T,R,K}, t2::TensorMapping{T,K,D}) where {T,R,K,D}
89 @boundscheck check_matching_size(t1::TensorMapping, t2::TensorMapping) 89 @boundscheck check_domain_size(t1, range_size(t2))
90 return new{T,R,K,D, typeof(t1), typeof(t2)}(t1,t2) 90 return new{T,R,K,D, typeof(t1), typeof(t2)}(t1,t2)
91 end 91 end
92 end 92 end
93 export TensorMappingComposition 93 export TensorMappingComposition
94 94
95 function check_matching_size(t1::TensorMapping, t2::TensorMapping) 95 function check_domain_size(tm::TensorMapping, sz)
96 if domain_size(t1) != range_size(t2) 96 if domain_size(tm) != sz
97 throw(DimensionMismatch("the first argument has domain size $(domain_size(t1)) while the second has range size $(range_size(t2)) ")) 97 throw(SizeMismatch(tm,sz))
98 end 98 end
99 end 99 end
100 100
101 struct SizeMismatch <: Exception 101 struct SizeMismatch <: Exception
102 tm::TensorMapping 102 tm::TensorMapping
190 Base.:∘(tmi, tm) 190 Base.:∘(tmi, tm)
191 191
192 Composes a `Tensormapping` `tm` with an `IdentityMapping` `tmi`, by returning `tm` 192 Composes a `Tensormapping` `tm` with an `IdentityMapping` `tmi`, by returning `tm`
193 """ 193 """
194 @inline function Base.:∘(tm::TensorMapping{T,R,D}, tmi::IdentityMapping{T,D}) where {T,R,D} 194 @inline function Base.:∘(tm::TensorMapping{T,R,D}, tmi::IdentityMapping{T,D}) where {T,R,D}
195 @boundscheck check_matching_size(tm::TensorMapping, tmi::TensorMapping) 195 @boundscheck check_domain_size(tm, range_size(tmi))
196 return tm 196 return tm
197 end 197 end
198 198
199 @inline function Base.:∘(tmi::IdentityMapping{T,R}, tm::TensorMapping{T,R,D}) where {T,R,D} 199 @inline function Base.:∘(tmi::IdentityMapping{T,R}, tm::TensorMapping{T,R,D}) where {T,R,D}
200 @boundscheck check_matching_size(tmi::TensorMapping, tm::TensorMapping) 200 @boundscheck check_domain_size(tmi, range_size(tm))
201 return tm 201 return tm
202 end 202 end
203 # Specialization for the case where tm is an IdentityMapping. Required to resolve ambiguity. 203 # Specialization for the case where tm is an IdentityMapping. Required to resolve ambiguity.
204 @inline function Base.:∘(tm::IdentityMapping{T,D}, tmi::IdentityMapping{T,D}) where {T,D} 204 @inline function Base.:∘(tm::IdentityMapping{T,D}, tmi::IdentityMapping{T,D}) where {T,D}
205 @boundscheck check_matching_size(tm::TensorMapping, tmi::TensorMapping) 205 @boundscheck check_domain_size(tm, range_size(tmi))
206 return tmi 206 return tmi
207 end 207 end
208 208
209 209
210 """ 210 """