Mercurial > repos > public > sbplib_julia
changeset 439:00c317c9ccfb feature/lazy_identity
Merge in default to get fix for LazyTranspose
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Mon, 19 Oct 2020 21:11:01 +0200 |
parents | 1db5ec38955e (diff) 907b0510699f (current diff) |
children | a57b71343aeb |
files | src/LazyTensors/lazy_tensor_operations.jl test/testLazyTensors.jl |
diffstat | 2 files changed, 43 insertions(+), 0 deletions(-) [+] |
line wrap: on
line diff
--- a/src/LazyTensors/lazy_tensor_operations.jl Mon Oct 19 21:09:11 2020 +0200 +++ b/src/LazyTensors/lazy_tensor_operations.jl Mon Oct 19 21:11:01 2020 +0200 @@ -147,3 +147,24 @@ function apply_transpose(llm::LazyLinearMap{T,R,D}, v::AbstractArray{T,R}, I::Vararg{Index,D}) where {T,R,D} apply(LazyLinearMap(llm.A, llm.domain_indicies, llm.range_indicies), v, I...) end + + +""" + IdentityMapping{T,D} <: TensorMapping{T,D,D} + +The lazy identity TensorMapping for a given size. Usefull for building up higher dimensional tensor mappings from lower +dimensional ones through outer products. Also used in the Implementation for InflatedTensorMapping. +""" +struct IdentityMapping{T,D} <: TensorMapping{T,D,D} + size::NTuple{D,Int} +end +export IdentityMapping + +IdentityMapping{T}(size::NTuple{D,Int}) where {T,D} = IdentityMapping{T,D}(size) + +range_size(tmi::IdentityMapping) = tmi.size +domain_size(tmi::IdentityMapping) = tmi.size + +apply(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...] +apply_transpose(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...] +
--- a/test/testLazyTensors.jl Mon Oct 19 21:09:11 2020 +0200 +++ b/test/testLazyTensors.jl Mon Oct 19 21:11:01 2020 +0200 @@ -281,4 +281,26 @@ end + +@testset "IdentityMapping" begin + @test IdentityMapping{Float64}((4,5)) isa IdentityMapping{T,2} where T + @test IdentityMapping{Float64}((4,5)) isa TensorMapping{T,2,2} where T + + for sz ∈ [(4,5),(3,),(5,6,4)] + I = IdentityMapping{Float64}(sz) + v = rand(sz...) + @test I*v == v + @test I'*v == v + + @test range_size(I) == sz + @test domain_size(I) == sz + end + + I = IdentityMapping{Float64}((4,5)) + v = rand(4,5) + @inferred (I*v)[3,2] + @test_broken @inferred (I'*v)[3,2] # TODO: Should fix the index typing before investigating this + @inferred range_size(I) end + +end