Mercurial > repos > public > sbplib_julia
comparison src/LazyTensors/lazy_tensor_operations.jl @ 520:fe86ac896377 feature/inflated_tensormapping_transpose
Start refactoring split index and apply to accomodate future addition of apply_transpose
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Mon, 23 Nov 2020 21:30:11 +0100 |
parents | 4b9d124fe984 |
children | 7e6250c51eb2 |
comparison
equal
deleted
inserted
replaced
508:27e64b3d3efa | 520:fe86ac896377 |
---|---|
259 domain_size(itm.after), | 259 domain_size(itm.after), |
260 ) | 260 ) |
261 end | 261 end |
262 | 262 |
263 function apply(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D} | 263 function apply(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D} |
264 view_index, inner_index = split_index(itm, I...) | 264 A = range_dim(itm.before) |
265 B_domain = domain_dim(itm.tm) | |
266 B_range = range_dim(itm.tm) | |
267 C = range_dim(itm.after) | |
268 | |
269 view_index, inner_index = split_index(Val(A), Val(B_range), Val(B_domain), Val(C), I...) | |
265 | 270 |
266 v_inner = view(v, view_index...) | 271 v_inner = view(v, view_index...) |
267 return apply(itm.tm, v_inner, inner_index...) | 272 return apply(itm.tm, v_inner, inner_index...) |
268 end | 273 end |
269 | 274 |
270 | 275 |
271 """ | 276 """ |
272 split_index(...) | 277 split_index(:Val{A}, ::Val{B_view}, ::Val{B_middle}, ::Val{C}, I...) |
273 | 278 |
274 Splits the multi-index into two parts. One part for the view that the inner TensorMapping acts on, and one part for indexing the result | 279 Splits the multi-index `I` into two parts. One part which is expected to be used as a view, which is expected to be used as an index. |
275 Eg. | 280 Eg. |
276 ``` | 281 ``` |
277 (1,2,3,4) -> (1,:,:,4), (2,3) | 282 (1,2,3,4) -> (1,:,:,:,4), (2,3) |
278 ``` | 283 ``` |
279 """ | 284 |
280 function split_index(itm::InflatedTensorMapping{T,R,D}, I::Vararg{Any,R}) where {T,R,D} | 285 `B_view` controls how many colons are in the view, and `B_middle` controls how many elements are extracted from the middle. |
281 I_before = slice_tuple(I, Val(1), Val(range_dim(itm.before))) | 286 `A` and `C` decides the length of the parts before and after the colons in the view index. |
282 I_after = slice_tuple(I, Val(R-range_dim(itm.after)+1), Val(R)) | 287 length(I) == A+B_domain+C |
283 | 288 length(I_middle) == B_domain |
284 view_index = (I_before..., ntuple((i)->:,domain_dim(itm.tm))..., I_after...) | 289 length(I_view) == A + B_range + C |
285 inner_index = slice_tuple(I, Val(range_dim(itm.before)+1), Val(R-range_dim(itm.after))) | 290 |
286 | 291 TODO: Finish documentation. |
287 return (view_index, inner_index) | 292 """ |
293 function split_index(::Val{A}, ::Val{B_view}, ::Val{B_middle}, ::Val{C}, I...) where {A,B_view, B_middle,C} | |
294 I_before = slice_tuple(I, Val(1), Val(A)) | |
295 I_middle = slice_tuple(I, Val(A+1), Val(A+B_middle)) | |
296 I_after = slice_tuple(I, Val(A+B_middle+1), Val(A+B_middle+C)) | |
297 | |
298 view_index = (I_before..., ntuple((i)->:, B_view)..., I_after...) | |
299 inner_index = | |
300 | |
301 return view_index, I_middle | |
288 end | 302 end |
289 | 303 |
290 # TODO: Can this be replaced by something more elegant while still being type stable? 2020-10-21 | 304 # TODO: Can this be replaced by something more elegant while still being type stable? 2020-10-21 |
291 # See: | 305 # See: |
292 # https://github.com/JuliaLang/julia/issues/34884 | 306 # https://github.com/JuliaLang/julia/issues/34884 |