comparison src/LazyTensors/lazy_tensor_operations.jl @ 943:fb060e98ac0a feature/tensormapping_application_promotion

Remove more type assertions
author Jonatan Werpers <jonatan@werpers.com>
date Thu, 10 Mar 2022 16:57:01 +0100
parents 7829c09f8137
children 4a4ef4bf6cb9
comparison
equal deleted inserted replaced
942:7829c09f8137 943:fb060e98ac0a
50 # # TBD: Should this be implemented on a type by type basis or through a trait to provide earlier errors? 50 # # TBD: Should this be implemented on a type by type basis or through a trait to provide earlier errors?
51 # Jonatan 2020-09-25: Is the problem that you can take the transpose of any TensorMapping even if it doesn't implement `apply_transpose`? 51 # Jonatan 2020-09-25: Is the problem that you can take the transpose of any TensorMapping even if it doesn't implement `apply_transpose`?
52 Base.adjoint(tm::TensorMapping) = LazyTensorMappingTranspose(tm) 52 Base.adjoint(tm::TensorMapping) = LazyTensorMappingTranspose(tm)
53 Base.adjoint(tmt::LazyTensorMappingTranspose) = tmt.tm 53 Base.adjoint(tmt::LazyTensorMappingTranspose) = tmt.tm
54 54
55 apply(tmt::LazyTensorMappingTranspose{T,R,D}, v::AbstractArray{T,R}, I::Vararg{Any,D}) where {T,R,D} = apply_transpose(tmt.tm, v, I...) 55 apply(tmt::LazyTensorMappingTranspose{T,R,D}, v::AbstractArray{<:Any,R}, I::Vararg{Any,D}) where {T,R,D} = apply_transpose(tmt.tm, v, I...)
56 apply_transpose(tmt::LazyTensorMappingTranspose{T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmt.tm, v, I...) 56 apply_transpose(tmt::LazyTensorMappingTranspose{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmt.tm, v, I...)
57 57
58 range_size(tmt::LazyTensorMappingTranspose) = domain_size(tmt.tm) 58 range_size(tmt::LazyTensorMappingTranspose) = domain_size(tmt.tm)
59 domain_size(tmt::LazyTensorMappingTranspose) = range_size(tmt.tm) 59 domain_size(tmt::LazyTensorMappingTranspose) = range_size(tmt.tm)
60 60
61 61
67 return new{Op,T,R,D,T1,T2}(tm1,tm2) 67 return new{Op,T,R,D,T1,T2}(tm1,tm2)
68 end 68 end
69 end 69 end
70 # TODO: Boundschecking in constructor. 70 # TODO: Boundschecking in constructor.
71 71
72 apply(tmBinOp::LazyTensorMappingBinaryOperation{:+,T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) + apply(tmBinOp.tm2, v, I...) 72 apply(tmBinOp::LazyTensorMappingBinaryOperation{:+,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) + apply(tmBinOp.tm2, v, I...)
73 apply(tmBinOp::LazyTensorMappingBinaryOperation{:-,T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) - apply(tmBinOp.tm2, v, I...) 73 apply(tmBinOp::LazyTensorMappingBinaryOperation{:-,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) - apply(tmBinOp.tm2, v, I...)
74 74
75 range_size(tmBinOp::LazyTensorMappingBinaryOperation) = range_size(tmBinOp.tm1) 75 range_size(tmBinOp::LazyTensorMappingBinaryOperation) = range_size(tmBinOp.tm1)
76 domain_size(tmBinOp::LazyTensorMappingBinaryOperation) = domain_size(tmBinOp.tm1) 76 domain_size(tmBinOp::LazyTensorMappingBinaryOperation) = domain_size(tmBinOp.tm1)
77 77
78 Base.:+(tm1::TensorMapping{T,R,D}, tm2::TensorMapping{T,R,D}) where {T,R,D} = LazyTensorMappingBinaryOperation{:+,T,R,D}(tm1,tm2) 78 Base.:+(tm1::TensorMapping{T,R,D}, tm2::TensorMapping{T,R,D}) where {T,R,D} = LazyTensorMappingBinaryOperation{:+,T,R,D}(tm1,tm2)
95 export TensorMappingComposition 95 export TensorMappingComposition
96 96
97 range_size(tm::TensorMappingComposition) = range_size(tm.t1) 97 range_size(tm::TensorMappingComposition) = range_size(tm.t1)
98 domain_size(tm::TensorMappingComposition) = domain_size(tm.t2) 98 domain_size(tm::TensorMappingComposition) = domain_size(tm.t2)
99 99
100 function apply(c::TensorMappingComposition{T,R,K,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,K,D} 100 function apply(c::TensorMappingComposition{T,R,K,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,K,D}
101 apply(c.t1, c.t2*v, I...) 101 apply(c.t1, c.t2*v, I...)
102 end 102 end
103 103
104 function apply_transpose(c::TensorMappingComposition{T,R,K,D}, v::AbstractArray{T,R}, I::Vararg{Any,D}) where {T,R,K,D} 104 function apply_transpose(c::TensorMappingComposition{T,R,K,D}, v::AbstractArray{<:Any,R}, I::Vararg{Any,D}) where {T,R,K,D}
105 apply_transpose(c.t2, c.t1'*v, I...) 105 apply_transpose(c.t2, c.t1'*v, I...)
106 end 106 end
107 107
108 Base.@propagate_inbounds Base.:∘(s::TensorMapping, t::TensorMapping) = TensorMappingComposition(s,t) 108 Base.@propagate_inbounds Base.:∘(s::TensorMapping, t::TensorMapping) = TensorMappingComposition(s,t)
109 109
132 export LazyLinearMap 132 export LazyLinearMap
133 133
134 range_size(llm::LazyLinearMap) = size(llm.A)[[llm.range_indicies...]] 134 range_size(llm::LazyLinearMap) = size(llm.A)[[llm.range_indicies...]]
135 domain_size(llm::LazyLinearMap) = size(llm.A)[[llm.domain_indicies...]] 135 domain_size(llm::LazyLinearMap) = size(llm.A)[[llm.domain_indicies...]]
136 136
137 function apply(llm::LazyLinearMap{T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D} 137 function apply(llm::LazyLinearMap{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D}
138 view_index = ntuple(i->:,ndims(llm.A)) 138 view_index = ntuple(i->:,ndims(llm.A))
139 for i ∈ 1:R 139 for i ∈ 1:R
140 view_index = Base.setindex(view_index, Int(I[i]), llm.range_indicies[i]) 140 view_index = Base.setindex(view_index, Int(I[i]), llm.range_indicies[i])
141 end 141 end
142 A_view = @view llm.A[view_index...] 142 A_view = @view llm.A[view_index...]
143 return sum(A_view.*v) 143 return sum(A_view.*v)
144 end 144 end
145 145
146 function apply_transpose(llm::LazyLinearMap{T,R,D}, v::AbstractArray{T,R}, I::Vararg{Any,D}) where {T,R,D} 146 function apply_transpose(llm::LazyLinearMap{T,R,D}, v::AbstractArray{<:Any,R}, I::Vararg{Any,D}) where {T,R,D}
147 apply(LazyLinearMap(llm.A, llm.domain_indicies, llm.range_indicies), v, I...) 147 apply(LazyLinearMap(llm.A, llm.domain_indicies, llm.range_indicies), v, I...)
148 end 148 end
149 149
150 150
151 """ 151 """
164 IdentityMapping(size::Vararg{Int,D}) where D = IdentityMapping{Float64,D}(size) 164 IdentityMapping(size::Vararg{Int,D}) where D = IdentityMapping{Float64,D}(size)
165 165
166 range_size(tmi::IdentityMapping) = tmi.size 166 range_size(tmi::IdentityMapping) = tmi.size
167 domain_size(tmi::IdentityMapping) = tmi.size 167 domain_size(tmi::IdentityMapping) = tmi.size
168 168
169 apply(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...] 169 apply(tmi::IdentityMapping{T,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,D}) where {T,D} = v[I...]
170 apply_transpose(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...] 170 apply_transpose(tmi::IdentityMapping{T,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,D}) where {T,D} = v[I...]
171 171
172 """ 172 """
173 Base.:∘(tm, tmi) 173 Base.:∘(tm, tmi)
174 Base.:∘(tmi, tm) 174 Base.:∘(tmi, tm)
175 175
258 domain_size(itm.tm), 258 domain_size(itm.tm),
259 domain_size(itm.after), 259 domain_size(itm.after),
260 ) 260 )
261 end 261 end
262 262
263 function apply(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D} 263 function apply(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D}
264 dim_before = range_dim(itm.before) 264 dim_before = range_dim(itm.before)
265 dim_domain = domain_dim(itm.tm) 265 dim_domain = domain_dim(itm.tm)
266 dim_range = range_dim(itm.tm) 266 dim_range = range_dim(itm.tm)
267 dim_after = range_dim(itm.after) 267 dim_after = range_dim(itm.after)
268 268
270 270
271 v_inner = view(v, view_index...) 271 v_inner = view(v, view_index...)
272 return apply(itm.tm, v_inner, inner_index...) 272 return apply(itm.tm, v_inner, inner_index...)
273 end 273 end
274 274
275 function apply_transpose(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{T,R}, I::Vararg{Any,D}) where {T,R,D} 275 function apply_transpose(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{<:Any,R}, I::Vararg{Any,D}) where {T,R,D}
276 dim_before = range_dim(itm.before) 276 dim_before = range_dim(itm.before)
277 dim_domain = domain_dim(itm.tm) 277 dim_domain = domain_dim(itm.tm)
278 dim_range = range_dim(itm.tm) 278 dim_range = range_dim(itm.tm)
279 dim_after = range_dim(itm.after) 279 dim_after = range_dim(itm.after)
280 280