Mercurial > repos > public > sbplib_julia
changeset 1229:8f4259fbd39c refactor/LazyTensors/tuple_manipulation
Simplify split_index
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Sun, 19 Feb 2023 11:43:29 +0100 |
parents | 73f262a0a384 |
children | 723a583cef96 |
files | src/LazyTensors/lazy_tensor_operations.jl src/LazyTensors/tuple_manipulation.jl test/LazyTensors/tuple_manipulation_test.jl |
diffstat | 3 files changed, 17 insertions(+), 15 deletions(-) [+] |
line wrap: on
line diff
diff -r 73f262a0a384 -r 8f4259fbd39c src/LazyTensors/lazy_tensor_operations.jl --- a/src/LazyTensors/lazy_tensor_operations.jl Sun Feb 19 11:41:40 2023 +0100 +++ b/src/LazyTensors/lazy_tensor_operations.jl Sun Feb 19 11:43:29 2023 +0100 @@ -197,7 +197,7 @@ dim_range = range_dim(itm.tm) dim_after = range_dim(itm.after) - view_index, inner_index = split_index(Val(dim_before), Val(dim_domain), Val(dim_range), Val(dim_after), I...) + view_index, inner_index = split_index(dim_before, dim_domain, dim_range, dim_after, I...) v_inner = view(v, view_index...) return apply(itm.tm, v_inner, inner_index...) @@ -209,7 +209,7 @@ dim_range = range_dim(itm.tm) dim_after = range_dim(itm.after) - view_index, inner_index = split_index(Val(dim_before), Val(dim_range), Val(dim_domain), Val(dim_after), I...) + view_index, inner_index = split_index(dim_before, dim_range, dim_domain, dim_after, I...) v_inner = view(v, view_index...) return apply_transpose(itm.tm, v_inner, inner_index...)
diff -r 73f262a0a384 -r 8f4259fbd39c src/LazyTensors/tuple_manipulation.jl --- a/src/LazyTensors/tuple_manipulation.jl Sun Feb 19 11:41:40 2023 +0100 +++ b/src/LazyTensors/tuple_manipulation.jl Sun Feb 19 11:43:29 2023 +0100 @@ -1,11 +1,11 @@ """ - split_index(::Val{dim_before}, ::Val{dim_view}, ::Val{dim_index}, ::Val{dim_after}, I...) + split_index(dim_before, dim_view, dim_index, dim_after, I...) Splits the multi-index `I` into two parts. One part which is expected to be used as a view, and one which is expected to be used as an index. Eg. ``` -split_index(Val(1),Val(3),Val(2),Val(1),(1,2,3,4)) -> (1,:,:,:,4), (2,3) +split_index(1, 3, 2, 1, (1,2,3,4)) -> (1,:,:,:,4), (2,3) ``` `dim_view` controls how many colons are in the view, and `dim_index` controls @@ -18,8 +18,9 @@ * `length(view_index) == dim_before + dim_view + dim_after` * `length(I_middle) == dim_index` """ -function split_index(::Val{dim_before}, ::Val{dim_view}, ::Val{dim_index}, ::Val{dim_after}, I...) where {dim_before,dim_view, dim_index,dim_after} - I_before, I_middle, I_after = @inline split_tuple(I, (dim_before, dim_index, dim_after)) +function split_index(dim_before, dim_view, dim_index, dim_after, I...) + @inline + I_before, I_middle, I_after = split_tuple(I, (dim_before, dim_index, dim_after)) view_index = (I_before..., ntuple((i)->:, dim_view)..., I_after...)
diff -r 73f262a0a384 -r 8f4259fbd39c test/LazyTensors/tuple_manipulation_test.jl --- a/test/LazyTensors/tuple_manipulation_test.jl Sun Feb 19 11:41:40 2023 +0100 +++ b/test/LazyTensors/tuple_manipulation_test.jl Sun Feb 19 11:43:29 2023 +0100 @@ -2,17 +2,18 @@ using Sbplib.LazyTensors @testset "split_index" begin - @test LazyTensors.split_index(Val(2),Val(1),Val(2),Val(2),1,2,3,4,5,6) == ((1,2,:,5,6),(3,4)) - @test LazyTensors.split_index(Val(2),Val(3),Val(2),Val(2),1,2,3,4,5,6) == ((1,2,:,:,:,5,6),(3,4)) - @test LazyTensors.split_index(Val(3),Val(1),Val(1),Val(2),1,2,3,4,5,6) == ((1,2,3,:,5,6),(4,)) - @test LazyTensors.split_index(Val(3),Val(2),Val(1),Val(2),1,2,3,4,5,6) == ((1,2,3,:,:,5,6),(4,)) - @test LazyTensors.split_index(Val(1),Val(1),Val(2),Val(3),1,2,3,4,5,6) == ((1,:,4,5,6),(2,3)) - @test LazyTensors.split_index(Val(1),Val(2),Val(2),Val(3),1,2,3,4,5,6) == ((1,:,:,4,5,6),(2,3)) + @test LazyTensors.split_index(2,1,2,2, 1,2,3,4,5,6) == ((1,2,:,5,6),(3,4)) + @test LazyTensors.split_index(2,3,2,2, 1,2,3,4,5,6) == ((1,2,:,:,:,5,6),(3,4)) + @test LazyTensors.split_index(3,1,1,2, 1,2,3,4,5,6) == ((1,2,3,:,5,6),(4,)) + @test LazyTensors.split_index(3,2,1,2, 1,2,3,4,5,6) == ((1,2,3,:,:,5,6),(4,)) + @test LazyTensors.split_index(1,1,2,3, 1,2,3,4,5,6) == ((1,:,4,5,6),(2,3)) + @test LazyTensors.split_index(1,2,2,3, 1,2,3,4,5,6) == ((1,:,:,4,5,6),(2,3)) - @test LazyTensors.split_index(Val(0),Val(1),Val(3),Val(3),1,2,3,4,5,6) == ((:,4,5,6),(1,2,3)) - @test LazyTensors.split_index(Val(3),Val(1),Val(3),Val(0),1,2,3,4,5,6) == ((1,2,3,:),(4,5,6)) + @test LazyTensors.split_index(0,1,3,3, 1,2,3,4,5,6) == ((:,4,5,6),(1,2,3)) + @test LazyTensors.split_index(3,1,3,0, 1,2,3,4,5,6) == ((1,2,3,:),(4,5,6)) - @inferred LazyTensors.split_index(Val(2),Val(3),Val(2),Val(2),1,2,3,2,2,4) + split_index_static(::Val{dim_before}, ::Val{dim_view}, ::Val{dim_index}, ::Val{dim_after}, I...) where {dim_before,dim_view,dim_index,dim_after} = LazyTensors.split_index(dim_before, dim_view, dim_index, dim_after, I...) + @inferred split_index_static(Val(2),Val(3),Val(2),Val(2),1,2,3,2,2,4) end @testset "split_tuple" begin