Mercurial > repos > public > sbplib_julia
comparison src/LazyTensors/lazy_tensor_operations.jl @ 1837:200971c71657 refactor/lazy_tensors/elementwise_ops
Refactor ElementwiseTensorOperation into TensorSum and use TensorNegation for handling subtraction
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Thu, 09 Jan 2025 21:46:01 +0100 |
parents | 368999a2e243 |
children | e1077273eda5 |
comparison
equal
deleted
inserted
replaced
1836:368999a2e243 | 1837:200971c71657 |
---|---|
60 | 60 |
61 range_size(tm::TensorNegation) = range_size(tm.tm) | 61 range_size(tm::TensorNegation) = range_size(tm.tm) |
62 domain_size(tm::TensorNegation) = domain_size(tm.tm) | 62 domain_size(tm::TensorNegation) = domain_size(tm.tm) |
63 | 63 |
64 | 64 |
65 struct ElementwiseTensorOperation{Op,T,R,D,TT<:NTuple{N, LazyTensor{T,R,D}} where N} <: LazyTensor{T,R,D} | 65 struct TensorSum{T,R,D,TT<:NTuple{N, LazyTensor{T,R,D}} where N} <: LazyTensor{T,R,D} |
66 tms::TT | 66 tms::TT |
67 | 67 |
68 function ElementwiseTensorOperation{Op,T,R,D}(tms::TT) where {Op,T,R,D, TT<:NTuple{N, LazyTensor{T,R,D}} where N} | 68 function TensorSum{T,R,D}(tms::TT) where {T,R,D, TT<:NTuple{N, LazyTensor{T,R,D}} where N} |
69 @boundscheck map(tms) do tm | 69 @boundscheck map(tms) do tm |
70 check_domain_size(tm, domain_size(tms[1])) | 70 check_domain_size(tm, domain_size(tms[1])) |
71 check_range_size(tm, range_size(tms[1])) | 71 check_range_size(tm, range_size(tms[1])) |
72 end | 72 end |
73 | 73 |
74 return new{Op,T,R,D,TT}(tms) | 74 return new{T,R,D,TT}(tms) |
75 end | 75 end |
76 end | 76 end |
77 # TBD: Can we introduce negation of LazyTensors? It could be done generically | 77 |
78 # with a ScalingTensor but also using specializations for specific tensor | 78 function TensorSum(ts::Vararg{LazyTensor}) |
79 # types. This would allow simplification of ElementwiseTensorOperation to | 79 T = eltype(ts[1]) |
80 # TensorSum. The implementation of `-` can be done using negation and the | 80 R = range_dim(ts[1]) |
81 # TensorSum type. We should make sure this doesn't impact the efficiency of | 81 D = domain_dim(ts[1]) |
82 # for example SATs. | 82 return TensorSum{T,R,D}(ts) |
83 | |
84 | |
85 function ElementwiseTensorOperation{:+}(ts::Vararg{LazyTensor}) | |
86 return ElementwiseTensorOperation{:+,eltype(ts[1]), range_dim(ts[1]), domain_dim(ts[1])}(ts) | |
87 end | 83 end |
88 | 84 |
89 # The following methods for :+ are intended to reduce the depth of the tree of operations in some caes | 85 # The following methods for :+ are intended to reduce the depth of the tree of operations in some caes |
90 function ElementwiseTensorOperation{:+}(t1::ElementwiseTensorOperation{:+}, t2::ElementwiseTensorOperation{:+}) | 86 function TensorSum(t1::TensorSum, t2::TensorSum) |
91 ElementwiseTensorOperation{:+}(t1.tms..., t2.tms...) | 87 TensorSum(t1.tms..., t2.tms...) |
92 end | 88 end |
93 | 89 |
94 function ElementwiseTensorOperation{:+}(t1::ElementwiseTensorOperation{:+}, t2::LazyTensor) | 90 function TensorSum(t1::TensorSum, t2::LazyTensor) |
95 ElementwiseTensorOperation{:+}(t1.tms..., t2) | 91 TensorSum(t1.tms..., t2) |
96 end | 92 end |
97 | 93 |
98 function ElementwiseTensorOperation{:+}(t1::LazyTensor, t2::ElementwiseTensorOperation{:+}) | 94 function TensorSum(t1::LazyTensor, t2::TensorSum) |
99 ElementwiseTensorOperation{:+}(t1, t2.tms...) | 95 TensorSum(t1, t2.tms...) |
100 end | 96 end |
101 | 97 |
102 function ElementwiseTensorOperation{:-}(t1::LazyTensor, t2::LazyTensor) | 98 function apply(tmBinOp::TensorSum{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} |
103 return ElementwiseTensorOperation{:-,eltype(t1), range_dim(t1), domain_dim(t1)}((t1,t2)) | |
104 end | |
105 | |
106 function apply(tmBinOp::ElementwiseTensorOperation{:+,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} | |
107 vs = map(tmBinOp.tms) do tm | 99 vs = map(tmBinOp.tms) do tm |
108 apply(tm,v,I...) | 100 apply(tm,v,I...) |
109 end | 101 end |
110 | 102 |
111 return +(vs...) | 103 return +(vs...) |
112 end | 104 end |
113 function apply(tmBinOp::ElementwiseTensorOperation{:-,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} | 105 |
114 apply(tmBinOp.tms[1], v, I...) - apply(tmBinOp.tms[2], v, I...) | 106 range_size(tmBinOp::TensorSum) = range_size(tmBinOp.tms[1]) |
115 end | 107 domain_size(tmBinOp::TensorSum) = domain_size(tmBinOp.tms[1]) |
116 | |
117 range_size(tmBinOp::ElementwiseTensorOperation) = range_size(tmBinOp.tms[1]) | |
118 domain_size(tmBinOp::ElementwiseTensorOperation) = domain_size(tmBinOp.tms[1]) | |
119 | 108 |
120 | 109 |
121 """ | 110 """ |
122 TensorComposition{T,R,K,D} | 111 TensorComposition{T,R,K,D} |
123 | 112 |