Mercurial > repos > public > sbplib_julia
comparison src/LazyTensors/LazyTensors.jl @ 1837:200971c71657 refactor/lazy_tensors/elementwise_ops
Refactor ElementwiseTensorOperation into TensorSum and use TensorNegation for handling subtraction
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Thu, 09 Jan 2025 21:46:01 +0100 |
parents | 368999a2e243 |
children | e1077273eda5 |
comparison
equal
deleted
inserted
replaced
1836:368999a2e243 | 1837:200971c71657 |
---|---|
8 | 8 |
9 export TensorApplication | 9 export TensorApplication |
10 export TensorTranspose | 10 export TensorTranspose |
11 export TensorComposition | 11 export TensorComposition |
12 export TensorNegation | 12 export TensorNegation |
13 export TensorSum | |
13 export IdentityTensor | 14 export IdentityTensor |
14 export ScalingTensor | 15 export ScalingTensor |
15 export DiagonalTensor | 16 export DiagonalTensor |
16 export DenseTensor | 17 export DenseTensor |
17 export InflatedTensor | 18 export InflatedTensor |
34 Base.:*(a::LazyTensor, v::AbstractArray) = TensorApplication(a,v) | 35 Base.:*(a::LazyTensor, v::AbstractArray) = TensorApplication(a,v) |
35 Base.:*(a::LazyTensor, b::LazyTensor) = throw(MethodError(Base.:*,(a,b))) | 36 Base.:*(a::LazyTensor, b::LazyTensor) = throw(MethodError(Base.:*,(a,b))) |
36 Base.:*(a::LazyTensor, args::Union{LazyTensor, AbstractArray}...) = foldr(*,(a,args...)) | 37 Base.:*(a::LazyTensor, args::Union{LazyTensor, AbstractArray}...) = foldr(*,(a,args...)) |
37 | 38 |
38 # Addition and subtraction of lazy tensors | 39 # Addition and subtraction of lazy tensors |
39 Base.:+(ts::LazyTensor...) = ElementwiseTensorOperation{:+}(ts...) | 40 Base.:+(ts::LazyTensor...) = TensorSum(ts...) |
40 Base.:-(t::LazyTensor) = TensorNegation(t) | 41 Base.:-(t::LazyTensor) = TensorNegation(t) |
41 Base.:-(s::LazyTensor, t::LazyTensor) = ElementwiseTensorOperation{:-}(s,t) | 42 Base.:-(s::LazyTensor, t::LazyTensor) = TensorSum(s,-t) |
42 | 43 |
43 # Composing lazy tensors | 44 # Composing lazy tensors |
44 Base.:∘(s::LazyTensor, t::LazyTensor) = TensorComposition(s,t) | 45 Base.:∘(s::LazyTensor, t::LazyTensor) = TensorComposition(s,t) |
45 Base.:∘(s::TensorComposition, t::LazyTensor) = s.t1∘(s.t2∘t) | 46 Base.:∘(s::TensorComposition, t::LazyTensor) = s.t1∘(s.t2∘t) |
46 | 47 |