Mercurial > repos > public > sbplib_julia
diff src/Grids/tensor_grid.jl @ 1266:a4ddae8b5d49 refactor/grids
Add tests for TensorGrid and make them pass
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Fri, 24 Feb 2023 21:42:28 +0100 |
parents | 198ccda331a6 |
children | dcbac783e4c1 |
line wrap: on
line diff
--- a/src/Grids/tensor_grid.jl Fri Feb 24 20:47:56 2023 +0100 +++ b/src/Grids/tensor_grid.jl Fri Feb 24 21:42:28 2023 +0100 @@ -1,24 +1,21 @@ +""" + TensorGrid{T,D} <: Grid{T,D} + +* Only supports HasShape grids at the moment +""" struct TensorGrid{T,D,GT<:NTuple{N,Grid} where N} <: Grid{T,D} grids::GT function TensorGrid(gs...) - T = eltype(gs[1]) # All gs should have the same T - D = sum(ndims,gs) + # T = combined_coordinate_vector_type(eltype.(gs)...) + T = mapreduce(eltype, combined_coordinate_vector_type, gs) + D = sum(ndims, gs) return new{T,D,typeof(gs)}(gs) end end # Indexing interface -# TODO -# Iteration interface -# TODO - - -function Base.size(g::TensorGrid) - return LazyTensors.concatenate_tuples(size.(g.grids)...) -end - function Base.getindex(g::TensorGrid, I...) szs = ndims.(g.grids) @@ -28,18 +25,32 @@ return vcat(ps...) end -IndexStyle(::TensorGrid) = IndexCartesian() - function Base.eachindex(g::TensorGrid) szs = LazyTensors.concatenate_tuples(size.(g.grids)...) return CartesianIndices(szs) end +# Iteration interface +Base.iterate(g::TensorGrid) = iterate(Iterators.product(g.grids...)) |> _iterate_combine_coords +Base.iterate(g::TensorGrid, state) = iterate(Iterators.product(g.grids...), state) |> _iterate_combine_coords +_iterate_combine_coords(::Nothing) = nothing +_iterate_combine_coords((next,state)) = combine_coordinates(next...), state -struct TensorBoundary{N, BID<:BoundaryIdentifier} <: BoundaryIdentifier end -grid_id(::TensorBoundary{N, BID}) where {N, BID} = N -boundary_id(::TensorBoundary{N, BID}) where {N, BID} = BID() +Base.IteratorSize(::Type{<:TensorGrid{<:Any, D}}) where D = Base.HasShape{D}() +Base.eltype(::Type{<:TensorGrid{T}}) where T = T +Base.length(g::TensorGrid) = sum(length, g.grids) +Base.size(g::TensorGrid) = LazyTensors.concatenate_tuples(size.(g.grids)...) + +refine(g::TensorGrid, r::Int) = mapreduce(g->refine(g,r), TensorGrid, g.grids) +coarsen(g::TensorGrid, r::Int) = mapreduce(g->coarsen(g,r), TensorGrid, g.grids) + +""" +# TODO: +""" +struct TensorGridBoundary{N, BID} <: BoundaryIdentifier end +grid_id(::TensorGridBoundary{N, BID}) where {N, BID} = N +boundary_id(::TensorGridBoundary{N, BID}) where {N, BID} = BID() """ boundary_identifiers(::TensorGrid) @@ -47,21 +58,44 @@ Returns a tuple containing the boundary identifiers for the grid. """ function boundary_identifiers(g::TensorGrid) - n = length(g.grids) per_grid = map(eachindex(g.grids)) do i - return map(bid -> TensorBoundary{i, bid}(), boundary_identifiers(g.grids[i])) + return map(bid -> TensorGridBoundary{i, typeof(bid)}(), boundary_identifiers(g.grids[i])) end return LazyTensors.concatenate_tuples(per_grid...) end """ - boundary_grid(grid::TensorGrid, id::TensorBoundary) + boundary_grid(grid::TensorGrid, id::TensorGridBoundary) The grid for the boundary specified by `id`. """ -function boundary_grid(g::TensorGrid, bid::TensorBoundary) +function boundary_grid(g::TensorGrid, bid::TensorGridBoundary) local_boundary_grid = boundary_grid(g.grids[grid_id(bid)], boundary_id(bid)) new_grids = Base.setindex(g.grids, local_boundary_grid, grid_id(bid)) return TensorGrid(new_grids...) end + + +function combined_coordinate_vector_type(coordinate_types...) + coord_length(::Type{<:Number}) = 1 + coord_length(T::Type{<:SVector}) = length(T) + + coord_type(T::Type{<:Number}) = T + coord_type(T::Type{<:SVector}) = eltype(T) + + + combined_coord_length = mapreduce(coord_length, +, coordinate_types) + combined_coord_type = mapreduce(coord_type, promote_type, coordinate_types) + + if combined_coord_length == 1 + return combined_coord_type + else + return SVector{combined_coord_length, combined_coord_type} + end +end + +function combine_coordinates(coords...) + return mapreduce(SVector, vcat, coords) + # return SVector(coords...) +end