comparison src/LazyTensors/lazy_tensor_operations.jl @ 944:4a4ef4bf6cb9 feature/tensormapping_application_promotion

Move exports to the top of the files
author Jonatan Werpers <jonatan@werpers.com>
date Thu, 10 Mar 2022 16:59:43 +0100
parents fb060e98ac0a
children 86889fc5b63f
comparison
equal deleted inserted replaced
943:fb060e98ac0a 944:4a4ef4bf6cb9
1 export LazyTensorMappingApplication
2 export LazyTensorMappingTranspose
3 export TensorMappingComposition
4 export LazyLinearMap
5 export IdentityMapping
6 export InflatedTensorMapping
7 export LazyOuterProduct
8 export ⊗
9 export SizeMismatch
10
1 """ 11 """
2 LazyTensorMappingApplication{T,R,D} <: LazyArray{T,R} 12 LazyTensorMappingApplication{T,R,D} <: LazyArray{T,R}
3 13
4 Struct for lazy application of a TensorMapping. Created using `*`. 14 Struct for lazy application of a TensorMapping. Created using `*`.
5 15
15 T = promote_type(eltype(t), eltype(o)) 25 T = promote_type(eltype(t), eltype(o))
16 return new{T,R,D,typeof(t), typeof(o)}(t,o) 26 return new{T,R,D,typeof(t), typeof(o)}(t,o)
17 end 27 end
18 end 28 end
19 # TODO: Do boundschecking on creation! 29 # TODO: Do boundschecking on creation!
20 export LazyTensorMappingApplication
21 30
22 Base.getindex(ta::LazyTensorMappingApplication{T,R}, I::Vararg{Any,R}) where {T,R} = apply(ta.t, ta.o, I...) 31 Base.getindex(ta::LazyTensorMappingApplication{T,R}, I::Vararg{Any,R}) where {T,R} = apply(ta.t, ta.o, I...)
23 Base.size(ta::LazyTensorMappingApplication) = range_size(ta.t) 32 Base.size(ta::LazyTensorMappingApplication) = range_size(ta.t)
24 # TODO: What else is needed to implement the AbstractArray interface? 33 # TODO: What else is needed to implement the AbstractArray interface?
25 34
43 the appropriate methods of `m`. 52 the appropriate methods of `m`.
44 """ 53 """
45 struct LazyTensorMappingTranspose{T,R,D, TM<:TensorMapping{T,R,D}} <: TensorMapping{T,D,R} 54 struct LazyTensorMappingTranspose{T,R,D, TM<:TensorMapping{T,R,D}} <: TensorMapping{T,D,R}
46 tm::TM 55 tm::TM
47 end 56 end
48 export LazyTensorMappingTranspose
49 57
50 # # TBD: Should this be implemented on a type by type basis or through a trait to provide earlier errors? 58 # # TBD: Should this be implemented on a type by type basis or through a trait to provide earlier errors?
51 # Jonatan 2020-09-25: Is the problem that you can take the transpose of any TensorMapping even if it doesn't implement `apply_transpose`? 59 # Jonatan 2020-09-25: Is the problem that you can take the transpose of any TensorMapping even if it doesn't implement `apply_transpose`?
52 Base.adjoint(tm::TensorMapping) = LazyTensorMappingTranspose(tm) 60 Base.adjoint(tm::TensorMapping) = LazyTensorMappingTranspose(tm)
53 Base.adjoint(tmt::LazyTensorMappingTranspose) = tmt.tm 61 Base.adjoint(tmt::LazyTensorMappingTranspose) = tmt.tm
90 @inline function TensorMappingComposition(t1::TensorMapping{T,R,K}, t2::TensorMapping{T,K,D}) where {T,R,K,D} 98 @inline function TensorMappingComposition(t1::TensorMapping{T,R,K}, t2::TensorMapping{T,K,D}) where {T,R,K,D}
91 @boundscheck check_domain_size(t1, range_size(t2)) 99 @boundscheck check_domain_size(t1, range_size(t2))
92 return new{T,R,K,D, typeof(t1), typeof(t2)}(t1,t2) 100 return new{T,R,K,D, typeof(t1), typeof(t2)}(t1,t2)
93 end 101 end
94 end 102 end
95 export TensorMappingComposition
96 103
97 range_size(tm::TensorMappingComposition) = range_size(tm.t1) 104 range_size(tm::TensorMappingComposition) = range_size(tm.t1)
98 domain_size(tm::TensorMappingComposition) = domain_size(tm.t2) 105 domain_size(tm::TensorMappingComposition) = domain_size(tm.t2)
99 106
100 function apply(c::TensorMappingComposition{T,R,K,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,K,D} 107 function apply(c::TensorMappingComposition{T,R,K,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,K,D}
127 end 134 end
128 135
129 return new{T,R,D,RD,AA}(A,range_indicies,domain_indicies) 136 return new{T,R,D,RD,AA}(A,range_indicies,domain_indicies)
130 end 137 end
131 end 138 end
132 export LazyLinearMap
133 139
134 range_size(llm::LazyLinearMap) = size(llm.A)[[llm.range_indicies...]] 140 range_size(llm::LazyLinearMap) = size(llm.A)[[llm.range_indicies...]]
135 domain_size(llm::LazyLinearMap) = size(llm.A)[[llm.domain_indicies...]] 141 domain_size(llm::LazyLinearMap) = size(llm.A)[[llm.domain_indicies...]]
136 142
137 function apply(llm::LazyLinearMap{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} 143 function apply(llm::LazyLinearMap{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D}
155 dimensional ones through outer products. Also used in the Implementation for InflatedTensorMapping. 161 dimensional ones through outer products. Also used in the Implementation for InflatedTensorMapping.
156 """ 162 """
157 struct IdentityMapping{T,D} <: TensorMapping{T,D,D} 163 struct IdentityMapping{T,D} <: TensorMapping{T,D,D}
158 size::NTuple{D,Int} 164 size::NTuple{D,Int}
159 end 165 end
160 export IdentityMapping
161 166
162 IdentityMapping{T}(size::NTuple{D,Int}) where {T,D} = IdentityMapping{T,D}(size) 167 IdentityMapping{T}(size::NTuple{D,Int}) where {T,D} = IdentityMapping{T,D}(size)
163 IdentityMapping{T}(size::Vararg{Int,D}) where {T,D} = IdentityMapping{T,D}(size) 168 IdentityMapping{T}(size::Vararg{Int,D}) where {T,D} = IdentityMapping{T,D}(size)
164 IdentityMapping(size::Vararg{Int,D}) where D = IdentityMapping{Float64,D}(size) 169 IdentityMapping(size::Vararg{Int,D}) where D = IdentityMapping{Float64,D}(size)
165 170
212 D_after = domain_dim(after) 217 D_after = domain_dim(after)
213 D = D_before+D_middle+D_after 218 D = D_before+D_middle+D_after
214 return new{T,R,D,D_before,R_middle,D_middle,D_after, typeof(tm)}(before, tm, after) 219 return new{T,R,D,D_before,R_middle,D_middle,D_after, typeof(tm)}(before, tm, after)
215 end 220 end
216 end 221 end
217 export InflatedTensorMapping
218 """ 222 """
219 InflatedTensorMapping(before, tm, after) 223 InflatedTensorMapping(before, tm, after)
220 InflatedTensorMapping(before,tm) 224 InflatedTensorMapping(before,tm)
221 InflatedTensorMapping(tm,after) 225 InflatedTensorMapping(tm,after)
222 226
398 ```math 402 ```math
399 (A⊗B⊗C)v = [(A⊗I_{|M|}⊗I_{|P|}) [(I_{|J|}⊗B⊗I_{|P|}) [(I_{|J|}⊗I_{|N|}⊗C)v]]] 403 (A⊗B⊗C)v = [(A⊗I_{|M|}⊗I_{|P|}) [(I_{|J|}⊗B⊗I_{|P|}) [(I_{|J|}⊗I_{|N|}⊗C)v]]]
400 ``` 404 ```
401 """ 405 """
402 function LazyOuterProduct end 406 function LazyOuterProduct end
403 export LazyOuterProduct
404 407
405 function LazyOuterProduct(tm1::TensorMapping{T}, tm2::TensorMapping{T}) where T 408 function LazyOuterProduct(tm1::TensorMapping{T}, tm2::TensorMapping{T}) where T
406 itm1 = InflatedTensorMapping(tm1, IdentityMapping{T}(range_size(tm2))) 409 itm1 = InflatedTensorMapping(tm1, IdentityMapping{T}(range_size(tm2)))
407 itm2 = InflatedTensorMapping(IdentityMapping{T}(domain_size(tm1)),tm2) 410 itm2 = InflatedTensorMapping(IdentityMapping{T}(domain_size(tm1)),tm2)
408 411
414 LazyOuterProduct(t1::IdentityMapping, t2::TensorMapping) = InflatedTensorMapping(t1, t2) 417 LazyOuterProduct(t1::IdentityMapping, t2::TensorMapping) = InflatedTensorMapping(t1, t2)
415 418
416 LazyOuterProduct(tms::Vararg{TensorMapping}) = foldl(LazyOuterProduct, tms) 419 LazyOuterProduct(tms::Vararg{TensorMapping}) = foldl(LazyOuterProduct, tms)
417 420
418 ⊗(a::TensorMapping, b::TensorMapping) = LazyOuterProduct(a,b) 421 ⊗(a::TensorMapping, b::TensorMapping) = LazyOuterProduct(a,b)
419 export ⊗
420 422
421 423
422 function check_domain_size(tm::TensorMapping, sz) 424 function check_domain_size(tm::TensorMapping, sz)
423 if domain_size(tm) != sz 425 if domain_size(tm) != sz
424 throw(SizeMismatch(tm,sz)) 426 throw(SizeMismatch(tm,sz))
427 429
428 struct SizeMismatch <: Exception 430 struct SizeMismatch <: Exception
429 tm::TensorMapping 431 tm::TensorMapping
430 sz 432 sz
431 end 433 end
432 export SizeMismatch
433 434
434 function Base.showerror(io::IO, err::SizeMismatch) 435 function Base.showerror(io::IO, err::SizeMismatch)
435 print(io, "SizeMismatch: ") 436 print(io, "SizeMismatch: ")
436 print(io, "domain size $(domain_size(err.tm)) of TensorMapping not matching size $(err.sz)") 437 print(io, "domain size $(domain_size(err.tm)) of TensorMapping not matching size $(err.sz)")
437 end 438 end