changeset 452:aeda2698166d feature/inflated_tensormapping

Add tullio as a test dependency and add a test for apply
author Jonatan Werpers <jonatan@werpers.com>
date Mon, 19 Oct 2020 22:34:58 +0200
parents 6cf234eef780
children c1ae837f1a2e
files src/LazyTensors/lazy_tensor_operations.jl test/Manifest.toml test/Project.toml test/runtests.jl test/testLazyTensors.jl
diffstat 5 files changed, 102 insertions(+), 9 deletions(-) [+]
line wrap: on
line diff
--- a/src/LazyTensors/lazy_tensor_operations.jl	Mon Oct 19 22:03:59 2020 +0200
+++ b/src/LazyTensors/lazy_tensor_operations.jl	Mon Oct 19 22:34:58 2020 +0200
@@ -212,8 +212,8 @@
     )
 end
 
-function apply(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{T,D}, I::Vararg{S,R} where S) where {T,R,D}
-    view_index, inner_index = split_index(I...)
+function apply(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D}
+    view_index, inner_index = split_index(itm, I...)
 
     v_inner = view(v, view_index...)
     return apply(itm.tm, v_inner, inner_index...)
@@ -229,16 +229,14 @@
 (1,2,3,4) -> (1,:,:,4), (2,3)
 ```
 """
-function split_index(itm::InflatedTensorMapping{T,R,D}, I::Vararg{S,R} where S) where {T,R,D}
+function split_index(itm::InflatedTensorMapping{T,R,D}, I::Vararg{Any,R}) where {T,R,D}
     I_before = I[1:range_dim(itm.before)]
     I_after = I[(end-range_dim(itm.after)+1):end]
 
     view_index = (I_before..., ntuple((i)->:,domain_dim(itm.tm))..., I_after...)
-    A_view = @view llm.A[view_index...]
     inner_index = I[range_dim(itm.before)+1:end-range_dim(itm.after)]
 
     return (view_index, inner_index)
-    return sum(A_view.*v)
 end
 
 flatten_tuple(t::NTuple{N, Number} where N) = t
--- a/test/Manifest.toml	Mon Oct 19 22:03:59 2020 +0200
+++ b/test/Manifest.toml	Mon Oct 19 22:34:58 2020 +0200
@@ -1,13 +1,35 @@
 # This file is machine-generated - editing it directly is not advised
 
+[[Artifacts]]
+deps = ["Pkg"]
+git-tree-sha1 = "c30985d8821e0cd73870b17b0ed0ce6dc44cb744"
+uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
+version = "1.3.0"
+
 [[Base64]]
 uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
 
+[[CompilerSupportLibraries_jll]]
+deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"]
+git-tree-sha1 = "8e695f735fca77e9708e795eda62afdb869cbb70"
+uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
+version = "0.3.4+0"
+
+[[Dates]]
+deps = ["Printf"]
+uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
+
 [[DeepDiffs]]
 git-tree-sha1 = "9824894295b62a6a4ab6adf1c7bf337b3a9ca34c"
 uuid = "ab62b9b5-e342-54a8-a765-a90f495de1a6"
 version = "1.2.0"
 
+[[DiffRules]]
+deps = ["NaNMath", "Random", "SpecialFunctions"]
+git-tree-sha1 = "eb0c34204c8410888844ada5359ac8b96292cfd1"
+uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
+version = "1.0.1"
+
 [[Distributed]]
 deps = ["Random", "Serialization", "Sockets"]
 uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
@@ -16,6 +38,15 @@
 deps = ["Markdown"]
 uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
 
+[[JLLWrappers]]
+git-tree-sha1 = "7cec881362e5b4e367ff0279dd99a06526d51a55"
+uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210"
+version = "1.1.2"
+
+[[LibGit2]]
+deps = ["Printf"]
+uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
+
 [[Libdl]]
 uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
 
@@ -30,16 +61,54 @@
 deps = ["Base64"]
 uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
 
+[[NaNMath]]
+git-tree-sha1 = "c84c576296d0e2fbb3fc134d3e09086b3ea617cd"
+uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
+version = "0.3.4"
+
+[[OpenSpecFun_jll]]
+deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"]
+git-tree-sha1 = "9db77584158d0ab52307f8c04f8e7c08ca76b5b3"
+uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e"
+version = "0.5.3+4"
+
+[[Pkg]]
+deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
+uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
+
+[[Printf]]
+deps = ["Unicode"]
+uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
+
+[[REPL]]
+deps = ["InteractiveUtils", "Markdown", "Sockets"]
+uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
+
 [[Random]]
 deps = ["Serialization"]
 uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
 
+[[Requires]]
+deps = ["UUIDs"]
+git-tree-sha1 = "28faf1c963ca1dc3ec87f166d92982e3c4a1f66d"
+uuid = "ae029012-a4dd-5104-9daa-d747884805df"
+version = "1.1.0"
+
+[[SHA]]
+uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
+
 [[Serialization]]
 uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
 
 [[Sockets]]
 uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
 
+[[SpecialFunctions]]
+deps = ["OpenSpecFun_jll"]
+git-tree-sha1 = "d8d8b8a9f4119829410ecd706da4cc8594a1e020"
+uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
+version = "0.10.3"
+
 [[Test]]
 deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
 uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
@@ -49,3 +118,16 @@
 git-tree-sha1 = "3a2919a78b04c29a1a57b05e1618e473162b15d0"
 uuid = "98d24dd4-01ad-11ea-1b02-c9a08f80db04"
 version = "2.0.0"
+
+[[Tullio]]
+deps = ["DiffRules", "LinearAlgebra", "Requires"]
+git-tree-sha1 = "b27ec3ce782f69c1c24f373bfb6aa60300ed57c7"
+uuid = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
+version = "0.2.8"
+
+[[UUIDs]]
+deps = ["Random", "SHA"]
+uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
+
+[[Unicode]]
+uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
--- a/test/Project.toml	Mon Oct 19 22:03:59 2020 +0200
+++ b/test/Project.toml	Mon Oct 19 22:34:58 2020 +0200
@@ -2,3 +2,4 @@
 LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
 Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
 TestSetExtensions = "98d24dd4-01ad-11ea-1b02-c9a08f80db04"
+Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
--- a/test/runtests.jl	Mon Oct 19 22:03:59 2020 +0200
+++ b/test/runtests.jl	Mon Oct 19 22:34:58 2020 +0200
@@ -1,6 +1,6 @@
 using Test
 using TestSetExtensions
 
-@testset ExtendedTestSet "All" begin
+@testset "All" begin
     @includetests ARGS
 end
--- a/test/testLazyTensors.jl	Mon Oct 19 22:03:59 2020 +0200
+++ b/test/testLazyTensors.jl	Mon Oct 19 22:34:58 2020 +0200
@@ -2,6 +2,8 @@
 using Sbplib.LazyTensors
 using Sbplib.RegionIndices
 
+using Tullio
+
 @testset "LazyTensors" begin
 
 @testset "Generic Mapping methods" begin
@@ -308,9 +310,14 @@
 
 @testset "InflatedTensorMapping" begin
     I(sz...) = IdentityMapping(sz...)
-    A = LazyLinearMap(rand(4,2),(1,),(2,))
-    B = LazyLinearMap(rand(4,2,3),(1,2),(3,))
-    C = LazyLinearMap(rand(4,2,3),(1,),(2,3))
+
+    Ã = rand(4,2)
+    B̃ = rand(4,2,3)
+    C̃ = rand(4,2,3)
+
+    A = LazyLinearMap(Ã,(1,),(2,))
+    B = LazyLinearMap(B̃,(1,2),(3,))
+    C = LazyLinearMap(C̃,(1,),(2,3))
 
     @test InflatedTensorMapping(I(3,2), A, I(4)) isa TensorMapping{Float64, 4, 4}
     @test InflatedTensorMapping(I(3,2), B, I(4)) isa TensorMapping{Float64, 5, 4}
@@ -328,6 +335,11 @@
     @inferred range_size(InflatedTensorMapping(I(3,2), A, I(4))) == (3,2,4,4)
     @inferred domain_size(InflatedTensorMapping(I(3,2), A, I(4))) == (3,2,2,4)
 
+    tm = InflatedTensorMapping(I(3,2), A, I(4))
+    v = rand(domain_size(tm)...)
+
+    @tullio IAIv[a,b,c,d] := Ã[c,i]*v[a,b,i,d]
+    @test tm*v ≈ IAIv rtol=1e-14
 
 end