comparison src/LazyTensors/lazy_tensor_operations.jl @ 452:aeda2698166d feature/inflated_tensormapping

Add tullio as a test dependency and add a test for apply
author Jonatan Werpers <jonatan@werpers.com>
date Mon, 19 Oct 2020 22:34:58 +0200
parents 6cf234eef780
children b86312d14873
comparison
equal deleted inserted replaced
451:6cf234eef780 452:aeda2698166d
210 domain_size(itm.tm), 210 domain_size(itm.tm),
211 domain_size(itm.after), 211 domain_size(itm.after),
212 ) 212 )
213 end 213 end
214 214
215 function apply(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{T,D}, I::Vararg{S,R} where S) where {T,R,D} 215 function apply(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D}
216 view_index, inner_index = split_index(I...) 216 view_index, inner_index = split_index(itm, I...)
217 217
218 v_inner = view(v, view_index...) 218 v_inner = view(v, view_index...)
219 return apply(itm.tm, v_inner, inner_index...) 219 return apply(itm.tm, v_inner, inner_index...)
220 end 220 end
221 221
227 Eg. 227 Eg.
228 ``` 228 ```
229 (1,2,3,4) -> (1,:,:,4), (2,3) 229 (1,2,3,4) -> (1,:,:,4), (2,3)
230 ``` 230 ```
231 """ 231 """
232 function split_index(itm::InflatedTensorMapping{T,R,D}, I::Vararg{S,R} where S) where {T,R,D} 232 function split_index(itm::InflatedTensorMapping{T,R,D}, I::Vararg{Any,R}) where {T,R,D}
233 I_before = I[1:range_dim(itm.before)] 233 I_before = I[1:range_dim(itm.before)]
234 I_after = I[(end-range_dim(itm.after)+1):end] 234 I_after = I[(end-range_dim(itm.after)+1):end]
235 235
236 view_index = (I_before..., ntuple((i)->:,domain_dim(itm.tm))..., I_after...) 236 view_index = (I_before..., ntuple((i)->:,domain_dim(itm.tm))..., I_after...)
237 A_view = @view llm.A[view_index...]
238 inner_index = I[range_dim(itm.before)+1:end-range_dim(itm.after)] 237 inner_index = I[range_dim(itm.before)+1:end-range_dim(itm.after)]
239 238
240 return (view_index, inner_index) 239 return (view_index, inner_index)
241 return sum(A_view.*v)
242 end 240 end
243 241
244 flatten_tuple(t::NTuple{N, Number} where N) = t 242 flatten_tuple(t::NTuple{N, Number} where N) = t
245 flatten_tuple(t::Tuple) = ((flatten_tuple.(t)...)...,) # simplify? 243 flatten_tuple(t::Tuple) = ((flatten_tuple.(t)...)...,) # simplify?
246 flatten_tuple(ts::Vararg) = flatten_tuple(ts) 244 flatten_tuple(ts::Vararg) = flatten_tuple(ts)