diff src/LazyTensors/lazy_tensor_operations.jl @ 402:1936e38fe51e

Merge feature/lazy_linear_map
author Vidar Stiernström <vidar.stiernstrom@it.uu.se>
date Mon, 05 Oct 2020 10:45:30 +0200
parents de4746d6d126 c640f37d1c74
children 4aa59af074ef d94891b8dfca
line wrap: on
line diff
--- a/src/LazyTensors/lazy_tensor_operations.jl	Wed Sep 30 21:53:52 2020 +0200
+++ b/src/LazyTensors/lazy_tensor_operations.jl	Mon Oct 05 10:45:30 2020 +0200
@@ -102,3 +102,38 @@
 # # Have i gone too crazy with the type parameters? Maybe they aren't all needed?
 
 # export →
+"""
+    LazyLinearMap{T,R,D,...}(A, range_indicies, domain_indicies)
+
+TensorMapping defined by the AbstractArray A. `range_indicies` and `domain_indicies` define which indicies of A should
+be considerd the range and domain of the TensorMapping. Each set of indices must be ordered in ascending order.
+
+For instance, if A is a m x n matrix, and range_size = (1,), domain_size = (2,), then the LazyLinearMap performs the
+standard matrix-vector product on vectors of size n.
+"""
+struct LazyLinearMap{T,R,D, RD, AA<:AbstractArray{T,RD}} <: TensorMapping{T,R,D}
+    A::AA
+    range_indicies::NTuple{R,Int}
+    domain_indicies::NTuple{D,Int}
+
+    function LazyLinearMap(A::AA, range_indicies::NTuple{R,Int}, domain_indicies::NTuple{D,Int}) where {T,R,D, RD, AA<:AbstractArray{T,RD}}
+        if !issorted(range_indicies) || !issorted(domain_indicies)
+            throw(DomainError("range_indicies and domain_indicies must be sorted in ascending order"))
+        end
+
+        return new{T,R,D,RD,AA}(A,range_indicies,domain_indicies)
+    end
+end
+export LazyLinearMap
+
+range_size(llm::LazyLinearMap) = size(llm.A)[[llm.range_indicies...]]
+domain_size(llm::LazyLinearMap) = size(llm.A)[[llm.domain_indicies...]]
+
+function apply(llm::LazyLinearMap{T,R,D}, v::AbstractArray{T,D}, I::Vararg{Index,R}) where {T,R,D}
+    view_index = ntuple(i->:,ndims(llm.A))
+    for i ∈ 1:R
+        view_index = Base.setindex(view_index, Int(I[i]), llm.range_indicies[i])
+    end
+    A_view = @view llm.A[view_index...]
+    return sum(A_view.*v)
+end