comparison src/Grids/tensor_grid.jl @ 1428:a936b414283a feature/grids/curvilinear

Merge bugfix/grids/complete_interface_impl
author Jonatan Werpers <jonatan@werpers.com>
date Thu, 24 Aug 2023 08:50:10 +0200
parents 18e21601da2d
children e82240df974d
comparison
equal deleted inserted replaced
1427:26e168924cf1 1428:a936b414283a
27 end 27 end
28 28
29 function Base.eachindex(g::TensorGrid) 29 function Base.eachindex(g::TensorGrid)
30 szs = LazyTensors.concatenate_tuples(size.(g.grids)...) 30 szs = LazyTensors.concatenate_tuples(size.(g.grids)...)
31 return CartesianIndices(szs) 31 return CartesianIndices(szs)
32 end
33
34 function Base.axes(g::TensorGrid, d)
35 i, ld = grid_and_local_dim_index(ndims.(g.grids), d)
36 return axes(g.grids[i], ld)
32 end 37 end
33 38
34 # Iteration interface 39 # Iteration interface
35 Base.iterate(g::TensorGrid) = iterate(Iterators.product(g.grids...)) |> _iterate_combine_coords 40 Base.iterate(g::TensorGrid) = iterate(Iterators.product(g.grids...)) |> _iterate_combine_coords
36 Base.iterate(g::TensorGrid, state) = iterate(Iterators.product(g.grids...), state) |> _iterate_combine_coords 41 Base.iterate(g::TensorGrid, state) = iterate(Iterators.product(g.grids...), state) |> _iterate_combine_coords
93 end 98 end
94 99
95 function combine_coordinates(coords...) 100 function combine_coordinates(coords...)
96 return mapreduce(SVector, vcat, coords) 101 return mapreduce(SVector, vcat, coords)
97 end 102 end
103
104 """
105 grid_and_local_dim_index(nds, d)
106
107 Given a tuple of number of dimensions `nds`, and a global dimension index `d`,
108 calculate which grid index, and local dimension, `d` corresponds to.
109
110 `nds` would come from broadcasting `ndims` on the grids tuple of a
111 `TensorGrid`. If you are interested in a dimension `d` of a tensor grid `g`
112 ```julia
113 gi, ldi = grid_and_local_dim_index(ndims.(g.grids), d)
114 ```
115 tells you which grid it belongs to (`gi`) and wich index it is at within that
116 grid (`ldi`).
117 """
118 function grid_and_local_dim_index(nds, d)
119 I = findfirst(>=(d), cumsum(nds))
120
121 if I == 1
122 return (1, d)
123 else
124 return (I, d-cumsum(nds)[I-1])
125 end
126 # TBD: Is there a cleaner way to compute this?
127 end