comparison src/LazyTensors/lazy_tensor_operations.jl @ 541:62d96e2cd165 refactor/tensor_index_coupling

Make the coupling between all the LazyTensors code and the Index type much weaker to make the module more flexible
author Jonatan Werpers <jonatan@werpers.com>
date Thu, 26 Nov 2020 21:35:34 +0100
parents 848dec405332
children 53828d3ed132
comparison
equal deleted inserted replaced
540:013ca4892540 541:62d96e2cd165
14 # TODO: Do boundschecking on creation! 14 # TODO: Do boundschecking on creation!
15 export LazyTensorMappingApplication 15 export LazyTensorMappingApplication
16 16
17 # TODO: Go through and remove unneccerary type parameters on functions 17 # TODO: Go through and remove unneccerary type parameters on functions
18 18
19 Base.getindex(ta::LazyTensorMappingApplication{T,R,D}, I::Vararg{Index,R}) where {T,R,D} = apply(ta.t, ta.o, I...) 19 Base.getindex(ta::LazyTensorMappingApplication{T,R,D}, I::Vararg{Any,R}) where {T,R,D} = apply(ta.t, ta.o, I...)
20 Base.getindex(ta::LazyTensorMappingApplication{T,R,D}, I::Vararg{Int,R}) where {T,R,D} = apply(ta.t, ta.o, Index{Unknown}.(I)...)
21 Base.size(ta::LazyTensorMappingApplication) = range_size(ta.t) 20 Base.size(ta::LazyTensorMappingApplication) = range_size(ta.t)
22 # TODO: What else is needed to implement the AbstractArray interface? 21 # TODO: What else is needed to implement the AbstractArray interface?
23 22
24 Base.:*(a::TensorMapping, v::AbstractArray) = LazyTensorMappingApplication(a,v) 23 Base.:*(a::TensorMapping, v::AbstractArray) = LazyTensorMappingApplication(a,v)
25 Base.:*(a::TensorMapping, b::TensorMapping) = throw(MethodError(Base.:*,(a,b))) 24 Base.:*(a::TensorMapping, b::TensorMapping) = throw(MethodError(Base.:*,(a,b)))
48 # # TBD: Should this be implemented on a type by type basis or through a trait to provide earlier errors? 47 # # TBD: Should this be implemented on a type by type basis or through a trait to provide earlier errors?
49 # Jonatan 2020-09-25: Is the problem that you can take the transpose of any TensorMapping even if it doesn't implement `apply_transpose`? 48 # Jonatan 2020-09-25: Is the problem that you can take the transpose of any TensorMapping even if it doesn't implement `apply_transpose`?
50 Base.adjoint(tm::TensorMapping) = LazyTensorMappingTranspose(tm) 49 Base.adjoint(tm::TensorMapping) = LazyTensorMappingTranspose(tm)
51 Base.adjoint(tmt::LazyTensorMappingTranspose) = tmt.tm 50 Base.adjoint(tmt::LazyTensorMappingTranspose) = tmt.tm
52 51
53 apply(tmt::LazyTensorMappingTranspose{T,R,D}, v::AbstractArray{T,R}, I::Vararg{Index,D}) where {T,R,D} = apply_transpose(tmt.tm, v, I...) 52 apply(tmt::LazyTensorMappingTranspose{T,R,D}, v::AbstractArray{T,R}, I::Vararg{Any,D}) where {T,R,D} = apply_transpose(tmt.tm, v, I...)
54 apply_transpose(tmt::LazyTensorMappingTranspose{T,R,D}, v::AbstractArray{T,D}, I::Vararg{Index,R}) where {T,R,D} = apply(tmt.tm, v, I...) 53 apply_transpose(tmt::LazyTensorMappingTranspose{T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmt.tm, v, I...)
55 54
56 range_size(tmt::LazyTensorMappingTranspose) = domain_size(tmt.tm) 55 range_size(tmt::LazyTensorMappingTranspose) = domain_size(tmt.tm)
57 domain_size(tmt::LazyTensorMappingTranspose) = range_size(tmt.tm) 56 domain_size(tmt::LazyTensorMappingTranspose) = range_size(tmt.tm)
58 57
59 58
65 return new{Op,T,R,D,T1,T2}(tm1,tm2) 64 return new{Op,T,R,D,T1,T2}(tm1,tm2)
66 end 65 end
67 end 66 end
68 # TODO: Boundschecking in constructor. 67 # TODO: Boundschecking in constructor.
69 68
70 apply(tmBinOp::LazyTensorMappingBinaryOperation{:+,T,R,D}, v::AbstractArray{T,D}, I::Vararg{Index,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) + apply(tmBinOp.tm2, v, I...) 69 apply(tmBinOp::LazyTensorMappingBinaryOperation{:+,T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) + apply(tmBinOp.tm2, v, I...)
71 apply(tmBinOp::LazyTensorMappingBinaryOperation{:-,T,R,D}, v::AbstractArray{T,D}, I::Vararg{Index,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) - apply(tmBinOp.tm2, v, I...) 70 apply(tmBinOp::LazyTensorMappingBinaryOperation{:-,T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) - apply(tmBinOp.tm2, v, I...)
72 71
73 range_size(tmBinOp::LazyTensorMappingBinaryOperation{Op,T,R,D}) where {Op,T,R,D} = range_size(tmBinOp.tm1) 72 range_size(tmBinOp::LazyTensorMappingBinaryOperation{Op,T,R,D}) where {Op,T,R,D} = range_size(tmBinOp.tm1)
74 domain_size(tmBinOp::LazyTensorMappingBinaryOperation{Op,T,R,D}) where {Op,T,R,D} = domain_size(tmBinOp.tm1) 73 domain_size(tmBinOp::LazyTensorMappingBinaryOperation{Op,T,R,D}) where {Op,T,R,D} = domain_size(tmBinOp.tm1)
75 74
76 Base.:+(tm1::TensorMapping{T,R,D}, tm2::TensorMapping{T,R,D}) where {T,R,D} = LazyTensorMappingBinaryOperation{:+,T,R,D}(tm1,tm2) 75 Base.:+(tm1::TensorMapping{T,R,D}, tm2::TensorMapping{T,R,D}) where {T,R,D} = LazyTensorMappingBinaryOperation{:+,T,R,D}(tm1,tm2)
130 export LazyLinearMap 129 export LazyLinearMap
131 130
132 range_size(llm::LazyLinearMap) = size(llm.A)[[llm.range_indicies...]] 131 range_size(llm::LazyLinearMap) = size(llm.A)[[llm.range_indicies...]]
133 domain_size(llm::LazyLinearMap) = size(llm.A)[[llm.domain_indicies...]] 132 domain_size(llm::LazyLinearMap) = size(llm.A)[[llm.domain_indicies...]]
134 133
135 function apply(llm::LazyLinearMap{T,R,D}, v::AbstractArray{T,D}, I::Vararg{Index,R}) where {T,R,D} 134 function apply(llm::LazyLinearMap{T,R,D}, v::AbstractArray{T,D}, I::Vararg{Any,R}) where {T,R,D}
136 view_index = ntuple(i->:,ndims(llm.A)) 135 view_index = ntuple(i->:,ndims(llm.A))
137 for i ∈ 1:R 136 for i ∈ 1:R
138 view_index = Base.setindex(view_index, Int(I[i]), llm.range_indicies[i]) 137 view_index = Base.setindex(view_index, Int(I[i]), llm.range_indicies[i])
139 end 138 end
140 A_view = @view llm.A[view_index...] 139 A_view = @view llm.A[view_index...]
141 return sum(A_view.*v) 140 return sum(A_view.*v)
142 end 141 end
143 142
144 function apply_transpose(llm::LazyLinearMap{T,R,D}, v::AbstractArray{T,R}, I::Vararg{Index,D}) where {T,R,D} 143 function apply_transpose(llm::LazyLinearMap{T,R,D}, v::AbstractArray{T,R}, I::Vararg{Any,D}) where {T,R,D}
145 apply(LazyLinearMap(llm.A, llm.domain_indicies, llm.range_indicies), v, I...) 144 apply(LazyLinearMap(llm.A, llm.domain_indicies, llm.range_indicies), v, I...)
146 end 145 end
147 146
148 147
149 """ 148 """