changeset 450:ac6d22570a08 feature/inflated_tensormapping

Merge in feature/lazy_identity
author Jonatan Werpers <jonatan@werpers.com>
date Mon, 19 Oct 2020 21:42:57 +0200
parents 14d60de71b72 (diff) e70e47fbfa7c (current diff)
children 6cf234eef780
files src/LazyTensors/lazy_tensor_operations.jl test/testLazyTensors.jl
diffstat 2 files changed, 77 insertions(+), 0 deletions(-) [+]
line wrap: on
line diff
--- a/src/LazyTensors/lazy_tensor_operations.jl	Mon Oct 19 21:41:34 2020 +0200
+++ b/src/LazyTensors/lazy_tensor_operations.jl	Mon Oct 19 21:42:57 2020 +0200
@@ -170,3 +170,76 @@
 apply(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...]
 apply_transpose(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...]
 
+struct InflatedTensorMapping{T,R,D,D_before,R_middle,D_middle,D_after, TM<:TensorMapping{T,R_middle,D_middle}} <: TensorMapping{T,R,D}
+    before::IdentityMapping{T,D_before}
+    tm::TM
+    after::IdentityMapping{T,D_after}
+
+    function InflatedTensorMapping(before, tm::TensorMapping{T}, after) where T
+        R_before = range_dim(before)
+        R_middle = range_dim(tm)
+        R_after = range_dim(after)
+        R = R_before+R_middle+R_after
+
+        D_before = domain_dim(before)
+        D_middle = domain_dim(tm)
+        D_after = domain_dim(after)
+        D = D_before+D_middle+D_after
+        return new{T,R,D,D_before,R_middle,D_middle,D_after, typeof(tm)}(before, tm, after)
+    end
+end
+
+# TODO: Implement constructors where one of `before` or `after` is missing
+
+# TODO: Implement syntax and constructors for products of different combinations of InflatedTensorMapping and IdentityMapping
+
+# TODO: Implement some pretty printing in terms of ⊗. E.g InflatedTensorMapping(I(3),B,I(2)) -> I(3)⊗B⊗I(2)
+
+function range_size(itm::InflatedTensorMapping)
+    return flatten_tuple(
+        range_size(itm.before),
+        range_size(itm.tm),
+        range_size(itm.after),
+    )
+end
+
+function domain_size(itm::InflatedTensorMapping)
+    return flatten_tuple(
+        domain_size(itm.before),
+        domain_size(itm.tm),
+        domain_size(itm.after),
+    )
+end
+
+function apply(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{T,D}, I::Vararg{S,R} where S) where {T,R,D}
+    view_index, inner_index = split_index(I...)
+
+    v_inner = view(v, view_index...)
+    return apply(itm.tm, v_inner, inner_index...)
+end
+
+
+"""
+    split_index(...)
+
+Splits the multi-index into two parts. One part for the view that the inner TensorMapping acts on, and one part for indexing the result
+Eg.
+```
+(1,2,3,4) -> (1,:,:,4), (2,3)
+```
+"""
+function split_index(itm::InflatedTensorMapping{T,R,D}, I::Vararg{S,R} where S) where {T,R,D}
+    I_before = I[1:range_dim(itm.before)]
+    I_after = I[(end-range_dim(itm.after)+1):end]
+
+    view_index = (I_before..., ntuple((i)->:,domain_dim(itm.tm))..., I_after...)
+    A_view = @view llm.A[view_index...]
+    inner_index = I[range_dim(itm.before)+1:end-range_dim(itm.after)]
+
+    return (view_index, inner_index)
+    return sum(A_view.*v)
+end
+
+flatten_tuple(t::NTuple{N, Number} where N) = t
+flatten_tuple(t::Tuple) = ((flatten_tuple.(t)...)...,) # simplify?
+flatten_tuple(ts::Vararg) = flatten_tuple(ts)
--- a/test/testLazyTensors.jl	Mon Oct 19 21:41:34 2020 +0200
+++ b/test/testLazyTensors.jl	Mon Oct 19 21:42:57 2020 +0200
@@ -306,4 +306,8 @@
     @inferred range_size(I)
 end
 
+@testset "InflatedTensorMapping" begin
+
 end
+
+end