comparison src/SbpOperators/stencil.jl @ 1072:14cb97284373 feature/dissipation_operators

Merge default
author Jonatan Werpers <jonatan@werpers.com>
date Thu, 24 Mar 2022 12:00:12 +0100
parents 11767fbb29f4 fdd594b2a15e
children 6baed7b081f2 4f79ab676ebc e1222fbb7c4d c10c6c3e9247 f13857f37b8f
comparison
equal deleted inserted replaced
1069:c89c6b63c7f4 1072:14cb97284373
1 export CenteredStencil 1 export CenteredStencil
2 export CenteredNestedStencil
2 3
3 struct Stencil{T,N} 4 struct Stencil{T,N}
4 range::Tuple{Int,Int} 5 range::UnitRange{Int64}
5 weights::NTuple{N,T} 6 weights::NTuple{N,T}
6 7
7 function Stencil(range::Tuple{Int,Int},weights::NTuple{N,T}) where {T, N} 8 function Stencil(range::UnitRange,weights::NTuple{N,T}) where {T, N}
8 @assert range[2]-range[1]+1 == N 9 @assert length(range) == N
9 new{T,N}(range,weights) 10 new{T,N}(range,weights)
10 end 11 end
11 end 12 end
12 13
13 """ 14 """
14 Stencil(weights::NTuple; center::Int) 15 Stencil(weights::NTuple; center::Int)
15 16
16 Create a stencil with the given weights with element `center` as the center of the stencil. 17 Create a stencil with the given weights with element `center` as the center of the stencil.
17 """ 18 """
18 function Stencil(weights::Vararg{T}; center::Int) where T # Type parameter T makes sure the weights are valid for the Stencil constuctors and throws an earlier, more readable, error 19 function Stencil(weights...; center::Int)
20 weights = promote(weights...)
19 N = length(weights) 21 N = length(weights)
20 range = (1, N) .- center 22 range = (1:N) .- center
21 23
22 return Stencil(range, weights) 24 return Stencil(range, weights)
23 end 25 end
24 26
25 function Stencil{T}(s::Stencil) where T 27 Stencil{T,N}(s::Stencil{S,N}) where {T,S,N} = Stencil(s.range, T.(s.weights))
26 return Stencil(s.range, T.(s.weights)) 28 Stencil{T}(s::Stencil) where T = Stencil{T,length(s)}(s)
27 end
28 29
29 Base.convert(::Type{Stencil{T}}, stencil) where T = Stencil{T}(stencil) 30 Base.convert(::Type{Stencil{T1,N}}, s::Stencil{T2,N}) where {T1,T2,N} = Stencil{T1,N}(s)
31 Base.convert(::Type{Stencil{T1}}, s::Stencil{T2,N}) where {T1,T2,N} = Stencil{T1,N}(s)
30 32
31 function CenteredStencil(weights::Vararg) 33 Base.promote_rule(::Type{Stencil{T1,N}}, ::Type{Stencil{T2,N}}) where {T1,T2,N} = Stencil{promote_type(T1,T2),N}
34
35 function CenteredStencil(weights...)
32 if iseven(length(weights)) 36 if iseven(length(weights))
33 throw(ArgumentError("a centered stencil must have an odd number of weights.")) 37 throw(ArgumentError("a centered stencil must have an odd number of weights."))
34 end 38 end
35 39
36 r = length(weights) ÷ 2 40 r = length(weights) ÷ 2
37 41
38 return Stencil((-r, r), weights) 42 return Stencil(-r:r, weights)
39 end 43 end
40 44
41 45
42 """ 46 """
43 scale(s::Stencil, a) 47 scale(s::Stencil, a)
46 """ 50 """
47 function scale(s::Stencil, a) 51 function scale(s::Stencil, a)
48 return Stencil(s.range, a.*s.weights) 52 return Stencil(s.range, a.*s.weights)
49 end 53 end
50 54
51 Base.eltype(::Stencil{T}) where T = T 55 Base.eltype(::Stencil{T,N}) where {T,N} = T
56 Base.length(::Stencil{T,N}) where {T,N} = N
52 57
53 function flip(s::Stencil) 58 function flip(s::Stencil)
54 range = (-s.range[2], -s.range[1]) 59 range = (-s.range[2], -s.range[1])
55 return Stencil(range, reverse(s.weights)) 60 return Stencil(range, reverse(s.weights))
56 end 61 end
57 62
58 # Provides index into the Stencil based on offset for the root element 63 # Provides index into the Stencil based on offset for the root element
59 @inline function Base.getindex(s::Stencil, i::Int) 64 @inline function Base.getindex(s::Stencil, i::Int)
60 @boundscheck if i < s.range[1] || s.range[2] < i 65 @boundscheck if i ∉ s.range
61 return zero(eltype(s)) 66 return zero(eltype(s))
62 end 67 end
63 return s.weights[1 + i - s.range[1]] 68 return s.weights[1 + i - s.range[1]]
64 end 69 end
65 70
66 Base.@propagate_inbounds @inline function apply_stencil(s::Stencil{T,N}, v::AbstractVector, i::Int) where {T,N} 71 Base.@propagate_inbounds @inline function apply_stencil(s::Stencil, v::AbstractVector, i::Int)
67 w = s.weights[1]*v[i + s.range[1]] 72 w = zero(promote_type(eltype(s),eltype(v)))
68 @simd for k ∈ 2:N 73 @simd for k ∈ 1:length(s)
69 w += s.weights[k]*v[i + s.range[1] + k-1] 74 w += s.weights[k]*v[i + s.range[k]]
75 end
76
77 return w
78 end
79
80 Base.@propagate_inbounds @inline function apply_stencil_backwards(s::Stencil, v::AbstractVector, i::Int)
81 w = zero(promote_type(eltype(s),eltype(v)))
82 @simd for k ∈ length(s):-1:1
83 w += s.weights[k]*v[i - s.range[k]]
70 end 84 end
71 return w 85 return w
72 end 86 end
73 87
74 Base.@propagate_inbounds @inline function apply_stencil_backwards(s::Stencil{T,N}, v::AbstractVector, i::Int) where {T,N}
75 w = s.weights[N]*v[i - s.range[2]]
76 @simd for k ∈ N-1:-1:1
77 w += s.weights[k]*v[i - s.range[1] - k + 1]
78 end
79 return w
80 end
81
82
83 function left_pad(s::Stencil, N) 88 function left_pad(s::Stencil, N)
84 weights = LazyTensors.left_pad_tuple(s.weights, zero(eltype(s)), N) 89 weights = LazyTensors.left_pad_tuple(s.weights, zero(eltype(s)), N)
85 range = (s.range[1] - (N - length(s.weights)) ,s.range[2]) 90 range = (first(s.range) - (N - length(s.weights))):last(s.range)
86 91
87 return Stencil(range, weights) 92 return Stencil(range, weights)
88 end 93 end
89 94
90 function right_pad(s::Stencil, N) 95 function right_pad(s::Stencil, N)
91 weights = LazyTensors.right_pad_tuple(s.weights, zero(eltype(s)), N) 96 weights = LazyTensors.right_pad_tuple(s.weights, zero(eltype(s)), N)
92 range = (s.range[1], s.range[2] + (N - length(s.weights))) 97 range = first(s.range):(last(s.range) + (N - length(s.weights)))
93 98
94 return Stencil(range, weights) 99 return Stencil(range, weights)
95 end 100 end
101
102
103
104 struct NestedStencil{T,N,M}
105 s::Stencil{Stencil{T,N},M}
106 end
107
108 # Stencil input
109 NestedStencil(s::Vararg{Stencil}; center) = NestedStencil(Stencil(s... ; center))
110 CenteredNestedStencil(s::Vararg{Stencil}) = NestedStencil(CenteredStencil(s...))
111
112 # Tuple input
113 function NestedStencil(weights::Vararg{NTuple{N,Any}}; center) where N
114 inner_stencils = map(w -> Stencil(w...; center), weights)
115 return NestedStencil(Stencil(inner_stencils... ; center))
116 end
117 function CenteredNestedStencil(weights::Vararg{NTuple{N,Any}}) where N
118 inner_stencils = map(w->CenteredStencil(w...), weights)
119 return CenteredNestedStencil(inner_stencils...)
120 end
121
122
123 # Conversion
124 function NestedStencil{T,N,M}(ns::NestedStencil{S,N,M}) where {T,S,N,M}
125 return NestedStencil(Stencil{Stencil{T}}(ns.s))
126 end
127
128 function NestedStencil{T}(ns::NestedStencil{S,N,M}) where {T,S,N,M}
129 NestedStencil{T,N,M}(ns)
130 end
131
132 function Base.convert(::Type{NestedStencil{T,N,M}}, s::NestedStencil{S,N,M}) where {T,S,N,M}
133 return NestedStencil{T,N,M}(s)
134 end
135 Base.convert(::Type{NestedStencil{T}}, stencil) where T = NestedStencil{T}(stencil)
136
137 function Base.promote_rule(::Type{NestedStencil{T,N,M}}, ::Type{NestedStencil{S,N,M}}) where {T,S,N,M}
138 return NestedStencil{promote_type(T,S),N,M}
139 end
140
141 Base.eltype(::NestedStencil{T}) where T = T
142
143 function scale(ns::NestedStencil, a)
144 range = ns.s.range
145 weights = ns.s.weights
146
147 return NestedStencil(Stencil(range, scale.(weights,a)))
148 end
149
150 function flip(ns::NestedStencil)
151 s_flip = flip(ns.s)
152 return NestedStencil(Stencil(s_flip.range, flip.(s_flip.weights)))
153 end
154
155 Base.getindex(ns::NestedStencil, i::Int) = ns.s[i]
156
157 "Apply inner stencils to `c` and get a concrete stencil"
158 Base.@propagate_inbounds function apply_inner_stencils(ns::NestedStencil, c::AbstractVector, i::Int)
159 weights = apply_stencil.(ns.s.weights, Ref(c), i)
160 return Stencil(ns.s.range, weights)
161 end
162
163 "Apply the whole nested stencil"
164 Base.@propagate_inbounds function apply_stencil(ns::NestedStencil, c::AbstractVector, v::AbstractVector, i::Int)
165 s = apply_inner_stencils(ns,c,i)
166 return apply_stencil(s, v, i)
167 end
168
169 "Apply inner stencils backwards to `c` and get a concrete stencil"
170 Base.@propagate_inbounds @inline function apply_inner_stencils_backwards(ns::NestedStencil, c::AbstractVector, i::Int)
171 weights = apply_stencil_backwards.(ns.s.weights, Ref(c), i)
172 return Stencil(ns.s.range, weights)
173 end
174
175 "Apply the whole nested stencil backwards"
176 Base.@propagate_inbounds @inline function apply_stencil_backwards(ns::NestedStencil, c::AbstractVector, v::AbstractVector, i::Int)
177 s = apply_inner_stencils_backwards(ns,c,i)
178 return apply_stencil_backwards(s, v, i)
179 end