Mercurial > repos > public > sbplib_julia
diff src/LazyTensors/lazy_tensor_operations.jl @ 403:618b7ee73b25 refactor/sbp_operators_tests/collect_and_compare
Merge in default and close branch before merge
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Wed, 07 Oct 2020 12:31:54 +0200 |
parents | 1936e38fe51e |
children | 4aa59af074ef d94891b8dfca |
line wrap: on
line diff
--- a/src/LazyTensors/lazy_tensor_operations.jl Sun Oct 04 19:41:14 2020 +0200 +++ b/src/LazyTensors/lazy_tensor_operations.jl Wed Oct 07 12:31:54 2020 +0200 @@ -102,3 +102,38 @@ # # Have i gone too crazy with the type parameters? Maybe they aren't all needed? # export → +""" + LazyLinearMap{T,R,D,...}(A, range_indicies, domain_indicies) + +TensorMapping defined by the AbstractArray A. `range_indicies` and `domain_indicies` define which indicies of A should +be considerd the range and domain of the TensorMapping. Each set of indices must be ordered in ascending order. + +For instance, if A is a m x n matrix, and range_size = (1,), domain_size = (2,), then the LazyLinearMap performs the +standard matrix-vector product on vectors of size n. +""" +struct LazyLinearMap{T,R,D, RD, AA<:AbstractArray{T,RD}} <: TensorMapping{T,R,D} + A::AA + range_indicies::NTuple{R,Int} + domain_indicies::NTuple{D,Int} + + function LazyLinearMap(A::AA, range_indicies::NTuple{R,Int}, domain_indicies::NTuple{D,Int}) where {T,R,D, RD, AA<:AbstractArray{T,RD}} + if !issorted(range_indicies) || !issorted(domain_indicies) + throw(DomainError("range_indicies and domain_indicies must be sorted in ascending order")) + end + + return new{T,R,D,RD,AA}(A,range_indicies,domain_indicies) + end +end +export LazyLinearMap + +range_size(llm::LazyLinearMap) = size(llm.A)[[llm.range_indicies...]] +domain_size(llm::LazyLinearMap) = size(llm.A)[[llm.domain_indicies...]] + +function apply(llm::LazyLinearMap{T,R,D}, v::AbstractArray{T,D}, I::Vararg{Index,R}) where {T,R,D} + view_index = ntuple(i->:,ndims(llm.A)) + for i ∈ 1:R + view_index = Base.setindex(view_index, Int(I[i]), llm.range_indicies[i]) + end + A_view = @view llm.A[view_index...] + return sum(A_view.*v) +end