changeset 540:013ca4892540

Merge feature/inflated_tensormapping_transpose
author Jonatan Werpers <jonatan@werpers.com>
date Thu, 26 Nov 2020 17:53:40 +0100
parents b6768d769f46 (current diff) 804791159fa6 (diff)
children 62d96e2cd165 ff412b29db31 37a81dad36b9
files
diffstat 2 files changed, 128 insertions(+), 19 deletions(-) [+]
line wrap: on
line diff
diff -r b6768d769f46 -r 013ca4892540 src/LazyTensors/lazy_tensor_operations.jl
--- a/src/LazyTensors/lazy_tensor_operations.jl	Thu Nov 26 17:20:22 2020 +0100
+++ b/src/LazyTensors/lazy_tensor_operations.jl	Thu Nov 26 17:53:40 2020 +0100
@@ -240,8 +240,6 @@
 # Resolve ambiguity between the two previous methods
 InflatedTensorMapping(I1::IdentityMapping{T}, I2::IdentityMapping{T}) where T = InflatedTensorMapping(I1,I2,IdentityMapping{T}())
 
-# TODO: Implement syntax and constructors for products of different combinations of InflatedTensorMapping and IdentityMapping
-
 # TODO: Implement some pretty printing in terms of ⊗. E.g InflatedTensorMapping(I(3),B,I(2)) -> I(3)⊗B⊗I(2)
 
 function range_size(itm::InflatedTensorMapping)
@@ -261,30 +259,56 @@
 end
 
 function apply(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D}
-    view_index, inner_index = split_index(itm, I...)
+    dim_before = range_dim(itm.before)
+    dim_domain = domain_dim(itm.tm)
+    dim_range = range_dim(itm.tm)
+    dim_after = range_dim(itm.after)
+
+    view_index, inner_index = split_index(Val(dim_before), Val(dim_domain), Val(dim_range), Val(dim_after), I...)
 
     v_inner = view(v, view_index...)
     return apply(itm.tm, v_inner, inner_index...)
 end
 
+function apply_transpose(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{T,R}, I::Vararg{Any,D}) where {T,R,D}
+    dim_before = range_dim(itm.before)
+    dim_domain = domain_dim(itm.tm)
+    dim_range = range_dim(itm.tm)
+    dim_after = range_dim(itm.after)
+
+    view_index, inner_index = split_index(Val(dim_before), Val(dim_range), Val(dim_domain), Val(dim_after), I...)
+
+    v_inner = view(v, view_index...)
+    return apply_transpose(itm.tm, v_inner, inner_index...)
+end
+
 
 """
-    split_index(...)
+    split_index(::Val{dim_before}, ::Val{dim_view}, ::Val{dim_index}, ::Val{dim_after}, I...)
 
-Splits the multi-index into two parts. One part for the view that the inner TensorMapping acts on, and one part for indexing the result
+Splits the multi-index `I` into two parts. One part which is expected to be
+used as a view, and one which is expected to be used as an index.
 Eg.
 ```
-(1,2,3,4) -> (1,:,:,4), (2,3)
+split_index(Val(1),Val(3),Val(2),Val(1),(1,2,3,4)) -> (1,:,:,:,4), (2,3)
 ```
+
+`dim_view` controls how many colons are in the view, and `dim_index` controls
+how many elements are extracted from the middle.
+`dim_before` and `dim_after` decides the length of the index parts before and after the colons in the view index.
+
+Arguments should satisfy `length(I) == dim_before+B_domain+dim_after`.
+
+The returned values satisfy
+ * `length(view_index) == dim_before + dim_view + dim_after`
+ * `length(I_middle) == dim_index`
 """
-function split_index(itm::InflatedTensorMapping{T,R,D}, I::Vararg{Any,R}) where {T,R,D}
-    I_before = slice_tuple(I, Val(1), Val(range_dim(itm.before)))
-    I_after = slice_tuple(I, Val(R-range_dim(itm.after)+1), Val(R))
+function split_index(::Val{dim_before}, ::Val{dim_view}, ::Val{dim_index}, ::Val{dim_after}, I...) where {dim_before,dim_view, dim_index,dim_after}
+    I_before, I_middle, I_after = split_tuple(I, Val(dim_before), Val(dim_index))
 
-    view_index = (I_before..., ntuple((i)->:,domain_dim(itm.tm))..., I_after...)
-    inner_index = slice_tuple(I, Val(range_dim(itm.before)+1), Val(R-range_dim(itm.after)))
+    view_index = (I_before..., ntuple((i)->:, dim_view)..., I_after...)
 
-    return (view_index, inner_index)
+    return view_index, I_middle
 end
 
 # TODO: Can this be replaced by something more elegant while still being type stable? 2020-10-21
