Mercurial > repos > public > sbplib_julia
comparison src/LazyTensors/lazy_tensor_operations.jl @ 1232:a8fa8c1137cc refactor/grids
Merge refactor/LazyTensors/tuple_manipulation
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Sun, 19 Feb 2023 22:07:57 +0100 |
parents | 8f4259fbd39c |
children | aa8579b7fc15 |
comparison
equal
deleted
inserted
replaced
1222:5f677cd6f0b6 | 1232:a8fa8c1137cc |
---|---|
174 InflatedTensor(I1::IdentityTensor{T}, I2::IdentityTensor{T}) where T = InflatedTensor(I1,I2,IdentityTensor{T}()) | 174 InflatedTensor(I1::IdentityTensor{T}, I2::IdentityTensor{T}) where T = InflatedTensor(I1,I2,IdentityTensor{T}()) |
175 | 175 |
176 # TODO: Implement some pretty printing in terms of ⊗. E.g InflatedTensor(I(3),B,I(2)) -> I(3)⊗B⊗I(2) | 176 # TODO: Implement some pretty printing in terms of ⊗. E.g InflatedTensor(I(3),B,I(2)) -> I(3)⊗B⊗I(2) |
177 | 177 |
178 function range_size(itm::InflatedTensor) | 178 function range_size(itm::InflatedTensor) |
179 return flatten_tuple( | 179 return concatenate_tuples( |
180 range_size(itm.before), | 180 range_size(itm.before), |
181 range_size(itm.tm), | 181 range_size(itm.tm), |
182 range_size(itm.after), | 182 range_size(itm.after), |
183 ) | 183 ) |
184 end | 184 end |
185 | 185 |
186 function domain_size(itm::InflatedTensor) | 186 function domain_size(itm::InflatedTensor) |
187 return flatten_tuple( | 187 return concatenate_tuples( |
188 domain_size(itm.before), | 188 domain_size(itm.before), |
189 domain_size(itm.tm), | 189 domain_size(itm.tm), |
190 domain_size(itm.after), | 190 domain_size(itm.after), |
191 ) | 191 ) |
192 end | 192 end |
195 dim_before = range_dim(itm.before) | 195 dim_before = range_dim(itm.before) |
196 dim_domain = domain_dim(itm.tm) | 196 dim_domain = domain_dim(itm.tm) |
197 dim_range = range_dim(itm.tm) | 197 dim_range = range_dim(itm.tm) |
198 dim_after = range_dim(itm.after) | 198 dim_after = range_dim(itm.after) |
199 | 199 |
200 view_index, inner_index = split_index(Val(dim_before), Val(dim_domain), Val(dim_range), Val(dim_after), I...) | 200 view_index, inner_index = split_index(dim_before, dim_domain, dim_range, dim_after, I...) |
201 | 201 |
202 v_inner = view(v, view_index...) | 202 v_inner = view(v, view_index...) |
203 return apply(itm.tm, v_inner, inner_index...) | 203 return apply(itm.tm, v_inner, inner_index...) |
204 end | 204 end |
205 | 205 |
207 dim_before = range_dim(itm.before) | 207 dim_before = range_dim(itm.before) |
208 dim_domain = domain_dim(itm.tm) | 208 dim_domain = domain_dim(itm.tm) |
209 dim_range = range_dim(itm.tm) | 209 dim_range = range_dim(itm.tm) |
210 dim_after = range_dim(itm.after) | 210 dim_after = range_dim(itm.after) |
211 | 211 |
212 view_index, inner_index = split_index(Val(dim_before), Val(dim_range), Val(dim_domain), Val(dim_after), I...) | 212 view_index, inner_index = split_index(dim_before, dim_range, dim_domain, dim_after, I...) |
213 | 213 |
214 v_inner = view(v, view_index...) | 214 v_inner = view(v, view_index...) |
215 return apply_transpose(itm.tm, v_inner, inner_index...) | 215 return apply_transpose(itm.tm, v_inner, inner_index...) |
216 end | 216 end |
217 | 217 |