Mercurial > repos > public > sbplib_julia
diff src/LazyTensors/lazy_tensor_operations.jl @ 485:4b49f03bdb98 feature/compose_identity_mappings
Switch from DimensionMismatch to SizeMismatch for boundschecks on compositions
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Thu, 05 Nov 2020 10:49:27 +0100 |
parents | b995f056ad1f |
children | 6a6b7eaf9edf |
line wrap: on
line diff
--- a/src/LazyTensors/lazy_tensor_operations.jl Thu Nov 05 10:47:31 2020 +0100 +++ b/src/LazyTensors/lazy_tensor_operations.jl Thu Nov 05 10:49:27 2020 +0100 @@ -86,15 +86,15 @@ t2::TM2 @inline function TensorMappingComposition(t1::TensorMapping{T,R,K}, t2::TensorMapping{T,K,D}) where {T,R,K,D} - @boundscheck check_matching_size(t1::TensorMapping, t2::TensorMapping) + @boundscheck check_domain_size(t1, range_size(t2)) return new{T,R,K,D, typeof(t1), typeof(t2)}(t1,t2) end end export TensorMappingComposition -function check_matching_size(t1::TensorMapping, t2::TensorMapping) - if domain_size(t1) != range_size(t2) - throw(DimensionMismatch("the first argument has domain size $(domain_size(t1)) while the second has range size $(range_size(t2)) ")) +function check_domain_size(tm::TensorMapping, sz) + if domain_size(tm) != sz + throw(SizeMismatch(tm,sz)) end end @@ -192,17 +192,17 @@ Composes a `Tensormapping` `tm` with an `IdentityMapping` `tmi`, by returning `tm` """ @inline function Base.:∘(tm::TensorMapping{T,R,D}, tmi::IdentityMapping{T,D}) where {T,R,D} - @boundscheck check_matching_size(tm::TensorMapping, tmi::TensorMapping) + @boundscheck check_domain_size(tm, range_size(tmi)) return tm end @inline function Base.:∘(tmi::IdentityMapping{T,R}, tm::TensorMapping{T,R,D}) where {T,R,D} - @boundscheck check_matching_size(tmi::TensorMapping, tm::TensorMapping) + @boundscheck check_domain_size(tmi, range_size(tm)) return tm end # Specialization for the case where tm is an IdentityMapping. Required to resolve ambiguity. @inline function Base.:∘(tm::IdentityMapping{T,D}, tmi::IdentityMapping{T,D}) where {T,D} - @boundscheck check_matching_size(tm::TensorMapping, tmi::TensorMapping) + @boundscheck check_domain_size(tm, range_size(tmi)) return tmi end