diff src/LazyTensors/lazy_tensor_operations.jl @ 1240:a9ac86f6be8a

Merge refactor/LazyTensors/tuple_manipulation
author Jonatan Werpers <jonatan@werpers.com>
date Tue, 21 Feb 2023 21:01:46 +0100
parents 8f4259fbd39c
children aa8579b7fc15
line wrap: on
line diff
--- a/src/LazyTensors/lazy_tensor_operations.jl	Wed Feb 08 10:29:06 2023 +0100
+++ b/src/LazyTensors/lazy_tensor_operations.jl	Tue Feb 21 21:01:46 2023 +0100
@@ -176,7 +176,7 @@
 # TODO: Implement some pretty printing in terms of ⊗. E.g InflatedTensor(I(3),B,I(2)) -> I(3)⊗B⊗I(2)
 
 function range_size(itm::InflatedTensor)
-    return flatten_tuple(
+    return concatenate_tuples(
         range_size(itm.before),
         range_size(itm.tm),
         range_size(itm.after),
@@ -184,7 +184,7 @@
 end
 
 function domain_size(itm::InflatedTensor)
-    return flatten_tuple(
+    return concatenate_tuples(
         domain_size(itm.before),
         domain_size(itm.tm),
         domain_size(itm.after),
@@ -197,7 +197,7 @@
     dim_range = range_dim(itm.tm)
     dim_after = range_dim(itm.after)
 
-    view_index, inner_index = split_index(Val(dim_before), Val(dim_domain), Val(dim_range), Val(dim_after), I...)
+    view_index, inner_index = split_index(dim_before, dim_domain, dim_range, dim_after, I...)
 
     v_inner = view(v, view_index...)
     return apply(itm.tm, v_inner, inner_index...)
@@ -209,7 +209,7 @@
     dim_range = range_dim(itm.tm)
     dim_after = range_dim(itm.after)
 
-    view_index, inner_index = split_index(Val(dim_before), Val(dim_range), Val(dim_domain), Val(dim_after), I...)
+    view_index, inner_index = split_index(dim_before, dim_range, dim_domain, dim_after, I...)
 
     v_inner = view(v, view_index...)
     return apply_transpose(itm.tm, v_inner, inner_index...)