Mercurial > repos > public > sbplib_julia
comparison src/LazyTensors/LazyTensors.jl @ 1900:418566cdd689
Merge refactor/lazy_tensors/elementwise_ops
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Fri, 31 Jan 2025 20:35:28 +0100 |
parents | e1077273eda5 |
children |
comparison
equal
deleted
inserted
replaced
1896:9d708f3300d5 | 1900:418566cdd689 |
---|---|
7 export range_size, domain_size | 7 export range_size, domain_size |
8 | 8 |
9 export TensorApplication | 9 export TensorApplication |
10 export TensorTranspose | 10 export TensorTranspose |
11 export TensorComposition | 11 export TensorComposition |
12 export TensorNegation | |
13 export TensorSum | |
12 export IdentityTensor | 14 export IdentityTensor |
13 export ScalingTensor | 15 export ScalingTensor |
14 export DiagonalTensor | 16 export DiagonalTensor |
15 export DenseTensor | 17 export DenseTensor |
16 export InflatedTensor | 18 export InflatedTensor |
33 Base.:*(a::LazyTensor, v::AbstractArray) = TensorApplication(a,v) | 35 Base.:*(a::LazyTensor, v::AbstractArray) = TensorApplication(a,v) |
34 Base.:*(a::LazyTensor, b::LazyTensor) = throw(MethodError(Base.:*,(a,b))) | 36 Base.:*(a::LazyTensor, b::LazyTensor) = throw(MethodError(Base.:*,(a,b))) |
35 Base.:*(a::LazyTensor, args::Union{LazyTensor, AbstractArray}...) = foldr(*,(a,args...)) | 37 Base.:*(a::LazyTensor, args::Union{LazyTensor, AbstractArray}...) = foldr(*,(a,args...)) |
36 | 38 |
37 # Addition and subtraction of lazy tensors | 39 # Addition and subtraction of lazy tensors |
38 Base.:+(s::LazyTensor, t::LazyTensor) = ElementwiseTensorOperation{:+}(s,t) | 40 Base.:+(ts::LazyTensor...) = TensorSum(ts...) |
39 Base.:-(s::LazyTensor, t::LazyTensor) = ElementwiseTensorOperation{:-}(s,t) | 41 Base.:-(t::LazyTensor) = TensorNegation(t) |
42 Base.:-(s::LazyTensor, t::LazyTensor) = s + (-t) | |
43 ## Specializations to flatten the nesting of tensors. This helps Julia during inference. | |
44 Base.:+(t::TensorSum, s::TensorSum) = TensorSum(t.tms..., s.tms...) | |
45 Base.:+(t::TensorSum, s::LazyTensor) = TensorSum(t.tms..., s) | |
46 Base.:+(t::LazyTensor, s::TensorSum) = TensorSum(t, s.tms...) | |
40 | 47 |
41 # Composing lazy tensors | 48 # Composing lazy tensors |
42 Base.:∘(s::LazyTensor, t::LazyTensor) = TensorComposition(s,t) | 49 Base.:∘(s::LazyTensor, t::LazyTensor) = TensorComposition(s,t) |
43 Base.:∘(s::TensorComposition, t::LazyTensor) = s.t1∘(s.t2∘t) | 50 Base.:∘(s::TensorComposition, t::LazyTensor) = s.t1∘(s.t2∘t) |
44 | 51 |