comparison src/LazyTensors/lazy_tensor_operations.jl @ 1365:4684c7f1c4cb feature/variable_derivatives

Merge with default
author Vidar Stiernström <vidar.stiernstrom@it.uu.se>
date Sun, 21 May 2023 21:55:14 +0200
parents 74ceac9c91e4 aa8579b7fc15
children bdcdbd4ea9cd d7bc11053951
comparison
equal deleted inserted replaced
1358:e7861cfb6ede 1365:4684c7f1c4cb
50 50
51 range_size(tmt::TensorTranspose) = domain_size(tmt.tm) 51 range_size(tmt::TensorTranspose) = domain_size(tmt.tm)
52 domain_size(tmt::TensorTranspose) = range_size(tmt.tm) 52 domain_size(tmt::TensorTranspose) = range_size(tmt.tm)
53 53
54 54
55 struct ElementwiseTensorOperation{Op,T,R,D,T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}} <: LazyTensor{T,D,R} 55 struct ElementwiseTensorOperation{Op,T,R,D,T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}} <: LazyTensor{T,R,D}
56 tm1::T1 56 tm1::T1
57 tm2::T2 57 tm2::T2
58 58
59 function ElementwiseTensorOperation{Op,T,R,D}(tm1::T1,tm2::T2) where {Op,T,R,D, T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}} 59 function ElementwiseTensorOperation{Op,T,R,D}(tm1::T1,tm2::T2) where {Op,T,R,D, T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}}
60 @boundscheck check_domain_size(tm2, domain_size(tm1)) 60 @boundscheck check_domain_size(tm2, domain_size(tm1))
174 InflatedTensor(I1::IdentityTensor{T}, I2::IdentityTensor{T}) where T = InflatedTensor(I1,I2,IdentityTensor{T}()) 174 InflatedTensor(I1::IdentityTensor{T}, I2::IdentityTensor{T}) where T = InflatedTensor(I1,I2,IdentityTensor{T}())
175 175
176 # TODO: Implement some pretty printing in terms of ⊗. E.g InflatedTensor(I(3),B,I(2)) -> I(3)⊗B⊗I(2) 176 # TODO: Implement some pretty printing in terms of ⊗. E.g InflatedTensor(I(3),B,I(2)) -> I(3)⊗B⊗I(2)
177 177
178 function range_size(itm::InflatedTensor) 178 function range_size(itm::InflatedTensor)
179 return flatten_tuple( 179 return concatenate_tuples(
180 range_size(itm.before), 180 range_size(itm.before),
181 range_size(itm.tm), 181 range_size(itm.tm),
182 range_size(itm.after), 182 range_size(itm.after),
183 ) 183 )
184 end 184 end
185 185
186 function domain_size(itm::InflatedTensor) 186 function domain_size(itm::InflatedTensor)
187 return flatten_tuple( 187 return concatenate_tuples(
188 domain_size(itm.before), 188 domain_size(itm.before),
189 domain_size(itm.tm), 189 domain_size(itm.tm),
190 domain_size(itm.after), 190 domain_size(itm.after),
191 ) 191 )
192 end 192 end
195 dim_before = range_dim(itm.before) 195 dim_before = range_dim(itm.before)
196 dim_domain = domain_dim(itm.tm) 196 dim_domain = domain_dim(itm.tm)
197 dim_range = range_dim(itm.tm) 197 dim_range = range_dim(itm.tm)
198 dim_after = range_dim(itm.after) 198 dim_after = range_dim(itm.after)
199 199
200 view_index, inner_index = split_index(Val(dim_before), Val(dim_domain), Val(dim_range), Val(dim_after), I...) 200 view_index, inner_index = split_index(dim_before, dim_domain, dim_range, dim_after, I...)
201 201
202 v_inner = view(v, view_index...) 202 v_inner = view(v, view_index...)
203 return apply(itm.tm, v_inner, inner_index...) 203 return apply(itm.tm, v_inner, inner_index...)
204 end 204 end
205 205
207 dim_before = range_dim(itm.before) 207 dim_before = range_dim(itm.before)
208 dim_domain = domain_dim(itm.tm) 208 dim_domain = domain_dim(itm.tm)
209 dim_range = range_dim(itm.tm) 209 dim_range = range_dim(itm.tm)
210 dim_after = range_dim(itm.after) 210 dim_after = range_dim(itm.after)
211 211
212 view_index, inner_index = split_index(Val(dim_before), Val(dim_range), Val(dim_domain), Val(dim_after), I...) 212 view_index, inner_index = split_index(dim_before, dim_range, dim_domain, dim_after, I...)
213 213
214 v_inner = view(v, view_index...) 214 v_inner = view(v, view_index...)
215 return apply_transpose(itm.tm, v_inner, inner_index...) 215 return apply_transpose(itm.tm, v_inner, inner_index...)
216 end 216 end
217 217