diff src/LazyTensors/lazy_tensor_operations.jl @ 532:588a843907de feature/inflated_tensormapping_transpose

Add a split_tuple function to make split_index more readable
author Jonatan Werpers <jonatan@werpers.com>
date Thu, 26 Nov 2020 16:13:58 +0100
parents 36dfc57e8e0b
children aac7cc1fa79a
line wrap: on
line diff
--- a/src/LazyTensors/lazy_tensor_operations.jl	Wed Nov 25 22:03:26 2020 +0100
+++ b/src/LazyTensors/lazy_tensor_operations.jl	Thu Nov 26 16:13:58 2020 +0100
@@ -302,9 +302,7 @@
  * `length(I_middle) == B_middle`
 """
 function split_index(::Val{A}, ::Val{B_view}, ::Val{B_middle}, ::Val{C}, I...) where {A,B_view, B_middle,C}
-    I_before = slice_tuple(I, Val(1), Val(A))
-    I_middle = slice_tuple(I, Val(A+1), Val(A+B_middle))
-    I_after = slice_tuple(I, Val(A+B_middle+1), Val(A+B_middle+C))
+    I_before, I_middle, I_after = split_tuple(I, Val(A), Val(B_middle))
 
     view_index = (I_before..., ntuple((i)->:, B_view)..., I_after...)
 
@@ -326,6 +324,33 @@
 end
 
 """
+    split_tuple(t::Tuple{...}, ::Val{M}) where {N,M}
+
+Split the tuple `t` into two parts. the first part is `M` long.
+E.g
+```
+split_tuple((1,2,3,4),Val(3)) -> (1,2,3), (4,)
+```
+"""
+function split_tuple(t::NTuple{N},::Val{M}) where {N,M}
+    return slice_tuple(t,Val(1), Val(M)), slice_tuple(t,Val(M+1), Val(N))
+end
+
+"""
+    split_tuple(t::Tuple{...},::Val{M},::Val{K}) where {N,M,K}
+
+Same as `split_tuple(t::NTuple{N},::Val{M})` but splits the tuple in three parts. With the first
+two parts having lenght `M` and `K`.
+"""
+function split_tuple(t::NTuple{N},::Val{M},::Val{K}) where {N,M,K}
+    p1, tail = split_tuple(t, Val(M))
+    p2, p3 = split_tuple(tail, Val(K))
+    return p1,p2,p3
+end
+
+
+
+"""
     flatten_tuple(t)
 
 Takes a nested tuple and flattens the whole structure