Mercurial > repos > public > sbplib_julia
comparison src/Grids/tensor_grid.jl @ 1833:0e0833663dee refactor/grids/iterable_boundary_indices
Simplify code for tensor grids
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Mon, 21 Oct 2024 23:50:14 +0200 |
parents | 871f3f1decea |
children | 516eaabf1169 |
comparison
equal
deleted
inserted
replaced
1832:85f8855473ab | 1833:0e0833663dee |
---|---|
93 | 93 |
94 function boundary_indices(g::TensorGrid{T,1} where T, id::TensorGridBoundary) | 94 function boundary_indices(g::TensorGrid{T,1} where T, id::TensorGridBoundary) |
95 return boundary_indices(g.grids[grid_id(id)], boundary_id(id)) | 95 return boundary_indices(g.grids[grid_id(id)], boundary_id(id)) |
96 end | 96 end |
97 function boundary_indices(g::TensorGrid, id::TensorGridBoundary) | 97 function boundary_indices(g::TensorGrid, id::TensorGridBoundary) |
98 all_indices = map(eachindex, g.grids) | |
99 | |
100 local_b_ind = boundary_indices(g.grids[grid_id(id)], boundary_id(id)) | 98 local_b_ind = boundary_indices(g.grids[grid_id(id)], boundary_id(id)) |
101 | 99 |
102 b_ind = Base.setindex(all_indices, local_b_ind, grid_id(id)) | 100 b_ind = Base.setindex(map(eachindex, g.grids), local_b_ind, grid_id(id)) |
103 | 101 |
104 return view(_combine_indices(all_indices...), LazyTensors.concatenate_tuples(bla.(b_ind)...)...) | 102 return view(eachindex(g), b_ind...) |
105 end | 103 end |
106 # TODO: There must be a way to make the above code cleaner? | |
107 | |
108 # function _combine_indices(Is::Vararg{Union{Int, <:AbstractRange}}) | |
109 function _combine_indices(Is...) | |
110 return CartesianIndices(LazyTensors.concatenate_tuples(bla.(Is)...)) | |
111 end | |
112 | |
113 bla(a) = (a,) | |
114 bla(a::CartesianIndices) = a.indices | |
115 | 104 |
116 function combined_coordinate_vector_type(coordinate_types...) | 105 function combined_coordinate_vector_type(coordinate_types...) |
117 combined_coord_length = mapreduce(_ncomponents, +, coordinate_types) | 106 combined_coord_length = mapreduce(_ncomponents, +, coordinate_types) |
118 combined_coord_type = mapreduce(eltype, promote_type, coordinate_types) | 107 combined_coord_type = mapreduce(eltype, promote_type, coordinate_types) |
119 | 108 |