diff src/Grids/tensor_grid.jl @ 1421:69c9e6eae686

Merge bugfix/grids/complete_interface_impl
author Jonatan Werpers <jonatan@werpers.com>
date Fri, 25 Aug 2023 08:49:07 +0200
parents e82240df974d
children 48e16efaac7a 25af92b4a7ea af73340a8f0e e3a80ef08d09
line wrap: on
line diff
--- a/src/Grids/tensor_grid.jl	Tue Aug 22 21:57:38 2023 +0200
+++ b/src/Grids/tensor_grid.jl	Fri Aug 25 08:49:07 2023 +0200
@@ -31,6 +31,11 @@
     return CartesianIndices(szs)
 end
 
+function Base.axes(g::TensorGrid, d)
+    i, ld = grid_and_local_dim_index(ndims.(g.grids), d)
+    return axes(g.grids[i], ld)
+end
+
 # Iteration interface
 Base.iterate(g::TensorGrid) = iterate(Iterators.product(g.grids...)) |> _iterate_combine_coords
 Base.iterate(g::TensorGrid, state) = iterate(Iterators.product(g.grids...), state) |> _iterate_combine_coords
@@ -95,3 +100,27 @@
 function combine_coordinates(coords...)
     return mapreduce(SVector, vcat, coords)
 end
+
+"""
+   grid_and_local_dim_index(nds, d)
+
+Given a tuple of number of dimensions `nds`, and a global dimension index `d`,
+calculate which grid index, and local dimension, `d` corresponds to.
+
+`nds` would come from broadcasting `ndims` on the grids tuple of a
+`TensorGrid`. If you are interested in a dimension `d` of a tensor grid `g`
+```julia
+gi, ldi = grid_and_local_dim_index(ndims.(g.grids), d)
+```
+tells you which grid it belongs to (`gi`) and which index it is at within that
+grid (`ldi`).
+"""
+function grid_and_local_dim_index(nds, d)
+    I = findfirst(>=(d), cumsum(nds))
+
+    if I == 1
+        return (1, d)
+    else
+        return (I, d-cumsum(nds)[I-1])
+    end
+end