Mercurial > repos > public > sbplib_julia
comparison src/LazyTensors/lazy_tensor_operations.jl @ 479:95f3b9036801 feature/compose_identity_mappings
Specialize composition operator for composing a tensormapping with an identitymapping.
author | Vidar Stiernström <vidar.stiernstrom@it.uu.se> |
---|---|
date | Wed, 04 Nov 2020 20:03:37 +0100 |
parents | f270d82fc9ad |
children | c1a366331e75 |
comparison
equal
deleted
inserted
replaced
473:3041f8578bba | 479:95f3b9036801 |
---|---|
84 struct TensorMappingComposition{T,R,K,D, TM1<:TensorMapping{T,R,K}, TM2<:TensorMapping{T,K,D}} <: TensorMapping{T,R,D} | 84 struct TensorMappingComposition{T,R,K,D, TM1<:TensorMapping{T,R,K}, TM2<:TensorMapping{T,K,D}} <: TensorMapping{T,R,D} |
85 t1::TM1 | 85 t1::TM1 |
86 t2::TM2 | 86 t2::TM2 |
87 | 87 |
88 @inline function TensorMappingComposition(t1::TensorMapping{T,R,K}, t2::TensorMapping{T,K,D}) where {T,R,K,D} | 88 @inline function TensorMappingComposition(t1::TensorMapping{T,R,K}, t2::TensorMapping{T,K,D}) where {T,R,K,D} |
89 @boundscheck if domain_size(t1) != range_size(t2) | 89 @boundscheck check_matching_size(t1::TensorMapping, t2::TensorMapping) |
90 throw(DimensionMismatch("the first argument has domain size $(domain_size(t1)) while the second has range size $(range_size(t2)) ")) | |
91 end | |
92 return new{T,R,K,D, typeof(t1), typeof(t2)}(t1,t2) | 90 return new{T,R,K,D, typeof(t1), typeof(t2)}(t1,t2) |
93 end | 91 end |
94 # Add check for matching sizes as a boundscheck | |
95 end | 92 end |
96 export TensorMappingComposition | 93 export TensorMappingComposition |
94 | |
95 function check_matching_size(t1::TensorMapping, t2::TensorMapping) | |
96 if domain_size(t1) != range_size(t2) | |
97 throw(DimensionMismatch("the first argument has domain size $(domain_size(t1)) while the second has range size $(range_size(t2)) ")) | |
98 end | |
99 end | |
97 | 100 |
98 range_size(tm::TensorMappingComposition) = range_size(tm.t1) | 101 range_size(tm::TensorMappingComposition) = range_size(tm.t1) |
99 domain_size(tm::TensorMappingComposition) = domain_size(tm.t2) | 102 domain_size(tm::TensorMappingComposition) = domain_size(tm.t2) |
100 | 103 |
101 function apply(c::TensorMappingComposition{T,R,K,D}, v::AbstractArray{T,D}, I::Vararg{S,R} where S) where {T,R,K,D} | 104 function apply(c::TensorMappingComposition{T,R,K,D}, v::AbstractArray{T,D}, I::Vararg{S,R} where S) where {T,R,K,D} |
167 range_size(tmi::IdentityMapping) = tmi.size | 170 range_size(tmi::IdentityMapping) = tmi.size |
168 domain_size(tmi::IdentityMapping) = tmi.size | 171 domain_size(tmi::IdentityMapping) = tmi.size |
169 | 172 |
170 apply(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...] | 173 apply(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...] |
171 apply_transpose(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...] | 174 apply_transpose(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...] |
175 | |
176 """ | |
177 Base.:∘(tm, tmi) | |
178 Base.:∘(tmi, tm) | |
179 | |
180 Composes a `Tensormapping`s `tm` with an `IdentityMapping`s `tmi`, by returning `tm` | |
181 """ | |
182 @inline function Base.:∘(tm::TensorMapping{T,R,D}, tmi::IdentityMapping{T,D}) where {T,R,D} | |
183 @boundscheck check_matching_size(tm::TensorMapping, tmi::TensorMapping) | |
184 return tm | |
185 end | |
186 | |
187 @inline function Base.:∘(tmi::IdentityMapping{T,R}, tm::TensorMapping{T,R,D}) where {T,R,D} | |
188 @boundscheck check_matching_size(tmi::TensorMapping, tm::TensorMapping) | |
189 return tm | |
190 end | |
191 # Specialization for the case where tm is an IdentityMapping. Required to resolve ambiguity. | |
192 @inline function Base.:∘(tm::IdentityMapping{T,D}, tmi::IdentityMapping{T,D}) where {T,D} | |
193 @boundscheck check_matching_size(tm::TensorMapping, tmi::TensorMapping) | |
194 return tmi | |
195 end | |
172 | 196 |
173 | 197 |
174 """ | 198 """ |
175 InflatedTensorMapping{T,R,D} <: TensorMapping{T,R,D} | 199 InflatedTensorMapping{T,R,D} <: TensorMapping{T,R,D} |
176 | 200 |