comparison src/SbpOperators/stencil.jl @ 1060:ff0e819f2075 feature/nested_stencils

Refactor code for regular stencils to use fewer type parameters and allow promotion (grafted from 737cd68318c7118f9f20a5fbc31431c48962bc71)
author Jonatan Werpers <jonatan@werpers.com>
date Sat, 12 Feb 2022 22:02:06 +0100
parents 4d06642174ec
children cd6d71781137
comparison
equal deleted inserted replaced
1059:4d06642174ec 1060:ff0e819f2075
1 export CenteredStencil 1 export CenteredStencil
2 export CenteredNestedStencil 2 export CenteredNestedStencil
3 3
4 struct Stencil{T,N} 4 struct Stencil{T,N}
5 range::Tuple{Int,Int} 5 range::UnitRange
6 weights::NTuple{N,T} 6 weights::NTuple{N,T}
7 7
8 function Stencil(range::Tuple{Int,Int},weights::NTuple{N,T}) where {T, N} 8 function Stencil(range::UnitRange,weights::NTuple{N,T}) where {T, N}
9 @assert range[2]-range[1]+1 == N 9 @assert length(range) == N
10 new{T,N}(range,weights) 10 new{T,N}(range,weights)
11 end 11 end
12 end 12 end
13 13
14 """ 14 """
15 Stencil(weights::NTuple; center::Int) 15 Stencil(weights::NTuple; center::Int)
16 16
17 Create a stencil with the given weights with element `center` as the center of the stencil. 17 Create a stencil with the given weights with element `center` as the center of the stencil.
18 """ 18 """
19 function Stencil(weights::Vararg{T}; center::Int) where T # Type parameter T makes sure the weights are valid for the Stencil constuctors and throws an earlier, more readable, error 19 function Stencil(weights...; center::Int)
20 weights = promote(weights...)
20 N = length(weights) 21 N = length(weights)
21 range = (1, N) .- center 22 range = (1:N) .- center
22 23
23 return Stencil(range, weights) 24 return Stencil(range, weights)
24 end 25 end
25 26
26 function Stencil{T}(s::Stencil) where T 27 Stencil{T,N}(s::Stencil{S,N}) where {T,S,N} = Stencil(s.range, T.(s.weights))
27 return Stencil(s.range, T.(s.weights)) 28 Stencil{T}(s::Stencil) where T = Stencil{T,length(s)}(s)
28 end
29 29
30 Base.convert(::Type{Stencil{T}}, stencil) where T = Stencil{T}(stencil) 30 Base.convert(::Type{Stencil{T1,N}}, s::Stencil{T2,N}) where {T1,T2,N} = Stencil{T1,N}(s)
31 Base.convert(::Type{Stencil{T1}}, s::Stencil{T2,N}) where {T1,T2,N} = Stencil{T1,N}(s)
31 32
32 function CenteredStencil(weights::Vararg{T}) where T 33 Base.promote_rule(::Type{Stencil{T1,N}}, ::Type{Stencil{T2,N}}) where {T1,T2,N} = Stencil{promote_type(T1,T2),N}
34
35 function CenteredStencil(weights...)
33 if iseven(length(weights)) 36 if iseven(length(weights))
34 throw(ArgumentError("a centered stencil must have an odd number of weights.")) 37 throw(ArgumentError("a centered stencil must have an odd number of weights."))
35 end 38 end
36 39
37 r = length(weights) ÷ 2 40 r = length(weights) ÷ 2
38 41
39 return Stencil((-r, r), weights) 42 return Stencil(-r:r, weights)
40 end 43 end
41 44
42 45
43 """ 46 """
44 scale(s::Stencil, a) 47 scale(s::Stencil, a)
57 return Stencil(range, reverse(s.weights)) 60 return Stencil(range, reverse(s.weights))
58 end 61 end
59 62
60 # Provides index into the Stencil based on offset for the root element 63 # Provides index into the Stencil based on offset for the root element
61 @inline function Base.getindex(s::Stencil, i::Int) 64 @inline function Base.getindex(s::Stencil, i::Int)
62 @boundscheck if i < s.range[1] || s.range[2] < i 65 @boundscheck if i ∉ s.range
63 return zero(eltype(s)) 66 return zero(eltype(s))
64 end 67 end
65 return s.weights[1 + i - s.range[1]] 68 return s.weights[1 + i - s.range[1]]
66 end 69 end
67 70
68 Base.@propagate_inbounds @inline function apply_stencil(s::Stencil{T,N}, v::AbstractVector, i::Int) where {T,N} 71 Base.@propagate_inbounds @inline function apply_stencil(s::Stencil, v::AbstractVector, i::Int)
69 w = s.weights[1]*v[i + s.range[1]] 72 w = zero(eltype(v))
70 @simd for k ∈ 2:N 73 @simd for k ∈ 1:length(s)
71 w += s.weights[k]*v[i + s.range[1] + k-1] 74 w += s.weights[k]*v[i + s.range[k]]
72 end 75 end
76
73 return w 77 return w
74 end 78 end
75 79
76 Base.@propagate_inbounds @inline function apply_stencil_backwards(s::Stencil{T,N}, v::AbstractVector, i::Int) where {T,N} 80 Base.@propagate_inbounds @inline function apply_stencil_backwards(s::Stencil, v::AbstractVector, i::Int)
77 w = s.weights[N]*v[i - s.range[2]] 81 w = zero(eltype(v))
78 @simd for k ∈ N-1:-1:1 82 @simd for k ∈ length(s):-1:1
79 w += s.weights[k]*v[i - s.range[1] - k + 1] 83 w += s.weights[k]*v[i - s.range[k]]
80 end 84 end
81 return w 85 return w
82 end 86 end
83 87
84 88