diff src/LazyTensors/lazy_tensor_operations.jl @ 473:3041f8578bba

Merge in feature/inflated_tensormapping.
author Jonatan Werpers <jonatan@werpers.com>
date Mon, 02 Nov 2020 21:33:35 +0100
parents f270d82fc9ad
children 481e86e77c22 95f3b9036801
line wrap: on
line diff
--- a/src/LazyTensors/lazy_tensor_operations.jl	Tue Oct 20 09:23:16 2020 +0200
+++ b/src/LazyTensors/lazy_tensor_operations.jl	Mon Nov 02 21:33:35 2020 +0100
@@ -170,3 +170,112 @@
 apply(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...]
 apply_transpose(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...]
 
+
+"""
+    InflatedTensorMapping{T,R,D} <: TensorMapping{T,R,D}
+
+An inflated `TensorMapping` with dimensions added before and afer its actual dimensions.
+"""
+struct InflatedTensorMapping{T,R,D,D_before,R_middle,D_middle,D_after, TM<:TensorMapping{T,R_middle,D_middle}} <: TensorMapping{T,R,D}
+    before::IdentityMapping{T,D_before}
+    tm::TM
+    after::IdentityMapping{T,D_after}
+
+    function InflatedTensorMapping(before, tm::TensorMapping{T}, after) where T
+        R_before = range_dim(before)
+        R_middle = range_dim(tm)
+        R_after = range_dim(after)
+        R = R_before+R_middle+R_after
+
+        D_before = domain_dim(before)
+        D_middle = domain_dim(tm)
+        D_after = domain_dim(after)
+        D = D_before+D_middle+D_after
+        return new{T,R,D,D_before,R_middle,D_middle,D_after, typeof(tm)}(before, tm, after)
+    end
+end
+export InflatedTensorMapping
+"""
+    InflatedTensorMapping(before, tm, after)
+    InflatedTensorMapping(before,tm)
+    InflatedTensorMapping(tm,after)
+
+The outer product of `before`, `tm` and `after`, where `before` and `after` are `IdentityMapping`s.
+
+If one of `before` or `after` is left out, a 0-dimensional `IdentityMapping` is used as the default value.
+"""
+InflatedTensorMapping(::IdentityMapping, ::TensorMapping, ::IdentityMapping)
+InflatedTensorMapping(before::IdentityMapping, tm::TensorMapping{T}) where T = InflatedTensorMapping(before,tm,IdentityMapping{T}())
+InflatedTensorMapping(tm::TensorMapping{T}, after::IdentityMapping) where T = InflatedTensorMapping(IdentityMapping{T}(),tm,after)
+# Resolve ambiguity between the two previous methods
+InflatedTensorMapping(I1::IdentityMapping{T}, I2::IdentityMapping{T}) where T = InflatedTensorMapping(I1,I2,IdentityMapping{T}())
+
+# TODO: Implement syntax and constructors for products of different combinations of InflatedTensorMapping and IdentityMapping
+
+# TODO: Implement some pretty printing in terms of ⊗. E.g InflatedTensorMapping(I(3),B,I(2)) -> I(3)⊗B⊗I(2)
+
+function range_size(itm::InflatedTensorMapping)
+    return flatten_tuple(
+        range_size(itm.before),
+        range_size(itm.tm),
+        range_size(itm.after),
+    )
+end
+
+function domain_size(itm::InflatedTensorMapping)
+    return flatten_tuple(
+        domain_size(itm.before),
+        domain_size(itm.tm),
+        domain_size(itm.after),
+    )
+end
+
+function apply(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D}
+    view_index, inner_index = split_index(itm, I...)
+
+    v_inner = view(v, view_index...)
+    return apply(itm.tm, v_inner, inner_index...)
+end
+
+
+"""
+    split_index(...)
+
+Splits the multi-index into two parts. One part for the view that the inner TensorMapping acts on, and one part for indexing the result
+Eg.
+```
+(1,2,3,4) -> (1,:,:,4), (2,3)
+```
+"""
+function split_index(itm::InflatedTensorMapping{T,R,D}, I::Vararg{Any,R}) where {T,R,D}
+    I_before = slice_tuple(I, Val(1), Val(range_dim(itm.before)))
+    I_after = slice_tuple(I, Val(R-range_dim(itm.after)+1), Val(R))
+
+    view_index = (I_before..., ntuple((i)->:,domain_dim(itm.tm))..., I_after...)
+    inner_index = slice_tuple(I, Val(range_dim(itm.before)+1), Val(R-range_dim(itm.after)))
+
+    return (view_index, inner_index)
+end
+
+# TODO: Can this be replaced by something more elegant while still being type stable? 2020-10-21
+# See:
+# https://github.com/JuliaLang/julia/issues/34884
+# https://github.com/JuliaLang/julia/issues/30386
+"""
+    slice_tuple(t, Val(l), Val(u))
+
+Get a slice of a tuple in a type stable way.
+Equivalent to t[l:u] but type stable.
+"""
+function slice_tuple(t,::Val{L},::Val{U}) where {L,U}
+    return ntuple(i->t[i+L-1], U-L+1)
+end
+
+"""
+    flatten_tuple(t)
+
+Takes a nested tuple and flattens the whole structure
+"""
+flatten_tuple(t::NTuple{N, Number} where N) = t
+flatten_tuple(t::Tuple) = ((flatten_tuple.(t)...)...,) # simplify?
+flatten_tuple(ts::Vararg) = flatten_tuple(ts)