comparison src/LazyTensors/lazy_tensor_operations.jl @ 1073:5a3281429a48 feature/variable_derivatives

Merge feature/variable_derivatives
author Jonatan Werpers <jonatan@werpers.com>
date Thu, 24 Mar 2022 12:35:14 +0100
parents 3bb94ce74697 f857057e61e6
children fa0800591306
comparison
equal deleted inserted replaced
1068:0b0444adacd3 1073:5a3281429a48
267 LazyOuterProduct(t1::LazyTensor, t2::IdentityTensor) = InflatedTensor(t1, t2) 267 LazyOuterProduct(t1::LazyTensor, t2::IdentityTensor) = InflatedTensor(t1, t2)
268 LazyOuterProduct(t1::IdentityTensor, t2::LazyTensor) = InflatedTensor(t1, t2) 268 LazyOuterProduct(t1::IdentityTensor, t2::LazyTensor) = InflatedTensor(t1, t2)
269 269
270 LazyOuterProduct(tms::Vararg{LazyTensor}) = foldl(LazyOuterProduct, tms) 270 LazyOuterProduct(tms::Vararg{LazyTensor}) = foldl(LazyOuterProduct, tms)
271 271
272
273
274 """
275 inflate(tm, sz, dir)
276
277 Inflate `tm` with identity tensors in all directions `d` for `d != dir`.
278
279 # TODO: Describe when it is useful
280 """
281 function inflate(tm::LazyTensor, sz, dir)
282 Is = IdentityTensor{eltype(tm)}.(sz)
283 parts = Base.setindex(Is, tm, dir)
284 return foldl(⊗, parts)
285 end
272 286
273 function check_domain_size(tm::LazyTensor, sz) 287 function check_domain_size(tm::LazyTensor, sz)
274 if domain_size(tm) != sz 288 if domain_size(tm) != sz
275 throw(DomainSizeMismatch(tm,sz)) 289 throw(DomainSizeMismatch(tm,sz))
276 end 290 end