comparison src/LazyTensors/lazy_tensor_operations.jl @ 509:b7e42384053a feature/boundary_ops

Merge w. default
author Vidar Stiernström <vidar.stiernstrom@it.uu.se>
date Sun, 08 Nov 2020 16:01:39 +0100
parents 4b9d124fe984
children a5caa934b35f fe86ac896377
comparison
equal deleted inserted replaced
478:2ab687b1d221 509:b7e42384053a
14 # TODO: Do boundschecking on creation! 14 # TODO: Do boundschecking on creation!
15 export LazyTensorMappingApplication 15 export LazyTensorMappingApplication
16 16
17 # TODO: Go through and remove unneccerary type parameters on functions 17 # TODO: Go through and remove unneccerary type parameters on functions
18 18
19 Base.:*(tm::TensorMapping{T,R,D}, o::AbstractArray{T,D}) where {T,R,D} = LazyTensorMappingApplication(tm,o)
20 Base.getindex(ta::LazyTensorMappingApplication{T,R,D}, I::Vararg{Index,R}) where {T,R,D} = apply(ta.t, ta.o, I...) 19 Base.getindex(ta::LazyTensorMappingApplication{T,R,D}, I::Vararg{Index,R}) where {T,R,D} = apply(ta.t, ta.o, I...)
21 Base.getindex(ta::LazyTensorMappingApplication{T,R,D}, I::Vararg{Int,R}) where {T,R,D} = apply(ta.t, ta.o, Index{Unknown}.(I)...) 20 Base.getindex(ta::LazyTensorMappingApplication{T,R,D}, I::Vararg{Int,R}) where {T,R,D} = apply(ta.t, ta.o, Index{Unknown}.(I)...)
22 Base.size(ta::LazyTensorMappingApplication) = range_size(ta.t) 21 Base.size(ta::LazyTensorMappingApplication) = range_size(ta.t)
23 # TODO: What else is needed to implement the AbstractArray interface? 22 # TODO: What else is needed to implement the AbstractArray interface?
24 23
24 Base.:*(a::TensorMapping, v::AbstractArray) = LazyTensorMappingApplication(a,v)
25 Base.:*(a::TensorMapping, b::TensorMapping) = throw(MethodError(Base.:*,(a,b)))
26 Base.:*(a::TensorMapping, args::Union{TensorMapping, AbstractArray}...) = foldr(*,(a,args...))
27
25 # # We need the associativity to be a→b→c = a→(b→c), which is the case for '→' 28 # # We need the associativity to be a→b→c = a→(b→c), which is the case for '→'
26 Base.:*(a::TensorMapping{T,R,D}, b::TensorMapping{T,D,K}, args::Union{TensorMapping{T}, AbstractArray{T}}...) where {T,R,D,K} = foldr(*,(a,b,args...))
27 # # Should we overload some other infix binary opesrator? 29 # # Should we overload some other infix binary opesrator?
28 # →(tm::TensorMapping{T,R,D}, o::AbstractArray{T,D}) where {T,R,D} = LazyTensorMappingApplication(tm,o) 30 # →(tm::TensorMapping{T,R,D}, o::AbstractArray{T,D}) where {T,R,D} = LazyTensorMappingApplication(tm,o)
29 # TODO: We need to be really careful about good error messages. 31 # TODO: We need to be really careful about good error messages.
30 # For example what happens if you try to multiply LazyTensorMappingApplication with a TensorMapping(wrong order)? 32 # For example what happens if you try to multiply LazyTensorMappingApplication with a TensorMapping(wrong order)?
31 33
36 38
37 If a mapping implements the the `apply_transpose` method this allows working with 39 If a mapping implements the the `apply_transpose` method this allows working with
38 the transpose of mapping `m` by using `m'`. `m'` will work as a regular TensorMapping lazily calling 40 the transpose of mapping `m` by using `m'`. `m'` will work as a regular TensorMapping lazily calling
39 the appropriate methods of `m`. 41 the appropriate methods of `m`.
40 """ 42 """
41 struct LazyTensorMappingTranspose{T,R,D} <: TensorMapping{T,D,R} 43 struct LazyTensorMappingTranspose{T,R,D, TM<:TensorMapping{T,R,D}} <: TensorMapping{T,D,R}
42 tm::TensorMapping{T,R,D} 44 tm::TM
43 end 45 end
44 export LazyTensorMappingTranspose 46 export LazyTensorMappingTranspose
45 47
46 # # TBD: Should this be implemented on a type by type basis or through a trait to provide earlier errors? 48 # # TBD: Should this be implemented on a type by type basis or through a trait to provide earlier errors?
47 # Jonatan 2020-09-25: Is the problem that you can take the transpose of any TensorMapping even if it doesn't implement `apply_transpose`? 49 # Jonatan 2020-09-25: Is the problem that you can take the transpose of any TensorMapping even if it doesn't implement `apply_transpose`?
82 struct TensorMappingComposition{T,R,K,D, TM1<:TensorMapping{T,R,K}, TM2<:TensorMapping{T,K,D}} <: TensorMapping{T,R,D} 84 struct TensorMappingComposition{T,R,K,D, TM1<:TensorMapping{T,R,K}, TM2<:TensorMapping{T,K,D}} <: TensorMapping{T,R,D}
83 t1::TM1 85 t1::TM1
84 t2::TM2 86 t2::TM2
85 87
86 @inline function TensorMappingComposition(t1::TensorMapping{T,R,K}, t2::TensorMapping{T,K,D}) where {T,R,K,D} 88 @inline function TensorMappingComposition(t1::TensorMapping{T,R,K}, t2::TensorMapping{T,K,D}) where {T,R,K,D}
87 @boundscheck if domain_size(t1) != range_size(t2) 89 @boundscheck check_domain_size(t1, range_size(t2))
88 throw(DimensionMismatch("the first argument has domain size $(domain_size(t1)) while the second has range size $(range_size(t2)) "))
89 end
90 return new{T,R,K,D, typeof(t1), typeof(t2)}(t1,t2) 90 return new{T,R,K,D, typeof(t1), typeof(t2)}(t1,t2)
91 end 91 end
92 # Add check for matching sizes as a boundscheck
93 end 92 end
94 export TensorMappingComposition 93 export TensorMappingComposition
95 94
96 range_size(tm::TensorMappingComposition) = range_size(tm.t1) 95 range_size(tm::TensorMappingComposition) = range_size(tm.t1)
97 domain_size(tm::TensorMappingComposition) = domain_size(tm.t2) 96 domain_size(tm::TensorMappingComposition) = domain_size(tm.t2)
143 end 142 end
144 143
145 function apply_transpose(llm::LazyLinearMap{T,R,D}, v::AbstractArray{T,R}, I::Vararg{Index,D}) where {T,R,D} 144 function apply_transpose(llm::LazyLinearMap{T,R,D}, v::AbstractArray{T,R}, I::Vararg{Index,D}) where {T,R,D}
146 apply(LazyLinearMap(llm.A, llm.domain_indicies, llm.range_indicies), v, I...) 145 apply(LazyLinearMap(llm.A, llm.domain_indicies, llm.range_indicies), v, I...)
147 end 146 end
147
148
149 """
150 IdentityMapping{T,D} <: TensorMapping{T,D,D}
151
152 The lazy identity TensorMapping for a given size. Usefull for building up higher dimensional tensor mappings from lower
153 dimensional ones through outer products. Also used in the Implementation for InflatedTensorMapping.
154 """
155 struct IdentityMapping{T,D} <: TensorMapping{T,D,D}
156 size::NTuple{D,Int}
157 end
158 export IdentityMapping
159
160 IdentityMapping{T}(size::NTuple{D,Int}) where {T,D} = IdentityMapping{T,D}(size)
161 IdentityMapping{T}(size::Vararg{Int,D}) where {T,D} = IdentityMapping{T,D}(size)
162 IdentityMapping(size::Vararg{Int,D}) where D = IdentityMapping{Float64,D}(size)
163
164 range_size(tmi::IdentityMapping) = tmi.size
165 domain_size(tmi::IdentityMapping) = tmi.size
166
167 apply(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...]
168 apply_transpose(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...]
169
170 """
171 Base.:∘(tm, tmi)
172 Base.:∘(tmi, tm)
173
174 Composes a `Tensormapping` `tm` with an `IdentityMapping` `tmi`, by returning `tm`
175 """
176 @inline function Base.:∘(tm::TensorMapping{T,R,D}, tmi::IdentityMapping{T,D}) where {T,R,D}
177 @boundscheck check_domain_size(tm, range_size(tmi))
178 return tm
179 end
180
181 @inline function Base.:∘(tmi::IdentityMapping{T,R}, tm::TensorMapping{T,R,D}) where {T,R,D}
182 @boundscheck check_domain_size(tmi, range_size(tm))
183 return tm
184 end
185 # Specialization for the case where tm is an IdentityMapping. Required to resolve ambiguity.
186 @inline function Base.:∘(tm::IdentityMapping{T,D}, tmi::IdentityMapping{T,D}) where {T,D}
187 @boundscheck check_domain_size(tm, range_size(tmi))
188 return tmi
189 end
190
191
192 """
193 InflatedTensorMapping{T,R,D} <: TensorMapping{T,R,D}
194
195 An inflated `TensorMapping` with dimensions added before and afer its actual dimensions.
196 """
197 struct InflatedTensorMapping{T,R,D,D_before,R_middle,D_middle,D_after, TM<:TensorMapping{T,R_middle,D_middle}} <: TensorMapping{T,R,D}
198 before::IdentityMapping{T,D_before}
199 tm::TM
200 after::IdentityMapping{T,D_after}
201
202 function InflatedTensorMapping(before, tm::TensorMapping{T}, after) where T
203 R_before = range_dim(before)
204 R_middle = range_dim(tm)
205 R_after = range_dim(after)
206 R = R_before+R_middle+R_after
207
208 D_before = domain_dim(before)
209 D_middle = domain_dim(tm)
210 D_after = domain_dim(after)
211 D = D_before+D_middle+D_after
212 return new{T,R,D,D_before,R_middle,D_middle,D_after, typeof(tm)}(before, tm, after)
213 end
214 end
215 export InflatedTensorMapping
216 """
217 InflatedTensorMapping(before, tm, after)
218 InflatedTensorMapping(before,tm)
219 InflatedTensorMapping(tm,after)
220
221 The outer product of `before`, `tm` and `after`, where `before` and `after` are `IdentityMapping`s.
222
223 If one of `before` or `after` is left out, a 0-dimensional `IdentityMapping` is used as the default value.
224
225 If `tm` already is an `InflatedTensorMapping`, `before` and `after` will be extended instead of
226 creating a nested `InflatedTensorMapping`.
227 """
228 InflatedTensorMapping(::IdentityMapping, ::TensorMapping, ::IdentityMapping)
229
230 function InflatedTensorMapping(before, itm::InflatedTensorMapping, after)
231 return InflatedTensorMapping(
232 IdentityMapping(before.size..., itm.before.size...),
233 itm.tm,
234 IdentityMapping(itm.after.size..., after.size...),
235 )
236 end
237
238 InflatedTensorMapping(before::IdentityMapping, tm::TensorMapping{T}) where T = InflatedTensorMapping(before,tm,IdentityMapping{T}())
239 InflatedTensorMapping(tm::TensorMapping{T}, after::IdentityMapping) where T = InflatedTensorMapping(IdentityMapping{T}(),tm,after)
240 # Resolve ambiguity between the two previous methods
241 InflatedTensorMapping(I1::IdentityMapping{T}, I2::IdentityMapping{T}) where T = InflatedTensorMapping(I1,I2,IdentityMapping{T}())
242
243 # TODO: Implement syntax and constructors for products of different combinations of InflatedTensorMapping and IdentityMapping
244
245 # TODO: Implement some pretty printing in terms of ⊗. E.g InflatedTensorMapping(I(3),B,I(2)) -> I(3)⊗B⊗I(2)
246
247 function range_size(itm::InflatedTensorMapping)
248 return flatten_tuple(
249 range_size(itm.before),
250 range_size(itm.tm),
251 range_size(itm.after),
252 )
253 end
254
255 function domain_size(itm::InflatedTensorMapping)
256 return flatten_tuple(
257 domain_size(itm.before),
258 domain_size(itm.tm),
259 domain_size(itm.after),
260 )
261 end
262
263 function apply(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D}
264 view_index, inner_index = split_index(itm, I...)
265
266 v_inner = view(v, view_index...)
267 return apply(itm.tm, v_inner, inner_index...)
268 end
269
270
271 """
272 split_index(...)
273
274 Splits the multi-index into two parts. One part for the view that the inner TensorMapping acts on, and one part for indexing the result
275 Eg.
276 ```
277 (1,2,3,4) -> (1,:,:,4), (2,3)
278 ```
279 """
280 function split_index(itm::InflatedTensorMapping{T,R,D}, I::Vararg{Any,R}) where {T,R,D}
281 I_before = slice_tuple(I, Val(1), Val(range_dim(itm.before)))
282 I_after = slice_tuple(I, Val(R-range_dim(itm.after)+1), Val(R))
283
284 view_index = (I_before..., ntuple((i)->:,domain_dim(itm.tm))..., I_after...)
285 inner_index = slice_tuple(I, Val(range_dim(itm.before)+1), Val(R-range_dim(itm.after)))
286
287 return (view_index, inner_index)
288 end
289
290 # TODO: Can this be replaced by something more elegant while still being type stable? 2020-10-21
291 # See:
292 # https://github.com/JuliaLang/julia/issues/34884
293 # https://github.com/JuliaLang/julia/issues/30386
294 """
295 slice_tuple(t, Val(l), Val(u))
296
297 Get a slice of a tuple in a type stable way.
298 Equivalent to t[l:u] but type stable.
299 """
300 function slice_tuple(t,::Val{L},::Val{U}) where {L,U}
301 return ntuple(i->t[i+L-1], U-L+1)
302 end
303
304 """
305 flatten_tuple(t)
306
307 Takes a nested tuple and flattens the whole structure
308 """
309 flatten_tuple(t::NTuple{N, Number} where N) = t
310 flatten_tuple(t::Tuple) = ((flatten_tuple.(t)...)...,) # simplify?
311 flatten_tuple(ts::Vararg) = flatten_tuple(ts)
312
313 """
314 LazyOuterProduct(tms...)
315
316 Creates a `TensorMappingComposition` for the outerproduct of `tms...`.
317 This is done by separating the outer product into regular products of outer products involving only identity mappings and one non-identity mapping.
318
319 First let
320 ```math
321 A = A_{I,J}
322 B = B_{M,N}
323 C = C_{P,Q}
324 ```
325
326 where ``I``, ``M``, ``P`` are multi-indexes for the ranges of ``A``, ``B``, ``C``, and ``J``, ``N``, ``Q`` are multi-indexes of the domains.
327
328 We use ``⊗`` to denote the outer product
329 ```math
330 (A⊗B)_{IM,JN} = A_{I,J}B_{M,N}
331 ```
332
333 We note that
334 ```math
335 A⊗B⊗C = (A⊗B⊗C)_{IMP,JNQ} = A_{I,J}B_{M,N}C_{P,Q}
336 ```
337 And that
338 ```math
339 A⊗B⊗C = (A⊗I_{|M|}⊗I_{|P|})(I_{|J|}⊗B⊗I_{|P|})(I_{|J|}⊗I_{|N|}⊗C)
340 ```
341 where |.| of a multi-index is a vector of sizes for each dimension. ``I_v`` denotes the identity tensor of size ``v[i]`` in each direction
342 To apply ``A⊗B⊗C`` we evaluate
343
344 (A⊗B⊗C)v = [(A⊗I_{|M|}⊗I_{|P|}) [(I_{|J|}⊗B⊗I_{|P|}) [(I_{|J|}⊗I_{|N|}⊗C)v]]]
345 """
346 function LazyOuterProduct end
347 export LazyOuterProduct
348
349 function LazyOuterProduct(tm1::TensorMapping{T}, tm2::TensorMapping{T}) where T
350 itm1 = InflatedTensorMapping(tm1, IdentityMapping{T}(range_size(tm2)))
351 itm2 = InflatedTensorMapping(IdentityMapping{T}(domain_size(tm1)),tm2)
352
353 return itm1∘itm2
354 end
355
356 LazyOuterProduct(t1::IdentityMapping{T}, t2::IdentityMapping{T}) where T = IdentityMapping{T}(t1.size...,t2.size...)
357 LazyOuterProduct(t1::TensorMapping, t2::IdentityMapping) = InflatedTensorMapping(t1, t2)
358 LazyOuterProduct(t1::IdentityMapping, t2::TensorMapping) = InflatedTensorMapping(t1, t2)
359
360 LazyOuterProduct(tms::Vararg{TensorMapping}) = foldl(LazyOuterProduct, tms)
361
362 ⊗(a::TensorMapping, b::TensorMapping) = LazyOuterProduct(a,b)
363 export ⊗
364
365
366 function check_domain_size(tm::TensorMapping, sz)
367 if domain_size(tm) != sz
368 throw(SizeMismatch(tm,sz))
369 end
370 end
371
372 struct SizeMismatch <: Exception
373 tm::TensorMapping
374 sz
375 end
376 export SizeMismatch
377
378 function Base.showerror(io::IO, err::SizeMismatch)
379 print(io, "SizeMismatch: ")
380 print(io, "domain size $(domain_size(err.tm)) of TensorMapping not matching size $(err.sz)")
381 end