Mercurial > repos > public > sbplib_julia
comparison src/LazyTensors/lazy_tensor_operations.jl @ 944:4a4ef4bf6cb9 feature/tensormapping_application_promotion
Move exports to the top of the files
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Thu, 10 Mar 2022 16:59:43 +0100 |
parents | fb060e98ac0a |
children | 86889fc5b63f |
comparison
equal
deleted
inserted
replaced
943:fb060e98ac0a | 944:4a4ef4bf6cb9 |
---|---|
1 export LazyTensorMappingApplication | |
2 export LazyTensorMappingTranspose | |
3 export TensorMappingComposition | |
4 export LazyLinearMap | |
5 export IdentityMapping | |
6 export InflatedTensorMapping | |
7 export LazyOuterProduct | |
8 export ⊗ | |
9 export SizeMismatch | |
10 | |
1 """ | 11 """ |
2 LazyTensorMappingApplication{T,R,D} <: LazyArray{T,R} | 12 LazyTensorMappingApplication{T,R,D} <: LazyArray{T,R} |
3 | 13 |
4 Struct for lazy application of a TensorMapping. Created using `*`. | 14 Struct for lazy application of a TensorMapping. Created using `*`. |
5 | 15 |
15 T = promote_type(eltype(t), eltype(o)) | 25 T = promote_type(eltype(t), eltype(o)) |
16 return new{T,R,D,typeof(t), typeof(o)}(t,o) | 26 return new{T,R,D,typeof(t), typeof(o)}(t,o) |
17 end | 27 end |
18 end | 28 end |
19 # TODO: Do boundschecking on creation! | 29 # TODO: Do boundschecking on creation! |
20 export LazyTensorMappingApplication | |
21 | 30 |
22 Base.getindex(ta::LazyTensorMappingApplication{T,R}, I::Vararg{Any,R}) where {T,R} = apply(ta.t, ta.o, I...) | 31 Base.getindex(ta::LazyTensorMappingApplication{T,R}, I::Vararg{Any,R}) where {T,R} = apply(ta.t, ta.o, I...) |
23 Base.size(ta::LazyTensorMappingApplication) = range_size(ta.t) | 32 Base.size(ta::LazyTensorMappingApplication) = range_size(ta.t) |
24 # TODO: What else is needed to implement the AbstractArray interface? | 33 # TODO: What else is needed to implement the AbstractArray interface? |
25 | 34 |
43 the appropriate methods of `m`. | 52 the appropriate methods of `m`. |
44 """ | 53 """ |
45 struct LazyTensorMappingTranspose{T,R,D, TM<:TensorMapping{T,R,D}} <: TensorMapping{T,D,R} | 54 struct LazyTensorMappingTranspose{T,R,D, TM<:TensorMapping{T,R,D}} <: TensorMapping{T,D,R} |
46 tm::TM | 55 tm::TM |
47 end | 56 end |
48 export LazyTensorMappingTranspose | |
49 | 57 |
50 # # TBD: Should this be implemented on a type by type basis or through a trait to provide earlier errors? | 58 # # TBD: Should this be implemented on a type by type basis or through a trait to provide earlier errors? |
51 # Jonatan 2020-09-25: Is the problem that you can take the transpose of any TensorMapping even if it doesn't implement `apply_transpose`? | 59 # Jonatan 2020-09-25: Is the problem that you can take the transpose of any TensorMapping even if it doesn't implement `apply_transpose`? |
52 Base.adjoint(tm::TensorMapping) = LazyTensorMappingTranspose(tm) | 60 Base.adjoint(tm::TensorMapping) = LazyTensorMappingTranspose(tm) |
53 Base.adjoint(tmt::LazyTensorMappingTranspose) = tmt.tm | 61 Base.adjoint(tmt::LazyTensorMappingTranspose) = tmt.tm |
90 @inline function TensorMappingComposition(t1::TensorMapping{T,R,K}, t2::TensorMapping{T,K,D}) where {T,R,K,D} | 98 @inline function TensorMappingComposition(t1::TensorMapping{T,R,K}, t2::TensorMapping{T,K,D}) where {T,R,K,D} |
91 @boundscheck check_domain_size(t1, range_size(t2)) | 99 @boundscheck check_domain_size(t1, range_size(t2)) |
92 return new{T,R,K,D, typeof(t1), typeof(t2)}(t1,t2) | 100 return new{T,R,K,D, typeof(t1), typeof(t2)}(t1,t2) |
93 end | 101 end |
94 end | 102 end |
95 export TensorMappingComposition | |
96 | 103 |
97 range_size(tm::TensorMappingComposition) = range_size(tm.t1) | 104 range_size(tm::TensorMappingComposition) = range_size(tm.t1) |
98 domain_size(tm::TensorMappingComposition) = domain_size(tm.t2) | 105 domain_size(tm::TensorMappingComposition) = domain_size(tm.t2) |
99 | 106 |
100 function apply(c::TensorMappingComposition{T,R,K,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,K,D} | 107 function apply(c::TensorMappingComposition{T,R,K,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,K,D} |
127 end | 134 end |
128 | 135 |
129 return new{T,R,D,RD,AA}(A,range_indicies,domain_indicies) | 136 return new{T,R,D,RD,AA}(A,range_indicies,domain_indicies) |
130 end | 137 end |
131 end | 138 end |
132 export LazyLinearMap | |
133 | 139 |
134 range_size(llm::LazyLinearMap) = size(llm.A)[[llm.range_indicies...]] | 140 range_size(llm::LazyLinearMap) = size(llm.A)[[llm.range_indicies...]] |
135 domain_size(llm::LazyLinearMap) = size(llm.A)[[llm.domain_indicies...]] | 141 domain_size(llm::LazyLinearMap) = size(llm.A)[[llm.domain_indicies...]] |
136 | 142 |
137 function apply(llm::LazyLinearMap{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} | 143 function apply(llm::LazyLinearMap{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} |
155 dimensional ones through outer products. Also used in the Implementation for InflatedTensorMapping. | 161 dimensional ones through outer products. Also used in the Implementation for InflatedTensorMapping. |
156 """ | 162 """ |
157 struct IdentityMapping{T,D} <: TensorMapping{T,D,D} | 163 struct IdentityMapping{T,D} <: TensorMapping{T,D,D} |
158 size::NTuple{D,Int} | 164 size::NTuple{D,Int} |
159 end | 165 end |
160 export IdentityMapping | |
161 | 166 |
162 IdentityMapping{T}(size::NTuple{D,Int}) where {T,D} = IdentityMapping{T,D}(size) | 167 IdentityMapping{T}(size::NTuple{D,Int}) where {T,D} = IdentityMapping{T,D}(size) |
163 IdentityMapping{T}(size::Vararg{Int,D}) where {T,D} = IdentityMapping{T,D}(size) | 168 IdentityMapping{T}(size::Vararg{Int,D}) where {T,D} = IdentityMapping{T,D}(size) |
164 IdentityMapping(size::Vararg{Int,D}) where D = IdentityMapping{Float64,D}(size) | 169 IdentityMapping(size::Vararg{Int,D}) where D = IdentityMapping{Float64,D}(size) |
165 | 170 |
212 D_after = domain_dim(after) | 217 D_after = domain_dim(after) |
213 D = D_before+D_middle+D_after | 218 D = D_before+D_middle+D_after |
214 return new{T,R,D,D_before,R_middle,D_middle,D_after, typeof(tm)}(before, tm, after) | 219 return new{T,R,D,D_before,R_middle,D_middle,D_after, typeof(tm)}(before, tm, after) |
215 end | 220 end |
216 end | 221 end |
217 export InflatedTensorMapping | |
218 """ | 222 """ |
219 InflatedTensorMapping(before, tm, after) | 223 InflatedTensorMapping(before, tm, after) |
220 InflatedTensorMapping(before,tm) | 224 InflatedTensorMapping(before,tm) |
221 InflatedTensorMapping(tm,after) | 225 InflatedTensorMapping(tm,after) |
222 | 226 |
398 ```math | 402 ```math |
399 (A⊗B⊗C)v = [(A⊗I_{|M|}⊗I_{|P|}) [(I_{|J|}⊗B⊗I_{|P|}) [(I_{|J|}⊗I_{|N|}⊗C)v]]] | 403 (A⊗B⊗C)v = [(A⊗I_{|M|}⊗I_{|P|}) [(I_{|J|}⊗B⊗I_{|P|}) [(I_{|J|}⊗I_{|N|}⊗C)v]]] |
400 ``` | 404 ``` |
401 """ | 405 """ |
402 function LazyOuterProduct end | 406 function LazyOuterProduct end |
403 export LazyOuterProduct | |
404 | 407 |
405 function LazyOuterProduct(tm1::TensorMapping{T}, tm2::TensorMapping{T}) where T | 408 function LazyOuterProduct(tm1::TensorMapping{T}, tm2::TensorMapping{T}) where T |
406 itm1 = InflatedTensorMapping(tm1, IdentityMapping{T}(range_size(tm2))) | 409 itm1 = InflatedTensorMapping(tm1, IdentityMapping{T}(range_size(tm2))) |
407 itm2 = InflatedTensorMapping(IdentityMapping{T}(domain_size(tm1)),tm2) | 410 itm2 = InflatedTensorMapping(IdentityMapping{T}(domain_size(tm1)),tm2) |
408 | 411 |
414 LazyOuterProduct(t1::IdentityMapping, t2::TensorMapping) = InflatedTensorMapping(t1, t2) | 417 LazyOuterProduct(t1::IdentityMapping, t2::TensorMapping) = InflatedTensorMapping(t1, t2) |
415 | 418 |
416 LazyOuterProduct(tms::Vararg{TensorMapping}) = foldl(LazyOuterProduct, tms) | 419 LazyOuterProduct(tms::Vararg{TensorMapping}) = foldl(LazyOuterProduct, tms) |
417 | 420 |
418 ⊗(a::TensorMapping, b::TensorMapping) = LazyOuterProduct(a,b) | 421 ⊗(a::TensorMapping, b::TensorMapping) = LazyOuterProduct(a,b) |
419 export ⊗ | |
420 | 422 |
421 | 423 |
422 function check_domain_size(tm::TensorMapping, sz) | 424 function check_domain_size(tm::TensorMapping, sz) |
423 if domain_size(tm) != sz | 425 if domain_size(tm) != sz |
424 throw(SizeMismatch(tm,sz)) | 426 throw(SizeMismatch(tm,sz)) |
427 | 429 |
428 struct SizeMismatch <: Exception | 430 struct SizeMismatch <: Exception |
429 tm::TensorMapping | 431 tm::TensorMapping |
430 sz | 432 sz |
431 end | 433 end |
432 export SizeMismatch | |
433 | 434 |
434 function Base.showerror(io::IO, err::SizeMismatch) | 435 function Base.showerror(io::IO, err::SizeMismatch) |
435 print(io, "SizeMismatch: ") | 436 print(io, "SizeMismatch: ") |
436 print(io, "domain size $(domain_size(err.tm)) of TensorMapping not matching size $(err.sz)") | 437 print(io, "domain size $(domain_size(err.tm)) of TensorMapping not matching size $(err.sz)") |
437 end | 438 end |