diff src/LazyTensors/lazy_tensor_operations.jl @ 533:aac7cc1fa79a feature/inflated_tensormapping_transpose

Try to improve the naming in split_index()
author Jonatan Werpers <jonatan@werpers.com>
date Thu, 26 Nov 2020 17:18:32 +0100
parents 588a843907de
children 41e82a5d4d48
line wrap: on
line diff
--- a/src/LazyTensors/lazy_tensor_operations.jl	Thu Nov 26 16:13:58 2020 +0100
+++ b/src/LazyTensors/lazy_tensor_operations.jl	Thu Nov 26 17:18:32 2020 +0100
@@ -284,27 +284,29 @@
 
 
 """
-    split_index(::Val{A}, ::Val{B_view}, ::Val{B_middle}, ::Val{C}, I...)
+    split_index(::Val{dim_before}, ::Val{dim_view}, ::Val{dim_index}, ::Val{dim_after}, I...)
 
-Splits the multi-index `I` into two parts. One part which is expected to be used as a view, and one which is expected to be used as an index.
+Splits the multi-index `I` into two parts. One part which is expected to be
+used as a view, and one which is expected to be used as an index.
 Eg.
 ```
 (1,2,3,4) -> (1,:,:,:,4), (2,3)
 ```
 
-`B_view` controls how many colons are in the view, and `B_middle` controls how many elements are extracted from the middle.
-`A` and `C` decides the length of the index parts before and after the colons in the view index.
+`dim_view` controls how many colons are in the view, and `dim_index` controls
+how many elements are extracted from the middle.
+`dim_before` and `dim_after` decides the length of the index parts before and after the colons in the view index.
 
-Arguments should satisfy `length(I) == A+B_domain+C`.
+Arguments should satisfy `length(I) == dim_before+B_domain+dim_after`.
 
 The returned values satisfy
- * `length(view_index) == A + B_view + C`
- * `length(I_middle) == B_middle`
+ * `length(view_index) == dim_before + dim_view + dim_after`
+ * `length(I_middle) == dim_index`
 """
-function split_index(::Val{A}, ::Val{B_view}, ::Val{B_middle}, ::Val{C}, I...) where {A,B_view, B_middle,C}
-    I_before, I_middle, I_after = split_tuple(I, Val(A), Val(B_middle))
+function split_index(::Val{dim_before}, ::Val{dim_view}, ::Val{dim_index}, ::Val{dim_after}, I...) where {dim_before,dim_view, dim_index,dim_after}
+    I_before, I_middle, I_after = split_tuple(I, Val(dim_before), Val(dim_index))
 
-    view_index = (I_before..., ntuple((i)->:, B_view)..., I_after...)
+    view_index = (I_before..., ntuple((i)->:, dim_view)..., I_after...)
 
     return view_index, I_middle
 end
@@ -349,7 +351,6 @@
 end
 
 
-
 """
     flatten_tuple(t)