Mercurial > repos > public > sbplib_julia
comparison src/LazyTensors/lazy_tensor_operations.jl @ 943:fb060e98ac0a feature/tensormapping_application_promotion
Remove more type assertions
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Thu, 10 Mar 2022 16:57:01 +0100 |
parents | 7829c09f8137 |
children | 4a4ef4bf6cb9 |
comparison
equal
deleted
inserted
replaced
942:7829c09f8137 | 943:fb060e98ac0a |
---|---|
50 # # TBD: Should this be implemented on a type by type basis or through a trait to provide earlier errors? | 50 # # TBD: Should this be implemented on a type by type basis or through a trait to provide earlier errors? |
51 # Jonatan 2020-09-25: Is the problem that you can take the transpose of any TensorMapping even if it doesn't implement `apply_transpose`? | 51 # Jonatan 2020-09-25: Is the problem that you can take the transpose of any TensorMapping even if it doesn't implement `apply_transpose`? |
52 Base.adjoint(tm::TensorMapping) = LazyTensorMappingTranspose(tm) | 52 Base.adjoint(tm::TensorMapping) = LazyTensorMappingTranspose(tm) |
53 Base.adjoint(tmt::LazyTensorMappingTranspose) = tmt.tm | 53 Base.adjoint(tmt::LazyTensorMappingTranspose) = tmt.tm |
54 | 54 |
55 apply(tmt::LazyTensorMappingTranspose{T,R,D}, v::AbstractArray{T,R}, I::Vararg{Any,D}) where {T,R,D} = apply_transpose(tmt.tm, v, I...) | 55 apply(tmt::LazyTensorMappingTranspose{T,R,D}, v::AbstractArray{<:Any,R}, I::Vararg{Any,D}) where {T,R,D} = apply_transpose(tmt.tm, v, I...) |
56 apply_transpose(tmt::LazyTensorMappingTranspose{T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmt.tm, v, I...) | 56 apply_transpose(tmt::LazyTensorMappingTranspose{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmt.tm, v, I...) |
57 | 57 |
58 range_size(tmt::LazyTensorMappingTranspose) = domain_size(tmt.tm) | 58 range_size(tmt::LazyTensorMappingTranspose) = domain_size(tmt.tm) |
59 domain_size(tmt::LazyTensorMappingTranspose) = range_size(tmt.tm) | 59 domain_size(tmt::LazyTensorMappingTranspose) = range_size(tmt.tm) |
60 | 60 |
61 | 61 |
67 return new{Op,T,R,D,T1,T2}(tm1,tm2) | 67 return new{Op,T,R,D,T1,T2}(tm1,tm2) |
68 end | 68 end |
69 end | 69 end |
70 # TODO: Boundschecking in constructor. | 70 # TODO: Boundschecking in constructor. |
71 | 71 |
72 apply(tmBinOp::LazyTensorMappingBinaryOperation{:+,T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) + apply(tmBinOp.tm2, v, I...) | 72 apply(tmBinOp::LazyTensorMappingBinaryOperation{:+,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) + apply(tmBinOp.tm2, v, I...) |
73 apply(tmBinOp::LazyTensorMappingBinaryOperation{:-,T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) - apply(tmBinOp.tm2, v, I...) | 73 apply(tmBinOp::LazyTensorMappingBinaryOperation{:-,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) - apply(tmBinOp.tm2, v, I...) |
74 | 74 |
75 range_size(tmBinOp::LazyTensorMappingBinaryOperation) = range_size(tmBinOp.tm1) | 75 range_size(tmBinOp::LazyTensorMappingBinaryOperation) = range_size(tmBinOp.tm1) |
76 domain_size(tmBinOp::LazyTensorMappingBinaryOperation) = domain_size(tmBinOp.tm1) | 76 domain_size(tmBinOp::LazyTensorMappingBinaryOperation) = domain_size(tmBinOp.tm1) |
77 | 77 |
78 Base.:+(tm1::TensorMapping{T,R,D}, tm2::TensorMapping{T,R,D}) where {T,R,D} = LazyTensorMappingBinaryOperation{:+,T,R,D}(tm1,tm2) | 78 Base.:+(tm1::TensorMapping{T,R,D}, tm2::TensorMapping{T,R,D}) where {T,R,D} = LazyTensorMappingBinaryOperation{:+,T,R,D}(tm1,tm2) |
95 export TensorMappingComposition | 95 export TensorMappingComposition |
96 | 96 |
97 range_size(tm::TensorMappingComposition) = range_size(tm.t1) | 97 range_size(tm::TensorMappingComposition) = range_size(tm.t1) |
98 domain_size(tm::TensorMappingComposition) = domain_size(tm.t2) | 98 domain_size(tm::TensorMappingComposition) = domain_size(tm.t2) |
99 | 99 |
100 function apply(c::TensorMappingComposition{T,R,K,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,K,D} | 100 function apply(c::TensorMappingComposition{T,R,K,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,K,D} |
101 apply(c.t1, c.t2*v, I...) | 101 apply(c.t1, c.t2*v, I...) |
102 end | 102 end |
103 | 103 |
104 function apply_transpose(c::TensorMappingComposition{T,R,K,D}, v::AbstractArray{T,R}, I::Vararg{Any,D}) where {T,R,K,D} | 104 function apply_transpose(c::TensorMappingComposition{T,R,K,D}, v::AbstractArray{<:Any,R}, I::Vararg{Any,D}) where {T,R,K,D} |
105 apply_transpose(c.t2, c.t1'*v, I...) | 105 apply_transpose(c.t2, c.t1'*v, I...) |
106 end | 106 end |
107 | 107 |
108 Base.@propagate_inbounds Base.:∘(s::TensorMapping, t::TensorMapping) = TensorMappingComposition(s,t) | 108 Base.@propagate_inbounds Base.:∘(s::TensorMapping, t::TensorMapping) = TensorMappingComposition(s,t) |
109 | 109 |
132 export LazyLinearMap | 132 export LazyLinearMap |
133 | 133 |
134 range_size(llm::LazyLinearMap) = size(llm.A)[[llm.range_indicies...]] | 134 range_size(llm::LazyLinearMap) = size(llm.A)[[llm.range_indicies...]] |
135 domain_size(llm::LazyLinearMap) = size(llm.A)[[llm.domain_indicies...]] | 135 domain_size(llm::LazyLinearMap) = size(llm.A)[[llm.domain_indicies...]] |
136 | 136 |
137 function apply(llm::LazyLinearMap{T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D} | 137 function apply(llm::LazyLinearMap{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} |
138 view_index = ntuple(i->:,ndims(llm.A)) | 138 view_index = ntuple(i->:,ndims(llm.A)) |
139 for i ∈ 1:R | 139 for i ∈ 1:R |
140 view_index = Base.setindex(view_index, Int(I[i]), llm.range_indicies[i]) | 140 view_index = Base.setindex(view_index, Int(I[i]), llm.range_indicies[i]) |
141 end | 141 end |
142 A_view = @view llm.A[view_index...] | 142 A_view = @view llm.A[view_index...] |
143 return sum(A_view.*v) | 143 return sum(A_view.*v) |
144 end | 144 end |
145 | 145 |
146 function apply_transpose(llm::LazyLinearMap{T,R,D}, v::AbstractArray{T,R}, I::Vararg{Any,D}) where {T,R,D} | 146 function apply_transpose(llm::LazyLinearMap{T,R,D}, v::AbstractArray{<:Any,R}, I::Vararg{Any,D}) where {T,R,D} |
147 apply(LazyLinearMap(llm.A, llm.domain_indicies, llm.range_indicies), v, I...) | 147 apply(LazyLinearMap(llm.A, llm.domain_indicies, llm.range_indicies), v, I...) |
148 end | 148 end |
149 | 149 |
150 | 150 |
151 """ | 151 """ |
164 IdentityMapping(size::Vararg{Int,D}) where D = IdentityMapping{Float64,D}(size) | 164 IdentityMapping(size::Vararg{Int,D}) where D = IdentityMapping{Float64,D}(size) |
165 | 165 |
166 range_size(tmi::IdentityMapping) = tmi.size | 166 range_size(tmi::IdentityMapping) = tmi.size |
167 domain_size(tmi::IdentityMapping) = tmi.size | 167 domain_size(tmi::IdentityMapping) = tmi.size |
168 | 168 |
169 apply(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...] | 169 apply(tmi::IdentityMapping{T,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,D}) where {T,D} = v[I...] |
170 apply_transpose(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...] | 170 apply_transpose(tmi::IdentityMapping{T,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,D}) where {T,D} = v[I...] |
171 | 171 |
172 """ | 172 """ |
173 Base.:∘(tm, tmi) | 173 Base.:∘(tm, tmi) |
174 Base.:∘(tmi, tm) | 174 Base.:∘(tmi, tm) |
175 | 175 |
258 domain_size(itm.tm), | 258 domain_size(itm.tm), |
259 domain_size(itm.after), | 259 domain_size(itm.after), |
260 ) | 260 ) |
261 end | 261 end |
262 | 262 |
263 function apply(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D} | 263 function apply(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} |
264 dim_before = range_dim(itm.before) | 264 dim_before = range_dim(itm.before) |
265 dim_domain = domain_dim(itm.tm) | 265 dim_domain = domain_dim(itm.tm) |
266 dim_range = range_dim(itm.tm) | 266 dim_range = range_dim(itm.tm) |
267 dim_after = range_dim(itm.after) | 267 dim_after = range_dim(itm.after) |
268 | 268 |
270 | 270 |
271 v_inner = view(v, view_index...) | 271 v_inner = view(v, view_index...) |
272 return apply(itm.tm, v_inner, inner_index...) | 272 return apply(itm.tm, v_inner, inner_index...) |
273 end | 273 end |
274 | 274 |
275 function apply_transpose(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{T,R}, I::Vararg{Any,D}) where {T,R,D} | 275 function apply_transpose(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{<:Any,R}, I::Vararg{Any,D}) where {T,R,D} |
276 dim_before = range_dim(itm.before) | 276 dim_before = range_dim(itm.before) |
277 dim_domain = domain_dim(itm.tm) | 277 dim_domain = domain_dim(itm.tm) |
278 dim_range = range_dim(itm.tm) | 278 dim_range = range_dim(itm.tm) |
279 dim_after = range_dim(itm.after) | 279 dim_after = range_dim(itm.after) |
280 | 280 |