Mercurial > repos > public > sbplib_julia
comparison src/Grids/tensor_grid.jl @ 1266:a4ddae8b5d49 refactor/grids
Add tests for TensorGrid and make them pass
| author | Jonatan Werpers <jonatan@werpers.com> |
|---|---|
| date | Fri, 24 Feb 2023 21:42:28 +0100 |
| parents | 198ccda331a6 |
| children | dcbac783e4c1 |
comparison
equal
deleted
inserted
replaced
| 1265:9c9ea2900250 | 1266:a4ddae8b5d49 |
|---|---|
| 1 """ | |
| 2 TensorGrid{T,D} <: Grid{T,D} | |
| 3 | |
| 4 * Only supports HasShape grids at the moment | |
| 5 """ | |
| 1 struct TensorGrid{T,D,GT<:NTuple{N,Grid} where N} <: Grid{T,D} | 6 struct TensorGrid{T,D,GT<:NTuple{N,Grid} where N} <: Grid{T,D} |
| 2 grids::GT | 7 grids::GT |
| 3 | 8 |
| 4 function TensorGrid(gs...) | 9 function TensorGrid(gs...) |
| 5 T = eltype(gs[1]) # All gs should have the same T | 10 # T = combined_coordinate_vector_type(eltype.(gs)...) |
| 6 D = sum(ndims,gs) | 11 T = mapreduce(eltype, combined_coordinate_vector_type, gs) |
| 12 D = sum(ndims, gs) | |
| 7 | 13 |
| 8 return new{T,D,typeof(gs)}(gs) | 14 return new{T,D,typeof(gs)}(gs) |
| 9 end | 15 end |
| 10 end | 16 end |
| 11 | 17 |
| 12 # Indexing interface | 18 # Indexing interface |
| 13 # TODO | |
| 14 # Iteration interface | |
| 15 # TODO | |
| 16 | |
| 17 | |
| 18 function Base.size(g::TensorGrid) | |
| 19 return LazyTensors.concatenate_tuples(size.(g.grids)...) | |
| 20 end | |
| 21 | |
| 22 function Base.getindex(g::TensorGrid, I...) | 19 function Base.getindex(g::TensorGrid, I...) |
| 23 szs = ndims.(g.grids) | 20 szs = ndims.(g.grids) |
| 24 | 21 |
| 25 Is = LazyTensors.split_tuple(I, szs) | 22 Is = LazyTensors.split_tuple(I, szs) |
| 26 ps = map((g,I)->SVector(g[I...]), g.grids, Is) | 23 ps = map((g,I)->SVector(g[I...]), g.grids, Is) |
| 27 | 24 |
| 28 return vcat(ps...) | 25 return vcat(ps...) |
| 29 end | 26 end |
| 30 | 27 |
| 31 IndexStyle(::TensorGrid) = IndexCartesian() | |
| 32 | |
| 33 function Base.eachindex(g::TensorGrid) | 28 function Base.eachindex(g::TensorGrid) |
| 34 szs = LazyTensors.concatenate_tuples(size.(g.grids)...) | 29 szs = LazyTensors.concatenate_tuples(size.(g.grids)...) |
| 35 return CartesianIndices(szs) | 30 return CartesianIndices(szs) |
| 36 end | 31 end |
| 37 | 32 |
| 33 # Iteration interface | |
| 34 Base.iterate(g::TensorGrid) = iterate(Iterators.product(g.grids...)) |> _iterate_combine_coords | |
| 35 Base.iterate(g::TensorGrid, state) = iterate(Iterators.product(g.grids...), state) |> _iterate_combine_coords | |
| 36 _iterate_combine_coords(::Nothing) = nothing | |
| 37 _iterate_combine_coords((next,state)) = combine_coordinates(next...), state | |
| 38 | 38 |
| 39 struct TensorBoundary{N, BID<:BoundaryIdentifier} <: BoundaryIdentifier end | 39 Base.IteratorSize(::Type{<:TensorGrid{<:Any, D}}) where D = Base.HasShape{D}() |
| 40 grid_id(::TensorBoundary{N, BID}) where {N, BID} = N | 40 Base.eltype(::Type{<:TensorGrid{T}}) where T = T |
| 41 boundary_id(::TensorBoundary{N, BID}) where {N, BID} = BID() | 41 Base.length(g::TensorGrid) = sum(length, g.grids) |
| 42 Base.size(g::TensorGrid) = LazyTensors.concatenate_tuples(size.(g.grids)...) | |
| 42 | 43 |
| 44 | |
| 45 refine(g::TensorGrid, r::Int) = mapreduce(g->refine(g,r), TensorGrid, g.grids) | |
| 46 coarsen(g::TensorGrid, r::Int) = mapreduce(g->coarsen(g,r), TensorGrid, g.grids) | |
| 47 | |
| 48 """ | |
| 49 # TODO: | |
| 50 """ | |
| 51 struct TensorGridBoundary{N, BID} <: BoundaryIdentifier end | |
| 52 grid_id(::TensorGridBoundary{N, BID}) where {N, BID} = N | |
| 53 boundary_id(::TensorGridBoundary{N, BID}) where {N, BID} = BID() | |
| 43 | 54 |
| 44 """ | 55 """ |
| 45 boundary_identifiers(::TensorGrid) | 56 boundary_identifiers(::TensorGrid) |
| 46 | 57 |
| 47 Returns a tuple containing the boundary identifiers for the grid. | 58 Returns a tuple containing the boundary identifiers for the grid. |
| 48 """ | 59 """ |
| 49 function boundary_identifiers(g::TensorGrid) | 60 function boundary_identifiers(g::TensorGrid) |
| 50 n = length(g.grids) | |
| 51 per_grid = map(eachindex(g.grids)) do i | 61 per_grid = map(eachindex(g.grids)) do i |
| 52 return map(bid -> TensorBoundary{i, bid}(), boundary_identifiers(g.grids[i])) | 62 return map(bid -> TensorGridBoundary{i, typeof(bid)}(), boundary_identifiers(g.grids[i])) |
| 53 end | 63 end |
| 54 return LazyTensors.concatenate_tuples(per_grid...) | 64 return LazyTensors.concatenate_tuples(per_grid...) |
| 55 end | 65 end |
| 56 | 66 |
| 57 | 67 |
| 58 """ | 68 """ |
| 59 boundary_grid(grid::TensorGrid, id::TensorBoundary) | 69 boundary_grid(grid::TensorGrid, id::TensorGridBoundary) |
| 60 | 70 |
| 61 The grid for the boundary specified by `id`. | 71 The grid for the boundary specified by `id`. |
| 62 """ | 72 """ |
| 63 function boundary_grid(g::TensorGrid, bid::TensorBoundary) | 73 function boundary_grid(g::TensorGrid, bid::TensorGridBoundary) |
| 64 local_boundary_grid = boundary_grid(g.grids[grid_id(bid)], boundary_id(bid)) | 74 local_boundary_grid = boundary_grid(g.grids[grid_id(bid)], boundary_id(bid)) |
| 65 new_grids = Base.setindex(g.grids, local_boundary_grid, grid_id(bid)) | 75 new_grids = Base.setindex(g.grids, local_boundary_grid, grid_id(bid)) |
| 66 return TensorGrid(new_grids...) | 76 return TensorGrid(new_grids...) |
| 67 end | 77 end |
| 78 | |
| 79 | |
| 80 function combined_coordinate_vector_type(coordinate_types...) | |
| 81 coord_length(::Type{<:Number}) = 1 | |
| 82 coord_length(T::Type{<:SVector}) = length(T) | |
| 83 | |
| 84 coord_type(T::Type{<:Number}) = T | |
| 85 coord_type(T::Type{<:SVector}) = eltype(T) | |
| 86 | |
| 87 | |
| 88 combined_coord_length = mapreduce(coord_length, +, coordinate_types) | |
| 89 combined_coord_type = mapreduce(coord_type, promote_type, coordinate_types) | |
| 90 | |
| 91 if combined_coord_length == 1 | |
| 92 return combined_coord_type | |
| 93 else | |
| 94 return SVector{combined_coord_length, combined_coord_type} | |
| 95 end | |
| 96 end | |
| 97 | |
| 98 function combine_coordinates(coords...) | |
| 99 return mapreduce(SVector, vcat, coords) | |
| 100 # return SVector(coords...) | |
| 101 end |
