Mercurial > repos > public > sbplib_julia
comparison src/LazyTensors/lazy_tensor_operations.jl @ 541:62d96e2cd165 refactor/tensor_index_coupling
Make the coupling between all the LazyTensors code and the Index type much weaker to make the module more flexible
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Thu, 26 Nov 2020 21:35:34 +0100 |
parents | 848dec405332 |
children | 53828d3ed132 |
comparison
equal
deleted
inserted
replaced
540:013ca4892540 | 541:62d96e2cd165 |
---|---|
14 # TODO: Do boundschecking on creation! | 14 # TODO: Do boundschecking on creation! |
15 export LazyTensorMappingApplication | 15 export LazyTensorMappingApplication |
16 | 16 |
17 # TODO: Go through and remove unneccerary type parameters on functions | 17 # TODO: Go through and remove unneccerary type parameters on functions |
18 | 18 |
19 Base.getindex(ta::LazyTensorMappingApplication{T,R,D}, I::Vararg{Index,R}) where {T,R,D} = apply(ta.t, ta.o, I...) | 19 Base.getindex(ta::LazyTensorMappingApplication{T,R,D}, I::Vararg{Any,R}) where {T,R,D} = apply(ta.t, ta.o, I...) |
20 Base.getindex(ta::LazyTensorMappingApplication{T,R,D}, I::Vararg{Int,R}) where {T,R,D} = apply(ta.t, ta.o, Index{Unknown}.(I)...) | |
21 Base.size(ta::LazyTensorMappingApplication) = range_size(ta.t) | 20 Base.size(ta::LazyTensorMappingApplication) = range_size(ta.t) |
22 # TODO: What else is needed to implement the AbstractArray interface? | 21 # TODO: What else is needed to implement the AbstractArray interface? |
23 | 22 |
24 Base.:*(a::TensorMapping, v::AbstractArray) = LazyTensorMappingApplication(a,v) | 23 Base.:*(a::TensorMapping, v::AbstractArray) = LazyTensorMappingApplication(a,v) |
25 Base.:*(a::TensorMapping, b::TensorMapping) = throw(MethodError(Base.:*,(a,b))) | 24 Base.:*(a::TensorMapping, b::TensorMapping) = throw(MethodError(Base.:*,(a,b))) |
48 # # TBD: Should this be implemented on a type by type basis or through a trait to provide earlier errors? | 47 # # TBD: Should this be implemented on a type by type basis or through a trait to provide earlier errors? |
49 # Jonatan 2020-09-25: Is the problem that you can take the transpose of any TensorMapping even if it doesn't implement `apply_transpose`? | 48 # Jonatan 2020-09-25: Is the problem that you can take the transpose of any TensorMapping even if it doesn't implement `apply_transpose`? |
50 Base.adjoint(tm::TensorMapping) = LazyTensorMappingTranspose(tm) | 49 Base.adjoint(tm::TensorMapping) = LazyTensorMappingTranspose(tm) |
51 Base.adjoint(tmt::LazyTensorMappingTranspose) = tmt.tm | 50 Base.adjoint(tmt::LazyTensorMappingTranspose) = tmt.tm |
52 | 51 |
53 apply(tmt::LazyTensorMappingTranspose{T,R,D}, v::AbstractArray{T,R}, I::Vararg{Index,D}) where {T,R,D} = apply_transpose(tmt.tm, v, I...) | 52 apply(tmt::LazyTensorMappingTranspose{T,R,D}, v::AbstractArray{T,R}, I::Vararg{Any,D}) where {T,R,D} = apply_transpose(tmt.tm, v, I...) |
54 apply_transpose(tmt::LazyTensorMappingTranspose{T,R,D}, v::AbstractArray{T,D}, I::Vararg{Index,R}) where {T,R,D} = apply(tmt.tm, v, I...) | 53 apply_transpose(tmt::LazyTensorMappingTranspose{T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmt.tm, v, I...) |
55 | 54 |
56 range_size(tmt::LazyTensorMappingTranspose) = domain_size(tmt.tm) | 55 range_size(tmt::LazyTensorMappingTranspose) = domain_size(tmt.tm) |
57 domain_size(tmt::LazyTensorMappingTranspose) = range_size(tmt.tm) | 56 domain_size(tmt::LazyTensorMappingTranspose) = range_size(tmt.tm) |
58 | 57 |
59 | 58 |
65 return new{Op,T,R,D,T1,T2}(tm1,tm2) | 64 return new{Op,T,R,D,T1,T2}(tm1,tm2) |
66 end | 65 end |
67 end | 66 end |
68 # TODO: Boundschecking in constructor. | 67 # TODO: Boundschecking in constructor. |
69 | 68 |
70 apply(tmBinOp::LazyTensorMappingBinaryOperation{:+,T,R,D}, v::AbstractArray{T,D}, I::Vararg{Index,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) + apply(tmBinOp.tm2, v, I...) | 69 apply(tmBinOp::LazyTensorMappingBinaryOperation{:+,T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) + apply(tmBinOp.tm2, v, I...) |
71 apply(tmBinOp::LazyTensorMappingBinaryOperation{:-,T,R,D}, v::AbstractArray{T,D}, I::Vararg{Index,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) - apply(tmBinOp.tm2, v, I...) | 70 apply(tmBinOp::LazyTensorMappingBinaryOperation{:-,T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) - apply(tmBinOp.tm2, v, I...) |
72 | 71 |
73 range_size(tmBinOp::LazyTensorMappingBinaryOperation{Op,T,R,D}) where {Op,T,R,D} = range_size(tmBinOp.tm1) | 72 range_size(tmBinOp::LazyTensorMappingBinaryOperation{Op,T,R,D}) where {Op,T,R,D} = range_size(tmBinOp.tm1) |
74 domain_size(tmBinOp::LazyTensorMappingBinaryOperation{Op,T,R,D}) where {Op,T,R,D} = domain_size(tmBinOp.tm1) | 73 domain_size(tmBinOp::LazyTensorMappingBinaryOperation{Op,T,R,D}) where {Op,T,R,D} = domain_size(tmBinOp.tm1) |
75 | 74 |
76 Base.:+(tm1::TensorMapping{T,R,D}, tm2::TensorMapping{T,R,D}) where {T,R,D} = LazyTensorMappingBinaryOperation{:+,T,R,D}(tm1,tm2) | 75 Base.:+(tm1::TensorMapping{T,R,D}, tm2::TensorMapping{T,R,D}) where {T,R,D} = LazyTensorMappingBinaryOperation{:+,T,R,D}(tm1,tm2) |
130 export LazyLinearMap | 129 export LazyLinearMap |
131 | 130 |
132 range_size(llm::LazyLinearMap) = size(llm.A)[[llm.range_indicies...]] | 131 range_size(llm::LazyLinearMap) = size(llm.A)[[llm.range_indicies...]] |
133 domain_size(llm::LazyLinearMap) = size(llm.A)[[llm.domain_indicies...]] | 132 domain_size(llm::LazyLinearMap) = size(llm.A)[[llm.domain_indicies...]] |
134 | 133 |
135 function apply(llm::LazyLinearMap{T,R,D}, v::AbstractArray{T,D}, I::Vararg{Index,R}) where {T,R,D} | 134 function apply(llm::LazyLinearMap{T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D} |
136 view_index = ntuple(i->:,ndims(llm.A)) | 135 view_index = ntuple(i->:,ndims(llm.A)) |
137 for i ∈ 1:R | 136 for i ∈ 1:R |
138 view_index = Base.setindex(view_index, Int(I[i]), llm.range_indicies[i]) | 137 view_index = Base.setindex(view_index, Int(I[i]), llm.range_indicies[i]) |
139 end | 138 end |
140 A_view = @view llm.A[view_index...] | 139 A_view = @view llm.A[view_index...] |
141 return sum(A_view.*v) | 140 return sum(A_view.*v) |
142 end | 141 end |
143 | 142 |
144 function apply_transpose(llm::LazyLinearMap{T,R,D}, v::AbstractArray{T,R}, I::Vararg{Index,D}) where {T,R,D} | 143 function apply_transpose(llm::LazyLinearMap{T,R,D}, v::AbstractArray{T,R}, I::Vararg{Any,D}) where {T,R,D} |
145 apply(LazyLinearMap(llm.A, llm.domain_indicies, llm.range_indicies), v, I...) | 144 apply(LazyLinearMap(llm.A, llm.domain_indicies, llm.range_indicies), v, I...) |
146 end | 145 end |
147 | 146 |
148 | 147 |
149 """ | 148 """ |