changeset 450:ac6d22570a08 feature/inflated_tensormapping

Merge in feature/lazy_identity
author Jonatan Werpers <jonatan@werpers.com>
date Mon, 19 Oct 2020 21:42:57 +0200
parents 14d60de71b72 (current diff) e70e47fbfa7c (diff)
children 6cf234eef780
files src/LazyTensors/lazy_tensor_operations.jl test/testLazyTensors.jl
diffstat 2 files changed, 13 insertions(+), 5 deletions(-) [+]
line wrap: on
line diff
--- a/src/LazyTensors/lazy_tensor_operations.jl	Mon Oct 19 21:14:46 2020 +0200
+++ b/src/LazyTensors/lazy_tensor_operations.jl	Mon Oct 19 21:42:57 2020 +0200
@@ -16,14 +16,16 @@
 
 # TODO: Go through and remove unneccerary type parameters on functions
 
-Base.:*(tm::TensorMapping{T,R,D}, o::AbstractArray{T,D}) where {T,R,D} = LazyTensorMappingApplication(tm,o)
 Base.getindex(ta::LazyTensorMappingApplication{T,R,D}, I::Vararg{Index,R}) where {T,R,D} = apply(ta.t, ta.o, I...)
 Base.getindex(ta::LazyTensorMappingApplication{T,R,D}, I::Vararg{Int,R}) where {T,R,D} = apply(ta.t, ta.o, Index{Unknown}.(I)...)
 Base.size(ta::LazyTensorMappingApplication) = range_size(ta.t)
 # TODO: What else is needed to implement the AbstractArray interface?
 
+Base.:*(a::TensorMapping, v::AbstractArray) = LazyTensorMappingApplication(a,v)
+Base.:*(a::TensorMapping, b::TensorMapping) = throw(MethodError(Base.:*,(a,b)))
+Base.:*(a::TensorMapping, args::Union{TensorMapping, AbstractArray}...) = foldr(*,(a,args...))
+
 # # We need the associativity to be a→b→c = a→(b→c), which is the case for '→'
-Base.:*(a::TensorMapping{T,R,D}, b::TensorMapping{T,D,K}, args::Union{TensorMapping{T}, AbstractArray{T}}...) where {T,R,D,K} = foldr(*,(a,b,args...))
 # # Should we overload some other infix binary opesrator?
 # →(tm::TensorMapping{T,R,D}, o::AbstractArray{T,D}) where {T,R,D} = LazyTensorMappingApplication(tm,o)
 # TODO: We need to be really careful about good error messages.
@@ -38,8 +40,8 @@
 the transpose of mapping `m` by using `m'`. `m'` will work as a regular TensorMapping lazily calling
 the appropriate methods of `m`.
 """
-struct LazyTensorMappingTranspose{T,R,D} <: TensorMapping{T,D,R}
-    tm::TensorMapping{T,R,D}
+struct LazyTensorMappingTranspose{T,R,D, TM<:TensorMapping{T,R,D}} <: TensorMapping{T,D,R}
+    tm::TM
 end
 export LazyTensorMappingTranspose
 
@@ -159,6 +161,8 @@
 export IdentityMapping
 
 IdentityMapping{T}(size::NTuple{D,Int}) where {T,D} = IdentityMapping{T,D}(size)
+IdentityMapping{T}(size::Vararg{Int,D}) where {T,D} = IdentityMapping{T,D}(size)
+IdentityMapping(size::Vararg{Int,D}) where D = IdentityMapping{Float64,D}(size)
 
 range_size(tmi::IdentityMapping) = tmi.size
 domain_size(tmi::IdentityMapping) = tmi.size
--- a/test/testLazyTensors.jl	Mon Oct 19 21:14:46 2020 +0200
+++ b/test/testLazyTensors.jl	Mon Oct 19 21:42:57 2020 +0200
@@ -58,6 +58,7 @@
     @test (m*m*v)[6] == (:apply,m*v,(Index{Unknown}(6),))
     @test_broken BoundsError == (m*m*v)[0]
     @test_broken BoundsError == (m*m*v)[7]
+    @test_throws MethodError m*m
 
     m = SizeDoublingMapping{Int, 2, 1}((3,))
     @test_throws MethodError m*ones(Int,2,2)
@@ -284,6 +285,9 @@
 @testset "IdentityMapping" begin
     @test IdentityMapping{Float64}((4,5)) isa IdentityMapping{T,2} where T
     @test IdentityMapping{Float64}((4,5)) isa TensorMapping{T,2,2} where T
+    @test IdentityMapping{Float64}((4,5)) == IdentityMapping{Float64}(4,5)
+
+    @test IdentityMapping(3,2) isa IdentityMapping{Float64,2}
 
     for sz ∈ [(4,5),(3,),(5,6,4)]
         I = IdentityMapping{Float64}(sz)
@@ -298,7 +302,7 @@
     I = IdentityMapping{Float64}((4,5))
     v = rand(4,5)
     @inferred (I*v)[3,2]
-    @test_broken @inferred (I'*v)[3,2] # TODO: Should fix the index typing before investigating this
+    @inferred (I'*v)[3,2]
     @inferred range_size(I)
 end