Mercurial > repos > public > sbplib_julia
comparison src/LazyTensors/lazy_tensor_operations.jl @ 503:fbbb3733650c
Merge in feature/outer_product
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Fri, 06 Nov 2020 21:37:10 +0100 |
parents | 4b9d124fe984 |
children | a5caa934b35f fe86ac896377 |
comparison
equal
deleted
inserted
replaced
497:d8075fb14418 | 503:fbbb3733650c |
---|---|
308 """ | 308 """ |
309 flatten_tuple(t::NTuple{N, Number} where N) = t | 309 flatten_tuple(t::NTuple{N, Number} where N) = t |
310 flatten_tuple(t::Tuple) = ((flatten_tuple.(t)...)...,) # simplify? | 310 flatten_tuple(t::Tuple) = ((flatten_tuple.(t)...)...,) # simplify? |
311 flatten_tuple(ts::Vararg) = flatten_tuple(ts) | 311 flatten_tuple(ts::Vararg) = flatten_tuple(ts) |
312 | 312 |
313 """ | |
314 LazyOuterProduct(tms...) | |
315 | |
316 Creates a `TensorMappingComposition` for the outerproduct of `tms...`. | |
317 This is done by separating the outer product into regular products of outer products involving only identity mappings and one non-identity mapping. | |
318 | |
319 First let | |
320 ```math | |
321 A = A_{I,J} | |
322 B = B_{M,N} | |
323 C = C_{P,Q} | |
324 ``` | |
325 | |
326 where ``I``, ``M``, ``P`` are multi-indexes for the ranges of ``A``, ``B``, ``C``, and ``J``, ``N``, ``Q`` are multi-indexes of the domains. | |
327 | |
328 We use ``⊗`` to denote the outer product | |
329 ```math | |
330 (A⊗B)_{IM,JN} = A_{I,J}B_{M,N} | |
331 ``` | |
332 | |
333 We note that | |
334 ```math | |
335 A⊗B⊗C = (A⊗B⊗C)_{IMP,JNQ} = A_{I,J}B_{M,N}C_{P,Q} | |
336 ``` | |
337 And that | |
338 ```math | |
339 A⊗B⊗C = (A⊗I_{|M|}⊗I_{|P|})(I_{|J|}⊗B⊗I_{|P|})(I_{|J|}⊗I_{|N|}⊗C) | |
340 ``` | |
341 where |.| of a multi-index is a vector of sizes for each dimension. ``I_v`` denotes the identity tensor of size ``v[i]`` in each direction | |
342 To apply ``A⊗B⊗C`` we evaluate | |
343 | |
344 (A⊗B⊗C)v = [(A⊗I_{|M|}⊗I_{|P|}) [(I_{|J|}⊗B⊗I_{|P|}) [(I_{|J|}⊗I_{|N|}⊗C)v]]] | |
345 """ | |
346 function LazyOuterProduct end | |
347 export LazyOuterProduct | |
348 | |
349 function LazyOuterProduct(tm1::TensorMapping{T}, tm2::TensorMapping{T}) where T | |
350 itm1 = InflatedTensorMapping(tm1, IdentityMapping{T}(range_size(tm2))) | |
351 itm2 = InflatedTensorMapping(IdentityMapping{T}(domain_size(tm1)),tm2) | |
352 | |
353 return itm1∘itm2 | |
354 end | |
355 | |
356 LazyOuterProduct(t1::IdentityMapping{T}, t2::IdentityMapping{T}) where T = IdentityMapping{T}(t1.size...,t2.size...) | |
357 LazyOuterProduct(t1::TensorMapping, t2::IdentityMapping) = InflatedTensorMapping(t1, t2) | |
358 LazyOuterProduct(t1::IdentityMapping, t2::TensorMapping) = InflatedTensorMapping(t1, t2) | |
359 | |
360 LazyOuterProduct(tms::Vararg{TensorMapping}) = foldl(LazyOuterProduct, tms) | |
361 | |
362 ⊗(a::TensorMapping, b::TensorMapping) = LazyOuterProduct(a,b) | |
363 export ⊗ | |
364 | |
365 | |
313 function check_domain_size(tm::TensorMapping, sz) | 366 function check_domain_size(tm::TensorMapping, sz) |
314 if domain_size(tm) != sz | 367 if domain_size(tm) != sz |
315 throw(SizeMismatch(tm,sz)) | 368 throw(SizeMismatch(tm,sz)) |
316 end | 369 end |
317 end | 370 end |