Mercurial > repos > public > sbplib_julia
comparison src/LazyTensors/lazy_tensor_operations.jl @ 509:b7e42384053a feature/boundary_ops
Merge w. default
author | Vidar Stiernström <vidar.stiernstrom@it.uu.se> |
---|---|
date | Sun, 08 Nov 2020 16:01:39 +0100 |
parents | 4b9d124fe984 |
children | a5caa934b35f fe86ac896377 |
comparison
equal
deleted
inserted
replaced
478:2ab687b1d221 | 509:b7e42384053a |
---|---|
14 # TODO: Do boundschecking on creation! | 14 # TODO: Do boundschecking on creation! |
15 export LazyTensorMappingApplication | 15 export LazyTensorMappingApplication |
16 | 16 |
17 # TODO: Go through and remove unneccerary type parameters on functions | 17 # TODO: Go through and remove unneccerary type parameters on functions |
18 | 18 |
19 Base.:*(tm::TensorMapping{T,R,D}, o::AbstractArray{T,D}) where {T,R,D} = LazyTensorMappingApplication(tm,o) | |
20 Base.getindex(ta::LazyTensorMappingApplication{T,R,D}, I::Vararg{Index,R}) where {T,R,D} = apply(ta.t, ta.o, I...) | 19 Base.getindex(ta::LazyTensorMappingApplication{T,R,D}, I::Vararg{Index,R}) where {T,R,D} = apply(ta.t, ta.o, I...) |
21 Base.getindex(ta::LazyTensorMappingApplication{T,R,D}, I::Vararg{Int,R}) where {T,R,D} = apply(ta.t, ta.o, Index{Unknown}.(I)...) | 20 Base.getindex(ta::LazyTensorMappingApplication{T,R,D}, I::Vararg{Int,R}) where {T,R,D} = apply(ta.t, ta.o, Index{Unknown}.(I)...) |
22 Base.size(ta::LazyTensorMappingApplication) = range_size(ta.t) | 21 Base.size(ta::LazyTensorMappingApplication) = range_size(ta.t) |
23 # TODO: What else is needed to implement the AbstractArray interface? | 22 # TODO: What else is needed to implement the AbstractArray interface? |
24 | 23 |
24 Base.:*(a::TensorMapping, v::AbstractArray) = LazyTensorMappingApplication(a,v) | |
25 Base.:*(a::TensorMapping, b::TensorMapping) = throw(MethodError(Base.:*,(a,b))) | |
26 Base.:*(a::TensorMapping, args::Union{TensorMapping, AbstractArray}...) = foldr(*,(a,args...)) | |
27 | |
25 # # We need the associativity to be a→b→c = a→(b→c), which is the case for '→' | 28 # # We need the associativity to be a→b→c = a→(b→c), which is the case for '→' |
26 Base.:*(a::TensorMapping{T,R,D}, b::TensorMapping{T,D,K}, args::Union{TensorMapping{T}, AbstractArray{T}}...) where {T,R,D,K} = foldr(*,(a,b,args...)) | |
27 # # Should we overload some other infix binary opesrator? | 29 # # Should we overload some other infix binary opesrator? |
28 # →(tm::TensorMapping{T,R,D}, o::AbstractArray{T,D}) where {T,R,D} = LazyTensorMappingApplication(tm,o) | 30 # →(tm::TensorMapping{T,R,D}, o::AbstractArray{T,D}) where {T,R,D} = LazyTensorMappingApplication(tm,o) |
29 # TODO: We need to be really careful about good error messages. | 31 # TODO: We need to be really careful about good error messages. |
30 # For example what happens if you try to multiply LazyTensorMappingApplication with a TensorMapping(wrong order)? | 32 # For example what happens if you try to multiply LazyTensorMappingApplication with a TensorMapping(wrong order)? |
31 | 33 |
36 | 38 |
37 If a mapping implements the the `apply_transpose` method this allows working with | 39 If a mapping implements the the `apply_transpose` method this allows working with |
38 the transpose of mapping `m` by using `m'`. `m'` will work as a regular TensorMapping lazily calling | 40 the transpose of mapping `m` by using `m'`. `m'` will work as a regular TensorMapping lazily calling |
39 the appropriate methods of `m`. | 41 the appropriate methods of `m`. |
40 """ | 42 """ |
41 struct LazyTensorMappingTranspose{T,R,D} <: TensorMapping{T,D,R} | 43 struct LazyTensorMappingTranspose{T,R,D, TM<:TensorMapping{T,R,D}} <: TensorMapping{T,D,R} |
42 tm::TensorMapping{T,R,D} | 44 tm::TM |
43 end | 45 end |
44 export LazyTensorMappingTranspose | 46 export LazyTensorMappingTranspose |
45 | 47 |
46 # # TBD: Should this be implemented on a type by type basis or through a trait to provide earlier errors? | 48 # # TBD: Should this be implemented on a type by type basis or through a trait to provide earlier errors? |
47 # Jonatan 2020-09-25: Is the problem that you can take the transpose of any TensorMapping even if it doesn't implement `apply_transpose`? | 49 # Jonatan 2020-09-25: Is the problem that you can take the transpose of any TensorMapping even if it doesn't implement `apply_transpose`? |
82 struct TensorMappingComposition{T,R,K,D, TM1<:TensorMapping{T,R,K}, TM2<:TensorMapping{T,K,D}} <: TensorMapping{T,R,D} | 84 struct TensorMappingComposition{T,R,K,D, TM1<:TensorMapping{T,R,K}, TM2<:TensorMapping{T,K,D}} <: TensorMapping{T,R,D} |
83 t1::TM1 | 85 t1::TM1 |
84 t2::TM2 | 86 t2::TM2 |
85 | 87 |
86 @inline function TensorMappingComposition(t1::TensorMapping{T,R,K}, t2::TensorMapping{T,K,D}) where {T,R,K,D} | 88 @inline function TensorMappingComposition(t1::TensorMapping{T,R,K}, t2::TensorMapping{T,K,D}) where {T,R,K,D} |
87 @boundscheck if domain_size(t1) != range_size(t2) | 89 @boundscheck check_domain_size(t1, range_size(t2)) |
88 throw(DimensionMismatch("the first argument has domain size $(domain_size(t1)) while the second has range size $(range_size(t2)) ")) | |
89 end | |
90 return new{T,R,K,D, typeof(t1), typeof(t2)}(t1,t2) | 90 return new{T,R,K,D, typeof(t1), typeof(t2)}(t1,t2) |
91 end | 91 end |
92 # Add check for matching sizes as a boundscheck | |
93 end | 92 end |
94 export TensorMappingComposition | 93 export TensorMappingComposition |
95 | 94 |
96 range_size(tm::TensorMappingComposition) = range_size(tm.t1) | 95 range_size(tm::TensorMappingComposition) = range_size(tm.t1) |
97 domain_size(tm::TensorMappingComposition) = domain_size(tm.t2) | 96 domain_size(tm::TensorMappingComposition) = domain_size(tm.t2) |
143 end | 142 end |
144 | 143 |
145 function apply_transpose(llm::LazyLinearMap{T,R,D}, v::AbstractArray{T,R}, I::Vararg{Index,D}) where {T,R,D} | 144 function apply_transpose(llm::LazyLinearMap{T,R,D}, v::AbstractArray{T,R}, I::Vararg{Index,D}) where {T,R,D} |
146 apply(LazyLinearMap(llm.A, llm.domain_indicies, llm.range_indicies), v, I...) | 145 apply(LazyLinearMap(llm.A, llm.domain_indicies, llm.range_indicies), v, I...) |
147 end | 146 end |
147 | |
148 | |
149 """ | |
150 IdentityMapping{T,D} <: TensorMapping{T,D,D} | |
151 | |
152 The lazy identity TensorMapping for a given size. Usefull for building up higher dimensional tensor mappings from lower | |
153 dimensional ones through outer products. Also used in the Implementation for InflatedTensorMapping. | |
154 """ | |
155 struct IdentityMapping{T,D} <: TensorMapping{T,D,D} | |
156 size::NTuple{D,Int} | |
157 end | |
158 export IdentityMapping | |
159 | |
160 IdentityMapping{T}(size::NTuple{D,Int}) where {T,D} = IdentityMapping{T,D}(size) | |
161 IdentityMapping{T}(size::Vararg{Int,D}) where {T,D} = IdentityMapping{T,D}(size) | |
162 IdentityMapping(size::Vararg{Int,D}) where D = IdentityMapping{Float64,D}(size) | |
163 | |
164 range_size(tmi::IdentityMapping) = tmi.size | |
165 domain_size(tmi::IdentityMapping) = tmi.size | |
166 | |
167 apply(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...] | |
168 apply_transpose(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...] | |
169 | |
170 """ | |
171 Base.:∘(tm, tmi) | |
172 Base.:∘(tmi, tm) | |
173 | |
174 Composes a `Tensormapping` `tm` with an `IdentityMapping` `tmi`, by returning `tm` | |
175 """ | |
176 @inline function Base.:∘(tm::TensorMapping{T,R,D}, tmi::IdentityMapping{T,D}) where {T,R,D} | |
177 @boundscheck check_domain_size(tm, range_size(tmi)) | |
178 return tm | |
179 end | |
180 | |
181 @inline function Base.:∘(tmi::IdentityMapping{T,R}, tm::TensorMapping{T,R,D}) where {T,R,D} | |
182 @boundscheck check_domain_size(tmi, range_size(tm)) | |
183 return tm | |
184 end | |
185 # Specialization for the case where tm is an IdentityMapping. Required to resolve ambiguity. | |
186 @inline function Base.:∘(tm::IdentityMapping{T,D}, tmi::IdentityMapping{T,D}) where {T,D} | |
187 @boundscheck check_domain_size(tm, range_size(tmi)) | |
188 return tmi | |
189 end | |
190 | |
191 | |
192 """ | |
193 InflatedTensorMapping{T,R,D} <: TensorMapping{T,R,D} | |
194 | |
195 An inflated `TensorMapping` with dimensions added before and afer its actual dimensions. | |
196 """ | |
197 struct InflatedTensorMapping{T,R,D,D_before,R_middle,D_middle,D_after, TM<:TensorMapping{T,R_middle,D_middle}} <: TensorMapping{T,R,D} | |
198 before::IdentityMapping{T,D_before} | |
199 tm::TM | |
200 after::IdentityMapping{T,D_after} | |
201 | |
202 function InflatedTensorMapping(before, tm::TensorMapping{T}, after) where T | |
203 R_before = range_dim(before) | |
204 R_middle = range_dim(tm) | |
205 R_after = range_dim(after) | |
206 R = R_before+R_middle+R_after | |
207 | |
208 D_before = domain_dim(before) | |
209 D_middle = domain_dim(tm) | |
210 D_after = domain_dim(after) | |
211 D = D_before+D_middle+D_after | |
212 return new{T,R,D,D_before,R_middle,D_middle,D_after, typeof(tm)}(before, tm, after) | |
213 end | |
214 end | |
215 export InflatedTensorMapping | |
216 """ | |
217 InflatedTensorMapping(before, tm, after) | |
218 InflatedTensorMapping(before,tm) | |
219 InflatedTensorMapping(tm,after) | |
220 | |
221 The outer product of `before`, `tm` and `after`, where `before` and `after` are `IdentityMapping`s. | |
222 | |
223 If one of `before` or `after` is left out, a 0-dimensional `IdentityMapping` is used as the default value. | |
224 | |
225 If `tm` already is an `InflatedTensorMapping`, `before` and `after` will be extended instead of | |
226 creating a nested `InflatedTensorMapping`. | |
227 """ | |
228 InflatedTensorMapping(::IdentityMapping, ::TensorMapping, ::IdentityMapping) | |
229 | |
230 function InflatedTensorMapping(before, itm::InflatedTensorMapping, after) | |
231 return InflatedTensorMapping( | |
232 IdentityMapping(before.size..., itm.before.size...), | |
233 itm.tm, | |
234 IdentityMapping(itm.after.size..., after.size...), | |
235 ) | |
236 end | |
237 | |
238 InflatedTensorMapping(before::IdentityMapping, tm::TensorMapping{T}) where T = InflatedTensorMapping(before,tm,IdentityMapping{T}()) | |
239 InflatedTensorMapping(tm::TensorMapping{T}, after::IdentityMapping) where T = InflatedTensorMapping(IdentityMapping{T}(),tm,after) | |
240 # Resolve ambiguity between the two previous methods | |
241 InflatedTensorMapping(I1::IdentityMapping{T}, I2::IdentityMapping{T}) where T = InflatedTensorMapping(I1,I2,IdentityMapping{T}()) | |
242 | |
243 # TODO: Implement syntax and constructors for products of different combinations of InflatedTensorMapping and IdentityMapping | |
244 | |
245 # TODO: Implement some pretty printing in terms of ⊗. E.g InflatedTensorMapping(I(3),B,I(2)) -> I(3)⊗B⊗I(2) | |
246 | |
247 function range_size(itm::InflatedTensorMapping) | |
248 return flatten_tuple( | |
249 range_size(itm.before), | |
250 range_size(itm.tm), | |
251 range_size(itm.after), | |
252 ) | |
253 end | |
254 | |
255 function domain_size(itm::InflatedTensorMapping) | |
256 return flatten_tuple( | |
257 domain_size(itm.before), | |
258 domain_size(itm.tm), | |
259 domain_size(itm.after), | |
260 ) | |
261 end | |
262 | |
263 function apply(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D} | |
264 view_index, inner_index = split_index(itm, I...) | |
265 | |
266 v_inner = view(v, view_index...) | |
267 return apply(itm.tm, v_inner, inner_index...) | |
268 end | |
269 | |
270 | |
271 """ | |
272 split_index(...) | |
273 | |
274 Splits the multi-index into two parts. One part for the view that the inner TensorMapping acts on, and one part for indexing the result | |
275 Eg. | |
276 ``` | |
277 (1,2,3,4) -> (1,:,:,4), (2,3) | |
278 ``` | |
279 """ | |
280 function split_index(itm::InflatedTensorMapping{T,R,D}, I::Vararg{Any,R}) where {T,R,D} | |
281 I_before = slice_tuple(I, Val(1), Val(range_dim(itm.before))) | |
282 I_after = slice_tuple(I, Val(R-range_dim(itm.after)+1), Val(R)) | |
283 | |
284 view_index = (I_before..., ntuple((i)->:,domain_dim(itm.tm))..., I_after...) | |
285 inner_index = slice_tuple(I, Val(range_dim(itm.before)+1), Val(R-range_dim(itm.after))) | |
286 | |
287 return (view_index, inner_index) | |
288 end | |
289 | |
290 # TODO: Can this be replaced by something more elegant while still being type stable? 2020-10-21 | |
291 # See: | |
292 # https://github.com/JuliaLang/julia/issues/34884 | |
293 # https://github.com/JuliaLang/julia/issues/30386 | |
294 """ | |
295 slice_tuple(t, Val(l), Val(u)) | |
296 | |
297 Get a slice of a tuple in a type stable way. | |
298 Equivalent to t[l:u] but type stable. | |
299 """ | |
300 function slice_tuple(t,::Val{L},::Val{U}) where {L,U} | |
301 return ntuple(i->t[i+L-1], U-L+1) | |
302 end | |
303 | |
304 """ | |
305 flatten_tuple(t) | |
306 | |
307 Takes a nested tuple and flattens the whole structure | |
308 """ | |
309 flatten_tuple(t::NTuple{N, Number} where N) = t | |
310 flatten_tuple(t::Tuple) = ((flatten_tuple.(t)...)...,) # simplify? | |
311 flatten_tuple(ts::Vararg) = flatten_tuple(ts) | |
312 | |
313 """ | |
314 LazyOuterProduct(tms...) | |
315 | |
316 Creates a `TensorMappingComposition` for the outerproduct of `tms...`. | |
317 This is done by separating the outer product into regular products of outer products involving only identity mappings and one non-identity mapping. | |
318 | |
319 First let | |
320 ```math | |
321 A = A_{I,J} | |
322 B = B_{M,N} | |
323 C = C_{P,Q} | |
324 ``` | |
325 | |
326 where ``I``, ``M``, ``P`` are multi-indexes for the ranges of ``A``, ``B``, ``C``, and ``J``, ``N``, ``Q`` are multi-indexes of the domains. | |
327 | |
328 We use ``⊗`` to denote the outer product | |
329 ```math | |
330 (A⊗B)_{IM,JN} = A_{I,J}B_{M,N} | |
331 ``` | |
332 | |
333 We note that | |
334 ```math | |
335 A⊗B⊗C = (A⊗B⊗C)_{IMP,JNQ} = A_{I,J}B_{M,N}C_{P,Q} | |
336 ``` | |
337 And that | |
338 ```math | |
339 A⊗B⊗C = (A⊗I_{|M|}⊗I_{|P|})(I_{|J|}⊗B⊗I_{|P|})(I_{|J|}⊗I_{|N|}⊗C) | |
340 ``` | |
341 where |.| of a multi-index is a vector of sizes for each dimension. ``I_v`` denotes the identity tensor of size ``v[i]`` in each direction | |
342 To apply ``A⊗B⊗C`` we evaluate | |
343 | |
344 (A⊗B⊗C)v = [(A⊗I_{|M|}⊗I_{|P|}) [(I_{|J|}⊗B⊗I_{|P|}) [(I_{|J|}⊗I_{|N|}⊗C)v]]] | |
345 """ | |
346 function LazyOuterProduct end | |
347 export LazyOuterProduct | |
348 | |
349 function LazyOuterProduct(tm1::TensorMapping{T}, tm2::TensorMapping{T}) where T | |
350 itm1 = InflatedTensorMapping(tm1, IdentityMapping{T}(range_size(tm2))) | |
351 itm2 = InflatedTensorMapping(IdentityMapping{T}(domain_size(tm1)),tm2) | |
352 | |
353 return itm1∘itm2 | |
354 end | |
355 | |
356 LazyOuterProduct(t1::IdentityMapping{T}, t2::IdentityMapping{T}) where T = IdentityMapping{T}(t1.size...,t2.size...) | |
357 LazyOuterProduct(t1::TensorMapping, t2::IdentityMapping) = InflatedTensorMapping(t1, t2) | |
358 LazyOuterProduct(t1::IdentityMapping, t2::TensorMapping) = InflatedTensorMapping(t1, t2) | |
359 | |
360 LazyOuterProduct(tms::Vararg{TensorMapping}) = foldl(LazyOuterProduct, tms) | |
361 | |
362 ⊗(a::TensorMapping, b::TensorMapping) = LazyOuterProduct(a,b) | |
363 export ⊗ | |
364 | |
365 | |
366 function check_domain_size(tm::TensorMapping, sz) | |
367 if domain_size(tm) != sz | |
368 throw(SizeMismatch(tm,sz)) | |
369 end | |
370 end | |
371 | |
372 struct SizeMismatch <: Exception | |
373 tm::TensorMapping | |
374 sz | |
375 end | |
376 export SizeMismatch | |
377 | |
378 function Base.showerror(io::IO, err::SizeMismatch) | |
379 print(io, "SizeMismatch: ") | |
380 print(io, "domain size $(domain_size(err.tm)) of TensorMapping not matching size $(err.sz)") | |
381 end |