comparison src/LazyTensors/lazy_tensor_operations.jl @ 1043:c16116e403e2

Merge refactor/lazy_tensors
author Jonatan Werpers <jonatan@werpers.com>
date Tue, 22 Mar 2022 14:33:13 +0100
parents 9e76bf19904c
children f857057e61e6 3bb94ce74697 2e606d4c0ab1
comparison
equal deleted inserted replaced
1039:696a3307b6a4 1043:c16116e403e2
1 """ 1 """
2 LazyTensorMappingApplication{T,R,D} <: LazyArray{T,R} 2 TensorApplication{T,R,D} <: LazyArray{T,R}
3 3
4 Struct for lazy application of a TensorMapping. Created using `*`. 4 Struct for lazy application of a LazyTensor. Created using `*`.
5 5
6 Allows the result of a `TensorMapping` applied to a vector to be treated as an `AbstractArray`. 6 Allows the result of a `LazyTensor` applied to a vector to be treated as an `AbstractArray`.
7 With a mapping `m` and a vector `v` the LazyTensorMappingApplication object can be created by `m*v`. 7 With a mapping `m` and a vector `v` the TensorApplication object can be created by `m*v`.
8 The actual result will be calcualted when indexing into `m*v`. 8 The actual result will be calcualted when indexing into `m*v`.
9 """ 9 """
10 struct LazyTensorMappingApplication{T,R,D, TM<:TensorMapping{<:Any,R,D}, AA<:AbstractArray{<:Any,D}} <: LazyArray{T,R} 10 struct TensorApplication{T,R,D, TM<:LazyTensor{<:Any,R,D}, AA<:AbstractArray{<:Any,D}} <: LazyArray{T,R}
11 t::TM 11 t::TM
12 o::AA 12 o::AA
13 13
14 function LazyTensorMappingApplication(t::TensorMapping{<:Any,R,D}, o::AbstractArray{<:Any,D}) where {R,D} 14 function TensorApplication(t::LazyTensor{<:Any,R,D}, o::AbstractArray{<:Any,D}) where {R,D}
15 @boundscheck check_domain_size(t, size(o))
15 I = ntuple(i->1, range_dim(t)) 16 I = ntuple(i->1, range_dim(t))
16 T = typeof(apply(t,o,I...)) 17 T = typeof(apply(t,o,I...))
17 return new{T,R,D,typeof(t), typeof(o)}(t,o) 18 return new{T,R,D,typeof(t), typeof(o)}(t,o)
18 end 19 end
19 end 20 end
20 # TODO: Do boundschecking on creation! 21
21 22 function Base.getindex(ta::TensorApplication{T,R}, I::Vararg{Any,R}) where {T,R}
22 Base.getindex(ta::LazyTensorMappingApplication{T,R}, I::Vararg{Any,R}) where {T,R} = apply(ta.t, ta.o, I...) 23 @boundscheck checkbounds(ta, Int.(I)...)
23 Base.getindex(ta::LazyTensorMappingApplication{T,1}, I::CartesianIndex{1}) where {T} = apply(ta.t, ta.o, I.I...) # Would otherwise be caught in the previous method. 24 return @inbounds apply(ta.t, ta.o, I...)
24 Base.size(ta::LazyTensorMappingApplication) = range_size(ta.t) 25 end
25 # TODO: What else is needed to implement the AbstractArray interface? 26 Base.@propagate_inbounds Base.getindex(ta::TensorApplication{T,1} where T, I::CartesianIndex{1}) = ta[Tuple(I)...] # Would otherwise be caught in the previous method.
26 27 Base.size(ta::TensorApplication) = range_size(ta.t)
27 Base.:*(a::TensorMapping, v::AbstractArray) = LazyTensorMappingApplication(a,v) 28
28 Base.:*(a::TensorMapping, b::TensorMapping) = throw(MethodError(Base.:*,(a,b))) 29
29 Base.:*(a::TensorMapping, args::Union{TensorMapping, AbstractArray}...) = foldr(*,(a,args...)) 30 """
30 31 TensorTranspose{T,R,D} <: LazyTensor{T,D,R}
31 # # We need the associativity to be a→b→c = a→(b→c), which is the case for '→' 32
32 # # Should we overload some other infix binary opesrator? 33 Struct for lazy transpose of a LazyTensor.
33 # →(tm::TensorMapping{T,R,D}, o::AbstractArray{T,D}) where {T,R,D} = LazyTensorMappingApplication(tm,o)
34 # TODO: We need to be really careful about good error messages.
35 # For example what happens if you try to multiply LazyTensorMappingApplication with a TensorMapping(wrong order)?
36
37 """
38 LazyTensorMappingTranspose{T,R,D} <: TensorMapping{T,D,R}
39
40 Struct for lazy transpose of a TensorMapping.
41 34
42 If a mapping implements the the `apply_transpose` method this allows working with 35 If a mapping implements the the `apply_transpose` method this allows working with
43 the transpose of mapping `m` by using `m'`. `m'` will work as a regular TensorMapping lazily calling 36 the transpose of mapping `m` by using `m'`. `m'` will work as a regular LazyTensor lazily calling
44 the appropriate methods of `m`. 37 the appropriate methods of `m`.
45 """ 38 """
46 struct LazyTensorMappingTranspose{T,R,D, TM<:TensorMapping{T,R,D}} <: TensorMapping{T,D,R} 39 struct TensorTranspose{T,R,D, TM<:LazyTensor{T,R,D}} <: LazyTensor{T,D,R}
47 tm::TM 40 tm::TM
48 end 41 end
49 42
50 # # TBD: Should this be implemented on a type by type basis or through a trait to provide earlier errors? 43 # # TBD: Should this be implemented on a type by type basis or through a trait to provide earlier errors?
51 # Jonatan 2020-09-25: Is the problem that you can take the transpose of any TensorMapping even if it doesn't implement `apply_transpose`? 44 # Jonatan 2020-09-25: Is the problem that you can take the transpose of any LazyTensor even if it doesn't implement `apply_transpose`?
52 Base.adjoint(tm::TensorMapping) = LazyTensorMappingTranspose(tm) 45 Base.adjoint(tm::LazyTensor) = TensorTranspose(tm)
53 Base.adjoint(tmt::LazyTensorMappingTranspose) = tmt.tm 46 Base.adjoint(tmt::TensorTranspose) = tmt.tm
54 47
55 apply(tmt::LazyTensorMappingTranspose{T,R,D}, v::AbstractArray{<:Any,R}, I::Vararg{Any,D}) where {T,R,D} = apply_transpose(tmt.tm, v, I...) 48 apply(tmt::TensorTranspose{T,R,D}, v::AbstractArray{<:Any,R}, I::Vararg{Any,D}) where {T,R,D} = apply_transpose(tmt.tm, v, I...)
56 apply_transpose(tmt::LazyTensorMappingTranspose{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmt.tm, v, I...) 49 apply_transpose(tmt::TensorTranspose{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmt.tm, v, I...)
57 50
58 range_size(tmt::LazyTensorMappingTranspose) = domain_size(tmt.tm) 51 range_size(tmt::TensorTranspose) = domain_size(tmt.tm)
59 domain_size(tmt::LazyTensorMappingTranspose) = range_size(tmt.tm) 52 domain_size(tmt::TensorTranspose) = range_size(tmt.tm)
60 53
61 54
62 struct LazyTensorMappingBinaryOperation{Op,T,R,D,T1<:TensorMapping{T,R,D},T2<:TensorMapping{T,R,D}} <: TensorMapping{T,D,R} 55 struct ElementwiseTensorOperation{Op,T,R,D,T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}} <: LazyTensor{T,D,R}
63 tm1::T1 56 tm1::T1
64 tm2::T2 57 tm2::T2
65 58
66 @inline function LazyTensorMappingBinaryOperation{Op,T,R,D}(tm1::T1,tm2::T2) where {Op,T,R,D, T1<:TensorMapping{T,R,D},T2<:TensorMapping{T,R,D}} 59 function ElementwiseTensorOperation{Op,T,R,D}(tm1::T1,tm2::T2) where {Op,T,R,D, T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}}
60 @boundscheck check_domain_size(tm2, domain_size(tm1))
61 @boundscheck check_range_size(tm2, range_size(tm1))
67 return new{Op,T,R,D,T1,T2}(tm1,tm2) 62 return new{Op,T,R,D,T1,T2}(tm1,tm2)
68 end 63 end
69 end 64 end
70 # TODO: Boundschecking in constructor. 65
71 66 ElementwiseTensorOperation{Op}(s,t) where Op = ElementwiseTensorOperation{Op,eltype(s), range_dim(s), domain_dim(s)}(s,t)
72 apply(tmBinOp::LazyTensorMappingBinaryOperation{:+,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) + apply(tmBinOp.tm2, v, I...) 67
73 apply(tmBinOp::LazyTensorMappingBinaryOperation{:-,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) - apply(tmBinOp.tm2, v, I...) 68 apply(tmBinOp::ElementwiseTensorOperation{:+,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) + apply(tmBinOp.tm2, v, I...)
74 69 apply(tmBinOp::ElementwiseTensorOperation{:-,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) - apply(tmBinOp.tm2, v, I...)
75 range_size(tmBinOp::LazyTensorMappingBinaryOperation) = range_size(tmBinOp.tm1) 70
76 domain_size(tmBinOp::LazyTensorMappingBinaryOperation) = domain_size(tmBinOp.tm1) 71 range_size(tmBinOp::ElementwiseTensorOperation) = range_size(tmBinOp.tm1)
77 72 domain_size(tmBinOp::ElementwiseTensorOperation) = domain_size(tmBinOp.tm1)
78 Base.:+(tm1::TensorMapping{T,R,D}, tm2::TensorMapping{T,R,D}) where {T,R,D} = LazyTensorMappingBinaryOperation{:+,T,R,D}(tm1,tm2) 73
79 Base.:-(tm1::TensorMapping{T,R,D}, tm2::TensorMapping{T,R,D}) where {T,R,D} = LazyTensorMappingBinaryOperation{:-,T,R,D}(tm1,tm2) 74
80 75 """
81 """ 76 TensorComposition{T,R,K,D}
82 TensorMappingComposition{T,R,K,D} 77
83 78 Lazily compose two `LazyTensor`s, so that they can be handled as a single `LazyTensor`.
84 Lazily compose two `TensorMapping`s, so that they can be handled as a single `TensorMapping`. 79 """
85 """ 80 struct TensorComposition{T,R,K,D, TM1<:LazyTensor{T,R,K}, TM2<:LazyTensor{T,K,D}} <: LazyTensor{T,R,D}
86 struct TensorMappingComposition{T,R,K,D, TM1<:TensorMapping{T,R,K}, TM2<:TensorMapping{T,K,D}} <: TensorMapping{T,R,D}
87 t1::TM1 81 t1::TM1
88 t2::TM2 82 t2::TM2
89 83
90 @inline function TensorMappingComposition(t1::TensorMapping{T,R,K}, t2::TensorMapping{T,K,D}) where {T,R,K,D} 84 function TensorComposition(t1::LazyTensor{T,R,K}, t2::LazyTensor{T,K,D}) where {T,R,K,D}
91 @boundscheck check_domain_size(t1, range_size(t2)) 85 @boundscheck check_domain_size(t1, range_size(t2))
92 return new{T,R,K,D, typeof(t1), typeof(t2)}(t1,t2) 86 return new{T,R,K,D, typeof(t1), typeof(t2)}(t1,t2)
93 end 87 end
94 end 88 end
95 89
96 range_size(tm::TensorMappingComposition) = range_size(tm.t1) 90 range_size(tm::TensorComposition) = range_size(tm.t1)
97 domain_size(tm::TensorMappingComposition) = domain_size(tm.t2) 91 domain_size(tm::TensorComposition) = domain_size(tm.t2)
98 92
99 function apply(c::TensorMappingComposition{T,R,K,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,K,D} 93 function apply(c::TensorComposition{T,R,K,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,K,D}
100 apply(c.t1, c.t2*v, I...) 94 apply(c.t1, c.t2*v, I...)
101 end 95 end
102 96
103 function apply_transpose(c::TensorMappingComposition{T,R,K,D}, v::AbstractArray{<:Any,R}, I::Vararg{Any,D}) where {T,R,K,D} 97 function apply_transpose(c::TensorComposition{T,R,K,D}, v::AbstractArray{<:Any,R}, I::Vararg{Any,D}) where {T,R,K,D}
104 apply_transpose(c.t2, c.t1'*v, I...) 98 apply_transpose(c.t2, c.t1'*v, I...)
105 end 99 end
106 100
107 Base.@propagate_inbounds Base.:∘(s::TensorMapping, t::TensorMapping) = TensorMappingComposition(s,t) 101
108 102 """
109 """ 103 TensorComposition(tm, tmi::IdentityTensor)
110 LazyLinearMap{T,R,D,...}(A, range_indicies, domain_indicies) 104 TensorComposition(tmi::IdentityTensor, tm)
111 105
112 TensorMapping defined by the AbstractArray A. `range_indicies` and `domain_indicies` define which indicies of A should 106 Composes a `Tensormapping` `tm` with an `IdentityTensor` `tmi`, by returning `tm`
113 be considerd the range and domain of the TensorMapping. Each set of indices must be ordered in ascending order. 107 """
114 108 function TensorComposition(tm::LazyTensor{T,R,D}, tmi::IdentityTensor{T,D}) where {T,R,D}
115 For instance, if A is a m x n matrix, and range_size = (1,), domain_size = (2,), then the LazyLinearMap performs the
116 standard matrix-vector product on vectors of size n.
117 """
118 struct LazyLinearMap{T,R,D, RD, AA<:AbstractArray{T,RD}} <: TensorMapping{T,R,D}
119 A::AA
120 range_indicies::NTuple{R,Int}
121 domain_indicies::NTuple{D,Int}
122
123 function LazyLinearMap(A::AA, range_indicies::NTuple{R,Int}, domain_indicies::NTuple{D,Int}) where {T,R,D, RD, AA<:AbstractArray{T,RD}}
124 if !issorted(range_indicies) || !issorted(domain_indicies)
125 throw(DomainError("range_indicies and domain_indicies must be sorted in ascending order"))
126 end
127
128 return new{T,R,D,RD,AA}(A,range_indicies,domain_indicies)
129 end
130 end
131
132 range_size(llm::LazyLinearMap) = size(llm.A)[[llm.range_indicies...]]
133 domain_size(llm::LazyLinearMap) = size(llm.A)[[llm.domain_indicies...]]
134
135 function apply(llm::LazyLinearMap{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D}
136 view_index = ntuple(i->:,ndims(llm.A))
137 for i ∈ 1:R
138 view_index = Base.setindex(view_index, Int(I[i]), llm.range_indicies[i])
139 end
140 A_view = @view llm.A[view_index...]
141 return sum(A_view.*v)
142 end
143
144 function apply_transpose(llm::LazyLinearMap{T,R,D}, v::AbstractArray{<:Any,R}, I::Vararg{Any,D}) where {T,R,D}
145 apply(LazyLinearMap(llm.A, llm.domain_indicies, llm.range_indicies), v, I...)
146 end
147
148
149 """
150 IdentityMapping{T,D} <: TensorMapping{T,D,D}
151
152 The lazy identity TensorMapping for a given size. Usefull for building up higher dimensional tensor mappings from lower
153 dimensional ones through outer products. Also used in the Implementation for InflatedTensorMapping.
154 """
155 struct IdentityMapping{T,D} <: TensorMapping{T,D,D}
156 size::NTuple{D,Int}
157 end
158
159 IdentityMapping{T}(size::NTuple{D,Int}) where {T,D} = IdentityMapping{T,D}(size)
160 IdentityMapping{T}(size::Vararg{Int,D}) where {T,D} = IdentityMapping{T,D}(size)
161 IdentityMapping(size::Vararg{Int,D}) where D = IdentityMapping{Float64,D}(size)
162
163 range_size(tmi::IdentityMapping) = tmi.size
164 domain_size(tmi::IdentityMapping) = tmi.size
165
166 apply(tmi::IdentityMapping{T,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,D}) where {T,D} = v[I...]
167 apply_transpose(tmi::IdentityMapping{T,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,D}) where {T,D} = v[I...]
168
169 """
170 Base.:∘(tm, tmi)
171 Base.:∘(tmi, tm)
172
173 Composes a `Tensormapping` `tm` with an `IdentityMapping` `tmi`, by returning `tm`
174 """
175 @inline function Base.:∘(tm::TensorMapping{T,R,D}, tmi::IdentityMapping{T,D}) where {T,R,D}
176 @boundscheck check_domain_size(tm, range_size(tmi)) 109 @boundscheck check_domain_size(tm, range_size(tmi))
177 return tm 110 return tm
178 end 111 end
179 112
180 @inline function Base.:∘(tmi::IdentityMapping{T,R}, tm::TensorMapping{T,R,D}) where {T,R,D} 113 function TensorComposition(tmi::IdentityTensor{T,R}, tm::LazyTensor{T,R,D}) where {T,R,D}
181 @boundscheck check_domain_size(tmi, range_size(tm)) 114 @boundscheck check_domain_size(tmi, range_size(tm))
182 return tm 115 return tm
183 end 116 end
184 # Specialization for the case where tm is an IdentityMapping. Required to resolve ambiguity. 117 # Specialization for the case where tm is an IdentityTensor. Required to resolve ambiguity.
185 @inline function Base.:∘(tm::IdentityMapping{T,D}, tmi::IdentityMapping{T,D}) where {T,D} 118 function TensorComposition(tm::IdentityTensor{T,D}, tmi::IdentityTensor{T,D}) where {T,D}
186 @boundscheck check_domain_size(tm, range_size(tmi)) 119 @boundscheck check_domain_size(tm, range_size(tmi))
187 return tmi 120 return tmi
188 end 121 end
189 122
190 123
191 """ 124 """
192 InflatedTensorMapping{T,R,D} <: TensorMapping{T,R,D} 125 InflatedTensor{T,R,D} <: LazyTensor{T,R,D}
193 126
194 An inflated `TensorMapping` with dimensions added before and afer its actual dimensions. 127 An inflated `LazyTensor` with dimensions added before and afer its actual dimensions.
195 """ 128 """
196 struct InflatedTensorMapping{T,R,D,D_before,R_middle,D_middle,D_after, TM<:TensorMapping{T,R_middle,D_middle}} <: TensorMapping{T,R,D} 129 struct InflatedTensor{T,R,D,D_before,R_middle,D_middle,D_after, TM<:LazyTensor{T,R_middle,D_middle}} <: LazyTensor{T,R,D}
197 before::IdentityMapping{T,D_before} 130 before::IdentityTensor{T,D_before}
198 tm::TM 131 tm::TM
199 after::IdentityMapping{T,D_after} 132 after::IdentityTensor{T,D_after}
200 133
201 function InflatedTensorMapping(before, tm::TensorMapping{T}, after) where T 134 function InflatedTensor(before, tm::LazyTensor{T}, after) where T
202 R_before = range_dim(before) 135 R_before = range_dim(before)
203 R_middle = range_dim(tm) 136 R_middle = range_dim(tm)
204 R_after = range_dim(after) 137 R_after = range_dim(after)
205 R = R_before+R_middle+R_after 138 R = R_before+R_middle+R_after
206 139
209 D_after = domain_dim(after) 142 D_after = domain_dim(after)
210 D = D_before+D_middle+D_after 143 D = D_before+D_middle+D_after
211 return new{T,R,D,D_before,R_middle,D_middle,D_after, typeof(tm)}(before, tm, after) 144 return new{T,R,D,D_before,R_middle,D_middle,D_after, typeof(tm)}(before, tm, after)
212 end 145 end
213 end 146 end
214 """ 147
215 InflatedTensorMapping(before, tm, after) 148 """
216 InflatedTensorMapping(before,tm) 149 InflatedTensor(before, tm, after)
217 InflatedTensorMapping(tm,after) 150 InflatedTensor(before,tm)
218 151 InflatedTensor(tm,after)
219 The outer product of `before`, `tm` and `after`, where `before` and `after` are `IdentityMapping`s. 152
220 153 The outer product of `before`, `tm` and `after`, where `before` and `after` are `IdentityTensor`s.
221 If one of `before` or `after` is left out, a 0-dimensional `IdentityMapping` is used as the default value. 154
222 155 If one of `before` or `after` is left out, a 0-dimensional `IdentityTensor` is used as the default value.
223 If `tm` already is an `InflatedTensorMapping`, `before` and `after` will be extended instead of 156
224 creating a nested `InflatedTensorMapping`. 157 If `tm` already is an `InflatedTensor`, `before` and `after` will be extended instead of
225 """ 158 creating a nested `InflatedTensor`.
226 InflatedTensorMapping(::IdentityMapping, ::TensorMapping, ::IdentityMapping) 159 """
227 160 InflatedTensor(::IdentityTensor, ::LazyTensor, ::IdentityTensor)
228 function InflatedTensorMapping(before, itm::InflatedTensorMapping, after) 161
229 return InflatedTensorMapping( 162 function InflatedTensor(before, itm::InflatedTensor, after)
230 IdentityMapping(before.size..., itm.before.size...), 163 return InflatedTensor(
164 IdentityTensor(before.size..., itm.before.size...),
231 itm.tm, 165 itm.tm,
232 IdentityMapping(itm.after.size..., after.size...), 166 IdentityTensor(itm.after.size..., after.size...),
233 ) 167 )
234 end 168 end
235 169
236 InflatedTensorMapping(before::IdentityMapping, tm::TensorMapping{T}) where T = InflatedTensorMapping(before,tm,IdentityMapping{T}()) 170 InflatedTensor(before::IdentityTensor, tm::LazyTensor{T}) where T = InflatedTensor(before,tm,IdentityTensor{T}())
237 InflatedTensorMapping(tm::TensorMapping{T}, after::IdentityMapping) where T = InflatedTensorMapping(IdentityMapping{T}(),tm,after) 171 InflatedTensor(tm::LazyTensor{T}, after::IdentityTensor) where T = InflatedTensor(IdentityTensor{T}(),tm,after)
238 # Resolve ambiguity between the two previous methods 172 # Resolve ambiguity between the two previous methods
239 InflatedTensorMapping(I1::IdentityMapping{T}, I2::IdentityMapping{T}) where T = InflatedTensorMapping(I1,I2,IdentityMapping{T}()) 173 InflatedTensor(I1::IdentityTensor{T}, I2::IdentityTensor{T}) where T = InflatedTensor(I1,I2,IdentityTensor{T}())
240 174
241 # TODO: Implement some pretty printing in terms of ⊗. E.g InflatedTensorMapping(I(3),B,I(2)) -> I(3)⊗B⊗I(2) 175 # TODO: Implement some pretty printing in terms of ⊗. E.g InflatedTensor(I(3),B,I(2)) -> I(3)⊗B⊗I(2)
242 176
243 function range_size(itm::InflatedTensorMapping) 177 function range_size(itm::InflatedTensor)
244 return flatten_tuple( 178 return flatten_tuple(
245 range_size(itm.before), 179 range_size(itm.before),
246 range_size(itm.tm), 180 range_size(itm.tm),
247 range_size(itm.after), 181 range_size(itm.after),
248 ) 182 )
249 end 183 end
250 184
251 function domain_size(itm::InflatedTensorMapping) 185 function domain_size(itm::InflatedTensor)
252 return flatten_tuple( 186 return flatten_tuple(
253 domain_size(itm.before), 187 domain_size(itm.before),
254 domain_size(itm.tm), 188 domain_size(itm.tm),
255 domain_size(itm.after), 189 domain_size(itm.after),
256 ) 190 )
257 end 191 end
258 192
259 function apply(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} 193 function apply(itm::InflatedTensor{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D}
260 dim_before = range_dim(itm.before) 194 dim_before = range_dim(itm.before)
261 dim_domain = domain_dim(itm.tm) 195 dim_domain = domain_dim(itm.tm)
262 dim_range = range_dim(itm.tm) 196 dim_range = range_dim(itm.tm)
263 dim_after = range_dim(itm.after) 197 dim_after = range_dim(itm.after)
264 198
266 200
267 v_inner = view(v, view_index...) 201 v_inner = view(v, view_index...)
268 return apply(itm.tm, v_inner, inner_index...) 202 return apply(itm.tm, v_inner, inner_index...)
269 end 203 end
270 204
271 function apply_transpose(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{<:Any,R}, I::Vararg{Any,D}) where {T,R,D} 205 function apply_transpose(itm::InflatedTensor{T,R,D}, v::AbstractArray{<:Any,R}, I::Vararg{Any,D}) where {T,R,D}
272 dim_before = range_dim(itm.before) 206 dim_before = range_dim(itm.before)
273 dim_domain = domain_dim(itm.tm) 207 dim_domain = domain_dim(itm.tm)
274 dim_range = range_dim(itm.tm) 208 dim_range = range_dim(itm.tm)
275 dim_after = range_dim(itm.after) 209 dim_after = range_dim(itm.after)
276 210
279 v_inner = view(v, view_index...) 213 v_inner = view(v, view_index...)
280 return apply_transpose(itm.tm, v_inner, inner_index...) 214 return apply_transpose(itm.tm, v_inner, inner_index...)
281 end 215 end
282 216
283 217
284 """
285 split_index(::Val{dim_before}, ::Val{dim_view}, ::Val{dim_index}, ::Val{dim_after}, I...)
286
287 Splits the multi-index `I` into two parts. One part which is expected to be
288 used as a view, and one which is expected to be used as an index.
289 Eg.
290 ```
291 split_index(Val(1),Val(3),Val(2),Val(1),(1,2,3,4)) -> (1,:,:,:,4), (2,3)
292 ```
293
294 `dim_view` controls how many colons are in the view, and `dim_index` controls
295 how many elements are extracted from the middle.
296 `dim_before` and `dim_after` decides the length of the index parts before and after the colons in the view index.
297
298 Arguments should satisfy `length(I) == dim_before+B_domain+dim_after`.
299
300 The returned values satisfy
301 * `length(view_index) == dim_before + dim_view + dim_after`
302 * `length(I_middle) == dim_index`
303 """
304 function split_index(::Val{dim_before}, ::Val{dim_view}, ::Val{dim_index}, ::Val{dim_after}, I...) where {dim_before,dim_view, dim_index,dim_after}
305 I_before, I_middle, I_after = split_tuple(I, Val(dim_before), Val(dim_index))
306
307 view_index = (I_before..., ntuple((i)->:, dim_view)..., I_after...)
308
309 return view_index, I_middle
310 end
311
312 # TODO: Can this be replaced by something more elegant while still being type stable? 2020-10-21
313 # See:
314 # https://github.com/JuliaLang/julia/issues/34884
315 # https://github.com/JuliaLang/julia/issues/30386
316 """
317 slice_tuple(t, Val(l), Val(u))
318
319 Get a slice of a tuple in a type stable way.
320 Equivalent to `t[l:u]` but type stable.
321 """
322 function slice_tuple(t,::Val{L},::Val{U}) where {L,U}
323 return ntuple(i->t[i+L-1], U-L+1)
324 end
325
326 """
327 split_tuple(t::Tuple{...}, ::Val{M}) where {N,M}
328
329 Split the tuple `t` into two parts. the first part is `M` long.
330 E.g
331 ```julia
332 split_tuple((1,2,3,4),Val(3)) -> (1,2,3), (4,)
333 ```
334 """
335 function split_tuple(t::NTuple{N,Any},::Val{M}) where {N,M}
336 return slice_tuple(t,Val(1), Val(M)), slice_tuple(t,Val(M+1), Val(N))
337 end
338
339 """
340 split_tuple(t::Tuple{...},::Val{M},::Val{K}) where {N,M,K}
341
342 Same as `split_tuple(t::NTuple{N},::Val{M})` but splits the tuple in three parts. With the first
343 two parts having lenght `M` and `K`.
344 """
345 function split_tuple(t::NTuple{N,Any},::Val{M},::Val{K}) where {N,M,K}
346 p1, tail = split_tuple(t, Val(M))
347 p2, p3 = split_tuple(tail, Val(K))
348 return p1,p2,p3
349 end
350
351
352 """
353 flatten_tuple(t)
354
355 Takes a nested tuple and flattens the whole structure
356 """
357 flatten_tuple(t::NTuple{N, Number} where N) = t
358 flatten_tuple(t::Tuple) = ((flatten_tuple.(t)...)...,) # simplify?
359 flatten_tuple(ts::Vararg) = flatten_tuple(ts)
360
361 @doc raw""" 218 @doc raw"""
362 LazyOuterProduct(tms...) 219 LazyOuterProduct(tms...)
363 220
364 Creates a `TensorMappingComposition` for the outerproduct of `tms...`. 221 Creates a `TensorComposition` for the outerproduct of `tms...`.
365 This is done by separating the outer product into regular products of outer products involving only identity mappings and one non-identity mapping. 222 This is done by separating the outer product into regular products of outer products involving only identity mappings and one non-identity mapping.
366 223
367 First let 224 First let
368 ```math 225 ```math
369 \begin{aligned} 226 \begin{aligned}
395 (A⊗B⊗C)v = [(A⊗I_{|M|}⊗I_{|P|}) [(I_{|J|}⊗B⊗I_{|P|}) [(I_{|J|}⊗I_{|N|}⊗C)v]]] 252 (A⊗B⊗C)v = [(A⊗I_{|M|}⊗I_{|P|}) [(I_{|J|}⊗B⊗I_{|P|}) [(I_{|J|}⊗I_{|N|}⊗C)v]]]
396 ``` 253 ```
397 """ 254 """
398 function LazyOuterProduct end 255 function LazyOuterProduct end
399 256
400 function LazyOuterProduct(tm1::TensorMapping{T}, tm2::TensorMapping{T}) where T 257 function LazyOuterProduct(tm1::LazyTensor{T}, tm2::LazyTensor{T}) where T
401 itm1 = InflatedTensorMapping(tm1, IdentityMapping{T}(range_size(tm2))) 258 itm1 = InflatedTensor(tm1, IdentityTensor{T}(range_size(tm2)))
402 itm2 = InflatedTensorMapping(IdentityMapping{T}(domain_size(tm1)),tm2) 259 itm2 = InflatedTensor(IdentityTensor{T}(domain_size(tm1)),tm2)
403 260
404 return itm1∘itm2 261 return itm1∘itm2
405 end 262 end
406 263
407 LazyOuterProduct(t1::IdentityMapping{T}, t2::IdentityMapping{T}) where T = IdentityMapping{T}(t1.size...,t2.size...) 264 LazyOuterProduct(t1::IdentityTensor{T}, t2::IdentityTensor{T}) where T = IdentityTensor{T}(t1.size...,t2.size...)
408 LazyOuterProduct(t1::TensorMapping, t2::IdentityMapping) = InflatedTensorMapping(t1, t2) 265 LazyOuterProduct(t1::LazyTensor, t2::IdentityTensor) = InflatedTensor(t1, t2)
409 LazyOuterProduct(t1::IdentityMapping, t2::TensorMapping) = InflatedTensorMapping(t1, t2) 266 LazyOuterProduct(t1::IdentityTensor, t2::LazyTensor) = InflatedTensor(t1, t2)
410 267
411 LazyOuterProduct(tms::Vararg{TensorMapping}) = foldl(LazyOuterProduct, tms) 268 LazyOuterProduct(tms::Vararg{LazyTensor}) = foldl(LazyOuterProduct, tms)
412 269
413 ⊗(a::TensorMapping, b::TensorMapping) = LazyOuterProduct(a,b) 270
414 271 function check_domain_size(tm::LazyTensor, sz)
415
416 function check_domain_size(tm::TensorMapping, sz)
417 if domain_size(tm) != sz 272 if domain_size(tm) != sz
418 throw(SizeMismatch(tm,sz)) 273 throw(DomainSizeMismatch(tm,sz))
419 end 274 end
420 end 275 end
421 276
422 struct SizeMismatch <: Exception 277 function check_range_size(tm::LazyTensor, sz)
423 tm::TensorMapping 278 if range_size(tm) != sz
279 throw(RangeSizeMismatch(tm,sz))
280 end
281 end
282
283 struct DomainSizeMismatch <: Exception
284 tm::LazyTensor
424 sz 285 sz
425 end 286 end
426 287
427 function Base.showerror(io::IO, err::SizeMismatch) 288 function Base.showerror(io::IO, err::DomainSizeMismatch)
428 print(io, "SizeMismatch: ") 289 print(io, "DomainSizeMismatch: ")
429 print(io, "domain size $(domain_size(err.tm)) of TensorMapping not matching size $(err.sz)") 290 print(io, "domain size $(domain_size(err.tm)) of LazyTensor not matching size $(err.sz)")
430 end 291 end
292
293
294 struct RangeSizeMismatch <: Exception
295 tm::LazyTensor
296 sz
297 end
298
299 function Base.showerror(io::IO, err::RangeSizeMismatch)
300 print(io, "RangeSizeMismatch: ")
301 print(io, "range size $(range_size(err.tm)) of LazyTensor not matching size $(err.sz)")
302 end