Mercurial > repos > public > sbplib_julia
comparison src/Grids/tensor_grid.jl @ 2057:8a2a0d678d6f feature/lazy_tensors/pretty_printing
Merge default
| author | Jonatan Werpers <jonatan@werpers.com> |
|---|---|
| date | Tue, 10 Feb 2026 22:41:19 +0100 |
| parents | b1eb33e62d1e |
| children |
comparison
equal
deleted
inserted
replaced
| 1110:c0bff9f6e0fb | 2057:8a2a0d678d6f |
|---|---|
| 1 """ | |
| 2 TensorGrid{T,D} <: Grid{T,D} | |
| 3 | |
| 4 A grid constructed as the tensor product of other grids. | |
| 5 | |
| 6 Currently only supports grids with the `HasShape`-trait. | |
| 7 """ | |
| 8 struct TensorGrid{T,D,GT<:NTuple{N,Grid} where N} <: Grid{T,D} | |
| 9 grids::GT | |
| 10 | |
| 11 function TensorGrid(gs...) | |
| 12 T = mapreduce(eltype, combined_coordinate_vector_type, gs) | |
| 13 D = sum(ndims, gs) | |
| 14 | |
| 15 return new{T,D,typeof(gs)}(gs) | |
| 16 end | |
| 17 end | |
| 18 | |
| 19 # Indexing interface | |
| 20 function Base.getindex(g::TensorGrid, I::Vararg{Int}) | |
| 21 szs = ndims.(g.grids) | |
| 22 | |
| 23 Is = LazyTensors.split_tuple(I, szs) | |
| 24 ps = map((g,I)->SVector(g[I...]), g.grids, Is) | |
| 25 | |
| 26 return vcat(ps...) | |
| 27 end | |
| 28 | |
| 29 function Base.eachindex(g::TensorGrid) | |
| 30 szs = LazyTensors.concatenate_tuples(size.(g.grids)...) | |
| 31 return CartesianIndices(szs) | |
| 32 end | |
| 33 | |
| 34 function Base.axes(g::TensorGrid, d) | |
| 35 i, ld = grid_and_local_dim_index(ndims.(g.grids), d) | |
| 36 return axes(g.grids[i], ld) | |
| 37 end | |
| 38 | |
| 39 # Iteration interface | |
| 40 Base.iterate(g::TensorGrid) = iterate(Iterators.product(g.grids...)) |> _iterate_combine_coords | |
| 41 Base.iterate(g::TensorGrid, state) = iterate(Iterators.product(g.grids...), state) |> _iterate_combine_coords | |
| 42 _iterate_combine_coords(::Nothing) = nothing | |
| 43 _iterate_combine_coords((next,state)) = combine_coordinates(next...), state | |
| 44 | |
| 45 Base.IteratorSize(::Type{<:TensorGrid{<:Any, D}}) where D = Base.HasShape{D}() | |
| 46 Base.length(g::TensorGrid) = prod(length, g.grids) | |
| 47 Base.size(g::TensorGrid) = LazyTensors.concatenate_tuples(size.(g.grids)...) | |
| 48 Base.size(g::TensorGrid, d) = size(g)[d] | |
| 49 | |
| 50 function spacing(g::TensorGrid) | |
| 51 relevant_grids = filter(g->!isa(g,ZeroDimGrid),g.grids) | |
| 52 return spacing.(relevant_grids) | |
| 53 end | |
| 54 | |
| 55 function min_spacing(g::TensorGrid) | |
| 56 relevant_grids = filter(g->!isa(g,ZeroDimGrid),g.grids) | |
| 57 d = min_spacing.(relevant_grids) | |
| 58 return minimum(d) | |
| 59 end | |
| 60 | |
| 61 refine(g::TensorGrid, r::Int) = mapreduce(g->refine(g,r), TensorGrid, g.grids) | |
| 62 coarsen(g::TensorGrid, r::Int) = mapreduce(g->coarsen(g,r), TensorGrid, g.grids) | |
| 63 | |
| 64 """ | |
| 65 TensorGridBoundary{N, BID} <: BoundaryIdentifier | |
| 66 | |
| 67 A boundary identifier for a tensor grid. `N` Specifies which grid in the | |
| 68 tensor product and `BID` which boundary on that grid. | |
| 69 """ | |
| 70 struct TensorGridBoundary{N, BID} <: BoundaryIdentifier end | |
| 71 grid_id(::TensorGridBoundary{N, BID}) where {N, BID} = N | |
| 72 boundary_id(::TensorGridBoundary{N, BID}) where {N, BID} = BID() | |
| 73 | |
| 74 """ | |
| 75 boundary_identifiers(g::TensorGrid) | |
| 76 | |
| 77 Returns a tuple containing the boundary identifiers of `g`. | |
| 78 """ | |
| 79 function boundary_identifiers(g::TensorGrid) | |
| 80 per_grid = map(eachindex(g.grids)) do i | |
| 81 return map(bid -> TensorGridBoundary{i, typeof(bid)}(), boundary_identifiers(g.grids[i])) | |
| 82 end | |
| 83 return LazyTensors.concatenate_tuples(per_grid...) | |
| 84 end | |
| 85 | |
| 86 """ | |
| 87 boundary_grid(g::TensorGrid, id::TensorGridBoundary) | |
| 88 | |
| 89 The grid for the boundary of `g` specified by `id`. | |
| 90 """ | |
| 91 function boundary_grid(g::TensorGrid, id::TensorGridBoundary) | |
| 92 local_boundary_grid = boundary_grid(g.grids[grid_id(id)], boundary_id(id)) | |
| 93 new_grids = Base.setindex(g.grids, local_boundary_grid, grid_id(id)) | |
| 94 return TensorGrid(new_grids...) | |
| 95 end | |
| 96 | |
| 97 | |
| 98 function boundary_indices(g::TensorGrid{<:Any, 1}, id::TensorGridBoundary) | |
| 99 return boundary_indices(g.grids[grid_id(id)], boundary_id(id)) | |
| 100 end | |
| 101 function boundary_indices(g::TensorGrid, id::TensorGridBoundary) | |
| 102 local_b_ind = boundary_indices(g.grids[grid_id(id)], boundary_id(id)) | |
| 103 | |
| 104 b_ind = Base.setindex(map(eachindex, g.grids), local_b_ind, grid_id(id)) | |
| 105 | |
| 106 return view(eachindex(g), b_ind...) | |
| 107 end | |
| 108 | |
| 109 function combined_coordinate_vector_type(coordinate_types...) | |
| 110 combined_coord_length = mapreduce(_ncomponents, +, coordinate_types) | |
| 111 combined_coord_type = mapreduce(eltype, promote_type, coordinate_types) | |
| 112 | |
| 113 if combined_coord_length == 1 | |
| 114 return combined_coord_type | |
| 115 else | |
| 116 return SVector{combined_coord_length, combined_coord_type} | |
| 117 end | |
| 118 end | |
| 119 | |
| 120 function combine_coordinates(coords...) | |
| 121 return mapreduce(SVector, vcat, coords) | |
| 122 end | |
| 123 | |
| 124 """ | |
| 125 grid_and_local_dim_index(nds, d) | |
| 126 | |
| 127 Given a tuple of number of dimensions `nds`, and a global dimension index `d`, | |
| 128 calculate which grid index, and local dimension, `d` corresponds to. | |
| 129 | |
| 130 `nds` would come from broadcasting `ndims` on the grids tuple of a | |
| 131 `TensorGrid`. If you are interested in a dimension `d` of a tensor grid `g` | |
| 132 ```julia | |
| 133 gi, ldi = grid_and_local_dim_index(ndims.(g.grids), d) | |
| 134 ``` | |
| 135 tells you which grid it belongs to (`gi`) and which index it is at within that | |
| 136 grid (`ldi`). | |
| 137 """ | |
| 138 function grid_and_local_dim_index(nds, d) | |
| 139 I = findfirst(>=(d), cumsum(nds)) | |
| 140 | |
| 141 if I == 1 | |
| 142 return (1, d) | |
| 143 else | |
| 144 return (I, d-cumsum(nds)[I-1]) | |
| 145 end | |
| 146 end |
