Mercurial > repos > public > sbplib_julia
comparison src/LazyTensors/lazy_tensor_operations.jl @ 1900:418566cdd689
Merge refactor/lazy_tensors/elementwise_ops
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Fri, 31 Jan 2025 20:35:28 +0100 |
parents | ed50eec18365 |
children |
comparison
equal
deleted
inserted
replaced
1896:9d708f3300d5 | 1900:418566cdd689 |
---|---|
50 | 50 |
51 range_size(tmt::TensorTranspose) = domain_size(tmt.tm) | 51 range_size(tmt::TensorTranspose) = domain_size(tmt.tm) |
52 domain_size(tmt::TensorTranspose) = range_size(tmt.tm) | 52 domain_size(tmt::TensorTranspose) = range_size(tmt.tm) |
53 | 53 |
54 | 54 |
55 struct ElementwiseTensorOperation{Op,T,R,D,T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}} <: LazyTensor{T,R,D} | 55 """ |
56 tm1::T1 | 56 TensorNegation{T,R,D,...} <: LazyTensor{T,R,D} |
57 tm2::T2 | 57 |
58 | 58 The negation of a LazyTensor. |
59 function ElementwiseTensorOperation{Op,T,R,D}(tm1::T1,tm2::T2) where {Op,T,R,D, T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}} | 59 """ |
60 @boundscheck check_domain_size(tm2, domain_size(tm1)) | 60 struct TensorNegation{T,R,D,TM<:LazyTensor{T,R,D}} <: LazyTensor{T,R,D} |
61 @boundscheck check_range_size(tm2, range_size(tm1)) | 61 tm::TM |
62 return new{Op,T,R,D,T1,T2}(tm1,tm2) | 62 end |
63 end | 63 |
64 end | 64 apply(tm::TensorNegation, v, I...) = -apply(tm.tm, v, I...) |
65 | 65 apply_transpose(tm::TensorNegation, v, I...) = -apply_transpose(tm.tm, v, I...) |
66 ElementwiseTensorOperation{Op}(s,t) where Op = ElementwiseTensorOperation{Op,eltype(s), range_dim(s), domain_dim(s)}(s,t) | 66 |
67 | 67 range_size(tm::TensorNegation) = range_size(tm.tm) |
68 apply(tmBinOp::ElementwiseTensorOperation{:+,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) + apply(tmBinOp.tm2, v, I...) | 68 domain_size(tm::TensorNegation) = domain_size(tm.tm) |
69 apply(tmBinOp::ElementwiseTensorOperation{:-,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) - apply(tmBinOp.tm2, v, I...) | 69 |
70 | 70 |
71 range_size(tmBinOp::ElementwiseTensorOperation) = range_size(tmBinOp.tm1) | 71 """ |
72 domain_size(tmBinOp::ElementwiseTensorOperation) = domain_size(tmBinOp.tm1) | 72 TensorSum{T,R,D,...} <: LazyTensor{T,R,D} |
73 | |
74 The lazy sum of 2 or more lazy tensors. | |
75 """ | |
76 struct TensorSum{T,R,D,TT<:NTuple{N, LazyTensor{T,R,D}} where N} <: LazyTensor{T,R,D} | |
77 tms::TT | |
78 | |
79 function TensorSum{T,R,D}(tms::TT) where {T,R,D, TT<:NTuple{N, LazyTensor{T,R,D}} where N} | |
80 @boundscheck map(tms) do tm | |
81 check_domain_size(tm, domain_size(tms[1])) | |
82 check_range_size(tm, range_size(tms[1])) | |
83 end | |
84 | |
85 return new{T,R,D,TT}(tms) | |
86 end | |
87 end | |
88 | |
89 """ | |
90 TensorSum(ts::Vararg{LazyTensor}) | |
91 | |
92 The lazy sum of the tensors `ts`. | |
93 """ | |
94 function TensorSum(ts::Vararg{LazyTensor}) | |
95 T = eltype(ts[1]) | |
96 R = range_dim(ts[1]) | |
97 D = domain_dim(ts[1]) | |
98 return TensorSum{T,R,D}(ts) | |
99 end | |
100 | |
101 function apply(tmBinOp::TensorSum{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} | |
102 return sum(tmBinOp.tms) do tm | |
103 apply(tm,v,I...) | |
104 end | |
105 end | |
106 | |
107 function apply_transpose(tmBinOp::TensorSum{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} | |
108 return sum(tmBinOp.tms) do tm | |
109 apply_transpose(tm,v,I...) | |
110 end | |
111 end | |
112 | |
113 range_size(tmBinOp::TensorSum) = range_size(tmBinOp.tms[1]) | |
114 domain_size(tmBinOp::TensorSum) = domain_size(tmBinOp.tms[1]) | |
73 | 115 |
74 | 116 |
75 """ | 117 """ |
76 TensorComposition{T,R,K,D} | 118 TensorComposition{T,R,K,D} |
77 | 119 |
119 return tmi | 161 return tmi |
120 end | 162 end |
121 | 163 |
122 Base.:*(a::T, tm::LazyTensor{T}) where T = TensorComposition(ScalingTensor{T,range_dim(tm)}(a,range_size(tm)), tm) | 164 Base.:*(a::T, tm::LazyTensor{T}) where T = TensorComposition(ScalingTensor{T,range_dim(tm)}(a,range_size(tm)), tm) |
123 Base.:*(tm::LazyTensor{T}, a::T) where T = a*tm | 165 Base.:*(tm::LazyTensor{T}, a::T) where T = a*tm |
124 Base.:-(tm::LazyTensor) = (-one(eltype(tm)))*tm | |
125 | 166 |
126 """ | 167 """ |
127 InflatedTensor{T,R,D} <: LazyTensor{T,R,D} | 168 InflatedTensor{T,R,D} <: LazyTensor{T,R,D} |
128 | 169 |
129 An inflated `LazyTensor` with dimensions added before and after its actual dimensions. | 170 An inflated `LazyTensor` with dimensions added before and after its actual dimensions. |