diff src/Grids/equidistant_grid.jl @ 1255:1989d432731a refactor/grids

Implement the interfaces for iteration and indexing on EquidistantGrid. Make collect() work
author Jonatan Werpers <jonatan@werpers.com>
date Wed, 22 Feb 2023 22:38:25 +0100
parents ff8f335c32d1
children 198ccda331a6
line wrap: on
line diff
--- a/src/Grids/equidistant_grid.jl	Wed Feb 22 21:58:45 2023 +0100
+++ b/src/Grids/equidistant_grid.jl	Wed Feb 22 22:38:25 2023 +0100
@@ -6,16 +6,21 @@
     points::R
 end
 
-Base.eltype(g::EquidistantGrid{T}) where T = T
-Base.getindex(g::EquidistantGrid, i) = g.points[i]
-Base.size(g::EquidistantGrid) = size(g.points)
-Base.length(g::EquidistantGrid) = length(g.points)
 Base.eachindex(g::EquidistantGrid) = eachindex(g.points)
 
+# Indexing interface
+Base.getindex(g::EquidistantGrid, i) = g.points[i]
 Base.firstindex(g::EquidistantGrid) = firstindex(g.points)
 Base.lastindex(g::EquidistantGrid) = lastindex(g.points)
 
-# TODO: Make sure collect works!
+# Iteration interface
+Base.iterate(g::EquidistantGrid) = iterate(g.points)
+Base.iterate(g::EquidistantGrid, state) = iterate(g.points, state)
+
+Base.IteratorSize(::Type{EquidistantGrid}) = Base.HasShape{1}()
+Base.eltype(::Type{EquidistantGrid{T}}) where T = T
+Base.length(g::EquidistantGrid) = length(g.points)
+Base.size(g::EquidistantGrid) = size(g.points)
 
 
 """