Mercurial > repos > public > sbplib_julia
annotate src/Grids/tensor_grid.jl @ 1330:5f05a708d730 refactor/grids
grid.l: More documentation
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Tue, 02 May 2023 22:42:25 +0200 |
parents | 3b7ebd135918 |
children | ed3ea0630825 |
rev | line source |
---|---|
1266
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
1 """ |
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
2 TensorGrid{T,D} <: Grid{T,D} |
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
3 |
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
4 * Only supports HasShape grids at the moment |
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
5 """ |
1257
198ccda331a6
Remove range dim as a type paratmeter on Grid as it is already encoded in T if available
Jonatan Werpers <jonatan@werpers.com>
parents:
1256
diff
changeset
|
6 struct TensorGrid{T,D,GT<:NTuple{N,Grid} where N} <: Grid{T,D} |
1222 | 7 grids::GT |
8 | |
9 function TensorGrid(gs...) | |
1266
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
10 # T = combined_coordinate_vector_type(eltype.(gs)...) |
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
11 T = mapreduce(eltype, combined_coordinate_vector_type, gs) |
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
12 D = sum(ndims, gs) |
1222 | 13 |
1257
198ccda331a6
Remove range dim as a type paratmeter on Grid as it is already encoded in T if available
Jonatan Werpers <jonatan@werpers.com>
parents:
1256
diff
changeset
|
14 return new{T,D,typeof(gs)}(gs) |
1222 | 15 end |
16 end | |
17 | |
1256
3fc78ad26d03
Add notes and todos about interface implementations for grids
Jonatan Werpers <jonatan@werpers.com>
parents:
1251
diff
changeset
|
18 # Indexing interface |
1222 | 19 function Base.getindex(g::TensorGrid, I...) |
20 szs = ndims.(g.grids) | |
21 | |
1251
6f75f2d2bf5c
Qualify package when using split_tuple and concatenate_tuple
Jonatan Werpers <jonatan@werpers.com>
parents:
1236
diff
changeset
|
22 Is = LazyTensors.split_tuple(I, szs) |
1222 | 23 ps = map((g,I)->SVector(g[I...]), g.grids, Is) |
24 | |
25 return vcat(ps...) | |
26 end | |
27 | |
28 function Base.eachindex(g::TensorGrid) | |
1251
6f75f2d2bf5c
Qualify package when using split_tuple and concatenate_tuple
Jonatan Werpers <jonatan@werpers.com>
parents:
1236
diff
changeset
|
29 szs = LazyTensors.concatenate_tuples(size.(g.grids)...) |
1222 | 30 return CartesianIndices(szs) |
31 end | |
32 | |
1266
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
33 # Iteration interface |
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
34 Base.iterate(g::TensorGrid) = iterate(Iterators.product(g.grids...)) |> _iterate_combine_coords |
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
35 Base.iterate(g::TensorGrid, state) = iterate(Iterators.product(g.grids...), state) |> _iterate_combine_coords |
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
36 _iterate_combine_coords(::Nothing) = nothing |
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
37 _iterate_combine_coords((next,state)) = combine_coordinates(next...), state |
1222 | 38 |
1266
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
39 Base.IteratorSize(::Type{<:TensorGrid{<:Any, D}}) where D = Base.HasShape{D}() |
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
40 Base.eltype(::Type{<:TensorGrid{T}}) where T = T |
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
41 Base.length(g::TensorGrid) = sum(length, g.grids) |
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
42 Base.size(g::TensorGrid) = LazyTensors.concatenate_tuples(size.(g.grids)...) |
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
43 |
1222 | 44 |
1266
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
45 refine(g::TensorGrid, r::Int) = mapreduce(g->refine(g,r), TensorGrid, g.grids) |
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
46 coarsen(g::TensorGrid, r::Int) = mapreduce(g->coarsen(g,r), TensorGrid, g.grids) |
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
47 |
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
48 """ |
1330
5f05a708d730
grid.l: More documentation
Jonatan Werpers <jonatan@werpers.com>
parents:
1289
diff
changeset
|
49 TensorGridBoundary{N, BID} <: BoundaryIdentifier |
1266
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
50 # TODO: |
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
51 """ |
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
52 struct TensorGridBoundary{N, BID} <: BoundaryIdentifier end |
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
53 grid_id(::TensorGridBoundary{N, BID}) where {N, BID} = N |
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
54 boundary_id(::TensorGridBoundary{N, BID}) where {N, BID} = BID() |
1222 | 55 |
56 """ | |
1236
95e294576c2a
Implement boundary methods for TensorGrid
Jonatan Werpers <jonatan@werpers.com>
parents:
1222
diff
changeset
|
57 boundary_identifiers(::TensorGrid) |
1222 | 58 |
1236
95e294576c2a
Implement boundary methods for TensorGrid
Jonatan Werpers <jonatan@werpers.com>
parents:
1222
diff
changeset
|
59 Returns a tuple containing the boundary identifiers for the grid. |
1222 | 60 """ |
1236
95e294576c2a
Implement boundary methods for TensorGrid
Jonatan Werpers <jonatan@werpers.com>
parents:
1222
diff
changeset
|
61 function boundary_identifiers(g::TensorGrid) |
95e294576c2a
Implement boundary methods for TensorGrid
Jonatan Werpers <jonatan@werpers.com>
parents:
1222
diff
changeset
|
62 per_grid = map(eachindex(g.grids)) do i |
1266
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
63 return map(bid -> TensorGridBoundary{i, typeof(bid)}(), boundary_identifiers(g.grids[i])) |
1236
95e294576c2a
Implement boundary methods for TensorGrid
Jonatan Werpers <jonatan@werpers.com>
parents:
1222
diff
changeset
|
64 end |
1251
6f75f2d2bf5c
Qualify package when using split_tuple and concatenate_tuple
Jonatan Werpers <jonatan@werpers.com>
parents:
1236
diff
changeset
|
65 return LazyTensors.concatenate_tuples(per_grid...) |
1222 | 66 end |
67 | |
68 | |
69 """ | |
1266
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
70 boundary_grid(grid::TensorGrid, id::TensorGridBoundary) |
1222 | 71 |
1236
95e294576c2a
Implement boundary methods for TensorGrid
Jonatan Werpers <jonatan@werpers.com>
parents:
1222
diff
changeset
|
72 The grid for the boundary specified by `id`. |
1222 | 73 """ |
1266
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
74 function boundary_grid(g::TensorGrid, bid::TensorGridBoundary) |
1236
95e294576c2a
Implement boundary methods for TensorGrid
Jonatan Werpers <jonatan@werpers.com>
parents:
1222
diff
changeset
|
75 local_boundary_grid = boundary_grid(g.grids[grid_id(bid)], boundary_id(bid)) |
95e294576c2a
Implement boundary methods for TensorGrid
Jonatan Werpers <jonatan@werpers.com>
parents:
1222
diff
changeset
|
76 new_grids = Base.setindex(g.grids, local_boundary_grid, grid_id(bid)) |
95e294576c2a
Implement boundary methods for TensorGrid
Jonatan Werpers <jonatan@werpers.com>
parents:
1222
diff
changeset
|
77 return TensorGrid(new_grids...) |
1222 | 78 end |
1266
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
79 |
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
80 |
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
81 function combined_coordinate_vector_type(coordinate_types...) |
1270
dcbac783e4c1
Factor out functions for getting the type and number of components in a type
Jonatan Werpers <jonatan@werpers.com>
parents:
1266
diff
changeset
|
82 combined_coord_length = mapreduce(_ncomponents, +, coordinate_types) |
1289
3b7ebd135918
Remove _component_type and replace with eltype
Jonatan Werpers <jonatan@werpers.com>
parents:
1270
diff
changeset
|
83 combined_coord_type = mapreduce(eltype, promote_type, coordinate_types) |
1266
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
84 |
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
85 if combined_coord_length == 1 |
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
86 return combined_coord_type |
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
87 else |
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
88 return SVector{combined_coord_length, combined_coord_type} |
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
89 end |
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
90 end |
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
91 |
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
92 function combine_coordinates(coords...) |
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
93 return mapreduce(SVector, vcat, coords) |
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
94 # return SVector(coords...) |
a4ddae8b5d49
Add tests for TensorGrid and make them pass
Jonatan Werpers <jonatan@werpers.com>
parents:
1257
diff
changeset
|
95 end |