Mercurial > repos > public > sbplib_julia
comparison src/Grids/tensor_grid.jl @ 1854:654a2b7e6824 tooling/benchmarks
Merge default
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Sat, 11 Jan 2025 10:19:47 +0100 |
parents | 054447ac4b0e |
children | 516eaabf1169 |
comparison
equal
deleted
inserted
replaced
1378:2b5480e2d4bf | 1854:654a2b7e6824 |
---|---|
15 return new{T,D,typeof(gs)}(gs) | 15 return new{T,D,typeof(gs)}(gs) |
16 end | 16 end |
17 end | 17 end |
18 | 18 |
19 # Indexing interface | 19 # Indexing interface |
20 function Base.getindex(g::TensorGrid, I...) | 20 function Base.getindex(g::TensorGrid, I::Vararg{Int}) |
21 szs = ndims.(g.grids) | 21 szs = ndims.(g.grids) |
22 | 22 |
23 Is = LazyTensors.split_tuple(I, szs) | 23 Is = LazyTensors.split_tuple(I, szs) |
24 ps = map((g,I)->SVector(g[I...]), g.grids, Is) | 24 ps = map((g,I)->SVector(g[I...]), g.grids, Is) |
25 | 25 |
26 return vcat(ps...) | 26 return vcat(ps...) |
27 end | 27 end |
28 | 28 |
29 Base.getindex(g::TensorGrid, I::CartesianIndex) = g[Tuple(I)...] | |
30 | |
31 function Base.eachindex(g::TensorGrid) | 29 function Base.eachindex(g::TensorGrid) |
32 szs = LazyTensors.concatenate_tuples(size.(g.grids)...) | 30 szs = LazyTensors.concatenate_tuples(size.(g.grids)...) |
33 return CartesianIndices(szs) | 31 return CartesianIndices(szs) |
32 end | |
33 | |
34 function Base.axes(g::TensorGrid, d) | |
35 i, ld = grid_and_local_dim_index(ndims.(g.grids), d) | |
36 return axes(g.grids[i], ld) | |
34 end | 37 end |
35 | 38 |
36 # Iteration interface | 39 # Iteration interface |
37 Base.iterate(g::TensorGrid) = iterate(Iterators.product(g.grids...)) |> _iterate_combine_coords | 40 Base.iterate(g::TensorGrid) = iterate(Iterators.product(g.grids...)) |> _iterate_combine_coords |
38 Base.iterate(g::TensorGrid, state) = iterate(Iterators.product(g.grids...), state) |> _iterate_combine_coords | 41 Base.iterate(g::TensorGrid, state) = iterate(Iterators.product(g.grids...), state) |> _iterate_combine_coords |
39 _iterate_combine_coords(::Nothing) = nothing | 42 _iterate_combine_coords(::Nothing) = nothing |
40 _iterate_combine_coords((next,state)) = combine_coordinates(next...), state | 43 _iterate_combine_coords((next,state)) = combine_coordinates(next...), state |
41 | 44 |
42 Base.IteratorSize(::Type{<:TensorGrid{<:Any, D}}) where D = Base.HasShape{D}() | 45 Base.IteratorSize(::Type{<:TensorGrid{<:Any, D}}) where D = Base.HasShape{D}() |
43 Base.eltype(::Type{<:TensorGrid{T}}) where T = T | 46 Base.length(g::TensorGrid) = prod(length, g.grids) |
44 Base.length(g::TensorGrid) = sum(length, g.grids) | |
45 Base.size(g::TensorGrid) = LazyTensors.concatenate_tuples(size.(g.grids)...) | 47 Base.size(g::TensorGrid) = LazyTensors.concatenate_tuples(size.(g.grids)...) |
48 Base.size(g::TensorGrid, d) = size(g)[d] | |
46 | 49 |
50 function spacing(g::TensorGrid) | |
51 relevant_grids = filter(g->!isa(g,ZeroDimGrid),g.grids) | |
52 return spacing.(relevant_grids) | |
53 end | |
54 | |
55 function min_spacing(g::TensorGrid) | |
56 relevant_grids = filter(g->!isa(g,ZeroDimGrid),g.grids) | |
57 d = min_spacing.(relevant_grids) | |
58 return minimum(d) | |
59 end | |
47 | 60 |
48 refine(g::TensorGrid, r::Int) = mapreduce(g->refine(g,r), TensorGrid, g.grids) | 61 refine(g::TensorGrid, r::Int) = mapreduce(g->refine(g,r), TensorGrid, g.grids) |
49 coarsen(g::TensorGrid, r::Int) = mapreduce(g->coarsen(g,r), TensorGrid, g.grids) | 62 coarsen(g::TensorGrid, r::Int) = mapreduce(g->coarsen(g,r), TensorGrid, g.grids) |
50 | 63 |
51 """ | 64 """ |
68 return map(bid -> TensorGridBoundary{i, typeof(bid)}(), boundary_identifiers(g.grids[i])) | 81 return map(bid -> TensorGridBoundary{i, typeof(bid)}(), boundary_identifiers(g.grids[i])) |
69 end | 82 end |
70 return LazyTensors.concatenate_tuples(per_grid...) | 83 return LazyTensors.concatenate_tuples(per_grid...) |
71 end | 84 end |
72 | 85 |
73 | |
74 """ | 86 """ |
75 boundary_grid(g::TensorGrid, id::TensorGridBoundary) | 87 boundary_grid(g::TensorGrid, id::TensorGridBoundary) |
76 | 88 |
77 The grid for the boundary of `g` specified by `id`. | 89 The grid for the boundary of `g` specified by `id`. |
78 """ | 90 """ |
80 local_boundary_grid = boundary_grid(g.grids[grid_id(id)], boundary_id(id)) | 92 local_boundary_grid = boundary_grid(g.grids[grid_id(id)], boundary_id(id)) |
81 new_grids = Base.setindex(g.grids, local_boundary_grid, grid_id(id)) | 93 new_grids = Base.setindex(g.grids, local_boundary_grid, grid_id(id)) |
82 return TensorGrid(new_grids...) | 94 return TensorGrid(new_grids...) |
83 end | 95 end |
84 | 96 |
97 function boundary_indices(g::TensorGrid, id::TensorGridBoundary) | |
98 per_grid_ind = map(g.grids) do g | |
99 ntuple(i->:, ndims(g)) | |
100 end | |
101 | |
102 local_b_ind = boundary_indices(g.grids[grid_id(id)], boundary_id(id)) | |
103 b_ind = Base.setindex(per_grid_ind, local_b_ind, grid_id(id)) | |
104 | |
105 return LazyTensors.concatenate_tuples(b_ind...) | |
106 end | |
85 | 107 |
86 function combined_coordinate_vector_type(coordinate_types...) | 108 function combined_coordinate_vector_type(coordinate_types...) |
87 combined_coord_length = mapreduce(_ncomponents, +, coordinate_types) | 109 combined_coord_length = mapreduce(_ncomponents, +, coordinate_types) |
88 combined_coord_type = mapreduce(eltype, promote_type, coordinate_types) | 110 combined_coord_type = mapreduce(eltype, promote_type, coordinate_types) |
89 | 111 |
95 end | 117 end |
96 | 118 |
97 function combine_coordinates(coords...) | 119 function combine_coordinates(coords...) |
98 return mapreduce(SVector, vcat, coords) | 120 return mapreduce(SVector, vcat, coords) |
99 end | 121 end |
122 | |
123 """ | |
124 grid_and_local_dim_index(nds, d) | |
125 | |
126 Given a tuple of number of dimensions `nds`, and a global dimension index `d`, | |
127 calculate which grid index, and local dimension, `d` corresponds to. | |
128 | |
129 `nds` would come from broadcasting `ndims` on the grids tuple of a | |
130 `TensorGrid`. If you are interested in a dimension `d` of a tensor grid `g` | |
131 ```julia | |
132 gi, ldi = grid_and_local_dim_index(ndims.(g.grids), d) | |
133 ``` | |
134 tells you which grid it belongs to (`gi`) and which index it is at within that | |
135 grid (`ldi`). | |
136 """ | |
137 function grid_and_local_dim_index(nds, d) | |
138 I = findfirst(>=(d), cumsum(nds)) | |
139 | |
140 if I == 1 | |
141 return (1, d) | |
142 else | |
143 return (I, d-cumsum(nds)[I-1]) | |
144 end | |
145 end |