Mercurial > repos > public > sbplib_julia
diff src/LazyTensors/lazy_tensor_operations.jl @ 520:fe86ac896377 feature/inflated_tensormapping_transpose
Start refactoring split index and apply to accomodate future addition of apply_transpose
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Mon, 23 Nov 2020 21:30:11 +0100 |
parents | 4b9d124fe984 |
children | 7e6250c51eb2 |
line wrap: on
line diff
--- a/src/LazyTensors/lazy_tensor_operations.jl Mon Nov 23 21:15:04 2020 +0100 +++ b/src/LazyTensors/lazy_tensor_operations.jl Mon Nov 23 21:30:11 2020 +0100 @@ -261,7 +261,12 @@ end function apply(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D} - view_index, inner_index = split_index(itm, I...) + A = range_dim(itm.before) + B_domain = domain_dim(itm.tm) + B_range = range_dim(itm.tm) + C = range_dim(itm.after) + + view_index, inner_index = split_index(Val(A), Val(B_range), Val(B_domain), Val(C), I...) v_inner = view(v, view_index...) return apply(itm.tm, v_inner, inner_index...) @@ -269,22 +274,31 @@ """ - split_index(...) + split_index(:Val{A}, ::Val{B_view}, ::Val{B_middle}, ::Val{C}, I...) -Splits the multi-index into two parts. One part for the view that the inner TensorMapping acts on, and one part for indexing the result +Splits the multi-index `I` into two parts. One part which is expected to be used as a view, which is expected to be used as an index. Eg. ``` -(1,2,3,4) -> (1,:,:,4), (2,3) +(1,2,3,4) -> (1,:,:,:,4), (2,3) ``` + +`B_view` controls how many colons are in the view, and `B_middle` controls how many elements are extracted from the middle. +`A` and `C` decides the length of the parts before and after the colons in the view index. +length(I) == A+B_domain+C +length(I_middle) == B_domain +length(I_view) == A + B_range + C + +TODO: Finish documentation. """ -function split_index(itm::InflatedTensorMapping{T,R,D}, I::Vararg{Any,R}) where {T,R,D} - I_before = slice_tuple(I, Val(1), Val(range_dim(itm.before))) - I_after = slice_tuple(I, Val(R-range_dim(itm.after)+1), Val(R)) +function split_index(::Val{A}, ::Val{B_view}, ::Val{B_middle}, ::Val{C}, I...) where {A,B_view, B_middle,C} + I_before = slice_tuple(I, Val(1), Val(A)) + I_middle = slice_tuple(I, Val(A+1), Val(A+B_middle)) + I_after = slice_tuple(I, Val(A+B_middle+1), Val(A+B_middle+C)) - view_index = (I_before..., ntuple((i)->:,domain_dim(itm.tm))..., I_after...) - inner_index = slice_tuple(I, Val(range_dim(itm.before)+1), Val(R-range_dim(itm.after))) + view_index = (I_before..., ntuple((i)->:, B_view)..., I_after...) + inner_index = - return (view_index, inner_index) + return view_index, I_middle end # TODO: Can this be replaced by something more elegant while still being type stable? 2020-10-21