Mercurial > repos > public > sbplib_julia
comparison src/Grids/tensor_grid.jl @ 1360:f59228534d3a tooling/benchmarks
Merge default
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Sat, 20 May 2023 15:15:22 +0200 |
parents | 42ecd4b3e215 |
children | 4d628c83987e c0208286234e |
comparison
equal
deleted
inserted
replaced
1321:42738616422e | 1360:f59228534d3a |
---|---|
1 """ | |
2 TensorGrid{T,D} <: Grid{T,D} | |
3 | |
4 A grid constructed as the tensor product of other grids. | |
5 | |
6 Currently only supports grids with the `HasShape`-trait. | |
7 """ | |
8 struct TensorGrid{T,D,GT<:NTuple{N,Grid} where N} <: Grid{T,D} | |
9 grids::GT | |
10 | |
11 function TensorGrid(gs...) | |
12 T = mapreduce(eltype, combined_coordinate_vector_type, gs) | |
13 D = sum(ndims, gs) | |
14 | |
15 return new{T,D,typeof(gs)}(gs) | |
16 end | |
17 end | |
18 | |
19 # Indexing interface | |
20 function Base.getindex(g::TensorGrid, I...) | |
21 szs = ndims.(g.grids) | |
22 | |
23 Is = LazyTensors.split_tuple(I, szs) | |
24 ps = map((g,I)->SVector(g[I...]), g.grids, Is) | |
25 | |
26 return vcat(ps...) | |
27 end | |
28 | |
29 Base.getindex(g::TensorGrid, I::CartesianIndex) = g[Tuple(I)...] | |
30 | |
31 function Base.eachindex(g::TensorGrid) | |
32 szs = LazyTensors.concatenate_tuples(size.(g.grids)...) | |
33 return CartesianIndices(szs) | |
34 end | |
35 | |
36 # Iteration interface | |
37 Base.iterate(g::TensorGrid) = iterate(Iterators.product(g.grids...)) |> _iterate_combine_coords | |
38 Base.iterate(g::TensorGrid, state) = iterate(Iterators.product(g.grids...), state) |> _iterate_combine_coords | |
39 _iterate_combine_coords(::Nothing) = nothing | |
40 _iterate_combine_coords((next,state)) = combine_coordinates(next...), state | |
41 | |
42 Base.IteratorSize(::Type{<:TensorGrid{<:Any, D}}) where D = Base.HasShape{D}() | |
43 Base.eltype(::Type{<:TensorGrid{T}}) where T = T | |
44 Base.length(g::TensorGrid) = sum(length, g.grids) | |
45 Base.size(g::TensorGrid) = LazyTensors.concatenate_tuples(size.(g.grids)...) | |
46 | |
47 | |
48 refine(g::TensorGrid, r::Int) = mapreduce(g->refine(g,r), TensorGrid, g.grids) | |
49 coarsen(g::TensorGrid, r::Int) = mapreduce(g->coarsen(g,r), TensorGrid, g.grids) | |
50 | |
51 """ | |
52 TensorGridBoundary{N, BID} <: BoundaryIdentifier | |
53 | |
54 A boundary identifier for a tensor grid. `N` Specifies which grid in the | |
55 tensor product and `BID` which boundary on that grid. | |
56 """ | |
57 struct TensorGridBoundary{N, BID} <: BoundaryIdentifier end | |
58 grid_id(::TensorGridBoundary{N, BID}) where {N, BID} = N | |
59 boundary_id(::TensorGridBoundary{N, BID}) where {N, BID} = BID() | |
60 | |
61 """ | |
62 boundary_identifiers(g::TensorGrid) | |
63 | |
64 Returns a tuple containing the boundary identifiers of `g`. | |
65 """ | |
66 function boundary_identifiers(g::TensorGrid) | |
67 per_grid = map(eachindex(g.grids)) do i | |
68 return map(bid -> TensorGridBoundary{i, typeof(bid)}(), boundary_identifiers(g.grids[i])) | |
69 end | |
70 return LazyTensors.concatenate_tuples(per_grid...) | |
71 end | |
72 | |
73 | |
74 """ | |
75 boundary_grid(g::TensorGrid, id::TensorGridBoundary) | |
76 | |
77 The grid for the boundary of `g` specified by `id`. | |
78 """ | |
79 function boundary_grid(g::TensorGrid, id::TensorGridBoundary) | |
80 local_boundary_grid = boundary_grid(g.grids[grid_id(id)], boundary_id(id)) | |
81 new_grids = Base.setindex(g.grids, local_boundary_grid, grid_id(id)) | |
82 return TensorGrid(new_grids...) | |
83 end | |
84 | |
85 | |
86 function combined_coordinate_vector_type(coordinate_types...) | |
87 combined_coord_length = mapreduce(_ncomponents, +, coordinate_types) | |
88 combined_coord_type = mapreduce(eltype, promote_type, coordinate_types) | |
89 | |
90 if combined_coord_length == 1 | |
91 return combined_coord_type | |
92 else | |
93 return SVector{combined_coord_length, combined_coord_type} | |
94 end | |
95 end | |
96 | |
97 function combine_coordinates(coords...) | |
98 return mapreduce(SVector, vcat, coords) | |
99 end |