Mercurial > repos > public > sbplib_julia
diff src/LazyTensors/lazy_tensor_operations.jl @ 446:904aae1899df feature/inflated_tensormapping
Start implementing InflatedTensorMapping
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Mon, 19 Oct 2020 08:37:35 +0200 |
parents | 46acb2560451 |
children | 27e0e256e5d9 |
line wrap: on
line diff
--- a/src/LazyTensors/lazy_tensor_operations.jl Sun Oct 18 22:30:17 2020 +0200 +++ b/src/LazyTensors/lazy_tensor_operations.jl Mon Oct 19 08:37:35 2020 +0200 @@ -162,3 +162,75 @@ apply(tmi::LazyIdentity{T,D}, v::AbstractArray{T,D}, I::Vararg{Index,D}) where {T,D} = v[Int.(I)...] apply_transpose(tmi::LazyIdentity{T,D}, v::AbstractArray{T,D}, I::Vararg{Index,D}) where {T,D} = v[Int.(I)...] +struct InflatedTensorMapping{T,R,D,D_before,R_middle,D_middle,D_after} <: TensorMapping{T,R,D} + before::LazyIdentity{T,D_before} + tm::TensorMapping{T,R_middle,D_middle} + after::LazyIdentity{T,D_after} + + function InflatedTensorMapping(before, tm::TensorMapping{T}, after) where T + R_before = range_dim(before) + R_middle = range_dim(tm) + R_after = range_dim(after) + R = R_before+R_middle+R_after + + D_before = domain_dim(before) + D_middle = domain_dim(tm) + D_after = domain_dim(after) + D = D_before+D_middle+D_after + return new{T,R,D,D_before,R_middle,D_middle,D_after}(before, tm, after) +end + +# TODO: Implement constructors where one of `before` or `after` is missing + +# TODO: Implement syntax and constructors for products of different combinations of InflatedTensorMapping and LazyIdentity + +# TODO: Implement some pretty printing in terms of ⊗. E.g InflatedTensorMapping(I(3),B,I(2)) -> I(3)⊗B⊗I(2) + +function range_size(itm::InflatedTensorMapping) + return flatten_tuple( + range_size(itm.before), + range_size(itm.tm), + range_size(itm.after), + ) +end + +function domain_size(itm::InflatedTensorMapping) + return flatten_tuple( + domain_size(itm.before), + domain_size(itm.tm), + domain_size(itm.after), + ) +end + +function apply(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{T,D}, I::Vararg{S,R} where S) where {T,R,D} + view_index, inner_index = split_index(I...) + + v_inner = view(v, view_index...) + return apply(itm.tm, v_inner, inner_index...) +end + + +""" + split_index(...) + +Splits the multi-index into two parts. One part for the view that the inner TensorMapping acts on, and one part for indexing the result +Eg. +``` +(1,2,3,4) -> (1,:,:,4), (2,3) +``` +""" +function split_index(itm::InflatedTensorMapping{T,R,D}, I::Vararg{S,R} where S) where {T,R,D} + I_before = I[1:range_dim(itm.before)] + I_after = I[(end-range_dim(itm.after)+1):end] + + view_index = (I_before..., ntuple((i)->:,domain_dim(itm.tm))..., I_after...) + A_view = @view llm.A[view_index...] + inner_index = I[range_dim(itm.before)+1:end-range_dim(itm.after)] + + return (view_index, inner_index) + return sum(A_view.*v) +end + +flatten_tuple(t::NTuple{N, Number} where N) = t +flatten_tuple(t::Tuple) = ((flatten_tuple.(t)...)...,) # simplify? +flatten_tuple(ts::Vararg) = flatten_tuple(ts)