comparison src/LazyTensors/lazy_tensor_operations.jl @ 446:904aae1899df feature/inflated_tensormapping

Start implementing InflatedTensorMapping
author Jonatan Werpers <jonatan@werpers.com>
date Mon, 19 Oct 2020 08:37:35 +0200
parents 46acb2560451
children 27e0e256e5d9
comparison
equal deleted inserted replaced
429:46acb2560451 446:904aae1899df
160 domain_size(tmi::LazyIdentity) = tmi.size 160 domain_size(tmi::LazyIdentity) = tmi.size
161 161
162 apply(tmi::LazyIdentity{T,D}, v::AbstractArray{T,D}, I::Vararg{Index,D}) where {T,D} = v[Int.(I)...] 162 apply(tmi::LazyIdentity{T,D}, v::AbstractArray{T,D}, I::Vararg{Index,D}) where {T,D} = v[Int.(I)...]
163 apply_transpose(tmi::LazyIdentity{T,D}, v::AbstractArray{T,D}, I::Vararg{Index,D}) where {T,D} = v[Int.(I)...] 163 apply_transpose(tmi::LazyIdentity{T,D}, v::AbstractArray{T,D}, I::Vararg{Index,D}) where {T,D} = v[Int.(I)...]
164 164
165 struct InflatedTensorMapping{T,R,D,D_before,R_middle,D_middle,D_after} <: TensorMapping{T,R,D}
166 before::LazyIdentity{T,D_before}
167 tm::TensorMapping{T,R_middle,D_middle}
168 after::LazyIdentity{T,D_after}
169
170 function InflatedTensorMapping(before, tm::TensorMapping{T}, after) where T
171 R_before = range_dim(before)
172 R_middle = range_dim(tm)
173 R_after = range_dim(after)
174 R = R_before+R_middle+R_after
175
176 D_before = domain_dim(before)
177 D_middle = domain_dim(tm)
178 D_after = domain_dim(after)
179 D = D_before+D_middle+D_after
180 return new{T,R,D,D_before,R_middle,D_middle,D_after}(before, tm, after)
181 end
182
183 # TODO: Implement constructors where one of `before` or `after` is missing
184
185 # TODO: Implement syntax and constructors for products of different combinations of InflatedTensorMapping and LazyIdentity
186
187 # TODO: Implement some pretty printing in terms of ⊗. E.g InflatedTensorMapping(I(3),B,I(2)) -> I(3)⊗B⊗I(2)
188
189 function range_size(itm::InflatedTensorMapping)
190 return flatten_tuple(
191 range_size(itm.before),
192 range_size(itm.tm),
193 range_size(itm.after),
194 )
195 end
196
197 function domain_size(itm::InflatedTensorMapping)
198 return flatten_tuple(
199 domain_size(itm.before),
200 domain_size(itm.tm),
201 domain_size(itm.after),
202 )
203 end
204
205 function apply(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{T,D}, I::Vararg{S,R} where S) where {T,R,D}
206 view_index, inner_index = split_index(I...)
207
208 v_inner = view(v, view_index...)
209 return apply(itm.tm, v_inner, inner_index...)
210 end
211
212
213 """
214 split_index(...)
215
216 Splits the multi-index into two parts. One part for the view that the inner TensorMapping acts on, and one part for indexing the result
217 Eg.
218 ```
219 (1,2,3,4) -> (1,:,:,4), (2,3)
220 ```
221 """
222 function split_index(itm::InflatedTensorMapping{T,R,D}, I::Vararg{S,R} where S) where {T,R,D}
223 I_before = I[1:range_dim(itm.before)]
224 I_after = I[(end-range_dim(itm.after)+1):end]
225
226 view_index = (I_before..., ntuple((i)->:,domain_dim(itm.tm))..., I_after...)
227 A_view = @view llm.A[view_index...]
228 inner_index = I[range_dim(itm.before)+1:end-range_dim(itm.after)]
229
230 return (view_index, inner_index)
231 return sum(A_view.*v)
232 end
233
234 flatten_tuple(t::NTuple{N, Number} where N) = t
235 flatten_tuple(t::Tuple) = ((flatten_tuple.(t)...)...,) # simplify?
236 flatten_tuple(ts::Vararg) = flatten_tuple(ts)