comparison src/LazyTensors/tuple_manipulation.jl @ 1230:723a583cef96 refactor/LazyTensors/tuple_manipulation

Improve examples for split_index and split_tuple
author Jonatan Werpers <jonatan@werpers.com>
date Sun, 19 Feb 2023 11:45:48 +0100
parents 8f4259fbd39c
children de6a9635f293
comparison
equal deleted inserted replaced
1229:8f4259fbd39c 1230:723a583cef96
2 split_index(dim_before, dim_view, dim_index, dim_after, I...) 2 split_index(dim_before, dim_view, dim_index, dim_after, I...)
3 3
4 Splits the multi-index `I` into two parts. One part which is expected to be 4 Splits the multi-index `I` into two parts. One part which is expected to be
5 used as a view, and one which is expected to be used as an index. 5 used as a view, and one which is expected to be used as an index.
6 Eg. 6 Eg.
7 ``` 7 ```julia-repl
8 split_index(1, 3, 2, 1, (1,2,3,4)) -> (1,:,:,:,4), (2,3) 8 julia> LazyTensors.split_index(1, 3, 2, 1, (1,2,3,4)...)
9 ((1, Colon(), Colon(), Colon(), 4), (2, 3))
9 ``` 10 ```
10 11
11 `dim_view` controls how many colons are in the view, and `dim_index` controls 12 `dim_view` controls how many colons are in the view, and `dim_index` controls
12 how many elements are extracted from the middle. 13 how many elements are extracted from the middle.
13 `dim_before` and `dim_after` decides the length of the index parts before and after the colons in the view index. 14 `dim_before` and `dim_after` decides the length of the index parts before and after the colons in the view index.
33 34
34 Split the tuple `t` into a set of tuples of the sizes given in `szs`. 35 Split the tuple `t` into a set of tuples of the sizes given in `szs`.
35 `sum(szs)` should equal `lenght(t)`. 36 `sum(szs)` should equal `lenght(t)`.
36 37
37 E.g 38 E.g
38 ```julia 39 ```julia-repl
39 split_tuple((1,2,3,4,5,6), (3,1,2)) -> (1,2,3),(4,),(5,6) 40 julia> LazyTensors.split_tuple((1,2,3,4,5,6), (3,1,2))
41 ((1, 2, 3), (4,), (5, 6))
40 ``` 42 ```
41 """ 43 """
42 function split_tuple(t, szs) 44 function split_tuple(t, szs)
43 @inline 45 @inline
44 if length(t) != sum(szs; init=0) 46 if length(t) != sum(szs; init=0)