@@ -302,6 +326,32 @@
 end
 
 """
+    split_tuple(t::Tuple{...}, ::Val{M}) where {N,M}
+
+Split the tuple `t` into two parts. the first part is `M` long.
+E.g
+```
+split_tuple((1,2,3,4),Val(3)) -> (1,2,3), (4,)
+```
+"""
+function split_tuple(t::NTuple{N},::Val{M}) where {N,M}
+    return slice_tuple(t,Val(1), Val(M)), slice_tuple(t,Val(M+1), Val(N))
+end
+
+"""
+    split_tuple(t::Tuple{...},::Val{M},::Val{K}) where {N,M,K}
+
+Same as `split_tuple(t::NTuple{N},::Val{M})` but splits the tuple in three parts. With the first
+two parts having lenght `M` and `K`.
+"""
+function split_tuple(t::NTuple{N},::Val{M},::Val{K}) where {N,M,K}
+    p1, tail = split_tuple(t, Val(M))
+    p2, p3 = split_tuple(tail, Val(K))
+    return p1,p2,p3
+end
+
+
+"""
     flatten_tuple(t)
 
 Takes a nested tuple and flattens the whole structure
diff -r b6768d769f46 -r 013ca4892540 test/testLazyTensors.jl
--- a/test/testLazyTensors.jl	Thu Nov 26 17:20:22 2020 +0100
+++ b/test/testLazyTensors.jl	Thu Nov 26 17:53:40 2020 +0100
@@ -369,47 +369,67 @@
         tests = [
             (
                 InflatedTensorMapping(I(3,2), A, I(4)),
-                (v-> @tullio res[a,b,c,d] := Ã[c,i]*v[a,b,i,d]),
+                (v-> @tullio res[a,b,c,d] := Ã[c,i]*v[a,b,i,d]), # Expected result of apply
+                (v-> @tullio res[a,b,c,d] := Ã[i,c]*v[a,b,i,d]), # Expected result of apply_transpose
             ),
             (
                 InflatedTensorMapping(I(3,2), B, I(4)),
                 (v-> @tullio res[a,b,c,d,e] := B̃[c,d,i]*v[a,b,i,e]),
+                (v-> @tullio res[a,b,c,d] := B̃[i,j,c]*v[a,b,i,j,d]),
             ),
             (
                 InflatedTensorMapping(I(3,2), C, I(4)),
                 (v-> @tullio res[a,b,c,d] := C̃[c,i,j]*v[a,b,i,j,d]),
+                (v-> @tullio res[a,b,c,d,e] := C̃[i,c,d]*v[a,b,i,e]),
             ),
             (
                 InflatedTensorMapping(I(3,2), A),
                 (v-> @tullio res[a,b,c] := Ã[c,i]*v[a,b,i]),
+                (v-> @tullio res[a,b,c] := Ã[i,c]*v[a,b,i]),
             ),
             (
                 InflatedTensorMapping(I(3,2), B),
                 (v-> @tullio res[a,b,c,d] := B̃[c,d,i]*v[a,b,i]),
+                (v-> @tullio res[a,b,c] := B̃[i,j,c]*v[a,b,i,j]),
             ),
             (
                 InflatedTensorMapping(I(3,2), C),
                 (v-> @tullio res[a,b,c] := C̃[c,i,j]*v[a,b,i,j]),
+                (v-> @tullio res[a,b,c,d] := C̃[i,c,d]*v[a,b,i]),
             ),
             (
                 InflatedTensorMapping(A,I(4)),
                 (v-> @tullio res[a,b] := Ã[a,i]*v[i,b]),
+                (v-> @tullio res[a,b] := Ã[i,a]*v[i,b]),
             ),
             (
                 InflatedTensorMapping(B,I(4)),
                 (v-> @tullio res[a,b,c] := B̃[a,b,i]*v[i,c]),
+                (v-> @tullio res[a,b] := B̃[i,j,a]*v[i,j,b]),
             ),
             (
                 InflatedTensorMapping(C,I(4)),
                 (v-> @tullio res[a,b] := C̃[a,i,j]*v[i,j,b]),
+                (v-> @tullio res[a,b,c] := C̃[i,a,b]*v[i,c]),
             ),
         ]
 
-        for i ∈ 1:length(tests)
-            tm = tests[i][1]
-            v = rand(domain_size(tm)...)
-            true_value = tests[i][2](v)
-            @test tm*v ≈ true_value rtol=1e-14
+        @testset "apply" begin
+            for i ∈ 1:length(tests)
+                tm = tests[i][1]
+                v = rand(domain_size(tm)...)
+                true_value = tests[i][2](v)
+                @test tm*v ≈ true_value rtol=1e-14
+            end
+        end
+
+        @testset "apply_transpose" begin
+            for i ∈ 1:length(tests)
+                tm = tests[i][1]
+                v = rand(range_size(tm)...)
+                true_value = tests[i][3](v)
+                @test tm'*v ≈ true_value rtol=1e-14
+            end
         end
 
         @testset "Inference of application" begin
@@ -425,7 +445,6 @@
             tm = InflatedTensorMapping(I(2,3),ScalingOperator(2.0, (3,2)),I(3,4))
             v = rand(domain_size(tm)...)
 
-            @inferred LazyTensors.split_index(tm,1,2,3,2,2,4)
             @inferred apply(tm,v,Index{Unknown}.((1,2,3,2,2,4))...)
             @inferred (tm*v)[1,2,3,2,2,4]
         end
@@ -442,6 +461,20 @@
     end
 end
 
+@testset "split_index" begin
+    @test LazyTensors.split_index(Val(2),Val(1),Val(2),Val(2),1,2,3,4,5,6) == ((1,2,:,5,6),(3,4))
+    @test LazyTensors.split_index(Val(2),Val(3),Val(2),Val(2),1,2,3,4,5,6) == ((1,2,:,:,:,5,6),(3,4))
+    @test LazyTensors.split_index(Val(3),Val(1),Val(1),Val(2),1,2,3,4,5,6) == ((1,2,3,:,5,6),(4,))
+    @test LazyTensors.split_index(Val(3),Val(2),Val(1),Val(2),1,2,3,4,5,6) == ((1,2,3,:,:,5,6),(4,))
+    @test LazyTensors.split_index(Val(1),Val(1),Val(2),Val(3),1,2,3,4,5,6) == ((1,:,4,5,6),(2,3))
+    @test LazyTensors.split_index(Val(1),Val(2),Val(2),Val(3),1,2,3,4,5,6) == ((1,:,:,4,5,6),(2,3))
+
+    @test LazyTensors.split_index(Val(0),Val(1),Val(3),Val(3),1,2,3,4,5,6) == ((:,4,5,6),(1,2,3))
+    @test LazyTensors.split_index(Val(3),Val(1),Val(3),Val(0),1,2,3,4,5,6) == ((1,2,3,:),(4,5,6))
+
+    @inferred LazyTensors.split_index(Val(2),Val(3),Val(2),Val(2),1,2,3,2,2,4)
+end
+
 @testset "slice_tuple" begin
     @test LazyTensors.slice_tuple((1,2,3),Val(1), Val(3)) == (1,2,3)
     @test LazyTensors.slice_tuple((1,2,3,4,5,6),Val(2), Val(5)) == (2,3,4,5)
@@ -449,6 +482,32 @@
     @test LazyTensors.slice_tuple((1,2,3,4,5,6),Val(4), Val(6)) == (4,5,6)
 end
 
+@testset "split_tuple" begin
+    @testset "2 parts" begin
+        @test LazyTensors.split_tuple((),Val(0)) == ((),())
+        @test LazyTensors.split_tuple((1,),Val(0)) == ((),(1,))
+        @test LazyTensors.split_tuple((1,),Val(1)) == ((1,),())
+
+        @test LazyTensors.split_tuple((1,2,3,4),Val(0)) == ((),(1,2,3,4))
+        @test LazyTensors.split_tuple((1,2,3,4),Val(1)) == ((1,),(2,3,4))
+        @test LazyTensors.split_tuple((1,2,3,4),Val(2)) == ((1,2),(3,4))
+        @test LazyTensors.split_tuple((1,2,3,4),Val(3)) == ((1,2,3),(4,))
+        @test LazyTensors.split_tuple((1,2,3,4),Val(4)) == ((1,2,3,4),())
+
+        @inferred LazyTensors.split_tuple((1,2,3,4),Val(3))
+    end
+
+    @testset "3 parts" begin
+        @test LazyTensors.split_tuple((),Val(0),Val(0)) == ((),(),())
+        @test LazyTensors.split_tuple((1,2,3),Val(1), Val(1)) == ((1,),(2,),(3,))
+
+        @test LazyTensors.split_tuple((1,2,3,4,5,6),Val(1),Val(2)) == ((1,),(2,3),(4,5,6))
+        @test LazyTensors.split_tuple((1,2,3,4,5,6),Val(3),Val(2)) == ((1,2,3),(4,5),(6,))
+
+        @inferred LazyTensors.split_tuple((1,2,3,4,5,6),Val(3),Val(2))
+    end
+end
+
 @testset "flatten_tuple" begin
     @test LazyTensors.flatten_tuple((1,)) == (1,)
     @test LazyTensors.flatten_tuple((1,2,3,4,5,6)) == (1,2,3,4,5,6)