changeset 1060:ff0e819f2075 feature/nested_stencils

Refactor code for regular stencils to use fewer type parameters and allow promotion (grafted from 737cd68318c7118f9f20a5fbc31431c48962bc71)
author Jonatan Werpers <jonatan@werpers.com>
date Sat, 12 Feb 2022 22:02:06 +0100
parents 4d06642174ec
children d1ddcdc098bf
files src/SbpOperators/stencil.jl test/SbpOperators/stencil_test.jl
diffstat 2 files changed, 43 insertions(+), 30 deletions(-) [+]
line wrap: on
line diff
--- a/src/SbpOperators/stencil.jl	Thu Feb 10 11:26:28 2022 +0100
+++ b/src/SbpOperators/stencil.jl	Sat Feb 12 22:02:06 2022 +0100
@@ -2,11 +2,11 @@
 export CenteredNestedStencil
 
 struct Stencil{T,N}
-    range::Tuple{Int,Int}
+    range::UnitRange
     weights::NTuple{N,T}
 
-    function Stencil(range::Tuple{Int,Int},weights::NTuple{N,T}) where {T, N}
-        @assert range[2]-range[1]+1 == N
+    function Stencil(range::UnitRange,weights::NTuple{N,T}) where {T, N}
+        @assert length(range) == N
         new{T,N}(range,weights)
     end
 end
