Mercurial > repos > public > sbplib_julia
changeset 532:588a843907de feature/inflated_tensormapping_transpose
Add a split_tuple function to make split_index more readable
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Thu, 26 Nov 2020 16:13:58 +0100 |
parents | 36dfc57e8e0b |
children | aac7cc1fa79a |
files | src/LazyTensors/lazy_tensor_operations.jl test/testLazyTensors.jl |
diffstat | 2 files changed, 54 insertions(+), 3 deletions(-) [+] |
line wrap: on
line diff
--- a/src/LazyTensors/lazy_tensor_operations.jl Wed Nov 25 22:03:26 2020 +0100 +++ b/src/LazyTensors/lazy_tensor_operations.jl Thu Nov 26 16:13:58 2020 +0100 @@ -302,9 +302,7 @@ * `length(I_middle) == B_middle` """ function split_index(::Val{A}, ::Val{B_view}, ::Val{B_middle}, ::Val{C}, I...) where {A,B_view, B_middle,C} - I_before = slice_tuple(I, Val(1), Val(A)) - I_middle = slice_tuple(I, Val(A+1), Val(A+B_middle)) - I_after = slice_tuple(I, Val(A+B_middle+1), Val(A+B_middle+C)) + I_before, I_middle, I_after = split_tuple(I, Val(A), Val(B_middle)) view_index = (I_before..., ntuple((i)->:, B_view)..., I_after...) @@ -326,6 +324,33 @@ end """ + split_tuple(t::Tuple{...}, ::Val{M}) where {N,M} + +Split the tuple `t` into two parts. the first part is `M` long. +E.g +``` +split_tuple((1,2,3,4),Val(3)) -> (1,2,3), (4,) +``` +""" +function split_tuple(t::NTuple{N},::Val{M}) where {N,M} + return slice_tuple(t,Val(1), Val(M)), slice_tuple(t,Val(M+1), Val(N)) +end + +""" + split_tuple(t::Tuple{...},::Val{M},::Val{K}) where {N,M,K} + +Same as `split_tuple(t::NTuple{N},::Val{M})` but splits the tuple in three parts. With the first +two parts having lenght `M` and `K`. +""" +function split_tuple(t::NTuple{N},::Val{M},::Val{K}) where {N,M,K} + p1, tail = split_tuple(t, Val(M)) + p2, p3 = split_tuple(tail, Val(K)) + return p1,p2,p3 +end + + + +""" flatten_tuple(t) Takes a nested tuple and flattens the whole structure
--- a/test/testLazyTensors.jl Wed Nov 25 22:03:26 2020 +0100 +++ b/test/testLazyTensors.jl Thu Nov 26 16:13:58 2020 +0100 @@ -480,6 +480,32 @@ @test LazyTensors.slice_tuple((1,2,3,4,5,6),Val(4), Val(6)) == (4,5,6) end +@testset "split_tuple" begin + @testset "2 parts" begin + @test LazyTensors.split_tuple((),Val(0)) == ((),()) + @test LazyTensors.split_tuple((1,),Val(0)) == ((),(1,)) + @test LazyTensors.split_tuple((1,),Val(1)) == ((1,),()) + + @test LazyTensors.split_tuple((1,2,3,4),Val(0)) == ((),(1,2,3,4)) + @test LazyTensors.split_tuple((1,2,3,4),Val(1)) == ((1,),(2,3,4)) + @test LazyTensors.split_tuple((1,2,3,4),Val(2)) == ((1,2),(3,4)) + @test LazyTensors.split_tuple((1,2,3,4),Val(3)) == ((1,2,3),(4,)) + @test LazyTensors.split_tuple((1,2,3,4),Val(4)) == ((1,2,3,4),()) + + @inferred LazyTensors.split_tuple((1,2,3,4),Val(3)) + end + + @testset "3 parts" begin + @test LazyTensors.split_tuple((),Val(0),Val(0)) == ((),(),()) + @test LazyTensors.split_tuple((1,2,3),Val(1), Val(1)) == ((1,),(2,),(3,)) + + @test LazyTensors.split_tuple((1,2,3,4,5,6),Val(1),Val(2)) == ((1,),(2,3),(4,5,6)) + @test LazyTensors.split_tuple((1,2,3,4,5,6),Val(3),Val(2)) == ((1,2,3),(4,5),(6,)) + + @inferred LazyTensors.split_tuple((1,2,3,4,5,6),Val(3),Val(2)) + end +end + @testset "flatten_tuple" begin @test LazyTensors.flatten_tuple((1,)) == (1,) @test LazyTensors.flatten_tuple((1,2,3,4,5,6)) == (1,2,3,4,5,6)