changeset 1255:1989d432731a refactor/grids

Implement the interfaces for iteration and indexing on EquidistantGrid. Make collect() work
author Jonatan Werpers <jonatan@werpers.com>
date Wed, 22 Feb 2023 22:38:25 +0100
parents f98d8ede0e90
children 3fc78ad26d03
files src/Grids/equidistant_grid.jl test/Grids/equidistant_grid_test.jl
diffstat 2 files changed, 11 insertions(+), 6 deletions(-) [+]
line wrap: on
line diff
--- a/src/Grids/equidistant_grid.jl	Wed Feb 22 21:58:45 2023 +0100
+++ b/src/Grids/equidistant_grid.jl	Wed Feb 22 22:38:25 2023 +0100
@@ -6,16 +6,21 @@
     points::R
 end
 
-Base.eltype(g::EquidistantGrid{T}) where T = T
-Base.getindex(g::EquidistantGrid, i) = g.points[i]
-Base.size(g::EquidistantGrid) = size(g.points)
-Base.length(g::EquidistantGrid) = length(g.points)
 Base.eachindex(g::EquidistantGrid) = eachindex(g.points)
 
+# Indexing interface
+Base.getindex(g::EquidistantGrid, i) = g.points[i]
 Base.firstindex(g::EquidistantGrid) = firstindex(g.points)
 Base.lastindex(g::EquidistantGrid) = lastindex(g.points)
 
-# TODO: Make sure collect works!
+# Iteration interface
+Base.iterate(g::EquidistantGrid) = iterate(g.points)
+Base.iterate(g::EquidistantGrid, state) = iterate(g.points, state)
+
+Base.IteratorSize(::Type{EquidistantGrid}) = Base.HasShape{1}()
+Base.eltype(::Type{EquidistantGrid{T}}) where T = T
+Base.length(g::EquidistantGrid) = length(g.points)
+Base.size(g::EquidistantGrid) = size(g.points)
 
 
 """
--- a/test/Grids/equidistant_grid_test.jl	Wed Feb 22 21:58:45 2023 +0100
+++ b/test/Grids/equidistant_grid_test.jl	Wed Feb 22 22:38:25 2023 +0100
@@ -30,7 +30,7 @@
 
     @testset "collect" begin
         g = EquidistantGrid(0:0.1:0.5)
-        @test_broken collect(g) == [0.0, 0.1, 0.2, 0.3, 0.4, 0.5]
+        @test collect(g) == [0.0, 0.1, 0.2, 0.3, 0.4, 0.5]
     end
 
     @testset "getindex" begin