comparison src/SbpOperators/stencil.jl @ 2057:8a2a0d678d6f feature/lazy_tensors/pretty_printing

Merge default
author Jonatan Werpers <jonatan@werpers.com>
date Tue, 10 Feb 2026 22:41:19 +0100
parents 244311761969
children
comparison
equal deleted inserted replaced
1110:c0bff9f6e0fb 2057:8a2a0d678d6f
1 export CenteredStencil
2 export CenteredNestedStencil
3
4 struct Stencil{T,N} 1 struct Stencil{T,N}
5 range::UnitRange{Int64} 2 range::UnitRange{Int64}
6 weights::NTuple{N,T} 3 weights::NTuple{N,T}
7 4
8 function Stencil(range::UnitRange,weights::NTuple{N,T}) where {T, N} 5 function Stencil(range::UnitRange,weights::NTuple{N,Any}) where N
6 T = eltype(weights)
7
9 @assert length(range) == N 8 @assert length(range) == N
10 new{T,N}(range,weights) 9 new{T,N}(range,weights)
11 end 10 end
12 end 11 end
13 12
14 """ 13 """
15 Stencil(weights::NTuple; center::Int) 14 Stencil(weights...; center::Int)
16 15
17 Create a stencil with the given weights with element `center` as the center of the stencil. 16 Create a stencil with the given weights with element `center` as the center of the stencil.
18 """ 17 """
19 function Stencil(weights...; center::Int) 18 function Stencil(weights...; center::Int)
20 weights = promote(weights...) 19 weights = promote(weights...)
67 end 66 end
68 return s.weights[1 + i - s.range[1]] 67 return s.weights[1 + i - s.range[1]]
69 end 68 end
70 69
71 Base.@propagate_inbounds @inline function apply_stencil(s::Stencil, v::AbstractVector, i::Int) 70 Base.@propagate_inbounds @inline function apply_stencil(s::Stencil, v::AbstractVector, i::Int)
72 w = zero(promote_type(eltype(s),eltype(v))) 71 return sum(enumerate(s.weights)) do (k,w) #TBD: Which optimizations are needed here?
73 @simd for k ∈ 1:length(s) 72 w*v[i + @inbounds s.range[k]]
74 w += s.weights[k]*v[i + s.range[k]]
75 end 73 end
76
77 return w
78 end 74 end
79 75
80 Base.@propagate_inbounds @inline function apply_stencil_backwards(s::Stencil, v::AbstractVector, i::Int) 76 Base.@propagate_inbounds @inline function apply_stencil_backwards(s::Stencil, v::AbstractVector, i::Int)
81 w = zero(promote_type(eltype(s),eltype(v))) 77 return sum(enumerate(s.weights)) do (k,w) #TBD: Which optimizations are needed here?
82 @simd for k ∈ length(s):-1:1 78 w*v[i - @inbounds s.range[k]]
83 w += s.weights[k]*v[i - s.range[k]]
84 end 79 end
85 return w
86 end 80 end
81
82 # There are many options for the implementation of `apply_stencil` and
83 # `apply_stencil_backwards`. Some alternatives were tried on the branch
84 # bugfix/sbp_operators/stencil_return_type and can be found at the following
85 # revision:
86 #
87 # * 237b980ffb91 (baseline)
88 # * a72bab15228e (mapreduce)
89 # * ffd735354d54 (multiplication)
90 # * b5abd5191f2c (promote_op)
91 # * 8d56846185fc (return_type)
92 #
93
94 function left_pad(s::Stencil, N)
95 weights = LazyTensors.left_pad_tuple(s.weights, zero(eltype(s)), N)
96 range = (first(s.range) - (N - length(s.weights))):last(s.range)
97
98 return Stencil(range, weights)
99 end
100
101 function right_pad(s::Stencil, N)
102 weights = LazyTensors.right_pad_tuple(s.weights, zero(eltype(s)), N)
103 range = first(s.range):(last(s.range) + (N - length(s.weights)))
104
105 return Stencil(range, weights)
106 end
107
87 108
88 109
89 struct NestedStencil{T,N,M} 110 struct NestedStencil{T,N,M}
90 s::Stencil{Stencil{T,N},M} 111 s::Stencil{Stencil{T,N},M}
91 end 112 end
92 113
114 NestedStencil(;center) = NestedStencil(Stencil(;center))
115 CenteredNestedStencil() = NestedStencil(CenteredStencil())
116
93 # Stencil input 117 # Stencil input
94 NestedStencil(s::Vararg{Stencil}; center) = NestedStencil(Stencil(s... ; center)) 118 NestedStencil(s::Vararg{Stencil}; center) = NestedStencil(Stencil(s... ; center))
95 CenteredNestedStencil(s::Vararg{Stencil}) = NestedStencil(CenteredStencil(s...)) 119 CenteredNestedStencil(s::Vararg{Stencil}) = NestedStencil(CenteredStencil(s...))
96 120
97 # Tuple input 121 # Tuple input
98 function NestedStencil(weights::Vararg{NTuple{N,Any}}; center) where N 122 function NestedStencil(weights::Vararg{NTuple{N,Any} where N}; center)
99 inner_stencils = map(w -> Stencil(w...; center), weights) 123 inner_stencils = map(w -> Stencil(w...; center), weights)
100 return NestedStencil(Stencil(inner_stencils... ; center)) 124 return NestedStencil(Stencil(inner_stencils... ; center))
101 end 125 end
102 function CenteredNestedStencil(weights::Vararg{NTuple{N,Any}}) where N 126
127 function CenteredNestedStencil(weights::Vararg{NTuple{N,Any} where N})
103 inner_stencils = map(w->CenteredStencil(w...), weights) 128 inner_stencils = map(w->CenteredStencil(w...), weights)
104 return CenteredNestedStencil(inner_stencils...) 129 return CenteredNestedStencil(inner_stencils...)
105 end 130 end
106
107 131
108 # Conversion 132 # Conversion
109 function NestedStencil{T,N,M}(ns::NestedStencil{S,N,M}) where {T,S,N,M} 133 function NestedStencil{T,N,M}(ns::NestedStencil{S,N,M}) where {T,S,N,M}
110 return NestedStencil(Stencil{Stencil{T}}(ns.s)) 134 return NestedStencil(Stencil{Stencil{T}}(ns.s))
111 end 135 end
115 end 139 end
116 140
117 function Base.convert(::Type{NestedStencil{T,N,M}}, s::NestedStencil{S,N,M}) where {T,S,N,M} 141 function Base.convert(::Type{NestedStencil{T,N,M}}, s::NestedStencil{S,N,M}) where {T,S,N,M}
118 return NestedStencil{T,N,M}(s) 142 return NestedStencil{T,N,M}(s)
119 end 143 end
120 Base.convert(::Type{NestedStencil{T}}, stencil) where T = NestedStencil{T}(stencil) 144 Base.convert(::Type{NestedStencil{T}}, stencil::NestedStencil) where T = NestedStencil{T}(stencil)
121 145
122 function Base.promote_rule(::Type{NestedStencil{T,N,M}}, ::Type{NestedStencil{S,N,M}}) where {T,S,N,M} 146 function Base.promote_rule(::Type{NestedStencil{T,N,M}}, ::Type{NestedStencil{S,N,M}}) where {T,S,N,M}
123 return NestedStencil{promote_type(T,S),N,M} 147 return NestedStencil{promote_type(T,S),N,M}
124 end 148 end
125 149