Mercurial > repos > public > sbplib_julia
comparison src/LazyTensors/lazy_tensor_operations.jl @ 960:e79debd10f7d feature/variable_derivatives
Merge feature/tensormapping_application_promotion
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Mon, 14 Mar 2022 10:14:38 +0100 |
parents | e9752c1e92f8 86889fc5b63f |
children | df562695b1b5 |
comparison
equal
deleted
inserted
replaced
959:e9752c1e92f8 | 960:e79debd10f7d |
---|---|
1 using Sbplib.RegionIndices | 1 using Sbplib.RegionIndices |
2 | |
3 export LazyTensorMappingApplication | |
4 export LazyTensorMappingTranspose | |
5 export TensorMappingComposition | |
6 export LazyLinearMap | |
7 export IdentityMapping | |
8 export InflatedTensorMapping | |
9 export LazyOuterProduct | |
10 export ⊗ | |
11 export SizeMismatch | |
2 | 12 |
3 """ | 13 """ |
4 LazyTensorMappingApplication{T,R,D} <: LazyArray{T,R} | 14 LazyTensorMappingApplication{T,R,D} <: LazyArray{T,R} |
5 | 15 |
6 Struct for lazy application of a TensorMapping. Created using `*`. | 16 Struct for lazy application of a TensorMapping. Created using `*`. |
7 | 17 |
8 Allows the result of a `TensorMapping` applied to a vector to be treated as an `AbstractArray`. | 18 Allows the result of a `TensorMapping` applied to a vector to be treated as an `AbstractArray`. |
9 With a mapping `m` and a vector `v` the LazyTensorMappingApplication object can be created by `m*v`. | 19 With a mapping `m` and a vector `v` the LazyTensorMappingApplication object can be created by `m*v`. |
10 The actual result will be calcualted when indexing into `m*v`. | 20 The actual result will be calcualted when indexing into `m*v`. |
11 """ | 21 """ |
12 struct LazyTensorMappingApplication{T,R,D, TM<:TensorMapping{T,R,D}, AA<:AbstractArray{T,D}} <: LazyArray{T,R} | 22 struct LazyTensorMappingApplication{T,R,D, TM<:TensorMapping{<:Any,R,D}, AA<:AbstractArray{<:Any,D}} <: LazyArray{T,R} |
13 t::TM | 23 t::TM |
14 o::AA | 24 o::AA |
25 | |
26 function LazyTensorMappingApplication(t::TensorMapping{<:Any,R,D}, o::AbstractArray{<:Any,D}) where {R,D} | |
27 T = promote_type(eltype(t), eltype(o)) | |
28 return new{T,R,D,typeof(t), typeof(o)}(t,o) | |
29 end | |
15 end | 30 end |
16 # TODO: Do boundschecking on creation! | 31 # TODO: Do boundschecking on creation! |
17 export LazyTensorMappingApplication | |
18 | 32 |
19 Base.getindex(ta::LazyTensorMappingApplication{T,R}, I::Vararg{Any,R}) where {T,R} = apply(ta.t, ta.o, I...) | 33 Base.getindex(ta::LazyTensorMappingApplication{T,R}, I::Vararg{Any,R}) where {T,R} = apply(ta.t, ta.o, I...) |
20 Base.getindex(ta::LazyTensorMappingApplication{T,1}, I::CartesianIndex{1}) where {T} = apply(ta.t, ta.o, I.I...) # Would otherwise be caught in the previous method. | 34 Base.getindex(ta::LazyTensorMappingApplication{T,1}, I::CartesianIndex{1}) where {T} = apply(ta.t, ta.o, I.I...) # Would otherwise be caught in the previous method. |
21 Base.size(ta::LazyTensorMappingApplication) = range_size(ta.t) | 35 Base.size(ta::LazyTensorMappingApplication) = range_size(ta.t) |
22 # TODO: What else is needed to implement the AbstractArray interface? | 36 # TODO: What else is needed to implement the AbstractArray interface? |
41 the appropriate methods of `m`. | 55 the appropriate methods of `m`. |
42 """ | 56 """ |
43 struct LazyTensorMappingTranspose{T,R,D, TM<:TensorMapping{T,R,D}} <: TensorMapping{T,D,R} | 57 struct LazyTensorMappingTranspose{T,R,D, TM<:TensorMapping{T,R,D}} <: TensorMapping{T,D,R} |
44 tm::TM | 58 tm::TM |
45 end | 59 end |
46 export LazyTensorMappingTranspose | |
47 | 60 |
48 # # TBD: Should this be implemented on a type by type basis or through a trait to provide earlier errors? | 61 # # TBD: Should this be implemented on a type by type basis or through a trait to provide earlier errors? |
49 # Jonatan 2020-09-25: Is the problem that you can take the transpose of any TensorMapping even if it doesn't implement `apply_transpose`? | 62 # Jonatan 2020-09-25: Is the problem that you can take the transpose of any TensorMapping even if it doesn't implement `apply_transpose`? |
50 Base.adjoint(tm::TensorMapping) = LazyTensorMappingTranspose(tm) | 63 Base.adjoint(tm::TensorMapping) = LazyTensorMappingTranspose(tm) |
51 Base.adjoint(tmt::LazyTensorMappingTranspose) = tmt.tm | 64 Base.adjoint(tmt::LazyTensorMappingTranspose) = tmt.tm |
52 | 65 |
53 apply(tmt::LazyTensorMappingTranspose{T,R,D}, v::AbstractArray{T,R}, I::Vararg{Any,D}) where {T,R,D} = apply_transpose(tmt.tm, v, I...) | 66 apply(tmt::LazyTensorMappingTranspose{T,R,D}, v::AbstractArray{<:Any,R}, I::Vararg{Any,D}) where {T,R,D} = apply_transpose(tmt.tm, v, I...) |
54 apply_transpose(tmt::LazyTensorMappingTranspose{T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmt.tm, v, I...) | 67 apply_transpose(tmt::LazyTensorMappingTranspose{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmt.tm, v, I...) |
55 | 68 |
56 range_size(tmt::LazyTensorMappingTranspose) = domain_size(tmt.tm) | 69 range_size(tmt::LazyTensorMappingTranspose) = domain_size(tmt.tm) |
57 domain_size(tmt::LazyTensorMappingTranspose) = range_size(tmt.tm) | 70 domain_size(tmt::LazyTensorMappingTranspose) = range_size(tmt.tm) |
58 | 71 |
59 | 72 |
65 return new{Op,T,R,D,T1,T2}(tm1,tm2) | 78 return new{Op,T,R,D,T1,T2}(tm1,tm2) |
66 end | 79 end |
67 end | 80 end |
68 # TODO: Boundschecking in constructor. | 81 # TODO: Boundschecking in constructor. |
69 | 82 |
70 apply(tmBinOp::LazyTensorMappingBinaryOperation{:+,T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) + apply(tmBinOp.tm2, v, I...) | 83 apply(tmBinOp::LazyTensorMappingBinaryOperation{:+,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) + apply(tmBinOp.tm2, v, I...) |
71 apply(tmBinOp::LazyTensorMappingBinaryOperation{:-,T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) - apply(tmBinOp.tm2, v, I...) | 84 apply(tmBinOp::LazyTensorMappingBinaryOperation{:-,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) - apply(tmBinOp.tm2, v, I...) |
72 | 85 |
73 range_size(tmBinOp::LazyTensorMappingBinaryOperation) = range_size(tmBinOp.tm1) | 86 range_size(tmBinOp::LazyTensorMappingBinaryOperation) = range_size(tmBinOp.tm1) |
74 domain_size(tmBinOp::LazyTensorMappingBinaryOperation) = domain_size(tmBinOp.tm1) | 87 domain_size(tmBinOp::LazyTensorMappingBinaryOperation) = domain_size(tmBinOp.tm1) |
75 | 88 |
76 Base.:+(tm1::TensorMapping{T,R,D}, tm2::TensorMapping{T,R,D}) where {T,R,D} = LazyTensorMappingBinaryOperation{:+,T,R,D}(tm1,tm2) | 89 Base.:+(tm1::TensorMapping{T,R,D}, tm2::TensorMapping{T,R,D}) where {T,R,D} = LazyTensorMappingBinaryOperation{:+,T,R,D}(tm1,tm2) |
88 @inline function TensorMappingComposition(t1::TensorMapping{T,R,K}, t2::TensorMapping{T,K,D}) where {T,R,K,D} | 101 @inline function TensorMappingComposition(t1::TensorMapping{T,R,K}, t2::TensorMapping{T,K,D}) where {T,R,K,D} |
89 @boundscheck check_domain_size(t1, range_size(t2)) | 102 @boundscheck check_domain_size(t1, range_size(t2)) |
90 return new{T,R,K,D, typeof(t1), typeof(t2)}(t1,t2) | 103 return new{T,R,K,D, typeof(t1), typeof(t2)}(t1,t2) |
91 end | 104 end |
92 end | 105 end |
93 export TensorMappingComposition | |
94 | 106 |
95 range_size(tm::TensorMappingComposition) = range_size(tm.t1) | 107 range_size(tm::TensorMappingComposition) = range_size(tm.t1) |
96 domain_size(tm::TensorMappingComposition) = domain_size(tm.t2) | 108 domain_size(tm::TensorMappingComposition) = domain_size(tm.t2) |
97 | 109 |
98 function apply(c::TensorMappingComposition{T,R,K,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,K,D} | 110 function apply(c::TensorMappingComposition{T,R,K,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,K,D} |
99 apply(c.t1, c.t2*v, I...) | 111 apply(c.t1, c.t2*v, I...) |
100 end | 112 end |
101 | 113 |
102 function apply_transpose(c::TensorMappingComposition{T,R,K,D}, v::AbstractArray{T,R}, I::Vararg{Any,D}) where {T,R,K,D} | 114 function apply_transpose(c::TensorMappingComposition{T,R,K,D}, v::AbstractArray{<:Any,R}, I::Vararg{Any,D}) where {T,R,K,D} |
103 apply_transpose(c.t2, c.t1'*v, I...) | 115 apply_transpose(c.t2, c.t1'*v, I...) |
104 end | 116 end |
105 | 117 |
106 Base.@propagate_inbounds Base.:∘(s::TensorMapping, t::TensorMapping) = TensorMappingComposition(s,t) | 118 Base.@propagate_inbounds Base.:∘(s::TensorMapping, t::TensorMapping) = TensorMappingComposition(s,t) |
107 | 119 |
125 end | 137 end |
126 | 138 |
127 return new{T,R,D,RD,AA}(A,range_indicies,domain_indicies) | 139 return new{T,R,D,RD,AA}(A,range_indicies,domain_indicies) |
128 end | 140 end |
129 end | 141 end |
130 export LazyLinearMap | |
131 | 142 |
132 range_size(llm::LazyLinearMap) = size(llm.A)[[llm.range_indicies...]] | 143 range_size(llm::LazyLinearMap) = size(llm.A)[[llm.range_indicies...]] |
133 domain_size(llm::LazyLinearMap) = size(llm.A)[[llm.domain_indicies...]] | 144 domain_size(llm::LazyLinearMap) = size(llm.A)[[llm.domain_indicies...]] |
134 | 145 |
135 function apply(llm::LazyLinearMap{T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D} | 146 function apply(llm::LazyLinearMap{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} |
136 view_index = ntuple(i->:,ndims(llm.A)) | 147 view_index = ntuple(i->:,ndims(llm.A)) |
137 for i ∈ 1:R | 148 for i ∈ 1:R |
138 view_index = Base.setindex(view_index, Int(I[i]), llm.range_indicies[i]) | 149 view_index = Base.setindex(view_index, Int(I[i]), llm.range_indicies[i]) |
139 end | 150 end |
140 A_view = @view llm.A[view_index...] | 151 A_view = @view llm.A[view_index...] |
141 return sum(A_view.*v) | 152 return sum(A_view.*v) |
142 end | 153 end |
143 | 154 |
144 function apply_transpose(llm::LazyLinearMap{T,R,D}, v::AbstractArray{T,R}, I::Vararg{Any,D}) where {T,R,D} | 155 function apply_transpose(llm::LazyLinearMap{T,R,D}, v::AbstractArray{<:Any,R}, I::Vararg{Any,D}) where {T,R,D} |
145 apply(LazyLinearMap(llm.A, llm.domain_indicies, llm.range_indicies), v, I...) | 156 apply(LazyLinearMap(llm.A, llm.domain_indicies, llm.range_indicies), v, I...) |
146 end | 157 end |
147 | 158 |
148 | 159 |
149 """ | 160 """ |
153 dimensional ones through outer products. Also used in the Implementation for InflatedTensorMapping. | 164 dimensional ones through outer products. Also used in the Implementation for InflatedTensorMapping. |
154 """ | 165 """ |
155 struct IdentityMapping{T,D} <: TensorMapping{T,D,D} | 166 struct IdentityMapping{T,D} <: TensorMapping{T,D,D} |
156 size::NTuple{D,Int} | 167 size::NTuple{D,Int} |
157 end | 168 end |
158 export IdentityMapping | |
159 | 169 |
160 IdentityMapping{T}(size::NTuple{D,Int}) where {T,D} = IdentityMapping{T,D}(size) | 170 IdentityMapping{T}(size::NTuple{D,Int}) where {T,D} = IdentityMapping{T,D}(size) |
161 IdentityMapping{T}(size::Vararg{Int,D}) where {T,D} = IdentityMapping{T,D}(size) | 171 IdentityMapping{T}(size::Vararg{Int,D}) where {T,D} = IdentityMapping{T,D}(size) |
162 IdentityMapping(size::Vararg{Int,D}) where D = IdentityMapping{Float64,D}(size) | 172 IdentityMapping(size::Vararg{Int,D}) where D = IdentityMapping{Float64,D}(size) |
163 | 173 |
164 range_size(tmi::IdentityMapping) = tmi.size | 174 range_size(tmi::IdentityMapping) = tmi.size |
165 domain_size(tmi::IdentityMapping) = tmi.size | 175 domain_size(tmi::IdentityMapping) = tmi.size |
166 | 176 |
167 apply(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...] | 177 apply(tmi::IdentityMapping{T,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,D}) where {T,D} = v[I...] |
168 apply_transpose(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...] | 178 apply_transpose(tmi::IdentityMapping{T,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,D}) where {T,D} = v[I...] |
169 | 179 |
170 """ | 180 """ |
171 Base.:∘(tm, tmi) | 181 Base.:∘(tm, tmi) |
172 Base.:∘(tmi, tm) | 182 Base.:∘(tmi, tm) |
173 | 183 |
210 D_after = domain_dim(after) | 220 D_after = domain_dim(after) |
211 D = D_before+D_middle+D_after | 221 D = D_before+D_middle+D_after |
212 return new{T,R,D,D_before,R_middle,D_middle,D_after, typeof(tm)}(before, tm, after) | 222 return new{T,R,D,D_before,R_middle,D_middle,D_after, typeof(tm)}(before, tm, after) |
213 end | 223 end |
214 end | 224 end |
215 export InflatedTensorMapping | |
216 """ | 225 """ |
217 InflatedTensorMapping(before, tm, after) | 226 InflatedTensorMapping(before, tm, after) |
218 InflatedTensorMapping(before,tm) | 227 InflatedTensorMapping(before,tm) |
219 InflatedTensorMapping(tm,after) | 228 InflatedTensorMapping(tm,after) |
220 | 229 |
256 domain_size(itm.tm), | 265 domain_size(itm.tm), |
257 domain_size(itm.after), | 266 domain_size(itm.after), |
258 ) | 267 ) |
259 end | 268 end |
260 | 269 |
261 function apply(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D} | 270 function apply(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} |
262 dim_before = range_dim(itm.before) | 271 dim_before = range_dim(itm.before) |
263 dim_domain = domain_dim(itm.tm) | 272 dim_domain = domain_dim(itm.tm) |
264 dim_range = range_dim(itm.tm) | 273 dim_range = range_dim(itm.tm) |
265 dim_after = range_dim(itm.after) | 274 dim_after = range_dim(itm.after) |
266 | 275 |
268 | 277 |
269 v_inner = view(v, view_index...) | 278 v_inner = view(v, view_index...) |
270 return apply(itm.tm, v_inner, inner_index...) | 279 return apply(itm.tm, v_inner, inner_index...) |
271 end | 280 end |
272 | 281 |
273 function apply_transpose(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{T,R}, I::Vararg{Any,D}) where {T,R,D} | 282 function apply_transpose(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{<:Any,R}, I::Vararg{Any,D}) where {T,R,D} |
274 dim_before = range_dim(itm.before) | 283 dim_before = range_dim(itm.before) |
275 dim_domain = domain_dim(itm.tm) | 284 dim_domain = domain_dim(itm.tm) |
276 dim_range = range_dim(itm.tm) | 285 dim_range = range_dim(itm.tm) |
277 dim_after = range_dim(itm.after) | 286 dim_after = range_dim(itm.after) |
278 | 287 |
396 ```math | 405 ```math |
397 (A⊗B⊗C)v = [(A⊗I_{|M|}⊗I_{|P|}) [(I_{|J|}⊗B⊗I_{|P|}) [(I_{|J|}⊗I_{|N|}⊗C)v]]] | 406 (A⊗B⊗C)v = [(A⊗I_{|M|}⊗I_{|P|}) [(I_{|J|}⊗B⊗I_{|P|}) [(I_{|J|}⊗I_{|N|}⊗C)v]]] |
398 ``` | 407 ``` |
399 """ | 408 """ |
400 function LazyOuterProduct end | 409 function LazyOuterProduct end |
401 export LazyOuterProduct | |
402 | 410 |
403 function LazyOuterProduct(tm1::TensorMapping{T}, tm2::TensorMapping{T}) where T | 411 function LazyOuterProduct(tm1::TensorMapping{T}, tm2::TensorMapping{T}) where T |
404 itm1 = InflatedTensorMapping(tm1, IdentityMapping{T}(range_size(tm2))) | 412 itm1 = InflatedTensorMapping(tm1, IdentityMapping{T}(range_size(tm2))) |
405 itm2 = InflatedTensorMapping(IdentityMapping{T}(domain_size(tm1)),tm2) | 413 itm2 = InflatedTensorMapping(IdentityMapping{T}(domain_size(tm1)),tm2) |
406 | 414 |
412 LazyOuterProduct(t1::IdentityMapping, t2::TensorMapping) = InflatedTensorMapping(t1, t2) | 420 LazyOuterProduct(t1::IdentityMapping, t2::TensorMapping) = InflatedTensorMapping(t1, t2) |
413 | 421 |
414 LazyOuterProduct(tms::Vararg{TensorMapping}) = foldl(LazyOuterProduct, tms) | 422 LazyOuterProduct(tms::Vararg{TensorMapping}) = foldl(LazyOuterProduct, tms) |
415 | 423 |
416 ⊗(a::TensorMapping, b::TensorMapping) = LazyOuterProduct(a,b) | 424 ⊗(a::TensorMapping, b::TensorMapping) = LazyOuterProduct(a,b) |
417 export ⊗ | |
418 | 425 |
419 | 426 |
420 function check_domain_size(tm::TensorMapping, sz) | 427 function check_domain_size(tm::TensorMapping, sz) |
421 if domain_size(tm) != sz | 428 if domain_size(tm) != sz |
422 throw(SizeMismatch(tm,sz)) | 429 throw(SizeMismatch(tm,sz)) |
425 | 432 |
426 struct SizeMismatch <: Exception | 433 struct SizeMismatch <: Exception |
427 tm::TensorMapping | 434 tm::TensorMapping |
428 sz | 435 sz |
429 end | 436 end |
430 export SizeMismatch | |
431 | 437 |
432 function Base.showerror(io::IO, err::SizeMismatch) | 438 function Base.showerror(io::IO, err::SizeMismatch) |
433 print(io, "SizeMismatch: ") | 439 print(io, "SizeMismatch: ") |
434 print(io, "domain size $(domain_size(err.tm)) of TensorMapping not matching size $(err.sz)") | 440 print(io, "domain size $(domain_size(err.tm)) of TensorMapping not matching size $(err.sz)") |
435 end | 441 end |