@@ -16,27 +16,30 @@
 
 Create a stencil with the given weights with element `center` as the center of the stencil.
 """
-function Stencil(weights::Vararg{T}; center::Int) where T # Type parameter T makes sure the weights are valid for the Stencil constuctors and throws an earlier, more readable, error
+function Stencil(weights...; center::Int)
+    weights = promote(weights...)
     N = length(weights)
-    range = (1, N) .- center
+    range = (1:N) .- center
 
     return Stencil(range, weights)
 end
 
-function Stencil{T}(s::Stencil) where T
-    return Stencil(s.range, T.(s.weights))
-end
+Stencil{T,N}(s::Stencil{S,N}) where {T,S,N} = Stencil(s.range, T.(s.weights))
+Stencil{T}(s::Stencil) where T = Stencil{T,length(s)}(s)
 
-Base.convert(::Type{Stencil{T}}, stencil) where T = Stencil{T}(stencil)
+Base.convert(::Type{Stencil{T1,N}}, s::Stencil{T2,N}) where {T1,T2,N} = Stencil{T1,N}(s)
+Base.convert(::Type{Stencil{T1}}, s::Stencil{T2,N}) where {T1,T2,N} = Stencil{T1,N}(s)
 
-function CenteredStencil(weights::Vararg{T}) where T
+Base.promote_rule(::Type{Stencil{T1,N}}, ::Type{Stencil{T2,N}}) where {T1,T2,N} = Stencil{promote_type(T1,T2),N}
+
+function CenteredStencil(weights...)
     if iseven(length(weights))
         throw(ArgumentError("a centered stencil must have an odd number of weights."))
     end
 
     r = length(weights) ÷ 2
 
-    return Stencil((-r, r), weights)
+    return Stencil(-r:r, weights)
 end
 
 
@@ -59,24 +62,25 @@
 
 # Provides index into the Stencil based on offset for the root element
 @inline function Base.getindex(s::Stencil, i::Int)
-    @boundscheck if i < s.range[1] || s.range[2] < i
+    @boundscheck if i ∉ s.range
         return zero(eltype(s))
     end
     return s.weights[1 + i - s.range[1]]
 end
 
-Base.@propagate_inbounds @inline function apply_stencil(s::Stencil{T,N}, v::AbstractVector, i::Int) where {T,N}
-    w = s.weights[1]*v[i + s.range[1]]
-    @simd for k ∈ 2:N
-        w += s.weights[k]*v[i + s.range[1] + k-1]
+Base.@propagate_inbounds @inline function apply_stencil(s::Stencil, v::AbstractVector, i::Int)
+    w = zero(eltype(v))
+    @simd for k ∈ 1:length(s)
+        w += s.weights[k]*v[i + s.range[k]]
     end
+
     return w
 end
 
-Base.@propagate_inbounds @inline function apply_stencil_backwards(s::Stencil{T,N}, v::AbstractVector, i::Int) where {T,N}
-    w = s.weights[N]*v[i - s.range[2]]
-    @simd for k ∈ N-1:-1:1
-        w += s.weights[k]*v[i - s.range[1] - k + 1]
+Base.@propagate_inbounds @inline function apply_stencil_backwards(s::Stencil, v::AbstractVector, i::Int)
+    w = zero(eltype(v))
+    @simd for k ∈ length(s):-1:1
+        w += s.weights[k]*v[i - s.range[k]]
     end
     return w
 end
--- a/test/SbpOperators/stencil_test.jl	Thu Feb 10 11:26:28 2022 +0100
+++ b/test/SbpOperators/stencil_test.jl	Sat Feb 12 22:02:06 2022 +0100
@@ -5,21 +5,21 @@
 import Sbplib.SbpOperators.scale
 
 @testset "Stencil" begin
-    s = Stencil((-2,2), (1.,2.,2.,3.,4.))
+    s = Stencil(-2:2, (1.,2.,2.,3.,4.))
     @test s isa Stencil{Float64, 5}
 
     @test eltype(s) == Float64
 
     @test length(s) == 5
-    @test length(Stencil((-1,2), (1,2,3,4))) == 4
+    @test length(Stencil(-1:2, (1,2,3,4))) == 4
 
-    @test SbpOperators.scale(s, 2) == Stencil((-2,2), (2.,4.,4.,6.,8.))
+    @test SbpOperators.scale(s, 2) == Stencil(-2:2, (2.,4.,4.,6.,8.))
 
-    @test Stencil(1,2,3,4; center=1) == Stencil((0, 3),(1,2,3,4))
-    @test Stencil(1,2,3,4; center=2) == Stencil((-1, 2),(1,2,3,4))
-    @test Stencil(1,2,3,4; center=4) == Stencil((-3, 0),(1,2,3,4))
+    @test Stencil(1,2,3,4; center=1) == Stencil(0:3,(1,2,3,4))
+    @test Stencil(1,2,3,4; center=2) == Stencil(-1:2,(1,2,3,4))
+    @test Stencil(1,2,3,4; center=4) == Stencil(-3:0,(1,2,3,4))
 
-    @test CenteredStencil(1,2,3,4,5) == Stencil((-2, 2), (1,2,3,4,5))
+    @test CenteredStencil(1,2,3,4,5) == Stencil(-2:2, (1,2,3,4,5))
     @test_throws ArgumentError CenteredStencil(1,2,3,4)
 
     # Changing the type of the weights
@@ -30,9 +30,18 @@
 
     @testset "convert" begin
         @test convert(Stencil{Float64}, Stencil(1,2,3,4,5; center=2)) == Stencil(1.,2.,3.,4.,5.; center=2)
-        @test convert(Stencil{Float64}, CenteredStencil(1,2,3,4,5)) == CenteredStencil(1.,2.,3.,4.,5.)
-        @test convert(Stencil{Int}, Stencil(1.,2.,3.,4.,5.; center=2)) == Stencil(1,2,3,4,5; center=2)
-        @test convert(Stencil{Rational}, Stencil(1.,2.,3.,4.,5.; center=2)) == Stencil(1//1,2//1,3//1,4//1,5//1; center=2)
+        @test convert(Stencil{Float64,5}, CenteredStencil(1,2,3,4,5)) == CenteredStencil(1.,2.,3.,4.,5.)
+        @test convert(Stencil{Int,5}, Stencil(1.,2.,3.,4.,5.; center=2)) == Stencil(1,2,3,4,5; center=2)
+        @test convert(Stencil{Rational,5}, Stencil(1.,2.,3.,4.,5.; center=2)) == Stencil(1//1,2//1,3//1,4//1,5//1; center=2)
+    end
+
+    @testset "promotion of weights" begin
+        @test Stencil(1.,2; center = 1) isa Stencil{Float64, 2}
+        @test Stencil(1,2//2; center = 1) isa Stencil{Rational{Int64}, 2}
+    end
+
+    @testset "promotion" begin
+        @test promote(Stencil(1,1;center=1), Stencil(2.,2.;center=2)) == (Stencil(1.,1.;center=1), Stencil(2.,2.;center=2))
     end
 end