diff src/Grids/tensor_grid.jl @ 1854:654a2b7e6824 tooling/benchmarks

Merge default
author Jonatan Werpers <jonatan@werpers.com>
date Sat, 11 Jan 2025 10:19:47 +0100
parents 054447ac4b0e
children 516eaabf1169
line wrap: on
line diff
--- a/src/Grids/tensor_grid.jl	Wed May 31 08:59:34 2023 +0200
+++ b/src/Grids/tensor_grid.jl	Sat Jan 11 10:19:47 2025 +0100
@@ -17,7 +17,7 @@
 end
 
 # Indexing interface
-function Base.getindex(g::TensorGrid, I...)
+function Base.getindex(g::TensorGrid, I::Vararg{Int})
     szs = ndims.(g.grids)
 
     Is = LazyTensors.split_tuple(I, szs)
@@ -26,13 +26,16 @@
     return vcat(ps...)
 end
 
-Base.getindex(g::TensorGrid, I::CartesianIndex) = g[Tuple(I)...]
-
 function Base.eachindex(g::TensorGrid)
     szs = LazyTensors.concatenate_tuples(size.(g.grids)...)
     return CartesianIndices(szs)
 end
 
+function Base.axes(g::TensorGrid, d)
+    i, ld = grid_and_local_dim_index(ndims.(g.grids), d)
+    return axes(g.grids[i], ld)
+end
+
 # Iteration interface
 Base.iterate(g::TensorGrid) = iterate(Iterators.product(g.grids...)) |> _iterate_combine_coords
 Base.iterate(g::TensorGrid, state) = iterate(Iterators.product(g.grids...), state) |> _iterate_combine_coords
@@ -40,10 +43,20 @@
 _iterate_combine_coords((next,state)) = combine_coordinates(next...), state
 
 Base.IteratorSize(::Type{<:TensorGrid{<:Any, D}}) where D = Base.HasShape{D}()
-Base.eltype(::Type{<:TensorGrid{T}}) where T = T
-Base.length(g::TensorGrid) = sum(length, g.grids)
+Base.length(g::TensorGrid) = prod(length, g.grids)
 Base.size(g::TensorGrid) = LazyTensors.concatenate_tuples(size.(g.grids)...)
+Base.size(g::TensorGrid, d) = size(g)[d]
 
+function spacing(g::TensorGrid)
+    relevant_grids = filter(g->!isa(g,ZeroDimGrid),g.grids)
+    return spacing.(relevant_grids)
+end
+
+function min_spacing(g::TensorGrid)
+    relevant_grids = filter(g->!isa(g,ZeroDimGrid),g.grids)
+    d = min_spacing.(relevant_grids)
+    return minimum(d)
+end
 
 refine(g::TensorGrid, r::Int) = mapreduce(g->refine(g,r), TensorGrid, g.grids)
 coarsen(g::TensorGrid, r::Int) = mapreduce(g->coarsen(g,r), TensorGrid, g.grids)
@@ -70,7 +83,6 @@
     return LazyTensors.concatenate_tuples(per_grid...)
 end
 
-
 """
     boundary_grid(g::TensorGrid, id::TensorGridBoundary)
 
@@ -82,6 +94,16 @@
     return TensorGrid(new_grids...)
 end
 
+function boundary_indices(g::TensorGrid, id::TensorGridBoundary)
+    per_grid_ind = map(g.grids) do g
+        ntuple(i->:, ndims(g))
+    end
+
+    local_b_ind = boundary_indices(g.grids[grid_id(id)], boundary_id(id))
+    b_ind = Base.setindex(per_grid_ind, local_b_ind, grid_id(id))
+
+    return LazyTensors.concatenate_tuples(b_ind...)
+end
 
 function combined_coordinate_vector_type(coordinate_types...)
     combined_coord_length = mapreduce(_ncomponents, +, coordinate_types)
@@ -97,3 +119,27 @@
 function combine_coordinates(coords...)
     return mapreduce(SVector, vcat, coords)
 end
+
+"""
+    grid_and_local_dim_index(nds, d)
+
+Given a tuple of number of dimensions `nds`, and a global dimension index `d`,
+calculate which grid index, and local dimension, `d` corresponds to.
+
+`nds` would come from broadcasting `ndims` on the grids tuple of a
+`TensorGrid`. If you are interested in a dimension `d` of a tensor grid `g`
+```julia
+gi, ldi = grid_and_local_dim_index(ndims.(g.grids), d)
+```
+tells you which grid it belongs to (`gi`) and which index it is at within that
+grid (`ldi`).
+"""
+function grid_and_local_dim_index(nds, d)
+    I = findfirst(>=(d), cumsum(nds))
+
+    if I == 1
+        return (1, d)
+    else
+        return (I, d-cumsum(nds)[I-1])
+    end
+end