changeset 1229:8f4259fbd39c refactor/LazyTensors/tuple_manipulation

Simplify split_index
author Jonatan Werpers <jonatan@werpers.com>
date Sun, 19 Feb 2023 11:43:29 +0100
parents 73f262a0a384
children 723a583cef96
files src/LazyTensors/lazy_tensor_operations.jl src/LazyTensors/tuple_manipulation.jl test/LazyTensors/tuple_manipulation_test.jl
diffstat 3 files changed, 17 insertions(+), 15 deletions(-) [+]
line wrap: on
line diff
diff -r 73f262a0a384 -r 8f4259fbd39c src/LazyTensors/lazy_tensor_operations.jl
--- a/src/LazyTensors/lazy_tensor_operations.jl	Sun Feb 19 11:41:40 2023 +0100
+++ b/src/LazyTensors/lazy_tensor_operations.jl	Sun Feb 19 11:43:29 2023 +0100
@@ -197,7 +197,7 @@
     dim_range = range_dim(itm.tm)
     dim_after = range_dim(itm.after)
 
-    view_index, inner_index = split_index(Val(dim_before), Val(dim_domain), Val(dim_range), Val(dim_after), I...)
+    view_index, inner_index = split_index(dim_before, dim_domain, dim_range, dim_after, I...)
 
     v_inner = view(v, view_index...)
     return apply(itm.tm, v_inner, inner_index...)
@@ -209,7 +209,7 @@
     dim_range = range_dim(itm.tm)
     dim_after = range_dim(itm.after)
 
-    view_index, inner_index = split_index(Val(dim_before), Val(dim_range), Val(dim_domain), Val(dim_after), I...)
+    view_index, inner_index = split_index(dim_before, dim_range, dim_domain, dim_after, I...)
 
     v_inner = view(v, view_index...)
     return apply_transpose(itm.tm, v_inner, inner_index...)
diff -r 73f262a0a384 -r 8f4259fbd39c src/LazyTensors/tuple_manipulation.jl
--- a/src/LazyTensors/tuple_manipulation.jl	Sun Feb 19 11:41:40 2023 +0100
+++ b/src/LazyTensors/tuple_manipulation.jl	Sun Feb 19 11:43:29 2023 +0100
@@ -1,11 +1,11 @@
 """
-    split_index(::Val{dim_before}, ::Val{dim_view}, ::Val{dim_index}, ::Val{dim_after}, I...)
+    split_index(dim_before, dim_view, dim_index, dim_after, I...)
 
 Splits the multi-index `I` into two parts. One part which is expected to be
 used as a view, and one which is expected to be used as an index.
 Eg.
 ```
-split_index(Val(1),Val(3),Val(2),Val(1),(1,2,3,4)) -> (1,:,:,:,4), (2,3)
+split_index(1, 3, 2, 1, (1,2,3,4)) -> (1,:,:,:,4), (2,3)
 ```
 
 `dim_view` controls how many colons are in the view, and `dim_index` controls
@@ -18,8 +18,9 @@
  * `length(view_index) == dim_before + dim_view + dim_after`
  * `length(I_middle) == dim_index`
 """
-function split_index(::Val{dim_before}, ::Val{dim_view}, ::Val{dim_index}, ::Val{dim_after}, I...) where {dim_before,dim_view, dim_index,dim_after}
-    I_before, I_middle, I_after = @inline split_tuple(I, (dim_before, dim_index, dim_after))
+function split_index(dim_before, dim_view, dim_index, dim_after, I...)
+    @inline
+    I_before, I_middle, I_after = split_tuple(I, (dim_before, dim_index, dim_after))
 
     view_index = (I_before..., ntuple((i)->:, dim_view)..., I_after...)
 
diff -r 73f262a0a384 -r 8f4259fbd39c test/LazyTensors/tuple_manipulation_test.jl
--- a/test/LazyTensors/tuple_manipulation_test.jl	Sun Feb 19 11:41:40 2023 +0100
+++ b/test/LazyTensors/tuple_manipulation_test.jl	Sun Feb 19 11:43:29 2023 +0100
@@ -2,17 +2,18 @@
 using Sbplib.LazyTensors
 
 @testset "split_index" begin
-    @test LazyTensors.split_index(Val(2),Val(1),Val(2),Val(2),1,2,3,4,5,6) == ((1,2,:,5,6),(3,4))
-    @test LazyTensors.split_index(Val(2),Val(3),Val(2),Val(2),1,2,3,4,5,6) == ((1,2,:,:,:,5,6),(3,4))
-    @test LazyTensors.split_index(Val(3),Val(1),Val(1),Val(2),1,2,3,4,5,6) == ((1,2,3,:,5,6),(4,))
-    @test LazyTensors.split_index(Val(3),Val(2),Val(1),Val(2),1,2,3,4,5,6) == ((1,2,3,:,:,5,6),(4,))
-    @test LazyTensors.split_index(Val(1),Val(1),Val(2),Val(3),1,2,3,4,5,6) == ((1,:,4,5,6),(2,3))
-    @test LazyTensors.split_index(Val(1),Val(2),Val(2),Val(3),1,2,3,4,5,6) == ((1,:,:,4,5,6),(2,3))
+    @test LazyTensors.split_index(2,1,2,2, 1,2,3,4,5,6) == ((1,2,:,5,6),(3,4))
+    @test LazyTensors.split_index(2,3,2,2, 1,2,3,4,5,6) == ((1,2,:,:,:,5,6),(3,4))
+    @test LazyTensors.split_index(3,1,1,2, 1,2,3,4,5,6) == ((1,2,3,:,5,6),(4,))
+    @test LazyTensors.split_index(3,2,1,2, 1,2,3,4,5,6) == ((1,2,3,:,:,5,6),(4,))
+    @test LazyTensors.split_index(1,1,2,3, 1,2,3,4,5,6) == ((1,:,4,5,6),(2,3))
+    @test LazyTensors.split_index(1,2,2,3, 1,2,3,4,5,6) == ((1,:,:,4,5,6),(2,3))
 
-    @test LazyTensors.split_index(Val(0),Val(1),Val(3),Val(3),1,2,3,4,5,6) == ((:,4,5,6),(1,2,3))
-    @test LazyTensors.split_index(Val(3),Val(1),Val(3),Val(0),1,2,3,4,5,6) == ((1,2,3,:),(4,5,6))
+    @test LazyTensors.split_index(0,1,3,3, 1,2,3,4,5,6) == ((:,4,5,6),(1,2,3))
+    @test LazyTensors.split_index(3,1,3,0, 1,2,3,4,5,6) == ((1,2,3,:),(4,5,6))
 
-    @inferred LazyTensors.split_index(Val(2),Val(3),Val(2),Val(2),1,2,3,2,2,4)
+    split_index_static(::Val{dim_before}, ::Val{dim_view}, ::Val{dim_index}, ::Val{dim_after}, I...) where {dim_before,dim_view,dim_index,dim_after} = LazyTensors.split_index(dim_before, dim_view, dim_index, dim_after, I...)
+    @inferred split_index_static(Val(2),Val(3),Val(2),Val(2),1,2,3,2,2,4)
 end
 
 @testset "split_tuple" begin