Mercurial > repos > public > sbplib_julia
changeset 454:eb4c34438e30 feature/inflated_tensormapping
Merge in default
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Tue, 20 Oct 2020 09:59:44 +0200 |
parents | c1ae837f1a2e (diff) a79d7b3209c9 (current diff) |
children | b86312d14873 |
files | |
diffstat | 4 files changed, 198 insertions(+), 0 deletions(-) [+] |
line wrap: on
line diff
--- a/src/LazyTensors/lazy_tensor_operations.jl Tue Oct 20 09:23:16 2020 +0200 +++ b/src/LazyTensors/lazy_tensor_operations.jl Tue Oct 20 09:59:44 2020 +0200 @@ -170,3 +170,75 @@ apply(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...] apply_transpose(tmi::IdentityMapping{T,D}, v::AbstractArray{T,D}, I::Vararg{Any,D}) where {T,D} = v[I...] +struct InflatedTensorMapping{T,R,D,D_before,R_middle,D_middle,D_after, TM<:TensorMapping{T,R_middle,D_middle}} <: TensorMapping{T,R,D} + before::IdentityMapping{T,D_before} + tm::TM + after::IdentityMapping{T,D_after} + + function InflatedTensorMapping(before, tm::TensorMapping{T}, after) where T + R_before = range_dim(before) + R_middle = range_dim(tm) + R_after = range_dim(after) + R = R_before+R_middle+R_after + + D_before = domain_dim(before) + D_middle = domain_dim(tm) + D_after = domain_dim(after) + D = D_before+D_middle+D_after + return new{T,R,D,D_before,R_middle,D_middle,D_after, typeof(tm)}(before, tm, after) + end +end +export InflatedTensorMapping + +# TODO: Implement constructors where one of `before` or `after` is missing + +# TODO: Implement syntax and constructors for products of different combinations of InflatedTensorMapping and IdentityMapping + +# TODO: Implement some pretty printing in terms of ⊗. E.g InflatedTensorMapping(I(3),B,I(2)) -> I(3)⊗B⊗I(2) + +function range_size(itm::InflatedTensorMapping) + return flatten_tuple( + range_size(itm.before), + range_size(itm.tm), + range_size(itm.after), + ) +end + +function domain_size(itm::InflatedTensorMapping) + return flatten_tuple( + domain_size(itm.before), + domain_size(itm.tm), + domain_size(itm.after), + ) +end + +function apply(itm::InflatedTensorMapping{T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D} + view_index, inner_index = split_index(itm, I...) + + v_inner = view(v, view_index...) + return apply(itm.tm, v_inner, inner_index...) +end + + +""" + split_index(...) + +Splits the multi-index into two parts. One part for the view that the inner TensorMapping acts on, and one part for indexing the result +Eg. +``` +(1,2,3,4) -> (1,:,:,4), (2,3) +``` +""" +function split_index(itm::InflatedTensorMapping{T,R,D}, I::Vararg{Any,R}) where {T,R,D} + I_before = I[1:range_dim(itm.before)] + I_after = I[(end-range_dim(itm.after)+1):end] + + view_index = (I_before..., ntuple((i)->:,domain_dim(itm.tm))..., I_after...) + inner_index = I[range_dim(itm.before)+1:end-range_dim(itm.after)] + + return (view_index, inner_index) +end + +flatten_tuple(t::NTuple{N, Number} where N) = t +flatten_tuple(t::Tuple) = ((flatten_tuple.(t)...)...,) # simplify? +flatten_tuple(ts::Vararg) = flatten_tuple(ts)
--- a/test/Manifest.toml Tue Oct 20 09:23:16 2020 +0200 +++ b/test/Manifest.toml Tue Oct 20 09:59:44 2020 +0200 @@ -1,13 +1,35 @@ # This file is machine-generated - editing it directly is not advised +[[Artifacts]] +deps = ["Pkg"] +git-tree-sha1 = "c30985d8821e0cd73870b17b0ed0ce6dc44cb744" +uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" +version = "1.3.0" + [[Base64]] uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" +[[CompilerSupportLibraries_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "8e695f735fca77e9708e795eda62afdb869cbb70" +uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" +version = "0.3.4+0" + +[[Dates]] +deps = ["Printf"] +uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" + [[DeepDiffs]] git-tree-sha1 = "9824894295b62a6a4ab6adf1c7bf337b3a9ca34c" uuid = "ab62b9b5-e342-54a8-a765-a90f495de1a6" version = "1.2.0" +[[DiffRules]] +deps = ["NaNMath", "Random", "SpecialFunctions"] +git-tree-sha1 = "eb0c34204c8410888844ada5359ac8b96292cfd1" +uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" +version = "1.0.1" + [[Distributed]] deps = ["Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" @@ -16,6 +38,15 @@ deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +[[JLLWrappers]] +git-tree-sha1 = "7cec881362e5b4e367ff0279dd99a06526d51a55" +uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" +version = "1.1.2" + +[[LibGit2]] +deps = ["Printf"] +uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" + [[Libdl]] uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" @@ -30,16 +61,54 @@ deps = ["Base64"] uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" +[[NaNMath]] +git-tree-sha1 = "c84c576296d0e2fbb3fc134d3e09086b3ea617cd" +uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" +version = "0.3.4" + +[[OpenSpecFun_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "9db77584158d0ab52307f8c04f8e7c08ca76b5b3" +uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" +version = "0.5.3+4" + +[[Pkg]] +deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] +uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" + +[[Printf]] +deps = ["Unicode"] +uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" + +[[REPL]] +deps = ["InteractiveUtils", "Markdown", "Sockets"] +uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" + [[Random]] deps = ["Serialization"] uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +[[Requires]] +deps = ["UUIDs"] +git-tree-sha1 = "28faf1c963ca1dc3ec87f166d92982e3c4a1f66d" +uuid = "ae029012-a4dd-5104-9daa-d747884805df" +version = "1.1.0" + +[[SHA]] +uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" + [[Serialization]] uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" [[Sockets]] uuid = "6462fe0b-24de-5631-8697-dd941f90decc" +[[SpecialFunctions]] +deps = ["OpenSpecFun_jll"] +git-tree-sha1 = "d8d8b8a9f4119829410ecd706da4cc8594a1e020" +uuid = "276daf66-3868-5448-9aa4-cd146d93841b" +version = "0.10.3" + [[Test]] deps = ["Distributed", "InteractiveUtils", "Logging", "Random"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" @@ -49,3 +118,16 @@ git-tree-sha1 = "3a2919a78b04c29a1a57b05e1618e473162b15d0" uuid = "98d24dd4-01ad-11ea-1b02-c9a08f80db04" version = "2.0.0" + +[[Tullio]] +deps = ["DiffRules", "LinearAlgebra", "Requires"] +git-tree-sha1 = "b27ec3ce782f69c1c24f373bfb6aa60300ed57c7" +uuid = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc" +version = "0.2.8" + +[[UUIDs]] +deps = ["Random", "SHA"] +uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" + +[[Unicode]] +uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
--- a/test/Project.toml Tue Oct 20 09:23:16 2020 +0200 +++ b/test/Project.toml Tue Oct 20 09:59:44 2020 +0200 @@ -2,3 +2,4 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TestSetExtensions = "98d24dd4-01ad-11ea-1b02-c9a08f80db04" +Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
--- a/test/testLazyTensors.jl Tue Oct 20 09:23:16 2020 +0200 +++ b/test/testLazyTensors.jl Tue Oct 20 09:59:44 2020 +0200 @@ -2,6 +2,8 @@ using Sbplib.LazyTensors using Sbplib.RegionIndices +using Tullio + @testset "LazyTensors" begin @testset "Generic Mapping methods" begin @@ -304,6 +306,47 @@ @inferred (I*v)[3,2] @inferred (I'*v)[3,2] @inferred range_size(I) + + @inferred range_dim(I) + @inferred domain_dim(I) +end + +@testset "InflatedTensorMapping" begin + I(sz...) = IdentityMapping(sz...) + + Ã = rand(4,2) + B̃ = rand(4,2,3) + C̃ = rand(4,2,3) + + A = LazyLinearMap(Ã,(1,),(2,)) + B = LazyLinearMap(B̃,(1,2),(3,)) + C = LazyLinearMap(C̃,(1,),(2,3)) + + @test InflatedTensorMapping(I(3,2), A, I(4)) isa TensorMapping{Float64, 4, 4} + @test InflatedTensorMapping(I(3,2), B, I(4)) isa TensorMapping{Float64, 5, 4} + @test InflatedTensorMapping(I(3), C, I(2,3)) isa TensorMapping{Float64, 4, 5} + + @test range_size(InflatedTensorMapping(I(3,2), A, I(4))) == (3,2,4,4) + @test domain_size(InflatedTensorMapping(I(3,2), A, I(4))) == (3,2,2,4) + + @test range_size(InflatedTensorMapping(I(3,2), B, I(4))) == (3,2,4,2,4) + @test domain_size(InflatedTensorMapping(I(3,2), B, I(4))) == (3,2,3,4) + + @test range_size(InflatedTensorMapping(I(3), C, I(2,3))) == (3,4,2,3) + @test domain_size(InflatedTensorMapping(I(3), C, I(2,3))) == (3,2,3,2,3) + + @inferred range_size(InflatedTensorMapping(I(3,2), A, I(4))) == (3,2,4,4) + @inferred domain_size(InflatedTensorMapping(I(3,2), A, I(4))) == (3,2,2,4) + + tm = InflatedTensorMapping(I(3,2), A, I(4)) + v = rand(domain_size(tm)...) + + @tullio IAIv[a,b,c,d] := Ã[c,i]*v[a,b,i,d] + @test tm*v ≈ IAIv rtol=1e-14 + + @inferred LazyTensors.split_index(tm,1,1,1,1) + @inferred (tm*v)[1,1,1,1] + end end