comparison src/LazyTensors/lazy_tensor_operations.jl @ 1232:a8fa8c1137cc refactor/grids

Merge refactor/LazyTensors/tuple_manipulation
author Jonatan Werpers <jonatan@werpers.com>
date Sun, 19 Feb 2023 22:07:57 +0100
parents 8f4259fbd39c
children aa8579b7fc15
comparison
equal deleted inserted replaced
1222:5f677cd6f0b6 1232:a8fa8c1137cc
174 InflatedTensor(I1::IdentityTensor{T}, I2::IdentityTensor{T}) where T = InflatedTensor(I1,I2,IdentityTensor{T}()) 174 InflatedTensor(I1::IdentityTensor{T}, I2::IdentityTensor{T}) where T = InflatedTensor(I1,I2,IdentityTensor{T}())
175 175
176 # TODO: Implement some pretty printing in terms of ⊗. E.g InflatedTensor(I(3),B,I(2)) -> I(3)⊗B⊗I(2) 176 # TODO: Implement some pretty printing in terms of ⊗. E.g InflatedTensor(I(3),B,I(2)) -> I(3)⊗B⊗I(2)
177 177
178 function range_size(itm::InflatedTensor) 178 function range_size(itm::InflatedTensor)
179 return flatten_tuple( 179 return concatenate_tuples(
180 range_size(itm.before), 180 range_size(itm.before),
181 range_size(itm.tm), 181 range_size(itm.tm),
182 range_size(itm.after), 182 range_size(itm.after),
183 ) 183 )
184 end 184 end
185 185
186 function domain_size(itm::InflatedTensor) 186 function domain_size(itm::InflatedTensor)
187 return flatten_tuple( 187 return concatenate_tuples(
188 domain_size(itm.before), 188 domain_size(itm.before),
189 domain_size(itm.tm), 189 domain_size(itm.tm),
190 domain_size(itm.after), 190 domain_size(itm.after),
191 ) 191 )
192 end 192 end
195 dim_before = range_dim(itm.before) 195 dim_before = range_dim(itm.before)
196 dim_domain = domain_dim(itm.tm) 196 dim_domain = domain_dim(itm.tm)
197 dim_range = range_dim(itm.tm) 197 dim_range = range_dim(itm.tm)
198 dim_after = range_dim(itm.after) 198 dim_after = range_dim(itm.after)
199 199
200 view_index, inner_index = split_index(Val(dim_before), Val(dim_domain), Val(dim_range), Val(dim_after), I...) 200 view_index, inner_index = split_index(dim_before, dim_domain, dim_range, dim_after, I...)
201 201
202 v_inner = view(v, view_index...) 202 v_inner = view(v, view_index...)
203 return apply(itm.tm, v_inner, inner_index...) 203 return apply(itm.tm, v_inner, inner_index...)
204 end 204 end
205 205
207 dim_before = range_dim(itm.before) 207 dim_before = range_dim(itm.before)
208 dim_domain = domain_dim(itm.tm) 208 dim_domain = domain_dim(itm.tm)
209 dim_range = range_dim(itm.tm) 209 dim_range = range_dim(itm.tm)
210 dim_after = range_dim(itm.after) 210 dim_after = range_dim(itm.after)
211 211
212 view_index, inner_index = split_index(Val(dim_before), Val(dim_range), Val(dim_domain), Val(dim_after), I...) 212 view_index, inner_index = split_index(dim_before, dim_range, dim_domain, dim_after, I...)
213 213
214 v_inner = view(v, view_index...) 214 v_inner = view(v, view_index...)
215 return apply_transpose(itm.tm, v_inner, inner_index...) 215 return apply_transpose(itm.tm, v_inner, inner_index...)
216 end 216 end
217 217