comparison src/LazyTensors/lazy_tensor_operations.jl @ 532:588a843907de feature/inflated_tensormapping_transpose

Add a split_tuple function to make split_index more readable
author Jonatan Werpers <jonatan@werpers.com>
date Thu, 26 Nov 2020 16:13:58 +0100
parents 36dfc57e8e0b
children aac7cc1fa79a
comparison
equal deleted inserted replaced
531:36dfc57e8e0b 532:588a843907de
300 The returned values satisfy 300 The returned values satisfy
301 * `length(view_index) == A + B_view + C` 301 * `length(view_index) == A + B_view + C`
302 * `length(I_middle) == B_middle` 302 * `length(I_middle) == B_middle`
303 """ 303 """
304 function split_index(::Val{A}, ::Val{B_view}, ::Val{B_middle}, ::Val{C}, I...) where {A,B_view, B_middle,C} 304 function split_index(::Val{A}, ::Val{B_view}, ::Val{B_middle}, ::Val{C}, I...) where {A,B_view, B_middle,C}
305 I_before = slice_tuple(I, Val(1), Val(A)) 305 I_before, I_middle, I_after = split_tuple(I, Val(A), Val(B_middle))
306 I_middle = slice_tuple(I, Val(A+1), Val(A+B_middle))
307 I_after = slice_tuple(I, Val(A+B_middle+1), Val(A+B_middle+C))
308 306
309 view_index = (I_before..., ntuple((i)->:, B_view)..., I_after...) 307 view_index = (I_before..., ntuple((i)->:, B_view)..., I_after...)
310 308
311 return view_index, I_middle 309 return view_index, I_middle
312 end 310 end
322 Equivalent to t[l:u] but type stable. 320 Equivalent to t[l:u] but type stable.
323 """ 321 """
324 function slice_tuple(t,::Val{L},::Val{U}) where {L,U} 322 function slice_tuple(t,::Val{L},::Val{U}) where {L,U}
325 return ntuple(i->t[i+L-1], U-L+1) 323 return ntuple(i->t[i+L-1], U-L+1)
326 end 324 end
325
326 """
327 split_tuple(t::Tuple{...}, ::Val{M}) where {N,M}
328
329 Split the tuple `t` into two parts. the first part is `M` long.
330 E.g
331 ```
332 split_tuple((1,2,3,4),Val(3)) -> (1,2,3), (4,)
333 ```
334 """
335 function split_tuple(t::NTuple{N},::Val{M}) where {N,M}
336 return slice_tuple(t,Val(1), Val(M)), slice_tuple(t,Val(M+1), Val(N))
337 end
338
339 """
340 split_tuple(t::Tuple{...},::Val{M},::Val{K}) where {N,M,K}
341
342 Same as `split_tuple(t::NTuple{N},::Val{M})` but splits the tuple in three parts. With the first
343 two parts having lenght `M` and `K`.
344 """
345 function split_tuple(t::NTuple{N},::Val{M},::Val{K}) where {N,M,K}
346 p1, tail = split_tuple(t, Val(M))
347 p2, p3 = split_tuple(tail, Val(K))
348 return p1,p2,p3
349 end
350
351
327 352
328 """ 353 """
329 flatten_tuple(t) 354 flatten_tuple(t)
330 355
331 Takes a nested tuple and flattens the whole structure 356 Takes a nested tuple and flattens the whole structure