Mercurial > repos > public > sbplib_julia
comparison src/LazyTensors/lazy_tensor_operations.jl @ 999:20cb83efb3f1 refactor/lazy_tensors
More operator definitions to top file
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Fri, 18 Mar 2022 21:43:38 +0100 |
parents | 390dfc3db4b1 |
children | 1091ac8c69ad |
comparison
equal
deleted
inserted
replaced
998:390dfc3db4b1 | 999:20cb83efb3f1 |
---|---|
1 # TBD: Is there a good way to split this file? | 1 # TBD: Is there a good way to split this file? |
2 # TODO: Split out functions for composition | |
3 # TODO: We need to be really careful about good error messages. | |
2 | 4 |
3 """ | 5 """ |
4 LazyTensorApplication{T,R,D} <: LazyArray{T,R} | 6 LazyTensorApplication{T,R,D} <: LazyArray{T,R} |
5 | 7 |
6 Struct for lazy application of a LazyTensor. Created using `*`. | 8 Struct for lazy application of a LazyTensor. Created using `*`. |
24 Base.getindex(ta::LazyTensorApplication{T,R}, I::Vararg{Any,R}) where {T,R} = apply(ta.t, ta.o, I...) | 26 Base.getindex(ta::LazyTensorApplication{T,R}, I::Vararg{Any,R}) where {T,R} = apply(ta.t, ta.o, I...) |
25 Base.getindex(ta::LazyTensorApplication{T,1}, I::CartesianIndex{1}) where {T} = apply(ta.t, ta.o, I.I...) # Would otherwise be caught in the previous method. | 27 Base.getindex(ta::LazyTensorApplication{T,1}, I::CartesianIndex{1}) where {T} = apply(ta.t, ta.o, I.I...) # Would otherwise be caught in the previous method. |
26 Base.size(ta::LazyTensorApplication) = range_size(ta.t) | 28 Base.size(ta::LazyTensorApplication) = range_size(ta.t) |
27 # TODO: What else is needed to implement the AbstractArray interface? | 29 # TODO: What else is needed to implement the AbstractArray interface? |
28 | 30 |
29 Base.:*(a::LazyTensor, v::AbstractArray) = LazyTensorApplication(a,v) | |
30 Base.:*(a::LazyTensor, b::LazyTensor) = throw(MethodError(Base.:*,(a,b))) | |
31 Base.:*(a::LazyTensor, args::Union{LazyTensor, AbstractArray}...) = foldr(*,(a,args...)) | |
32 | |
33 # # We need the associativity to be a→b→c = a→(b→c), which is the case for '→' | |
34 # # Should we overload some other infix binary opesrator? | |
35 # →(tm::LazyTensor{T,R,D}, o::AbstractArray{T,D}) where {T,R,D} = LazyTensorApplication(tm,o) | |
36 # TODO: We need to be really careful about good error messages. | |
37 # For example what happens if you try to multiply LazyTensorApplication with a LazyTensor(wrong order)? | |
38 | 31 |
39 """ | 32 """ |
40 LazyTensorTranspose{T,R,D} <: LazyTensor{T,D,R} | 33 LazyTensorTranspose{T,R,D} <: LazyTensor{T,D,R} |
41 | 34 |
42 Struct for lazy transpose of a LazyTensor. | 35 Struct for lazy transpose of a LazyTensor. |
58 apply_transpose(tmt::LazyTensorTranspose{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmt.tm, v, I...) | 51 apply_transpose(tmt::LazyTensorTranspose{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmt.tm, v, I...) |
59 | 52 |
60 range_size(tmt::LazyTensorTranspose) = domain_size(tmt.tm) | 53 range_size(tmt::LazyTensorTranspose) = domain_size(tmt.tm) |
61 domain_size(tmt::LazyTensorTranspose) = range_size(tmt.tm) | 54 domain_size(tmt::LazyTensorTranspose) = range_size(tmt.tm) |
62 | 55 |
63 | 56 # TODO: Rename this |
64 struct LazyLazyTensorBinaryOperation{Op,T,R,D,T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}} <: LazyTensor{T,D,R} | 57 struct LazyLazyTensorBinaryOperation{Op,T,R,D,T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}} <: LazyTensor{T,D,R} |
65 tm1::T1 | 58 tm1::T1 |
66 tm2::T2 | 59 tm2::T2 |
67 | 60 |
68 @inline function LazyLazyTensorBinaryOperation{Op,T,R,D}(tm1::T1,tm2::T2) where {Op,T,R,D, T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}} | 61 @inline function LazyLazyTensorBinaryOperation{Op,T,R,D}(tm1::T1,tm2::T2) where {Op,T,R,D, T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}} |
75 apply(tmBinOp::LazyLazyTensorBinaryOperation{:-,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) - apply(tmBinOp.tm2, v, I...) | 68 apply(tmBinOp::LazyLazyTensorBinaryOperation{:-,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) - apply(tmBinOp.tm2, v, I...) |
76 | 69 |
77 range_size(tmBinOp::LazyLazyTensorBinaryOperation) = range_size(tmBinOp.tm1) | 70 range_size(tmBinOp::LazyLazyTensorBinaryOperation) = range_size(tmBinOp.tm1) |
78 domain_size(tmBinOp::LazyLazyTensorBinaryOperation) = domain_size(tmBinOp.tm1) | 71 domain_size(tmBinOp::LazyLazyTensorBinaryOperation) = domain_size(tmBinOp.tm1) |
79 | 72 |
80 Base.:+(tm1::LazyTensor{T,R,D}, tm2::LazyTensor{T,R,D}) where {T,R,D} = LazyLazyTensorBinaryOperation{:+,T,R,D}(tm1,tm2) | |
81 Base.:-(tm1::LazyTensor{T,R,D}, tm2::LazyTensor{T,R,D}) where {T,R,D} = LazyLazyTensorBinaryOperation{:-,T,R,D}(tm1,tm2) | |
82 | 73 |
83 """ | 74 """ |
84 LazyTensorComposition{T,R,K,D} | 75 LazyTensorComposition{T,R,K,D} |
85 | 76 |
86 Lazily compose two `LazyTensor`s, so that they can be handled as a single `LazyTensor`. | 77 Lazily compose two `LazyTensor`s, so that they can be handled as a single `LazyTensor`. |
104 | 95 |
105 function apply_transpose(c::LazyTensorComposition{T,R,K,D}, v::AbstractArray{<:Any,R}, I::Vararg{Any,D}) where {T,R,K,D} | 96 function apply_transpose(c::LazyTensorComposition{T,R,K,D}, v::AbstractArray{<:Any,R}, I::Vararg{Any,D}) where {T,R,K,D} |
106 apply_transpose(c.t2, c.t1'*v, I...) | 97 apply_transpose(c.t2, c.t1'*v, I...) |
107 end | 98 end |
108 | 99 |
109 Base.@propagate_inbounds Base.:∘(s::LazyTensor, t::LazyTensor) = LazyTensorComposition(s,t) | |
110 | 100 |
111 """ | 101 """ |
112 LazyLinearMap{T,R,D,...}(A, range_indicies, domain_indicies) | 102 LazyLinearMap{T,R,D,...}(A, range_indicies, domain_indicies) |
113 | 103 |
114 LazyTensor defined by the AbstractArray A. `range_indicies` and `domain_indicies` define which indicies of A should | 104 LazyTensor defined by the AbstractArray A. `range_indicies` and `domain_indicies` define which indicies of A should |
186 # Specialization for the case where tm is an IdentityTensor. Required to resolve ambiguity. | 176 # Specialization for the case where tm is an IdentityTensor. Required to resolve ambiguity. |
187 function LazyTensorComposition(tm::IdentityTensor{T,D}, tmi::IdentityTensor{T,D}) where {T,D} | 177 function LazyTensorComposition(tm::IdentityTensor{T,D}, tmi::IdentityTensor{T,D}) where {T,D} |
188 @boundscheck check_domain_size(tm, range_size(tmi)) | 178 @boundscheck check_domain_size(tm, range_size(tmi)) |
189 return tmi | 179 return tmi |
190 end | 180 end |
191 # TODO: Move the operator definitions to one place | |
192 | 181 |
193 """ | 182 """ |
194 ScalingTensor{T,D} <: LazyTensor{T,D,D} | 183 ScalingTensor{T,D} <: LazyTensor{T,D,D} |
195 | 184 |
196 A lazy tensor that scales its input with `λ`. | 185 A lazy tensor that scales its input with `λ`. |
349 LazyOuterProduct(t1::LazyTensor, t2::IdentityTensor) = InflatedLazyTensor(t1, t2) | 338 LazyOuterProduct(t1::LazyTensor, t2::IdentityTensor) = InflatedLazyTensor(t1, t2) |
350 LazyOuterProduct(t1::IdentityTensor, t2::LazyTensor) = InflatedLazyTensor(t1, t2) | 339 LazyOuterProduct(t1::IdentityTensor, t2::LazyTensor) = InflatedLazyTensor(t1, t2) |
351 | 340 |
352 LazyOuterProduct(tms::Vararg{LazyTensor}) = foldl(LazyOuterProduct, tms) | 341 LazyOuterProduct(tms::Vararg{LazyTensor}) = foldl(LazyOuterProduct, tms) |
353 | 342 |
354 ⊗(a::LazyTensor, b::LazyTensor) = LazyOuterProduct(a,b) | |
355 | |
356 | 343 |
357 function check_domain_size(tm::LazyTensor, sz) | 344 function check_domain_size(tm::LazyTensor, sz) |
358 if domain_size(tm) != sz | 345 if domain_size(tm) != sz |
359 throw(SizeMismatch(tm,sz)) | 346 throw(SizeMismatch(tm,sz)) |
360 end | 347 end |