Mercurial > repos > public > sbplib_julia
changeset 455:b86312d14873 feature/inflated_tensormapping
Make split_index type stable
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Wed, 21 Oct 2020 16:29:59 +0200 |
parents | eb4c34438e30 |
children | 8f4c31e06689 |
files | src/LazyTensors/lazy_tensor_operations.jl test/testLazyTensors.jl |
diffstat | 2 files changed, 8 insertions(+), 2 deletions(-) [+] |
line wrap: on
line diff
diff -r eb4c34438e30 -r b86312d14873 src/LazyTensors/lazy_tensor_operations.jl --- a/src/LazyTensors/lazy_tensor_operations.jl Tue Oct 20 09:59:44 2020 +0200 +++ b/src/LazyTensors/lazy_tensor_operations.jl Wed Oct 21 16:29:59 2020 +0200 @@ -231,14 +231,18 @@ """ function split_index(itm::InflatedTensorMapping{T,R,D}, I::Vararg{Any,R}) where {T,R,D} I_before = I[1:range_dim(itm.before)] - I_after = I[(end-range_dim(itm.after)+1):end] + I_after = slice_tuple(I,Val(R-range_dim(itm.after)+1),Val(R)) view_index = (I_before..., ntuple((i)->:,domain_dim(itm.tm))..., I_after...) - inner_index = I[range_dim(itm.before)+1:end-range_dim(itm.after)] + inner_index = slice_tuple(I, Val(range_dim(itm.before)+1), Val(R-range_dim(itm.after))) return (view_index, inner_index) end +function slice_tuple(t,::Val{L},::Val{U}) where {L,U} + return ntuple(i->t[i+L-1], U-L+1) +end + flatten_tuple(t::NTuple{N, Number} where N) = t flatten_tuple(t::Tuple) = ((flatten_tuple.(t)...)...,) # simplify? flatten_tuple(ts::Vararg) = flatten_tuple(ts)
diff -r eb4c34438e30 -r b86312d14873 test/testLazyTensors.jl --- a/test/testLazyTensors.jl Tue Oct 20 09:59:44 2020 +0200 +++ b/test/testLazyTensors.jl Wed Oct 21 16:29:59 2020 +0200 @@ -281,6 +281,8 @@ @test B̃*v ≈ B[1,:,1]*v[1,1] + B[2,:,1]*v[2,1] + B[3,:,1]*v[3,1] + B[1,:,2]v[1,2] + B[2,:,2]*v[2,2] + B[3,:,2]*v[3,2] atol=5e-13 + + @inferred (B̃*v)[2] end