Mercurial > repos > public > sbplib_julia
comparison src/Grids/tensor_grid.jl @ 1266:a4ddae8b5d49 refactor/grids
Add tests for TensorGrid and make them pass
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Fri, 24 Feb 2023 21:42:28 +0100 |
parents | 198ccda331a6 |
children | dcbac783e4c1 |
comparison
equal
deleted
inserted
replaced
1265:9c9ea2900250 | 1266:a4ddae8b5d49 |
---|---|
1 """ | |
2 TensorGrid{T,D} <: Grid{T,D} | |
3 | |
4 * Only supports HasShape grids at the moment | |
5 """ | |
1 struct TensorGrid{T,D,GT<:NTuple{N,Grid} where N} <: Grid{T,D} | 6 struct TensorGrid{T,D,GT<:NTuple{N,Grid} where N} <: Grid{T,D} |
2 grids::GT | 7 grids::GT |
3 | 8 |
4 function TensorGrid(gs...) | 9 function TensorGrid(gs...) |
5 T = eltype(gs[1]) # All gs should have the same T | 10 # T = combined_coordinate_vector_type(eltype.(gs)...) |
6 D = sum(ndims,gs) | 11 T = mapreduce(eltype, combined_coordinate_vector_type, gs) |
12 D = sum(ndims, gs) | |
7 | 13 |
8 return new{T,D,typeof(gs)}(gs) | 14 return new{T,D,typeof(gs)}(gs) |
9 end | 15 end |
10 end | 16 end |
11 | 17 |
12 # Indexing interface | 18 # Indexing interface |
13 # TODO | |
14 # Iteration interface | |
15 # TODO | |
16 | |
17 | |
18 function Base.size(g::TensorGrid) | |
19 return LazyTensors.concatenate_tuples(size.(g.grids)...) | |
20 end | |
21 | |
22 function Base.getindex(g::TensorGrid, I...) | 19 function Base.getindex(g::TensorGrid, I...) |
23 szs = ndims.(g.grids) | 20 szs = ndims.(g.grids) |
24 | 21 |
25 Is = LazyTensors.split_tuple(I, szs) | 22 Is = LazyTensors.split_tuple(I, szs) |
26 ps = map((g,I)->SVector(g[I...]), g.grids, Is) | 23 ps = map((g,I)->SVector(g[I...]), g.grids, Is) |
27 | 24 |
28 return vcat(ps...) | 25 return vcat(ps...) |
29 end | 26 end |
30 | 27 |
31 IndexStyle(::TensorGrid) = IndexCartesian() | |
32 | |
33 function Base.eachindex(g::TensorGrid) | 28 function Base.eachindex(g::TensorGrid) |
34 szs = LazyTensors.concatenate_tuples(size.(g.grids)...) | 29 szs = LazyTensors.concatenate_tuples(size.(g.grids)...) |
35 return CartesianIndices(szs) | 30 return CartesianIndices(szs) |
36 end | 31 end |
37 | 32 |
33 # Iteration interface | |
34 Base.iterate(g::TensorGrid) = iterate(Iterators.product(g.grids...)) |> _iterate_combine_coords | |
35 Base.iterate(g::TensorGrid, state) = iterate(Iterators.product(g.grids...), state) |> _iterate_combine_coords | |
36 _iterate_combine_coords(::Nothing) = nothing | |
37 _iterate_combine_coords((next,state)) = combine_coordinates(next...), state | |
38 | 38 |
39 struct TensorBoundary{N, BID<:BoundaryIdentifier} <: BoundaryIdentifier end | 39 Base.IteratorSize(::Type{<:TensorGrid{<:Any, D}}) where D = Base.HasShape{D}() |
40 grid_id(::TensorBoundary{N, BID}) where {N, BID} = N | 40 Base.eltype(::Type{<:TensorGrid{T}}) where T = T |
41 boundary_id(::TensorBoundary{N, BID}) where {N, BID} = BID() | 41 Base.length(g::TensorGrid) = sum(length, g.grids) |
42 Base.size(g::TensorGrid) = LazyTensors.concatenate_tuples(size.(g.grids)...) | |
42 | 43 |
44 | |
45 refine(g::TensorGrid, r::Int) = mapreduce(g->refine(g,r), TensorGrid, g.grids) | |
46 coarsen(g::TensorGrid, r::Int) = mapreduce(g->coarsen(g,r), TensorGrid, g.grids) | |
47 | |
48 """ | |
49 # TODO: | |
50 """ | |
51 struct TensorGridBoundary{N, BID} <: BoundaryIdentifier end | |
52 grid_id(::TensorGridBoundary{N, BID}) where {N, BID} = N | |
53 boundary_id(::TensorGridBoundary{N, BID}) where {N, BID} = BID() | |
43 | 54 |
44 """ | 55 """ |
45 boundary_identifiers(::TensorGrid) | 56 boundary_identifiers(::TensorGrid) |
46 | 57 |
47 Returns a tuple containing the boundary identifiers for the grid. | 58 Returns a tuple containing the boundary identifiers for the grid. |
48 """ | 59 """ |
49 function boundary_identifiers(g::TensorGrid) | 60 function boundary_identifiers(g::TensorGrid) |
50 n = length(g.grids) | |
51 per_grid = map(eachindex(g.grids)) do i | 61 per_grid = map(eachindex(g.grids)) do i |
52 return map(bid -> TensorBoundary{i, bid}(), boundary_identifiers(g.grids[i])) | 62 return map(bid -> TensorGridBoundary{i, typeof(bid)}(), boundary_identifiers(g.grids[i])) |
53 end | 63 end |
54 return LazyTensors.concatenate_tuples(per_grid...) | 64 return LazyTensors.concatenate_tuples(per_grid...) |
55 end | 65 end |
56 | 66 |
57 | 67 |
58 """ | 68 """ |
59 boundary_grid(grid::TensorGrid, id::TensorBoundary) | 69 boundary_grid(grid::TensorGrid, id::TensorGridBoundary) |
60 | 70 |
61 The grid for the boundary specified by `id`. | 71 The grid for the boundary specified by `id`. |
62 """ | 72 """ |
63 function boundary_grid(g::TensorGrid, bid::TensorBoundary) | 73 function boundary_grid(g::TensorGrid, bid::TensorGridBoundary) |
64 local_boundary_grid = boundary_grid(g.grids[grid_id(bid)], boundary_id(bid)) | 74 local_boundary_grid = boundary_grid(g.grids[grid_id(bid)], boundary_id(bid)) |
65 new_grids = Base.setindex(g.grids, local_boundary_grid, grid_id(bid)) | 75 new_grids = Base.setindex(g.grids, local_boundary_grid, grid_id(bid)) |
66 return TensorGrid(new_grids...) | 76 return TensorGrid(new_grids...) |
67 end | 77 end |
78 | |
79 | |
80 function combined_coordinate_vector_type(coordinate_types...) | |
81 coord_length(::Type{<:Number}) = 1 | |
82 coord_length(T::Type{<:SVector}) = length(T) | |
83 | |
84 coord_type(T::Type{<:Number}) = T | |
85 coord_type(T::Type{<:SVector}) = eltype(T) | |
86 | |
87 | |
88 combined_coord_length = mapreduce(coord_length, +, coordinate_types) | |
89 combined_coord_type = mapreduce(coord_type, promote_type, coordinate_types) | |
90 | |
91 if combined_coord_length == 1 | |
92 return combined_coord_type | |
93 else | |
94 return SVector{combined_coord_length, combined_coord_type} | |
95 end | |
96 end | |
97 | |
98 function combine_coordinates(coords...) | |
99 return mapreduce(SVector, vcat, coords) | |
100 # return SVector(coords...) | |
101 end |