diff src/LazyTensors/lazy_tensor_operations.jl @ 960:e79debd10f7d feature/variable_derivatives

Merge feature/tensormapping_application_promotion
author Jonatan Werpers <jonatan@werpers.com>
date Mon, 14 Mar 2022 10:14:38 +0100
parents e9752c1e92f8 86889fc5b63f
children df562695b1b5
line wrap: on
line diff
--- a/src/LazyTensors/lazy_tensor_operations.jl	Mon Mar 14 09:46:22 2022 +0100
+++ b/src/LazyTensors/lazy_tensor_operations.jl	Mon Mar 14 10:14:38 2022 +0100
@@ -1,5 +1,15 @@
 using Sbplib.RegionIndices
 
+export LazyTensorMappingApplication
+export LazyTensorMappingTranspose
+export TensorMappingComposition
+export LazyLinearMap
+export IdentityMapping
+export InflatedTensorMapping
+export LazyOuterProduct
+export ⊗
+export SizeMismatch
+
 """
     LazyTensorMappingApplication{T,R,D} <: LazyArray{T,R}
 
@@ -9,12 +19,16 @@
 With a mapping `m` and a vector `v` the LazyTensorMappingApplication object can be created by `m*v`.
 The actual result will be calcualted when indexing into `m*v`.
 """
-struct LazyTensorMappingApplication{T,R,D, TM<:TensorMapping{T,R,D}, AA<:AbstractArray{T,D}} <: LazyArray{T,R}
+struct LazyTensorMappingApplication{T,R,D, TM<:TensorMapping{<:Any,R,D}, AA<:AbstractArray{<:Any,D}} <: LazyArray{T,R}
     t::TM
     o::AA
+
+    function LazyTensorMappingApplication(t::TensorMapping{<:Any,R,D}, o::AbstractArray{<:Any,D}) where {R,D}
+        T = promote_type(eltype(t), eltype(o))
+        return new{T,R,D,typeof(t), typeof(o)}(t,o)
+    end
 end
 # TODO: Do boundschecking on creation!
-export LazyTensorMappingApplication
 
 Base.getindex(ta::LazyTensorMappingApplication{T,R}, I::Vararg{Any,R}) where {T,R} = apply(ta.t, ta.o, I...)
 Base.getindex(ta::LazyTensorMappingApplication{T,1}, I::CartesianIndex{1}) where {T} = apply(ta.t, ta.o, I.I...) # Would otherwise be caught in the previous method.
@@ -43,15 +57,14 @@
 struct LazyTensorMappingTranspose{T,R,D, TM<:TensorMapping{T,R,D}} <: TensorMapping{T,D,R}
     tm::TM
 end
-export LazyTensorMappingTranspose
 
 # # TBD: Should this be implemented on a type by type basis or through a trait to provide earlier errors?
 # Jonatan 2020-09-25: Is the problem that you can take the transpose of any TensorMapping even if it doesn't implement `apply_transpose`?
 Base.adjoint(tm::TensorMapping) = LazyTensorMappingTranspose(tm)
 Base.adjoint(tmt::LazyTensorMappingTranspose) = tmt.tm
 
