changeset 1067:af64233419d3

Merge feature/nested_stencils
author Jonatan Werpers <jonatan@werpers.com>
date Wed, 23 Mar 2022 13:30:10 +0100
parents 396278072f18 (current diff) d13f84dc267a (diff)
children 0b0444adacd3 2b6298905692
files
diffstat 3 files changed, 260 insertions(+), 33 deletions(-) [+]
line wrap: on
line diff
--- a/src/SbpOperators/stencil.jl	Wed Mar 23 13:09:31 2022 +0100
+++ b/src/SbpOperators/stencil.jl	Wed Mar 23 13:30:10 2022 +0100
@@ -1,11 +1,12 @@
 export CenteredStencil
+export CenteredNestedStencil
 
 struct Stencil{T,N}
-    range::Tuple{Int,Int}
+    range::UnitRange{Int64}
     weights::NTuple{N,T}
 
-    function Stencil(range::Tuple{Int,Int},weights::NTuple{N,T}) where {T, N}
-        @assert range[2]-range[1]+1 == N
+    function Stencil(range::UnitRange,weights::NTuple{N,T}) where {T, N}
+        @assert length(range) == N
         new{T,N}(range,weights)
     end
 end
@@ -15,27 +16,30 @@
 
 Create a stencil with the given weights with element `center` as the center of the stencil.
 """
-function Stencil(weights::Vararg{T}; center::Int) where T # Type parameter T makes sure the weights are valid for the Stencil constuctors and throws an earlier, more readable, error
+function Stencil(weights...; center::Int)
+    weights = promote(weights...)
     N = length(weights)
-    range = (1, N) .- center
+    range = (1:N) .- center
 
     return Stencil(range, weights)
 end
 
-function Stencil{T}(s::Stencil) where T
-    return Stencil(s.range, T.(s.weights))
-end
+Stencil{T,N}(s::Stencil{S,N}) where {T,S,N} = Stencil(s.range, T.(s.weights))
+Stencil{T}(s::Stencil) where T = Stencil{T,length(s)}(s)
 
-Base.convert(::Type{Stencil{T}}, stencil) where T = Stencil{T}(stencil)
+Base.convert(::Type{Stencil{T1,N}}, s::Stencil{T2,N}) where {T1,T2,N} = Stencil{T1,N}(s)
+Base.convert(::Type{Stencil{T1}}, s::Stencil{T2,N}) where {T1,T2,N} = Stencil{T1,N}(s)
 
-function CenteredStencil(weights::Vararg)
+Base.promote_rule(::Type{Stencil{T1,N}}, ::Type{Stencil{T2,N}}) where {T1,T2,N} = Stencil{promote_type(T1,T2),N}
+
+function CenteredStencil(weights...)
     if iseven(length(weights))
         throw(ArgumentError("a centered stencil must have an odd number of weights."))
     end
 
     r = length(weights) ÷ 2
 
-    return Stencil((-r, r), weights)
+    return Stencil(-r:r, weights)
 end
 
 
@@ -48,7 +52,8 @@
     return Stencil(s.range, a.*s.weights)
 end
 
-Base.eltype(::Stencil{T}) where T = T
+Base.eltype(::Stencil{T,N}) where {T,N} = T
+Base.length(::Stencil{T,N}) where {T,N} = N
 
 function flip(s::Stencil)
     range = (-s.range[2], -s.range[1])
@@ -57,24 +62,103 @@
 
 # Provides index into the Stencil based on offset for the root element
 @inline function Base.getindex(s::Stencil, i::Int)
-    @boundscheck if i < s.range[1] || s.range[2] < i
+    @boundscheck if i ∉ s.range
         return zero(eltype(s))
     end
     return s.weights[1 + i - s.range[1]]
 end
 
-Base.@propagate_inbounds @inline function apply_stencil(s::Stencil{T,N}, v::AbstractVector, i::Int) where {T,N}
-    w = s.weights[1]*v[i + s.range[1]]
-    @simd for k ∈ 2:N
-        w += s.weights[k]*v[i + s.range[1] + k-1]
+Base.@propagate_inbounds @inline function apply_stencil(s::Stencil, v::AbstractVector, i::Int)
+    w = zero(promote_type(eltype(s),eltype(v)))
+    @simd for k ∈ 1:length(s)
+        w += s.weights[k]*v[i + s.range[k]]
+    end
+
+    return w
+end
+
+Base.@propagate_inbounds @inline function apply_stencil_backwards(s::Stencil, v::AbstractVector, i::Int)
+    w = zero(promote_type(eltype(s),eltype(v)))
+    @simd for k ∈ length(s):-1:1
+        w += s.weights[k]*v[i - s.range[k]]
     end
     return w
 end
 
-Base.@propagate_inbounds @inline function apply_stencil_backwards(s::Stencil{T,N}, v::AbstractVector, i::Int) where {T,N}
-    w = s.weights[N]*v[i - s.range[2]]
-    @simd for k ∈ N-1:-1:1
-        w += s.weights[k]*v[i - s.range[1] - k + 1]
-    end
-    return w
+
+struct NestedStencil{T,N,M}
+    s::Stencil{Stencil{T,N},M}
+end
+
+# Stencil input
+NestedStencil(s::Vararg{Stencil}; center) = NestedStencil(Stencil(s... ; center))
+CenteredNestedStencil(s::Vararg{Stencil}) = NestedStencil(CenteredStencil(s...))
+
+# Tuple input
+function NestedStencil(weights::Vararg{NTuple{N,Any}}; center) where N
+    inner_stencils = map(w -> Stencil(w...; center), weights)
+    return NestedStencil(Stencil(inner_stencils... ; center))
+end
+function CenteredNestedStencil(weights::Vararg{NTuple{N,Any}}) where N
+    inner_stencils = map(w->CenteredStencil(w...), weights)
+    return CenteredNestedStencil(inner_stencils...)
+end
+
+
+# Conversion
+function NestedStencil{T,N,M}(ns::NestedStencil{S,N,M}) where {T,S,N,M}
+    return NestedStencil(Stencil{Stencil{T}}(ns.s))
+end
+
+function NestedStencil{T}(ns::NestedStencil{S,N,M}) where {T,S,N,M}
+    NestedStencil{T,N,M}(ns)
+end
+
+function Base.convert(::Type{NestedStencil{T,N,M}}, s::NestedStencil{S,N,M}) where {T,S,N,M}
+    return NestedStencil{T,N,M}(s)
+end
+Base.convert(::Type{NestedStencil{T}}, stencil) where T = NestedStencil{T}(stencil)
+
+function Base.promote_rule(::Type{NestedStencil{T,N,M}}, ::Type{NestedStencil{S,N,M}}) where {T,S,N,M}
+    return NestedStencil{promote_type(T,S),N,M}
 end
+
+Base.eltype(::NestedStencil{T}) where T = T
+
+function scale(ns::NestedStencil, a)
+    range = ns.s.range
+    weights = ns.s.weights
+
+    return NestedStencil(Stencil(range, scale.(weights,a)))
+end
+
+function flip(ns::NestedStencil)
+    s_flip = flip(ns.s)
+    return NestedStencil(Stencil(s_flip.range, flip.(s_flip.weights)))
+end
+
+Base.getindex(ns::NestedStencil, i::Int) = ns.s[i]
+
+"Apply inner stencils to `c` and get a concrete stencil"
+Base.@propagate_inbounds function apply_inner_stencils(ns::NestedStencil, c::AbstractVector, i::Int)
+    weights = apply_stencil.(ns.s.weights, Ref(c), i)
+    return Stencil(ns.s.range, weights)
+end
+
+"Apply the whole nested stencil"
+Base.@propagate_inbounds function apply_stencil(ns::NestedStencil, c::AbstractVector, v::AbstractVector, i::Int)
+    s = apply_inner_stencils(ns,c,i)
+    return apply_stencil(s, v, i)
+end
+
+"Apply inner stencils backwards to `c` and get a concrete stencil"
+Base.@propagate_inbounds @inline function apply_inner_stencils_backwards(ns::NestedStencil, c::AbstractVector, i::Int)
+    weights = apply_stencil_backwards.(ns.s.weights, Ref(c), i)
+    return Stencil(ns.s.range, weights)
+end
+
+"Apply the whole nested stencil backwards"
+Base.@propagate_inbounds @inline function apply_stencil_backwards(ns::NestedStencil, c::AbstractVector, v::AbstractVector, i::Int)
+    s = apply_inner_stencils_backwards(ns,c,i)
+    return apply_stencil_backwards(s, v, i)
+end
--- a/test/SbpOperators/boundaryops/boundary_operator_test.jl	Wed Mar 23 13:09:31 2022 +0100
+++ b/test/SbpOperators/boundaryops/boundary_operator_test.jl	Wed Mar 23 13:30:10 2022 +0100
@@ -9,7 +9,7 @@
 import Sbplib.SbpOperators.boundary_operator
 
 @testset "BoundaryOperator" begin
-    closure_stencil = Stencil((0,2), (2.,1.,3.))
+    closure_stencil = Stencil(2.,1.,3.; center = 1)
     g_1D = EquidistantGrid(11, 0.0, 1.0)
     g_2D = EquidistantGrid((11,15), (0.0, 0.0), (1.0,1.0))
 
--- a/test/SbpOperators/stencil_test.jl	Wed Mar 23 13:09:31 2022 +0100
+++ b/test/SbpOperators/stencil_test.jl	Wed Mar 23 13:30:10 2022 +0100
@@ -1,19 +1,25 @@
 using Test
 using Sbplib.SbpOperators
 import Sbplib.SbpOperators.Stencil
+import Sbplib.SbpOperators.NestedStencil
+import Sbplib.SbpOperators.scale
 
 @testset "Stencil" begin
-    s = Stencil((-2,2), (1.,2.,2.,3.,4.))
+    s = Stencil(-2:2, (1.,2.,2.,3.,4.))
     @test s isa Stencil{Float64, 5}
 
     @test eltype(s) == Float64
-    @test SbpOperators.scale(s, 2) == Stencil((-2,2), (2.,4.,4.,6.,8.))
+
+    @test length(s) == 5
+    @test length(Stencil(-1:2, (1,2,3,4))) == 4
+
+    @test SbpOperators.scale(s, 2) == Stencil(-2:2, (2.,4.,4.,6.,8.))
 
-    @test Stencil(1,2,3,4; center=1) == Stencil((0, 3),(1,2,3,4))
-    @test Stencil(1,2,3,4; center=2) == Stencil((-1, 2),(1,2,3,4))
-    @test Stencil(1,2,3,4; center=4) == Stencil((-3, 0),(1,2,3,4))
+    @test Stencil(1,2,3,4; center=1) == Stencil(0:3,(1,2,3,4))
+    @test Stencil(1,2,3,4; center=2) == Stencil(-1:2,(1,2,3,4))
+    @test Stencil(1,2,3,4; center=4) == Stencil(-3:0,(1,2,3,4))
 
-    @test CenteredStencil(1,2,3,4,5) == Stencil((-2, 2), (1,2,3,4,5))
+    @test CenteredStencil(1,2,3,4,5) == Stencil(-2:2, (1,2,3,4,5))
     @test_throws ArgumentError CenteredStencil(1,2,3,4)
 
     # Changing the type of the weights
@@ -24,8 +30,145 @@
 
     @testset "convert" begin
         @test convert(Stencil{Float64}, Stencil(1,2,3,4,5; center=2)) == Stencil(1.,2.,3.,4.,5.; center=2)
-        @test convert(Stencil{Float64}, CenteredStencil(1,2,3,4,5)) == CenteredStencil(1.,2.,3.,4.,5.)
-        @test convert(Stencil{Int}, Stencil(1.,2.,3.,4.,5.; center=2)) == Stencil(1,2,3,4,5; center=2)
-        @test convert(Stencil{Rational}, Stencil(1.,2.,3.,4.,5.; center=2)) == Stencil(1//1,2//1,3//1,4//1,5//1; center=2)
+        @test convert(Stencil{Float64,5}, CenteredStencil(1,2,3,4,5)) == CenteredStencil(1.,2.,3.,4.,5.)
+        @test convert(Stencil{Int,5}, Stencil(1.,2.,3.,4.,5.; center=2)) == Stencil(1,2,3,4,5; center=2)
+        @test convert(Stencil{Rational,5}, Stencil(1.,2.,3.,4.,5.; center=2)) == Stencil(1//1,2//1,3//1,4//1,5//1; center=2)
+    end
+
+    @testset "promotion of weights" begin
+        @test Stencil(1.,2; center = 1) isa Stencil{Float64, 2}
+        @test Stencil(1,2//2; center = 1) isa Stencil{Rational{Int64}, 2}
+    end
+
+    @testset "promotion" begin
+        @test promote(Stencil(1,1;center=1), Stencil(2.,2.;center=2)) == (Stencil(1.,1.;center=1), Stencil(2.,2.;center=2))
+    end
+
+    @testset "type stability" begin
+        s_int = CenteredStencil(1,2,3)
+        s_float = CenteredStencil(1.,2.,3.)
+        v_int = rand(1:10,10);
+        v_float = rand(10);
+
+        @inferred SbpOperators.apply_stencil(s_int, v_int, 2)
+        @inferred SbpOperators.apply_stencil(s_float, v_float, 2)
+        @inferred SbpOperators.apply_stencil(s_int,  v_float, 2)
+        @inferred SbpOperators.apply_stencil(s_float, v_int, 2)
+
+        @inferred SbpOperators.apply_stencil_backwards(s_int, v_int, 5)
+        @inferred SbpOperators.apply_stencil_backwards(s_float, v_float, 5)
+        @inferred SbpOperators.apply_stencil_backwards(s_int,  v_float, 5)
+        @inferred SbpOperators.apply_stencil_backwards(s_float, v_int, 5)
     end
 end
+
+@testset "NestedStencil" begin
+
+    @testset "Constructors" begin
+        s1 = CenteredStencil(-1, 1, 0)
+        s2 = CenteredStencil(-1, 0, 1)
+        s3 = CenteredStencil( 0,-1, 1)
+
+        ns = NestedStencil(CenteredStencil(s1,s2,s3))
+        @test ns isa NestedStencil{Int,3}
+
+        @test CenteredNestedStencil(s1,s2,s3) == ns
+
+        @test NestedStencil(s1,s2,s3, center = 2) == ns
+        @test NestedStencil(s1,s2,s3, center = 1) == NestedStencil(Stencil(s1,s2,s3, center=1))
+
+        @test NestedStencil((-1,1,0),(-1,0,1),(0,-1,1), center=2) == ns
+        @test CenteredNestedStencil((-1,1,0),(-1,0,1),(0,-1,1)) == ns
+        @test NestedStencil((-1,1,0),(-1,0,1),(0,-1,1), center=1) == NestedStencil(Stencil(
+            Stencil(-1, 1, 0; center=1),
+            Stencil(-1, 0, 1; center=1),
+            Stencil( 0,-1, 1; center=1);
+            center=1
+        ))
+
+        @testset "Error handling" begin
+        end
+    end
+
+    @testset "scale" begin
+        ns = NestedStencil((-1,1,0),(-1,0,1),(0,-1,1), center=2)
+        @test SbpOperators.scale(ns, 2) == NestedStencil((-2,2,0),(-2,0,2),(0,-2,2), center=2)
+    end
+
+    @testset "conversion" begin
+        ns = NestedStencil((-1,1,0),(-1,0,1),(0,-1,1), center=2)
+        @test NestedStencil{Float64}(ns) == NestedStencil((-1.,1.,0.),(-1.,0.,1.),(0.,-1.,1.), center=2)
+        @test NestedStencil{Rational}(ns) == NestedStencil((-1//1,1//1,0//1),(-1//1,0//1,1//1),(0//1,-1//1,1//1), center=2)
+
+        @test convert(NestedStencil{Float64}, ns) == NestedStencil((-1.,1.,0.),(-1.,0.,1.),(0.,-1.,1.), center=2)
+        @test convert(NestedStencil{Rational}, ns) == NestedStencil((-1//1,1//1,0//1),(-1//1,0//1,1//1),(0//1,-1//1,1//1), center=2)
+    end
+
+    @testset "promotion of weights" begin
+        @test NestedStencil((-1,1,0),(-1.,0.,1.),(0,-1,1), center=2) isa NestedStencil{Float64,3,3}
+        @test NestedStencil((-1,1,0),(-1,0,1),(0//1,-1,1), center=2) isa NestedStencil{Rational{Int64},3,3}
+    end
+
+    @testset "promotion" begin
+        promote(
+            CenteredNestedStencil((-1,1,0),(-1,0,1),(0,-1,1)),
+            CenteredNestedStencil((-1.,1.,0.),(-1.,0.,1.),(0.,-1.,1.))
+        ) == (
+            CenteredNestedStencil((-1.,1.,0.),(-1.,0.,1.),(0.,-1.,1.)),
+            CenteredNestedStencil((-1.,1.,0.),(-1.,0.,1.),(0.,-1.,1.))
+        )
+    end
+
+    @testset "apply" begin
+        c = [  1,  3,  6, 10, 15, 21, 28, 36, 45, 55]
+        v = [  2,  3,  5,  7, 11, 13, 17, 19, 23, 29]
+
+        # Centered
+        ns = NestedStencil((-1,1,0),(-1,0,1),(0,-2,2), center=2)
+        @test SbpOperators.apply_inner_stencils(ns, c, 4) == Stencil(4,9,10; center=2)
+        @test SbpOperators.apply_inner_stencils_backwards(ns, c, 4) == Stencil(-5,-9,-8; center=2)
+
+        @test SbpOperators.apply_stencil(ns, c, v, 4) == 4*5 + 9*7 + 10*11
+        @test SbpOperators.apply_stencil_backwards(ns, c, v, 4) == -8*5 - 9*7 - 5*11
+
+        # Non-centered
+        ns = NestedStencil((-1,1,0),(-1,0,1),(0,-1,1), center=1)
+        @test SbpOperators.apply_inner_stencils(ns, c, 4) == Stencil(5,11,6; center=1)
+        @test SbpOperators.apply_inner_stencils_backwards(ns, c, 4) == Stencil(-4,-7,-3; center=1)
+
+        @test SbpOperators.apply_stencil(ns, c, v, 4) == 5*7 + 11*11 + 6*13
+        @test SbpOperators.apply_stencil_backwards(ns, c, v, 4) == -3*3 - 7*5 - 4*7
+    end
+
+    @testset "type stability" begin
+        s_int = CenteredNestedStencil((1,2,3),(1,2,3),(1,2,3))
+        s_float = CenteredNestedStencil((1.,2.,3.),(1.,2.,3.),(1.,2.,3.))
+
+        v_int = rand(1:10,10);
+        v_float = rand(10);
+
+        c_int = rand(1:10,10);
+        c_float = rand(10);
+
+        @inferred SbpOperators.apply_stencil(s_int,   c_int, v_int,   2)
+        @inferred SbpOperators.apply_stencil(s_float, c_int, v_float, 2)
+        @inferred SbpOperators.apply_stencil(s_int,   c_int, v_float, 2)
+        @inferred SbpOperators.apply_stencil(s_float, c_int, v_int,   2)
+
+        @inferred SbpOperators.apply_stencil(s_int,   c_float, v_int,   2)
+        @inferred SbpOperators.apply_stencil(s_float, c_float, v_float, 2)
+        @inferred SbpOperators.apply_stencil(s_int,   c_float, v_float, 2)
+        @inferred SbpOperators.apply_stencil(s_float, c_float, v_int,   2)
+
+        @inferred SbpOperators.apply_stencil_backwards(s_int,   c_int, v_int,   2)
+        @inferred SbpOperators.apply_stencil_backwards(s_float, c_int, v_float, 2)
+        @inferred SbpOperators.apply_stencil_backwards(s_int,   c_int, v_float, 2)
+        @inferred SbpOperators.apply_stencil_backwards(s_float, c_int, v_int,   2)
+
+        @inferred SbpOperators.apply_stencil_backwards(s_int,   c_float, v_int,   2)
+        @inferred SbpOperators.apply_stencil_backwards(s_float, c_float, v_float, 2)
+        @inferred SbpOperators.apply_stencil_backwards(s_int,   c_float, v_float, 2)
+        @inferred SbpOperators.apply_stencil_backwards(s_float, c_float, v_int,   2)
+    end
+
+end