Mercurial > repos > public > sbplib_julia
comparison src/LazyTensors/lazy_tensor_operations.jl @ 1788:8b64df6cadba refactor/lazy_tensors/elementwise_ops
Refactor ElementWiseOperation to give a flatter structure of tensor compositions improving type inference
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Wed, 25 Sep 2024 10:25:30 +0200 |
parents | df1856b0e2f0 |
children | 48eaa973159a |
comparison
equal
deleted
inserted
replaced
1770:fbbadc6df706 | 1788:8b64df6cadba |
---|---|
50 | 50 |
51 range_size(tmt::TensorTranspose) = domain_size(tmt.tm) | 51 range_size(tmt::TensorTranspose) = domain_size(tmt.tm) |
52 domain_size(tmt::TensorTranspose) = range_size(tmt.tm) | 52 domain_size(tmt::TensorTranspose) = range_size(tmt.tm) |
53 | 53 |
54 | 54 |
55 struct ElementwiseTensorOperation{Op,T,R,D,T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}} <: LazyTensor{T,R,D} | 55 struct ElementwiseTensorOperation{Op,T,R,D,TT<:NTuple{N, LazyTensor{T,R,D}} where N} <: LazyTensor{T,R,D} |
56 tm1::T1 | 56 tms::TT |
57 tm2::T2 | 57 |
58 | 58 function ElementwiseTensorOperation{Op,T,R,D}(tms::TT) where {Op,T,R,D, TT<:NTuple{N, LazyTensor{T,R,D}} where N} |
59 function ElementwiseTensorOperation{Op,T,R,D}(tm1::T1,tm2::T2) where {Op,T,R,D, T1<:LazyTensor{T,R,D},T2<:LazyTensor{T,R,D}} | 59 @boundscheck map(tms) do tm |
60 @boundscheck check_domain_size(tm2, domain_size(tm1)) | 60 check_domain_size(tm, domain_size(tms[1])) |
61 @boundscheck check_range_size(tm2, range_size(tm1)) | 61 check_range_size(tm, range_size(tms[1])) |
62 return new{Op,T,R,D,T1,T2}(tm1,tm2) | 62 end |
63 end | 63 |
64 end | 64 return new{Op,T,R,D,TT}(tms) |
65 | 65 end |
66 ElementwiseTensorOperation{Op}(s,t) where Op = ElementwiseTensorOperation{Op,eltype(s), range_dim(s), domain_dim(s)}(s,t) | 66 end |
67 | 67 |
68 apply(tmBinOp::ElementwiseTensorOperation{:+,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) + apply(tmBinOp.tm2, v, I...) | 68 |
69 apply(tmBinOp::ElementwiseTensorOperation{:-,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} = apply(tmBinOp.tm1, v, I...) - apply(tmBinOp.tm2, v, I...) | 69 function ElementwiseTensorOperation{:+}(ts::Vararg{LazyTensor}) |
70 | 70 return ElementwiseTensorOperation{:+,eltype(ts[1]), range_dim(ts[1]), domain_dim(ts[1])}(ts) |
71 range_size(tmBinOp::ElementwiseTensorOperation) = range_size(tmBinOp.tm1) | 71 end |
72 domain_size(tmBinOp::ElementwiseTensorOperation) = domain_size(tmBinOp.tm1) | 72 |
73 # The following methods for :+ are intended to reduce the depth of the tree of operations in some caes | |
74 function ElementwiseTensorOperation{:+}(t1::ElementwiseTensorOperation{:+}, t2::ElementwiseTensorOperation{:+}) | |
75 ElementwiseTensorOperation{:+}(t1.tms..., t2.tms...) | |
76 end | |
77 | |
78 function ElementwiseTensorOperation{:+}(t1::ElementwiseTensorOperation{:+}, t2::LazyTensor) | |
79 ElementwiseTensorOperation{:+}(t1.tms..., t2) | |
80 end | |
81 | |
82 function ElementwiseTensorOperation{:+}(t1::LazyTensor, t2::ElementwiseTensorOperation{:+}) | |
83 ElementwiseTensorOperation{:+}(t1, t2.tms...) | |
84 end | |
85 | |
86 function ElementwiseTensorOperation{:-}(t1::LazyTensor, t2::LazyTensor) | |
87 return ElementwiseTensorOperation{:-,eltype(t1), range_dim(t1), domain_dim(t1)}((t1,t2)) | |
88 end | |
89 | |
90 function apply(tmBinOp::ElementwiseTensorOperation{:+,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} | |
91 vs = map(tmBinOp.tms) do tm | |
92 apply(tm,v,I...) | |
93 end | |
94 | |
95 return +(vs...) | |
96 end | |
97 function apply(tmBinOp::ElementwiseTensorOperation{:-,T,R,D}, v::AbstractArray{<:Any,D}, I::Vararg{Any,R}) where {T,R,D} | |
98 apply(tmBinOp.tms[1], v, I...) - apply(tmBinOp.tms[2], v, I...) | |
99 end | |
100 | |
101 range_size(tmBinOp::ElementwiseTensorOperation) = range_size(tmBinOp.tms[1]) | |
102 domain_size(tmBinOp::ElementwiseTensorOperation) = domain_size(tmBinOp.tms[1]) | |
73 | 103 |
74 | 104 |
75 """ | 105 """ |
76 TensorComposition{T,R,K,D} | 106 TensorComposition{T,R,K,D} |
77 | 107 |