Mercurial > repos > public > sbplib_julia
comparison src/LazyTensors/LazyTensors.jl @ 1839:e1077273eda5 refactor/lazy_tensors/elementwise_ops
Separate the logic of how TensorSum works from the semantics of + and -
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Thu, 09 Jan 2025 22:12:26 +0100 |
parents | 200971c71657 |
children |
comparison
equal
deleted
inserted
replaced
1838:4bd998069053 | 1839:e1077273eda5 |
---|---|
37 Base.:*(a::LazyTensor, args::Union{LazyTensor, AbstractArray}...) = foldr(*,(a,args...)) | 37 Base.:*(a::LazyTensor, args::Union{LazyTensor, AbstractArray}...) = foldr(*,(a,args...)) |
38 | 38 |
39 # Addition and subtraction of lazy tensors | 39 # Addition and subtraction of lazy tensors |
40 Base.:+(ts::LazyTensor...) = TensorSum(ts...) | 40 Base.:+(ts::LazyTensor...) = TensorSum(ts...) |
41 Base.:-(t::LazyTensor) = TensorNegation(t) | 41 Base.:-(t::LazyTensor) = TensorNegation(t) |
42 Base.:-(s::LazyTensor, t::LazyTensor) = TensorSum(s,-t) | 42 Base.:-(s::LazyTensor, t::LazyTensor) = s + (-t) |
43 ## Specializations to flatten the nesting of tensors. This helps Julia during inference. | |
44 Base.:+(t::TensorSum, s::TensorSum) = TensorSum(t.tms..., s.tms...) | |
45 Base.:+(t::TensorSum, s::LazyTensor) = TensorSum(t.tms..., s) | |
46 Base.:+(t::LazyTensor, s::TensorSum) = TensorSum(t, s.tms...) | |
43 | 47 |
44 # Composing lazy tensors | 48 # Composing lazy tensors |
45 Base.:∘(s::LazyTensor, t::LazyTensor) = TensorComposition(s,t) | 49 Base.:∘(s::LazyTensor, t::LazyTensor) = TensorComposition(s,t) |
46 Base.:∘(s::TensorComposition, t::LazyTensor) = s.t1∘(s.t2∘t) | 50 Base.:∘(s::TensorComposition, t::LazyTensor) = s.t1∘(s.t2∘t) |
47 | 51 |