comparison src/LazyTensors/lazy_tensor_operations.jl @ 1877:21e5fe1545c0 refactor/lazy_tensors/elementwise_ops

Merge default
author Jonatan Werpers <jonatan@werpers.com>
date Mon, 27 Jan 2025 16:56:04 +0100
parents cb3a8450ed44 164e26a6cf79
children b12e28a03b2e
comparison
equal deleted inserted replaced
1840:cb3a8450ed44 1877:21e5fe1545c0
198 itm.tm, 198 itm.tm,
199 IdentityTensor(itm.after.size..., after.size...), 199 IdentityTensor(itm.after.size..., after.size...),
200 ) 200 )
201 end 201 end
202 202
203 InflatedTensor(before::IdentityTensor, tm::LazyTensor{T}) where T = InflatedTensor(before,tm,IdentityTensor{T}()) 203 InflatedTensor(before::IdentityTensor, tm::LazyTensor) = InflatedTensor(before,tm,IdentityTensor{eltype(tm)}())
204 InflatedTensor(tm::LazyTensor{T}, after::IdentityTensor) where T = InflatedTensor(IdentityTensor{T}(),tm,after) 204 InflatedTensor(tm::LazyTensor, after::IdentityTensor) = InflatedTensor(IdentityTensor{eltype(tm)}(),tm,after)
205 # Resolve ambiguity between the two previous methods 205 # Resolve ambiguity between the two previous methods
206 InflatedTensor(I1::IdentityTensor{T}, I2::IdentityTensor{T}) where T = InflatedTensor(I1,I2,IdentityTensor{T}()) 206 InflatedTensor(I1::IdentityTensor, I2::IdentityTensor) = InflatedTensor(I1,I2,IdentityTensor{promote_type(eltype(I1), eltype(I2))}())
207 207
208 # TODO: Implement some pretty printing in terms of ⊗. E.g InflatedTensor(I(3),B,I(2)) -> I(3)⊗B⊗I(2) 208 # TODO: Implement some pretty printing in terms of ⊗. E.g InflatedTensor(I(3),B,I(2)) -> I(3)⊗B⊗I(2)
209 209
210 function range_size(itm::InflatedTensor) 210 function range_size(itm::InflatedTensor)
211 return concatenate_tuples( 211 return concatenate_tuples(
292 itm2 = InflatedTensor(IdentityTensor{T}(domain_size(tm1)),tm2) 292 itm2 = InflatedTensor(IdentityTensor{T}(domain_size(tm1)),tm2)
293 293
294 return itm1∘itm2 294 return itm1∘itm2
295 end 295 end
296 296
297 LazyOuterProduct(t1::IdentityTensor{T}, t2::IdentityTensor{T}) where T = IdentityTensor{T}(t1.size...,t2.size...) 297 LazyOuterProduct(t1::IdentityTensor, t2::IdentityTensor) = IdentityTensor{promote_type(eltype(t1),eltype(t2))}(t1.size...,t2.size...)
298 LazyOuterProduct(t1::LazyTensor, t2::IdentityTensor) = InflatedTensor(t1, t2) 298 LazyOuterProduct(t1::LazyTensor, t2::IdentityTensor) = InflatedTensor(t1, t2)
299 LazyOuterProduct(t1::IdentityTensor, t2::LazyTensor) = InflatedTensor(t1, t2) 299 LazyOuterProduct(t1::IdentityTensor, t2::LazyTensor) = InflatedTensor(t1, t2)
300 300
301 LazyOuterProduct(tms::Vararg{LazyTensor}) = foldl(LazyOuterProduct, tms) 301 LazyOuterProduct(tms::Vararg{LazyTensor}) = foldl(LazyOuterProduct, tms)
302 302