comparison src/Grids/tensor_grid.jl @ 1266:a4ddae8b5d49 refactor/grids

Add tests for TensorGrid and make them pass
author Jonatan Werpers <jonatan@werpers.com>
date Fri, 24 Feb 2023 21:42:28 +0100
parents 198ccda331a6
children dcbac783e4c1
comparison
equal deleted inserted replaced
1265:9c9ea2900250 1266:a4ddae8b5d49
1 """
2 TensorGrid{T,D} <: Grid{T,D}
3
4 * Only supports HasShape grids at the moment
5 """
1 struct TensorGrid{T,D,GT<:NTuple{N,Grid} where N} <: Grid{T,D} 6 struct TensorGrid{T,D,GT<:NTuple{N,Grid} where N} <: Grid{T,D}
2 grids::GT 7 grids::GT
3 8
4 function TensorGrid(gs...) 9 function TensorGrid(gs...)
5 T = eltype(gs[1]) # All gs should have the same T 10 # T = combined_coordinate_vector_type(eltype.(gs)...)
6 D = sum(ndims,gs) 11 T = mapreduce(eltype, combined_coordinate_vector_type, gs)
12 D = sum(ndims, gs)
7 13
8 return new{T,D,typeof(gs)}(gs) 14 return new{T,D,typeof(gs)}(gs)
9 end 15 end
10 end 16 end
11 17
12 # Indexing interface 18 # Indexing interface
13 # TODO
14 # Iteration interface
15 # TODO
16
17
18 function Base.size(g::TensorGrid)
19 return LazyTensors.concatenate_tuples(size.(g.grids)...)
20 end
21
22 function Base.getindex(g::TensorGrid, I...) 19 function Base.getindex(g::TensorGrid, I...)
23 szs = ndims.(g.grids) 20 szs = ndims.(g.grids)
24 21
25 Is = LazyTensors.split_tuple(I, szs) 22 Is = LazyTensors.split_tuple(I, szs)
26 ps = map((g,I)->SVector(g[I...]), g.grids, Is) 23 ps = map((g,I)->SVector(g[I...]), g.grids, Is)
27 24
28 return vcat(ps...) 25 return vcat(ps...)
29 end 26 end
30 27
31 IndexStyle(::TensorGrid) = IndexCartesian()
32
33 function Base.eachindex(g::TensorGrid) 28 function Base.eachindex(g::TensorGrid)
34 szs = LazyTensors.concatenate_tuples(size.(g.grids)...) 29 szs = LazyTensors.concatenate_tuples(size.(g.grids)...)
35 return CartesianIndices(szs) 30 return CartesianIndices(szs)
36 end 31 end
37 32
33 # Iteration interface
34 Base.iterate(g::TensorGrid) = iterate(Iterators.product(g.grids...)) |> _iterate_combine_coords
35 Base.iterate(g::TensorGrid, state) = iterate(Iterators.product(g.grids...), state) |> _iterate_combine_coords
36 _iterate_combine_coords(::Nothing) = nothing
37 _iterate_combine_coords((next,state)) = combine_coordinates(next...), state
38 38
39 struct TensorBoundary{N, BID<:BoundaryIdentifier} <: BoundaryIdentifier end 39 Base.IteratorSize(::Type{<:TensorGrid{<:Any, D}}) where D = Base.HasShape{D}()
40 grid_id(::TensorBoundary{N, BID}) where {N, BID} = N 40 Base.eltype(::Type{<:TensorGrid{T}}) where T = T
41 boundary_id(::TensorBoundary{N, BID}) where {N, BID} = BID() 41 Base.length(g::TensorGrid) = sum(length, g.grids)
42 Base.size(g::TensorGrid) = LazyTensors.concatenate_tuples(size.(g.grids)...)
42 43
44
45 refine(g::TensorGrid, r::Int) = mapreduce(g->refine(g,r), TensorGrid, g.grids)
46 coarsen(g::TensorGrid, r::Int) = mapreduce(g->coarsen(g,r), TensorGrid, g.grids)
47
48 """
49 # TODO:
50 """
51 struct TensorGridBoundary{N, BID} <: BoundaryIdentifier end
52 grid_id(::TensorGridBoundary{N, BID}) where {N, BID} = N
53 boundary_id(::TensorGridBoundary{N, BID}) where {N, BID} = BID()
43 54
44 """ 55 """
45 boundary_identifiers(::TensorGrid) 56 boundary_identifiers(::TensorGrid)
46 57
47 Returns a tuple containing the boundary identifiers for the grid. 58 Returns a tuple containing the boundary identifiers for the grid.
48 """ 59 """
49 function boundary_identifiers(g::TensorGrid) 60 function boundary_identifiers(g::TensorGrid)
50 n = length(g.grids)
51 per_grid = map(eachindex(g.grids)) do i 61 per_grid = map(eachindex(g.grids)) do i
52 return map(bid -> TensorBoundary{i, bid}(), boundary_identifiers(g.grids[i])) 62 return map(bid -> TensorGridBoundary{i, typeof(bid)}(), boundary_identifiers(g.grids[i]))
53 end 63 end
54 return LazyTensors.concatenate_tuples(per_grid...) 64 return LazyTensors.concatenate_tuples(per_grid...)
55 end 65 end
56 66
57 67
58 """ 68 """
59 boundary_grid(grid::TensorGrid, id::TensorBoundary) 69 boundary_grid(grid::TensorGrid, id::TensorGridBoundary)
60 70
61 The grid for the boundary specified by `id`. 71 The grid for the boundary specified by `id`.
62 """ 72 """
63 function boundary_grid(g::TensorGrid, bid::TensorBoundary) 73 function boundary_grid(g::TensorGrid, bid::TensorGridBoundary)
64 local_boundary_grid = boundary_grid(g.grids[grid_id(bid)], boundary_id(bid)) 74 local_boundary_grid = boundary_grid(g.grids[grid_id(bid)], boundary_id(bid))
65 new_grids = Base.setindex(g.grids, local_boundary_grid, grid_id(bid)) 75 new_grids = Base.setindex(g.grids, local_boundary_grid, grid_id(bid))
66 return TensorGrid(new_grids...) 76 return TensorGrid(new_grids...)
67 end 77 end
78
79
80 function combined_coordinate_vector_type(coordinate_types...)
81 coord_length(::Type{<:Number}) = 1
82 coord_length(T::Type{<:SVector}) = length(T)
83
84 coord_type(T::Type{<:Number}) = T
85 coord_type(T::Type{<:SVector}) = eltype(T)
86
87
88 combined_coord_length = mapreduce(coord_length, +, coordinate_types)
89 combined_coord_type = mapreduce(coord_type, promote_type, coordinate_types)
90
91 if combined_coord_length == 1
92 return combined_coord_type
93 else
94 return SVector{combined_coord_length, combined_coord_type}
95 end
96 end
97
98 function combine_coordinates(coords...)
99 return mapreduce(SVector, vcat, coords)
100 # return SVector(coords...)
101 end