comparison src/LazyTensors/lazy_tensor_operations.jl @ 487:6a6b7eaf9edf feature/compose_identity_mappings

Move exception handling to end of file and update error message.
author Vidar Stiernström <vidar.stiernstrom@it.uu.se>
date Thu, 05 Nov 2020 11:30:53 +0100
parents 4b49f03bdb98
children df566372bb4f
comparison
equal deleted inserted replaced
486:8082d43103c1 487:6a6b7eaf9edf
89 @boundscheck check_domain_size(t1, range_size(t2)) 89 @boundscheck check_domain_size(t1, range_size(t2))
90 return new{T,R,K,D, typeof(t1), typeof(t2)}(t1,t2) 90 return new{T,R,K,D, typeof(t1), typeof(t2)}(t1,t2)
91 end 91 end
92 end 92 end
93 export TensorMappingComposition 93 export TensorMappingComposition
94
95 function check_domain_size(tm::TensorMapping, sz)
96 if domain_size(tm) != sz
97 throw(SizeMismatch(tm,sz))
98 end
99 end
100
101 struct SizeMismatch <: Exception
102 tm::TensorMapping
103 sz
104 end
105 export SizeMismatch
106
107 function Base.showerror(io::IO, err::SizeMismatch)
108 print(io, "SizeMismatch: ")
109 print(io, "attempt to apply TensorMapping with domain size $(domain_size(err.tm)) to a domain of size $(err.sz)")
110 end
111
112 94
113 range_size(tm::TensorMappingComposition) = range_size(tm.t1) 95 range_size(tm::TensorMappingComposition) = range_size(tm.t1)
114 domain_size(tm::TensorMappingComposition) = domain_size(tm.t2) 96 domain_size(tm::TensorMappingComposition) = domain_size(tm.t2)
115 97
116 function apply(c::TensorMappingComposition{T,R,K,D}, v::AbstractArray{T,D}, I::Vararg{S,R} where S) where {T,R,K,D} 98 function apply(c::TensorMappingComposition{T,R,K,D}, v::AbstractArray{T,D}, I::Vararg{S,R} where S) where {T,R,K,D}
313 Takes a nested tuple and flattens the whole structure 295 Takes a nested tuple and flattens the whole structure
314 """ 296 """
315 flatten_tuple(t::NTuple{N, Number} where N) = t 297 flatten_tuple(t::NTuple{N, Number} where N) = t
316 flatten_tuple(t::Tuple) = ((flatten_tuple.(t)...)...,) # simplify? 298 flatten_tuple(t::Tuple) = ((flatten_tuple.(t)...)...,) # simplify?
317 flatten_tuple(ts::Vararg) = flatten_tuple(ts) 299 flatten_tuple(ts::Vararg) = flatten_tuple(ts)
300
301 function check_domain_size(tm::TensorMapping, sz)
302 if domain_size(tm) != sz
303 throw(SizeMismatch(tm,sz))
304 end
305 end
306
307 struct SizeMismatch <: Exception
308 tm::TensorMapping
309 sz
310 end
311 export SizeMismatch
312
313 function Base.showerror(io::IO, err::SizeMismatch)
314 print(io, "SizeMismatch: ")
315 print(io, "domain size $(domain_size(err.tm)) of TensorMapping not matching size $(err.sz)")
316 end