changeset 520:fe86ac896377 feature/inflated_tensormapping_transpose

Start refactoring split index and apply to accomodate future addition of apply_transpose
author Jonatan Werpers <jonatan@werpers.com>
date Mon, 23 Nov 2020 21:30:11 +0100
parents 27e64b3d3efa
children 41c1760a7770
files src/LazyTensors/lazy_tensor_operations.jl test/testLazyTensors.jl
diffstat 2 files changed, 38 insertions(+), 14 deletions(-) [+]
line wrap: on
line diff
--- a/src/LazyTensors/lazy_tensor_operations.jl	Mon Nov 23 21:15:04 2020 +0100
+++ b/src/LazyTensors/lazy_tensor_operations.jl	Mon Nov 23 21:30:11 2020 +0100
@@ -261,7 +261,12 @@
 end
 
 function apply(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D}
-    view_index, inner_index = split_index(itm, I...)
+    A = range_dim(itm.before)
+    B_domain = domain_dim(itm.tm)
+    B_range = range_dim(itm.tm)
+    C = range_dim(itm.after)
+
+    view_index, inner_index = split_index(Val(A), Val(B_range), Val(B_domain), Val(C), I...)
 
     v_inner = view(v, view_index...)
     return apply(itm.tm, v_inner, inner_index...)
@@ -269,22 +274,31 @@
 
 
 """
-    split_index(...)
+    split_index(:Val{A}, ::Val{B_view}, ::Val{B_middle}, ::Val{C}, I...)
 
-Splits the multi-index into two parts. One part for the view that the inner TensorMapping acts on, and one part for indexing the result
+Splits the multi-index `I` into two parts. One part which is expected to be used as a view, which is expected to be used as an index.
 Eg.
 ```
-(1,2,3,4) -> (1,:,:,4), (2,3)
+(1,2,3,4) -> (1,:,:,:,4), (2,3)
 ```
+
+`B_view` controls how many colons are in the view, and `B_middle` controls how many elements are extracted from the middle.
+`A` and `C` decides the length of the parts before and after the colons in the view index.
+length(I) == A+B_domain+C
+length(I_middle) == B_domain
+length(I_view) == A + B_range + C
+
+TODO: Finish documentation.
 """
-function split_index(itm::InflatedTensorMapping{T,R,D}, I::Vararg{Any,R}) where {T,R,D}
-    I_before = slice_tuple(I, Val(1), Val(range_dim(itm.before)))
-    I_after = slice_tuple(I, Val(R-range_dim(itm.after)+1), Val(R))
+function split_index(::Val{A}, ::Val{B_view}, ::Val{B_middle}, ::Val{C}, I...) where {A,B_view, B_middle,C}
+    I_before = slice_tuple(I, Val(1), Val(A))
+    I_middle = slice_tuple(I, Val(A+1), Val(A+B_middle))
+    I_after = slice_tuple(I, Val(A+B_middle+1), Val(A+B_middle+C))
 
-    view_index = (I_before..., ntuple((i)->:,domain_dim(itm.tm))..., I_after...)
-    inner_index = slice_tuple(I, Val(range_dim(itm.before)+1), Val(R-range_dim(itm.after)))
+    view_index = (I_before..., ntuple((i)->:, B_view)..., I_after...)
+    inner_index =
 
-    return (view_index, inner_index)
+    return view_index, I_middle
 end
 
 # TODO: Can this be replaced by something more elegant while still being type stable? 2020-10-21
--- a/test/testLazyTensors.jl	Mon Nov 23 21:15:04 2020 +0100
+++ b/test/testLazyTensors.jl	Mon Nov 23 21:30:11 2020 +0100
@@ -369,21 +369,18 @@
         v = rand(domain_size(tm)...)
         @tullio IAIv[a,b,c,d] := Ã[c,i]*v[a,b,i,d]
         @test tm*v ≈ IAIv rtol=1e-14
-        @inferred LazyTensors.split_index(tm,1,1,1,1)
 
         # Test InflatedTensorMapping mapping w. before
         tm = InflatedTensorMapping(I(3,2), A)
         v = rand(domain_size(tm)...)
         @tullio IAIv[a,b,c] := Ã[c,i]*v[a,b,i]
         @test tm*v ≈ IAIv rtol=1e-14
-        @inferred LazyTensors.split_index(tm,1,1,1)
 
         # Test InflatedTensorMapping mapping w. after
         tm = InflatedTensorMapping(A,I(4))
         v = rand(domain_size(tm)...)
         @tullio IAIv[c,d] := Ã[c,i]*v[i,d]
         @test tm*v ≈ IAIv rtol=1e-14
-        @inferred LazyTensors.split_index(tm,1,1)
 
         @testset "Inference of application" begin
             struct ScalingOperator{T,D} <: TensorMapping{T,D,D}
@@ -398,7 +395,6 @@
             tm = InflatedTensorMapping(I(2,3),ScalingOperator(2.0, (3,2)),I(3,4))
             v = rand(domain_size(tm)...)
 
-            @inferred LazyTensors.split_index(tm,1,2,3,2,2,4)
             @inferred apply(tm,v,Index{Unknown}.((1,2,3,2,2,4))...)
             @inferred (tm*v)[1,2,3,2,2,4]
         end
@@ -415,6 +411,20 @@
     end
 end
 
+@testset "split_index" begin
+    @test LazyTensors.split_index(Val(2),Val(1),Val(2),Val(2),1,2,3,4,5,6) == ((1,2,:,5,6),(3,4))
+    @test LazyTensors.split_index(Val(2),Val(3),Val(2),Val(2),1,2,3,4,5,6) == ((1,2,:,:,:,5,6),(3,4))
+    @test LazyTensors.split_index(Val(3),Val(1),Val(1),Val(2),1,2,3,4,5,6) == ((1,2,3,:,5,6),(4,))
+    @test LazyTensors.split_index(Val(3),Val(2),Val(1),Val(2),1,2,3,4,5,6) == ((1,2,3,:,:,5,6),(4,))
+    @test LazyTensors.split_index(Val(1),Val(1),Val(2),Val(3),1,2,3,4,5,6) == ((1,:,4,5,6),(2,3))
+    @test LazyTensors.split_index(Val(1),Val(2),Val(2),Val(3),1,2,3,4,5,6) == ((1,:,:,4,5,6),(2,3))
+
+    @test LazyTensors.split_index(Val(0),Val(1),Val(3),Val(3),1,2,3,4,5,6) == ((:,4,5,6),(1,2,3))
+    @test LazyTensors.split_index(Val(3),Val(1),Val(3),Val(0),1,2,3,4,5,6) == ((1,2,3,:),(4,5,6))
+
+    @inferred LazyTensors.split_index(Val(2),Val(3),Val(2),Val(2),1,2,3,2,2,4)
+end
+
 @testset "slice_tuple" begin
     @test LazyTensors.slice_tuple((1,2,3),Val(1), Val(3)) == (1,2,3)
     @test LazyTensors.slice_tuple((1,2,3,4,5,6),Val(2), Val(5)) == (2,3,4,5)