comparison src/LazyTensors/lazy_tensor_operations.jl @ 1837:200971c71657 refactor/lazy_tensors/elementwise_ops

Refactor ElementwiseTensorOperation into TensorSum and use TensorNegation for handling subtraction
author Jonatan Werpers <jonatan@werpers.com>
date Thu, 09 Jan 2025 21:46:01 +0100
parents 368999a2e243
children e1077273eda5
comparison
equal deleted inserted replaced
1836:368999a2e243 1837:200971c71657
60 60
61 range_size(tm::TensorNegation) = range_size(tm.tm) 61 range_size(tm::TensorNegation) = range_size(tm.tm)
62 domain_size(tm::TensorNegation) = domain_size(tm.tm) 62 domain_size(tm::TensorNegation) = domain_size(tm.tm)
63 63
64 64
65 struct ElementwiseTensorOperation{Op,T,R,D,TT<:NTuple{N, LazyTensor{T,R,D}} where N} <: LazyTensor{T,R,D} 65 struct TensorSum{T,R,D,TT<:NTuple{N, LazyTensor{T,R,D}} where N} <: LazyTensor{T,R,D}
66 tms::TT 66 tms::TT
67 67
68 function ElementwiseTensorOperation{Op,T,R,D}(tms::TT) where {Op,T,R,D, TT<:NTuple{N, LazyTensor{T,R,D}} where N} 68 function TensorSum{T,R,D}(tms::TT) where {T,R,D, TT<:NTuple{N, LazyTensor{T,R,D}} where N}
69 @boundscheck map(tms) do tm 69 @boundscheck map(tms) do tm
70 check_domain_size(tm, domain_size(tms[1])) 70 check_domain_size(tm, domain_size(tms[1]))
71 check_range_size(tm, range_size(tms[1])) 71 check_range_size(tm, range_size(tms[1]))
72 end 72 end
73 73
74 return new{Op,T,R,D,TT}(tms) 74 return new{T,R,D,TT}(tms)
75 end 75 end
76 end 76 end
77 # TBD: Can we introduce negation of LazyTensors? It could be done generically 77
78 # with a ScalingTensor but also using specializations for specific tensor 78 function TensorSum(ts::Vararg{LazyTensor})
79 # types. This would allow simplification of ElementwiseTensorOperation to 79 T = eltype(ts[1])
80 # TensorSum. The implementation of `-` can be done using negation and the 80 R = range_dim(ts[1])
81 # TensorSum type. We should make sure this doesn't impact the efficiency of 81 D = domain_dim(ts[1])
82 # for example SATs. 82 return TensorSum{T,R,D}(ts)
83
84
85 function ElementwiseTensorOperation{:+}(ts::Vararg{LazyTensor})
86 return ElementwiseTensorOperation{:+,eltype(ts[1]), range_dim(ts[1]), domain_dim(ts[1])}(ts)
87 end 83 end
88 84
89 # The following methods for :+ are intended to reduce the depth of the tree of operations in some caes 85 # The following methods for :+ are intended to reduce the depth of the tree of operations in some caes
90 function ElementwiseTensorOperation{:+}(t1::ElementwiseTensorOperation{:+}, t2::ElementwiseTensorOperation{:+}) 86 function TensorSum(t1::TensorSum, t2::TensorSum)
91 ElementwiseTensorOperation{:+}(t1.tms..., t2.tms...) 87 TensorSum(t1.tms..., t2.tms...)
92 end 88 end
93 89
94 function ElementwiseTensorOperation{:+}(t1::ElementwiseTensorOperation{:+}, t2::LazyTensor) 90 function TensorSum(t1::TensorSum, t2::LazyTensor)
95 ElementwiseTensorOperation{:+}(t1.tms..., t2) 91 TensorSum(t1.tms..., t2)
96 end 92 end
97 93
98 function ElementwiseTensorOperation{:+}(t1::LazyTensor, t2::ElementwiseTensorOperation{:+}) 94 function TensorSum(t1::LazyTensor, t2::TensorSum)
99 ElementwiseTensorOperation{:+}(t1, t2.tms...) 95 TensorSum(t1, t2.tms...)
100 end 96 end
101 97
102 function ElementwiseTensorOperation{:-}(t1::LazyTensor, t2::LazyTensor) 98 function apply(tmBinOp::TensorSum{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D}
103 return ElementwiseTensorOperation{:-,eltype(t1), range_dim(t1), domain_dim(t1)}((t1,t2))
104 end
105
106 function apply(tmBinOp::ElementwiseTensorOperation{:+,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D}
107 vs = map(tmBinOp.tms) do tm 99 vs = map(tmBinOp.tms) do tm
108 apply(tm,v,I...) 100 apply(tm,v,I...)
109 end 101 end
110 102
111 return +(vs...) 103 return +(vs...)
112 end 104 end
113 function apply(tmBinOp::ElementwiseTensorOperation{:-,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} 105
114 apply(tmBinOp.tms[1], v, I...) - apply(tmBinOp.tms[2], v, I...) 106 range_size(tmBinOp::TensorSum) = range_size(tmBinOp.tms[1])
115 end 107 domain_size(tmBinOp::TensorSum) = domain_size(tmBinOp.tms[1])
116
117 range_size(tmBinOp::ElementwiseTensorOperation) = range_size(tmBinOp.tms[1])
118 domain_size(tmBinOp::ElementwiseTensorOperation) = domain_size(tmBinOp.tms[1])
119 108
120 109
121 """ 110 """
122 TensorComposition{T,R,K,D} 111 TensorComposition{T,R,K,D}
123 112