Mercurial > repos > public > sbplib_julia
comparison src/Grids/tensor_grid.jl @ 1428:a936b414283a feature/grids/curvilinear
Merge bugfix/grids/complete_interface_impl
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Thu, 24 Aug 2023 08:50:10 +0200 |
parents | 18e21601da2d |
children | e82240df974d |
comparison
equal
deleted
inserted
replaced
1427:26e168924cf1 | 1428:a936b414283a |
---|---|
27 end | 27 end |
28 | 28 |
29 function Base.eachindex(g::TensorGrid) | 29 function Base.eachindex(g::TensorGrid) |
30 szs = LazyTensors.concatenate_tuples(size.(g.grids)...) | 30 szs = LazyTensors.concatenate_tuples(size.(g.grids)...) |
31 return CartesianIndices(szs) | 31 return CartesianIndices(szs) |
32 end | |
33 | |
34 function Base.axes(g::TensorGrid, d) | |
35 i, ld = grid_and_local_dim_index(ndims.(g.grids), d) | |
36 return axes(g.grids[i], ld) | |
32 end | 37 end |
33 | 38 |
34 # Iteration interface | 39 # Iteration interface |
35 Base.iterate(g::TensorGrid) = iterate(Iterators.product(g.grids...)) |> _iterate_combine_coords | 40 Base.iterate(g::TensorGrid) = iterate(Iterators.product(g.grids...)) |> _iterate_combine_coords |
36 Base.iterate(g::TensorGrid, state) = iterate(Iterators.product(g.grids...), state) |> _iterate_combine_coords | 41 Base.iterate(g::TensorGrid, state) = iterate(Iterators.product(g.grids...), state) |> _iterate_combine_coords |
93 end | 98 end |
94 | 99 |
95 function combine_coordinates(coords...) | 100 function combine_coordinates(coords...) |
96 return mapreduce(SVector, vcat, coords) | 101 return mapreduce(SVector, vcat, coords) |
97 end | 102 end |
103 | |
104 """ | |
105 grid_and_local_dim_index(nds, d) | |
106 | |
107 Given a tuple of number of dimensions `nds`, and a global dimension index `d`, | |
108 calculate which grid index, and local dimension, `d` corresponds to. | |
109 | |
110 `nds` would come from broadcasting `ndims` on the grids tuple of a | |
111 `TensorGrid`. If you are interested in a dimension `d` of a tensor grid `g` | |
112 ```julia | |
113 gi, ldi = grid_and_local_dim_index(ndims.(g.grids), d) | |
114 ``` | |
115 tells you which grid it belongs to (`gi`) and wich index it is at within that | |
116 grid (`ldi`). | |
117 """ | |
118 function grid_and_local_dim_index(nds, d) | |
119 I = findfirst(>=(d), cumsum(nds)) | |
120 | |
121 if I == 1 | |
122 return (1, d) | |
123 else | |
124 return (I, d-cumsum(nds)[I-1]) | |
125 end | |
126 # TBD: Is there a cleaner way to compute this? | |
127 end |