comparison src/LazyTensors/lazy_tensor_operations.jl @ 960:e79debd10f7d feature/variable_derivatives

Merge feature/tensormapping_application_promotion
author Jonatan Werpers <jonatan@werpers.com>
date Mon, 14 Mar 2022 10:14:38 +0100
parents e9752c1e92f8 86889fc5b63f
children df562695b1b5
comparison
equal deleted inserted replaced
959:e9752c1e92f8 960:e79debd10f7d
1 using Sbplib.RegionIndices 1 using Sbplib.RegionIndices
2
3 export LazyTensorMappingApplication
4 export LazyTensorMappingTranspose
5 export TensorMappingComposition
6 export LazyLinearMap
7 export IdentityMapping
8 export InflatedTensorMapping
9 export LazyOuterProduct
10 export ⊗
11 export SizeMismatch
2 12
3 """ 13 """
4 LazyTensorMappingApplication{T,R,D} <: LazyArray{T,R} 14 LazyTensorMappingApplication{T,R,D} <: LazyArray{T,R}
5 15
6 Struct for lazy application of a TensorMapping. Created using `*`. 16 Struct for lazy application of a TensorMapping. Created using `*`.
7 17
8 Allows the result of a `TensorMapping` applied to a vector to be treated as an `AbstractArray`. 18 Allows the result of a `TensorMapping` applied to a vector to be treated as an `AbstractArray`.
9 With a mapping `m` and a vector `v` the LazyTensorMappingApplication object can be created by `m*v`. 19 With a mapping `m` and a vector `v` the LazyTensorMappingApplication object can be created by `m*v`.
10 The actual result will be calcualted when indexing into `m*v`. 20 The actual result will be calcualted when indexing into `m*v`.
11 """ 21 """
12 struct LazyTensorMappingApplication{T,R,D, TM<:TensorMapping{T,R,D}, AA<:AbstractArray{T,D}} <: LazyArray{T,R} 22 struct LazyTensorMappingApplication{T,R,D, TM<:TensorMapping{<:Any,R,D}, AA<:AbstractArray{<:Any,D}} <: LazyArray{T,R}
13 t::TM 23 t::TM
14 o::AA 24 o::AA
25
26 function LazyTensorMappingApplication(t::TensorMapping{<:Any,R,D}, o::AbstractArray{<:Any,D}) where {R,D}
27 T = promote_type(eltype(t), eltype(o))
28 return new{T,R,D,typeof(t), typeof(o)}(t,o)
29 end
15 end 30 end
16 # TODO: Do boundschecking on creation! 31 # TODO: Do boundschecking on creation!
17 export LazyTensorMappingApplication
18 32
19 Base.getindex(ta::LazyTensorMappingApplication{T,R}, I::Vararg{Any,R}) where {T,R} = apply(ta.t, ta.o, I...) 33 Base.getindex(ta::LazyTensorMappingApplication{T,R}, I::Vararg{Any,R}) where {T,R} = apply(ta.t, ta.o, I...)
20 Base.getindex(ta::LazyTensorMappingApplication{T,1}, I::CartesianIndex{1}) where {T} = apply(ta.t, ta.o, I.I...) # Would otherwise be caught in the previous method. 34 Base.getindex(ta::LazyTensorMappingApplication{T,1}, I::CartesianIndex{1}) where {T} = apply(ta.t, ta.o, I.I...) # Would otherwise be caught in the previous method.
21 Base.size(ta::LazyTensorMappingApplication) = range_size(ta.t) 35 Base.size(ta::LazyTensorMappingApplication) = range_size(ta.t)
22 # TODO: What else is needed to implement the AbstractArray interface? 36 # TODO: What else is needed to implement the AbstractArray interface?
41 the appropriate methods of `m`. 55 the appropriate methods of `m`.
42 """ 56 """
43 struct LazyTensorMappingTranspose{T,R,D, TM<:TensorMapping{T,R,D}} <: TensorMapping{T,D,R} 57 struct LazyTensorMappingTranspose{T,R,D, TM<:TensorMapping{T,R,D}} <: TensorMapping{T,D,R}
44 tm::TM 58 tm::TM
45 end 59 end
46 export LazyTensorMappingTranspose
47 60
48 # # TBD: Should this be implemented on a type by type basis or through a trait to provide earlier errors? 61 # # TBD: Should this be implemented on a type by type basis or through a trait to provide earlier errors?
49 # Jonatan 2020-09-25: Is the problem that you can take the transpose of any TensorMapping even if it doesn't implement `apply_transpose`? 62 # Jonatan 2020-09-25: Is the problem that you can take the transpose of any TensorMapping even if it doesn't implement `apply_transpose`?
50 Base.adjoint(tm::TensorMapping) = LazyTensorMappingTranspose(tm) 63 Base.adjoint(tm::TensorMapping) = LazyTensorMappingTranspose(tm)
51 Base.adjoint(tmt::LazyTensorMappingTranspose) = tmt.tm 64 Base.adjoint(tmt::LazyTensorMappingTranspose) = tmt.tm
52 65
53 apply(tmt::LazyTensorMappingTranspose{T,R,D}, v::AbstractArray{T,R}, I::Vararg{Any,D}) where {T,R,D} = apply_transpose(tmt.tm, v, I...) 66 apply(tmt::LazyTensorMappingTranspose{T,R,D}, v::AbstractArray{<:Any,R}, I::Vararg{Any,D}) where {T,R,D} = apply_transpose(tmt.tm, v, I...)
54 apply_transpose(tmt::LazyTensorMappingTranspose{T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmt.tm, v, I...) 67 apply_transpose(tmt::LazyTensorMappingTranspose{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmt.tm, v, I...)
55 68
56 range_size(tmt::LazyTensorMappingTranspose) = domain_size(tmt.tm) 69 range_size(tmt::LazyTensorMappingTranspose) = domain_size(tmt.tm)
57 domain_size(tmt::LazyTensorMappingTranspose) = range_size(tmt.tm) 70 domain_size(tmt::LazyTensorMappingTranspose) = range_size(tmt.tm)
58 71
59 72
65 return new{Op,T,R,D,T1,T2}(tm1,tm2) 78 return new{Op,T,R,D,T1,T2}(tm1,tm2)
66 end 79 end
67 end 80 end
68 # TODO: Boundschecking in constructor. 81 # TODO: Boundschecking in constructor.
69 82
70 apply(tmBinOp::LazyTensorMappingBinaryOperation{:+,T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) + apply(tmBinOp.tm2, v, I...) 83 apply(tmBinOp::LazyTensorMappingBinaryOperation{:+,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) + apply(tmBinOp.tm2, v, I...)
71 apply(tmBinOp::LazyTensorMappingBinaryOperation{:-,T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) - apply(tmBinOp.tm2, v, I...) 84 apply(tmBinOp::LazyTensorMappingBinaryOperation{:-,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) - apply(tmBinOp.tm2, v, I...)
72 85
73 range_size(tmBinOp::LazyTensorMappingBinaryOperation) = range_size(tmBinOp.tm1) 86 range_size(tmBinOp::LazyTensorMappingBinaryOperation) = range_size(tmBinOp.tm1)
74 domain_size(tmBinOp::LazyTensorMappingBinaryOperation) = domain_size(tmBinOp.tm1) 87 domain_size(tmBinOp::LazyTensorMappingBinaryOperation) = domain_size(tmBinOp.tm1)
75 88
76 Base.:+(tm1::TensorMapping{T,R,D}, tm2::TensorMapping{T,R,D}) where {T,R,D} = LazyTensorMappingBinaryOperation{:+,T,R,D}(tm1,tm2) 89 Base.:+(tm1::TensorMapping{T,R,D}, tm2::TensorMapping{T,R,D}) where {T,R,D} = LazyTensorMappingBinaryOperation{:+,T,R,D}(tm1,tm2)
88 @inline function TensorMappingComposition(t1::TensorMapping{T,R,K}, t2::TensorMapping{T,K,D}) where {T,R,K,D} 101 @inline function TensorMappingComposition(t1::TensorMapping{T,R,K}, t2::TensorMapping{T,K,D}) where {T,R,K,D}
89 @boundscheck check_domain_size(t1, range_size(t2)) 102 @boundscheck check_domain_size(t1, range_size(t2))
90 return new{T,R,K,D, typeof(t1), typeof(t2)}(t1,t2) 103 return new{T,R,K,D, typeof(t1), typeof(t2)}(t1,t2)
91 end 104 end
92 end 105 end
93 export TensorMappingComposition
94 106
95 range_size(tm::TensorMappingComposition) = range_size(tm.t1) 107 range_size(tm::TensorMappingComposition) = range_size(tm.t1)
96 domain_size(tm::TensorMappingComposition) = domain_size(tm.t2) 108 domain_size(tm::TensorMappingComposition) = domain_size(tm.t2)
97 109
98 function apply(c::TensorMappingComposition{T,R,K,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,K,D} 110 function apply(c::TensorMappingComposition{T,R,K,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,K,D}
99 apply(c.t1, c.t2*v, I...) 111 apply(c.t1, c.t2*v, I...)
100 end 112 end
101 113
102 function apply_transpose(c::TensorMappingComposition{T,R,K,D}, v::AbstractArray{T,R}, I::Vararg{Any,D}) where {T,R,K,D} 114 function apply_transpose(c::TensorMappingComposition{T,R,K,D}, v::AbstractArray{<:Any,R}, I::Vararg{Any,D}) where {T,R,K,D}
103 apply_transpose(c.t2, c.t1'*v, I...) 115 apply_transpose(c.t2, c.t1'*v, I...)
104 end 116 end
105 117
106 Base.@propagate_inbounds Base.:∘(s::TensorMapping, t::TensorMapping) = TensorMappingComposition(s,t) 118 Base.@propagate_inbounds Base.:∘(s::TensorMapping, t::TensorMapping) = TensorMappingComposition(s,t)
107 119
125 end 137 end
126 138
127 return new{T,R,D,RD,AA}(A,range_indicies,domain_indicies) 139 return new{T,R,D,RD,AA}(A,range_indicies,domain_indicies)
128 end 140 end
129 end 141 end
130 export LazyLinearMap
131 142
132 range_size(llm::LazyLinearMap) = size(llm.A)[[llm.range_indicies...]] 143 range_size(llm::LazyLinearMap) = size(llm.A)[[llm.range_indicies...]]
133 domain_size(llm::LazyLinearMap) = size(llm.A)[[llm.domain_indicies...]] 144 domain_size(llm::LazyLinearMap) = size(llm.A)[[llm.domain_indicies...]]
134 145
135 function apply(llm::LazyLinearMap{T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D} 146 function apply(llm::LazyLinearMap{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D}
136 view_index = ntuple(i->:,ndims(llm.A)) 147 view_index = ntuple(i->:,ndims(llm.A))
137 for i ∈ 1:R 148 for i ∈ 1:R
138 view_index = Base.setindex(view_index, Int(I[i]), llm.range_indicies[i]) 149 view_index = Base.setindex(view_index, Int(I[i]), llm.range_indicies[i])
139 end 150 end
140 A_view = @view llm.A[view_index...] 151 A_view = @view llm.A[view_index...]
141 return sum(A_view.*v) 152 return sum(A_view.*v)
142 end 153 end
143 154
144 function apply_transpose(llm::LazyLinearMap{T,R,D}, v::AbstractArray{T,R}, I::Vararg{Any,D}) where {T,R,D} 155 function apply_transpose(llm::LazyLinearMap{T,R,D}, v::AbstractArray{<:Any,R}, I::Vararg{Any,D}) where {T,R,D}
145 apply(LazyLinearMap(llm.A, llm.domain_indicies, llm.range_indicies), v, I...) 156 apply(LazyLinearMap(llm.A, llm.domain_indicies, llm.range_indicies), v, I...)
146 end 157 end
147 158
148 159
149 """ 160 """
153 dimensional ones through outer products. Also used in the Implementation for InflatedTensorMapping. 164 dimensional ones through outer products. Also used in the Implementation for InflatedTensorMapping.
154 """ 165 """
155 struct IdentityMapping{T,D} <: TensorMapping{T,D,D} 166 struct IdentityMapping{T,D} <: TensorMapping{T,D,D}
156 size::NTuple{D,Int} 167 size::NTuple{D,Int}
157 end 168 end
158 export IdentityMapping
159 169
160 IdentityMapping{T}(size::NTuple{D,Int}) where {T,D} = IdentityMapping{T,D}(size) 170 IdentityMapping{T}(size::NTuple{D,Int}) where {T,D} = IdentityMapping{T,D}(size)
161 IdentityMapping{T}(size::Vararg{Int,D}) where {T,D} = IdentityMapping{T,D}(size) 171 IdentityMapping{T}(size::Vararg{Int,D}) where {T,D} = IdentityMapping{T,D}(size)
162 IdentityMapping(size::Vararg{Int,D}) where D = IdentityMapping{Float64,D}(size) 172 IdentityMapping(size::Vararg{Int,D}) where D = IdentityMapping{Float64,D}(size)
163 173
164 range_size(tmi::IdentityMapping) = tmi.size 174 range_size(tmi::IdentityMapping) = tmi.size
165 domain_size(tmi::IdentityMapping) = tmi.size 175 domain_size(tmi::IdentityMapping) = tmi.size
166 176
167 apply(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...] 177 apply(tmi::IdentityMapping{T,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,D}) where {T,D} = v[I...]
168 apply_transpose(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...] 178 apply_transpose(tmi::IdentityMapping{T,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,D}) where {T,D} = v[I...]
169 179
170 """ 180 """
171 Base.:∘(tm, tmi) 181 Base.:∘(tm, tmi)
172 Base.:∘(tmi, tm) 182 Base.:∘(tmi, tm)
173 183
210 D_after = domain_dim(after) 220 D_after = domain_dim(after)
211 D = D_before+D_middle+D_after 221 D = D_before+D_middle+D_after
212 return new{T,R,D,D_before,R_middle,D_middle,D_after, typeof(tm)}(before, tm, after) 222 return new{T,R,D,D_before,R_middle,D_middle,D_after, typeof(tm)}(before, tm, after)
213 end 223 end
214 end 224 end
215 export InflatedTensorMapping
216 """ 225 """
217 InflatedTensorMapping(before, tm, after) 226 InflatedTensorMapping(before, tm, after)
218 InflatedTensorMapping(before,tm) 227 InflatedTensorMapping(before,tm)
219 InflatedTensorMapping(tm,after) 228 InflatedTensorMapping(tm,after)
220 229
256 domain_size(itm.tm), 265 domain_size(itm.tm),
257 domain_size(itm.after), 266 domain_size(itm.after),
258 ) 267 )
259 end 268 end
260 269
261 function apply(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D} 270 function apply(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D}
262 dim_before = range_dim(itm.before) 271 dim_before = range_dim(itm.before)
263 dim_domain = domain_dim(itm.tm) 272 dim_domain = domain_dim(itm.tm)
264 dim_range = range_dim(itm.tm) 273 dim_range = range_dim(itm.tm)
265 dim_after = range_dim(itm.after) 274 dim_after = range_dim(itm.after)
266 275
268 277
269 v_inner = view(v, view_index...) 278 v_inner = view(v, view_index...)
270 return apply(itm.tm, v_inner, inner_index...) 279 return apply(itm.tm, v_inner, inner_index...)
271 end 280 end
272 281
273 function apply_transpose(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{T,R}, I::Vararg{Any,D}) where {T,R,D} 282 function apply_transpose(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{<:Any,R}, I::Vararg{Any,D}) where {T,R,D}
274 dim_before = range_dim(itm.before) 283 dim_before = range_dim(itm.before)
275 dim_domain = domain_dim(itm.tm) 284 dim_domain = domain_dim(itm.tm)
276 dim_range = range_dim(itm.tm) 285 dim_range = range_dim(itm.tm)
277 dim_after = range_dim(itm.after) 286 dim_after = range_dim(itm.after)
278 287
396 ```math 405 ```math
397 (A⊗B⊗C)v = [(A⊗I_{|M|}⊗I_{|P|}) [(I_{|J|}⊗B⊗I_{|P|}) [(I_{|J|}⊗I_{|N|}⊗C)v]]] 406 (A⊗B⊗C)v = [(A⊗I_{|M|}⊗I_{|P|}) [(I_{|J|}⊗B⊗I_{|P|}) [(I_{|J|}⊗I_{|N|}⊗C)v]]]
398 ``` 407 ```
399 """ 408 """
400 function LazyOuterProduct end 409 function LazyOuterProduct end
401 export LazyOuterProduct
402 410
403 function LazyOuterProduct(tm1::TensorMapping{T}, tm2::TensorMapping{T}) where T 411 function LazyOuterProduct(tm1::TensorMapping{T}, tm2::TensorMapping{T}) where T
404 itm1 = InflatedTensorMapping(tm1, IdentityMapping{T}(range_size(tm2))) 412 itm1 = InflatedTensorMapping(tm1, IdentityMapping{T}(range_size(tm2)))
405 itm2 = InflatedTensorMapping(IdentityMapping{T}(domain_size(tm1)),tm2) 413 itm2 = InflatedTensorMapping(IdentityMapping{T}(domain_size(tm1)),tm2)
406 414
412 LazyOuterProduct(t1::IdentityMapping, t2::TensorMapping) = InflatedTensorMapping(t1, t2) 420 LazyOuterProduct(t1::IdentityMapping, t2::TensorMapping) = InflatedTensorMapping(t1, t2)
413 421
414 LazyOuterProduct(tms::Vararg{TensorMapping}) = foldl(LazyOuterProduct, tms) 422 LazyOuterProduct(tms::Vararg{TensorMapping}) = foldl(LazyOuterProduct, tms)
415 423
416 ⊗(a::TensorMapping, b::TensorMapping) = LazyOuterProduct(a,b) 424 ⊗(a::TensorMapping, b::TensorMapping) = LazyOuterProduct(a,b)
417 export ⊗
418 425
419 426
420 function check_domain_size(tm::TensorMapping, sz) 427 function check_domain_size(tm::TensorMapping, sz)
421 if domain_size(tm) != sz 428 if domain_size(tm) != sz
422 throw(SizeMismatch(tm,sz)) 429 throw(SizeMismatch(tm,sz))
425 432
426 struct SizeMismatch <: Exception 433 struct SizeMismatch <: Exception
427 tm::TensorMapping 434 tm::TensorMapping
428 sz 435 sz
429 end 436 end
430 export SizeMismatch
431 437
432 function Base.showerror(io::IO, err::SizeMismatch) 438 function Base.showerror(io::IO, err::SizeMismatch)
433 print(io, "SizeMismatch: ") 439 print(io, "SizeMismatch: ")
434 print(io, "domain size $(domain_size(err.tm)) of TensorMapping not matching size $(err.sz)") 440 print(io, "domain size $(domain_size(err.tm)) of TensorMapping not matching size $(err.sz)")
435 end 441 end