Mercurial > repos > public > sbplib_julia
changeset 1101:1e8270c18edb feature/lazy_tensors/pretty_printing
Merge default
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Thu, 12 May 2022 21:52:47 +0200 |
parents | 67969bd7e642 (diff) 74c54996de6a (current diff) |
children | 84820d4780fa |
files | src/LazyTensors/lazy_tensor_operations.jl src/LazyTensors/tensor_types.jl test/LazyTensors/lazy_tensor_operations_test.jl test/LazyTensors/tensor_types_test.jl |
diffstat | 4 files changed, 63 insertions(+), 0 deletions(-) [+] |
line wrap: on
line diff
--- a/src/LazyTensors/lazy_tensor_operations.jl Sun May 08 11:35:22 2022 +0200 +++ b/src/LazyTensors/lazy_tensor_operations.jl Thu May 12 21:52:47 2022 +0200 @@ -98,6 +98,7 @@ apply_transpose(c.t2, c.t1'*v, I...) end + """ TensorComposition(tm, tmi::IdentityTensor) TensorComposition(tmi::IdentityTensor, tm) @@ -215,6 +216,15 @@ return apply_transpose(itm.tm, v_inner, inner_index...) end +function Base.show(io::IO, ::MIME"text/plain", tm::InflatedTensor{T}) where T + show(IOContext(io, :compact=>true), MIME("text/plain"), tm.before) + print(io, "⊗") + # if get(io, :compact, false) + show(io, MIME("text/plain"), tm.tm) + print(io, "⊗") + show(IOContext(io, :compact=>true), MIME("text/plain"), tm.after) +end + @doc raw""" LazyOuterProduct(tms...)
--- a/src/LazyTensors/tensor_types.jl Sun May 08 11:35:22 2022 +0200 +++ b/src/LazyTensors/tensor_types.jl Thu May 12 21:52:47 2022 +0200 @@ -18,6 +18,17 @@ apply(tmi::IdentityTensor{T,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,D}) where {T,D} = v[I...] apply_transpose(tmi::IdentityTensor{T,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,D}) where {T,D} = v[I...] +function Base.show(io::IO, ::MIME"text/plain", tm::IdentityTensor{T}) where T + if get(io, :compact, false) + print(io, "I") + else + print(io, "IdentityTensor{$T}") + end + print(io, "(") + join(io, tm.size, ",") + print(io, ")") +end + """ ScalingTensor{T,D} <: LazyTensor{T,D,D} @@ -35,6 +46,18 @@ LazyTensors.range_size(m::ScalingTensor) = m.size LazyTensors.domain_size(m::ScalingTensor) = m.size +function Base.show(io::IO, ::MIME"text/plain", tm::ScalingTensor{T}) where T + if get(io, :compact, false) + print(io, "$(tm.λ)*I(") + join(io, tm.size, ",") + print(io, ")") + else + print(io, "ScalingTensor{$T}(") + print(io, tm.λ, ", ") + print(io, tm.size) + print(io, ")") + end +end """ DiagonalTensor{T,D,...} <: LazyTensor{T,D,D}
--- a/test/LazyTensors/lazy_tensor_operations_test.jl Sun May 08 11:35:22 2022 +0200 +++ b/test/LazyTensors/lazy_tensor_operations_test.jl Thu May 12 21:52:47 2022 +0200 @@ -313,6 +313,20 @@ @test InflatedTensor(I(2), I(2), I(2)) isa InflatedTensor # The constructor should always return its type. end + + @testset "Pretty printing" begin + cases = [ + InflatedTensor(I(4), ScalingTensor(2., (3,2)), I(2)) => ( + regular="I(4)⊗ScalingTensor{Float64}(2.0, (3, 2))⊗I(2)", + compact="I(4)⊗2.0*I(3,2)⊗I(2)" + ) + ] + + @testset "$tm" for (tm, r) ∈ cases + @test repr(MIME("text/plain"), tm) == r.regular + @test repr(MIME("text/plain"), tm, context=:compact=>true) == r.compact + end + end end @testset "LazyOuterProduct" begin
--- a/test/LazyTensors/tensor_types_test.jl Sun May 08 11:35:22 2022 +0200 +++ b/test/LazyTensors/tensor_types_test.jl Thu May 12 21:52:47 2022 +0200 @@ -42,6 +42,14 @@ @test_throws DomainSizeMismatch I1∘A @test_throws DomainSizeMismatch A∘I2 @test_throws DomainSizeMismatch I1∘I2 + + @testset "Pretty printing" begin + @test repr(MIME("text/plain"), IdentityTensor{Float64}(5)) == "IdentityTensor{Float64}(5)" + @test repr(MIME("text/plain"), IdentityTensor{Int}(4,5)) == "IdentityTensor{Int64}(4,5)" + + @test repr(MIME("text/plain"), IdentityTensor{Float64}(5), context=:compact=>true) == "I(5)" + @test repr(MIME("text/plain"), IdentityTensor{Int}(4,5), context=:compact=>true) == "I(4,5)" + end end @@ -57,6 +65,14 @@ @inferred (st*v)[2,2] @inferred (st'*v)[2,2] + + @testset "Pretty printing" begin + @test repr(MIME("text/plain"), ScalingTensor(2., (5,))) == "ScalingTensor{Float64}(2.0, (5,))" # TODO: Can we make this nicer? + @test repr(MIME("text/plain"), ScalingTensor(3, (4,5))) == "ScalingTensor{Int64}(3, (4, 5))" + + @test repr(MIME("text/plain"), ScalingTensor(4., (5,)), context=:compact=>true) == "4.0*I(5)" + @test repr(MIME("text/plain"), ScalingTensor(2, (4,5)), context=:compact=>true) == "2*I(4,5)" + end end @testset "DiagonalTensor" begin