changeset 999:20cb83efb3f1 refactor/lazy_tensors

More operator definitions to top file
author Jonatan Werpers <jonatan@werpers.com>
date Fri, 18 Mar 2022 21:43:38 +0100
parents 390dfc3db4b1
children 1091ac8c69ad
files src/LazyTensors/LazyTensors.jl src/LazyTensors/lazy_tensor_operations.jl
diffstat 2 files changed, 18 insertions(+), 16 deletions(-) [+]
line wrap: on
line diff
--- a/src/LazyTensors/LazyTensors.jl	Fri Mar 18 21:36:17 2022 +0100
+++ b/src/LazyTensors/LazyTensors.jl	Fri Mar 18 21:43:38 2022 +0100
@@ -16,4 +16,19 @@
 include("lazy_tensor_operations.jl")
 include("tuple_manipulation.jl")
 
+# Applying lazy tensors to vectors
+Base.:*(a::LazyTensor, v::AbstractArray) = LazyTensorApplication(a,v)
+Base.:*(a::LazyTensor, b::LazyTensor) = throw(MethodError(Base.:*,(a,b)))
+Base.:*(a::LazyTensor, args::Union{LazyTensor, AbstractArray}...) = foldr(*,(a,args...))
+
+# Addition and subtraction of lazy tensors
+Base.:+(tm1::LazyTensor{T,R,D}, tm2::LazyTensor{T,R,D}) where {T,R,D} = LazyLazyTensorBinaryOperation{:+,T,R,D}(tm1,tm2)
+Base.:-(tm1::LazyTensor{T,R,D}, tm2::LazyTensor{T,R,D}) where {T,R,D} = LazyLazyTensorBinaryOperation{:-,T,R,D}(tm1,tm2)
+
+# Composing lazy tensors
+Base.:∘(s::LazyTensor, t::LazyTensor) = LazyTensorComposition(s,t)
+
+# Outer products of tensors
+⊗(a::LazyTensor, b::LazyTensor) = LazyOuterProduct(a,b)
+
 end # module
--- a/src/LazyTensors/lazy_tensor_operations.jl	Fri Mar 18 21:36:17 2022 +0100
+++ b/src/LazyTensors/lazy_tensor_operations.jl	Fri Mar 18 21:43:38 2022 +0100
@@ -1,4 +1,6 @@
 # TBD: Is there a good way to split this file?
+# TODO: Split out functions for composition
+# TODO: We need to be really careful about good error messages.
 
 """
     LazyTensorApplication{T,R,D} <: LazyArray{T,R}
@@ -26,15 +28,6 @@
 Base.size(ta::LazyTensorApplication) = range_size(ta.t)
 # TODO: What else is needed to implement the AbstractArray interface?
 
-Base.:*(a::LazyTensor, v::AbstractArray) = LazyTensorApplication(a,v)
-Base.:*(a::LazyTensor, b::LazyTensor) = throw(MethodError(Base.:*,(a,b)))
-Base.:*(a::LazyTensor, args::Union{LazyTensor, AbstractArray}...) = foldr(*,(a,args...))
-
-# # We need the associativity to be a→b→c = a→(b→c), which is the case for '→'
-# # Should we overload some other infix binary opesrator?
-# →(tm::LazyTensor{T,R,D}, o::AbstractArray{T,D}) where {T,R,D} = LazyTensorApplication(tm,o)
-# TODO: We need to be really careful about good error messages.
-# For example what happens if you try to multiply LazyTensorApplication with a LazyTensor(wrong order)?
 
 """
     LazyTensorTranspose{T,R,D} <: LazyTensor{T,D,R}
@@ -60,7 +53,7 @@
 range_size(tmt::LazyTensorTranspose) = domain_size(tmt.tm)
 domain_size(tmt::LazyTensorTranspose) = range_size(tmt.tm)
 
-
+# TODO: Rename this
 struct LazyLazyTensorBinaryOperation{Op,T,R,D,T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}} <: LazyTensor{T,D,R}
     tm1::T1
     tm2::T2
@@ -77,8 +70,6 @@
 range_size(tmBinOp::LazyLazyTensorBinaryOperation) = range_size(tmBinOp.tm1)
 domain_size(tmBinOp::LazyLazyTensorBinaryOperation) = domain_size(tmBinOp.tm1)
 
-Base.:+(tm1::LazyTensor{T,R,D}, tm2::LazyTensor{T,R,D}) where {T,R,D} = LazyLazyTensorBinaryOperation{:+,T,R,D}(tm1,tm2)
-Base.:-(tm1::LazyTensor{T,R,D}, tm2::LazyTensor{T,R,D}) where {T,R,D} = LazyLazyTensorBinaryOperation{:-,T,R,D}(tm1,tm2)
 
 """
     LazyTensorComposition{T,R,K,D}
@@ -106,7 +97,6 @@
     apply_transpose(c.t2, c.t1'*v, I...)
 end
 
-Base.@propagate_inbounds Base.:∘(s::LazyTensor, t::LazyTensor) = LazyTensorComposition(s,t)
 
 """
     LazyLinearMap{T,R,D,...}(A, range_indicies, domain_indicies)
@@ -188,7 +178,6 @@
     @boundscheck check_domain_size(tm, range_size(tmi))
     return tmi
 end
-# TODO: Move the operator definitions to one place
 
 """
     ScalingTensor{T,D} <: LazyTensor{T,D,D}
@@ -351,8 +340,6 @@
 
 LazyOuterProduct(tms::Vararg{LazyTensor}) = foldl(LazyOuterProduct, tms)
 
-⊗(a::LazyTensor, b::LazyTensor) = LazyOuterProduct(a,b)
-
 
 function check_domain_size(tm::LazyTensor, sz)
     if domain_size(tm) != sz