changeset 1101:1e8270c18edb feature/lazy_tensors/pretty_printing

Merge default
author Jonatan Werpers <jonatan@werpers.com>
date Thu, 12 May 2022 21:52:47 +0200
parents 67969bd7e642 (diff) 74c54996de6a (current diff)
children 84820d4780fa
files src/LazyTensors/lazy_tensor_operations.jl src/LazyTensors/tensor_types.jl test/LazyTensors/lazy_tensor_operations_test.jl test/LazyTensors/tensor_types_test.jl
diffstat 4 files changed, 63 insertions(+), 0 deletions(-) [+]
line wrap: on
line diff
--- a/src/LazyTensors/lazy_tensor_operations.jl	Sun May 08 11:35:22 2022 +0200
+++ b/src/LazyTensors/lazy_tensor_operations.jl	Thu May 12 21:52:47 2022 +0200
@@ -98,6 +98,7 @@
     apply_transpose(c.t2, c.t1'*v, I...)
 end
 
+
 """
     TensorComposition(tm, tmi::IdentityTensor)
     TensorComposition(tmi::IdentityTensor, tm)
@@ -215,6 +216,15 @@
     return apply_transpose(itm.tm, v_inner, inner_index...)
 end
 
+function Base.show(io::IO, ::MIME"text/plain", tm::InflatedTensor{T}) where T
+    show(IOContext(io, :compact=>true), MIME("text/plain"), tm.before)
+    print(io, "⊗")
+    # if get(io, :compact, false)
+    show(io, MIME("text/plain"), tm.tm)
+    print(io, "⊗")
+    show(IOContext(io, :compact=>true), MIME("text/plain"), tm.after)
+end
+
 
 @doc raw"""
     LazyOuterProduct(tms...)
--- a/src/LazyTensors/tensor_types.jl	Sun May 08 11:35:22 2022 +0200
+++ b/src/LazyTensors/tensor_types.jl	Thu May 12 21:52:47 2022 +0200
@@ -18,6 +18,17 @@
 apply(tmi::IdentityTensor{T,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,D}) where {T,D} = v[I...]
 apply_transpose(tmi::IdentityTensor{T,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,D}) where {T,D} = v[I...]
 
+function Base.show(io::IO, ::MIME"text/plain", tm::IdentityTensor{T}) where T
+    if get(io, :compact, false)
+        print(io, "I")
+    else
+        print(io, "IdentityTensor{$T}")
+    end
+    print(io, "(")
+    join(io, tm.size, ",")
+    print(io, ")")
+end
+
 
 """
     ScalingTensor{T,D} <: LazyTensor{T,D,D}
@@ -35,6 +46,18 @@
 LazyTensors.range_size(m::ScalingTensor) = m.size
 LazyTensors.domain_size(m::ScalingTensor) = m.size
 
+function Base.show(io::IO, ::MIME"text/plain", tm::ScalingTensor{T}) where T
+    if get(io, :compact, false)
+        print(io, "$(tm.λ)*I(")
+        join(io, tm.size, ",")
+        print(io, ")")
+    else
+        print(io, "ScalingTensor{$T}(")
+        print(io, tm.λ, ", ")
+        print(io, tm.size)
+        print(io, ")")
+    end
+end
 
 """
     DiagonalTensor{T,D,...} <: LazyTensor{T,D,D}
--- a/test/LazyTensors/lazy_tensor_operations_test.jl	Sun May 08 11:35:22 2022 +0200
+++ b/test/LazyTensors/lazy_tensor_operations_test.jl	Thu May 12 21:52:47 2022 +0200
@@ -313,6 +313,20 @@
 
         @test InflatedTensor(I(2), I(2), I(2)) isa InflatedTensor # The constructor should always return its type.
     end
+
+    @testset "Pretty printing" begin
+        cases = [
+            InflatedTensor(I(4), ScalingTensor(2., (3,2)), I(2)) => (
+                regular="I(4)⊗ScalingTensor{Float64}(2.0, (3, 2))⊗I(2)",
+                compact="I(4)⊗2.0*I(3,2)⊗I(2)"
+            )
+        ]
+
+        @testset "$tm" for (tm, r) ∈ cases
+            @test repr(MIME("text/plain"), tm) == r.regular
+            @test repr(MIME("text/plain"), tm, context=:compact=>true) == r.compact
+        end
+    end
 end
 
 @testset "LazyOuterProduct" begin
--- a/test/LazyTensors/tensor_types_test.jl	Sun May 08 11:35:22 2022 +0200
+++ b/test/LazyTensors/tensor_types_test.jl	Thu May 12 21:52:47 2022 +0200
@@ -42,6 +42,14 @@
     @test_throws DomainSizeMismatch I1∘A
     @test_throws DomainSizeMismatch A∘I2
     @test_throws DomainSizeMismatch I1∘I2
+
+    @testset "Pretty printing" begin
+        @test repr(MIME("text/plain"), IdentityTensor{Float64}(5)) == "IdentityTensor{Float64}(5)"
+        @test repr(MIME("text/plain"), IdentityTensor{Int}(4,5)) == "IdentityTensor{Int64}(4,5)"
+
+        @test repr(MIME("text/plain"), IdentityTensor{Float64}(5), context=:compact=>true) == "I(5)"
+        @test repr(MIME("text/plain"), IdentityTensor{Int}(4,5), context=:compact=>true) == "I(4,5)"
+    end
 end
 
 
@@ -57,6 +65,14 @@
 
     @inferred (st*v)[2,2]
     @inferred (st'*v)[2,2]
+
+    @testset "Pretty printing" begin
+        @test repr(MIME("text/plain"), ScalingTensor(2., (5,))) == "ScalingTensor{Float64}(2.0, (5,))" # TODO: Can we make this nicer?
+        @test repr(MIME("text/plain"), ScalingTensor(3, (4,5))) == "ScalingTensor{Int64}(3, (4, 5))"
+
+        @test repr(MIME("text/plain"), ScalingTensor(4., (5,)), context=:compact=>true) == "4.0*I(5)"
+        @test repr(MIME("text/plain"), ScalingTensor(2, (4,5)), context=:compact=>true) == "2*I(4,5)"
+    end
 end
 
 @testset "DiagonalTensor" begin