Mercurial > repos > public > sbplib_julia
comparison src/LazyTensors/lazy_tensor_operations.jl @ 473:3041f8578bba
Merge in feature/inflated_tensormapping.
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Mon, 02 Nov 2020 21:33:35 +0100 |
parents | f270d82fc9ad |
children | 481e86e77c22 95f3b9036801 |
comparison
equal
deleted
inserted
replaced
445:a79d7b3209c9 | 473:3041f8578bba |
---|---|
168 domain_size(tmi::IdentityMapping) = tmi.size | 168 domain_size(tmi::IdentityMapping) = tmi.size |
169 | 169 |
170 apply(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...] | 170 apply(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...] |
171 apply_transpose(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...] | 171 apply_transpose(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...] |
172 | 172 |
173 | |
174 """ | |
175 InflatedTensorMapping{T,R,D} <: TensorMapping{T,R,D} | |
176 | |
177 An inflated `TensorMapping` with dimensions added before and afer its actual dimensions. | |
178 """ | |
179 struct InflatedTensorMapping{T,R,D,D_before,R_middle,D_middle,D_after, TM<:TensorMapping{T,R_middle,D_middle}} <: TensorMapping{T,R,D} | |
180 before::IdentityMapping{T,D_before} | |
181 tm::TM | |
182 after::IdentityMapping{T,D_after} | |
183 | |
184 function InflatedTensorMapping(before, tm::TensorMapping{T}, after) where T | |
185 R_before = range_dim(before) | |
186 R_middle = range_dim(tm) | |
187 R_after = range_dim(after) | |
188 R = R_before+R_middle+R_after | |
189 | |
190 D_before = domain_dim(before) | |
191 D_middle = domain_dim(tm) | |
192 D_after = domain_dim(after) | |
193 D = D_before+D_middle+D_after | |
194 return new{T,R,D,D_before,R_middle,D_middle,D_after, typeof(tm)}(before, tm, after) | |
195 end | |
196 end | |
197 export InflatedTensorMapping | |
198 """ | |
199 InflatedTensorMapping(before, tm, after) | |
200 InflatedTensorMapping(before,tm) | |
201 InflatedTensorMapping(tm,after) | |
202 | |
203 The outer product of `before`, `tm` and `after`, where `before` and `after` are `IdentityMapping`s. | |
204 | |
205 If one of `before` or `after` is left out, a 0-dimensional `IdentityMapping` is used as the default value. | |
206 """ | |
207 InflatedTensorMapping(::IdentityMapping, ::TensorMapping, ::IdentityMapping) | |
208 InflatedTensorMapping(before::IdentityMapping, tm::TensorMapping{T}) where T = InflatedTensorMapping(before,tm,IdentityMapping{T}()) | |
209 InflatedTensorMapping(tm::TensorMapping{T}, after::IdentityMapping) where T = InflatedTensorMapping(IdentityMapping{T}(),tm,after) | |
210 # Resolve ambiguity between the two previous methods | |
211 InflatedTensorMapping(I1::IdentityMapping{T}, I2::IdentityMapping{T}) where T = InflatedTensorMapping(I1,I2,IdentityMapping{T}()) | |
212 | |
213 # TODO: Implement syntax and constructors for products of different combinations of InflatedTensorMapping and IdentityMapping | |
214 | |
215 # TODO: Implement some pretty printing in terms of ⊗. E.g InflatedTensorMapping(I(3),B,I(2)) -> I(3)⊗B⊗I(2) | |
216 | |
217 function range_size(itm::InflatedTensorMapping) | |
218 return flatten_tuple( | |
219 range_size(itm.before), | |
220 range_size(itm.tm), | |
221 range_size(itm.after), | |
222 ) | |
223 end | |
224 | |
225 function domain_size(itm::InflatedTensorMapping) | |
226 return flatten_tuple( | |
227 domain_size(itm.before), | |
228 domain_size(itm.tm), | |
229 domain_size(itm.after), | |
230 ) | |
231 end | |
232 | |
233 function apply(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D} | |
234 view_index, inner_index = split_index(itm, I...) | |
235 | |
236 v_inner = view(v, view_index...) | |
237 return apply(itm.tm, v_inner, inner_index...) | |
238 end | |
239 | |
240 | |
241 """ | |
242 split_index(...) | |
243 | |
244 Splits the multi-index into two parts. One part for the view that the inner TensorMapping acts on, and one part for indexing the result | |
245 Eg. | |
246 ``` | |
247 (1,2,3,4) -> (1,:,:,4), (2,3) | |
248 ``` | |
249 """ | |
250 function split_index(itm::InflatedTensorMapping{T,R,D}, I::Vararg{Any,R}) where {T,R,D} | |
251 I_before = slice_tuple(I, Val(1), Val(range_dim(itm.before))) | |
252 I_after = slice_tuple(I, Val(R-range_dim(itm.after)+1), Val(R)) | |
253 | |
254 view_index = (I_before..., ntuple((i)->:,domain_dim(itm.tm))..., I_after...) | |
255 inner_index = slice_tuple(I, Val(range_dim(itm.before)+1), Val(R-range_dim(itm.after))) | |
256 | |
257 return (view_index, inner_index) | |
258 end | |
259 | |
260 # TODO: Can this be replaced by something more elegant while still being type stable? 2020-10-21 | |
261 # See: | |
262 # https://github.com/JuliaLang/julia/issues/34884 | |
263 # https://github.com/JuliaLang/julia/issues/30386 | |
264 """ | |
265 slice_tuple(t, Val(l), Val(u)) | |
266 | |
267 Get a slice of a tuple in a type stable way. | |
268 Equivalent to t[l:u] but type stable. | |
269 """ | |
270 function slice_tuple(t,::Val{L},::Val{U}) where {L,U} | |
271 return ntuple(i->t[i+L-1], U-L+1) | |
272 end | |
273 | |
274 """ | |
275 flatten_tuple(t) | |
276 | |
277 Takes a nested tuple and flattens the whole structure | |
278 """ | |
279 flatten_tuple(t::NTuple{N, Number} where N) = t | |
280 flatten_tuple(t::Tuple) = ((flatten_tuple.(t)...)...,) # simplify? | |
281 flatten_tuple(ts::Vararg) = flatten_tuple(ts) |