Mercurial > repos > public > sbplib_julia
annotate src/Grids/tensor_grid.jl @ 1353:1629ddee4b3a refactor/grids
Close before merge
| author | Jonatan Werpers <jonatan@werpers.com> |
|---|---|
| date | Sat, 20 May 2023 14:17:18 +0200 |
| parents | 42ecd4b3e215 |
| children | 4d628c83987e c0208286234e |
| rev | line source |
|---|---|
|
1266
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
1 """ |
|
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
2 TensorGrid{T,D} <: Grid{T,D} |
|
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
3 |
|
1338
5604676d8426
Docs in tensor_grid.jl
Jonatan Werpers <jonatan@werpers.com>
parents:
1337
diff
changeset
|
4 A grid constructed as the tensor product of other grids. |
|
5604676d8426
Docs in tensor_grid.jl
Jonatan Werpers <jonatan@werpers.com>
parents:
1337
diff
changeset
|
5 |
|
5604676d8426
Docs in tensor_grid.jl
Jonatan Werpers <jonatan@werpers.com>
parents:
1337
diff
changeset
|
6 Currently only supports grids with the `HasShape`-trait. |
|
1266
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
7 """ |
|
1257
198ccda331a6
Remove range dim as a type paratmeter on Grid as it is already encoded in T if available
Jonatan Werpers <jonatan@werpers.com>
parents:
1256
diff
changeset
|
8 struct TensorGrid{T,D,GT<:NTuple{N,Grid} where N} <: Grid{T,D} |
| 1222 | 9 grids::GT |
| 10 | |
| 11 function TensorGrid(gs...) | |
|
1266
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
12 T = mapreduce(eltype, combined_coordinate_vector_type, gs) |
|
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
13 D = sum(ndims, gs) |
| 1222 | 14 |
|
1257
198ccda331a6
Remove range dim as a type paratmeter on Grid as it is already encoded in T if available
Jonatan Werpers <jonatan@werpers.com>
parents:
1256
diff
changeset
|
15 return new{T,D,typeof(gs)}(gs) |
| 1222 | 16 end |
| 17 end | |
| 18 | |
|
1256
3fc78ad26d03
Add notes and todos about interface implementations for grids
Jonatan Werpers <jonatan@werpers.com>
parents:
1251
diff
changeset
|
19 # Indexing interface |
| 1222 | 20 function Base.getindex(g::TensorGrid, I...) |
| 21 szs = ndims.(g.grids) | |
| 22 | |
|
1251
6f75f2d2bf5c
Qualify package when using split_tuple and concatenate_tuple
Jonatan Werpers <jonatan@werpers.com>
parents:
1236
diff
changeset
|
23 Is = LazyTensors.split_tuple(I, szs) |
| 1222 | 24 ps = map((g,I)->SVector(g[I...]), g.grids, Is) |
| 25 | |
| 26 return vcat(ps...) | |
| 27 end | |
| 28 | |
|
1349
42ecd4b3e215
Add support for cartesian indices on TensorGrid
Jonatan Werpers <jonatan@werpers.com>
parents:
1347
diff
changeset
|
29 Base.getindex(g::TensorGrid, I::CartesianIndex) = g[Tuple(I)...] |
|
42ecd4b3e215
Add support for cartesian indices on TensorGrid
Jonatan Werpers <jonatan@werpers.com>
parents:
1347
diff
changeset
|
30 |
| 1222 | 31 function Base.eachindex(g::TensorGrid) |
|
1251
6f75f2d2bf5c
Qualify package when using split_tuple and concatenate_tuple
Jonatan Werpers <jonatan@werpers.com>
parents:
1236
diff
changeset
|
32 szs = LazyTensors.concatenate_tuples(size.(g.grids)...) |
| 1222 | 33 return CartesianIndices(szs) |
| 34 end | |
| 35 | |
|
1266
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
36 # Iteration interface |
|
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
37 Base.iterate(g::TensorGrid) = iterate(Iterators.product(g.grids...)) |> _iterate_combine_coords |
|
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
38 Base.iterate(g::TensorGrid, state) = iterate(Iterators.product(g.grids...), state) |> _iterate_combine_coords |
|
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
39 _iterate_combine_coords(::Nothing) = nothing |
|
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
40 _iterate_combine_coords((next,state)) = combine_coordinates(next...), state |
| 1222 | 41 |
|
1266
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
42 Base.IteratorSize(::Type{<:TensorGrid{<:Any, D}}) where D = Base.HasShape{D}() |
|
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
43 Base.eltype(::Type{<:TensorGrid{T}}) where T = T |
|
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
44 Base.length(g::TensorGrid) = sum(length, g.grids) |
|
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
45 Base.size(g::TensorGrid) = LazyTensors.concatenate_tuples(size.(g.grids)...) |
|
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
46 |
| 1222 | 47 |
|
1266
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
48 refine(g::TensorGrid, r::Int) = mapreduce(g->refine(g,r), TensorGrid, g.grids) |
|
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
49 coarsen(g::TensorGrid, r::Int) = mapreduce(g->coarsen(g,r), TensorGrid, g.grids) |
|
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
50 |
|
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
51 """ |
|
1330
5f05a708d730
grid.l: More documentation
Jonatan Werpers <jonatan@werpers.com>
parents:
1289
diff
changeset
|
52 TensorGridBoundary{N, BID} <: BoundaryIdentifier |
|
1338
5604676d8426
Docs in tensor_grid.jl
Jonatan Werpers <jonatan@werpers.com>
parents:
1337
diff
changeset
|
53 |
|
5604676d8426
Docs in tensor_grid.jl
Jonatan Werpers <jonatan@werpers.com>
parents:
1337
diff
changeset
|
54 A boundary identifier for a tensor grid. `N` Specifies which grid in the |
|
5604676d8426
Docs in tensor_grid.jl
Jonatan Werpers <jonatan@werpers.com>
parents:
1337
diff
changeset
|
55 tensor product and `BID` which boundary on that grid. |
|
1266
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
56 """ |
|
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
57 struct TensorGridBoundary{N, BID} <: BoundaryIdentifier end |
|
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
58 grid_id(::TensorGridBoundary{N, BID}) where {N, BID} = N |
|
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
59 boundary_id(::TensorGridBoundary{N, BID}) where {N, BID} = BID() |
| 1222 | 60 |
| 61 """ | |
|
1347
08f06bfacd5c
Fix typos and formatting of documentation
Vidar Stiernström <vidar.stiernstrom@it.uu.se>
parents:
1338
diff
changeset
|
62 boundary_identifiers(g::TensorGrid) |
| 1222 | 63 |
|
1347
08f06bfacd5c
Fix typos and formatting of documentation
Vidar Stiernström <vidar.stiernstrom@it.uu.se>
parents:
1338
diff
changeset
|
64 Returns a tuple containing the boundary identifiers of `g`. |
| 1222 | 65 """ |
|
1236
95e294576c2a
Implement boundary methods for TensorGrid
Jonatan Werpers <jonatan@werpers.com>
parents:
1222
diff
changeset
|
66 function boundary_identifiers(g::TensorGrid) |
|
95e294576c2a
Implement boundary methods for TensorGrid
Jonatan Werpers <jonatan@werpers.com>
parents:
1222
diff
changeset
|
67 per_grid = map(eachindex(g.grids)) do i |
|
1266
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
68 return map(bid -> TensorGridBoundary{i, typeof(bid)}(), boundary_identifiers(g.grids[i])) |
|
1236
95e294576c2a
Implement boundary methods for TensorGrid
Jonatan Werpers <jonatan@werpers.com>
parents:
1222
diff
changeset
|
69 end |
|
1251
6f75f2d2bf5c
Qualify package when using split_tuple and concatenate_tuple
Jonatan Werpers <jonatan@werpers.com>
parents:
1236
diff
changeset
|
70 return LazyTensors.concatenate_tuples(per_grid...) |
| 1222 | 71 end |
| 72 | |
| 73 | |
| 74 """ | |
|
1347
08f06bfacd5c
Fix typos and formatting of documentation
Vidar Stiernström <vidar.stiernstrom@it.uu.se>
parents:
1338
diff
changeset
|
75 boundary_grid(g::TensorGrid, id::TensorGridBoundary) |
| 1222 | 76 |
|
1347
08f06bfacd5c
Fix typos and formatting of documentation
Vidar Stiernström <vidar.stiernstrom@it.uu.se>
parents:
1338
diff
changeset
|
77 The grid for the boundary of `g` specified by `id`. |
| 1222 | 78 """ |
|
1347
08f06bfacd5c
Fix typos and formatting of documentation
Vidar Stiernström <vidar.stiernstrom@it.uu.se>
parents:
1338
diff
changeset
|
79 function boundary_grid(g::TensorGrid, id::TensorGridBoundary) |
|
08f06bfacd5c
Fix typos and formatting of documentation
Vidar Stiernström <vidar.stiernstrom@it.uu.se>
parents:
1338
diff
changeset
|
80 local_boundary_grid = boundary_grid(g.grids[grid_id(id)], boundary_id(id)) |
|
08f06bfacd5c
Fix typos and formatting of documentation
Vidar Stiernström <vidar.stiernstrom@it.uu.se>
parents:
1338
diff
changeset
|
81 new_grids = Base.setindex(g.grids, local_boundary_grid, grid_id(id)) |
|
1236
95e294576c2a
Implement boundary methods for TensorGrid
Jonatan Werpers <jonatan@werpers.com>
parents:
1222
diff
changeset
|
82 return TensorGrid(new_grids...) |
| 1222 | 83 end |
|
1266
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
84 |
|
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
85 |
|
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
86 function combined_coordinate_vector_type(coordinate_types...) |
|
1270
dcbac783e4c1
Factor out functions for getting the type and number of components in a type
Jonatan Werpers <jonatan@werpers.com>
parents:
1266
diff
changeset
|
87 combined_coord_length = mapreduce(_ncomponents, +, coordinate_types) |
|
1289
3b7ebd135918
Remove _component_type and replace with eltype
Jonatan Werpers <jonatan@werpers.com>
parents:
1270
diff
changeset
|
88 combined_coord_type = mapreduce(eltype, promote_type, coordinate_types) |
|
1266
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
89 |
|
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
90 if combined_coord_length == 1 |
|
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
91 return combined_coord_type |
|
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
92 else |
|
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
93 return SVector{combined_coord_length, combined_coord_type} |
|
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
94 end |
|
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
95 end |
|
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
96 |
|
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
97 function combine_coordinates(coords...) |
|
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
98 return mapreduce(SVector, vcat, coords) |
|
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
99 end |
