comparison src/LazyTensors/lazy_tensor_operations.jl @ 1049:3bb94ce74697 feature/variable_derivatives

Merge default
author Jonatan Werpers <jonatan@werpers.com>
date Wed, 23 Mar 2022 12:54:45 +0100
parents df562695b1b5 9e76bf19904c
children 5a3281429a48
comparison
equal deleted inserted replaced
1048:86aa69ad3304 1049:3bb94ce74697
1 using Sbplib.RegionIndices 1 using Sbplib.RegionIndices
2 2
3 """ 3 """
4 LazyTensorMappingApplication{T,R,D} <: LazyArray{T,R} 4 TensorApplication{T,R,D} <: LazyArray{T,R}
5 5
6 Struct for lazy application of a TensorMapping. Created using `*`. 6 Struct for lazy application of a LazyTensor. Created using `*`.
7 7
8 Allows the result of a `TensorMapping` applied to a vector to be treated as an `AbstractArray`. 8 Allows the result of a `LazyTensor` applied to a vector to be treated as an `AbstractArray`.
9 With a mapping `m` and a vector `v` the LazyTensorMappingApplication object can be created by `m*v`. 9 With a mapping `m` and a vector `v` the TensorApplication object can be created by `m*v`.
10 The actual result will be calcualted when indexing into `m*v`. 10 The actual result will be calcualted when indexing into `m*v`.
11 """ 11 """
12 struct LazyTensorMappingApplication{T,R,D, TM<:TensorMapping{<:Any,R,D}, AA<:AbstractArray{<:Any,D}} <: LazyArray{T,R} 12 struct TensorApplication{T,R,D, TM<:LazyTensor{<:Any,R,D}, AA<:AbstractArray{<:Any,D}} <: LazyArray{T,R}
13 t::TM 13 t::TM
14 o::AA 14 o::AA
15 15
16 function LazyTensorMappingApplication(t::TensorMapping{<:Any,R,D}, o::AbstractArray{<:Any,D}) where {R,D} 16 function TensorApplication(t::LazyTensor{<:Any,R,D}, o::AbstractArray{<:Any,D}) where {R,D}
17 @boundscheck check_domain_size(t, size(o))
17 I = ntuple(i->1, range_dim(t)) 18 I = ntuple(i->1, range_dim(t))
18 T = typeof(apply(t,o,I...)) 19 T = typeof(apply(t,o,I...))
19 return new{T,R,D,typeof(t), typeof(o)}(t,o) 20 return new{T,R,D,typeof(t), typeof(o)}(t,o)
20 end 21 end
21 end 22 end
22 # TODO: Do boundschecking on creation! 23
23 24 function Base.getindex(ta::TensorApplication{T,R}, I::Vararg{Any,R}) where {T,R}
24 Base.getindex(ta::LazyTensorMappingApplication{T,R}, I::Vararg{Any,R}) where {T,R} = apply(ta.t, ta.o, I...) 25 @boundscheck checkbounds(ta, Int.(I)...)
25 Base.getindex(ta::LazyTensorMappingApplication{T,1}, I::CartesianIndex{1}) where {T} = apply(ta.t, ta.o, I.I...) # Would otherwise be caught in the previous method. 26 return @inbounds apply(ta.t, ta.o, I...)
26 Base.size(ta::LazyTensorMappingApplication) = range_size(ta.t) 27 end
27 # TODO: What else is needed to implement the AbstractArray interface? 28 Base.@propagate_inbounds Base.getindex(ta::TensorApplication{T,1} where T, I::CartesianIndex{1}) = ta[Tuple(I)...] # Would otherwise be caught in the previous method.
28 29 Base.size(ta::TensorApplication) = range_size(ta.t)
29 Base.:*(a::TensorMapping, v::AbstractArray) = LazyTensorMappingApplication(a,v) 30
30 Base.:*(a::TensorMapping, b::TensorMapping) = throw(MethodError(Base.:*,(a,b))) 31
31 Base.:*(a::TensorMapping, args::Union{TensorMapping, AbstractArray}...) = foldr(*,(a,args...)) 32 """
32 33 TensorTranspose{T,R,D} <: LazyTensor{T,D,R}
33 # # We need the associativity to be a→b→c = a→(b→c), which is the case for '→' 34
34 # # Should we overload some other infix binary opesrator? 35 Struct for lazy transpose of a LazyTensor.
35 # →(tm::TensorMapping{T,R,D}, o::AbstractArray{T,D}) where {T,R,D} = LazyTensorMappingApplication(tm,o)
36 # TODO: We need to be really careful about good error messages.
37 # For example what happens if you try to multiply LazyTensorMappingApplication with a TensorMapping(wrong order)?
38
39 """
40 LazyTensorMappingTranspose{T,R,D} <: TensorMapping{T,D,R}
41
42 Struct for lazy transpose of a TensorMapping.
43 36
44 If a mapping implements the the `apply_transpose` method this allows working with 37 If a mapping implements the the `apply_transpose` method this allows working with
45 the transpose of mapping `m` by using `m'`. `m'` will work as a regular TensorMapping lazily calling 38 the transpose of mapping `m` by using `m'`. `m'` will work as a regular LazyTensor lazily calling
46 the appropriate methods of `m`. 39 the appropriate methods of `m`.
47 """ 40 """
48 struct LazyTensorMappingTranspose{T,R,D, TM<:TensorMapping{T,R,D}} <: TensorMapping{T,D,R} 41 struct TensorTranspose{T,R,D, TM<:LazyTensor{T,R,D}} <: LazyTensor{T,D,R}
49 tm::TM 42 tm::TM
50 end 43 end
51 44
52 # # TBD: Should this be implemented on a type by type basis or through a trait to provide earlier errors? 45 # # TBD: Should this be implemented on a type by type basis or through a trait to provide earlier errors?
53 # Jonatan 2020-09-25: Is the problem that you can take the transpose of any TensorMapping even if it doesn't implement `apply_transpose`? 46 # Jonatan 2020-09-25: Is the problem that you can take the transpose of any LazyTensor even if it doesn't implement `apply_transpose`?
54 Base.adjoint(tm::TensorMapping) = LazyTensorMappingTranspose(tm) 47 Base.adjoint(tm::LazyTensor) = TensorTranspose(tm)
55 Base.adjoint(tmt::LazyTensorMappingTranspose) = tmt.tm 48 Base.adjoint(tmt::TensorTranspose) = tmt.tm
56 49
57 apply(tmt::LazyTensorMappingTranspose{T,R,D}, v::AbstractArray{<:Any,R}, I::Vararg{Any,D}) where {T,R,D} = apply_transpose(tmt.tm, v, I...) 50 apply(tmt::TensorTranspose{T,R,D}, v::AbstractArray{<:Any,R}, I::Vararg{Any,D}) where {T,R,D} = apply_transpose(tmt.tm, v, I...)
58 apply_transpose(tmt::LazyTensorMappingTranspose{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmt.tm, v, I...) 51 apply_transpose(tmt::TensorTranspose{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmt.tm, v, I...)
59 52
60 range_size(tmt::LazyTensorMappingTranspose) = domain_size(tmt.tm) 53 range_size(tmt::TensorTranspose) = domain_size(tmt.tm)
61 domain_size(tmt::LazyTensorMappingTranspose) = range_size(tmt.tm) 54 domain_size(tmt::TensorTranspose) = range_size(tmt.tm)
62 55
63 56
64 struct LazyTensorMappingBinaryOperation{Op,T,R,D,T1<:TensorMapping{T,R,D},T2<:TensorMapping{T,R,D}} <: TensorMapping{T,D,R} 57 struct ElementwiseTensorOperation{Op,T,R,D,T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}} <: LazyTensor{T,D,R}
65 tm1::T1 58 tm1::T1
66 tm2::T2 59 tm2::T2
67 60
68 @inline function LazyTensorMappingBinaryOperation{Op,T,R,D}(tm1::T1,tm2::T2) where {Op,T,R,D, T1<:TensorMapping{T,R,D},T2<:TensorMapping{T,R,D}} 61 function ElementwiseTensorOperation{Op,T,R,D}(tm1::T1,tm2::T2) where {Op,T,R,D, T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}}
62 @boundscheck check_domain_size(tm2, domain_size(tm1))
63 @boundscheck check_range_size(tm2, range_size(tm1))
69 return new{Op,T,R,D,T1,T2}(tm1,tm2) 64 return new{Op,T,R,D,T1,T2}(tm1,tm2)
70 end 65 end
71 end 66 end
72 # TODO: Boundschecking in constructor. 67
73 68 ElementwiseTensorOperation{Op}(s,t) where Op = ElementwiseTensorOperation{Op,eltype(s), range_dim(s), domain_dim(s)}(s,t)
74 apply(tmBinOp::LazyTensorMappingBinaryOperation{:+,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) + apply(tmBinOp.tm2, v, I...) 69
75 apply(tmBinOp::LazyTensorMappingBinaryOperation{:-,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) - apply(tmBinOp.tm2, v, I...) 70 apply(tmBinOp::ElementwiseTensorOperation{:+,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) + apply(tmBinOp.tm2, v, I...)
76 71 apply(tmBinOp::ElementwiseTensorOperation{:-,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) - apply(tmBinOp.tm2, v, I...)
77 range_size(tmBinOp::LazyTensorMappingBinaryOperation) = range_size(tmBinOp.tm1) 72
78 domain_size(tmBinOp::LazyTensorMappingBinaryOperation) = domain_size(tmBinOp.tm1) 73 range_size(tmBinOp::ElementwiseTensorOperation) = range_size(tmBinOp.tm1)
79 74 domain_size(tmBinOp::ElementwiseTensorOperation) = domain_size(tmBinOp.tm1)
80 Base.:+(tm1::TensorMapping{T,R,D}, tm2::TensorMapping{T,R,D}) where {T,R,D} = LazyTensorMappingBinaryOperation{:+,T,R,D}(tm1,tm2) 75
81 Base.:-(tm1::TensorMapping{T,R,D}, tm2::TensorMapping{T,R,D}) where {T,R,D} = LazyTensorMappingBinaryOperation{:-,T,R,D}(tm1,tm2) 76
82 77 """
83 """ 78 TensorComposition{T,R,K,D}
84 TensorMappingComposition{T,R,K,D} 79
85 80 Lazily compose two `LazyTensor`s, so that they can be handled as a single `LazyTensor`.
86 Lazily compose two `TensorMapping`s, so that they can be handled as a single `TensorMapping`. 81 """
87 """ 82 struct TensorComposition{T,R,K,D, TM1<:LazyTensor{T,R,K}, TM2<:LazyTensor{T,K,D}} <: LazyTensor{T,R,D}
88 struct TensorMappingComposition{T,R,K,D, TM1<:TensorMapping{T,R,K}, TM2<:TensorMapping{T,K,D}} <: TensorMapping{T,R,D}
89 t1::TM1 83 t1::TM1
90 t2::TM2 84 t2::TM2
91 85
92 @inline function TensorMappingComposition(t1::TensorMapping{T,R,K}, t2::TensorMapping{T,K,D}) where {T,R,K,D} 86 function TensorComposition(t1::LazyTensor{T,R,K}, t2::LazyTensor{T,K,D}) where {T,R,K,D}
93 @boundscheck check_domain_size(t1, range_size(t2)) 87 @boundscheck check_domain_size(t1, range_size(t2))
94 return new{T,R,K,D, typeof(t1), typeof(t2)}(t1,t2) 88 return new{T,R,K,D, typeof(t1), typeof(t2)}(t1,t2)
95 end 89 end
96 end 90 end
97 91
98 range_size(tm::TensorMappingComposition) = range_size(tm.t1) 92 range_size(tm::TensorComposition) = range_size(tm.t1)
99 domain_size(tm::TensorMappingComposition) = domain_size(tm.t2) 93 domain_size(tm::TensorComposition) = domain_size(tm.t2)
100 94
101 function apply(c::TensorMappingComposition{T,R,K,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,K,D} 95 function apply(c::TensorComposition{T,R,K,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,K,D}
102 apply(c.t1, c.t2*v, I...) 96 apply(c.t1, c.t2*v, I...)
103 end 97 end
104 98
105 function apply_transpose(c::TensorMappingComposition{T,R,K,D}, v::AbstractArray{<:Any,R}, I::Vararg{Any,D}) where {T,R,K,D} 99 function apply_transpose(c::TensorComposition{T,R,K,D}, v::AbstractArray{<:Any,R}, I::Vararg{Any,D}) where {T,R,K,D}
106 apply_transpose(c.t2, c.t1'*v, I...) 100 apply_transpose(c.t2, c.t1'*v, I...)
107 end 101 end
108 102
109 Base.@propagate_inbounds Base.:∘(s::TensorMapping, t::TensorMapping) = TensorMappingComposition(s,t) 103
110 104 """
111 """ 105 TensorComposition(tm, tmi::IdentityTensor)
112 LazyLinearMap{T,R,D,...}(A, range_indicies, domain_indicies) 106 TensorComposition(tmi::IdentityTensor, tm)
113 107
114 TensorMapping defined by the AbstractArray A. `range_indicies` and `domain_indicies` define which indicies of A should 108 Composes a `Tensormapping` `tm` with an `IdentityTensor` `tmi`, by returning `tm`
115 be considerd the range and domain of the TensorMapping. Each set of indices must be ordered in ascending order. 109 """
116 110 function TensorComposition(tm::LazyTensor{T,R,D}, tmi::IdentityTensor{T,D}) where {T,R,D}
117 For instance, if A is a m x n matrix, and range_size = (1,), domain_size = (2,), then the LazyLinearMap performs the
118 standard matrix-vector product on vectors of size n.
119 """
120 struct LazyLinearMap{T,R,D, RD, AA<:AbstractArray{T,RD}} <: TensorMapping{T,R,D}
121 A::AA
122 range_indicies::NTuple{R,Int}
123 domain_indicies::NTuple{D,Int}
124
125 function LazyLinearMap(A::AA, range_indicies::NTuple{R,Int}, domain_indicies::NTuple{D,Int}) where {T,R,D, RD, AA<:AbstractArray{T,RD}}
126 if !issorted(range_indicies) || !issorted(domain_indicies)
127 throw(DomainError("range_indicies and domain_indicies must be sorted in ascending order"))
128 end
129
130 return new{T,R,D,RD,AA}(A,range_indicies,domain_indicies)
131 end
132 end
133
134 range_size(llm::LazyLinearMap) = size(llm.A)[[llm.range_indicies...]]
135 domain_size(llm::LazyLinearMap) = size(llm.A)[[llm.domain_indicies...]]
136
137 function apply(llm::LazyLinearMap{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D}
138 view_index = ntuple(i->:,ndims(llm.A))
139 for i ∈ 1:R
140 view_index = Base.setindex(view_index, Int(I[i]), llm.range_indicies[i])
141 end
142 A_view = @view llm.A[view_index...]
143 return sum(A_view.*v)
144 end
145
146 function apply_transpose(llm::LazyLinearMap{T,R,D}, v::AbstractArray{<:Any,R}, I::Vararg{Any,D}) where {T,R,D}
147 apply(LazyLinearMap(llm.A, llm.domain_indicies, llm.range_indicies), v, I...)
148 end
149
150
151 """
152 IdentityMapping{T,D} <: TensorMapping{T,D,D}
153
154 The lazy identity TensorMapping for a given size. Usefull for building up higher dimensional tensor mappings from lower
155 dimensional ones through outer products. Also used in the Implementation for InflatedTensorMapping.
156 """
157 struct IdentityMapping{T,D} <: TensorMapping{T,D,D}
158 size::NTuple{D,Int}
159 end
160
161 IdentityMapping{T}(size::NTuple{D,Int}) where {T,D} = IdentityMapping{T,D}(size)
162 IdentityMapping{T}(size::Vararg{Int,D}) where {T,D} = IdentityMapping{T,D}(size)
163 IdentityMapping(size::Vararg{Int,D}) where D = IdentityMapping{Float64,D}(size)
164
165 range_size(tmi::IdentityMapping) = tmi.size
166 domain_size(tmi::IdentityMapping) = tmi.size
167
168 apply(tmi::IdentityMapping{T,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,D}) where {T,D} = v[I...]
169 apply_transpose(tmi::IdentityMapping{T,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,D}) where {T,D} = v[I...]
170
171 """
172 Base.:∘(tm, tmi)
173 Base.:∘(tmi, tm)
174
175 Composes a `Tensormapping` `tm` with an `IdentityMapping` `tmi`, by returning `tm`
176 """
177 @inline function Base.:∘(tm::TensorMapping{T,R,D}, tmi::IdentityMapping{T,D}) where {T,R,D}
178 @boundscheck check_domain_size(tm, range_size(tmi)) 111 @boundscheck check_domain_size(tm, range_size(tmi))
179 return tm 112 return tm
180 end 113 end
181 114
182 @inline function Base.:∘(tmi::IdentityMapping{T,R}, tm::TensorMapping{T,R,D}) where {T,R,D} 115 function TensorComposition(tmi::IdentityTensor{T,R}, tm::LazyTensor{T,R,D}) where {T,R,D}
183 @boundscheck check_domain_size(tmi, range_size(tm)) 116 @boundscheck check_domain_size(tmi, range_size(tm))
184 return tm 117 return tm
185 end 118 end
186 # Specialization for the case where tm is an IdentityMapping. Required to resolve ambiguity. 119 # Specialization for the case where tm is an IdentityTensor. Required to resolve ambiguity.
187 @inline function Base.:∘(tm::IdentityMapping{T,D}, tmi::IdentityMapping{T,D}) where {T,D} 120 function TensorComposition(tm::IdentityTensor{T,D}, tmi::IdentityTensor{T,D}) where {T,D}
188 @boundscheck check_domain_size(tm, range_size(tmi)) 121 @boundscheck check_domain_size(tm, range_size(tmi))
189 return tmi 122 return tmi
190 end 123 end
191 124
192 125
193 """ 126 """
194 InflatedTensorMapping{T,R,D} <: TensorMapping{T,R,D} 127 InflatedTensor{T,R,D} <: LazyTensor{T,R,D}
195 128
196 An inflated `TensorMapping` with dimensions added before and afer its actual dimensions. 129 An inflated `LazyTensor` with dimensions added before and afer its actual dimensions.
197 """ 130 """
198 struct InflatedTensorMapping{T,R,D,D_before,R_middle,D_middle,D_after, TM<:TensorMapping{T,R_middle,D_middle}} <: TensorMapping{T,R,D} 131 struct InflatedTensor{T,R,D,D_before,R_middle,D_middle,D_after, TM<:LazyTensor{T,R_middle,D_middle}} <: LazyTensor{T,R,D}
199 before::IdentityMapping{T,D_before} 132 before::IdentityTensor{T,D_before}
200 tm::TM 133 tm::TM
201 after::IdentityMapping{T,D_after} 134 after::IdentityTensor{T,D_after}
202 135
203 function InflatedTensorMapping(before, tm::TensorMapping{T}, after) where T 136 function InflatedTensor(before, tm::LazyTensor{T}, after) where T
204 R_before = range_dim(before) 137 R_before = range_dim(before)
205 R_middle = range_dim(tm) 138 R_middle = range_dim(tm)
206 R_after = range_dim(after) 139 R_after = range_dim(after)
207 R = R_before+R_middle+R_after 140 R = R_before+R_middle+R_after
208 141
211 D_after = domain_dim(after) 144 D_after = domain_dim(after)
212 D = D_before+D_middle+D_after 145 D = D_before+D_middle+D_after
213 return new{T,R,D,D_before,R_middle,D_middle,D_after, typeof(tm)}(before, tm, after) 146 return new{T,R,D,D_before,R_middle,D_middle,D_after, typeof(tm)}(before, tm, after)
214 end 147 end
215 end 148 end
216 """ 149
217 InflatedTensorMapping(before, tm, after) 150 """
218 InflatedTensorMapping(before,tm) 151 InflatedTensor(before, tm, after)
219 InflatedTensorMapping(tm,after) 152 InflatedTensor(before,tm)
220 153 InflatedTensor(tm,after)
221 The outer product of `before`, `tm` and `after`, where `before` and `after` are `IdentityMapping`s. 154
222 155 The outer product of `before`, `tm` and `after`, where `before` and `after` are `IdentityTensor`s.
223 If one of `before` or `after` is left out, a 0-dimensional `IdentityMapping` is used as the default value. 156
224 157 If one of `before` or `after` is left out, a 0-dimensional `IdentityTensor` is used as the default value.
225 If `tm` already is an `InflatedTensorMapping`, `before` and `after` will be extended instead of 158
226 creating a nested `InflatedTensorMapping`. 159 If `tm` already is an `InflatedTensor`, `before` and `after` will be extended instead of
227 """ 160 creating a nested `InflatedTensor`.
228 InflatedTensorMapping(::IdentityMapping, ::TensorMapping, ::IdentityMapping) 161 """
229 162 InflatedTensor(::IdentityTensor, ::LazyTensor, ::IdentityTensor)
230 function InflatedTensorMapping(before, itm::InflatedTensorMapping, after) 163
231 return InflatedTensorMapping( 164 function InflatedTensor(before, itm::InflatedTensor, after)
232 IdentityMapping(before.size..., itm.before.size...), 165 return InflatedTensor(
166 IdentityTensor(before.size..., itm.before.size...),
233 itm.tm, 167 itm.tm,
234 IdentityMapping(itm.after.size..., after.size...), 168 IdentityTensor(itm.after.size..., after.size...),
235 ) 169 )
236 end 170 end
237 171
238 InflatedTensorMapping(before::IdentityMapping, tm::TensorMapping{T}) where T = InflatedTensorMapping(before,tm,IdentityMapping{T}()) 172 InflatedTensor(before::IdentityTensor, tm::LazyTensor{T}) where T = InflatedTensor(before,tm,IdentityTensor{T}())
239 InflatedTensorMapping(tm::TensorMapping{T}, after::IdentityMapping) where T = InflatedTensorMapping(IdentityMapping{T}(),tm,after) 173 InflatedTensor(tm::LazyTensor{T}, after::IdentityTensor) where T = InflatedTensor(IdentityTensor{T}(),tm,after)
240 # Resolve ambiguity between the two previous methods 174 # Resolve ambiguity between the two previous methods
241 InflatedTensorMapping(I1::IdentityMapping{T}, I2::IdentityMapping{T}) where T = InflatedTensorMapping(I1,I2,IdentityMapping{T}()) 175 InflatedTensor(I1::IdentityTensor{T}, I2::IdentityTensor{T}) where T = InflatedTensor(I1,I2,IdentityTensor{T}())
242 176
243 # TODO: Implement some pretty printing in terms of ⊗. E.g InflatedTensorMapping(I(3),B,I(2)) -> I(3)⊗B⊗I(2) 177 # TODO: Implement some pretty printing in terms of ⊗. E.g InflatedTensor(I(3),B,I(2)) -> I(3)⊗B⊗I(2)
244 178
245 function range_size(itm::InflatedTensorMapping) 179 function range_size(itm::InflatedTensor)
246 return flatten_tuple( 180 return flatten_tuple(
247 range_size(itm.before), 181 range_size(itm.before),
248 range_size(itm.tm), 182 range_size(itm.tm),
249 range_size(itm.after), 183 range_size(itm.after),
250 ) 184 )
251 end 185 end
252 186
253 function domain_size(itm::InflatedTensorMapping) 187 function domain_size(itm::InflatedTensor)
254 return flatten_tuple( 188 return flatten_tuple(
255 domain_size(itm.before), 189 domain_size(itm.before),
256 domain_size(itm.tm), 190 domain_size(itm.tm),
257 domain_size(itm.after), 191 domain_size(itm.after),
258 ) 192 )
259 end 193 end
260 194
261 function apply(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} 195 function apply(itm::InflatedTensor{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D}
262 dim_before = range_dim(itm.before) 196 dim_before = range_dim(itm.before)
263 dim_domain = domain_dim(itm.tm) 197 dim_domain = domain_dim(itm.tm)
264 dim_range = range_dim(itm.tm) 198 dim_range = range_dim(itm.tm)
265 dim_after = range_dim(itm.after) 199 dim_after = range_dim(itm.after)
266 200
268 202
269 v_inner = view(v, view_index...) 203 v_inner = view(v, view_index...)
270 return apply(itm.tm, v_inner, inner_index...) 204 return apply(itm.tm, v_inner, inner_index...)
271 end 205 end
272 206
273 function apply_transpose(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{<:Any,R}, I::Vararg{Any,D}) where {T,R,D} 207 function apply_transpose(itm::InflatedTensor{T,R,D}, v::AbstractArray{<:Any,R}, I::Vararg{Any,D}) where {T,R,D}
274 dim_before = range_dim(itm.before) 208 dim_before = range_dim(itm.before)
275 dim_domain = domain_dim(itm.tm) 209 dim_domain = domain_dim(itm.tm)
276 dim_range = range_dim(itm.tm) 210 dim_range = range_dim(itm.tm)
277 dim_after = range_dim(itm.after) 211 dim_after = range_dim(itm.after)
278 212
281 v_inner = view(v, view_index...) 215 v_inner = view(v, view_index...)
282 return apply_transpose(itm.tm, v_inner, inner_index...) 216 return apply_transpose(itm.tm, v_inner, inner_index...)
283 end 217 end
284 218
285 219
286 """
287 split_index(::Val{dim_before}, ::Val{dim_view}, ::Val{dim_index}, ::Val{dim_after}, I...)
288
289 Splits the multi-index `I` into two parts. One part which is expected to be
290 used as a view, and one which is expected to be used as an index.
291 Eg.
292 ```
293 split_index(Val(1),Val(3),Val(2),Val(1),(1,2,3,4)) -> (1,:,:,:,4), (2,3)
294 ```
295
296 `dim_view` controls how many colons are in the view, and `dim_index` controls
297 how many elements are extracted from the middle.
298 `dim_before` and `dim_after` decides the length of the index parts before and after the colons in the view index.
299
300 Arguments should satisfy `length(I) == dim_before+B_domain+dim_after`.
301
302 The returned values satisfy
303 * `length(view_index) == dim_before + dim_view + dim_after`
304 * `length(I_middle) == dim_index`
305 """
306 function split_index(::Val{dim_before}, ::Val{dim_view}, ::Val{dim_index}, ::Val{dim_after}, I...) where {dim_before,dim_view, dim_index,dim_after}
307 I_before, I_middle, I_after = split_tuple(I, Val(dim_before), Val(dim_index))
308
309 view_index = (I_before..., ntuple((i)->:, dim_view)..., I_after...)
310
311 return view_index, I_middle
312 end
313
314 # TODO: Can this be replaced by something more elegant while still being type stable? 2020-10-21
315 # See:
316 # https://github.com/JuliaLang/julia/issues/34884
317 # https://github.com/JuliaLang/julia/issues/30386
318 """
319 slice_tuple(t, Val(l), Val(u))
320
321 Get a slice of a tuple in a type stable way.
322 Equivalent to `t[l:u]` but type stable.
323 """
324 function slice_tuple(t,::Val{L},::Val{U}) where {L,U}
325 return ntuple(i->t[i+L-1], U-L+1)
326 end
327
328 """
329 split_tuple(t::Tuple{...}, ::Val{M}) where {N,M}
330
331 Split the tuple `t` into two parts. the first part is `M` long.
332 E.g
333 ```julia
334 split_tuple((1,2,3,4),Val(3)) -> (1,2,3), (4,)
335 ```
336 """
337 function split_tuple(t::NTuple{N,Any},::Val{M}) where {N,M}
338 return slice_tuple(t,Val(1), Val(M)), slice_tuple(t,Val(M+1), Val(N))
339 end
340
341 """
342 split_tuple(t::Tuple{...},::Val{M},::Val{K}) where {N,M,K}
343
344 Same as `split_tuple(t::NTuple{N},::Val{M})` but splits the tuple in three parts. With the first
345 two parts having lenght `M` and `K`.
346 """
347 function split_tuple(t::NTuple{N,Any},::Val{M},::Val{K}) where {N,M,K}
348 p1, tail = split_tuple(t, Val(M))
349 p2, p3 = split_tuple(tail, Val(K))
350 return p1,p2,p3
351 end
352
353
354 """
355 flatten_tuple(t)
356
357 Takes a nested tuple and flattens the whole structure
358 """
359 flatten_tuple(t::NTuple{N, Number} where N) = t
360 flatten_tuple(t::Tuple) = ((flatten_tuple.(t)...)...,) # simplify?
361 flatten_tuple(ts::Vararg) = flatten_tuple(ts)
362
363 @doc raw""" 220 @doc raw"""
364 LazyOuterProduct(tms...) 221 LazyOuterProduct(tms...)
365 222
366 Creates a `TensorMappingComposition` for the outerproduct of `tms...`. 223 Creates a `TensorComposition` for the outerproduct of `tms...`.
367 This is done by separating the outer product into regular products of outer products involving only identity mappings and one non-identity mapping. 224 This is done by separating the outer product into regular products of outer products involving only identity mappings and one non-identity mapping.
368 225
369 First let 226 First let
370 ```math 227 ```math
371 \begin{aligned} 228 \begin{aligned}
397 (A⊗B⊗C)v = [(A⊗I_{|M|}⊗I_{|P|}) [(I_{|J|}⊗B⊗I_{|P|}) [(I_{|J|}⊗I_{|N|}⊗C)v]]] 254 (A⊗B⊗C)v = [(A⊗I_{|M|}⊗I_{|P|}) [(I_{|J|}⊗B⊗I_{|P|}) [(I_{|J|}⊗I_{|N|}⊗C)v]]]
398 ``` 255 ```
399 """ 256 """
400 function LazyOuterProduct end 257 function LazyOuterProduct end
401 258
402 function LazyOuterProduct(tm1::TensorMapping{T}, tm2::TensorMapping{T}) where T 259 function LazyOuterProduct(tm1::LazyTensor{T}, tm2::LazyTensor{T}) where T
403 itm1 = InflatedTensorMapping(tm1, IdentityMapping{T}(range_size(tm2))) 260 itm1 = InflatedTensor(tm1, IdentityTensor{T}(range_size(tm2)))
404 itm2 = InflatedTensorMapping(IdentityMapping{T}(domain_size(tm1)),tm2) 261 itm2 = InflatedTensor(IdentityTensor{T}(domain_size(tm1)),tm2)
405 262
406 return itm1∘itm2 263 return itm1∘itm2
407 end 264 end
408 265
409 LazyOuterProduct(t1::IdentityMapping{T}, t2::IdentityMapping{T}) where T = IdentityMapping{T}(t1.size...,t2.size...) 266 LazyOuterProduct(t1::IdentityTensor{T}, t2::IdentityTensor{T}) where T = IdentityTensor{T}(t1.size...,t2.size...)
410 LazyOuterProduct(t1::TensorMapping, t2::IdentityMapping) = InflatedTensorMapping(t1, t2) 267 LazyOuterProduct(t1::LazyTensor, t2::IdentityTensor) = InflatedTensor(t1, t2)
411 LazyOuterProduct(t1::IdentityMapping, t2::TensorMapping) = InflatedTensorMapping(t1, t2) 268 LazyOuterProduct(t1::IdentityTensor, t2::LazyTensor) = InflatedTensor(t1, t2)
412 269
413 LazyOuterProduct(tms::Vararg{TensorMapping}) = foldl(LazyOuterProduct, tms) 270 LazyOuterProduct(tms::Vararg{LazyTensor}) = foldl(LazyOuterProduct, tms)
414 271
415 ⊗(a::TensorMapping, b::TensorMapping) = LazyOuterProduct(a,b) 272
416 273 function check_domain_size(tm::LazyTensor, sz)
417
418 function check_domain_size(tm::TensorMapping, sz)
419 if domain_size(tm) != sz 274 if domain_size(tm) != sz
420 throw(SizeMismatch(tm,sz)) 275 throw(DomainSizeMismatch(tm,sz))
421 end 276 end
422 end 277 end
423 278
424 struct SizeMismatch <: Exception 279 function check_range_size(tm::LazyTensor, sz)
425 tm::TensorMapping 280 if range_size(tm) != sz
281 throw(RangeSizeMismatch(tm,sz))
282 end
283 end
284
285 struct DomainSizeMismatch <: Exception
286 tm::LazyTensor
426 sz 287 sz
427 end 288 end
428 289
429 function Base.showerror(io::IO, err::SizeMismatch) 290 function Base.showerror(io::IO, err::DomainSizeMismatch)
430 print(io, "SizeMismatch: ") 291 print(io, "DomainSizeMismatch: ")
431 print(io, "domain size $(domain_size(err.tm)) of TensorMapping not matching size $(err.sz)") 292 print(io, "domain size $(domain_size(err.tm)) of LazyTensor not matching size $(err.sz)")
432 end 293 end
294
295 struct RangeSizeMismatch <: Exception
296 tm::LazyTensor
297 sz
298 end
299
300 function Base.showerror(io::IO, err::RangeSizeMismatch)
301 print(io, "RangeSizeMismatch: ")
302 print(io, "range size $(range_size(err.tm)) of LazyTensor not matching size $(err.sz)")
303 end
304
433 305
434 function apply_with_region(op, v, boundary_width::Integer, dim_size::Integer, i) 306 function apply_with_region(op, v, boundary_width::Integer, dim_size::Integer, i)
435 if 0 < i <= boundary_width 307 if 0 < i <= boundary_width
436 return LazyTensors.apply(op,v,Index(i,Lower)) 308 return LazyTensors.apply(op,v,Index(i,Lower))
437 elseif boundary_width < i <= dim_size-boundary_width 309 elseif boundary_width < i <= dim_size-boundary_width