Mercurial > repos > public > sbplib_julia
changeset 1394:60857d8338cb bugfix/grids/complete_interface_impl
Merge default
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Tue, 15 Aug 2023 22:45:58 +0200 |
parents | 9da927271752 (diff) 7694b35d137d (current diff) |
children | 447833be2ecc |
files | src/Grids/tensor_grid.jl |
diffstat | 3 files changed, 71 insertions(+), 0 deletions(-) [+] |
line wrap: on
line diff
diff -r 7694b35d137d -r 60857d8338cb src/Grids/equidistant_grid.jl --- a/src/Grids/equidistant_grid.jl Tue Aug 15 22:42:19 2023 +0200 +++ b/src/Grids/equidistant_grid.jl Tue Aug 15 22:45:58 2023 +0200 @@ -20,6 +20,9 @@ Base.firstindex(g::EquidistantGrid) = firstindex(g.points) Base.lastindex(g::EquidistantGrid) = lastindex(g.points) +Base.firstindex(g::EquidistantGrid, d) = firstindex(g) +Base.lastindex(g::EquidistantGrid, d) = lastindex(g) + # Iteration interface Base.iterate(g::EquidistantGrid) = iterate(g.points) Base.iterate(g::EquidistantGrid, state) = iterate(g.points, state)
diff -r 7694b35d137d -r 60857d8338cb src/Grids/tensor_grid.jl --- a/src/Grids/tensor_grid.jl Tue Aug 15 22:42:19 2023 +0200 +++ b/src/Grids/tensor_grid.jl Tue Aug 15 22:45:58 2023 +0200 @@ -31,6 +31,17 @@ return CartesianIndices(szs) end +function Base.firstindex(g::TensorGrid, d) + i, ld = grid_and_local_dim_index(ndims.(g.grids), d) + return firstindex(g.grids[i], ld) +end + +function Base.lastindex(g::TensorGrid, d) + i, ld = grid_and_local_dim_index(ndims.(g.grids), d) + return lastindex(g.grids[i], ld) +end +# TBD: Should the two above functions be supported by implementing `axes` instead? + # Iteration interface Base.iterate(g::TensorGrid) = iterate(Iterators.product(g.grids...)) |> _iterate_combine_coords Base.iterate(g::TensorGrid, state) = iterate(Iterators.product(g.grids...), state) |> _iterate_combine_coords @@ -95,3 +106,20 @@ function combine_coordinates(coords...) return mapreduce(SVector, vcat, coords) end + +""" + grid_and_local_dim_index(nds, d) + +Given a tuple of number of dimensions `nds`, and a global dimension `d`, +calculate which grid index and local dimension `d` corresponds to. +""" +function grid_and_local_dim_index(nds, d) + I = findfirst(>=(d), cumsum(nds)) + + if I == 1 + return (1, d) + else + return (I, d-cumsum(nds)[I-1]) + end + # TBD: Is there a cleaner way to compute this? +end
diff -r 7694b35d137d -r 60857d8338cb test/Grids/tensor_grid_test.jl --- a/test/Grids/tensor_grid_test.jl Tue Aug 15 22:42:19 2023 +0200 +++ b/test/Grids/tensor_grid_test.jl Tue Aug 15 22:45:58 2023 +0200 @@ -59,6 +59,18 @@ @test eachindex(TensorGrid(g₁, g₄)) == CartesianIndices((11,)) @test eachindex(TensorGrid(g₁, g₄, g₂)) == CartesianIndices((11,6)) end + + @testset "firstindex" begin + @test firstindex(TensorGrid(g₁, g₂, g₃), 1) == 1 + @test firstindex(TensorGrid(g₁, g₂, g₃), 2) == 1 + @test firstindex(TensorGrid(g₁, g₂, g₃), 3) == 1 + end + + @testset "lastindex" begin + @test lastindex(TensorGrid(g₁, g₂, g₃), 1) == 11 + @test lastindex(TensorGrid(g₁, g₂, g₃), 2) == 6 + @test lastindex(TensorGrid(g₁, g₂, g₃), 3) == 10 + end end @testset "Iterator interface" begin @@ -144,3 +156,31 @@ @test Grids.combine_coordinates(1,@SVector[2.,3]) isa SVector{3, Float64} @test Grids.combine_coordinates(1,@SVector[2.,3]) == [1,2,3] end + +@testset "grid_and_local_dim_index" begin + cases = [ + ((1,), 1) => (1,1), + + ((1,1), 1) => (1,1), + ((1,1), 2) => (2,1), + + ((1,2), 1) => (1,1), + ((1,2), 2) => (2,1), + ((1,2), 3) => (2,2), + + ((2,1), 1) => (1,1), + ((2,1), 2) => (1,2), + ((2,1), 3) => (2,1), + + ((2,1,3), 1) => (1,1), + ((2,1,3), 2) => (1,2), + ((2,1,3), 3) => (2,1), + ((2,1,3), 4) => (3,1), + ((2,1,3), 5) => (3,2), + ((2,1,3), 6) => (3,3), + ] + + @testset "grid_and_local_dim_index$args" for (args, expected) ∈ cases + @test Grids.grid_and_local_dim_index(args...) == expected + end +end