changeset 942:7829c09f8137 feature/tensormapping_application_promotion

Add promotion calculation of element type for LazyTensorMappingApplication
author Jonatan Werpers <jonatan@werpers.com>
date Thu, 10 Mar 2022 11:13:34 +0100
parents de1625deb27e
children fb060e98ac0a
files src/LazyTensors/lazy_tensor_operations.jl test/LazyTensors/lazy_tensor_operations_test.jl
diffstat 2 files changed, 28 insertions(+), 1 deletions(-) [+]
line wrap: on
line diff
--- a/src/LazyTensors/lazy_tensor_operations.jl	Mon Feb 21 10:38:19 2022 +0100
+++ b/src/LazyTensors/lazy_tensor_operations.jl	Thu Mar 10 11:13:34 2022 +0100
@@ -7,9 +7,14 @@
 With a mapping `m` and a vector `v` the LazyTensorMappingApplication object can be created by `m*v`.
 The actual result will be calcualted when indexing into `m*v`.
 """
-struct LazyTensorMappingApplication{T,R,D, TM<:TensorMapping{T,R,D}, AA<:AbstractArray{T,D}} <: LazyArray{T,R}
+struct LazyTensorMappingApplication{T,R,D, TM<:TensorMapping{<:Any,R,D}, AA<:AbstractArray{<:Any,D}} <: LazyArray{T,R}
     t::TM
     o::AA
+
+    function LazyTensorMappingApplication(t::TensorMapping{<:Any,R,D}, o::AbstractArray{<:Any,D}) where {R,D}
+        T = promote_type(eltype(t), eltype(o))
+        return new{T,R,D,typeof(t), typeof(o)}(t,o)
+    end
 end
 # TODO: Do boundschecking on creation!
 export LazyTensorMappingApplication
--- a/test/LazyTensors/lazy_tensor_operations_test.jl	Mon Feb 21 10:38:19 2022 +0100
+++ b/test/LazyTensors/lazy_tensor_operations_test.jl	Thu Mar 10 11:13:34 2022 +0100
@@ -74,6 +74,28 @@
     v = [[1 2];[3 4]]
     @test m*v == [[2 4];[6 8]]
     @test (m*v)[2,1] == 6
+
+    @testset "Promotion" begin
+        m = ScalingOperator{Int,1}(2,(3,))
+        v = [1.,2.,3.]
+        @test m*v isa AbstractVector{Float64}
+        @test m*v == [2.,4.,6.]
+
+        m = ScalingOperator{Int,2}(2,(2,2))
+        v = [[1. 2.];[3. 4.]]
+        @test m*v == [[2. 4.];[6. 8.]]
+        @test (m*v)[2,1] == 6.
+
+        m = ScalingOperator{ComplexF64,1}(2. +2. *im,(3,))
+        v = [1.,2.,3.]
+        @test m*v isa AbstractVector{ComplexF64}
+        @test m*v == [2. + 2. *im, 4. + 4. *im, 6. + 6. *im]
+
+        m = ScalingOperator{ComplexF64,1}(1,(3,))
+        v = [2. + 2. *im, 4. + 4. *im, 6. + 6. *im]
+        @test m*v isa AbstractVector{ComplexF64}
+        @test m*v == [2. + 2. *im, 4. + 4. *im, 6. + 6. *im]
+    end
 end
 
 @testset "TensorMapping binary operations" begin