Mercurial > repos > public > sbplib_julia
comparison src/LazyTensors/lazy_tensor_operations.jl @ 452:aeda2698166d feature/inflated_tensormapping
Add tullio as a test dependency and add a test for apply
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Mon, 19 Oct 2020 22:34:58 +0200 |
parents | 6cf234eef780 |
children | b86312d14873 |
comparison
equal
deleted
inserted
replaced
451:6cf234eef780 | 452:aeda2698166d |
---|---|
210 domain_size(itm.tm), | 210 domain_size(itm.tm), |
211 domain_size(itm.after), | 211 domain_size(itm.after), |
212 ) | 212 ) |
213 end | 213 end |
214 | 214 |
215 function apply(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{T,D}, I::Vararg{S,R} where S) where {T,R,D} | 215 function apply(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D} |
216 view_index, inner_index = split_index(I...) | 216 view_index, inner_index = split_index(itm, I...) |
217 | 217 |
218 v_inner = view(v, view_index...) | 218 v_inner = view(v, view_index...) |
219 return apply(itm.tm, v_inner, inner_index...) | 219 return apply(itm.tm, v_inner, inner_index...) |
220 end | 220 end |
221 | 221 |
227 Eg. | 227 Eg. |
228 ``` | 228 ``` |
229 (1,2,3,4) -> (1,:,:,4), (2,3) | 229 (1,2,3,4) -> (1,:,:,4), (2,3) |
230 ``` | 230 ``` |
231 """ | 231 """ |
232 function split_index(itm::InflatedTensorMapping{T,R,D}, I::Vararg{S,R} where S) where {T,R,D} | 232 function split_index(itm::InflatedTensorMapping{T,R,D}, I::Vararg{Any,R}) where {T,R,D} |
233 I_before = I[1:range_dim(itm.before)] | 233 I_before = I[1:range_dim(itm.before)] |
234 I_after = I[(end-range_dim(itm.after)+1):end] | 234 I_after = I[(end-range_dim(itm.after)+1):end] |
235 | 235 |
236 view_index = (I_before..., ntuple((i)->:,domain_dim(itm.tm))..., I_after...) | 236 view_index = (I_before..., ntuple((i)->:,domain_dim(itm.tm))..., I_after...) |
237 A_view = @view llm.A[view_index...] | |
238 inner_index = I[range_dim(itm.before)+1:end-range_dim(itm.after)] | 237 inner_index = I[range_dim(itm.before)+1:end-range_dim(itm.after)] |
239 | 238 |
240 return (view_index, inner_index) | 239 return (view_index, inner_index) |
241 return sum(A_view.*v) | |
242 end | 240 end |
243 | 241 |
244 flatten_tuple(t::NTuple{N, Number} where N) = t | 242 flatten_tuple(t::NTuple{N, Number} where N) = t |
245 flatten_tuple(t::Tuple) = ((flatten_tuple.(t)...)...,) # simplify? | 243 flatten_tuple(t::Tuple) = ((flatten_tuple.(t)...)...,) # simplify? |
246 flatten_tuple(ts::Vararg) = flatten_tuple(ts) | 244 flatten_tuple(ts::Vararg) = flatten_tuple(ts) |