comparison src/LazyTensors/lazy_tensor_operations.jl @ 479:95f3b9036801 feature/compose_identity_mappings

Specialize composition operator for composing a tensormapping with an identitymapping.
author Vidar Stiernström <vidar.stiernstrom@it.uu.se>
date Wed, 04 Nov 2020 20:03:37 +0100
parents f270d82fc9ad
children c1a366331e75
comparison
equal deleted inserted replaced
473:3041f8578bba 479:95f3b9036801
84 struct TensorMappingComposition{T,R,K,D, TM1<:TensorMapping{T,R,K}, TM2<:TensorMapping{T,K,D}} <: TensorMapping{T,R,D} 84 struct TensorMappingComposition{T,R,K,D, TM1<:TensorMapping{T,R,K}, TM2<:TensorMapping{T,K,D}} <: TensorMapping{T,R,D}
85 t1::TM1 85 t1::TM1
86 t2::TM2 86 t2::TM2
87 87
88 @inline function TensorMappingComposition(t1::TensorMapping{T,R,K}, t2::TensorMapping{T,K,D}) where {T,R,K,D} 88 @inline function TensorMappingComposition(t1::TensorMapping{T,R,K}, t2::TensorMapping{T,K,D}) where {T,R,K,D}
89 @boundscheck if domain_size(t1) != range_size(t2) 89 @boundscheck check_matching_size(t1::TensorMapping, t2::TensorMapping)
90 throw(DimensionMismatch("the first argument has domain size $(domain_size(t1)) while the second has range size $(range_size(t2)) "))
91 end
92 return new{T,R,K,D, typeof(t1), typeof(t2)}(t1,t2) 90 return new{T,R,K,D, typeof(t1), typeof(t2)}(t1,t2)
93 end 91 end
94 # Add check for matching sizes as a boundscheck
95 end 92 end
96 export TensorMappingComposition 93 export TensorMappingComposition
94
95 function check_matching_size(t1::TensorMapping, t2::TensorMapping)
96 if domain_size(t1) != range_size(t2)
97 throw(DimensionMismatch("the first argument has domain size $(domain_size(t1)) while the second has range size $(range_size(t2)) "))
98 end
99 end
97 100
98 range_size(tm::TensorMappingComposition) = range_size(tm.t1) 101 range_size(tm::TensorMappingComposition) = range_size(tm.t1)
99 domain_size(tm::TensorMappingComposition) = domain_size(tm.t2) 102 domain_size(tm::TensorMappingComposition) = domain_size(tm.t2)
100 103
101 function apply(c::TensorMappingComposition{T,R,K,D}, v::AbstractArray{T,D}, I::Vararg{S,R} where S) where {T,R,K,D} 104 function apply(c::TensorMappingComposition{T,R,K,D}, v::AbstractArray{T,D}, I::Vararg{S,R} where S) where {T,R,K,D}
167 range_size(tmi::IdentityMapping) = tmi.size 170 range_size(tmi::IdentityMapping) = tmi.size
168 domain_size(tmi::IdentityMapping) = tmi.size 171 domain_size(tmi::IdentityMapping) = tmi.size
169 172
170 apply(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...] 173 apply(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...]
171 apply_transpose(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...] 174 apply_transpose(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...]
175
176 """
177 Base.:∘(tm, tmi)
178 Base.:∘(tmi, tm)
179
180 Composes a `Tensormapping`s `tm` with an `IdentityMapping`s `tmi`, by returning `tm`
181 """
182 @inline function Base.:∘(tm::TensorMapping{T,R,D}, tmi::IdentityMapping{T,D}) where {T,R,D}
183 @boundscheck check_matching_size(tm::TensorMapping, tmi::TensorMapping)
184 return tm
185 end
186
187 @inline function Base.:∘(tmi::IdentityMapping{T,R}, tm::TensorMapping{T,R,D}) where {T,R,D}
188 @boundscheck check_matching_size(tmi::TensorMapping, tm::TensorMapping)
189 return tm
190 end
191 # Specialization for the case where tm is an IdentityMapping. Required to resolve ambiguity.
192 @inline function Base.:∘(tm::IdentityMapping{T,D}, tmi::IdentityMapping{T,D}) where {T,D}
193 @boundscheck check_matching_size(tm::TensorMapping, tmi::TensorMapping)
194 return tmi
195 end
172 196
173 197
174 """ 198 """
175 InflatedTensorMapping{T,R,D} <: TensorMapping{T,R,D} 199 InflatedTensorMapping{T,R,D} <: TensorMapping{T,R,D}
176 200