Mercurial > repos > public > sbplib_julia
changeset 479:95f3b9036801 feature/compose_identity_mappings
Specialize composition operator for composing a tensormapping with an identitymapping.
author | Vidar Stiernström <vidar.stiernstrom@it.uu.se> |
---|---|
date | Wed, 04 Nov 2020 20:03:37 +0100 |
parents | 3041f8578bba |
children | c1a366331e75 |
files | src/LazyTensors/lazy_tensor_operations.jl test/testLazyTensors.jl |
diffstat | 2 files changed, 39 insertions(+), 4 deletions(-) [+] |
line wrap: on
line diff
--- a/src/LazyTensors/lazy_tensor_operations.jl Mon Nov 02 21:33:35 2020 +0100 +++ b/src/LazyTensors/lazy_tensor_operations.jl Wed Nov 04 20:03:37 2020 +0100 @@ -86,15 +86,18 @@ t2::TM2 @inline function TensorMappingComposition(t1::TensorMapping{T,R,K}, t2::TensorMapping{T,K,D}) where {T,R,K,D} - @boundscheck if domain_size(t1) != range_size(t2) - throw(DimensionMismatch("the first argument has domain size $(domain_size(t1)) while the second has range size $(range_size(t2)) ")) - end + @boundscheck check_matching_size(t1::TensorMapping, t2::TensorMapping) return new{T,R,K,D, typeof(t1), typeof(t2)}(t1,t2) end - # Add check for matching sizes as a boundscheck end export TensorMappingComposition +function check_matching_size(t1::TensorMapping, t2::TensorMapping) + if domain_size(t1) != range_size(t2) + throw(DimensionMismatch("the first argument has domain size $(domain_size(t1)) while the second has range size $(range_size(t2)) ")) + end +end + range_size(tm::TensorMappingComposition) = range_size(tm.t1) domain_size(tm::TensorMappingComposition) = domain_size(tm.t2) @@ -170,6 +173,27 @@ apply(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...] apply_transpose(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...] +""" +Base.:∘(tm, tmi) +Base.:∘(tmi, tm) + +Composes a `Tensormapping`s `tm` with an `IdentityMapping`s `tmi`, by returning `tm` +""" +@inline function Base.:∘(tm::TensorMapping{T,R,D}, tmi::IdentityMapping{T,D}) where {T,R,D} + @boundscheck check_matching_size(tm::TensorMapping, tmi::TensorMapping) + return tm +end + +@inline function Base.:∘(tmi::IdentityMapping{T,R}, tm::TensorMapping{T,R,D}) where {T,R,D} + @boundscheck check_matching_size(tmi::TensorMapping, tm::TensorMapping) + return tm +end +# Specialization for the case where tm is an IdentityMapping. Required to resolve ambiguity. +@inline function Base.:∘(tm::IdentityMapping{T,D}, tmi::IdentityMapping{T,D}) where {T,D} + @boundscheck check_matching_size(tm::TensorMapping, tmi::TensorMapping) + return tmi +end + """ InflatedTensorMapping{T,R,D} <: TensorMapping{T,R,D}
--- a/test/testLazyTensors.jl Mon Nov 02 21:33:35 2020 +0100 +++ b/test/testLazyTensors.jl Wed Nov 04 20:03:37 2020 +0100 @@ -312,6 +312,17 @@ @inferred range_dim(I) @inferred domain_dim(I) + + Ã = rand(4,2) + A = LazyLinearMap(Ã,(1,),(2,)) + I1 = IdentityMapping{Float64}(2) + I2 = IdentityMapping{Float64}(4) + @test A∘I1 == A + @test I2∘A == A + @test I1∘I1 == I1 + @test_throws DimensionMismatch I1∘A + @test_throws DimensionMismatch A∘I2 + @test_throws DimensionMismatch I1∘I2 end @testset "InflatedTensorMapping" begin