comparison src/LazyTensors/lazy_tensor_operations.jl @ 1023:52f07c77299d refactor/sbpoperators/inflation

Merge refactor/lazy_tensors
author Jonatan Werpers <jonatan@werpers.com>
date Mon, 21 Mar 2022 09:51:07 +0100
parents bbbc31953367 f7a718bcb4da
children f857057e61e6
comparison
equal deleted inserted replaced
1022:bbbc31953367 1023:52f07c77299d
1 """ 1 """
2 LazyTensorMappingApplication{T,R,D} <: LazyArray{T,R} 2 LazyTensorApplication{T,R,D} <: LazyArray{T,R}
3 3
4 Struct for lazy application of a TensorMapping. Created using `*`. 4 Struct for lazy application of a LazyTensor. Created using `*`.
5 5
6 Allows the result of a `TensorMapping` applied to a vector to be treated as an `AbstractArray`. 6 Allows the result of a `LazyTensor` applied to a vector to be treated as an `AbstractArray`.
7 With a mapping `m` and a vector `v` the LazyTensorMappingApplication object can be created by `m*v`. 7 With a mapping `m` and a vector `v` the LazyTensorApplication object can be created by `m*v`.
8 The actual result will be calcualted when indexing into `m*v`. 8 The actual result will be calcualted when indexing into `m*v`.
9 """ 9 """
10 struct LazyTensorMappingApplication{T,R,D, TM<:TensorMapping{<:Any,R,D}, AA<:AbstractArray{<:Any,D}} <: LazyArray{T,R} 10 struct LazyTensorApplication{T,R,D, TM<:LazyTensor{<:Any,R,D}, AA<:AbstractArray{<:Any,D}} <: LazyArray{T,R}
11 t::TM 11 t::TM
12 o::AA 12 o::AA
13 13
14 function LazyTensorMappingApplication(t::TensorMapping{<:Any,R,D}, o::AbstractArray{<:Any,D}) where {R,D} 14 function LazyTensorApplication(t::LazyTensor{<:Any,R,D}, o::AbstractArray{<:Any,D}) where {R,D}
15 @boundscheck check_domain_size(t, size(o))
15 I = ntuple(i->1, range_dim(t)) 16 I = ntuple(i->1, range_dim(t))
16 T = typeof(apply(t,o,I...)) 17 T = typeof(apply(t,o,I...))
17 return new{T,R,D,typeof(t), typeof(o)}(t,o) 18 return new{T,R,D,typeof(t), typeof(o)}(t,o)
18 end 19 end
19 end 20 end
20 # TODO: Do boundschecking on creation! 21
21 22 function Base.getindex(ta::LazyTensorApplication{T,R}, I::Vararg{Any,R}) where {T,R}
22 Base.getindex(ta::LazyTensorMappingApplication{T,R}, I::Vararg{Any,R}) where {T,R} = apply(ta.t, ta.o, I...) 23 @boundscheck checkbounds(ta, Int.(I)...)
23 Base.getindex(ta::LazyTensorMappingApplication{T,1}, I::CartesianIndex{1}) where {T} = apply(ta.t, ta.o, I.I...) # Would otherwise be caught in the previous method. 24 return apply(ta.t, ta.o, I...)
24 Base.size(ta::LazyTensorMappingApplication) = range_size(ta.t) 25 end
25 # TODO: What else is needed to implement the AbstractArray interface? 26 Base.getindex(ta::LazyTensorApplication{T,1} where T, I::CartesianIndex{1}) = ta[Tuple(I)...] # Would otherwise be caught in the previous method.
26 27 Base.size(ta::LazyTensorApplication) = range_size(ta.t)
27 Base.:*(a::TensorMapping, v::AbstractArray) = LazyTensorMappingApplication(a,v) 28
28 Base.:*(a::TensorMapping, b::TensorMapping) = throw(MethodError(Base.:*,(a,b))) 29
29 Base.:*(a::TensorMapping, args::Union{TensorMapping, AbstractArray}...) = foldr(*,(a,args...)) 30 """
30 31 LazyTensorTranspose{T,R,D} <: LazyTensor{T,D,R}
31 # # We need the associativity to be a→b→c = a→(b→c), which is the case for '→' 32
32 # # Should we overload some other infix binary opesrator? 33 Struct for lazy transpose of a LazyTensor.
33 # →(tm::TensorMapping{T,R,D}, o::AbstractArray{T,D}) where {T,R,D} = LazyTensorMappingApplication(tm,o)
34 # TODO: We need to be really careful about good error messages.
35 # For example what happens if you try to multiply LazyTensorMappingApplication with a TensorMapping(wrong order)?
36
37 """
38 LazyTensorMappingTranspose{T,R,D} <: TensorMapping{T,D,R}
39
40 Struct for lazy transpose of a TensorMapping.
41 34
42 If a mapping implements the the `apply_transpose` method this allows working with 35 If a mapping implements the the `apply_transpose` method this allows working with
43 the transpose of mapping `m` by using `m'`. `m'` will work as a regular TensorMapping lazily calling 36 the transpose of mapping `m` by using `m'`. `m'` will work as a regular LazyTensor lazily calling
44 the appropriate methods of `m`. 37 the appropriate methods of `m`.
45 """ 38 """
46 struct LazyTensorMappingTranspose{T,R,D, TM<:TensorMapping{T,R,D}} <: TensorMapping{T,D,R} 39 struct LazyTensorTranspose{T,R,D, TM<:LazyTensor{T,R,D}} <: LazyTensor{T,D,R}
47 tm::TM 40 tm::TM
48 end 41 end
49 42
50 # # TBD: Should this be implemented on a type by type basis or through a trait to provide earlier errors? 43 # # TBD: Should this be implemented on a type by type basis or through a trait to provide earlier errors?
51 # Jonatan 2020-09-25: Is the problem that you can take the transpose of any TensorMapping even if it doesn't implement `apply_transpose`? 44 # Jonatan 2020-09-25: Is the problem that you can take the transpose of any LazyTensor even if it doesn't implement `apply_transpose`?
52 Base.adjoint(tm::TensorMapping) = LazyTensorMappingTranspose(tm) 45 Base.adjoint(tm::LazyTensor) = LazyTensorTranspose(tm)
53 Base.adjoint(tmt::LazyTensorMappingTranspose) = tmt.tm 46 Base.adjoint(tmt::LazyTensorTranspose) = tmt.tm
54 47
55 apply(tmt::LazyTensorMappingTranspose{T,R,D}, v::AbstractArray{<:Any,R}, I::Vararg{Any,D}) where {T,R,D} = apply_transpose(tmt.tm, v, I...) 48 apply(tmt::LazyTensorTranspose{T,R,D}, v::AbstractArray{<:Any,R}, I::Vararg{Any,D}) where {T,R,D} = apply_transpose(tmt.tm, v, I...)
56 apply_transpose(tmt::LazyTensorMappingTranspose{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmt.tm, v, I...) 49 apply_transpose(tmt::LazyTensorTranspose{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmt.tm, v, I...)
57 50
58 range_size(tmt::LazyTensorMappingTranspose) = domain_size(tmt.tm) 51 range_size(tmt::LazyTensorTranspose) = domain_size(tmt.tm)
59 domain_size(tmt::LazyTensorMappingTranspose) = range_size(tmt.tm) 52 domain_size(tmt::LazyTensorTranspose) = range_size(tmt.tm)
60 53
61 54
62 struct LazyTensorMappingBinaryOperation{Op,T,R,D,T1<:TensorMapping{T,R,D},T2<:TensorMapping{T,R,D}} <: TensorMapping{T,D,R} 55 struct LazyTensorBinaryOperation{Op,T,R,D,T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}} <: LazyTensor{T,D,R}
63 tm1::T1 56 tm1::T1
64 tm2::T2 57 tm2::T2
65 58
66 @inline function LazyTensorMappingBinaryOperation{Op,T,R,D}(tm1::T1,tm2::T2) where {Op,T,R,D, T1<:TensorMapping{T,R,D},T2<:TensorMapping{T,R,D}} 59 function LazyTensorBinaryOperation{Op,T,R,D}(tm1::T1,tm2::T2) where {Op,T,R,D, T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}}
60 @boundscheck check_domain_size(tm2, domain_size(tm1))
61 @boundscheck check_range_size(tm2, range_size(tm1))
67 return new{Op,T,R,D,T1,T2}(tm1,tm2) 62 return new{Op,T,R,D,T1,T2}(tm1,tm2)
68 end 63 end
69 end 64 end
70 # TODO: Boundschecking in constructor. 65
71 66 LazyTensorBinaryOperation{Op}(s,t) where Op = LazyTensorBinaryOperation{Op,eltype(s), range_dim(s), domain_dim(s)}(s,t)
72 apply(tmBinOp::LazyTensorMappingBinaryOperation{:+,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) + apply(tmBinOp.tm2, v, I...) 67
73 apply(tmBinOp::LazyTensorMappingBinaryOperation{:-,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) - apply(tmBinOp.tm2, v, I...) 68 apply(tmBinOp::LazyTensorBinaryOperation{:+,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) + apply(tmBinOp.tm2, v, I...)
74 69 apply(tmBinOp::LazyTensorBinaryOperation{:-,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) - apply(tmBinOp.tm2, v, I...)
75 range_size(tmBinOp::LazyTensorMappingBinaryOperation) = range_size(tmBinOp.tm1) 70
76 domain_size(tmBinOp::LazyTensorMappingBinaryOperation) = domain_size(tmBinOp.tm1) 71 range_size(tmBinOp::LazyTensorBinaryOperation) = range_size(tmBinOp.tm1)
77 72 domain_size(tmBinOp::LazyTensorBinaryOperation) = domain_size(tmBinOp.tm1)
78 Base.:+(tm1::TensorMapping{T,R,D}, tm2::TensorMapping{T,R,D}) where {T,R,D} = LazyTensorMappingBinaryOperation{:+,T,R,D}(tm1,tm2) 73
79 Base.:-(tm1::TensorMapping{T,R,D}, tm2::TensorMapping{T,R,D}) where {T,R,D} = LazyTensorMappingBinaryOperation{:-,T,R,D}(tm1,tm2) 74
80 75 """
81 """ 76 LazyTensorComposition{T,R,K,D}
82 TensorMappingComposition{T,R,K,D} 77
83 78 Lazily compose two `LazyTensor`s, so that they can be handled as a single `LazyTensor`.
84 Lazily compose two `TensorMapping`s, so that they can be handled as a single `TensorMapping`. 79 """
85 """ 80 struct LazyTensorComposition{T,R,K,D, TM1<:LazyTensor{T,R,K}, TM2<:LazyTensor{T,K,D}} <: LazyTensor{T,R,D}
86 struct TensorMappingComposition{T,R,K,D, TM1<:TensorMapping{T,R,K}, TM2<:TensorMapping{T,K,D}} <: TensorMapping{T,R,D}
87 t1::TM1 81 t1::TM1
88 t2::TM2 82 t2::TM2
89 83
90 @inline function TensorMappingComposition(t1::TensorMapping{T,R,K}, t2::TensorMapping{T,K,D}) where {T,R,K,D} 84 function LazyTensorComposition(t1::LazyTensor{T,R,K}, t2::LazyTensor{T,K,D}) where {T,R,K,D}
91 @boundscheck check_domain_size(t1, range_size(t2)) 85 @boundscheck check_domain_size(t1, range_size(t2))
92 return new{T,R,K,D, typeof(t1), typeof(t2)}(t1,t2) 86 return new{T,R,K,D, typeof(t1), typeof(t2)}(t1,t2)
93 end 87 end
94 end 88 end
95 89
96 range_size(tm::TensorMappingComposition) = range_size(tm.t1) 90 range_size(tm::LazyTensorComposition) = range_size(tm.t1)
97 domain_size(tm::TensorMappingComposition) = domain_size(tm.t2) 91 domain_size(tm::LazyTensorComposition) = domain_size(tm.t2)
98 92
99 function apply(c::TensorMappingComposition{T,R,K,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,K,D} 93 function apply(c::LazyTensorComposition{T,R,K,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,K,D}
100 apply(c.t1, c.t2*v, I...) 94 apply(c.t1, c.t2*v, I...)
101 end 95 end
102 96
103 function apply_transpose(c::TensorMappingComposition{T,R,K,D}, v::AbstractArray{<:Any,R}, I::Vararg{Any,D}) where {T,R,K,D} 97 function apply_transpose(c::LazyTensorComposition{T,R,K,D}, v::AbstractArray{<:Any,R}, I::Vararg{Any,D}) where {T,R,K,D}
104 apply_transpose(c.t2, c.t1'*v, I...) 98 apply_transpose(c.t2, c.t1'*v, I...)
105 end 99 end
106 100
107 Base.@propagate_inbounds Base.:∘(s::TensorMapping, t::TensorMapping) = TensorMappingComposition(s,t) 101
108 102 """
109 """ 103 LazyTensorComposition(tm, tmi::IdentityTensor)
110 LazyLinearMap{T,R,D,...}(A, range_indicies, domain_indicies) 104 LazyTensorComposition(tmi::IdentityTensor, tm)
111 105
112 TensorMapping defined by the AbstractArray A. `range_indicies` and `domain_indicies` define which indicies of A should 106 Composes a `Tensormapping` `tm` with an `IdentityTensor` `tmi`, by returning `tm`
113 be considerd the range and domain of the TensorMapping. Each set of indices must be ordered in ascending order. 107 """
114 108 function LazyTensorComposition(tm::LazyTensor{T,R,D}, tmi::IdentityTensor{T,D}) where {T,R,D}
115 For instance, if A is a m x n matrix, and range_size = (1,), domain_size = (2,), then the LazyLinearMap performs the
116 standard matrix-vector product on vectors of size n.
117 """
118 struct LazyLinearMap{T,R,D, RD, AA<:AbstractArray{T,RD}} <: TensorMapping{T,R,D}
119 A::AA
120 range_indicies::NTuple{R,Int}
121 domain_indicies::NTuple{D,Int}
122
123 function LazyLinearMap(A::AA, range_indicies::NTuple{R,Int}, domain_indicies::NTuple{D,Int}) where {T,R,D, RD, AA<:AbstractArray{T,RD}}
124 if !issorted(range_indicies) || !issorted(domain_indicies)
125 throw(DomainError("range_indicies and domain_indicies must be sorted in ascending order"))
126 end
127
128 return new{T,R,D,RD,AA}(A,range_indicies,domain_indicies)
129 end
130 end
131
132 range_size(llm::LazyLinearMap) = size(llm.A)[[llm.range_indicies...]]
133 domain_size(llm::LazyLinearMap) = size(llm.A)[[llm.domain_indicies...]]
134
135 function apply(llm::LazyLinearMap{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D}
136 view_index = ntuple(i->:,ndims(llm.A))
137 for i ∈ 1:R
138 view_index = Base.setindex(view_index, Int(I[i]), llm.range_indicies[i])
139 end
140 A_view = @view llm.A[view_index...]
141 return sum(A_view.*v)
142 end
143
144 function apply_transpose(llm::LazyLinearMap{T,R,D}, v::AbstractArray{<:Any,R}, I::Vararg{Any,D}) where {T,R,D}
145 apply(LazyLinearMap(llm.A, llm.domain_indicies, llm.range_indicies), v, I...)
146 end
147
148
149 """
150 IdentityMapping{T,D} <: TensorMapping{T,D,D}
151
152 The lazy identity TensorMapping for a given size. Usefull for building up higher dimensional tensor mappings from lower
153 dimensional ones through outer products. Also used in the Implementation for InflatedTensorMapping.
154 """
155 struct IdentityMapping{T,D} <: TensorMapping{T,D,D}
156 size::NTuple{D,Int}
157 end
158
159 IdentityMapping{T}(size::NTuple{D,Int}) where {T,D} = IdentityMapping{T,D}(size)
160 IdentityMapping{T}(size::Vararg{Int,D}) where {T,D} = IdentityMapping{T,D}(size)
161 IdentityMapping(size::Vararg{Int,D}) where D = IdentityMapping{Float64,D}(size)
162
163 range_size(tmi::IdentityMapping) = tmi.size
164 domain_size(tmi::IdentityMapping) = tmi.size
165
166 apply(tmi::IdentityMapping{T,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,D}) where {T,D} = v[I...]
167 apply_transpose(tmi::IdentityMapping{T,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,D}) where {T,D} = v[I...]
168
169 """
170 Base.:∘(tm, tmi)
171 Base.:∘(tmi, tm)
172
173 Composes a `Tensormapping` `tm` with an `IdentityMapping` `tmi`, by returning `tm`
174 """
175 @inline function Base.:∘(tm::TensorMapping{T,R,D}, tmi::IdentityMapping{T,D}) where {T,R,D}
176 @boundscheck check_domain_size(tm, range_size(tmi)) 109 @boundscheck check_domain_size(tm, range_size(tmi))
177 return tm 110 return tm
178 end 111 end
179 112
180 @inline function Base.:∘(tmi::IdentityMapping{T,R}, tm::TensorMapping{T,R,D}) where {T,R,D} 113 function LazyTensorComposition(tmi::IdentityTensor{T,R}, tm::LazyTensor{T,R,D}) where {T,R,D}
181 @boundscheck check_domain_size(tmi, range_size(tm)) 114 @boundscheck check_domain_size(tmi, range_size(tm))
182 return tm 115 return tm
183 end 116 end
184 # Specialization for the case where tm is an IdentityMapping. Required to resolve ambiguity. 117 # Specialization for the case where tm is an IdentityTensor. Required to resolve ambiguity.
185 @inline function Base.:∘(tm::IdentityMapping{T,D}, tmi::IdentityMapping{T,D}) where {T,D} 118 function LazyTensorComposition(tm::IdentityTensor{T,D}, tmi::IdentityTensor{T,D}) where {T,D}
186 @boundscheck check_domain_size(tm, range_size(tmi)) 119 @boundscheck check_domain_size(tm, range_size(tmi))
187 return tmi 120 return tmi
188 end 121 end
189 122
190 123
191 """ 124 """
192 InflatedTensorMapping{T,R,D} <: TensorMapping{T,R,D} 125 InflatedLazyTensor{T,R,D} <: LazyTensor{T,R,D}
193 126
194 An inflated `TensorMapping` with dimensions added before and afer its actual dimensions. 127 An inflated `LazyTensor` with dimensions added before and afer its actual dimensions.
195 """ 128 """
196 struct InflatedTensorMapping{T,R,D,D_before,R_middle,D_middle,D_after, TM<:TensorMapping{T,R_middle,D_middle}} <: TensorMapping{T,R,D} 129 struct InflatedLazyTensor{T,R,D,D_before,R_middle,D_middle,D_after, TM<:LazyTensor{T,R_middle,D_middle}} <: LazyTensor{T,R,D}
197 before::IdentityMapping{T,D_before} 130 before::IdentityTensor{T,D_before}
198 tm::TM 131 tm::TM
199 after::IdentityMapping{T,D_after} 132 after::IdentityTensor{T,D_after}
200 133
201 function InflatedTensorMapping(before, tm::TensorMapping{T}, after) where T 134 function InflatedLazyTensor(before, tm::LazyTensor{T}, after) where T
202 R_before = range_dim(before) 135 R_before = range_dim(before)
203 R_middle = range_dim(tm) 136 R_middle = range_dim(tm)
204 R_after = range_dim(after) 137 R_after = range_dim(after)
205 R = R_before+R_middle+R_after 138 R = R_before+R_middle+R_after
206 139
209 D_after = domain_dim(after) 142 D_after = domain_dim(after)
210 D = D_before+D_middle+D_after 143 D = D_before+D_middle+D_after
211 return new{T,R,D,D_before,R_middle,D_middle,D_after, typeof(tm)}(before, tm, after) 144 return new{T,R,D,D_before,R_middle,D_middle,D_after, typeof(tm)}(before, tm, after)
212 end 145 end
213 end 146 end
214 """ 147
215 InflatedTensorMapping(before, tm, after) 148 """
216 InflatedTensorMapping(before,tm) 149 InflatedLazyTensor(before, tm, after)
217 InflatedTensorMapping(tm,after) 150 InflatedLazyTensor(before,tm)
218 151 InflatedLazyTensor(tm,after)
219 The outer product of `before`, `tm` and `after`, where `before` and `after` are `IdentityMapping`s. 152
220 153 The outer product of `before`, `tm` and `after`, where `before` and `after` are `IdentityTensor`s.
221 If one of `before` or `after` is left out, a 0-dimensional `IdentityMapping` is used as the default value. 154
222 155 If one of `before` or `after` is left out, a 0-dimensional `IdentityTensor` is used as the default value.
223 If `tm` already is an `InflatedTensorMapping`, `before` and `after` will be extended instead of 156
224 creating a nested `InflatedTensorMapping`. 157 If `tm` already is an `InflatedLazyTensor`, `before` and `after` will be extended instead of
225 """ 158 creating a nested `InflatedLazyTensor`.
226 InflatedTensorMapping(::IdentityMapping, ::TensorMapping, ::IdentityMapping) 159 """
227 160 InflatedLazyTensor(::IdentityTensor, ::LazyTensor, ::IdentityTensor)
228 function InflatedTensorMapping(before, itm::InflatedTensorMapping, after) 161
229 return InflatedTensorMapping( 162 function InflatedLazyTensor(before, itm::InflatedLazyTensor, after)
230 IdentityMapping(before.size..., itm.before.size...), 163 return InflatedLazyTensor(
164 IdentityTensor(before.size..., itm.before.size...),
231 itm.tm, 165 itm.tm,
232 IdentityMapping(itm.after.size..., after.size...), 166 IdentityTensor(itm.after.size..., after.size...),
233 ) 167 )
234 end 168 end
235 169
236 InflatedTensorMapping(before::IdentityMapping, tm::TensorMapping{T}) where T = InflatedTensorMapping(before,tm,IdentityMapping{T}()) 170 InflatedLazyTensor(before::IdentityTensor, tm::LazyTensor{T}) where T = InflatedLazyTensor(before,tm,IdentityTensor{T}())
237 InflatedTensorMapping(tm::TensorMapping{T}, after::IdentityMapping) where T = InflatedTensorMapping(IdentityMapping{T}(),tm,after) 171 InflatedLazyTensor(tm::LazyTensor{T}, after::IdentityTensor) where T = InflatedLazyTensor(IdentityTensor{T}(),tm,after)
238 # Resolve ambiguity between the two previous methods 172 # Resolve ambiguity between the two previous methods
239 InflatedTensorMapping(I1::IdentityMapping{T}, I2::IdentityMapping{T}) where T = InflatedTensorMapping(I1,I2,IdentityMapping{T}()) 173 InflatedLazyTensor(I1::IdentityTensor{T}, I2::IdentityTensor{T}) where T = InflatedLazyTensor(I1,I2,IdentityTensor{T}())
240 174
241 # TODO: Implement some pretty printing in terms of ⊗. E.g InflatedTensorMapping(I(3),B,I(2)) -> I(3)⊗B⊗I(2) 175 # TODO: Implement some pretty printing in terms of ⊗. E.g InflatedLazyTensor(I(3),B,I(2)) -> I(3)⊗B⊗I(2)
242 176
243 function range_size(itm::InflatedTensorMapping) 177 function range_size(itm::InflatedLazyTensor)
244 return flatten_tuple( 178 return flatten_tuple(
245 range_size(itm.before), 179 range_size(itm.before),
246 range_size(itm.tm), 180 range_size(itm.tm),
247 range_size(itm.after), 181 range_size(itm.after),
248 ) 182 )
249 end 183 end
250 184
251 function domain_size(itm::InflatedTensorMapping) 185 function domain_size(itm::InflatedLazyTensor)
252 return flatten_tuple( 186 return flatten_tuple(
253 domain_size(itm.before), 187 domain_size(itm.before),
254 domain_size(itm.tm), 188 domain_size(itm.tm),
255 domain_size(itm.after), 189 domain_size(itm.after),
256 ) 190 )
257 end 191 end
258 192
259 function apply(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} 193 function apply(itm::InflatedLazyTensor{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D}
260 dim_before = range_dim(itm.before) 194 dim_before = range_dim(itm.before)
261 dim_domain = domain_dim(itm.tm) 195 dim_domain = domain_dim(itm.tm)
262 dim_range = range_dim(itm.tm) 196 dim_range = range_dim(itm.tm)
263 dim_after = range_dim(itm.after) 197 dim_after = range_dim(itm.after)
264 198
266 200
267 v_inner = view(v, view_index...) 201 v_inner = view(v, view_index...)
268 return apply(itm.tm, v_inner, inner_index...) 202 return apply(itm.tm, v_inner, inner_index...)
269 end 203 end
270 204
271 function apply_transpose(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{<:Any,R}, I::Vararg{Any,D}) where {T,R,D} 205 function apply_transpose(itm::InflatedLazyTensor{T,R,D}, v::AbstractArray{<:Any,R}, I::Vararg{Any,D}) where {T,R,D}
272 dim_before = range_dim(itm.before) 206 dim_before = range_dim(itm.before)
273 dim_domain = domain_dim(itm.tm) 207 dim_domain = domain_dim(itm.tm)
274 dim_range = range_dim(itm.tm) 208 dim_range = range_dim(itm.tm)
275 dim_after = range_dim(itm.after) 209 dim_after = range_dim(itm.after)
276 210
279 v_inner = view(v, view_index...) 213 v_inner = view(v, view_index...)
280 return apply_transpose(itm.tm, v_inner, inner_index...) 214 return apply_transpose(itm.tm, v_inner, inner_index...)
281 end 215 end
282 216
283 217
284 """
285 split_index(::Val{dim_before}, ::Val{dim_view}, ::Val{dim_index}, ::Val{dim_after}, I...)
286
287 Splits the multi-index `I` into two parts. One part which is expected to be
288 used as a view, and one which is expected to be used as an index.
289 Eg.
290 ```
291 split_index(Val(1),Val(3),Val(2),Val(1),(1,2,3,4)) -> (1,:,:,:,4), (2,3)
292 ```
293
294 `dim_view` controls how many colons are in the view, and `dim_index` controls
295 how many elements are extracted from the middle.
296 `dim_before` and `dim_after` decides the length of the index parts before and after the colons in the view index.
297
298 Arguments should satisfy `length(I) == dim_before+B_domain+dim_after`.
299
300 The returned values satisfy
301 * `length(view_index) == dim_before + dim_view + dim_after`
302 * `length(I_middle) == dim_index`
303 """
304 function split_index(::Val{dim_before}, ::Val{dim_view}, ::Val{dim_index}, ::Val{dim_after}, I...) where {dim_before,dim_view, dim_index,dim_after}
305 I_before, I_middle, I_after = split_tuple(I, Val(dim_before), Val(dim_index))
306
307 view_index = (I_before..., ntuple((i)->:, dim_view)..., I_after...)
308
309 return view_index, I_middle
310 end
311
312 # TODO: Can this be replaced by something more elegant while still being type stable? 2020-10-21
313 # See:
314 # https://github.com/JuliaLang/julia/issues/34884
315 # https://github.com/JuliaLang/julia/issues/30386
316 """
317 slice_tuple(t, Val(l), Val(u))
318
319 Get a slice of a tuple in a type stable way.
320 Equivalent to `t[l:u]` but type stable.
321 """
322 function slice_tuple(t,::Val{L},::Val{U}) where {L,U}
323 return ntuple(i->t[i+L-1], U-L+1)
324 end
325
326 """
327 split_tuple(t::Tuple{...}, ::Val{M}) where {N,M}
328
329 Split the tuple `t` into two parts. the first part is `M` long.
330 E.g
331 ```julia
332 split_tuple((1,2,3,4),Val(3)) -> (1,2,3), (4,)
333 ```
334 """
335 function split_tuple(t::NTuple{N,Any},::Val{M}) where {N,M}
336 return slice_tuple(t,Val(1), Val(M)), slice_tuple(t,Val(M+1), Val(N))
337 end
338
339 """
340 split_tuple(t::Tuple{...},::Val{M},::Val{K}) where {N,M,K}
341
342 Same as `split_tuple(t::NTuple{N},::Val{M})` but splits the tuple in three parts. With the first
343 two parts having lenght `M` and `K`.
344 """
345 function split_tuple(t::NTuple{N,Any},::Val{M},::Val{K}) where {N,M,K}
346 p1, tail = split_tuple(t, Val(M))
347 p2, p3 = split_tuple(tail, Val(K))
348 return p1,p2,p3
349 end
350
351
352 """
353 flatten_tuple(t)
354
355 Takes a nested tuple and flattens the whole structure
356 """
357 flatten_tuple(t::NTuple{N, Number} where N) = t
358 flatten_tuple(t::Tuple) = ((flatten_tuple.(t)...)...,) # simplify?
359 flatten_tuple(ts::Vararg) = flatten_tuple(ts)
360
361 @doc raw""" 218 @doc raw"""
362 LazyOuterProduct(tms...) 219 LazyOuterProduct(tms...)
363 220
364 Creates a `TensorMappingComposition` for the outerproduct of `tms...`. 221 Creates a `LazyTensorComposition` for the outerproduct of `tms...`.
365 This is done by separating the outer product into regular products of outer products involving only identity mappings and one non-identity mapping. 222 This is done by separating the outer product into regular products of outer products involving only identity mappings and one non-identity mapping.
366 223
367 First let 224 First let
368 ```math 225 ```math
369 \begin{aligned} 226 \begin{aligned}
395 (A⊗B⊗C)v = [(A⊗I_{|M|}⊗I_{|P|}) [(I_{|J|}⊗B⊗I_{|P|}) [(I_{|J|}⊗I_{|N|}⊗C)v]]] 252 (A⊗B⊗C)v = [(A⊗I_{|M|}⊗I_{|P|}) [(I_{|J|}⊗B⊗I_{|P|}) [(I_{|J|}⊗I_{|N|}⊗C)v]]]
396 ``` 253 ```
397 """ 254 """
398 function LazyOuterProduct end 255 function LazyOuterProduct end
399 256
400 function LazyOuterProduct(tm1::TensorMapping{T}, tm2::TensorMapping{T}) where T 257 function LazyOuterProduct(tm1::LazyTensor{T}, tm2::LazyTensor{T}) where T
401 itm1 = InflatedTensorMapping(tm1, IdentityMapping{T}(range_size(tm2))) 258 itm1 = InflatedLazyTensor(tm1, IdentityTensor{T}(range_size(tm2)))
402 itm2 = InflatedTensorMapping(IdentityMapping{T}(domain_size(tm1)),tm2) 259 itm2 = InflatedLazyTensor(IdentityTensor{T}(domain_size(tm1)),tm2)
403 260
404 return itm1∘itm2 261 return itm1∘itm2
405 end 262 end
406 263
407 LazyOuterProduct(t1::IdentityMapping{T}, t2::IdentityMapping{T}) where T = IdentityMapping{T}(t1.size...,t2.size...) 264 LazyOuterProduct(t1::IdentityTensor{T}, t2::IdentityTensor{T}) where T = IdentityTensor{T}(t1.size...,t2.size...)
408 LazyOuterProduct(t1::TensorMapping, t2::IdentityMapping) = InflatedTensorMapping(t1, t2) 265 LazyOuterProduct(t1::LazyTensor, t2::IdentityTensor) = InflatedLazyTensor(t1, t2)
409 LazyOuterProduct(t1::IdentityMapping, t2::TensorMapping) = InflatedTensorMapping(t1, t2) 266 LazyOuterProduct(t1::IdentityTensor, t2::LazyTensor) = InflatedLazyTensor(t1, t2)
410 267
411 LazyOuterProduct(tms::Vararg{TensorMapping}) = foldl(LazyOuterProduct, tms) 268 LazyOuterProduct(tms::Vararg{LazyTensor}) = foldl(LazyOuterProduct, tms)
412 269
413 ⊗(a::TensorMapping, b::TensorMapping) = LazyOuterProduct(a,b)
414 270
415 271
416 """ 272 """
417 inflate(tm, sz, dir) 273 inflate(tm, sz, dir)
418 274
419 Inflate `tm` with identity tensors in all directions `d` for `d != dir`. 275 Inflate `tm` with identity tensors in all directions `d` for `d != dir`.
420 276
421 # TODO: Describe when it is useful 277 # TODO: Describe when it is useful
422 """ 278 """
423 function inflate(tm::TensorMapping, sz, dir) 279 function inflate(tm::LazyTensor, sz, dir)
424 Is = IdentityMapping{eltype(tm)}.(sz) 280 Is = IdentityTensor{eltype(tm)}.(sz)
425 parts = Base.setindex(Is, tm, dir) 281 parts = Base.setindex(Is, tm, dir)
426 return foldl(⊗, parts) 282 return foldl(⊗, parts)
427 end 283 end
428 284
429 function check_domain_size(tm::TensorMapping, sz) 285 function check_domain_size(tm::LazyTensor, sz)
430 if domain_size(tm) != sz 286 if domain_size(tm) != sz
431 throw(SizeMismatch(tm,sz)) 287 throw(DomainSizeMismatch(tm,sz))
432 end 288 end
433 end 289 end
434 290
435 struct SizeMismatch <: Exception 291 function check_range_size(tm::LazyTensor, sz)
436 tm::TensorMapping 292 if range_size(tm) != sz
293 throw(RangeSizeMismatch(tm,sz))
294 end
295 end
296
297 struct DomainSizeMismatch <: Exception
298 tm::LazyTensor
437 sz 299 sz
438 end 300 end
439 301
440 function Base.showerror(io::IO, err::SizeMismatch) 302 function Base.showerror(io::IO, err::DomainSizeMismatch)
441 print(io, "SizeMismatch: ") 303 print(io, "DomainSizeMismatch: ")
442 print(io, "domain size $(domain_size(err.tm)) of TensorMapping not matching size $(err.sz)") 304 print(io, "domain size $(domain_size(err.tm)) of LazyTensor not matching size $(err.sz)")
443 end 305 end
306
307
308 struct RangeSizeMismatch <: Exception
309 tm::LazyTensor
310 sz
311 end
312
313 function Base.showerror(io::IO, err::RangeSizeMismatch)
314 print(io, "RangeSizeMismatch: ")
315 print(io, "range size $(range_size(err.tm)) of LazyTensor not matching size $(err.sz)")
316 end