comparison src/Grids/tensor_grid.jl @ 1360:f59228534d3a tooling/benchmarks

Merge default
author Jonatan Werpers <jonatan@werpers.com>
date Sat, 20 May 2023 15:15:22 +0200
parents 42ecd4b3e215
children 4d628c83987e c0208286234e
comparison
equal deleted inserted replaced
1321:42738616422e 1360:f59228534d3a
1 """
2 TensorGrid{T,D} <: Grid{T,D}
3
4 A grid constructed as the tensor product of other grids.
5
6 Currently only supports grids with the `HasShape`-trait.
7 """
8 struct TensorGrid{T,D,GT<:NTuple{N,Grid} where N} <: Grid{T,D}
9 grids::GT
10
11 function TensorGrid(gs...)
12 T = mapreduce(eltype, combined_coordinate_vector_type, gs)
13 D = sum(ndims, gs)
14
15 return new{T,D,typeof(gs)}(gs)
16 end
17 end
18
19 # Indexing interface
20 function Base.getindex(g::TensorGrid, I...)
21 szs = ndims.(g.grids)
22
23 Is = LazyTensors.split_tuple(I, szs)
24 ps = map((g,I)->SVector(g[I...]), g.grids, Is)
25
26 return vcat(ps...)
27 end
28
29 Base.getindex(g::TensorGrid, I::CartesianIndex) = g[Tuple(I)...]
30
31 function Base.eachindex(g::TensorGrid)
32 szs = LazyTensors.concatenate_tuples(size.(g.grids)...)
33 return CartesianIndices(szs)
34 end
35
36 # Iteration interface
37 Base.iterate(g::TensorGrid) = iterate(Iterators.product(g.grids...)) |> _iterate_combine_coords
38 Base.iterate(g::TensorGrid, state) = iterate(Iterators.product(g.grids...), state) |> _iterate_combine_coords
39 _iterate_combine_coords(::Nothing) = nothing
40 _iterate_combine_coords((next,state)) = combine_coordinates(next...), state
41
42 Base.IteratorSize(::Type{<:TensorGrid{<:Any, D}}) where D = Base.HasShape{D}()
43 Base.eltype(::Type{<:TensorGrid{T}}) where T = T
44 Base.length(g::TensorGrid) = sum(length, g.grids)
45 Base.size(g::TensorGrid) = LazyTensors.concatenate_tuples(size.(g.grids)...)
46
47
48 refine(g::TensorGrid, r::Int) = mapreduce(g->refine(g,r), TensorGrid, g.grids)
49 coarsen(g::TensorGrid, r::Int) = mapreduce(g->coarsen(g,r), TensorGrid, g.grids)
50
51 """
52 TensorGridBoundary{N, BID} <: BoundaryIdentifier
53
54 A boundary identifier for a tensor grid. `N` Specifies which grid in the
55 tensor product and `BID` which boundary on that grid.
56 """
57 struct TensorGridBoundary{N, BID} <: BoundaryIdentifier end
58 grid_id(::TensorGridBoundary{N, BID}) where {N, BID} = N
59 boundary_id(::TensorGridBoundary{N, BID}) where {N, BID} = BID()
60
61 """
62 boundary_identifiers(g::TensorGrid)
63
64 Returns a tuple containing the boundary identifiers of `g`.
65 """
66 function boundary_identifiers(g::TensorGrid)
67 per_grid = map(eachindex(g.grids)) do i
68 return map(bid -> TensorGridBoundary{i, typeof(bid)}(), boundary_identifiers(g.grids[i]))
69 end
70 return LazyTensors.concatenate_tuples(per_grid...)
71 end
72
73
74 """
75 boundary_grid(g::TensorGrid, id::TensorGridBoundary)
76
77 The grid for the boundary of `g` specified by `id`.
78 """
79 function boundary_grid(g::TensorGrid, id::TensorGridBoundary)
80 local_boundary_grid = boundary_grid(g.grids[grid_id(id)], boundary_id(id))
81 new_grids = Base.setindex(g.grids, local_boundary_grid, grid_id(id))
82 return TensorGrid(new_grids...)
83 end
84
85
86 function combined_coordinate_vector_type(coordinate_types...)
87 combined_coord_length = mapreduce(_ncomponents, +, coordinate_types)
88 combined_coord_type = mapreduce(eltype, promote_type, coordinate_types)
89
90 if combined_coord_length == 1
91 return combined_coord_type
92 else
93 return SVector{combined_coord_length, combined_coord_type}
94 end
95 end
96
97 function combine_coordinates(coords...)
98 return mapreduce(SVector, vcat, coords)
99 end