Mercurial > repos > public > sbplib_julia
diff src/SbpOperators/stencil.jl @ 897:737cd68318c7 feature/variable_derivatives
Refactor code for regular stencils to use fewer type parameters and allow promotion
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Sat, 12 Feb 2022 22:02:06 +0100 |
parents | 004324d7ed35 |
children | cd6d71781137 |
line wrap: on
line diff
--- a/src/SbpOperators/stencil.jl Thu Feb 10 11:26:28 2022 +0100 +++ b/src/SbpOperators/stencil.jl Sat Feb 12 22:02:06 2022 +0100 @@ -2,11 +2,11 @@ export CenteredNestedStencil struct Stencil{T,N} - range::Tuple{Int,Int} + range::UnitRange weights::NTuple{N,T} - function Stencil(range::Tuple{Int,Int},weights::NTuple{N,T}) where {T, N} - @assert range[2]-range[1]+1 == N + function Stencil(range::UnitRange,weights::NTuple{N,T}) where {T, N} + @assert length(range) == N new{T,N}(range,weights) end end @@ -16,27 +16,30 @@ Create a stencil with the given weights with element `center` as the center of the stencil. """ -function Stencil(weights::Vararg{T}; center::Int) where T # Type parameter T makes sure the weights are valid for the Stencil constuctors and throws an earlier, more readable, error +function Stencil(weights...; center::Int) + weights = promote(weights...) N = length(weights) - range = (1, N) .- center + range = (1:N) .- center return Stencil(range, weights) end -function Stencil{T}(s::Stencil) where T - return Stencil(s.range, T.(s.weights)) -end +Stencil{T,N}(s::Stencil{S,N}) where {T,S,N} = Stencil(s.range, T.(s.weights)) +Stencil{T}(s::Stencil) where T = Stencil{T,length(s)}(s) -Base.convert(::Type{Stencil{T}}, stencil) where T = Stencil{T}(stencil) +Base.convert(::Type{Stencil{T1,N}}, s::Stencil{T2,N}) where {T1,T2,N} = Stencil{T1,N}(s) +Base.convert(::Type{Stencil{T1}}, s::Stencil{T2,N}) where {T1,T2,N} = Stencil{T1,N}(s) -function CenteredStencil(weights::Vararg{T}) where T +Base.promote_rule(::Type{Stencil{T1,N}}, ::Type{Stencil{T2,N}}) where {T1,T2,N} = Stencil{promote_type(T1,T2),N} + +function CenteredStencil(weights...) if iseven(length(weights)) throw(ArgumentError("a centered stencil must have an odd number of weights.")) end r = length(weights) ÷ 2 - return Stencil((-r, r), weights) + return Stencil(-r:r, weights) end @@ -59,24 +62,25 @@ # Provides index into the Stencil based on offset for the root element @inline function Base.getindex(s::Stencil, i::Int) - @boundscheck if i < s.range[1] || s.range[2] < i + @boundscheck if i ∉ s.range return zero(eltype(s)) end return s.weights[1 + i - s.range[1]] end -Base.@propagate_inbounds @inline function apply_stencil(s::Stencil{T,N}, v::AbstractVector, i::Int) where {T,N} - w = s.weights[1]*v[i + s.range[1]] - @simd for k ∈ 2:N - w += s.weights[k]*v[i + s.range[1] + k-1] +Base.@propagate_inbounds @inline function apply_stencil(s::Stencil, v::AbstractVector, i::Int) + w = zero(eltype(v)) + @simd for k ∈ 1:length(s) + w += s.weights[k]*v[i + s.range[k]] end + return w end -Base.@propagate_inbounds @inline function apply_stencil_backwards(s::Stencil{T,N}, v::AbstractVector, i::Int) where {T,N} - w = s.weights[N]*v[i - s.range[2]] - @simd for k ∈ N-1:-1:1 - w += s.weights[k]*v[i - s.range[1] - k + 1] +Base.@propagate_inbounds @inline function apply_stencil_backwards(s::Stencil, v::AbstractVector, i::Int) + w = zero(eltype(v)) + @simd for k ∈ length(s):-1:1 + w += s.weights[k]*v[i - s.range[k]] end return w end