Mercurial > repos > public > sbplib_julia
comparison src/LazyTensors/lazy_tensor_operations.jl @ 1365:4684c7f1c4cb feature/variable_derivatives
Merge with default
author | Vidar Stiernström <vidar.stiernstrom@it.uu.se> |
---|---|
date | Sun, 21 May 2023 21:55:14 +0200 |
parents | 74ceac9c91e4 aa8579b7fc15 |
children | bdcdbd4ea9cd d7bc11053951 |
comparison
equal
deleted
inserted
replaced
1358:e7861cfb6ede | 1365:4684c7f1c4cb |
---|---|
50 | 50 |
51 range_size(tmt::TensorTranspose) = domain_size(tmt.tm) | 51 range_size(tmt::TensorTranspose) = domain_size(tmt.tm) |
52 domain_size(tmt::TensorTranspose) = range_size(tmt.tm) | 52 domain_size(tmt::TensorTranspose) = range_size(tmt.tm) |
53 | 53 |
54 | 54 |
55 struct ElementwiseTensorOperation{Op,T,R,D,T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}} <: LazyTensor{T,D,R} | 55 struct ElementwiseTensorOperation{Op,T,R,D,T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}} <: LazyTensor{T,R,D} |
56 tm1::T1 | 56 tm1::T1 |
57 tm2::T2 | 57 tm2::T2 |
58 | 58 |
59 function ElementwiseTensorOperation{Op,T,R,D}(tm1::T1,tm2::T2) where {Op,T,R,D, T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}} | 59 function ElementwiseTensorOperation{Op,T,R,D}(tm1::T1,tm2::T2) where {Op,T,R,D, T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}} |
60 @boundscheck check_domain_size(tm2, domain_size(tm1)) | 60 @boundscheck check_domain_size(tm2, domain_size(tm1)) |
174 InflatedTensor(I1::IdentityTensor{T}, I2::IdentityTensor{T}) where T = InflatedTensor(I1,I2,IdentityTensor{T}()) | 174 InflatedTensor(I1::IdentityTensor{T}, I2::IdentityTensor{T}) where T = InflatedTensor(I1,I2,IdentityTensor{T}()) |
175 | 175 |
176 # TODO: Implement some pretty printing in terms of ⊗. E.g InflatedTensor(I(3),B,I(2)) -> I(3)⊗B⊗I(2) | 176 # TODO: Implement some pretty printing in terms of ⊗. E.g InflatedTensor(I(3),B,I(2)) -> I(3)⊗B⊗I(2) |
177 | 177 |
178 function range_size(itm::InflatedTensor) | 178 function range_size(itm::InflatedTensor) |
179 return flatten_tuple( | 179 return concatenate_tuples( |
180 range_size(itm.before), | 180 range_size(itm.before), |
181 range_size(itm.tm), | 181 range_size(itm.tm), |
182 range_size(itm.after), | 182 range_size(itm.after), |
183 ) | 183 ) |
184 end | 184 end |
185 | 185 |
186 function domain_size(itm::InflatedTensor) | 186 function domain_size(itm::InflatedTensor) |
187 return flatten_tuple( | 187 return concatenate_tuples( |
188 domain_size(itm.before), | 188 domain_size(itm.before), |
189 domain_size(itm.tm), | 189 domain_size(itm.tm), |
190 domain_size(itm.after), | 190 domain_size(itm.after), |
191 ) | 191 ) |
192 end | 192 end |
195 dim_before = range_dim(itm.before) | 195 dim_before = range_dim(itm.before) |
196 dim_domain = domain_dim(itm.tm) | 196 dim_domain = domain_dim(itm.tm) |
197 dim_range = range_dim(itm.tm) | 197 dim_range = range_dim(itm.tm) |
198 dim_after = range_dim(itm.after) | 198 dim_after = range_dim(itm.after) |
199 | 199 |
200 view_index, inner_index = split_index(Val(dim_before), Val(dim_domain), Val(dim_range), Val(dim_after), I...) | 200 view_index, inner_index = split_index(dim_before, dim_domain, dim_range, dim_after, I...) |
201 | 201 |
202 v_inner = view(v, view_index...) | 202 v_inner = view(v, view_index...) |
203 return apply(itm.tm, v_inner, inner_index...) | 203 return apply(itm.tm, v_inner, inner_index...) |
204 end | 204 end |
205 | 205 |
207 dim_before = range_dim(itm.before) | 207 dim_before = range_dim(itm.before) |
208 dim_domain = domain_dim(itm.tm) | 208 dim_domain = domain_dim(itm.tm) |
209 dim_range = range_dim(itm.tm) | 209 dim_range = range_dim(itm.tm) |
210 dim_after = range_dim(itm.after) | 210 dim_after = range_dim(itm.after) |
211 | 211 |
212 view_index, inner_index = split_index(Val(dim_before), Val(dim_range), Val(dim_domain), Val(dim_after), I...) | 212 view_index, inner_index = split_index(dim_before, dim_range, dim_domain, dim_after, I...) |
213 | 213 |
214 v_inner = view(v, view_index...) | 214 v_inner = view(v, view_index...) |
215 return apply_transpose(itm.tm, v_inner, inner_index...) | 215 return apply_transpose(itm.tm, v_inner, inner_index...) |
216 end | 216 end |
217 | 217 |