diff src/LazyTensors/lazy_tensor_operations.jl @ 446:904aae1899df feature/inflated_tensormapping

Start implementing InflatedTensorMapping
author Jonatan Werpers <jonatan@werpers.com>
date Mon, 19 Oct 2020 08:37:35 +0200
parents 46acb2560451
children 27e0e256e5d9
line wrap: on
line diff
--- a/src/LazyTensors/lazy_tensor_operations.jl	Sun Oct 18 22:30:17 2020 +0200
+++ b/src/LazyTensors/lazy_tensor_operations.jl	Mon Oct 19 08:37:35 2020 +0200
@@ -162,3 +162,75 @@
 apply(tmi::LazyIdentity{T,D}, v::AbstractArray{T,D}, I::Vararg{Index,D}) where {T,D} = v[Int.(I)...]
 apply_transpose(tmi::LazyIdentity{T,D}, v::AbstractArray{T,D}, I::Vararg{Index,D}) where {T,D} = v[Int.(I)...]
 
+struct InflatedTensorMapping{T,R,D,D_before,R_middle,D_middle,D_after} <: TensorMapping{T,R,D}
+    before::LazyIdentity{T,D_before}
+    tm::TensorMapping{T,R_middle,D_middle}
+    after::LazyIdentity{T,D_after}
+
+    function InflatedTensorMapping(before, tm::TensorMapping{T}, after) where T
+        R_before = range_dim(before)
+        R_middle = range_dim(tm)
+        R_after = range_dim(after)
+        R = R_before+R_middle+R_after
+
+        D_before = domain_dim(before)
+        D_middle = domain_dim(tm)
+        D_after = domain_dim(after)
+        D = D_before+D_middle+D_after
+        return new{T,R,D,D_before,R_middle,D_middle,D_after}(before, tm, after)
+end
+
+# TODO: Implement constructors where one of `before` or `after` is missing
+
+# TODO: Implement syntax and constructors for products of different combinations of InflatedTensorMapping and LazyIdentity
+
+# TODO: Implement some pretty printing in terms of ⊗. E.g InflatedTensorMapping(I(3),B,I(2)) -> I(3)⊗B⊗I(2)
+
+function range_size(itm::InflatedTensorMapping)
+    return flatten_tuple(
+        range_size(itm.before),
+        range_size(itm.tm),
+        range_size(itm.after),
+    )
+end
+
+function domain_size(itm::InflatedTensorMapping)
+    return flatten_tuple(
+        domain_size(itm.before),
+        domain_size(itm.tm),
+        domain_size(itm.after),
+    )
+end
+
+function apply(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{T,D}, I::Vararg{S,R} where S) where {T,R,D}
+    view_index, inner_index = split_index(I...)
+
+    v_inner = view(v, view_index...)
+    return apply(itm.tm, v_inner, inner_index...)
+end
+
+
+"""
+    split_index(...)
+
+Splits the multi-index into two parts. One part for the view that the inner TensorMapping acts on, and one part for indexing the result
+Eg.
+```
+(1,2,3,4) -> (1,:,:,4), (2,3)
+```
+"""
+function split_index(itm::InflatedTensorMapping{T,R,D}, I::Vararg{S,R} where S) where {T,R,D}
+    I_before = I[1:range_dim(itm.before)]
+    I_after = I[(end-range_dim(itm.after)+1):end]
+
+    view_index = (I_before..., ntuple((i)->:,domain_dim(itm.tm))..., I_after...)
+    A_view = @view llm.A[view_index...]
+    inner_index = I[range_dim(itm.before)+1:end-range_dim(itm.after)]
+
+    return (view_index, inner_index)
+    return sum(A_view.*v)
+end
+
+flatten_tuple(t::NTuple{N, Number} where N) = t
+flatten_tuple(t::Tuple) = ((flatten_tuple.(t)...)...,) # simplify?
+flatten_tuple(ts::Vararg) = flatten_tuple(ts)