Mercurial > repos > public > sbplib_julia
diff src/LazyTensors/lazy_tensor_operations.jl @ 473:3041f8578bba
Merge in feature/inflated_tensormapping.
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Mon, 02 Nov 2020 21:33:35 +0100 |
parents | f270d82fc9ad |
children | 481e86e77c22 95f3b9036801 |
line wrap: on
line diff
--- a/src/LazyTensors/lazy_tensor_operations.jl Tue Oct 20 09:23:16 2020 +0200 +++ b/src/LazyTensors/lazy_tensor_operations.jl Mon Nov 02 21:33:35 2020 +0100 @@ -170,3 +170,112 @@ apply(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...] apply_transpose(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...] + +""" + InflatedTensorMapping{T,R,D} <: TensorMapping{T,R,D} + +An inflated `TensorMapping` with dimensions added before and afer its actual dimensions. +""" +struct InflatedTensorMapping{T,R,D,D_before,R_middle,D_middle,D_after, TM<:TensorMapping{T,R_middle,D_middle}} <: TensorMapping{T,R,D} + before::IdentityMapping{T,D_before} + tm::TM + after::IdentityMapping{T,D_after} + + function InflatedTensorMapping(before, tm::TensorMapping{T}, after) where T + R_before = range_dim(before) + R_middle = range_dim(tm) + R_after = range_dim(after) + R = R_before+R_middle+R_after + + D_before = domain_dim(before) + D_middle = domain_dim(tm) + D_after = domain_dim(after) + D = D_before+D_middle+D_after + return new{T,R,D,D_before,R_middle,D_middle,D_after, typeof(tm)}(before, tm, after) + end +end +export InflatedTensorMapping +""" + InflatedTensorMapping(before, tm, after) + InflatedTensorMapping(before,tm) + InflatedTensorMapping(tm,after) + +The outer product of `before`, `tm` and `after`, where `before` and `after` are `IdentityMapping`s. + +If one of `before` or `after` is left out, a 0-dimensional `IdentityMapping` is used as the default value. +""" +InflatedTensorMapping(::IdentityMapping, ::TensorMapping, ::IdentityMapping) +InflatedTensorMapping(before::IdentityMapping, tm::TensorMapping{T}) where T = InflatedTensorMapping(before,tm,IdentityMapping{T}()) +InflatedTensorMapping(tm::TensorMapping{T}, after::IdentityMapping) where T = InflatedTensorMapping(IdentityMapping{T}(),tm,after) +# Resolve ambiguity between the two previous methods +InflatedTensorMapping(I1::IdentityMapping{T}, I2::IdentityMapping{T}) where T = InflatedTensorMapping(I1,I2,IdentityMapping{T}()) + +# TODO: Implement syntax and constructors for products of different combinations of InflatedTensorMapping and IdentityMapping + +# TODO: Implement some pretty printing in terms of ⊗. E.g InflatedTensorMapping(I(3),B,I(2)) -> I(3)⊗B⊗I(2) + +function range_size(itm::InflatedTensorMapping) + return flatten_tuple( + range_size(itm.before), + range_size(itm.tm), + range_size(itm.after), + ) +end + +function domain_size(itm::InflatedTensorMapping) + return flatten_tuple( + domain_size(itm.before), + domain_size(itm.tm), + domain_size(itm.after), + ) +end + +function apply(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D} + view_index, inner_index = split_index(itm, I...) + + v_inner = view(v, view_index...) + return apply(itm.tm, v_inner, inner_index...) +end + + +""" + split_index(...) + +Splits the multi-index into two parts. One part for the view that the inner TensorMapping acts on, and one part for indexing the result +Eg. +``` +(1,2,3,4) -> (1,:,:,4), (2,3) +``` +""" +function split_index(itm::InflatedTensorMapping{T,R,D}, I::Vararg{Any,R}) where {T,R,D} + I_before = slice_tuple(I, Val(1), Val(range_dim(itm.before))) + I_after = slice_tuple(I, Val(R-range_dim(itm.after)+1), Val(R)) + + view_index = (I_before..., ntuple((i)->:,domain_dim(itm.tm))..., I_after...) + inner_index = slice_tuple(I, Val(range_dim(itm.before)+1), Val(R-range_dim(itm.after))) + + return (view_index, inner_index) +end + +# TODO: Can this be replaced by something more elegant while still being type stable? 2020-10-21 +# See: +# https://github.com/JuliaLang/julia/issues/34884 +# https://github.com/JuliaLang/julia/issues/30386 +""" + slice_tuple(t, Val(l), Val(u)) + +Get a slice of a tuple in a type stable way. +Equivalent to t[l:u] but type stable. +""" +function slice_tuple(t,::Val{L},::Val{U}) where {L,U} + return ntuple(i->t[i+L-1], U-L+1) +end + +""" + flatten_tuple(t) + +Takes a nested tuple and flattens the whole structure +""" +flatten_tuple(t::NTuple{N, Number} where N) = t +flatten_tuple(t::Tuple) = ((flatten_tuple.(t)...)...,) # simplify? +flatten_tuple(ts::Vararg) = flatten_tuple(ts)