Mercurial > repos > public > sbplib_julia
comparison src/SbpOperators/stencil.jl @ 1207:f1c2a4fa0ee1 performance/get_region_type_inference
Merge default
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Fri, 03 Feb 2023 22:14:47 +0100 |
parents | fdd594b2a15e |
children | 14cb97284373 |
comparison
equal
deleted
inserted
replaced
919:b41180efb6c2 | 1207:f1c2a4fa0ee1 |
---|---|
1 export CenteredStencil | 1 export CenteredStencil |
2 export CenteredNestedStencil | |
2 | 3 |
3 struct Stencil{T,N} | 4 struct Stencil{T,N} |
4 range::Tuple{Int,Int} | 5 range::UnitRange{Int64} |
5 weights::NTuple{N,T} | 6 weights::NTuple{N,T} |
6 | 7 |
7 function Stencil(range::Tuple{Int,Int},weights::NTuple{N,T}) where {T, N} | 8 function Stencil(range::UnitRange,weights::NTuple{N,T}) where {T, N} |
8 @assert range[2]-range[1]+1 == N | 9 @assert length(range) == N |
9 new{T,N}(range,weights) | 10 new{T,N}(range,weights) |
10 end | 11 end |
11 end | 12 end |
12 | 13 |
13 """ | 14 """ |
14 Stencil(weights::NTuple; center::Int) | 15 Stencil(weights::NTuple; center::Int) |
15 | 16 |
16 Create a stencil with the given weights with element `center` as the center of the stencil. | 17 Create a stencil with the given weights with element `center` as the center of the stencil. |
17 """ | 18 """ |
18 function Stencil(weights::Vararg{T}; center::Int) where T # Type parameter T makes sure the weights are valid for the Stencil constuctors and throws an earlier, more readable, error | 19 function Stencil(weights...; center::Int) |
20 weights = promote(weights...) | |
19 N = length(weights) | 21 N = length(weights) |
20 range = (1, N) .- center | 22 range = (1:N) .- center |
21 | 23 |
22 return Stencil(range, weights) | 24 return Stencil(range, weights) |
23 end | 25 end |
24 | 26 |
25 function Stencil{T}(s::Stencil) where T | 27 Stencil{T,N}(s::Stencil{S,N}) where {T,S,N} = Stencil(s.range, T.(s.weights)) |
26 return Stencil(s.range, T.(s.weights)) | 28 Stencil{T}(s::Stencil) where T = Stencil{T,length(s)}(s) |
27 end | |
28 | 29 |
29 Base.convert(::Type{Stencil{T}}, stencil) where T = Stencil{T}(stencil) | 30 Base.convert(::Type{Stencil{T1,N}}, s::Stencil{T2,N}) where {T1,T2,N} = Stencil{T1,N}(s) |
31 Base.convert(::Type{Stencil{T1}}, s::Stencil{T2,N}) where {T1,T2,N} = Stencil{T1,N}(s) | |
30 | 32 |
31 function CenteredStencil(weights::Vararg) | 33 Base.promote_rule(::Type{Stencil{T1,N}}, ::Type{Stencil{T2,N}}) where {T1,T2,N} = Stencil{promote_type(T1,T2),N} |
34 | |
35 function CenteredStencil(weights...) | |
32 if iseven(length(weights)) | 36 if iseven(length(weights)) |
33 throw(ArgumentError("a centered stencil must have an odd number of weights.")) | 37 throw(ArgumentError("a centered stencil must have an odd number of weights.")) |
34 end | 38 end |
35 | 39 |
36 r = length(weights) ÷ 2 | 40 r = length(weights) ÷ 2 |
37 | 41 |
38 return Stencil((-r, r), weights) | 42 return Stencil(-r:r, weights) |
39 end | 43 end |
40 | 44 |
41 | 45 |
42 """ | 46 """ |
43 scale(s::Stencil, a) | 47 scale(s::Stencil, a) |
46 """ | 50 """ |
47 function scale(s::Stencil, a) | 51 function scale(s::Stencil, a) |
48 return Stencil(s.range, a.*s.weights) | 52 return Stencil(s.range, a.*s.weights) |
49 end | 53 end |
50 | 54 |
51 Base.eltype(::Stencil{T}) where T = T | 55 Base.eltype(::Stencil{T,N}) where {T,N} = T |
56 Base.length(::Stencil{T,N}) where {T,N} = N | |
52 | 57 |
53 function flip(s::Stencil) | 58 function flip(s::Stencil) |
54 range = (-s.range[2], -s.range[1]) | 59 range = (-s.range[2], -s.range[1]) |
55 return Stencil(range, reverse(s.weights)) | 60 return Stencil(range, reverse(s.weights)) |
56 end | 61 end |
57 | 62 |
58 # Provides index into the Stencil based on offset for the root element | 63 # Provides index into the Stencil based on offset for the root element |
59 @inline function Base.getindex(s::Stencil, i::Int) | 64 @inline function Base.getindex(s::Stencil, i::Int) |
60 @boundscheck if i < s.range[1] || s.range[2] < i | 65 @boundscheck if i ∉ s.range |
61 return zero(eltype(s)) | 66 return zero(eltype(s)) |
62 end | 67 end |
63 return s.weights[1 + i - s.range[1]] | 68 return s.weights[1 + i - s.range[1]] |
64 end | 69 end |
65 | 70 |
66 Base.@propagate_inbounds @inline function apply_stencil(s::Stencil{T,N}, v::AbstractVector, i::Int) where {T,N} | 71 Base.@propagate_inbounds @inline function apply_stencil(s::Stencil, v::AbstractVector, i::Int) |
67 w = s.weights[1]*v[i + s.range[1]] | 72 w = zero(promote_type(eltype(s),eltype(v))) |
68 @simd for k ∈ 2:N | 73 @simd for k ∈ 1:length(s) |
69 w += s.weights[k]*v[i + s.range[1] + k-1] | 74 w += s.weights[k]*v[i + s.range[k]] |
75 end | |
76 | |
77 return w | |
78 end | |
79 | |
80 Base.@propagate_inbounds @inline function apply_stencil_backwards(s::Stencil, v::AbstractVector, i::Int) | |
81 w = zero(promote_type(eltype(s),eltype(v))) | |
82 @simd for k ∈ length(s):-1:1 | |
83 w += s.weights[k]*v[i - s.range[k]] | |
70 end | 84 end |
71 return w | 85 return w |
72 end | 86 end |
73 | 87 |
74 Base.@propagate_inbounds @inline function apply_stencil_backwards(s::Stencil{T,N}, v::AbstractVector, i::Int) where {T,N} | 88 |
75 w = s.weights[N]*v[i - s.range[2]] | 89 struct NestedStencil{T,N,M} |
76 @simd for k ∈ N-1:-1:1 | 90 s::Stencil{Stencil{T,N},M} |
77 w += s.weights[k]*v[i - s.range[1] - k + 1] | |
78 end | |
79 return w | |
80 end | 91 end |
92 | |
93 # Stencil input | |
94 NestedStencil(s::Vararg{Stencil}; center) = NestedStencil(Stencil(s... ; center)) | |
95 CenteredNestedStencil(s::Vararg{Stencil}) = NestedStencil(CenteredStencil(s...)) | |
96 | |
97 # Tuple input | |
98 function NestedStencil(weights::Vararg{NTuple{N,Any}}; center) where N | |
99 inner_stencils = map(w -> Stencil(w...; center), weights) | |
100 return NestedStencil(Stencil(inner_stencils... ; center)) | |
101 end | |
102 function CenteredNestedStencil(weights::Vararg{NTuple{N,Any}}) where N | |
103 inner_stencils = map(w->CenteredStencil(w...), weights) | |
104 return CenteredNestedStencil(inner_stencils...) | |
105 end | |
106 | |
107 | |
108 # Conversion | |
109 function NestedStencil{T,N,M}(ns::NestedStencil{S,N,M}) where {T,S,N,M} | |
110 return NestedStencil(Stencil{Stencil{T}}(ns.s)) | |
111 end | |
112 | |
113 function NestedStencil{T}(ns::NestedStencil{S,N,M}) where {T,S,N,M} | |
114 NestedStencil{T,N,M}(ns) | |
115 end | |
116 | |
117 function Base.convert(::Type{NestedStencil{T,N,M}}, s::NestedStencil{S,N,M}) where {T,S,N,M} | |
118 return NestedStencil{T,N,M}(s) | |
119 end | |
120 Base.convert(::Type{NestedStencil{T}}, stencil) where T = NestedStencil{T}(stencil) | |
121 | |
122 function Base.promote_rule(::Type{NestedStencil{T,N,M}}, ::Type{NestedStencil{S,N,M}}) where {T,S,N,M} | |
123 return NestedStencil{promote_type(T,S),N,M} | |
124 end | |
125 | |
126 Base.eltype(::NestedStencil{T}) where T = T | |
127 | |
128 function scale(ns::NestedStencil, a) | |
129 range = ns.s.range | |
130 weights = ns.s.weights | |
131 | |
132 return NestedStencil(Stencil(range, scale.(weights,a))) | |
133 end | |
134 | |
135 function flip(ns::NestedStencil) | |
136 s_flip = flip(ns.s) | |
137 return NestedStencil(Stencil(s_flip.range, flip.(s_flip.weights))) | |
138 end | |
139 | |
140 Base.getindex(ns::NestedStencil, i::Int) = ns.s[i] | |
141 | |
142 "Apply inner stencils to `c` and get a concrete stencil" | |
143 Base.@propagate_inbounds function apply_inner_stencils(ns::NestedStencil, c::AbstractVector, i::Int) | |
144 weights = apply_stencil.(ns.s.weights, Ref(c), i) | |
145 return Stencil(ns.s.range, weights) | |
146 end | |
147 | |
148 "Apply the whole nested stencil" | |
149 Base.@propagate_inbounds function apply_stencil(ns::NestedStencil, c::AbstractVector, v::AbstractVector, i::Int) | |
150 s = apply_inner_stencils(ns,c,i) | |
151 return apply_stencil(s, v, i) | |
152 end | |
153 | |
154 "Apply inner stencils backwards to `c` and get a concrete stencil" | |
155 Base.@propagate_inbounds @inline function apply_inner_stencils_backwards(ns::NestedStencil, c::AbstractVector, i::Int) | |
156 weights = apply_stencil_backwards.(ns.s.weights, Ref(c), i) | |
157 return Stencil(ns.s.range, weights) | |
158 end | |
159 | |
160 "Apply the whole nested stencil backwards" | |
161 Base.@propagate_inbounds @inline function apply_stencil_backwards(ns::NestedStencil, c::AbstractVector, v::AbstractVector, i::Int) | |
162 s = apply_inner_stencils_backwards(ns,c,i) | |
163 return apply_stencil_backwards(s, v, i) | |
164 end |