Mercurial > repos > public > sbplib_julia
comparison src/LazyTensors/tuple_manipulation.jl @ 1223:5bfb182e24dc refactor/LazyTensors/tuple_manipulation
Start adding simpler code
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Sat, 18 Feb 2023 12:06:59 +0100 |
parents | 07c213167f7c |
children | 6567e38b05ca |
comparison
equal
deleted
inserted
replaced
1219:7ee258e5289e | 1223:5bfb182e24dc |
---|---|
23 | 23 |
24 view_index = (I_before..., ntuple((i)->:, dim_view)..., I_after...) | 24 view_index = (I_before..., ntuple((i)->:, dim_view)..., I_after...) |
25 | 25 |
26 return view_index, I_middle | 26 return view_index, I_middle |
27 end | 27 end |
28 # TBD: If the nice split_tuple works, can this be cleaned up as well? | |
28 | 29 |
29 # TODO: Can this be replaced by something more elegant while still being type stable? 2020-10-21 | 30 # TODO: Can this be replaced by something more elegant while still being type stable? 2020-10-21 |
30 # See: | 31 # See: |
31 # https://github.com/JuliaLang/julia/issues/34884 | 32 # https://github.com/JuliaLang/julia/issues/34884 |
32 # https://github.com/JuliaLang/julia/issues/30386 | 33 # https://github.com/JuliaLang/julia/issues/30386 |
63 p1, tail = split_tuple(t, Val(M)) | 64 p1, tail = split_tuple(t, Val(M)) |
64 p2, p3 = split_tuple(tail, Val(K)) | 65 p2, p3 = split_tuple(tail, Val(K)) |
65 return p1,p2,p3 | 66 return p1,p2,p3 |
66 end | 67 end |
67 | 68 |
69 # TBD Are the above defs even needed? Can the below one be used without problems? | |
70 | |
71 """ | |
72 split_tuple(t, szs) | |
73 | |
74 Split the tuple `t` into a set of tuples of the sizes given in `szs`. | |
75 `sum(szs)` should equal `lenght(t)`. | |
76 """ | |
77 function split_tuple(t, szs) | |
78 if length(t) != sum(szs; init=0) | |
79 throw(ArgumentError("length(t) must equal sum(szs)")) | |
80 end | |
81 | |
82 rs = sizes_to_ranges(szs) | |
83 return map(r->t[r], rs) | |
84 end | |
85 | |
86 function sizes_to_ranges(szs) | |
87 cum_szs = cumsum((0, szs...)) | |
88 return ntuple(i->cum_szs[i]+1:cum_szs[i+1], length(szs)) | |
89 end | |
90 | |
91 | |
92 concatenate_tuples(t::Tuple,ts::Vararg{Tuple}) = (t..., concatenate_tuples(ts...)...) | |
93 concatenate_tuples(t::Tuple) = t | |
68 | 94 |
69 """ | 95 """ |
70 flatten_tuple(t) | 96 flatten_tuple(t) |
71 | 97 |
72 Takes a nested tuple and flattens the whole structure | 98 Takes a nested tuple and flattens the whole structure |
73 """ | 99 """ |
74 flatten_tuple(t::NTuple{N, Number} where N) = t | 100 flatten_tuple(t::NTuple{N, Number} where N) = t |
75 flatten_tuple(t::Tuple) = ((flatten_tuple.(t)...)...,) # simplify? | 101 flatten_tuple(t::Tuple) = ((flatten_tuple.(t)...)...,) # simplify? |
76 flatten_tuple(ts::Vararg) = flatten_tuple(ts) | 102 flatten_tuple(ts::Vararg) = flatten_tuple(ts) |
103 # TBD: Can concatenate_tuples be used instead? | |
77 | 104 |
78 """ | 105 """ |
79 left_pad_tuple(t, val, N) | 106 left_pad_tuple(t, val, N) |
80 | 107 |
81 Left pad the tuple `t` to length `N` using the value `val`. | 108 Left pad the tuple `t` to length `N` using the value `val`. |