diff src/LazyTensors/lazy_tensor_operations.jl @ 455:b86312d14873 feature/inflated_tensormapping

Make split_index type stable
author Jonatan Werpers <jonatan@werpers.com>
date Wed, 21 Oct 2020 16:29:59 +0200
parents aeda2698166d
children 8f4c31e06689
line wrap: on
line diff
--- a/src/LazyTensors/lazy_tensor_operations.jl	Tue Oct 20 09:59:44 2020 +0200
+++ b/src/LazyTensors/lazy_tensor_operations.jl	Wed Oct 21 16:29:59 2020 +0200
@@ -231,14 +231,18 @@
 """
 function split_index(itm::InflatedTensorMapping{T,R,D}, I::Vararg{Any,R}) where {T,R,D}
     I_before = I[1:range_dim(itm.before)]
-    I_after = I[(end-range_dim(itm.after)+1):end]
+    I_after = slice_tuple(I,Val(R-range_dim(itm.after)+1),Val(R))
 
     view_index = (I_before..., ntuple((i)->:,domain_dim(itm.tm))..., I_after...)
-    inner_index = I[range_dim(itm.before)+1:end-range_dim(itm.after)]
+    inner_index = slice_tuple(I, Val(range_dim(itm.before)+1), Val(R-range_dim(itm.after)))
 
     return (view_index, inner_index)
 end
 
+function slice_tuple(t,::Val{L},::Val{U}) where {L,U}
+    return ntuple(i->t[i+L-1], U-L+1)
+end
+
 flatten_tuple(t::NTuple{N, Number} where N) = t
 flatten_tuple(t::Tuple) = ((flatten_tuple.(t)...)...,) # simplify?
 flatten_tuple(ts::Vararg) = flatten_tuple(ts)