view src/Grids/tensor_grid.jl @ 1256:3fc78ad26d03 refactor/grids

Add notes and todos about interface implementations for grids
author Jonatan Werpers <jonatan@werpers.com>
date Wed, 22 Feb 2023 22:38:54 +0100
parents 6f75f2d2bf5c
children 198ccda331a6
line wrap: on
line source

struct TensorGrid{T,D,RD,GT<:NTuple{N,Grid} where N} <: Grid{T,D,RD}
    grids::GT

    function TensorGrid(gs...)
        T = eltype(gs[1]) # All gs should have the same T
        D = sum(ndims,gs)
        RD = sum(nrangedims, gs)

        return new{T,D,RD,typeof(gs)}(gs)
    end
end

# Indexing interface
# TODO
# Iteration interface
# TODO


function Base.size(g::TensorGrid)
    return LazyTensors.concatenate_tuples(size.(g.grids)...)
end

function Base.getindex(g::TensorGrid, I...)
    szs = ndims.(g.grids)

    Is = LazyTensors.split_tuple(I, szs)
    ps = map((g,I)->SVector(g[I...]), g.grids, Is)

    return vcat(ps...)
end

IndexStyle(::TensorGrid) = IndexCartesian()

function Base.eachindex(g::TensorGrid)
    szs = LazyTensors.concatenate_tuples(size.(g.grids)...)
    return CartesianIndices(szs)
end


struct TensorBoundary{N, BID<:BoundaryIdentifier} <: BoundaryIdentifier end
grid_id(::TensorBoundary{N, BID}) where {N, BID} = N
boundary_id(::TensorBoundary{N, BID}) where {N, BID} = BID()


"""
    boundary_identifiers(::TensorGrid)

Returns a tuple containing the boundary identifiers for the grid.
"""
function boundary_identifiers(g::TensorGrid)
    n = length(g.grids)
    per_grid = map(eachindex(g.grids)) do i
        return map(bid -> TensorBoundary{i, bid}(), boundary_identifiers(g.grids[i]))
    end
    return LazyTensors.concatenate_tuples(per_grid...)
end


"""
    boundary_grid(grid::TensorGrid, id::TensorBoundary)

The grid for the boundary specified by `id`.
"""
function boundary_grid(g::TensorGrid, bid::TensorBoundary)
    local_boundary_grid = boundary_grid(g.grids[grid_id(bid)], boundary_id(bid))
    new_grids = Base.setindex(g.grids, local_boundary_grid, grid_id(bid))
    return TensorGrid(new_grids...)
end