comparison src/SbpOperators/stencil.jl @ 1207:f1c2a4fa0ee1 performance/get_region_type_inference

Merge default
author Jonatan Werpers <jonatan@werpers.com>
date Fri, 03 Feb 2023 22:14:47 +0100
parents fdd594b2a15e
children 14cb97284373
comparison
equal deleted inserted replaced
919:b41180efb6c2 1207:f1c2a4fa0ee1
1 export CenteredStencil 1 export CenteredStencil
2 export CenteredNestedStencil
2 3
3 struct Stencil{T,N} 4 struct Stencil{T,N}
4 range::Tuple{Int,Int} 5 range::UnitRange{Int64}
5 weights::NTuple{N,T} 6 weights::NTuple{N,T}
6 7
7 function Stencil(range::Tuple{Int,Int},weights::NTuple{N,T}) where {T, N} 8 function Stencil(range::UnitRange,weights::NTuple{N,T}) where {T, N}
8 @assert range[2]-range[1]+1 == N 9 @assert length(range) == N
9 new{T,N}(range,weights) 10 new{T,N}(range,weights)
10 end 11 end
11 end 12 end
12 13
13 """ 14 """
14 Stencil(weights::NTuple; center::Int) 15 Stencil(weights::NTuple; center::Int)
15 16
16 Create a stencil with the given weights with element `center` as the center of the stencil. 17 Create a stencil with the given weights with element `center` as the center of the stencil.
17 """ 18 """
18 function Stencil(weights::Vararg{T}; center::Int) where T # Type parameter T makes sure the weights are valid for the Stencil constuctors and throws an earlier, more readable, error 19 function Stencil(weights...; center::Int)
20 weights = promote(weights...)
19 N = length(weights) 21 N = length(weights)
20 range = (1, N) .- center 22 range = (1:N) .- center
21 23
22 return Stencil(range, weights) 24 return Stencil(range, weights)
23 end 25 end
24 26
25 function Stencil{T}(s::Stencil) where T 27 Stencil{T,N}(s::Stencil{S,N}) where {T,S,N} = Stencil(s.range, T.(s.weights))
26 return Stencil(s.range, T.(s.weights)) 28 Stencil{T}(s::Stencil) where T = Stencil{T,length(s)}(s)
27 end
28 29
29 Base.convert(::Type{Stencil{T}}, stencil) where T = Stencil{T}(stencil) 30 Base.convert(::Type{Stencil{T1,N}}, s::Stencil{T2,N}) where {T1,T2,N} = Stencil{T1,N}(s)
31 Base.convert(::Type{Stencil{T1}}, s::Stencil{T2,N}) where {T1,T2,N} = Stencil{T1,N}(s)
30 32
31 function CenteredStencil(weights::Vararg) 33 Base.promote_rule(::Type{Stencil{T1,N}}, ::Type{Stencil{T2,N}}) where {T1,T2,N} = Stencil{promote_type(T1,T2),N}
34
35 function CenteredStencil(weights...)
32 if iseven(length(weights)) 36 if iseven(length(weights))
33 throw(ArgumentError("a centered stencil must have an odd number of weights.")) 37 throw(ArgumentError("a centered stencil must have an odd number of weights."))
34 end 38 end
35 39
36 r = length(weights) ÷ 2 40 r = length(weights) ÷ 2
37 41
38 return Stencil((-r, r), weights) 42 return Stencil(-r:r, weights)
39 end 43 end
40 44
41 45
42 """ 46 """
43 scale(s::Stencil, a) 47 scale(s::Stencil, a)
46 """ 50 """
47 function scale(s::Stencil, a) 51 function scale(s::Stencil, a)
48 return Stencil(s.range, a.*s.weights) 52 return Stencil(s.range, a.*s.weights)
49 end 53 end
50 54
51 Base.eltype(::Stencil{T}) where T = T 55 Base.eltype(::Stencil{T,N}) where {T,N} = T
56 Base.length(::Stencil{T,N}) where {T,N} = N
52 57
53 function flip(s::Stencil) 58 function flip(s::Stencil)
54 range = (-s.range[2], -s.range[1]) 59 range = (-s.range[2], -s.range[1])
55 return Stencil(range, reverse(s.weights)) 60 return Stencil(range, reverse(s.weights))
56 end 61 end
57 62
58 # Provides index into the Stencil based on offset for the root element 63 # Provides index into the Stencil based on offset for the root element
59 @inline function Base.getindex(s::Stencil, i::Int) 64 @inline function Base.getindex(s::Stencil, i::Int)
60 @boundscheck if i < s.range[1] || s.range[2] < i 65 @boundscheck if i ∉ s.range
61 return zero(eltype(s)) 66 return zero(eltype(s))
62 end 67 end
63 return s.weights[1 + i - s.range[1]] 68 return s.weights[1 + i - s.range[1]]
64 end 69 end
65 70
66 Base.@propagate_inbounds @inline function apply_stencil(s::Stencil{T,N}, v::AbstractVector, i::Int) where {T,N} 71 Base.@propagate_inbounds @inline function apply_stencil(s::Stencil, v::AbstractVector, i::Int)
67 w = s.weights[1]*v[i + s.range[1]] 72 w = zero(promote_type(eltype(s),eltype(v)))
68 @simd for k ∈ 2:N 73 @simd for k ∈ 1:length(s)
69 w += s.weights[k]*v[i + s.range[1] + k-1] 74 w += s.weights[k]*v[i + s.range[k]]
75 end
76
77 return w
78 end
79
80 Base.@propagate_inbounds @inline function apply_stencil_backwards(s::Stencil, v::AbstractVector, i::Int)
81 w = zero(promote_type(eltype(s),eltype(v)))
82 @simd for k ∈ length(s):-1:1
83 w += s.weights[k]*v[i - s.range[k]]
70 end 84 end
71 return w 85 return w
72 end 86 end
73 87
74 Base.@propagate_inbounds @inline function apply_stencil_backwards(s::Stencil{T,N}, v::AbstractVector, i::Int) where {T,N} 88
75 w = s.weights[N]*v[i - s.range[2]] 89 struct NestedStencil{T,N,M}
76 @simd for k ∈ N-1:-1:1 90 s::Stencil{Stencil{T,N},M}
77 w += s.weights[k]*v[i - s.range[1] - k + 1]
78 end
79 return w
80 end 91 end
92
93 # Stencil input
94 NestedStencil(s::Vararg{Stencil}; center) = NestedStencil(Stencil(s... ; center))
95 CenteredNestedStencil(s::Vararg{Stencil}) = NestedStencil(CenteredStencil(s...))
96
97 # Tuple input
98 function NestedStencil(weights::Vararg{NTuple{N,Any}}; center) where N
99 inner_stencils = map(w -> Stencil(w...; center), weights)
100 return NestedStencil(Stencil(inner_stencils... ; center))
101 end
102 function CenteredNestedStencil(weights::Vararg{NTuple{N,Any}}) where N
103 inner_stencils = map(w->CenteredStencil(w...), weights)
104 return CenteredNestedStencil(inner_stencils...)
105 end
106
107
108 # Conversion
109 function NestedStencil{T,N,M}(ns::NestedStencil{S,N,M}) where {T,S,N,M}
110 return NestedStencil(Stencil{Stencil{T}}(ns.s))
111 end
112
113 function NestedStencil{T}(ns::NestedStencil{S,N,M}) where {T,S,N,M}
114 NestedStencil{T,N,M}(ns)
115 end
116
117 function Base.convert(::Type{NestedStencil{T,N,M}}, s::NestedStencil{S,N,M}) where {T,S,N,M}
118 return NestedStencil{T,N,M}(s)
119 end
120 Base.convert(::Type{NestedStencil{T}}, stencil) where T = NestedStencil{T}(stencil)
121
122 function Base.promote_rule(::Type{NestedStencil{T,N,M}}, ::Type{NestedStencil{S,N,M}}) where {T,S,N,M}
123 return NestedStencil{promote_type(T,S),N,M}
124 end
125
126 Base.eltype(::NestedStencil{T}) where T = T
127
128 function scale(ns::NestedStencil, a)
129 range = ns.s.range
130 weights = ns.s.weights
131
132 return NestedStencil(Stencil(range, scale.(weights,a)))
133 end
134
135 function flip(ns::NestedStencil)
136 s_flip = flip(ns.s)
137 return NestedStencil(Stencil(s_flip.range, flip.(s_flip.weights)))
138 end
139
140 Base.getindex(ns::NestedStencil, i::Int) = ns.s[i]
141
142 "Apply inner stencils to `c` and get a concrete stencil"
143 Base.@propagate_inbounds function apply_inner_stencils(ns::NestedStencil, c::AbstractVector, i::Int)
144 weights = apply_stencil.(ns.s.weights, Ref(c), i)
145 return Stencil(ns.s.range, weights)
146 end
147
148 "Apply the whole nested stencil"
149 Base.@propagate_inbounds function apply_stencil(ns::NestedStencil, c::AbstractVector, v::AbstractVector, i::Int)
150 s = apply_inner_stencils(ns,c,i)
151 return apply_stencil(s, v, i)
152 end
153
154 "Apply inner stencils backwards to `c` and get a concrete stencil"
155 Base.@propagate_inbounds @inline function apply_inner_stencils_backwards(ns::NestedStencil, c::AbstractVector, i::Int)
156 weights = apply_stencil_backwards.(ns.s.weights, Ref(c), i)
157 return Stencil(ns.s.range, weights)
158 end
159
160 "Apply the whole nested stencil backwards"
161 Base.@propagate_inbounds @inline function apply_stencil_backwards(ns::NestedStencil, c::AbstractVector, v::AbstractVector, i::Int)
162 s = apply_inner_stencils_backwards(ns,c,i)
163 return apply_stencil_backwards(s, v, i)
164 end