Mercurial > repos > public > sbplib_julia
comparison src/LazyTensors/lazy_tensor_operations.jl @ 417:4c6604b7d990 feature/tensor_composition
Add dimension checking in the constructor
author | Jonatan Werpers <jonatan@werpers.com> |
---|---|
date | Fri, 16 Oct 2020 20:32:09 +0200 |
parents | 814865d40f48 |
children | 264af2bb646f |
comparison
equal
deleted
inserted
replaced
416:ebc9b2383dae | 417:4c6604b7d990 |
---|---|
81 """ | 81 """ |
82 struct TensorMappingComposition{T,R,K,D, TM1<:TensorMapping{T,R,K}, TM2<:TensorMapping{T,K,D}} <: TensorMapping{T,R,D} | 82 struct TensorMappingComposition{T,R,K,D, TM1<:TensorMapping{T,R,K}, TM2<:TensorMapping{T,K,D}} <: TensorMapping{T,R,D} |
83 t1::TM1 | 83 t1::TM1 |
84 t2::TM2 | 84 t2::TM2 |
85 | 85 |
86 @inline function TensorMappingComposition(t1::TensorMapping{T,R,K}, t2::TensorMapping{T,K,D}) where {T,R,K,D} | |
87 @boundscheck if domain_size(t1) != range_size(t2) | |
88 throw(DimensionMismatch("The first argument has domain size $(domain_size(t1)) while the second has range size $(range_size(t2)) ")) | |
89 end | |
90 return new{T,R,K,D, typeof(t1), typeof(t2)}(t1,t2) | |
91 end | |
86 # Add check for matching sizes as a boundscheck | 92 # Add check for matching sizes as a boundscheck |
87 end | 93 end |
88 export TensorMappingComposition | 94 export TensorMappingComposition |
89 | 95 |
90 range_size(tm::TensorMappingComposition) = range_size(tm.t1) | 96 range_size(tm::TensorMappingComposition) = range_size(tm.t1) |
96 | 102 |
97 function apply_transpose(c::TensorMappingComposition{T,R,K,D}, v::AbstractArray{T,R}, I::Vararg{S,D} where S) where {T,R,K,D} | 103 function apply_transpose(c::TensorMappingComposition{T,R,K,D}, v::AbstractArray{T,R}, I::Vararg{S,D} where S) where {T,R,K,D} |
98 apply_transpose(c.t2, LazyTensorMappingApplication(c.t1',v), I...) | 104 apply_transpose(c.t2, LazyTensorMappingApplication(c.t1',v), I...) |
99 end | 105 end |
100 | 106 |
101 Base.:∘(s::TensorMapping, t::TensorMapping) = TensorMappingComposition(s,t) | 107 Base.@propagate_inbounds Base.:∘(s::TensorMapping, t::TensorMapping) = TensorMappingComposition(s,t) |
102 | 108 |
103 """ | 109 """ |
104 LazyLinearMap{T,R,D,...}(A, range_indicies, domain_indicies) | 110 LazyLinearMap{T,R,D,...}(A, range_indicies, domain_indicies) |
105 | 111 |
106 TensorMapping defined by the AbstractArray A. `range_indicies` and `domain_indicies` define which indicies of A should | 112 TensorMapping defined by the AbstractArray A. `range_indicies` and `domain_indicies` define which indicies of A should |