comparison src/Grids/tensor_grid.jl @ 2057:8a2a0d678d6f feature/lazy_tensors/pretty_printing

Merge default
author Jonatan Werpers <jonatan@werpers.com>
date Tue, 10 Feb 2026 22:41:19 +0100
parents b1eb33e62d1e
children
comparison
equal deleted inserted replaced
1110:c0bff9f6e0fb 2057:8a2a0d678d6f
1 """
2 TensorGrid{T,D} <: Grid{T,D}
3
4 A grid constructed as the tensor product of other grids.
5
6 Currently only supports grids with the `HasShape`-trait.
7 """
8 struct TensorGrid{T,D,GT<:NTuple{N,Grid} where N} <: Grid{T,D}
9 grids::GT
10
11 function TensorGrid(gs...)
12 T = mapreduce(eltype, combined_coordinate_vector_type, gs)
13 D = sum(ndims, gs)
14
15 return new{T,D,typeof(gs)}(gs)
16 end
17 end
18
19 # Indexing interface
20 function Base.getindex(g::TensorGrid, I::Vararg{Int})
21 szs = ndims.(g.grids)
22
23 Is = LazyTensors.split_tuple(I, szs)
24 ps = map((g,I)->SVector(g[I...]), g.grids, Is)
25
26 return vcat(ps...)
27 end
28
29 function Base.eachindex(g::TensorGrid)
30 szs = LazyTensors.concatenate_tuples(size.(g.grids)...)
31 return CartesianIndices(szs)
32 end
33
34 function Base.axes(g::TensorGrid, d)
35 i, ld = grid_and_local_dim_index(ndims.(g.grids), d)
36 return axes(g.grids[i], ld)
37 end
38
39 # Iteration interface
40 Base.iterate(g::TensorGrid) = iterate(Iterators.product(g.grids...)) |> _iterate_combine_coords
41 Base.iterate(g::TensorGrid, state) = iterate(Iterators.product(g.grids...), state) |> _iterate_combine_coords
42 _iterate_combine_coords(::Nothing) = nothing
43 _iterate_combine_coords((next,state)) = combine_coordinates(next...), state
44
45 Base.IteratorSize(::Type{<:TensorGrid{<:Any, D}}) where D = Base.HasShape{D}()
46 Base.length(g::TensorGrid) = prod(length, g.grids)
47 Base.size(g::TensorGrid) = LazyTensors.concatenate_tuples(size.(g.grids)...)
48 Base.size(g::TensorGrid, d) = size(g)[d]
49
50 function spacing(g::TensorGrid)
51 relevant_grids = filter(g->!isa(g,ZeroDimGrid),g.grids)
52 return spacing.(relevant_grids)
53 end
54
55 function min_spacing(g::TensorGrid)
56 relevant_grids = filter(g->!isa(g,ZeroDimGrid),g.grids)
57 d = min_spacing.(relevant_grids)
58 return minimum(d)
59 end
60
61 refine(g::TensorGrid, r::Int) = mapreduce(g->refine(g,r), TensorGrid, g.grids)
62 coarsen(g::TensorGrid, r::Int) = mapreduce(g->coarsen(g,r), TensorGrid, g.grids)
63
64 """
65 TensorGridBoundary{N, BID} <: BoundaryIdentifier
66
67 A boundary identifier for a tensor grid. `N` Specifies which grid in the
68 tensor product and `BID` which boundary on that grid.
69 """
70 struct TensorGridBoundary{N, BID} <: BoundaryIdentifier end
71 grid_id(::TensorGridBoundary{N, BID}) where {N, BID} = N
72 boundary_id(::TensorGridBoundary{N, BID}) where {N, BID} = BID()
73
74 """
75 boundary_identifiers(g::TensorGrid)
76
77 Returns a tuple containing the boundary identifiers of `g`.
78 """
79 function boundary_identifiers(g::TensorGrid)
80 per_grid = map(eachindex(g.grids)) do i
81 return map(bid -> TensorGridBoundary{i, typeof(bid)}(), boundary_identifiers(g.grids[i]))
82 end
83 return LazyTensors.concatenate_tuples(per_grid...)
84 end
85
86 """
87 boundary_grid(g::TensorGrid, id::TensorGridBoundary)
88
89 The grid for the boundary of `g` specified by `id`.
90 """
91 function boundary_grid(g::TensorGrid, id::TensorGridBoundary)
92 local_boundary_grid = boundary_grid(g.grids[grid_id(id)], boundary_id(id))
93 new_grids = Base.setindex(g.grids, local_boundary_grid, grid_id(id))
94 return TensorGrid(new_grids...)
95 end
96
97
98 function boundary_indices(g::TensorGrid{<:Any, 1}, id::TensorGridBoundary)
99 return boundary_indices(g.grids[grid_id(id)], boundary_id(id))
100 end
101 function boundary_indices(g::TensorGrid, id::TensorGridBoundary)
102 local_b_ind = boundary_indices(g.grids[grid_id(id)], boundary_id(id))
103
104 b_ind = Base.setindex(map(eachindex, g.grids), local_b_ind, grid_id(id))
105
106 return view(eachindex(g), b_ind...)
107 end
108
109 function combined_coordinate_vector_type(coordinate_types...)
110 combined_coord_length = mapreduce(_ncomponents, +, coordinate_types)
111 combined_coord_type = mapreduce(eltype, promote_type, coordinate_types)
112
113 if combined_coord_length == 1
114 return combined_coord_type
115 else
116 return SVector{combined_coord_length, combined_coord_type}
117 end
118 end
119
120 function combine_coordinates(coords...)
121 return mapreduce(SVector, vcat, coords)
122 end
123
124 """
125 grid_and_local_dim_index(nds, d)
126
127 Given a tuple of number of dimensions `nds`, and a global dimension index `d`,
128 calculate which grid index, and local dimension, `d` corresponds to.
129
130 `nds` would come from broadcasting `ndims` on the grids tuple of a
131 `TensorGrid`. If you are interested in a dimension `d` of a tensor grid `g`
132 ```julia
133 gi, ldi = grid_and_local_dim_index(ndims.(g.grids), d)
134 ```
135 tells you which grid it belongs to (`gi`) and which index it is at within that
136 grid (`ldi`).
137 """
138 function grid_and_local_dim_index(nds, d)
139 I = findfirst(>=(d), cumsum(nds))
140
141 if I == 1
142 return (1, d)
143 else
144 return (I, d-cumsum(nds)[I-1])
145 end
146 end