Mercurial > repos > public > sbplib_julia
comparison src/LazyTensors/lazy_tensor_operations.jl @ 485:4b49f03bdb98 feature/compose_identity_mappings
Switch from DimensionMismatch to SizeMismatch for boundschecks on compositions
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Thu, 05 Nov 2020 10:49:27 +0100 |
parents | b995f056ad1f |
children | 6a6b7eaf9edf |
comparison
equal
deleted
inserted
replaced
484:b995f056ad1f | 485:4b49f03bdb98 |
---|---|
84 struct TensorMappingComposition{T,R,K,D, TM1<:TensorMapping{T,R,K}, TM2<:TensorMapping{T,K,D}} <: TensorMapping{T,R,D} | 84 struct TensorMappingComposition{T,R,K,D, TM1<:TensorMapping{T,R,K}, TM2<:TensorMapping{T,K,D}} <: TensorMapping{T,R,D} |
85 t1::TM1 | 85 t1::TM1 |
86 t2::TM2 | 86 t2::TM2 |
87 | 87 |
88 @inline function TensorMappingComposition(t1::TensorMapping{T,R,K}, t2::TensorMapping{T,K,D}) where {T,R,K,D} | 88 @inline function TensorMappingComposition(t1::TensorMapping{T,R,K}, t2::TensorMapping{T,K,D}) where {T,R,K,D} |
89 @boundscheck check_matching_size(t1::TensorMapping, t2::TensorMapping) | 89 @boundscheck check_domain_size(t1, range_size(t2)) |
90 return new{T,R,K,D, typeof(t1), typeof(t2)}(t1,t2) | 90 return new{T,R,K,D, typeof(t1), typeof(t2)}(t1,t2) |
91 end | 91 end |
92 end | 92 end |
93 export TensorMappingComposition | 93 export TensorMappingComposition |
94 | 94 |
95 function check_matching_size(t1::TensorMapping, t2::TensorMapping) | 95 function check_domain_size(tm::TensorMapping, sz) |
96 if domain_size(t1) != range_size(t2) | 96 if domain_size(tm) != sz |
97 throw(DimensionMismatch("the first argument has domain size $(domain_size(t1)) while the second has range size $(range_size(t2)) ")) | 97 throw(SizeMismatch(tm,sz)) |
98 end | 98 end |
99 end | 99 end |
100 | 100 |
101 struct SizeMismatch <: Exception | 101 struct SizeMismatch <: Exception |
102 tm::TensorMapping | 102 tm::TensorMapping |
190 Base.:∘(tmi, tm) | 190 Base.:∘(tmi, tm) |
191 | 191 |
192 Composes a `Tensormapping` `tm` with an `IdentityMapping` `tmi`, by returning `tm` | 192 Composes a `Tensormapping` `tm` with an `IdentityMapping` `tmi`, by returning `tm` |
193 """ | 193 """ |
194 @inline function Base.:∘(tm::TensorMapping{T,R,D}, tmi::IdentityMapping{T,D}) where {T,R,D} | 194 @inline function Base.:∘(tm::TensorMapping{T,R,D}, tmi::IdentityMapping{T,D}) where {T,R,D} |
195 @boundscheck check_matching_size(tm::TensorMapping, tmi::TensorMapping) | 195 @boundscheck check_domain_size(tm, range_size(tmi)) |
196 return tm | 196 return tm |
197 end | 197 end |
198 | 198 |
199 @inline function Base.:∘(tmi::IdentityMapping{T,R}, tm::TensorMapping{T,R,D}) where {T,R,D} | 199 @inline function Base.:∘(tmi::IdentityMapping{T,R}, tm::TensorMapping{T,R,D}) where {T,R,D} |
200 @boundscheck check_matching_size(tmi::TensorMapping, tm::TensorMapping) | 200 @boundscheck check_domain_size(tmi, range_size(tm)) |
201 return tm | 201 return tm |
202 end | 202 end |
203 # Specialization for the case where tm is an IdentityMapping. Required to resolve ambiguity. | 203 # Specialization for the case where tm is an IdentityMapping. Required to resolve ambiguity. |
204 @inline function Base.:∘(tm::IdentityMapping{T,D}, tmi::IdentityMapping{T,D}) where {T,D} | 204 @inline function Base.:∘(tm::IdentityMapping{T,D}, tmi::IdentityMapping{T,D}) where {T,D} |
205 @boundscheck check_matching_size(tm::TensorMapping, tmi::TensorMapping) | 205 @boundscheck check_domain_size(tm, range_size(tmi)) |
206 return tmi | 206 return tmi |
207 end | 207 end |
208 | 208 |
209 | 209 |
210 """ | 210 """ |