Mercurial > repos > public > sbplib_julia
comparison src/LazyTensors/lazy_tensor_operations.jl @ 562:8f7919a9b398 feature/boundary_ops
Merge with default
author | Vidar Stiernström <vidar.stiernstrom@it.uu.se> |
---|---|
date | Mon, 30 Nov 2020 18:30:24 +0100 |
parents | a5caa934b35f 53828d3ed132 |
children | 1c512e796c6d |
comparison
equal
deleted
inserted
replaced
544:884be64e82d9 | 562:8f7919a9b398 |
---|---|
12 o::AA | 12 o::AA |
13 end | 13 end |
14 # TODO: Do boundschecking on creation! | 14 # TODO: Do boundschecking on creation! |
15 export LazyTensorMappingApplication | 15 export LazyTensorMappingApplication |
16 | 16 |
17 # TODO: Go through and remove unneccerary type parameters on functions | 17 Base.getindex(ta::LazyTensorMappingApplication{T,R}, I::Vararg{Any,R}) where {T,R} = apply(ta.t, ta.o, I...) |
18 Base.getindex(ta::LazyTensorMappingApplication{T,0}, I::Index) where T = apply(ta.t, ta.o, I) | |
19 Base.getindex(ta::LazyTensorMappingApplication{T,R}, I::Vararg{Index,R}) where {T,R} = apply(ta.t, ta.o, I...) | |
20 Base.getindex(ta::LazyTensorMappingApplication{T,R}, I::Vararg{Int,R}) where {T,R} = apply(ta.t, ta.o, Index{Unknown}.(I)...) | |
21 Base.size(ta::LazyTensorMappingApplication) = range_size(ta.t) | 18 Base.size(ta::LazyTensorMappingApplication) = range_size(ta.t) |
22 # TODO: What else is needed to implement the AbstractArray interface? | 19 # TODO: What else is needed to implement the AbstractArray interface? |
23 | 20 |
24 Base.:*(a::TensorMapping, v::AbstractArray) = LazyTensorMappingApplication(a,v) | 21 Base.:*(a::TensorMapping, v::AbstractArray) = LazyTensorMappingApplication(a,v) |
25 Base.:*(a::TensorMapping, b::TensorMapping) = throw(MethodError(Base.:*,(a,b))) | 22 Base.:*(a::TensorMapping, b::TensorMapping) = throw(MethodError(Base.:*,(a,b))) |
48 # # TBD: Should this be implemented on a type by type basis or through a trait to provide earlier errors? | 45 # # TBD: Should this be implemented on a type by type basis or through a trait to provide earlier errors? |
49 # Jonatan 2020-09-25: Is the problem that you can take the transpose of any TensorMapping even if it doesn't implement `apply_transpose`? | 46 # Jonatan 2020-09-25: Is the problem that you can take the transpose of any TensorMapping even if it doesn't implement `apply_transpose`? |
50 Base.adjoint(tm::TensorMapping) = LazyTensorMappingTranspose(tm) | 47 Base.adjoint(tm::TensorMapping) = LazyTensorMappingTranspose(tm) |
51 Base.adjoint(tmt::LazyTensorMappingTranspose) = tmt.tm | 48 Base.adjoint(tmt::LazyTensorMappingTranspose) = tmt.tm |
52 | 49 |
53 apply(tmt::LazyTensorMappingTranspose{T,R,D}, v::AbstractArray{T,R}, I::Vararg{Index,D}) where {T,R,D} = apply_transpose(tmt.tm, v, I...) | 50 apply(tmt::LazyTensorMappingTranspose{T,R,D}, v::AbstractArray{T,R}, I::Vararg{Any,D}) where {T,R,D} = apply_transpose(tmt.tm, v, I...) |
54 apply_transpose(tmt::LazyTensorMappingTranspose{T,R,D}, v::AbstractArray{T,D}, I::Vararg{Index,R}) where {T,R,D} = apply(tmt.tm, v, I...) | 51 apply_transpose(tmt::LazyTensorMappingTranspose{T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmt.tm, v, I...) |
55 | 52 |
56 range_size(tmt::LazyTensorMappingTranspose) = domain_size(tmt.tm) | 53 range_size(tmt::LazyTensorMappingTranspose) = domain_size(tmt.tm) |
57 domain_size(tmt::LazyTensorMappingTranspose) = range_size(tmt.tm) | 54 domain_size(tmt::LazyTensorMappingTranspose) = range_size(tmt.tm) |
58 | 55 |
59 | 56 |
65 return new{Op,T,R,D,T1,T2}(tm1,tm2) | 62 return new{Op,T,R,D,T1,T2}(tm1,tm2) |
66 end | 63 end |
67 end | 64 end |
68 # TODO: Boundschecking in constructor. | 65 # TODO: Boundschecking in constructor. |
69 | 66 |
70 apply(tmBinOp::LazyTensorMappingBinaryOperation{:+,T,R,D}, v::AbstractArray{T,D}, I::Vararg{Index,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) + apply(tmBinOp.tm2, v, I...) | 67 apply(tmBinOp::LazyTensorMappingBinaryOperation{:+,T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) + apply(tmBinOp.tm2, v, I...) |
71 apply(tmBinOp::LazyTensorMappingBinaryOperation{:-,T,R,D}, v::AbstractArray{T,D}, I::Vararg{Index,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) - apply(tmBinOp.tm2, v, I...) | 68 apply(tmBinOp::LazyTensorMappingBinaryOperation{:-,T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) - apply(tmBinOp.tm2, v, I...) |
72 | 69 |
73 range_size(tmBinOp::LazyTensorMappingBinaryOperation{Op,T,R,D}) where {Op,T,R,D} = range_size(tmBinOp.tm1) | 70 range_size(tmBinOp::LazyTensorMappingBinaryOperation) = range_size(tmBinOp.tm1) |
74 domain_size(tmBinOp::LazyTensorMappingBinaryOperation{Op,T,R,D}) where {Op,T,R,D} = domain_size(tmBinOp.tm1) | 71 domain_size(tmBinOp::LazyTensorMappingBinaryOperation) = domain_size(tmBinOp.tm1) |
75 | 72 |
76 Base.:+(tm1::TensorMapping{T,R,D}, tm2::TensorMapping{T,R,D}) where {T,R,D} = LazyTensorMappingBinaryOperation{:+,T,R,D}(tm1,tm2) | 73 Base.:+(tm1::TensorMapping{T,R,D}, tm2::TensorMapping{T,R,D}) where {T,R,D} = LazyTensorMappingBinaryOperation{:+,T,R,D}(tm1,tm2) |
77 Base.:-(tm1::TensorMapping{T,R,D}, tm2::TensorMapping{T,R,D}) where {T,R,D} = LazyTensorMappingBinaryOperation{:-,T,R,D}(tm1,tm2) | 74 Base.:-(tm1::TensorMapping{T,R,D}, tm2::TensorMapping{T,R,D}) where {T,R,D} = LazyTensorMappingBinaryOperation{:-,T,R,D}(tm1,tm2) |
78 | 75 |
79 """ | 76 """ |
93 export TensorMappingComposition | 90 export TensorMappingComposition |
94 | 91 |
95 range_size(tm::TensorMappingComposition) = range_size(tm.t1) | 92 range_size(tm::TensorMappingComposition) = range_size(tm.t1) |
96 domain_size(tm::TensorMappingComposition) = domain_size(tm.t2) | 93 domain_size(tm::TensorMappingComposition) = domain_size(tm.t2) |
97 | 94 |
98 function apply(c::TensorMappingComposition{T,R,K,D}, v::AbstractArray{T,D}, I::Vararg{S,R} where S) where {T,R,K,D} | 95 function apply(c::TensorMappingComposition{T,R,K,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,K,D} |
99 apply(c.t1, c.t2*v, I...) | 96 apply(c.t1, c.t2*v, I...) |
100 end | 97 end |
101 | 98 |
102 function apply_transpose(c::TensorMappingComposition{T,R,K,D}, v::AbstractArray{T,R}, I::Vararg{S,D} where S) where {T,R,K,D} | 99 function apply_transpose(c::TensorMappingComposition{T,R,K,D}, v::AbstractArray{T,R}, I::Vararg{Any,D}) where {T,R,K,D} |
103 apply_transpose(c.t2, c.t1'*v, I...) | 100 apply_transpose(c.t2, c.t1'*v, I...) |
104 end | 101 end |
105 | 102 |
106 Base.@propagate_inbounds Base.:∘(s::TensorMapping, t::TensorMapping) = TensorMappingComposition(s,t) | 103 Base.@propagate_inbounds Base.:∘(s::TensorMapping, t::TensorMapping) = TensorMappingComposition(s,t) |
107 | 104 |
130 export LazyLinearMap | 127 export LazyLinearMap |
131 | 128 |
132 range_size(llm::LazyLinearMap) = size(llm.A)[[llm.range_indicies...]] | 129 range_size(llm::LazyLinearMap) = size(llm.A)[[llm.range_indicies...]] |
133 domain_size(llm::LazyLinearMap) = size(llm.A)[[llm.domain_indicies...]] | 130 domain_size(llm::LazyLinearMap) = size(llm.A)[[llm.domain_indicies...]] |
134 | 131 |
135 function apply(llm::LazyLinearMap{T,R,D}, v::AbstractArray{T,D}, I::Vararg{Index,R}) where {T,R,D} | 132 function apply(llm::LazyLinearMap{T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D} |
136 view_index = ntuple(i->:,ndims(llm.A)) | 133 view_index = ntuple(i->:,ndims(llm.A)) |
137 for i ∈ 1:R | 134 for i ∈ 1:R |
138 view_index = Base.setindex(view_index, Int(I[i]), llm.range_indicies[i]) | 135 view_index = Base.setindex(view_index, Int(I[i]), llm.range_indicies[i]) |
139 end | 136 end |
140 A_view = @view llm.A[view_index...] | 137 A_view = @view llm.A[view_index...] |
141 return sum(A_view.*v) | 138 return sum(A_view.*v) |
142 end | 139 end |
143 | 140 |
144 function apply_transpose(llm::LazyLinearMap{T,R,D}, v::AbstractArray{T,R}, I::Vararg{Index,D}) where {T,R,D} | 141 function apply_transpose(llm::LazyLinearMap{T,R,D}, v::AbstractArray{T,R}, I::Vararg{Any,D}) where {T,R,D} |
145 apply(LazyLinearMap(llm.A, llm.domain_indicies, llm.range_indicies), v, I...) | 142 apply(LazyLinearMap(llm.A, llm.domain_indicies, llm.range_indicies), v, I...) |
146 end | 143 end |
147 | 144 |
148 | 145 |
149 """ | 146 """ |
238 InflatedTensorMapping(before::IdentityMapping, tm::TensorMapping{T}) where T = InflatedTensorMapping(before,tm,IdentityMapping{T}()) | 235 InflatedTensorMapping(before::IdentityMapping, tm::TensorMapping{T}) where T = InflatedTensorMapping(before,tm,IdentityMapping{T}()) |
239 InflatedTensorMapping(tm::TensorMapping{T}, after::IdentityMapping) where T = InflatedTensorMapping(IdentityMapping{T}(),tm,after) | 236 InflatedTensorMapping(tm::TensorMapping{T}, after::IdentityMapping) where T = InflatedTensorMapping(IdentityMapping{T}(),tm,after) |
240 # Resolve ambiguity between the two previous methods | 237 # Resolve ambiguity between the two previous methods |
241 InflatedTensorMapping(I1::IdentityMapping{T}, I2::IdentityMapping{T}) where T = InflatedTensorMapping(I1,I2,IdentityMapping{T}()) | 238 InflatedTensorMapping(I1::IdentityMapping{T}, I2::IdentityMapping{T}) where T = InflatedTensorMapping(I1,I2,IdentityMapping{T}()) |
242 | 239 |
243 # TODO: Implement syntax and constructors for products of different combinations of InflatedTensorMapping and IdentityMapping | |
244 | |
245 # TODO: Implement some pretty printing in terms of ⊗. E.g InflatedTensorMapping(I(3),B,I(2)) -> I(3)⊗B⊗I(2) | 240 # TODO: Implement some pretty printing in terms of ⊗. E.g InflatedTensorMapping(I(3),B,I(2)) -> I(3)⊗B⊗I(2) |
246 | 241 |
247 function range_size(itm::InflatedTensorMapping) | 242 function range_size(itm::InflatedTensorMapping) |
248 return flatten_tuple( | 243 return flatten_tuple( |
249 range_size(itm.before), | 244 range_size(itm.before), |
259 domain_size(itm.after), | 254 domain_size(itm.after), |
260 ) | 255 ) |
261 end | 256 end |
262 | 257 |
263 function apply(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D} | 258 function apply(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D} |
264 view_index, inner_index = split_index(itm, I...) | 259 dim_before = range_dim(itm.before) |
260 dim_domain = domain_dim(itm.tm) | |
261 dim_range = range_dim(itm.tm) | |
262 dim_after = range_dim(itm.after) | |
263 | |
264 view_index, inner_index = split_index(Val(dim_before), Val(dim_domain), Val(dim_range), Val(dim_after), I...) | |
265 | 265 |
266 v_inner = view(v, view_index...) | 266 v_inner = view(v, view_index...) |
267 return apply(itm.tm, v_inner, inner_index...) | 267 return apply(itm.tm, v_inner, inner_index...) |
268 end | 268 end |
269 | 269 |
270 | 270 function apply_transpose(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{T,R}, I::Vararg{Any,D}) where {T,R,D} |
271 """ | 271 dim_before = range_dim(itm.before) |
272 split_index(...) | 272 dim_domain = domain_dim(itm.tm) |
273 | 273 dim_range = range_dim(itm.tm) |
274 Splits the multi-index into two parts. One part for the view that the inner TensorMapping acts on, and one part for indexing the result | 274 dim_after = range_dim(itm.after) |
275 | |
276 view_index, inner_index = split_index(Val(dim_before), Val(dim_range), Val(dim_domain), Val(dim_after), I...) | |
277 | |
278 v_inner = view(v, view_index...) | |
279 return apply_transpose(itm.tm, v_inner, inner_index...) | |
280 end | |
281 | |
282 | |
283 """ | |
284 split_index(::Val{dim_before}, ::Val{dim_view}, ::Val{dim_index}, ::Val{dim_after}, I...) | |
285 | |
286 Splits the multi-index `I` into two parts. One part which is expected to be | |
287 used as a view, and one which is expected to be used as an index. | |
275 Eg. | 288 Eg. |
276 ``` | 289 ``` |
277 (1,2,3,4) -> (1,:,:,4), (2,3) | 290 split_index(Val(1),Val(3),Val(2),Val(1),(1,2,3,4)) -> (1,:,:,:,4), (2,3) |
278 ``` | 291 ``` |
279 """ | 292 |
280 function split_index(itm::InflatedTensorMapping{T,R,D}, I::Vararg{Any,R}) where {T,R,D} | 293 `dim_view` controls how many colons are in the view, and `dim_index` controls |
281 I_before = slice_tuple(I, Val(1), Val(range_dim(itm.before))) | 294 how many elements are extracted from the middle. |
282 I_after = slice_tuple(I, Val(R-range_dim(itm.after)+1), Val(R)) | 295 `dim_before` and `dim_after` decides the length of the index parts before and after the colons in the view index. |
283 | 296 |
284 view_index = (I_before..., ntuple((i)->:,domain_dim(itm.tm))..., I_after...) | 297 Arguments should satisfy `length(I) == dim_before+B_domain+dim_after`. |
285 inner_index = slice_tuple(I, Val(range_dim(itm.before)+1), Val(R-range_dim(itm.after))) | 298 |
286 | 299 The returned values satisfy |
287 return (view_index, inner_index) | 300 * `length(view_index) == dim_before + dim_view + dim_after` |
301 * `length(I_middle) == dim_index` | |
302 """ | |
303 function split_index(::Val{dim_before}, ::Val{dim_view}, ::Val{dim_index}, ::Val{dim_after}, I...) where {dim_before,dim_view, dim_index,dim_after} | |
304 I_before, I_middle, I_after = split_tuple(I, Val(dim_before), Val(dim_index)) | |
305 | |
306 view_index = (I_before..., ntuple((i)->:, dim_view)..., I_after...) | |
307 | |
308 return view_index, I_middle | |
288 end | 309 end |
289 | 310 |
290 # TODO: Can this be replaced by something more elegant while still being type stable? 2020-10-21 | 311 # TODO: Can this be replaced by something more elegant while still being type stable? 2020-10-21 |
291 # See: | 312 # See: |
292 # https://github.com/JuliaLang/julia/issues/34884 | 313 # https://github.com/JuliaLang/julia/issues/34884 |
298 Equivalent to t[l:u] but type stable. | 319 Equivalent to t[l:u] but type stable. |
299 """ | 320 """ |
300 function slice_tuple(t,::Val{L},::Val{U}) where {L,U} | 321 function slice_tuple(t,::Val{L},::Val{U}) where {L,U} |
301 return ntuple(i->t[i+L-1], U-L+1) | 322 return ntuple(i->t[i+L-1], U-L+1) |
302 end | 323 end |
324 | |
325 """ | |
326 split_tuple(t::Tuple{...}, ::Val{M}) where {N,M} | |
327 | |
328 Split the tuple `t` into two parts. the first part is `M` long. | |
329 E.g | |
330 ``` | |
331 split_tuple((1,2,3,4),Val(3)) -> (1,2,3), (4,) | |
332 ``` | |
333 """ | |
334 function split_tuple(t::NTuple{N},::Val{M}) where {N,M} | |
335 return slice_tuple(t,Val(1), Val(M)), slice_tuple(t,Val(M+1), Val(N)) | |
336 end | |
337 | |
338 """ | |
339 split_tuple(t::Tuple{...},::Val{M},::Val{K}) where {N,M,K} | |
340 | |
341 Same as `split_tuple(t::NTuple{N},::Val{M})` but splits the tuple in three parts. With the first | |
342 two parts having lenght `M` and `K`. | |
343 """ | |
344 function split_tuple(t::NTuple{N},::Val{M},::Val{K}) where {N,M,K} | |
345 p1, tail = split_tuple(t, Val(M)) | |
346 p2, p3 = split_tuple(tail, Val(K)) | |
347 return p1,p2,p3 | |
348 end | |
349 | |
303 | 350 |
304 """ | 351 """ |
305 flatten_tuple(t) | 352 flatten_tuple(t) |
306 | 353 |
307 Takes a nested tuple and flattens the whole structure | 354 Takes a nested tuple and flattens the whole structure |