Mercurial > repos > public > sbplib_julia
comparison src/LazyTensors/lazy_tensor_operations.jl @ 498:5a600ec40ccc feature/outer_product
Merge in default
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Thu, 05 Nov 2020 15:27:04 +0100 |
parents | 2dc2eac27f75 df566372bb4f |
children | 7b550c714f3f |
comparison
equal
deleted
inserted
replaced
491:2dc2eac27f75 | 498:5a600ec40ccc |
---|---|
84 struct TensorMappingComposition{T,R,K,D, TM1<:TensorMapping{T,R,K}, TM2<:TensorMapping{T,K,D}} <: TensorMapping{T,R,D} | 84 struct TensorMappingComposition{T,R,K,D, TM1<:TensorMapping{T,R,K}, TM2<:TensorMapping{T,K,D}} <: TensorMapping{T,R,D} |
85 t1::TM1 | 85 t1::TM1 |
86 t2::TM2 | 86 t2::TM2 |
87 | 87 |
88 @inline function TensorMappingComposition(t1::TensorMapping{T,R,K}, t2::TensorMapping{T,K,D}) where {T,R,K,D} | 88 @inline function TensorMappingComposition(t1::TensorMapping{T,R,K}, t2::TensorMapping{T,K,D}) where {T,R,K,D} |
89 @boundscheck if domain_size(t1) != range_size(t2) | 89 @boundscheck check_domain_size(t1, range_size(t2)) |
90 throw(DimensionMismatch("the first argument has domain size $(domain_size(t1)) while the second has range size $(range_size(t2)) ")) | |
91 end | |
92 return new{T,R,K,D, typeof(t1), typeof(t2)}(t1,t2) | 90 return new{T,R,K,D, typeof(t1), typeof(t2)}(t1,t2) |
93 end | 91 end |
94 # Add check for matching sizes as a boundscheck | |
95 end | 92 end |
96 export TensorMappingComposition | 93 export TensorMappingComposition |
97 | 94 |
98 range_size(tm::TensorMappingComposition) = range_size(tm.t1) | 95 range_size(tm::TensorMappingComposition) = range_size(tm.t1) |
99 domain_size(tm::TensorMappingComposition) = domain_size(tm.t2) | 96 domain_size(tm::TensorMappingComposition) = domain_size(tm.t2) |
167 range_size(tmi::IdentityMapping) = tmi.size | 164 range_size(tmi::IdentityMapping) = tmi.size |
168 domain_size(tmi::IdentityMapping) = tmi.size | 165 domain_size(tmi::IdentityMapping) = tmi.size |
169 | 166 |
170 apply(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...] | 167 apply(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...] |
171 apply_transpose(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...] | 168 apply_transpose(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...] |
169 | |
170 """ | |
171 Base.:∘(tm, tmi) | |
172 Base.:∘(tmi, tm) | |
173 | |
174 Composes a `Tensormapping` `tm` with an `IdentityMapping` `tmi`, by returning `tm` | |
175 """ | |
176 @inline function Base.:∘(tm::TensorMapping{T,R,D}, tmi::IdentityMapping{T,D}) where {T,R,D} | |
177 @boundscheck check_domain_size(tm, range_size(tmi)) | |
178 return tm | |
179 end | |
180 | |
181 @inline function Base.:∘(tmi::IdentityMapping{T,R}, tm::TensorMapping{T,R,D}) where {T,R,D} | |
182 @boundscheck check_domain_size(tmi, range_size(tm)) | |
183 return tm | |
184 end | |
185 # Specialization for the case where tm is an IdentityMapping. Required to resolve ambiguity. | |
186 @inline function Base.:∘(tm::IdentityMapping{T,D}, tmi::IdentityMapping{T,D}) where {T,D} | |
187 @boundscheck check_domain_size(tm, range_size(tmi)) | |
188 return tmi | |
189 end | |
172 | 190 |
173 | 191 |
174 """ | 192 """ |
175 InflatedTensorMapping{T,R,D} <: TensorMapping{T,R,D} | 193 InflatedTensorMapping{T,R,D} <: TensorMapping{T,R,D} |
176 | 194 |
201 InflatedTensorMapping(tm,after) | 219 InflatedTensorMapping(tm,after) |
202 | 220 |
203 The outer product of `before`, `tm` and `after`, where `before` and `after` are `IdentityMapping`s. | 221 The outer product of `before`, `tm` and `after`, where `before` and `after` are `IdentityMapping`s. |
204 | 222 |
205 If one of `before` or `after` is left out, a 0-dimensional `IdentityMapping` is used as the default value. | 223 If one of `before` or `after` is left out, a 0-dimensional `IdentityMapping` is used as the default value. |
224 | |
225 If `tm` already is an `InflatedTensorMapping`, `before` and `after` will be extended instead of | |
226 creating a nested `InflatedTensorMapping`. | |
206 """ | 227 """ |
207 InflatedTensorMapping(::IdentityMapping, ::TensorMapping, ::IdentityMapping) | 228 InflatedTensorMapping(::IdentityMapping, ::TensorMapping, ::IdentityMapping) |
229 | |
230 function InflatedTensorMapping(before, itm::InflatedTensorMapping, after) | |
231 return InflatedTensorMapping( | |
232 IdentityMapping(before.size..., itm.before.size...), | |
233 itm.tm, | |
234 IdentityMapping(itm.after.size..., after.size...), | |
235 ) | |
236 end | |
237 | |
208 InflatedTensorMapping(before::IdentityMapping, tm::TensorMapping{T}) where T = InflatedTensorMapping(before,tm,IdentityMapping{T}()) | 238 InflatedTensorMapping(before::IdentityMapping, tm::TensorMapping{T}) where T = InflatedTensorMapping(before,tm,IdentityMapping{T}()) |
209 InflatedTensorMapping(tm::TensorMapping{T}, after::IdentityMapping) where T = InflatedTensorMapping(IdentityMapping{T}(),tm,after) | 239 InflatedTensorMapping(tm::TensorMapping{T}, after::IdentityMapping) where T = InflatedTensorMapping(IdentityMapping{T}(),tm,after) |
210 # Resolve ambiguity between the two previous methods | 240 # Resolve ambiguity between the two previous methods |
211 InflatedTensorMapping(I1::IdentityMapping{T}, I2::IdentityMapping{T}) where T = InflatedTensorMapping(I1,I2,IdentityMapping{T}()) | 241 InflatedTensorMapping(I1::IdentityMapping{T}, I2::IdentityMapping{T}) where T = InflatedTensorMapping(I1,I2,IdentityMapping{T}()) |
212 | 242 |
278 """ | 308 """ |
279 flatten_tuple(t::NTuple{N, Number} where N) = t | 309 flatten_tuple(t::NTuple{N, Number} where N) = t |
280 flatten_tuple(t::Tuple) = ((flatten_tuple.(t)...)...,) # simplify? | 310 flatten_tuple(t::Tuple) = ((flatten_tuple.(t)...)...,) # simplify? |
281 flatten_tuple(ts::Vararg) = flatten_tuple(ts) | 311 flatten_tuple(ts::Vararg) = flatten_tuple(ts) |
282 | 312 |
283 | |
284 """ | 313 """ |
285 LazyOuterProduct(tms...) | 314 LazyOuterProduct(tms...) |
286 | 315 |
287 Creates a `TensorComposition` for the outerproduct of `tms...`. | 316 Creates a `TensorComposition` for the outerproduct of `tms...`. |
288 This is done by separating the outer product into regular products of outer products involving only identity mappings and one non-identity mapping. | 317 This is done by separating the outer product into regular products of outer products involving only identity mappings and one non-identity mapping. |
332 | 361 |
333 ⊗(a::TensorMapping, b::TensorMapping) = LazyOuterProduct(a,b) | 362 ⊗(a::TensorMapping, b::TensorMapping) = LazyOuterProduct(a,b) |
334 export ⊗ | 363 export ⊗ |
335 | 364 |
336 # TBD: Should we implement simplifications for outer products of LazyIdentities other LazyIdentities or Inflated tensormappings? | 365 # TBD: Should we implement simplifications for outer products of LazyIdentities other LazyIdentities or Inflated tensormappings? |
366 | |
367 function check_domain_size(tm::TensorMapping, sz) | |
368 if domain_size(tm) != sz | |
369 throw(SizeMismatch(tm,sz)) | |
370 end | |
371 end | |
372 | |
373 struct SizeMismatch <: Exception | |
374 tm::TensorMapping | |
375 sz | |
376 end | |
377 export SizeMismatch | |
378 | |
379 function Base.showerror(io::IO, err::SizeMismatch) | |
380 print(io, "SizeMismatch: ") | |
381 print(io, "domain size $(domain_size(err.tm)) of TensorMapping not matching size $(err.sz)") | |
382 end |