Mercurial > repos > public > sbplib_julia
comparison src/LazyTensors/lazy_tensor_operations.jl @ 446:904aae1899df feature/inflated_tensormapping
Start implementing InflatedTensorMapping
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Mon, 19 Oct 2020 08:37:35 +0200 |
parents | 46acb2560451 |
children | 27e0e256e5d9 |
comparison
equal
deleted
inserted
replaced
429:46acb2560451 | 446:904aae1899df |
---|---|
160 domain_size(tmi::LazyIdentity) = tmi.size | 160 domain_size(tmi::LazyIdentity) = tmi.size |
161 | 161 |
162 apply(tmi::LazyIdentity{T,D}, v::AbstractArray{T,D}, I::Vararg{Index,D}) where {T,D} = v[Int.(I)...] | 162 apply(tmi::LazyIdentity{T,D}, v::AbstractArray{T,D}, I::Vararg{Index,D}) where {T,D} = v[Int.(I)...] |
163 apply_transpose(tmi::LazyIdentity{T,D}, v::AbstractArray{T,D}, I::Vararg{Index,D}) where {T,D} = v[Int.(I)...] | 163 apply_transpose(tmi::LazyIdentity{T,D}, v::AbstractArray{T,D}, I::Vararg{Index,D}) where {T,D} = v[Int.(I)...] |
164 | 164 |
165 struct InflatedTensorMapping{T,R,D,D_before,R_middle,D_middle,D_after} <: TensorMapping{T,R,D} | |
166 before::LazyIdentity{T,D_before} | |
167 tm::TensorMapping{T,R_middle,D_middle} | |
168 after::LazyIdentity{T,D_after} | |
169 | |
170 function InflatedTensorMapping(before, tm::TensorMapping{T}, after) where T | |
171 R_before = range_dim(before) | |
172 R_middle = range_dim(tm) | |
173 R_after = range_dim(after) | |
174 R = R_before+R_middle+R_after | |
175 | |
176 D_before = domain_dim(before) | |
177 D_middle = domain_dim(tm) | |
178 D_after = domain_dim(after) | |
179 D = D_before+D_middle+D_after | |
180 return new{T,R,D,D_before,R_middle,D_middle,D_after}(before, tm, after) | |
181 end | |
182 | |
183 # TODO: Implement constructors where one of `before` or `after` is missing | |
184 | |
185 # TODO: Implement syntax and constructors for products of different combinations of InflatedTensorMapping and LazyIdentity | |
186 | |
187 # TODO: Implement some pretty printing in terms of ⊗. E.g InflatedTensorMapping(I(3),B,I(2)) -> I(3)⊗B⊗I(2) | |
188 | |
189 function range_size(itm::InflatedTensorMapping) | |
190 return flatten_tuple( | |
191 range_size(itm.before), | |
192 range_size(itm.tm), | |
193 range_size(itm.after), | |
194 ) | |
195 end | |
196 | |
197 function domain_size(itm::InflatedTensorMapping) | |
198 return flatten_tuple( | |
199 domain_size(itm.before), | |
200 domain_size(itm.tm), | |
201 domain_size(itm.after), | |
202 ) | |
203 end | |
204 | |
205 function apply(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{T,D}, I::Vararg{S,R} where S) where {T,R,D} | |
206 view_index, inner_index = split_index(I...) | |
207 | |
208 v_inner = view(v, view_index...) | |
209 return apply(itm.tm, v_inner, inner_index...) | |
210 end | |
211 | |
212 | |
213 """ | |
214 split_index(...) | |
215 | |
216 Splits the multi-index into two parts. One part for the view that the inner TensorMapping acts on, and one part for indexing the result | |
217 Eg. | |
218 ``` | |
219 (1,2,3,4) -> (1,:,:,4), (2,3) | |
220 ``` | |
221 """ | |
222 function split_index(itm::InflatedTensorMapping{T,R,D}, I::Vararg{S,R} where S) where {T,R,D} | |
223 I_before = I[1:range_dim(itm.before)] | |
224 I_after = I[(end-range_dim(itm.after)+1):end] | |
225 | |
226 view_index = (I_before..., ntuple((i)->:,domain_dim(itm.tm))..., I_after...) | |
227 A_view = @view llm.A[view_index...] | |
228 inner_index = I[range_dim(itm.before)+1:end-range_dim(itm.after)] | |
229 | |
230 return (view_index, inner_index) | |
231 return sum(A_view.*v) | |
232 end | |
233 | |
234 flatten_tuple(t::NTuple{N, Number} where N) = t | |
235 flatten_tuple(t::Tuple) = ((flatten_tuple.(t)...)...,) # simplify? | |
236 flatten_tuple(ts::Vararg) = flatten_tuple(ts) |