-apply(tmt::LazyTensorMappingTranspose{T,R,D}, v::AbstractArray{T,R}, I::Vararg{Any,D}) where {T,R,D} = apply_transpose(tmt.tm, v, I...)
-apply_transpose(tmt::LazyTensorMappingTranspose{T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmt.tm, v, I...)
+apply(tmt::LazyTensorMappingTranspose{T,R,D}, v::AbstractArray{<:Any,R}, I::Vararg{Any,D}) where {T,R,D} = apply_transpose(tmt.tm, v, I...)
+apply_transpose(tmt::LazyTensorMappingTranspose{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmt.tm, v, I...)
 
 range_size(tmt::LazyTensorMappingTranspose) = domain_size(tmt.tm)
 domain_size(tmt::LazyTensorMappingTranspose) = range_size(tmt.tm)
@@ -67,8 +80,8 @@
 end
 # TODO: Boundschecking in constructor.
 
-apply(tmBinOp::LazyTensorMappingBinaryOperation{:+,T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) + apply(tmBinOp.tm2, v, I...)
-apply(tmBinOp::LazyTensorMappingBinaryOperation{:-,T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) - apply(tmBinOp.tm2, v, I...)
+apply(tmBinOp::LazyTensorMappingBinaryOperation{:+,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) + apply(tmBinOp.tm2, v, I...)
+apply(tmBinOp::LazyTensorMappingBinaryOperation{:-,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) - apply(tmBinOp.tm2, v, I...)
 
 range_size(tmBinOp::LazyTensorMappingBinaryOperation) = range_size(tmBinOp.tm1)
 domain_size(tmBinOp::LazyTensorMappingBinaryOperation) = domain_size(tmBinOp.tm1)
@@ -90,16 +103,15 @@
         return new{T,R,K,D, typeof(t1), typeof(t2)}(t1,t2)
     end
 end
-export TensorMappingComposition
 
 range_size(tm::TensorMappingComposition) = range_size(tm.t1)
 domain_size(tm::TensorMappingComposition) = domain_size(tm.t2)
 
-function apply(c::TensorMappingComposition{T,R,K,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,K,D}
+function apply(c::TensorMappingComposition{T,R,K,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,K,D}
     apply(c.t1, c.t2*v, I...)
 end
 
-function apply_transpose(c::TensorMappingComposition{T,R,K,D}, v::AbstractArray{T,R}, I::Vararg{Any,D}) where {T,R,K,D}
+function apply_transpose(c::TensorMappingComposition{T,R,K,D}, v::AbstractArray{<:Any,R}, I::Vararg{Any,D}) where {T,R,K,D}
     apply_transpose(c.t2, c.t1'*v, I...)
 end
 
@@ -127,12 +139,11 @@
         return new{T,R,D,RD,AA}(A,range_indicies,domain_indicies)
     end
 end
-export LazyLinearMap
 
 range_size(llm::LazyLinearMap) = size(llm.A)[[llm.range_indicies...]]
 domain_size(llm::LazyLinearMap) = size(llm.A)[[llm.domain_indicies...]]
 
-function apply(llm::LazyLinearMap{T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D}
+function apply(llm::LazyLinearMap{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D}
     view_index = ntuple(i->:,ndims(llm.A))
     for i ∈ 1:R
         view_index = Base.setindex(view_index, Int(I[i]), llm.range_indicies[i])
@@ -141,7 +152,7 @@
     return sum(A_view.*v)
 end
 
-function apply_transpose(llm::LazyLinearMap{T,R,D}, v::AbstractArray{T,R}, I::Vararg{Any,D}) where {T,R,D}
+function apply_transpose(llm::LazyLinearMap{T,R,D}, v::AbstractArray{<:Any,R}, I::Vararg{Any,D}) where {T,R,D}
     apply(LazyLinearMap(llm.A, llm.domain_indicies, llm.range_indicies), v, I...)
 end
 
@@ -155,7 +166,6 @@
 struct IdentityMapping{T,D} <: TensorMapping{T,D,D}
     size::NTuple{D,Int}
 end
-export IdentityMapping
 
 IdentityMapping{T}(size::NTuple{D,Int}) where {T,D} = IdentityMapping{T,D}(size)
 IdentityMapping{T}(size::Vararg{Int,D}) where {T,D} = IdentityMapping{T,D}(size)
@@ -164,8 +174,8 @@
 range_size(tmi::IdentityMapping) = tmi.size
 domain_size(tmi::IdentityMapping) = tmi.size
 
-apply(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...]
-apply_transpose(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...]
+apply(tmi::IdentityMapping{T,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,D}) where {T,D} = v[I...]
+apply_transpose(tmi::IdentityMapping{T,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,D}) where {T,D} = v[I...]
 
 """
     Base.:∘(tm, tmi)
@@ -212,7 +222,6 @@
         return new{T,R,D,D_before,R_middle,D_middle,D_after, typeof(tm)}(before, tm, after)
     end
 end
-export InflatedTensorMapping
 """
     InflatedTensorMapping(before, tm, after)
     InflatedTensorMapping(before,tm)
@@ -258,7 +267,7 @@
     )
 end
 
-function apply(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D}
+function apply(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D}
     dim_before = range_dim(itm.before)
     dim_domain = domain_dim(itm.tm)
     dim_range = range_dim(itm.tm)
@@ -270,7 +279,7 @@
     return apply(itm.tm, v_inner, inner_index...)
 end
 
-function apply_transpose(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{T,R}, I::Vararg{Any,D}) where {T,R,D}
+function apply_transpose(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{<:Any,R}, I::Vararg{Any,D}) where {T,R,D}
     dim_before = range_dim(itm.before)
     dim_domain = domain_dim(itm.tm)
     dim_range = range_dim(itm.tm)
@@ -398,7 +407,6 @@
 ```
 """
 function LazyOuterProduct end
-export LazyOuterProduct
 
 function LazyOuterProduct(tm1::TensorMapping{T}, tm2::TensorMapping{T}) where T
     itm1 = InflatedTensorMapping(tm1, IdentityMapping{T}(range_size(tm2)))
@@ -414,7 +422,6 @@
 LazyOuterProduct(tms::Vararg{TensorMapping}) = foldl(LazyOuterProduct, tms)
 
 ⊗(a::TensorMapping, b::TensorMapping) = LazyOuterProduct(a,b)
-export ⊗
 
 
 function check_domain_size(tm::TensorMapping, sz)
@@ -427,7 +434,6 @@
     tm::TensorMapping
     sz
 end
-export SizeMismatch
 
 function Base.showerror(io::IO, err::SizeMismatch)
     print(io, "SizeMismatch: ")