comparison src/LazyTensors/lazy_tensor_operations.jl @ 473:3041f8578bba

Merge in feature/inflated_tensormapping.
author Jonatan Werpers <jonatan@werpers.com>
date Mon, 02 Nov 2020 21:33:35 +0100
parents f270d82fc9ad
children 481e86e77c22 95f3b9036801
comparison
equal deleted inserted replaced
445:a79d7b3209c9 473:3041f8578bba
168 domain_size(tmi::IdentityMapping) = tmi.size 168 domain_size(tmi::IdentityMapping) = tmi.size
169 169
170 apply(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...] 170 apply(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...]
171 apply_transpose(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...] 171 apply_transpose(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...]
172 172
173
174 """
175 InflatedTensorMapping{T,R,D} <: TensorMapping{T,R,D}
176
177 An inflated `TensorMapping` with dimensions added before and afer its actual dimensions.
178 """
179 struct InflatedTensorMapping{T,R,D,D_before,R_middle,D_middle,D_after, TM<:TensorMapping{T,R_middle,D_middle}} <: TensorMapping{T,R,D}
180 before::IdentityMapping{T,D_before}
181 tm::TM
182 after::IdentityMapping{T,D_after}
183
184 function InflatedTensorMapping(before, tm::TensorMapping{T}, after) where T
185 R_before = range_dim(before)
186 R_middle = range_dim(tm)
187 R_after = range_dim(after)
188 R = R_before+R_middle+R_after
189
190 D_before = domain_dim(before)
191 D_middle = domain_dim(tm)
192 D_after = domain_dim(after)
193 D = D_before+D_middle+D_after
194 return new{T,R,D,D_before,R_middle,D_middle,D_after, typeof(tm)}(before, tm, after)
195 end
196 end
197 export InflatedTensorMapping
198 """
199 InflatedTensorMapping(before, tm, after)
200 InflatedTensorMapping(before,tm)
201 InflatedTensorMapping(tm,after)
202
203 The outer product of `before`, `tm` and `after`, where `before` and `after` are `IdentityMapping`s.
204
205 If one of `before` or `after` is left out, a 0-dimensional `IdentityMapping` is used as the default value.
206 """
207 InflatedTensorMapping(::IdentityMapping, ::TensorMapping, ::IdentityMapping)
208 InflatedTensorMapping(before::IdentityMapping, tm::TensorMapping{T}) where T = InflatedTensorMapping(before,tm,IdentityMapping{T}())
209 InflatedTensorMapping(tm::TensorMapping{T}, after::IdentityMapping) where T = InflatedTensorMapping(IdentityMapping{T}(),tm,after)
210 # Resolve ambiguity between the two previous methods
211 InflatedTensorMapping(I1::IdentityMapping{T}, I2::IdentityMapping{T}) where T = InflatedTensorMapping(I1,I2,IdentityMapping{T}())
212
213 # TODO: Implement syntax and constructors for products of different combinations of InflatedTensorMapping and IdentityMapping
214
215 # TODO: Implement some pretty printing in terms of ⊗. E.g InflatedTensorMapping(I(3),B,I(2)) -> I(3)⊗B⊗I(2)
216
217 function range_size(itm::InflatedTensorMapping)
218 return flatten_tuple(
219 range_size(itm.before),
220 range_size(itm.tm),
221 range_size(itm.after),
222 )
223 end
224
225 function domain_size(itm::InflatedTensorMapping)
226 return flatten_tuple(
227 domain_size(itm.before),
228 domain_size(itm.tm),
229 domain_size(itm.after),
230 )
231 end
232
233 function apply(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D}
234 view_index, inner_index = split_index(itm, I...)
235
236 v_inner = view(v, view_index...)
237 return apply(itm.tm, v_inner, inner_index...)
238 end
239
240
241 """
242 split_index(...)
243
244 Splits the multi-index into two parts. One part for the view that the inner TensorMapping acts on, and one part for indexing the result
245 Eg.
246 ```
247 (1,2,3,4) -> (1,:,:,4), (2,3)
248 ```
249 """
250 function split_index(itm::InflatedTensorMapping{T,R,D}, I::Vararg{Any,R}) where {T,R,D}
251 I_before = slice_tuple(I, Val(1), Val(range_dim(itm.before)))
252 I_after = slice_tuple(I, Val(R-range_dim(itm.after)+1), Val(R))
253
254 view_index = (I_before..., ntuple((i)->:,domain_dim(itm.tm))..., I_after...)
255 inner_index = slice_tuple(I, Val(range_dim(itm.before)+1), Val(R-range_dim(itm.after)))
256
257 return (view_index, inner_index)
258 end
259
260 # TODO: Can this be replaced by something more elegant while still being type stable? 2020-10-21
261 # See:
262 # https://github.com/JuliaLang/julia/issues/34884
263 # https://github.com/JuliaLang/julia/issues/30386
264 """
265 slice_tuple(t, Val(l), Val(u))
266
267 Get a slice of a tuple in a type stable way.
268 Equivalent to t[l:u] but type stable.
269 """
270 function slice_tuple(t,::Val{L},::Val{U}) where {L,U}
271 return ntuple(i->t[i+L-1], U-L+1)
272 end
273
274 """
275 flatten_tuple(t)
276
277 Takes a nested tuple and flattens the whole structure
278 """
279 flatten_tuple(t::NTuple{N, Number} where N) = t
280 flatten_tuple(t::Tuple) = ((flatten_tuple.(t)...)...,) # simplify?
281 flatten_tuple(ts::Vararg) = flatten_tuple(ts)