Mercurial > repos > public > sbplib_julia
diff src/LazyTensors/lazy_tensor_operations.jl @ 493:df566372bb4f feature/avoid_nested_inflated_tensormappings
Implement constructors to avoid creating nested InflatedTensorMappings
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Thu, 05 Nov 2020 13:18:24 +0100 |
parents | 6a6b7eaf9edf |
children | 5a600ec40ccc |
line wrap: on
line diff
--- a/src/LazyTensors/lazy_tensor_operations.jl Thu Nov 05 11:46:03 2020 +0100 +++ b/src/LazyTensors/lazy_tensor_operations.jl Thu Nov 05 13:18:24 2020 +0100 @@ -221,8 +221,20 @@ The outer product of `before`, `tm` and `after`, where `before` and `after` are `IdentityMapping`s. If one of `before` or `after` is left out, a 0-dimensional `IdentityMapping` is used as the default value. + +If `tm` already is an `InflatedTensorMapping`, `before` and `after` will be extended instead of +creating a nested `InflatedTensorMapping`. """ InflatedTensorMapping(::IdentityMapping, ::TensorMapping, ::IdentityMapping) + +function InflatedTensorMapping(before, itm::InflatedTensorMapping, after) + return InflatedTensorMapping( + IdentityMapping(before.size..., itm.before.size...), + itm.tm, + IdentityMapping(itm.after.size..., after.size...), + ) +end + InflatedTensorMapping(before::IdentityMapping, tm::TensorMapping{T}) where T = InflatedTensorMapping(before,tm,IdentityMapping{T}()) InflatedTensorMapping(tm::TensorMapping{T}, after::IdentityMapping) where T = InflatedTensorMapping(IdentityMapping{T}(),tm,after) # Resolve ambiguity between the two previous methods