Mercurial > repos > public > sbplib_julia
comparison src/SbpOperators/stencil.jl @ 1072:14cb97284373 feature/dissipation_operators
Merge default
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Thu, 24 Mar 2022 12:00:12 +0100 |
parents | 11767fbb29f4 fdd594b2a15e |
children | 6baed7b081f2 4f79ab676ebc e1222fbb7c4d c10c6c3e9247 f13857f37b8f |
comparison
equal
deleted
inserted
replaced
1069:c89c6b63c7f4 | 1072:14cb97284373 |
---|---|
1 export CenteredStencil | 1 export CenteredStencil |
2 export CenteredNestedStencil | |
2 | 3 |
3 struct Stencil{T,N} | 4 struct Stencil{T,N} |
4 range::Tuple{Int,Int} | 5 range::UnitRange{Int64} |
5 weights::NTuple{N,T} | 6 weights::NTuple{N,T} |
6 | 7 |
7 function Stencil(range::Tuple{Int,Int},weights::NTuple{N,T}) where {T, N} | 8 function Stencil(range::UnitRange,weights::NTuple{N,T}) where {T, N} |
8 @assert range[2]-range[1]+1 == N | 9 @assert length(range) == N |
9 new{T,N}(range,weights) | 10 new{T,N}(range,weights) |
10 end | 11 end |
11 end | 12 end |
12 | 13 |
13 """ | 14 """ |
14 Stencil(weights::NTuple; center::Int) | 15 Stencil(weights::NTuple; center::Int) |
15 | 16 |
16 Create a stencil with the given weights with element `center` as the center of the stencil. | 17 Create a stencil with the given weights with element `center` as the center of the stencil. |
17 """ | 18 """ |
18 function Stencil(weights::Vararg{T}; center::Int) where T # Type parameter T makes sure the weights are valid for the Stencil constuctors and throws an earlier, more readable, error | 19 function Stencil(weights...; center::Int) |
20 weights = promote(weights...) | |
19 N = length(weights) | 21 N = length(weights) |
20 range = (1, N) .- center | 22 range = (1:N) .- center |
21 | 23 |
22 return Stencil(range, weights) | 24 return Stencil(range, weights) |
23 end | 25 end |
24 | 26 |
25 function Stencil{T}(s::Stencil) where T | 27 Stencil{T,N}(s::Stencil{S,N}) where {T,S,N} = Stencil(s.range, T.(s.weights)) |
26 return Stencil(s.range, T.(s.weights)) | 28 Stencil{T}(s::Stencil) where T = Stencil{T,length(s)}(s) |
27 end | |
28 | 29 |
29 Base.convert(::Type{Stencil{T}}, stencil) where T = Stencil{T}(stencil) | 30 Base.convert(::Type{Stencil{T1,N}}, s::Stencil{T2,N}) where {T1,T2,N} = Stencil{T1,N}(s) |
31 Base.convert(::Type{Stencil{T1}}, s::Stencil{T2,N}) where {T1,T2,N} = Stencil{T1,N}(s) | |
30 | 32 |
31 function CenteredStencil(weights::Vararg) | 33 Base.promote_rule(::Type{Stencil{T1,N}}, ::Type{Stencil{T2,N}}) where {T1,T2,N} = Stencil{promote_type(T1,T2),N} |
34 | |
35 function CenteredStencil(weights...) | |
32 if iseven(length(weights)) | 36 if iseven(length(weights)) |
33 throw(ArgumentError("a centered stencil must have an odd number of weights.")) | 37 throw(ArgumentError("a centered stencil must have an odd number of weights.")) |
34 end | 38 end |
35 | 39 |
36 r = length(weights) ÷ 2 | 40 r = length(weights) ÷ 2 |
37 | 41 |
38 return Stencil((-r, r), weights) | 42 return Stencil(-r:r, weights) |
39 end | 43 end |
40 | 44 |
41 | 45 |
42 """ | 46 """ |
43 scale(s::Stencil, a) | 47 scale(s::Stencil, a) |
46 """ | 50 """ |
47 function scale(s::Stencil, a) | 51 function scale(s::Stencil, a) |
48 return Stencil(s.range, a.*s.weights) | 52 return Stencil(s.range, a.*s.weights) |
49 end | 53 end |
50 | 54 |
51 Base.eltype(::Stencil{T}) where T = T | 55 Base.eltype(::Stencil{T,N}) where {T,N} = T |
56 Base.length(::Stencil{T,N}) where {T,N} = N | |
52 | 57 |
53 function flip(s::Stencil) | 58 function flip(s::Stencil) |
54 range = (-s.range[2], -s.range[1]) | 59 range = (-s.range[2], -s.range[1]) |
55 return Stencil(range, reverse(s.weights)) | 60 return Stencil(range, reverse(s.weights)) |
56 end | 61 end |
57 | 62 |
58 # Provides index into the Stencil based on offset for the root element | 63 # Provides index into the Stencil based on offset for the root element |
59 @inline function Base.getindex(s::Stencil, i::Int) | 64 @inline function Base.getindex(s::Stencil, i::Int) |
60 @boundscheck if i < s.range[1] || s.range[2] < i | 65 @boundscheck if i ∉ s.range |
61 return zero(eltype(s)) | 66 return zero(eltype(s)) |
62 end | 67 end |
63 return s.weights[1 + i - s.range[1]] | 68 return s.weights[1 + i - s.range[1]] |
64 end | 69 end |
65 | 70 |
66 Base.@propagate_inbounds @inline function apply_stencil(s::Stencil{T,N}, v::AbstractVector, i::Int) where {T,N} | 71 Base.@propagate_inbounds @inline function apply_stencil(s::Stencil, v::AbstractVector, i::Int) |
67 w = s.weights[1]*v[i + s.range[1]] | 72 w = zero(promote_type(eltype(s),eltype(v))) |
68 @simd for k ∈ 2:N | 73 @simd for k ∈ 1:length(s) |
69 w += s.weights[k]*v[i + s.range[1] + k-1] | 74 w += s.weights[k]*v[i + s.range[k]] |
75 end | |
76 | |
77 return w | |
78 end | |
79 | |
80 Base.@propagate_inbounds @inline function apply_stencil_backwards(s::Stencil, v::AbstractVector, i::Int) | |
81 w = zero(promote_type(eltype(s),eltype(v))) | |
82 @simd for k ∈ length(s):-1:1 | |
83 w += s.weights[k]*v[i - s.range[k]] | |
70 end | 84 end |
71 return w | 85 return w |
72 end | 86 end |
73 | 87 |
74 Base.@propagate_inbounds @inline function apply_stencil_backwards(s::Stencil{T,N}, v::AbstractVector, i::Int) where {T,N} | |
75 w = s.weights[N]*v[i - s.range[2]] | |
76 @simd for k ∈ N-1:-1:1 | |
77 w += s.weights[k]*v[i - s.range[1] - k + 1] | |
78 end | |
79 return w | |
80 end | |
81 | |
82 | |
83 function left_pad(s::Stencil, N) | 88 function left_pad(s::Stencil, N) |
84 weights = LazyTensors.left_pad_tuple(s.weights, zero(eltype(s)), N) | 89 weights = LazyTensors.left_pad_tuple(s.weights, zero(eltype(s)), N) |
85 range = (s.range[1] - (N - length(s.weights)) ,s.range[2]) | 90 range = (first(s.range) - (N - length(s.weights))):last(s.range) |
86 | 91 |
87 return Stencil(range, weights) | 92 return Stencil(range, weights) |
88 end | 93 end |
89 | 94 |
90 function right_pad(s::Stencil, N) | 95 function right_pad(s::Stencil, N) |
91 weights = LazyTensors.right_pad_tuple(s.weights, zero(eltype(s)), N) | 96 weights = LazyTensors.right_pad_tuple(s.weights, zero(eltype(s)), N) |
92 range = (s.range[1], s.range[2] + (N - length(s.weights))) | 97 range = first(s.range):(last(s.range) + (N - length(s.weights))) |
93 | 98 |
94 return Stencil(range, weights) | 99 return Stencil(range, weights) |
95 end | 100 end |
101 | |
102 | |
103 | |
104 struct NestedStencil{T,N,M} | |
105 s::Stencil{Stencil{T,N},M} | |
106 end | |
107 | |
108 # Stencil input | |
109 NestedStencil(s::Vararg{Stencil}; center) = NestedStencil(Stencil(s... ; center)) | |
110 CenteredNestedStencil(s::Vararg{Stencil}) = NestedStencil(CenteredStencil(s...)) | |
111 | |
112 # Tuple input | |
113 function NestedStencil(weights::Vararg{NTuple{N,Any}}; center) where N | |
114 inner_stencils = map(w -> Stencil(w...; center), weights) | |
115 return NestedStencil(Stencil(inner_stencils... ; center)) | |
116 end | |
117 function CenteredNestedStencil(weights::Vararg{NTuple{N,Any}}) where N | |
118 inner_stencils = map(w->CenteredStencil(w...), weights) | |
119 return CenteredNestedStencil(inner_stencils...) | |
120 end | |
121 | |
122 | |
123 # Conversion | |
124 function NestedStencil{T,N,M}(ns::NestedStencil{S,N,M}) where {T,S,N,M} | |
125 return NestedStencil(Stencil{Stencil{T}}(ns.s)) | |
126 end | |
127 | |
128 function NestedStencil{T}(ns::NestedStencil{S,N,M}) where {T,S,N,M} | |
129 NestedStencil{T,N,M}(ns) | |
130 end | |
131 | |
132 function Base.convert(::Type{NestedStencil{T,N,M}}, s::NestedStencil{S,N,M}) where {T,S,N,M} | |
133 return NestedStencil{T,N,M}(s) | |
134 end | |
135 Base.convert(::Type{NestedStencil{T}}, stencil) where T = NestedStencil{T}(stencil) | |
136 | |
137 function Base.promote_rule(::Type{NestedStencil{T,N,M}}, ::Type{NestedStencil{S,N,M}}) where {T,S,N,M} | |
138 return NestedStencil{promote_type(T,S),N,M} | |
139 end | |
140 | |
141 Base.eltype(::NestedStencil{T}) where T = T | |
142 | |
143 function scale(ns::NestedStencil, a) | |
144 range = ns.s.range | |
145 weights = ns.s.weights | |
146 | |
147 return NestedStencil(Stencil(range, scale.(weights,a))) | |
148 end | |
149 | |
150 function flip(ns::NestedStencil) | |
151 s_flip = flip(ns.s) | |
152 return NestedStencil(Stencil(s_flip.range, flip.(s_flip.weights))) | |
153 end | |
154 | |
155 Base.getindex(ns::NestedStencil, i::Int) = ns.s[i] | |
156 | |
157 "Apply inner stencils to `c` and get a concrete stencil" | |
158 Base.@propagate_inbounds function apply_inner_stencils(ns::NestedStencil, c::AbstractVector, i::Int) | |
159 weights = apply_stencil.(ns.s.weights, Ref(c), i) | |
160 return Stencil(ns.s.range, weights) | |
161 end | |
162 | |
163 "Apply the whole nested stencil" | |
164 Base.@propagate_inbounds function apply_stencil(ns::NestedStencil, c::AbstractVector, v::AbstractVector, i::Int) | |
165 s = apply_inner_stencils(ns,c,i) | |
166 return apply_stencil(s, v, i) | |
167 end | |
168 | |
169 "Apply inner stencils backwards to `c` and get a concrete stencil" | |
170 Base.@propagate_inbounds @inline function apply_inner_stencils_backwards(ns::NestedStencil, c::AbstractVector, i::Int) | |
171 weights = apply_stencil_backwards.(ns.s.weights, Ref(c), i) | |
172 return Stencil(ns.s.range, weights) | |
173 end | |
174 | |
175 "Apply the whole nested stencil backwards" | |
176 Base.@propagate_inbounds @inline function apply_stencil_backwards(ns::NestedStencil, c::AbstractVector, v::AbstractVector, i::Int) | |
177 s = apply_inner_stencils_backwards(ns,c,i) | |
178 return apply_stencil_backwards(s, v, i) | |
179 end |