comparison src/LazyTensors/lazy_tensor_operations.jl @ 1207:f1c2a4fa0ee1 performance/get_region_type_inference

Merge default
author Jonatan Werpers <jonatan@werpers.com>
date Fri, 03 Feb 2023 22:14:47 +0100
parents b41180efb6c2 2278730f9cee
children
comparison
equal deleted inserted replaced
919:b41180efb6c2 1207:f1c2a4fa0ee1
1 """ 1 """
2 LazyTensorMappingApplication{T,R,D} <: LazyArray{T,R} 2 TensorApplication{T,R,D} <: LazyArray{T,R}
3 3
4 Struct for lazy application of a TensorMapping. Created using `*`. 4 Struct for lazy application of a LazyTensor. Created using `*`.
5 5
6 Allows the result of a `TensorMapping` applied to a vector to be treated as an `AbstractArray`. 6 Allows the result of a `LazyTensor` applied to a vector to be treated as an `AbstractArray`.
7 With a mapping `m` and a vector `v` the LazyTensorMappingApplication object can be created by `m*v`. 7 With a mapping `m` and a vector `v` the TensorApplication object can be created by `m*v`.
8 The actual result will be calcualted when indexing into `m*v`. 8 The actual result will be calcualted when indexing into `m*v`.
9 """ 9 """
10 struct LazyTensorMappingApplication{T,R,D, TM<:TensorMapping{T,R,D}, AA<:AbstractArray{T,D}} <: LazyArray{T,R} 10 struct TensorApplication{T,R,D, TM<:LazyTensor{<:Any,R,D}, AA<:AbstractArray{<:Any,D}} <: LazyArray{T,R}
11 t::TM 11 t::TM
12 o::AA 12 o::AA
13 end 13
14 # TODO: Do boundschecking on creation! 14 function TensorApplication(t::LazyTensor{<:Any,R,D}, o::AbstractArray{<:Any,D}) where {R,D}
15 export LazyTensorMappingApplication 15 @boundscheck check_domain_size(t, size(o))
16 16 I = ntuple(i->1, range_dim(t))
17 Base.getindex(ta::LazyTensorMappingApplication{T,R}, I::Vararg{Any,R}) where {T,R} = apply(ta.t, ta.o, I...) 17 T = typeof(apply(t,o,I...))
18 Base.size(ta::LazyTensorMappingApplication) = range_size(ta.t) 18 return new{T,R,D,typeof(t), typeof(o)}(t,o)
19 # TODO: What else is needed to implement the AbstractArray interface? 19 end
20 20 end
21 Base.:*(a::TensorMapping, v::AbstractArray) = LazyTensorMappingApplication(a,v) 21
22 Base.:*(a::TensorMapping, b::TensorMapping) = throw(MethodError(Base.:*,(a,b))) 22 function Base.getindex(ta::TensorApplication{T,R}, I::Vararg{Any,R}) where {T,R}
23 Base.:*(a::TensorMapping, args::Union{TensorMapping, AbstractArray}...) = foldr(*,(a,args...)) 23 @boundscheck checkbounds(ta, Int.(I)...)
24 24 return @inbounds apply(ta.t, ta.o, I...)
25 # # We need the associativity to be a→b→c = a→(b→c), which is the case for '→' 25 end
26 # # Should we overload some other infix binary opesrator? 26 Base.@propagate_inbounds Base.getindex(ta::TensorApplication{T,1} where T, I::CartesianIndex{1}) = ta[Tuple(I)...] # Would otherwise be caught in the previous method.
27 # →(tm::TensorMapping{T,R,D}, o::AbstractArray{T,D}) where {T,R,D} = LazyTensorMappingApplication(tm,o) 27 Base.size(ta::TensorApplication) = range_size(ta.t)
28 # TODO: We need to be really careful about good error messages. 28
29 # For example what happens if you try to multiply LazyTensorMappingApplication with a TensorMapping(wrong order)? 29
30 30 """
31 """ 31 TensorTranspose{T,R,D} <: LazyTensor{T,D,R}
32 LazyTensorMappingTranspose{T,R,D} <: TensorMapping{T,D,R} 32
33 33 Struct for lazy transpose of a LazyTensor.
34 Struct for lazy transpose of a TensorMapping.
35 34
36 If a mapping implements the the `apply_transpose` method this allows working with 35 If a mapping implements the the `apply_transpose` method this allows working with
37 the transpose of mapping `m` by using `m'`. `m'` will work as a regular TensorMapping lazily calling 36 the transpose of mapping `m` by using `m'`. `m'` will work as a regular LazyTensor lazily calling
38 the appropriate methods of `m`. 37 the appropriate methods of `m`.
39 """ 38 """
40 struct LazyTensorMappingTranspose{T,R,D, TM<:TensorMapping{T,R,D}} <: TensorMapping{T,D,R} 39 struct TensorTranspose{T,R,D, TM<:LazyTensor{T,R,D}} <: LazyTensor{T,D,R}
41 tm::TM 40 tm::TM
42 end 41 end
43 export LazyTensorMappingTranspose
44 42
45 # # TBD: Should this be implemented on a type by type basis or through a trait to provide earlier errors? 43 # # TBD: Should this be implemented on a type by type basis or through a trait to provide earlier errors?
46 # Jonatan 2020-09-25: Is the problem that you can take the transpose of any TensorMapping even if it doesn't implement `apply_transpose`? 44 # Jonatan 2020-09-25: Is the problem that you can take the transpose of any LazyTensor even if it doesn't implement `apply_transpose`?
47 Base.adjoint(tm::TensorMapping) = LazyTensorMappingTranspose(tm) 45 Base.adjoint(tm::LazyTensor) = TensorTranspose(tm)
48 Base.adjoint(tmt::LazyTensorMappingTranspose) = tmt.tm 46 Base.adjoint(tmt::TensorTranspose) = tmt.tm
49 47
50 apply(tmt::LazyTensorMappingTranspose{T,R,D}, v::AbstractArray{T,R}, I::Vararg{Any,D}) where {T,R,D} = apply_transpose(tmt.tm, v, I...) 48 apply(tmt::TensorTranspose{T,R,D}, v::AbstractArray{<:Any,R}, I::Vararg{Any,D}) where {T,R,D} = apply_transpose(tmt.tm, v, I...)
51 apply_transpose(tmt::LazyTensorMappingTranspose{T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmt.tm, v, I...) 49 apply_transpose(tmt::TensorTranspose{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmt.tm, v, I...)
52 50
53 range_size(tmt::LazyTensorMappingTranspose) = domain_size(tmt.tm) 51 range_size(tmt::TensorTranspose) = domain_size(tmt.tm)
54 domain_size(tmt::LazyTensorMappingTranspose) = range_size(tmt.tm) 52 domain_size(tmt::TensorTranspose) = range_size(tmt.tm)
55 53
56 54
57 struct LazyTensorMappingBinaryOperation{Op,T,R,D,T1<:TensorMapping{T,R,D},T2<:TensorMapping{T,R,D}} <: TensorMapping{T,D,R} 55 struct ElementwiseTensorOperation{Op,T,R,D,T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}} <: LazyTensor{T,D,R}
58 tm1::T1 56 tm1::T1
59 tm2::T2 57 tm2::T2
60 58
61 @inline function LazyTensorMappingBinaryOperation{Op,T,R,D}(tm1::T1,tm2::T2) where {Op,T,R,D, T1<:TensorMapping{T,R,D},T2<:TensorMapping{T,R,D}} 59 function ElementwiseTensorOperation{Op,T,R,D}(tm1::T1,tm2::T2) where {Op,T,R,D, T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}}
60 @boundscheck check_domain_size(tm2, domain_size(tm1))
61 @boundscheck check_range_size(tm2, range_size(tm1))
62 return new{Op,T,R,D,T1,T2}(tm1,tm2) 62 return new{Op,T,R,D,T1,T2}(tm1,tm2)
63 end 63 end
64 end 64 end
65 # TODO: Boundschecking in constructor. 65
66 66 ElementwiseTensorOperation{Op}(s,t) where Op = ElementwiseTensorOperation{Op,eltype(s), range_dim(s), domain_dim(s)}(s,t)
67 apply(tmBinOp::LazyTensorMappingBinaryOperation{:+,T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) + apply(tmBinOp.tm2, v, I...) 67
68 apply(tmBinOp::LazyTensorMappingBinaryOperation{:-,T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) - apply(tmBinOp.tm2, v, I...) 68 apply(tmBinOp::ElementwiseTensorOperation{:+,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) + apply(tmBinOp.tm2, v, I...)
69 69 apply(tmBinOp::ElementwiseTensorOperation{:-,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) - apply(tmBinOp.tm2, v, I...)
70 range_size(tmBinOp::LazyTensorMappingBinaryOperation) = range_size(tmBinOp.tm1) 70
71 domain_size(tmBinOp::LazyTensorMappingBinaryOperation) = domain_size(tmBinOp.tm1) 71 range_size(tmBinOp::ElementwiseTensorOperation) = range_size(tmBinOp.tm1)
72 72 domain_size(tmBinOp::ElementwiseTensorOperation) = domain_size(tmBinOp.tm1)
73 Base.:+(tm1::TensorMapping{T,R,D}, tm2::TensorMapping{T,R,D}) where {T,R,D} = LazyTensorMappingBinaryOperation{:+,T,R,D}(tm1,tm2) 73
74 Base.:-(tm1::TensorMapping{T,R,D}, tm2::TensorMapping{T,R,D}) where {T,R,D} = LazyTensorMappingBinaryOperation{:-,T,R,D}(tm1,tm2) 74
75 75 """
76 """ 76 TensorComposition{T,R,K,D}
77 TensorMappingComposition{T,R,K,D} 77
78 78 Lazily compose two `LazyTensor`s, so that they can be handled as a single `LazyTensor`.
79 Lazily compose two `TensorMapping`s, so that they can be handled as a single `TensorMapping`. 79 """
80 """ 80 struct TensorComposition{T,R,K,D, TM1<:LazyTensor{T,R,K}, TM2<:LazyTensor{T,K,D}} <: LazyTensor{T,R,D}
81 struct TensorMappingComposition{T,R,K,D, TM1<:TensorMapping{T,R,K}, TM2<:TensorMapping{T,K,D}} <: TensorMapping{T,R,D}
82 t1::TM1 81 t1::TM1
83 t2::TM2 82 t2::TM2
84 83
85 @inline function TensorMappingComposition(t1::TensorMapping{T,R,K}, t2::TensorMapping{T,K,D}) where {T,R,K,D} 84 function TensorComposition(t1::LazyTensor{T,R,K}, t2::LazyTensor{T,K,D}) where {T,R,K,D}
86 @boundscheck check_domain_size(t1, range_size(t2)) 85 @boundscheck check_domain_size(t1, range_size(t2))
87 return new{T,R,K,D, typeof(t1), typeof(t2)}(t1,t2) 86 return new{T,R,K,D, typeof(t1), typeof(t2)}(t1,t2)
88 end 87 end
89 end 88 end
90 export TensorMappingComposition 89
91 90 range_size(tm::TensorComposition) = range_size(tm.t1)
92 range_size(tm::TensorMappingComposition) = range_size(tm.t1) 91 domain_size(tm::TensorComposition) = domain_size(tm.t2)
93 domain_size(tm::TensorMappingComposition) = domain_size(tm.t2) 92
94 93 function apply(c::TensorComposition{T,R,K,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,K,D}
95 function apply(c::TensorMappingComposition{T,R,K,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,K,D}
96 apply(c.t1, c.t2*v, I...) 94 apply(c.t1, c.t2*v, I...)
97 end 95 end
98 96
99 function apply_transpose(c::TensorMappingComposition{T,R,K,D}, v::AbstractArray{T,R}, I::Vararg{Any,D}) where {T,R,K,D} 97 function apply_transpose(c::TensorComposition{T,R,K,D}, v::AbstractArray{<:Any,R}, I::Vararg{Any,D}) where {T,R,K,D}
100 apply_transpose(c.t2, c.t1'*v, I...) 98 apply_transpose(c.t2, c.t1'*v, I...)
101 end 99 end
102 100
103 Base.@propagate_inbounds Base.:∘(s::TensorMapping, t::TensorMapping) = TensorMappingComposition(s,t) 101 """
104 102 TensorComposition(tm, tmi::IdentityTensor)
105 """ 103 TensorComposition(tmi::IdentityTensor, tm)
106 LazyLinearMap{T,R,D,...}(A, range_indicies, domain_indicies) 104
107 105 Composes a `Tensormapping` `tm` with an `IdentityTensor` `tmi`, by returning `tm`
108 TensorMapping defined by the AbstractArray A. `range_indicies` and `domain_indicies` define which indicies of A should 106 """
109 be considerd the range and domain of the TensorMapping. Each set of indices must be ordered in ascending order. 107 function TensorComposition(tm::LazyTensor{T,R,D}, tmi::IdentityTensor{T,D}) where {T,R,D}
110
111 For instance, if A is a m x n matrix, and range_size = (1,), domain_size = (2,), then the LazyLinearMap performs the
112 standard matrix-vector product on vectors of size n.
113 """
114 struct LazyLinearMap{T,R,D, RD, AA<:AbstractArray{T,RD}} <: TensorMapping{T,R,D}
115 A::AA
116 range_indicies::NTuple{R,Int}
117 domain_indicies::NTuple{D,Int}
118
119 function LazyLinearMap(A::AA, range_indicies::NTuple{R,Int}, domain_indicies::NTuple{D,Int}) where {T,R,D, RD, AA<:AbstractArray{T,RD}}
120 if !issorted(range_indicies) || !issorted(domain_indicies)
121 throw(DomainError("range_indicies and domain_indicies must be sorted in ascending order"))
122 end
123
124 return new{T,R,D,RD,AA}(A,range_indicies,domain_indicies)
125 end
126 end
127 export LazyLinearMap
128
129 range_size(llm::LazyLinearMap) = size(llm.A)[[llm.range_indicies...]]
130 domain_size(llm::LazyLinearMap) = size(llm.A)[[llm.domain_indicies...]]
131
132 function apply(llm::LazyLinearMap{T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D}
133 view_index = ntuple(i->:,ndims(llm.A))
134 for i ∈ 1:R
135 view_index = Base.setindex(view_index, Int(I[i]), llm.range_indicies[i])
136 end
137 A_view = @view llm.A[view_index...]
138 return sum(A_view.*v)
139 end
140
141 function apply_transpose(llm::LazyLinearMap{T,R,D}, v::AbstractArray{T,R}, I::Vararg{Any,D}) where {T,R,D}
142 apply(LazyLinearMap(llm.A, llm.domain_indicies, llm.range_indicies), v, I...)
143 end
144
145
146 """
147 IdentityMapping{T,D} <: TensorMapping{T,D,D}
148
149 The lazy identity TensorMapping for a given size. Usefull for building up higher dimensional tensor mappings from lower
150 dimensional ones through outer products. Also used in the Implementation for InflatedTensorMapping.
151 """
152 struct IdentityMapping{T,D} <: TensorMapping{T,D,D}
153 size::NTuple{D,Int}
154 end
155 export IdentityMapping
156
157 IdentityMapping{T}(size::NTuple{D,Int}) where {T,D} = IdentityMapping{T,D}(size)
158 IdentityMapping{T}(size::Vararg{Int,D}) where {T,D} = IdentityMapping{T,D}(size)
159 IdentityMapping(size::Vararg{Int,D}) where D = IdentityMapping{Float64,D}(size)
160
161 range_size(tmi::IdentityMapping) = tmi.size
162 domain_size(tmi::IdentityMapping) = tmi.size
163
164 apply(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...]
165 apply_transpose(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...]
166
167 """
168 Base.:∘(tm, tmi)
169 Base.:∘(tmi, tm)
170
171 Composes a `Tensormapping` `tm` with an `IdentityMapping` `tmi`, by returning `tm`
172 """
173 @inline function Base.:∘(tm::TensorMapping{T,R,D}, tmi::IdentityMapping{T,D}) where {T,R,D}
174 @boundscheck check_domain_size(tm, range_size(tmi)) 108 @boundscheck check_domain_size(tm, range_size(tmi))
175 return tm 109 return tm
176 end 110 end
177 111
178 @inline function Base.:∘(tmi::IdentityMapping{T,R}, tm::TensorMapping{T,R,D}) where {T,R,D} 112 function TensorComposition(tmi::IdentityTensor{T,R}, tm::LazyTensor{T,R,D}) where {T,R,D}
179 @boundscheck check_domain_size(tmi, range_size(tm)) 113 @boundscheck check_domain_size(tmi, range_size(tm))
180 return tm 114 return tm
181 end 115 end
182 # Specialization for the case where tm is an IdentityMapping. Required to resolve ambiguity. 116 # Specialization for the case where tm is an IdentityTensor. Required to resolve ambiguity.
183 @inline function Base.:∘(tm::IdentityMapping{T,D}, tmi::IdentityMapping{T,D}) where {T,D} 117 function TensorComposition(tm::IdentityTensor{T,D}, tmi::IdentityTensor{T,D}) where {T,D}
184 @boundscheck check_domain_size(tm, range_size(tmi)) 118 @boundscheck check_domain_size(tm, range_size(tmi))
185 return tmi 119 return tmi
186 end 120 end
187 121
188 122 Base.:*(a::T, tm::LazyTensor{T}) where T = TensorComposition(ScalingTensor{T,range_dim(tm)}(a,range_size(tm)), tm)
189 """ 123 Base.:*(tm::LazyTensor{T}, a::T) where T = a*tm
190 InflatedTensorMapping{T,R,D} <: TensorMapping{T,R,D} 124
191 125 """
192 An inflated `TensorMapping` with dimensions added before and afer its actual dimensions. 126 InflatedTensor{T,R,D} <: LazyTensor{T,R,D}
193 """ 127
194 struct InflatedTensorMapping{T,R,D,D_before,R_middle,D_middle,D_after, TM<:TensorMapping{T,R_middle,D_middle}} <: TensorMapping{T,R,D} 128 An inflated `LazyTensor` with dimensions added before and afer its actual dimensions.
195 before::IdentityMapping{T,D_before} 129 """
130 struct InflatedTensor{T,R,D,D_before,R_middle,D_middle,D_after, TM<:LazyTensor{T,R_middle,D_middle}} <: LazyTensor{T,R,D}
131 before::IdentityTensor{T,D_before}
196 tm::TM 132 tm::TM
197 after::IdentityMapping{T,D_after} 133 after::IdentityTensor{T,D_after}
198 134
199 function InflatedTensorMapping(before, tm::TensorMapping{T}, after) where T 135 function InflatedTensor(before, tm::LazyTensor{T}, after) where T
200 R_before = range_dim(before) 136 R_before = range_dim(before)
201 R_middle = range_dim(tm) 137 R_middle = range_dim(tm)
202 R_after = range_dim(after) 138 R_after = range_dim(after)
203 R = R_before+R_middle+R_after 139 R = R_before+R_middle+R_after
204 140
207 D_after = domain_dim(after) 143 D_after = domain_dim(after)
208 D = D_before+D_middle+D_after 144 D = D_before+D_middle+D_after
209 return new{T,R,D,D_before,R_middle,D_middle,D_after, typeof(tm)}(before, tm, after) 145 return new{T,R,D,D_before,R_middle,D_middle,D_after, typeof(tm)}(before, tm, after)
210 end 146 end
211 end 147 end
212 export InflatedTensorMapping 148
213 """ 149 """
214 InflatedTensorMapping(before, tm, after) 150 InflatedTensor(before, tm, after)
215 InflatedTensorMapping(before,tm) 151 InflatedTensor(before,tm)
216 InflatedTensorMapping(tm,after) 152 InflatedTensor(tm,after)
217 153
218 The outer product of `before`, `tm` and `after`, where `before` and `after` are `IdentityMapping`s. 154 The outer product of `before`, `tm` and `after`, where `before` and `after` are `IdentityTensor`s.
219 155
220 If one of `before` or `after` is left out, a 0-dimensional `IdentityMapping` is used as the default value. 156 If one of `before` or `after` is left out, a 0-dimensional `IdentityTensor` is used as the default value.
221 157
222 If `tm` already is an `InflatedTensorMapping`, `before` and `after` will be extended instead of 158 If `tm` already is an `InflatedTensor`, `before` and `after` will be extended instead of
223 creating a nested `InflatedTensorMapping`. 159 creating a nested `InflatedTensor`.
224 """ 160 """
225 InflatedTensorMapping(::IdentityMapping, ::TensorMapping, ::IdentityMapping) 161 InflatedTensor(::IdentityTensor, ::LazyTensor, ::IdentityTensor)
226 162
227 function InflatedTensorMapping(before, itm::InflatedTensorMapping, after) 163 function InflatedTensor(before, itm::InflatedTensor, after)
228 return InflatedTensorMapping( 164 return InflatedTensor(
229 IdentityMapping(before.size..., itm.before.size...), 165 IdentityTensor(before.size..., itm.before.size...),
230 itm.tm, 166 itm.tm,
231 IdentityMapping(itm.after.size..., after.size...), 167 IdentityTensor(itm.after.size..., after.size...),
232 ) 168 )
233 end 169 end
234 170
235 InflatedTensorMapping(before::IdentityMapping, tm::TensorMapping{T}) where T = InflatedTensorMapping(before,tm,IdentityMapping{T}()) 171 InflatedTensor(before::IdentityTensor, tm::LazyTensor{T}) where T = InflatedTensor(before,tm,IdentityTensor{T}())
236 InflatedTensorMapping(tm::TensorMapping{T}, after::IdentityMapping) where T = InflatedTensorMapping(IdentityMapping{T}(),tm,after) 172 InflatedTensor(tm::LazyTensor{T}, after::IdentityTensor) where T = InflatedTensor(IdentityTensor{T}(),tm,after)
237 # Resolve ambiguity between the two previous methods 173 # Resolve ambiguity between the two previous methods
238 InflatedTensorMapping(I1::IdentityMapping{T}, I2::IdentityMapping{T}) where T = InflatedTensorMapping(I1,I2,IdentityMapping{T}()) 174 InflatedTensor(I1::IdentityTensor{T}, I2::IdentityTensor{T}) where T = InflatedTensor(I1,I2,IdentityTensor{T}())
239 175
240 # TODO: Implement some pretty printing in terms of ⊗. E.g InflatedTensorMapping(I(3),B,I(2)) -> I(3)⊗B⊗I(2) 176 # TODO: Implement some pretty printing in terms of ⊗. E.g InflatedTensor(I(3),B,I(2)) -> I(3)⊗B⊗I(2)
241 177
242 function range_size(itm::InflatedTensorMapping) 178 function range_size(itm::InflatedTensor)
243 return flatten_tuple( 179 return flatten_tuple(
244 range_size(itm.before), 180 range_size(itm.before),
245 range_size(itm.tm), 181 range_size(itm.tm),
246 range_size(itm.after), 182 range_size(itm.after),
247 ) 183 )
248 end 184 end
249 185
250 function domain_size(itm::InflatedTensorMapping) 186 function domain_size(itm::InflatedTensor)
251 return flatten_tuple( 187 return flatten_tuple(
252 domain_size(itm.before), 188 domain_size(itm.before),
253 domain_size(itm.tm), 189 domain_size(itm.tm),
254 domain_size(itm.after), 190 domain_size(itm.after),
255 ) 191 )
256 end 192 end
257 193
258 function apply(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D} 194 function apply(itm::InflatedTensor{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D}
259 dim_before = range_dim(itm.before) 195 dim_before = range_dim(itm.before)
260 dim_domain = domain_dim(itm.tm) 196 dim_domain = domain_dim(itm.tm)
261 dim_range = range_dim(itm.tm) 197 dim_range = range_dim(itm.tm)
262 dim_after = range_dim(itm.after) 198 dim_after = range_dim(itm.after)
263 199
265 201
266 v_inner = view(v, view_index...) 202 v_inner = view(v, view_index...)
267 return apply(itm.tm, v_inner, inner_index...) 203 return apply(itm.tm, v_inner, inner_index...)
268 end 204 end
269 205
270 function apply_transpose(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{T,R}, I::Vararg{Any,D}) where {T,R,D} 206 function apply_transpose(itm::InflatedTensor{T,R,D}, v::AbstractArray{<:Any,R}, I::Vararg{Any,D}) where {T,R,D}
271 dim_before = range_dim(itm.before) 207 dim_before = range_dim(itm.before)
272 dim_domain = domain_dim(itm.tm) 208 dim_domain = domain_dim(itm.tm)
273 dim_range = range_dim(itm.tm) 209 dim_range = range_dim(itm.tm)
274 dim_after = range_dim(itm.after) 210 dim_after = range_dim(itm.after)
275 211
278 v_inner = view(v, view_index...) 214 v_inner = view(v, view_index...)
279 return apply_transpose(itm.tm, v_inner, inner_index...) 215 return apply_transpose(itm.tm, v_inner, inner_index...)
280 end 216 end
281 217
282 218
283 """
284 split_index(::Val{dim_before}, ::Val{dim_view}, ::Val{dim_index}, ::Val{dim_after}, I...)
285
286 Splits the multi-index `I` into two parts. One part which is expected to be
287 used as a view, and one which is expected to be used as an index.
288 Eg.
289 ```
290 split_index(Val(1),Val(3),Val(2),Val(1),(1,2,3,4)) -> (1,:,:,:,4), (2,3)
291 ```
292
293 `dim_view` controls how many colons are in the view, and `dim_index` controls
294 how many elements are extracted from the middle.
295 `dim_before` and `dim_after` decides the length of the index parts before and after the colons in the view index.
296
297 Arguments should satisfy `length(I) == dim_before+B_domain+dim_after`.
298
299 The returned values satisfy
300 * `length(view_index) == dim_before + dim_view + dim_after`
301 * `length(I_middle) == dim_index`
302 """
303 function split_index(::Val{dim_before}, ::Val{dim_view}, ::Val{dim_index}, ::Val{dim_after}, I...) where {dim_before,dim_view, dim_index,dim_after}
304 I_before, I_middle, I_after = split_tuple(I, Val(dim_before), Val(dim_index))
305
306 view_index = (I_before..., ntuple((i)->:, dim_view)..., I_after...)
307
308 return view_index, I_middle
309 end
310
311 # TODO: Can this be replaced by something more elegant while still being type stable? 2020-10-21
312 # See:
313 # https://github.com/JuliaLang/julia/issues/34884
314 # https://github.com/JuliaLang/julia/issues/30386
315 """
316 slice_tuple(t, Val(l), Val(u))
317
318 Get a slice of a tuple in a type stable way.
319 Equivalent to `t[l:u]` but type stable.
320 """
321 function slice_tuple(t,::Val{L},::Val{U}) where {L,U}
322 return ntuple(i->t[i+L-1], U-L+1)
323 end
324
325 """
326 split_tuple(t::Tuple{...}, ::Val{M}) where {N,M}
327
328 Split the tuple `t` into two parts. the first part is `M` long.
329 E.g
330 ```julia
331 split_tuple((1,2,3,4),Val(3)) -> (1,2,3), (4,)
332 ```
333 """
334 function split_tuple(t::NTuple{N,Any},::Val{M}) where {N,M}
335 return slice_tuple(t,Val(1), Val(M)), slice_tuple(t,Val(M+1), Val(N))
336 end
337
338 """
339 split_tuple(t::Tuple{...},::Val{M},::Val{K}) where {N,M,K}
340
341 Same as `split_tuple(t::NTuple{N},::Val{M})` but splits the tuple in three parts. With the first
342 two parts having lenght `M` and `K`.
343 """
344 function split_tuple(t::NTuple{N,Any},::Val{M},::Val{K}) where {N,M,K}
345 p1, tail = split_tuple(t, Val(M))
346 p2, p3 = split_tuple(tail, Val(K))
347 return p1,p2,p3
348 end
349
350
351 """
352 flatten_tuple(t)
353
354 Takes a nested tuple and flattens the whole structure
355 """
356 flatten_tuple(t::NTuple{N, Number} where N) = t
357 flatten_tuple(t::Tuple) = ((flatten_tuple.(t)...)...,) # simplify?
358 flatten_tuple(ts::Vararg) = flatten_tuple(ts)
359
360 @doc raw""" 219 @doc raw"""
361 LazyOuterProduct(tms...) 220 LazyOuterProduct(tms...)
362 221
363 Creates a `TensorMappingComposition` for the outerproduct of `tms...`. 222 Creates a `TensorComposition` for the outerproduct of `tms...`.
364 This is done by separating the outer product into regular products of outer products involving only identity mappings and one non-identity mapping. 223 This is done by separating the outer product into regular products of outer products involving only identity mappings and one non-identity mapping.
365 224
366 First let 225 First let
367 ```math 226 ```math
368 \begin{aligned} 227 \begin{aligned}
393 ```math 252 ```math
394 (A⊗B⊗C)v = [(A⊗I_{|M|}⊗I_{|P|}) [(I_{|J|}⊗B⊗I_{|P|}) [(I_{|J|}⊗I_{|N|}⊗C)v]]] 253 (A⊗B⊗C)v = [(A⊗I_{|M|}⊗I_{|P|}) [(I_{|J|}⊗B⊗I_{|P|}) [(I_{|J|}⊗I_{|N|}⊗C)v]]]
395 ``` 254 ```
396 """ 255 """
397 function LazyOuterProduct end 256 function LazyOuterProduct end
398 export LazyOuterProduct 257
399 258 function LazyOuterProduct(tm1::LazyTensor{T}, tm2::LazyTensor{T}) where T
400 function LazyOuterProduct(tm1::TensorMapping{T}, tm2::TensorMapping{T}) where T 259 itm1 = InflatedTensor(tm1, IdentityTensor{T}(range_size(tm2)))
401 itm1 = InflatedTensorMapping(tm1, IdentityMapping{T}(range_size(tm2))) 260 itm2 = InflatedTensor(IdentityTensor{T}(domain_size(tm1)),tm2)
402 itm2 = InflatedTensorMapping(IdentityMapping{T}(domain_size(tm1)),tm2)
403 261
404 return itm1∘itm2 262 return itm1∘itm2
405 end 263 end
406 264
407 LazyOuterProduct(t1::IdentityMapping{T}, t2::IdentityMapping{T}) where T = IdentityMapping{T}(t1.size...,t2.size...) 265 LazyOuterProduct(t1::IdentityTensor{T}, t2::IdentityTensor{T}) where T = IdentityTensor{T}(t1.size...,t2.size...)
408 LazyOuterProduct(t1::TensorMapping, t2::IdentityMapping) = InflatedTensorMapping(t1, t2) 266 LazyOuterProduct(t1::LazyTensor, t2::IdentityTensor) = InflatedTensor(t1, t2)
409 LazyOuterProduct(t1::IdentityMapping, t2::TensorMapping) = InflatedTensorMapping(t1, t2) 267 LazyOuterProduct(t1::IdentityTensor, t2::LazyTensor) = InflatedTensor(t1, t2)
410 268
411 LazyOuterProduct(tms::Vararg{TensorMapping}) = foldl(LazyOuterProduct, tms) 269 LazyOuterProduct(tms::Vararg{LazyTensor}) = foldl(LazyOuterProduct, tms)
412 270
413 ⊗(a::TensorMapping, b::TensorMapping) = LazyOuterProduct(a,b) 271
414 export ⊗ 272
415 273 """
416 274 inflate(tm::LazyTensor, sz, dir)
417 function check_domain_size(tm::TensorMapping, sz) 275
276 Inflate `tm` such that it gets the size `sz` in all directions except `dir`.
277 Here `sz[dir]` is ignored and replaced with the range and domains size of
278 `tm`.
279
280 An example of when this operation is useful is when extending a one
281 dimensional difference operator `D` to a 2D grid of a ceratin size. In that
282 case we could have
283
284 ```julia
285 Dx = inflate(D, (10,10), 1)
286 Dy = inflate(D, (10,10), 2)
287 ```
288 """
289 function inflate(tm::LazyTensor, sz, dir)
290 Is = IdentityTensor{eltype(tm)}.(sz)
291 parts = Base.setindex(Is, tm, dir)
292 return foldl(⊗, parts)
293 end
294
295 function check_domain_size(tm::LazyTensor, sz)
418 if domain_size(tm) != sz 296 if domain_size(tm) != sz
419 throw(SizeMismatch(tm,sz)) 297 throw(DomainSizeMismatch(tm,sz))
420 end 298 end
421 end 299 end
422 300
423 struct SizeMismatch <: Exception 301 function check_range_size(tm::LazyTensor, sz)
424 tm::TensorMapping 302 if range_size(tm) != sz
303 throw(RangeSizeMismatch(tm,sz))
304 end
305 end
306
307 struct DomainSizeMismatch <: Exception
308 tm::LazyTensor
425 sz 309 sz
426 end 310 end
427 export SizeMismatch 311
428 312 function Base.showerror(io::IO, err::DomainSizeMismatch)
429 function Base.showerror(io::IO, err::SizeMismatch) 313 print(io, "DomainSizeMismatch: ")
430 print(io, "SizeMismatch: ") 314 print(io, "domain size $(domain_size(err.tm)) of LazyTensor not matching size $(err.sz)")
431 print(io, "domain size $(domain_size(err.tm)) of TensorMapping not matching size $(err.sz)") 315 end
432 end 316
317
318 struct RangeSizeMismatch <: Exception
319 tm::LazyTensor
320 sz
321 end
322
323 function Base.showerror(io::IO, err::RangeSizeMismatch)
324 print(io, "RangeSizeMismatch: ")
325 print(io, "range size $(range_size(err.tm)) of LazyTensor not matching size $(err.sz)")
326 end
327
433 328
434 function apply_with_region(op, v, boundary_width::Integer, dim_size::Integer, i) 329 function apply_with_region(op, v, boundary_width::Integer, dim_size::Integer, i)
435 if 0 < i <= boundary_width 330 if 0 < i <= boundary_width
436 return LazyTensors.apply(op,v,Index(i,Lower)) 331 return LazyTensors.apply(op,v,Index(i,Lower))
437 elseif boundary_width < i <= dim_size-boundary_width 332 elseif boundary_width < i <= dim_size-boundary_width
453 return LazyTensors.apply_transpose(op,v,Index(i,Upper)) 348 return LazyTensors.apply_transpose(op,v,Index(i,Upper))
454 else 349 else
455 error("Bounds error") # TODO: Make this more standard 350 error("Bounds error") # TODO: Make this more standard
456 end 351 end
457 end 352 end
353