Mercurial > repos > public > sbplib_julia
changeset 999:20cb83efb3f1 refactor/lazy_tensors
More operator definitions to top file
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Fri, 18 Mar 2022 21:43:38 +0100 |
parents | 390dfc3db4b1 |
children | 1091ac8c69ad |
files | src/LazyTensors/LazyTensors.jl src/LazyTensors/lazy_tensor_operations.jl |
diffstat | 2 files changed, 18 insertions(+), 16 deletions(-) [+] |
line wrap: on
line diff
--- a/src/LazyTensors/LazyTensors.jl Fri Mar 18 21:36:17 2022 +0100 +++ b/src/LazyTensors/LazyTensors.jl Fri Mar 18 21:43:38 2022 +0100 @@ -16,4 +16,19 @@ include("lazy_tensor_operations.jl") include("tuple_manipulation.jl") +# Applying lazy tensors to vectors +Base.:*(a::LazyTensor, v::AbstractArray) = LazyTensorApplication(a,v) +Base.:*(a::LazyTensor, b::LazyTensor) = throw(MethodError(Base.:*,(a,b))) +Base.:*(a::LazyTensor, args::Union{LazyTensor, AbstractArray}...) = foldr(*,(a,args...)) + +# Addition and subtraction of lazy tensors +Base.:+(tm1::LazyTensor{T,R,D}, tm2::LazyTensor{T,R,D}) where {T,R,D} = LazyLazyTensorBinaryOperation{:+,T,R,D}(tm1,tm2) +Base.:-(tm1::LazyTensor{T,R,D}, tm2::LazyTensor{T,R,D}) where {T,R,D} = LazyLazyTensorBinaryOperation{:-,T,R,D}(tm1,tm2) + +# Composing lazy tensors +Base.:∘(s::LazyTensor, t::LazyTensor) = LazyTensorComposition(s,t) + +# Outer products of tensors +⊗(a::LazyTensor, b::LazyTensor) = LazyOuterProduct(a,b) + end # module
--- a/src/LazyTensors/lazy_tensor_operations.jl Fri Mar 18 21:36:17 2022 +0100 +++ b/src/LazyTensors/lazy_tensor_operations.jl Fri Mar 18 21:43:38 2022 +0100 @@ -1,4 +1,6 @@ # TBD: Is there a good way to split this file? +# TODO: Split out functions for composition +# TODO: We need to be really careful about good error messages. """ LazyTensorApplication{T,R,D} <: LazyArray{T,R} @@ -26,15 +28,6 @@ Base.size(ta::LazyTensorApplication) = range_size(ta.t) # TODO: What else is needed to implement the AbstractArray interface? -Base.:*(a::LazyTensor, v::AbstractArray) = LazyTensorApplication(a,v) -Base.:*(a::LazyTensor, b::LazyTensor) = throw(MethodError(Base.:*,(a,b))) -Base.:*(a::LazyTensor, args::Union{LazyTensor, AbstractArray}...) = foldr(*,(a,args...)) - -# # We need the associativity to be a→b→c = a→(b→c), which is the case for '→' -# # Should we overload some other infix binary opesrator? -# →(tm::LazyTensor{T,R,D}, o::AbstractArray{T,D}) where {T,R,D} = LazyTensorApplication(tm,o) -# TODO: We need to be really careful about good error messages. -# For example what happens if you try to multiply LazyTensorApplication with a LazyTensor(wrong order)? """ LazyTensorTranspose{T,R,D} <: LazyTensor{T,D,R} @@ -60,7 +53,7 @@ range_size(tmt::LazyTensorTranspose) = domain_size(tmt.tm) domain_size(tmt::LazyTensorTranspose) = range_size(tmt.tm) - +# TODO: Rename this struct LazyLazyTensorBinaryOperation{Op,T,R,D,T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}} <: LazyTensor{T,D,R} tm1::T1 tm2::T2 @@ -77,8 +70,6 @@ range_size(tmBinOp::LazyLazyTensorBinaryOperation) = range_size(tmBinOp.tm1) domain_size(tmBinOp::LazyLazyTensorBinaryOperation) = domain_size(tmBinOp.tm1) -Base.:+(tm1::LazyTensor{T,R,D}, tm2::LazyTensor{T,R,D}) where {T,R,D} = LazyLazyTensorBinaryOperation{:+,T,R,D}(tm1,tm2) -Base.:-(tm1::LazyTensor{T,R,D}, tm2::LazyTensor{T,R,D}) where {T,R,D} = LazyLazyTensorBinaryOperation{:-,T,R,D}(tm1,tm2) """ LazyTensorComposition{T,R,K,D} @@ -106,7 +97,6 @@ apply_transpose(c.t2, c.t1'*v, I...) end -Base.@propagate_inbounds Base.:∘(s::LazyTensor, t::LazyTensor) = LazyTensorComposition(s,t) """ LazyLinearMap{T,R,D,...}(A, range_indicies, domain_indicies) @@ -188,7 +178,6 @@ @boundscheck check_domain_size(tm, range_size(tmi)) return tmi end -# TODO: Move the operator definitions to one place """ ScalingTensor{T,D} <: LazyTensor{T,D,D} @@ -351,8 +340,6 @@ LazyOuterProduct(tms::Vararg{LazyTensor}) = foldl(LazyOuterProduct, tms) -⊗(a::LazyTensor, b::LazyTensor) = LazyOuterProduct(a,b) - function check_domain_size(tm::LazyTensor, sz) if domain_size(tm) != sz