diff src/Grids/tensor_grid.jl @ 1428:a936b414283a feature/grids/curvilinear

Merge bugfix/grids/complete_interface_impl
author Jonatan Werpers <jonatan@werpers.com>
date Thu, 24 Aug 2023 08:50:10 +0200
parents 18e21601da2d
children e82240df974d
line wrap: on
line diff
--- a/src/Grids/tensor_grid.jl	Wed Aug 23 15:51:32 2023 +0200
+++ b/src/Grids/tensor_grid.jl	Thu Aug 24 08:50:10 2023 +0200
@@ -31,6 +31,11 @@
     return CartesianIndices(szs)
 end
 
+function Base.axes(g::TensorGrid, d)
+    i, ld = grid_and_local_dim_index(ndims.(g.grids), d)
+    return axes(g.grids[i], ld)
+end
+
 # Iteration interface
 Base.iterate(g::TensorGrid) = iterate(Iterators.product(g.grids...)) |> _iterate_combine_coords
 Base.iterate(g::TensorGrid, state) = iterate(Iterators.product(g.grids...), state) |> _iterate_combine_coords
@@ -95,3 +100,28 @@
 function combine_coordinates(coords...)
     return mapreduce(SVector, vcat, coords)
 end
+
+"""
+   grid_and_local_dim_index(nds, d)
+
+Given a tuple of number of dimensions `nds`, and a global dimension index `d`,
+calculate which grid index, and local dimension, `d` corresponds to.
+
+`nds` would come from broadcasting `ndims` on the grids tuple of a
+`TensorGrid`. If you are interested in a dimension `d` of a tensor grid `g`
+```julia
+gi, ldi = grid_and_local_dim_index(ndims.(g.grids), d)
+```
+tells you which grid it belongs to (`gi`) and wich index it is at within that
+grid (`ldi`).
+"""
+function grid_and_local_dim_index(nds, d)
+    I = findfirst(>=(d), cumsum(nds))
+
+    if I == 1
+        return (1, d)
+    else
+        return (I, d-cumsum(nds)[I-1])
+    end
+    # TBD: Is there a cleaner way to compute this?
+end