comparison src/LazyTensors/lazy_tensor_operations.jl @ 1047:d12ab8120d29 feature/first_derivative

Merge default
author Jonatan Werpers <jonatan@werpers.com>
date Wed, 23 Mar 2022 12:43:03 +0100
parents 9e76bf19904c
children f857057e61e6 3bb94ce74697 2e606d4c0ab1
comparison
equal deleted inserted replaced
1046:e00eb000346e 1047:d12ab8120d29
1 """ 1 """
2 LazyTensorMappingApplication{T,R,D} <: LazyArray{T,R} 2 TensorApplication{T,R,D} <: LazyArray{T,R}
3 3
4 Struct for lazy application of a TensorMapping. Created using `*`. 4 Struct for lazy application of a LazyTensor. Created using `*`.
5 5
6 Allows the result of a `TensorMapping` applied to a vector to be treated as an `AbstractArray`. 6 Allows the result of a `LazyTensor` applied to a vector to be treated as an `AbstractArray`.
7 With a mapping `m` and a vector `v` the LazyTensorMappingApplication object can be created by `m*v`. 7 With a mapping `m` and a vector `v` the TensorApplication object can be created by `m*v`.
8 The actual result will be calcualted when indexing into `m*v`. 8 The actual result will be calcualted when indexing into `m*v`.
9 """ 9 """
10 struct LazyTensorMappingApplication{T,R,D, TM<:TensorMapping{T,R,D}, AA<:AbstractArray{T,D}} <: LazyArray{T,R} 10 struct TensorApplication{T,R,D, TM<:LazyTensor{<:Any,R,D}, AA<:AbstractArray{<:Any,D}} <: LazyArray{T,R}
11 t::TM 11 t::TM
12 o::AA 12 o::AA
13 end 13
14 # TODO: Do boundschecking on creation! 14 function TensorApplication(t::LazyTensor{<:Any,R,D}, o::AbstractArray{<:Any,D}) where {R,D}
15 export LazyTensorMappingApplication 15 @boundscheck check_domain_size(t, size(o))
16 16 I = ntuple(i->1, range_dim(t))
17 Base.getindex(ta::LazyTensorMappingApplication{T,R}, I::Vararg{Any,R}) where {T,R} = apply(ta.t, ta.o, I...) 17 T = typeof(apply(t,o,I...))
18 Base.getindex(ta::LazyTensorMappingApplication{T,1}, I::CartesianIndex{1}) where {T} = apply(ta.t, ta.o, I.I...) # Would otherwise be caught in the previous method. 18 return new{T,R,D,typeof(t), typeof(o)}(t,o)
19 Base.size(ta::LazyTensorMappingApplication) = range_size(ta.t) 19 end
20 # TODO: What else is needed to implement the AbstractArray interface? 20 end
21 21
22 Base.:*(a::TensorMapping, v::AbstractArray) = LazyTensorMappingApplication(a,v) 22 function Base.getindex(ta::TensorApplication{T,R}, I::Vararg{Any,R}) where {T,R}
23 Base.:*(a::TensorMapping, b::TensorMapping) = throw(MethodError(Base.:*,(a,b))) 23 @boundscheck checkbounds(ta, Int.(I)...)
24 Base.:*(a::TensorMapping, args::Union{TensorMapping, AbstractArray}...) = foldr(*,(a,args...)) 24 return @inbounds apply(ta.t, ta.o, I...)
25 25 end
26 # # We need the associativity to be a→b→c = a→(b→c), which is the case for '→' 26 Base.@propagate_inbounds Base.getindex(ta::TensorApplication{T,1} where T, I::CartesianIndex{1}) = ta[Tuple(I)...] # Would otherwise be caught in the previous method.
27 # # Should we overload some other infix binary opesrator? 27 Base.size(ta::TensorApplication) = range_size(ta.t)
28 # →(tm::TensorMapping{T,R,D}, o::AbstractArray{T,D}) where {T,R,D} = LazyTensorMappingApplication(tm,o) 28
29 # TODO: We need to be really careful about good error messages. 29
30 # For example what happens if you try to multiply LazyTensorMappingApplication with a TensorMapping(wrong order)? 30 """
31 31 TensorTranspose{T,R,D} <: LazyTensor{T,D,R}
32 """ 32
33 LazyTensorMappingTranspose{T,R,D} <: TensorMapping{T,D,R} 33 Struct for lazy transpose of a LazyTensor.
34
35 Struct for lazy transpose of a TensorMapping.
36 34
37 If a mapping implements the the `apply_transpose` method this allows working with 35 If a mapping implements the the `apply_transpose` method this allows working with
38 the transpose of mapping `m` by using `m'`. `m'` will work as a regular TensorMapping lazily calling 36 the transpose of mapping `m` by using `m'`. `m'` will work as a regular LazyTensor lazily calling
39 the appropriate methods of `m`. 37 the appropriate methods of `m`.
40 """ 38 """
41 struct LazyTensorMappingTranspose{T,R,D, TM<:TensorMapping{T,R,D}} <: TensorMapping{T,D,R} 39 struct TensorTranspose{T,R,D, TM<:LazyTensor{T,R,D}} <: LazyTensor{T,D,R}
42 tm::TM 40 tm::TM
43 end 41 end
44 export LazyTensorMappingTranspose
45 42
46 # # TBD: Should this be implemented on a type by type basis or through a trait to provide earlier errors? 43 # # TBD: Should this be implemented on a type by type basis or through a trait to provide earlier errors?
47 # Jonatan 2020-09-25: Is the problem that you can take the transpose of any TensorMapping even if it doesn't implement `apply_transpose`? 44 # Jonatan 2020-09-25: Is the problem that you can take the transpose of any LazyTensor even if it doesn't implement `apply_transpose`?
48 Base.adjoint(tm::TensorMapping) = LazyTensorMappingTranspose(tm) 45 Base.adjoint(tm::LazyTensor) = TensorTranspose(tm)
49 Base.adjoint(tmt::LazyTensorMappingTranspose) = tmt.tm 46 Base.adjoint(tmt::TensorTranspose) = tmt.tm
50 47
51 apply(tmt::LazyTensorMappingTranspose{T,R,D}, v::AbstractArray{T,R}, I::Vararg{Any,D}) where {T,R,D} = apply_transpose(tmt.tm, v, I...) 48 apply(tmt::TensorTranspose{T,R,D}, v::AbstractArray{<:Any,R}, I::Vararg{Any,D}) where {T,R,D} = apply_transpose(tmt.tm, v, I...)
52 apply_transpose(tmt::LazyTensorMappingTranspose{T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmt.tm, v, I...) 49 apply_transpose(tmt::TensorTranspose{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmt.tm, v, I...)
53 50
54 range_size(tmt::LazyTensorMappingTranspose) = domain_size(tmt.tm) 51 range_size(tmt::TensorTranspose) = domain_size(tmt.tm)
55 domain_size(tmt::LazyTensorMappingTranspose) = range_size(tmt.tm) 52 domain_size(tmt::TensorTranspose) = range_size(tmt.tm)
56 53
57 54
58 struct LazyTensorMappingBinaryOperation{Op,T,R,D,T1<:TensorMapping{T,R,D},T2<:TensorMapping{T,R,D}} <: TensorMapping{T,D,R} 55 struct ElementwiseTensorOperation{Op,T,R,D,T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}} <: LazyTensor{T,D,R}
59 tm1::T1 56 tm1::T1
60 tm2::T2 57 tm2::T2
61 58
62 @inline function LazyTensorMappingBinaryOperation{Op,T,R,D}(tm1::T1,tm2::T2) where {Op,T,R,D, T1<:TensorMapping{T,R,D},T2<:TensorMapping{T,R,D}} 59 function ElementwiseTensorOperation{Op,T,R,D}(tm1::T1,tm2::T2) where {Op,T,R,D, T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}}
60 @boundscheck check_domain_size(tm2, domain_size(tm1))
61 @boundscheck check_range_size(tm2, range_size(tm1))
63 return new{Op,T,R,D,T1,T2}(tm1,tm2) 62 return new{Op,T,R,D,T1,T2}(tm1,tm2)
64 end 63 end
65 end 64 end
66 # TODO: Boundschecking in constructor. 65
67 66 ElementwiseTensorOperation{Op}(s,t) where Op = ElementwiseTensorOperation{Op,eltype(s), range_dim(s), domain_dim(s)}(s,t)
68 apply(tmBinOp::LazyTensorMappingBinaryOperation{:+,T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) + apply(tmBinOp.tm2, v, I...) 67
69 apply(tmBinOp::LazyTensorMappingBinaryOperation{:-,T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) - apply(tmBinOp.tm2, v, I...) 68 apply(tmBinOp::ElementwiseTensorOperation{:+,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) + apply(tmBinOp.tm2, v, I...)
70 69 apply(tmBinOp::ElementwiseTensorOperation{:-,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) - apply(tmBinOp.tm2, v, I...)
71 range_size(tmBinOp::LazyTensorMappingBinaryOperation) = range_size(tmBinOp.tm1) 70
72 domain_size(tmBinOp::LazyTensorMappingBinaryOperation) = domain_size(tmBinOp.tm1) 71 range_size(tmBinOp::ElementwiseTensorOperation) = range_size(tmBinOp.tm1)
73 72 domain_size(tmBinOp::ElementwiseTensorOperation) = domain_size(tmBinOp.tm1)
74 Base.:+(tm1::TensorMapping{T,R,D}, tm2::TensorMapping{T,R,D}) where {T,R,D} = LazyTensorMappingBinaryOperation{:+,T,R,D}(tm1,tm2) 73
75 Base.:-(tm1::TensorMapping{T,R,D}, tm2::TensorMapping{T,R,D}) where {T,R,D} = LazyTensorMappingBinaryOperation{:-,T,R,D}(tm1,tm2) 74
76 75 """
77 """ 76 TensorComposition{T,R,K,D}
78 TensorMappingComposition{T,R,K,D} 77
79 78 Lazily compose two `LazyTensor`s, so that they can be handled as a single `LazyTensor`.
80 Lazily compose two `TensorMapping`s, so that they can be handled as a single `TensorMapping`. 79 """
81 """ 80 struct TensorComposition{T,R,K,D, TM1<:LazyTensor{T,R,K}, TM2<:LazyTensor{T,K,D}} <: LazyTensor{T,R,D}
82 struct TensorMappingComposition{T,R,K,D, TM1<:TensorMapping{T,R,K}, TM2<:TensorMapping{T,K,D}} <: TensorMapping{T,R,D}
83 t1::TM1 81 t1::TM1
84 t2::TM2 82 t2::TM2
85 83
86 @inline function TensorMappingComposition(t1::TensorMapping{T,R,K}, t2::TensorMapping{T,K,D}) where {T,R,K,D} 84 function TensorComposition(t1::LazyTensor{T,R,K}, t2::LazyTensor{T,K,D}) where {T,R,K,D}
87 @boundscheck check_domain_size(t1, range_size(t2)) 85 @boundscheck check_domain_size(t1, range_size(t2))
88 return new{T,R,K,D, typeof(t1), typeof(t2)}(t1,t2) 86 return new{T,R,K,D, typeof(t1), typeof(t2)}(t1,t2)
89 end 87 end
90 end 88 end
91 export TensorMappingComposition 89
92 90 range_size(tm::TensorComposition) = range_size(tm.t1)
93 range_size(tm::TensorMappingComposition) = range_size(tm.t1) 91 domain_size(tm::TensorComposition) = domain_size(tm.t2)
94 domain_size(tm::TensorMappingComposition) = domain_size(tm.t2) 92
95 93 function apply(c::TensorComposition{T,R,K,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,K,D}
96 function apply(c::TensorMappingComposition{T,R,K,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,K,D}
97 apply(c.t1, c.t2*v, I...) 94 apply(c.t1, c.t2*v, I...)
98 end 95 end
99 96
100 function apply_transpose(c::TensorMappingComposition{T,R,K,D}, v::AbstractArray{T,R}, I::Vararg{Any,D}) where {T,R,K,D} 97 function apply_transpose(c::TensorComposition{T,R,K,D}, v::AbstractArray{<:Any,R}, I::Vararg{Any,D}) where {T,R,K,D}
101 apply_transpose(c.t2, c.t1'*v, I...) 98 apply_transpose(c.t2, c.t1'*v, I...)
102 end 99 end
103 100
104 Base.@propagate_inbounds Base.:∘(s::TensorMapping, t::TensorMapping) = TensorMappingComposition(s,t) 101
105 102 """
106 """ 103 TensorComposition(tm, tmi::IdentityTensor)
107 LazyLinearMap{T,R,D,...}(A, range_indicies, domain_indicies) 104 TensorComposition(tmi::IdentityTensor, tm)
108 105
109 TensorMapping defined by the AbstractArray A. `range_indicies` and `domain_indicies` define which indicies of A should 106 Composes a `Tensormapping` `tm` with an `IdentityTensor` `tmi`, by returning `tm`
110 be considerd the range and domain of the TensorMapping. Each set of indices must be ordered in ascending order. 107 """
111 108 function TensorComposition(tm::LazyTensor{T,R,D}, tmi::IdentityTensor{T,D}) where {T,R,D}
112 For instance, if A is a m x n matrix, and range_size = (1,), domain_size = (2,), then the LazyLinearMap performs the
113 standard matrix-vector product on vectors of size n.
114 """
115 struct LazyLinearMap{T,R,D, RD, AA<:AbstractArray{T,RD}} <: TensorMapping{T,R,D}
116 A::AA
117 range_indicies::NTuple{R,Int}
118 domain_indicies::NTuple{D,Int}
119
120 function LazyLinearMap(A::AA, range_indicies::NTuple{R,Int}, domain_indicies::NTuple{D,Int}) where {T,R,D, RD, AA<:AbstractArray{T,RD}}
121 if !issorted(range_indicies) || !issorted(domain_indicies)
122 throw(DomainError("range_indicies and domain_indicies must be sorted in ascending order"))
123 end
124
125 return new{T,R,D,RD,AA}(A,range_indicies,domain_indicies)
126 end
127 end
128 export LazyLinearMap
129
130 range_size(llm::LazyLinearMap) = size(llm.A)[[llm.range_indicies...]]
131 domain_size(llm::LazyLinearMap) = size(llm.A)[[llm.domain_indicies...]]
132
133 function apply(llm::LazyLinearMap{T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D}
134 view_index = ntuple(i->:,ndims(llm.A))
135 for i ∈ 1:R
136 view_index = Base.setindex(view_index, Int(I[i]), llm.range_indicies[i])
137 end
138 A_view = @view llm.A[view_index...]
139 return sum(A_view.*v)
140 end
141
142 function apply_transpose(llm::LazyLinearMap{T,R,D}, v::AbstractArray{T,R}, I::Vararg{Any,D}) where {T,R,D}
143 apply(LazyLinearMap(llm.A, llm.domain_indicies, llm.range_indicies), v, I...)
144 end
145
146
147 """
148 IdentityMapping{T,D} <: TensorMapping{T,D,D}
149
150 The lazy identity TensorMapping for a given size. Usefull for building up higher dimensional tensor mappings from lower
151 dimensional ones through outer products. Also used in the Implementation for InflatedTensorMapping.
152 """
153 struct IdentityMapping{T,D} <: TensorMapping{T,D,D}
154 size::NTuple{D,Int}
155 end
156 export IdentityMapping
157
158 IdentityMapping{T}(size::NTuple{D,Int}) where {T,D} = IdentityMapping{T,D}(size)
159 IdentityMapping{T}(size::Vararg{Int,D}) where {T,D} = IdentityMapping{T,D}(size)
160 IdentityMapping(size::Vararg{Int,D}) where D = IdentityMapping{Float64,D}(size)
161
162 range_size(tmi::IdentityMapping) = tmi.size
163 domain_size(tmi::IdentityMapping) = tmi.size
164
165 apply(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...]
166 apply_transpose(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...]
167
168 """
169 Base.:∘(tm, tmi)
170 Base.:∘(tmi, tm)
171
172 Composes a `Tensormapping` `tm` with an `IdentityMapping` `tmi`, by returning `tm`
173 """
174 @inline function Base.:∘(tm::TensorMapping{T,R,D}, tmi::IdentityMapping{T,D}) where {T,R,D}
175 @boundscheck check_domain_size(tm, range_size(tmi)) 109 @boundscheck check_domain_size(tm, range_size(tmi))
176 return tm 110 return tm
177 end 111 end
178 112
179 @inline function Base.:∘(tmi::IdentityMapping{T,R}, tm::TensorMapping{T,R,D}) where {T,R,D} 113 function TensorComposition(tmi::IdentityTensor{T,R}, tm::LazyTensor{T,R,D}) where {T,R,D}
180 @boundscheck check_domain_size(tmi, range_size(tm)) 114 @boundscheck check_domain_size(tmi, range_size(tm))
181 return tm 115 return tm
182 end 116 end
183 # Specialization for the case where tm is an IdentityMapping. Required to resolve ambiguity. 117 # Specialization for the case where tm is an IdentityTensor. Required to resolve ambiguity.
184 @inline function Base.:∘(tm::IdentityMapping{T,D}, tmi::IdentityMapping{T,D}) where {T,D} 118 function TensorComposition(tm::IdentityTensor{T,D}, tmi::IdentityTensor{T,D}) where {T,D}
185 @boundscheck check_domain_size(tm, range_size(tmi)) 119 @boundscheck check_domain_size(tm, range_size(tmi))
186 return tmi 120 return tmi
187 end 121 end
188 122
189 123
190 """ 124 """
191 InflatedTensorMapping{T,R,D} <: TensorMapping{T,R,D} 125 InflatedTensor{T,R,D} <: LazyTensor{T,R,D}
192 126
193 An inflated `TensorMapping` with dimensions added before and afer its actual dimensions. 127 An inflated `LazyTensor` with dimensions added before and afer its actual dimensions.
194 """ 128 """
195 struct InflatedTensorMapping{T,R,D,D_before,R_middle,D_middle,D_after, TM<:TensorMapping{T,R_middle,D_middle}} <: TensorMapping{T,R,D} 129 struct InflatedTensor{T,R,D,D_before,R_middle,D_middle,D_after, TM<:LazyTensor{T,R_middle,D_middle}} <: LazyTensor{T,R,D}
196 before::IdentityMapping{T,D_before} 130 before::IdentityTensor{T,D_before}
197 tm::TM 131 tm::TM
198 after::IdentityMapping{T,D_after} 132 after::IdentityTensor{T,D_after}
199 133
200 function InflatedTensorMapping(before, tm::TensorMapping{T}, after) where T 134 function InflatedTensor(before, tm::LazyTensor{T}, after) where T
201 R_before = range_dim(before) 135 R_before = range_dim(before)
202 R_middle = range_dim(tm) 136 R_middle = range_dim(tm)
203 R_after = range_dim(after) 137 R_after = range_dim(after)
204 R = R_before+R_middle+R_after 138 R = R_before+R_middle+R_after
205 139
208 D_after = domain_dim(after) 142 D_after = domain_dim(after)
209 D = D_before+D_middle+D_after 143 D = D_before+D_middle+D_after
210 return new{T,R,D,D_before,R_middle,D_middle,D_after, typeof(tm)}(before, tm, after) 144 return new{T,R,D,D_before,R_middle,D_middle,D_after, typeof(tm)}(before, tm, after)
211 end 145 end
212 end 146 end
213 export InflatedTensorMapping 147
214 """ 148 """
215 InflatedTensorMapping(before, tm, after) 149 InflatedTensor(before, tm, after)
216 InflatedTensorMapping(before,tm) 150 InflatedTensor(before,tm)
217 InflatedTensorMapping(tm,after) 151 InflatedTensor(tm,after)
218 152
219 The outer product of `before`, `tm` and `after`, where `before` and `after` are `IdentityMapping`s. 153 The outer product of `before`, `tm` and `after`, where `before` and `after` are `IdentityTensor`s.
220 154
221 If one of `before` or `after` is left out, a 0-dimensional `IdentityMapping` is used as the default value. 155 If one of `before` or `after` is left out, a 0-dimensional `IdentityTensor` is used as the default value.
222 156
223 If `tm` already is an `InflatedTensorMapping`, `before` and `after` will be extended instead of 157 If `tm` already is an `InflatedTensor`, `before` and `after` will be extended instead of
224 creating a nested `InflatedTensorMapping`. 158 creating a nested `InflatedTensor`.
225 """ 159 """
226 InflatedTensorMapping(::IdentityMapping, ::TensorMapping, ::IdentityMapping) 160 InflatedTensor(::IdentityTensor, ::LazyTensor, ::IdentityTensor)
227 161
228 function InflatedTensorMapping(before, itm::InflatedTensorMapping, after) 162 function InflatedTensor(before, itm::InflatedTensor, after)
229 return InflatedTensorMapping( 163 return InflatedTensor(
230 IdentityMapping(before.size..., itm.before.size...), 164 IdentityTensor(before.size..., itm.before.size...),
231 itm.tm, 165 itm.tm,
232 IdentityMapping(itm.after.size..., after.size...), 166 IdentityTensor(itm.after.size..., after.size...),
233 ) 167 )
234 end 168 end
235 169
236 InflatedTensorMapping(before::IdentityMapping, tm::TensorMapping{T}) where T = InflatedTensorMapping(before,tm,IdentityMapping{T}()) 170 InflatedTensor(before::IdentityTensor, tm::LazyTensor{T}) where T = InflatedTensor(before,tm,IdentityTensor{T}())
237 InflatedTensorMapping(tm::TensorMapping{T}, after::IdentityMapping) where T = InflatedTensorMapping(IdentityMapping{T}(),tm,after) 171 InflatedTensor(tm::LazyTensor{T}, after::IdentityTensor) where T = InflatedTensor(IdentityTensor{T}(),tm,after)
238 # Resolve ambiguity between the two previous methods 172 # Resolve ambiguity between the two previous methods
239 InflatedTensorMapping(I1::IdentityMapping{T}, I2::IdentityMapping{T}) where T = InflatedTensorMapping(I1,I2,IdentityMapping{T}()) 173 InflatedTensor(I1::IdentityTensor{T}, I2::IdentityTensor{T}) where T = InflatedTensor(I1,I2,IdentityTensor{T}())
240 174
241 # TODO: Implement some pretty printing in terms of ⊗. E.g InflatedTensorMapping(I(3),B,I(2)) -> I(3)⊗B⊗I(2) 175 # TODO: Implement some pretty printing in terms of ⊗. E.g InflatedTensor(I(3),B,I(2)) -> I(3)⊗B⊗I(2)
242 176
243 function range_size(itm::InflatedTensorMapping) 177 function range_size(itm::InflatedTensor)
244 return flatten_tuple( 178 return flatten_tuple(
245 range_size(itm.before), 179 range_size(itm.before),
246 range_size(itm.tm), 180 range_size(itm.tm),
247 range_size(itm.after), 181 range_size(itm.after),
248 ) 182 )
249 end 183 end
250 184
251 function domain_size(itm::InflatedTensorMapping) 185 function domain_size(itm::InflatedTensor)
252 return flatten_tuple( 186 return flatten_tuple(
253 domain_size(itm.before), 187 domain_size(itm.before),
254 domain_size(itm.tm), 188 domain_size(itm.tm),
255 domain_size(itm.after), 189 domain_size(itm.after),
256 ) 190 )
257 end 191 end
258 192
259 function apply(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D} 193 function apply(itm::InflatedTensor{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D}
260 dim_before = range_dim(itm.before) 194 dim_before = range_dim(itm.before)
261 dim_domain = domain_dim(itm.tm) 195 dim_domain = domain_dim(itm.tm)
262 dim_range = range_dim(itm.tm) 196 dim_range = range_dim(itm.tm)
263 dim_after = range_dim(itm.after) 197 dim_after = range_dim(itm.after)
264 198
266 200
267 v_inner = view(v, view_index...) 201 v_inner = view(v, view_index...)
268 return apply(itm.tm, v_inner, inner_index...) 202 return apply(itm.tm, v_inner, inner_index...)
269 end 203 end
270 204
271 function apply_transpose(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{T,R}, I::Vararg{Any,D}) where {T,R,D} 205 function apply_transpose(itm::InflatedTensor{T,R,D}, v::AbstractArray{<:Any,R}, I::Vararg{Any,D}) where {T,R,D}
272 dim_before = range_dim(itm.before) 206 dim_before = range_dim(itm.before)
273 dim_domain = domain_dim(itm.tm) 207 dim_domain = domain_dim(itm.tm)
274 dim_range = range_dim(itm.tm) 208 dim_range = range_dim(itm.tm)
275 dim_after = range_dim(itm.after) 209 dim_after = range_dim(itm.after)
276 210
279 v_inner = view(v, view_index...) 213 v_inner = view(v, view_index...)
280 return apply_transpose(itm.tm, v_inner, inner_index...) 214 return apply_transpose(itm.tm, v_inner, inner_index...)
281 end 215 end
282 216
283 217
284 """
285 split_index(::Val{dim_before}, ::Val{dim_view}, ::Val{dim_index}, ::Val{dim_after}, I...)
286
287 Splits the multi-index `I` into two parts. One part which is expected to be
288 used as a view, and one which is expected to be used as an index.
289 Eg.
290 ```
291 split_index(Val(1),Val(3),Val(2),Val(1),(1,2,3,4)) -> (1,:,:,:,4), (2,3)
292 ```
293
294 `dim_view` controls how many colons are in the view, and `dim_index` controls
295 how many elements are extracted from the middle.
296 `dim_before` and `dim_after` decides the length of the index parts before and after the colons in the view index.
297
298 Arguments should satisfy `length(I) == dim_before+B_domain+dim_after`.
299
300 The returned values satisfy
301 * `length(view_index) == dim_before + dim_view + dim_after`
302 * `length(I_middle) == dim_index`
303 """
304 function split_index(::Val{dim_before}, ::Val{dim_view}, ::Val{dim_index}, ::Val{dim_after}, I...) where {dim_before,dim_view, dim_index,dim_after}
305 I_before, I_middle, I_after = split_tuple(I, Val(dim_before), Val(dim_index))
306
307 view_index = (I_before..., ntuple((i)->:, dim_view)..., I_after...)
308
309 return view_index, I_middle
310 end
311
312 # TODO: Can this be replaced by something more elegant while still being type stable? 2020-10-21
313 # See:
314 # https://github.com/JuliaLang/julia/issues/34884
315 # https://github.com/JuliaLang/julia/issues/30386
316 """
317 slice_tuple(t, Val(l), Val(u))
318
319 Get a slice of a tuple in a type stable way.
320 Equivalent to `t[l:u]` but type stable.
321 """
322 function slice_tuple(t,::Val{L},::Val{U}) where {L,U}
323 return ntuple(i->t[i+L-1], U-L+1)
324 end
325
326 """
327 split_tuple(t::Tuple{...}, ::Val{M}) where {N,M}
328
329 Split the tuple `t` into two parts. the first part is `M` long.
330 E.g
331 ```julia
332 split_tuple((1,2,3,4),Val(3)) -> (1,2,3), (4,)
333 ```
334 """
335 function split_tuple(t::NTuple{N,Any},::Val{M}) where {N,M}
336 return slice_tuple(t,Val(1), Val(M)), slice_tuple(t,Val(M+1), Val(N))
337 end
338
339 """
340 split_tuple(t::Tuple{...},::Val{M},::Val{K}) where {N,M,K}
341
342 Same as `split_tuple(t::NTuple{N},::Val{M})` but splits the tuple in three parts. With the first
343 two parts having lenght `M` and `K`.
344 """
345 function split_tuple(t::NTuple{N,Any},::Val{M},::Val{K}) where {N,M,K}
346 p1, tail = split_tuple(t, Val(M))
347 p2, p3 = split_tuple(tail, Val(K))
348 return p1,p2,p3
349 end
350
351
352 """
353 flatten_tuple(t)
354
355 Takes a nested tuple and flattens the whole structure
356 """
357 flatten_tuple(t::NTuple{N, Number} where N) = t
358 flatten_tuple(t::Tuple) = ((flatten_tuple.(t)...)...,) # simplify?
359 flatten_tuple(ts::Vararg) = flatten_tuple(ts)
360
361 @doc raw""" 218 @doc raw"""
362 LazyOuterProduct(tms...) 219 LazyOuterProduct(tms...)
363 220
364 Creates a `TensorMappingComposition` for the outerproduct of `tms...`. 221 Creates a `TensorComposition` for the outerproduct of `tms...`.
365 This is done by separating the outer product into regular products of outer products involving only identity mappings and one non-identity mapping. 222 This is done by separating the outer product into regular products of outer products involving only identity mappings and one non-identity mapping.
366 223
367 First let 224 First let
368 ```math 225 ```math
369 \begin{aligned} 226 \begin{aligned}
394 ```math 251 ```math
395 (A⊗B⊗C)v = [(A⊗I_{|M|}⊗I_{|P|}) [(I_{|J|}⊗B⊗I_{|P|}) [(I_{|J|}⊗I_{|N|}⊗C)v]]] 252 (A⊗B⊗C)v = [(A⊗I_{|M|}⊗I_{|P|}) [(I_{|J|}⊗B⊗I_{|P|}) [(I_{|J|}⊗I_{|N|}⊗C)v]]]
396 ``` 253 ```
397 """ 254 """
398 function LazyOuterProduct end 255 function LazyOuterProduct end
399 export LazyOuterProduct 256
400 257 function LazyOuterProduct(tm1::LazyTensor{T}, tm2::LazyTensor{T}) where T
401 function LazyOuterProduct(tm1::TensorMapping{T}, tm2::TensorMapping{T}) where T 258 itm1 = InflatedTensor(tm1, IdentityTensor{T}(range_size(tm2)))
402 itm1 = InflatedTensorMapping(tm1, IdentityMapping{T}(range_size(tm2))) 259 itm2 = InflatedTensor(IdentityTensor{T}(domain_size(tm1)),tm2)
403 itm2 = InflatedTensorMapping(IdentityMapping{T}(domain_size(tm1)),tm2)
404 260
405 return itm1∘itm2 261 return itm1∘itm2
406 end 262 end
407 263
408 LazyOuterProduct(t1::IdentityMapping{T}, t2::IdentityMapping{T}) where T = IdentityMapping{T}(t1.size...,t2.size...) 264 LazyOuterProduct(t1::IdentityTensor{T}, t2::IdentityTensor{T}) where T = IdentityTensor{T}(t1.size...,t2.size...)
409 LazyOuterProduct(t1::TensorMapping, t2::IdentityMapping) = InflatedTensorMapping(t1, t2) 265 LazyOuterProduct(t1::LazyTensor, t2::IdentityTensor) = InflatedTensor(t1, t2)
410 LazyOuterProduct(t1::IdentityMapping, t2::TensorMapping) = InflatedTensorMapping(t1, t2) 266 LazyOuterProduct(t1::IdentityTensor, t2::LazyTensor) = InflatedTensor(t1, t2)
411 267
412 LazyOuterProduct(tms::Vararg{TensorMapping}) = foldl(LazyOuterProduct, tms) 268 LazyOuterProduct(tms::Vararg{LazyTensor}) = foldl(LazyOuterProduct, tms)
413 269
414 ⊗(a::TensorMapping, b::TensorMapping) = LazyOuterProduct(a,b) 270
415 export ⊗ 271 function check_domain_size(tm::LazyTensor, sz)
416
417
418 function check_domain_size(tm::TensorMapping, sz)
419 if domain_size(tm) != sz 272 if domain_size(tm) != sz
420 throw(SizeMismatch(tm,sz)) 273 throw(DomainSizeMismatch(tm,sz))
421 end 274 end
422 end 275 end
423 276
424 struct SizeMismatch <: Exception 277 function check_range_size(tm::LazyTensor, sz)
425 tm::TensorMapping 278 if range_size(tm) != sz
279 throw(RangeSizeMismatch(tm,sz))
280 end
281 end
282
283 struct DomainSizeMismatch <: Exception
284 tm::LazyTensor
426 sz 285 sz
427 end 286 end
428 export SizeMismatch 287
429 288 function Base.showerror(io::IO, err::DomainSizeMismatch)
430 function Base.showerror(io::IO, err::SizeMismatch) 289 print(io, "DomainSizeMismatch: ")
431 print(io, "SizeMismatch: ") 290 print(io, "domain size $(domain_size(err.tm)) of LazyTensor not matching size $(err.sz)")
432 print(io, "domain size $(domain_size(err.tm)) of TensorMapping not matching size $(err.sz)") 291 end
433 end 292
293
294 struct RangeSizeMismatch <: Exception
295 tm::LazyTensor
296 sz
297 end
298
299 function Base.showerror(io::IO, err::RangeSizeMismatch)
300 print(io, "RangeSizeMismatch: ")
301 print(io, "range size $(range_size(err.tm)) of LazyTensor not matching size $(err.sz)")
302 end