comparison src/LazyTensors/tuple_manipulation.jl @ 1223:5bfb182e24dc refactor/LazyTensors/tuple_manipulation

Start adding simpler code
author Jonatan Werpers <jonatan@werpers.com>
date Sat, 18 Feb 2023 12:06:59 +0100
parents 07c213167f7c
children 6567e38b05ca
comparison
equal deleted inserted replaced
1219:7ee258e5289e 1223:5bfb182e24dc
23 23
24 view_index = (I_before..., ntuple((i)->:, dim_view)..., I_after...) 24 view_index = (I_before..., ntuple((i)->:, dim_view)..., I_after...)
25 25
26 return view_index, I_middle 26 return view_index, I_middle
27 end 27 end
28 # TBD: If the nice split_tuple works, can this be cleaned up as well?
28 29
29 # TODO: Can this be replaced by something more elegant while still being type stable? 2020-10-21 30 # TODO: Can this be replaced by something more elegant while still being type stable? 2020-10-21
30 # See: 31 # See:
31 # https://github.com/JuliaLang/julia/issues/34884 32 # https://github.com/JuliaLang/julia/issues/34884
32 # https://github.com/JuliaLang/julia/issues/30386 33 # https://github.com/JuliaLang/julia/issues/30386
63 p1, tail = split_tuple(t, Val(M)) 64 p1, tail = split_tuple(t, Val(M))
64 p2, p3 = split_tuple(tail, Val(K)) 65 p2, p3 = split_tuple(tail, Val(K))
65 return p1,p2,p3 66 return p1,p2,p3
66 end 67 end
67 68
69 # TBD Are the above defs even needed? Can the below one be used without problems?
70
71 """
72 split_tuple(t, szs)
73
74 Split the tuple `t` into a set of tuples of the sizes given in `szs`.
75 `sum(szs)` should equal `lenght(t)`.
76 """
77 function split_tuple(t, szs)
78 if length(t) != sum(szs; init=0)
79 throw(ArgumentError("length(t) must equal sum(szs)"))
80 end
81
82 rs = sizes_to_ranges(szs)
83 return map(r->t[r], rs)
84 end
85
86 function sizes_to_ranges(szs)
87 cum_szs = cumsum((0, szs...))
88 return ntuple(i->cum_szs[i]+1:cum_szs[i+1], length(szs))
89 end
90
91
92 concatenate_tuples(t::Tuple,ts::Vararg{Tuple}) = (t..., concatenate_tuples(ts...)...)
93 concatenate_tuples(t::Tuple) = t
68 94
69 """ 95 """
70 flatten_tuple(t) 96 flatten_tuple(t)
71 97
72 Takes a nested tuple and flattens the whole structure 98 Takes a nested tuple and flattens the whole structure
73 """ 99 """
74 flatten_tuple(t::NTuple{N, Number} where N) = t 100 flatten_tuple(t::NTuple{N, Number} where N) = t
75 flatten_tuple(t::Tuple) = ((flatten_tuple.(t)...)...,) # simplify? 101 flatten_tuple(t::Tuple) = ((flatten_tuple.(t)...)...,) # simplify?
76 flatten_tuple(ts::Vararg) = flatten_tuple(ts) 102 flatten_tuple(ts::Vararg) = flatten_tuple(ts)
103 # TBD: Can concatenate_tuples be used instead?
77 104
78 """ 105 """
79 left_pad_tuple(t, val, N) 106 left_pad_tuple(t, val, N)
80 107
81 Left pad the tuple `t` to length `N` using the value `val`. 108 Left pad the tuple `t` to length `N` using the value `val`.