Mercurial > repos > public > sbplib_julia
comparison src/LazyTensors/lazy_tensor_operations.jl @ 1877:21e5fe1545c0 refactor/lazy_tensors/elementwise_ops
Merge default
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Mon, 27 Jan 2025 16:56:04 +0100 |
parents | cb3a8450ed44 164e26a6cf79 |
children | b12e28a03b2e |
comparison
equal
deleted
inserted
replaced
1840:cb3a8450ed44 | 1877:21e5fe1545c0 |
---|---|
198 itm.tm, | 198 itm.tm, |
199 IdentityTensor(itm.after.size..., after.size...), | 199 IdentityTensor(itm.after.size..., after.size...), |
200 ) | 200 ) |
201 end | 201 end |
202 | 202 |
203 InflatedTensor(before::IdentityTensor, tm::LazyTensor{T}) where T = InflatedTensor(before,tm,IdentityTensor{T}()) | 203 InflatedTensor(before::IdentityTensor, tm::LazyTensor) = InflatedTensor(before,tm,IdentityTensor{eltype(tm)}()) |
204 InflatedTensor(tm::LazyTensor{T}, after::IdentityTensor) where T = InflatedTensor(IdentityTensor{T}(),tm,after) | 204 InflatedTensor(tm::LazyTensor, after::IdentityTensor) = InflatedTensor(IdentityTensor{eltype(tm)}(),tm,after) |
205 # Resolve ambiguity between the two previous methods | 205 # Resolve ambiguity between the two previous methods |
206 InflatedTensor(I1::IdentityTensor{T}, I2::IdentityTensor{T}) where T = InflatedTensor(I1,I2,IdentityTensor{T}()) | 206 InflatedTensor(I1::IdentityTensor, I2::IdentityTensor) = InflatedTensor(I1,I2,IdentityTensor{promote_type(eltype(I1), eltype(I2))}()) |
207 | 207 |
208 # TODO: Implement some pretty printing in terms of ⊗. E.g InflatedTensor(I(3),B,I(2)) -> I(3)⊗B⊗I(2) | 208 # TODO: Implement some pretty printing in terms of ⊗. E.g InflatedTensor(I(3),B,I(2)) -> I(3)⊗B⊗I(2) |
209 | 209 |
210 function range_size(itm::InflatedTensor) | 210 function range_size(itm::InflatedTensor) |
211 return concatenate_tuples( | 211 return concatenate_tuples( |
292 itm2 = InflatedTensor(IdentityTensor{T}(domain_size(tm1)),tm2) | 292 itm2 = InflatedTensor(IdentityTensor{T}(domain_size(tm1)),tm2) |
293 | 293 |
294 return itm1∘itm2 | 294 return itm1∘itm2 |
295 end | 295 end |
296 | 296 |
297 LazyOuterProduct(t1::IdentityTensor{T}, t2::IdentityTensor{T}) where T = IdentityTensor{T}(t1.size...,t2.size...) | 297 LazyOuterProduct(t1::IdentityTensor, t2::IdentityTensor) = IdentityTensor{promote_type(eltype(t1),eltype(t2))}(t1.size...,t2.size...) |
298 LazyOuterProduct(t1::LazyTensor, t2::IdentityTensor) = InflatedTensor(t1, t2) | 298 LazyOuterProduct(t1::LazyTensor, t2::IdentityTensor) = InflatedTensor(t1, t2) |
299 LazyOuterProduct(t1::IdentityTensor, t2::LazyTensor) = InflatedTensor(t1, t2) | 299 LazyOuterProduct(t1::IdentityTensor, t2::LazyTensor) = InflatedTensor(t1, t2) |
300 | 300 |
301 LazyOuterProduct(tms::Vararg{LazyTensor}) = foldl(LazyOuterProduct, tms) | 301 LazyOuterProduct(tms::Vararg{LazyTensor}) = foldl(LazyOuterProduct, tms) |
302 | 302 |