comparison src/LazyTensors/lazy_tensor_operations.jl @ 498:5a600ec40ccc feature/outer_product

Merge in default
author Jonatan Werpers <jonatan@werpers.com>
date Thu, 05 Nov 2020 15:27:04 +0100
parents 2dc2eac27f75 df566372bb4f
children 7b550c714f3f
comparison
equal deleted inserted replaced
491:2dc2eac27f75 498:5a600ec40ccc
84 struct TensorMappingComposition{T,R,K,D, TM1<:TensorMapping{T,R,K}, TM2<:TensorMapping{T,K,D}} <: TensorMapping{T,R,D} 84 struct TensorMappingComposition{T,R,K,D, TM1<:TensorMapping{T,R,K}, TM2<:TensorMapping{T,K,D}} <: TensorMapping{T,R,D}
85 t1::TM1 85 t1::TM1
86 t2::TM2 86 t2::TM2
87 87
88 @inline function TensorMappingComposition(t1::TensorMapping{T,R,K}, t2::TensorMapping{T,K,D}) where {T,R,K,D} 88 @inline function TensorMappingComposition(t1::TensorMapping{T,R,K}, t2::TensorMapping{T,K,D}) where {T,R,K,D}
89 @boundscheck if domain_size(t1) != range_size(t2) 89 @boundscheck check_domain_size(t1, range_size(t2))
90 throw(DimensionMismatch("the first argument has domain size $(domain_size(t1)) while the second has range size $(range_size(t2)) "))
91 end
92 return new{T,R,K,D, typeof(t1), typeof(t2)}(t1,t2) 90 return new{T,R,K,D, typeof(t1), typeof(t2)}(t1,t2)
93 end 91 end
94 # Add check for matching sizes as a boundscheck
95 end 92 end
96 export TensorMappingComposition 93 export TensorMappingComposition
97 94
98 range_size(tm::TensorMappingComposition) = range_size(tm.t1) 95 range_size(tm::TensorMappingComposition) = range_size(tm.t1)
99 domain_size(tm::TensorMappingComposition) = domain_size(tm.t2) 96 domain_size(tm::TensorMappingComposition) = domain_size(tm.t2)
167 range_size(tmi::IdentityMapping) = tmi.size 164 range_size(tmi::IdentityMapping) = tmi.size
168 domain_size(tmi::IdentityMapping) = tmi.size 165 domain_size(tmi::IdentityMapping) = tmi.size
169 166
170 apply(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...] 167 apply(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...]
171 apply_transpose(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...] 168 apply_transpose(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...]
169
170 """
171 Base.:∘(tm, tmi)
172 Base.:∘(tmi, tm)
173
174 Composes a `Tensormapping` `tm` with an `IdentityMapping` `tmi`, by returning `tm`
175 """
176 @inline function Base.:∘(tm::TensorMapping{T,R,D}, tmi::IdentityMapping{T,D}) where {T,R,D}
177 @boundscheck check_domain_size(tm, range_size(tmi))
178 return tm
179 end
180
181 @inline function Base.:∘(tmi::IdentityMapping{T,R}, tm::TensorMapping{T,R,D}) where {T,R,D}
182 @boundscheck check_domain_size(tmi, range_size(tm))
183 return tm
184 end
185 # Specialization for the case where tm is an IdentityMapping. Required to resolve ambiguity.
186 @inline function Base.:∘(tm::IdentityMapping{T,D}, tmi::IdentityMapping{T,D}) where {T,D}
187 @boundscheck check_domain_size(tm, range_size(tmi))
188 return tmi
189 end
172 190
173 191
174 """ 192 """
175 InflatedTensorMapping{T,R,D} <: TensorMapping{T,R,D} 193 InflatedTensorMapping{T,R,D} <: TensorMapping{T,R,D}
176 194
201 InflatedTensorMapping(tm,after) 219 InflatedTensorMapping(tm,after)
202 220
203 The outer product of `before`, `tm` and `after`, where `before` and `after` are `IdentityMapping`s. 221 The outer product of `before`, `tm` and `after`, where `before` and `after` are `IdentityMapping`s.
204 222
205 If one of `before` or `after` is left out, a 0-dimensional `IdentityMapping` is used as the default value. 223 If one of `before` or `after` is left out, a 0-dimensional `IdentityMapping` is used as the default value.
224
225 If `tm` already is an `InflatedTensorMapping`, `before` and `after` will be extended instead of
226 creating a nested `InflatedTensorMapping`.
206 """ 227 """
207 InflatedTensorMapping(::IdentityMapping, ::TensorMapping, ::IdentityMapping) 228 InflatedTensorMapping(::IdentityMapping, ::TensorMapping, ::IdentityMapping)
229
230 function InflatedTensorMapping(before, itm::InflatedTensorMapping, after)
231 return InflatedTensorMapping(
232 IdentityMapping(before.size..., itm.before.size...),
233 itm.tm,
234 IdentityMapping(itm.after.size..., after.size...),
235 )
236 end
237
208 InflatedTensorMapping(before::IdentityMapping, tm::TensorMapping{T}) where T = InflatedTensorMapping(before,tm,IdentityMapping{T}()) 238 InflatedTensorMapping(before::IdentityMapping, tm::TensorMapping{T}) where T = InflatedTensorMapping(before,tm,IdentityMapping{T}())
209 InflatedTensorMapping(tm::TensorMapping{T}, after::IdentityMapping) where T = InflatedTensorMapping(IdentityMapping{T}(),tm,after) 239 InflatedTensorMapping(tm::TensorMapping{T}, after::IdentityMapping) where T = InflatedTensorMapping(IdentityMapping{T}(),tm,after)
210 # Resolve ambiguity between the two previous methods 240 # Resolve ambiguity between the two previous methods
211 InflatedTensorMapping(I1::IdentityMapping{T}, I2::IdentityMapping{T}) where T = InflatedTensorMapping(I1,I2,IdentityMapping{T}()) 241 InflatedTensorMapping(I1::IdentityMapping{T}, I2::IdentityMapping{T}) where T = InflatedTensorMapping(I1,I2,IdentityMapping{T}())
212 242
278 """ 308 """
279 flatten_tuple(t::NTuple{N, Number} where N) = t 309 flatten_tuple(t::NTuple{N, Number} where N) = t
280 flatten_tuple(t::Tuple) = ((flatten_tuple.(t)...)...,) # simplify? 310 flatten_tuple(t::Tuple) = ((flatten_tuple.(t)...)...,) # simplify?
281 flatten_tuple(ts::Vararg) = flatten_tuple(ts) 311 flatten_tuple(ts::Vararg) = flatten_tuple(ts)
282 312
283
284 """ 313 """
285 LazyOuterProduct(tms...) 314 LazyOuterProduct(tms...)
286 315
287 Creates a `TensorComposition` for the outerproduct of `tms...`. 316 Creates a `TensorComposition` for the outerproduct of `tms...`.
288 This is done by separating the outer product into regular products of outer products involving only identity mappings and one non-identity mapping. 317 This is done by separating the outer product into regular products of outer products involving only identity mappings and one non-identity mapping.
332 361
333 ⊗(a::TensorMapping, b::TensorMapping) = LazyOuterProduct(a,b) 362 ⊗(a::TensorMapping, b::TensorMapping) = LazyOuterProduct(a,b)
334 export ⊗ 363 export ⊗
335 364
336 # TBD: Should we implement simplifications for outer products of LazyIdentities other LazyIdentities or Inflated tensormappings? 365 # TBD: Should we implement simplifications for outer products of LazyIdentities other LazyIdentities or Inflated tensormappings?
366
367 function check_domain_size(tm::TensorMapping, sz)
368 if domain_size(tm) != sz
369 throw(SizeMismatch(tm,sz))
370 end
371 end
372
373 struct SizeMismatch <: Exception
374 tm::TensorMapping
375 sz
376 end
377 export SizeMismatch
378
379 function Base.showerror(io::IO, err::SizeMismatch)
380 print(io, "SizeMismatch: ")
381 print(io, "domain size $(domain_size(err.tm)) of TensorMapping not matching size $(err.sz)")
382 end