[Mlir-commits] [mlir] 550ea38 - [mlir] Remove unnecessary canonicalization from Linalg Detensorize.cpp

Alexander Belyaev llvmlistbot at llvm.org
Mon Jan 3 07:34:03 PST 2022


Author: Alexander Belyaev
Date: 2022-01-03T16:33:45+01:00
New Revision: 550ea385abc2805fd3e0a539bf55bc82edb5c13e

URL: https://github.com/llvm/llvm-project/commit/550ea385abc2805fd3e0a539bf55bc82edb5c13e
DIFF: https://github.com/llvm/llvm-project/commit/550ea385abc2805fd3e0a539bf55bc82edb5c13e.diff

LOG:  [mlir] Remove unnecessary canonicalization from Linalg Detensorize.cpp

After https://reviews.llvm.org/D115821 it became possible to create
`tensor<elem_type>` with a single `tensor.from_elements` operation without
collapsing tensor shape from `tensor<1xelem_type>` to `tensor<elem_type>`

Differential Revision: https://reviews.llvm.org/D115891

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
    mlir/test/Dialect/Linalg/detensorize_0d.mlir
    mlir/test/Dialect/Linalg/detensorize_br_operands.mlir
    mlir/test/Dialect/Linalg/detensorize_if.mlir
    mlir/test/Dialect/Linalg/detensorize_trivial.mlir
    mlir/test/Dialect/Linalg/detensorize_while.mlir
    mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
index aa8a3b9f47715..5aebbe08fcd7f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
@@ -24,18 +24,14 @@ using namespace mlir::linalg;
 static Value sourceMaterializationCallback(OpBuilder &builder, Type type,
                                            ValueRange inputs, Location loc) {
   assert(inputs.size() == 1);
-  if (inputs[0].getType().isa<TensorType>())
+  auto inputType = inputs[0].getType();
+  if (inputType.isa<TensorType>())
     return nullptr;
 
   // A detensored value is converted back by creating a new tensor from its
   // element(s).
-  auto createNewTensorOp =
-      builder.create<tensor::FromElementsOp>(loc, inputs[0]);
-
-  // FromElementsOp results in a tensor<1xdtype>, we need to reshape that to
-  // a tensor<dtype> instead.
-  return builder.create<tensor::CollapseShapeOp>(
-      loc, type, createNewTensorOp, ArrayRef<ReassociationExprs>{});
+  return builder.create<tensor::FromElementsOp>(
+      loc, RankedTensorType::get({}, inputType), inputs[0]);
 }
 
 namespace {
@@ -161,39 +157,6 @@ class DetensorizeTypeConverter : public TypeConverter {
   }
 };
 
-/// Canonicalizes the pattern of the form
-///
-/// %tensor = tensor.from_elements(%element) : (i32) -> tensor<1xi32>
-/// %reshaped_tensor = tensor.collapse_shape %tensor []
-///     : tensor<1xi32> into tensor<i32>
-/// %extracted_element = tensor.extract %reshaped_tensor[] : tensor<i32>
-///
-/// to just %element.
-struct ExtractFromReshapeFromElements
-    : public OpRewritePattern<tensor::ExtractOp> {
-  using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(tensor::ExtractOp extract,
-                                PatternRewriter &rewriter) const final {
-    if (!extract.indices().empty())
-      return failure();
-
-    auto tensorReshape =
-        extract.tensor().getDefiningOp<tensor::CollapseShapeOp>();
-    if (tensorReshape == nullptr)
-      return failure();
-
-    auto tensorFromElements =
-        tensorReshape.getOperand()
-            .getDefiningOp<mlir::tensor::FromElementsOp>();
-    if (tensorFromElements == nullptr)
-      return failure();
-
-    rewriter.replaceOp(extract, tensorFromElements.getOperand(0));
-    return success();
-  }
-};
-
 /// @see LinalgDetensorize in Linalg/Passes.td for more details.
 struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
   LinalgDetensorize() = default;
@@ -591,7 +554,7 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
       signalPassFailure();
 
     RewritePatternSet canonPatterns(context);
-    canonPatterns.add<ExtractFromReshapeFromElements>(context);
+    tensor::FromElementsOp::getCanonicalizationPatterns(canonPatterns, context);
     if (failed(applyPatternsAndFoldGreedily(getOperation(),
                                             std::move(canonPatterns))))
       signalPassFailure();

diff  --git a/mlir/test/Dialect/Linalg/detensorize_0d.mlir b/mlir/test/Dialect/Linalg/detensorize_0d.mlir
index 2d73c5b97c455..9ce2f8ccfa5a7 100644
--- a/mlir/test/Dialect/Linalg/detensorize_0d.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_0d.mlir
@@ -19,8 +19,7 @@ func @detensor_simple(%arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> att
 // CHECK-DAG:     %[[arg2_val:.*]] = tensor.extract %[[arg2]]
 // CHECK:         %[[detensored_res:.*]] = arith.addf %[[arg1_val]], %[[arg2_val]]
 // CHECK:         %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res]]
-// CHECK:         %[[reshaped_tensor_res:.*]] = tensor.collapse_shape %[[new_tensor_res]]
-// CHECK:         return %[[reshaped_tensor_res]]
+// CHECK:         return %[[new_tensor_res]]
 
 func @detensor_op_sequence(%arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> attributes {iree.module.export} {
   %0 = linalg.init_tensor [] : tensor<f32>
@@ -60,8 +59,7 @@ func @detensor_op_sequence(%arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32
 // CHECK:         %[[detensored_res2:.*]] = arith.mulf %[[arg1_val]], %[[detensored_res]]
 // CHECK:         %[[detensored_res3:.*]] = arith.divf %[[detensored_res]], %[[detensored_res2]]
 // CHECK:         %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res3]]
-// CHECK:         %[[reshaped_tensor_res:.*]] = tensor.collapse_shape %[[new_tensor_res]]
-// CHECK:         return %[[reshaped_tensor_res]]
+// CHECK:         return %[[new_tensor_res]]
 
 func @detensor_multiple_ops(%arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> attributes {iree.module.export} {
   %0 = linalg.init_tensor [] : tensor<f32>
@@ -82,8 +80,7 @@ func @detensor_multiple_ops(%arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f3
 // CHECK:         %[[detensored_res:.*]] = arith.addf %[[arg1_val]], %[[arg2_val]]
 // CHECK:         %[[detensored_res2:.*]] = arith.mulf %[[detensored_res]], %[[arg2_val]]
 // CHECK:         %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res2]]
-// CHECK:         %[[reshaped_tensor_res:.*]] = tensor.collapse_shape %[[new_tensor_res]]
-// CHECK:         return %[[reshaped_tensor_res]]
+// CHECK:         return %[[new_tensor_res]]
 
 func @detensor_foreign_op(%arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> attributes {iree.module.export} {
   %0 = linalg.init_tensor [] : tensor<f32>
@@ -102,5 +99,4 @@ func @detensor_foreign_op(%arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32>
 // CHECK-DAG:     %[[arg2_val:.*]] = tensor.extract %[[arg2]]
 // CHECK:         %[[detensored_res:.*]] = "foreign.do_something"(%[[arg1_val]], %[[arg2_val]])
 // CHECK:         %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res]]
-// CHECK:         %[[reshaped_tensor_res:.*]] = tensor.collapse_shape %[[new_tensor_res]]
-// CHECK:         return %[[reshaped_tensor_res]]
+// CHECK:         return %[[new_tensor_res]]

diff  --git a/mlir/test/Dialect/Linalg/detensorize_br_operands.mlir b/mlir/test/Dialect/Linalg/detensorize_br_operands.mlir
index 2682a298dd2aa..ff7cd003ad4a8 100644
--- a/mlir/test/Dialect/Linalg/detensorize_br_operands.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_br_operands.mlir
@@ -2,17 +2,14 @@
 
 // TODO: Detensoring breaks if %arg0 or %arg1 are passed directly as tensors. Fix that.
 func @if_true_test(%arg0: i1, %arg1: i32) -> tensor<i32> attributes {} {
-  %arg0_t = tensor.from_elements %arg0 : tensor<1xi1>
-  %arg0_t2 = tensor.collapse_shape %arg0_t [] : tensor<1xi1> into tensor<i1>
-
-  %arg1_t = tensor.from_elements %arg1 : tensor<1xi32>
-  %arg1_t2 = tensor.collapse_shape %arg1_t [] : tensor<1xi32> into tensor<i32>
+  %arg0_t = tensor.from_elements %arg0 : tensor<i1>
+  %arg1_t = tensor.from_elements %arg1 : tensor<i32>
 
   %cst = arith.constant dense<10> : tensor<i32>
   %2 = linalg.init_tensor [] : tensor<i8>
   %3 = linalg.generic
     {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []}
-    ins(%arg0_t2 : tensor<i1>)
+    ins(%arg0_t : tensor<i1>)
     outs(%2 : tensor<i8>) {
   ^bb0(%arg2: i1, %arg3: i8):  // no predecessors
     %10 = arith.extui %arg2 : i1 to i8
@@ -20,12 +17,12 @@ func @if_true_test(%arg0: i1, %arg1: i32) -> tensor<i32> attributes {} {
   } -> tensor<i8>
   %4 = tensor.extract %3[] : tensor<i8>
   %5 = arith.trunci %4 : i8 to i1
-  cond_br %5, ^bb1, ^bb2(%arg1_t2 : tensor<i32>)
+  cond_br %5, ^bb1, ^bb2(%arg1_t : tensor<i32>)
 ^bb1:
   %6 = linalg.init_tensor [] : tensor<i32>
   %7 = linalg.generic
     {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []}
-    ins(%arg1_t2, %cst : tensor<i32>, tensor<i32>)
+    ins(%arg1_t, %cst : tensor<i32>, tensor<i32>)
     outs(%6 : tensor<i32>) {
   ^bb0(%arg2: i32, %arg3: i32, %arg4: i32):  // no predecessors
     %10 = arith.addi %arg2, %arg3 : i32
@@ -44,6 +41,5 @@ func @if_true_test(%arg0: i1, %arg1: i32) -> tensor<i32> attributes {} {
 // CHECK-NEXT:     %[[add_res:.*]] = arith.addi
 // CHECK-NEXT:     br ^[[bb2]](%[[add_res]] : i32)
 // CHECK-NEXT:   ^[[bb2]]
-// CHECK-NEXT:     tensor.from_elements
-// CHECK-NEXT:     %[[func_res:.*]] = tensor.collapse_shape
+// CHECK-NEXT:     %[[func_res:.*]] = tensor.from_elements
 // CHECK-NEXT:     return %[[func_res]]

diff  --git a/mlir/test/Dialect/Linalg/detensorize_if.mlir b/mlir/test/Dialect/Linalg/detensorize_if.mlir
index c9e843bc7d69b..4341cf262fb6a 100644
--- a/mlir/test/Dialect/Linalg/detensorize_if.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_if.mlir
@@ -9,17 +9,15 @@
 
 func @main() -> (tensor<i32>) attributes {} {
   %c0 = arith.constant 0 : i32
-  %0 = tensor.from_elements %c0 : tensor<1xi32>
-  %reshaped0 = tensor.collapse_shape %0 [] : tensor<1xi32> into tensor<i32>
+  %0 = tensor.from_elements %c0 : tensor<i32>
   %c10 = arith.constant 10 : i32
-  %1 = tensor.from_elements %c10 : tensor<1xi32>
-  %reshaped1 = tensor.collapse_shape %1 [] : tensor<1xi32> into tensor<i32>
-  br ^bb1(%reshaped0 : tensor<i32>)
+  %1 = tensor.from_elements %c10 : tensor<i32>
+  br ^bb1(%0 : tensor<i32>)
 
 ^bb1(%2: tensor<i32>):  // 2 preds: ^bb0, ^bb2
   %3 = linalg.init_tensor [] : tensor<i1>
   %4 = linalg.generic #attrs
-    ins(%2, %reshaped1 : tensor<i32>, tensor<i32>)
+    ins(%2, %1 : tensor<i32>, tensor<i32>)
     outs(%3 : tensor<i1>) {
     ^bb0(%arg0: i32, %arg1: i32, %arg2: i1):  // no predecessors
       %8 = arith.cmpi slt, %arg0, %arg1 : i32
@@ -54,8 +52,7 @@ func @main() -> (tensor<i32>) attributes {} {
 // CHECK-NEXT:     arith.addi %{{.*}}, %{{.*}}
 // CHECK-NEXT:     br ^[[bb3:.*]](%{{.*}} : i32)
 // CHECK-NEXT:   ^[[bb3]](%{{.*}}: i32)
-// CHECK-NEXT:     tensor.from_elements %{{.*}} : tensor<1xi32>
-// CHECK-NEXT:     tensor.collapse_shape %{{.*}} [] : tensor<1xi32> into tensor<i32>
+// CHECK-NEXT:     tensor.from_elements %{{.*}} : tensor<i32>
 // CHECK-NEXT:     return %{{.*}}
 // CHECK-NEXT:   }
 
@@ -73,17 +70,15 @@ func @main() -> (tensor<i32>) attributes {} {
 
 func @main() -> (tensor<i32>) attributes {} {
   %c0 = arith.constant 0 : i32
-  %0 = tensor.from_elements %c0 : tensor<1xi32>
-  %reshaped0 = tensor.collapse_shape %0 [] : tensor<1xi32> into tensor<i32>
+  %0 = tensor.from_elements %c0 : tensor<i32>
   %c10 = arith.constant 10 : i32
-  %1 = tensor.from_elements %c10 : tensor<1xi32>
-  %reshaped1 = tensor.collapse_shape %1 [] : tensor<1xi32> into tensor<i32>
-  br ^bb1(%reshaped0 : tensor<i32>)
+  %1 = tensor.from_elements %c10 : tensor<i32>
+  br ^bb1(%0 : tensor<i32>)
 
 ^bb1(%2: tensor<i32>):  // 2 preds: ^bb0, ^bb2
   %3 = linalg.init_tensor [] : tensor<i1>
   %4 = linalg.generic #attrs
-    ins(%2, %reshaped1 : tensor<i32>, tensor<i32>)
+    ins(%2, %1 : tensor<i32>, tensor<i32>)
     outs(%3 : tensor<i1>) {
     ^bb0(%arg0: i32, %arg1: i32, %arg2: i1):  // no predecessors
       %8 = arith.cmpi slt, %arg0, %arg1 : i32
@@ -123,8 +118,7 @@ func @main() -> (tensor<i32>) attributes {} {
 // CHECK-NEXT:   ^[[bb3]](%{{.*}}: i32)
 // CHECK-NEXT:     br ^[[bb4:.*]](%{{.*}} : i32)
 // CHECK-NEXT:   ^[[bb4]](%{{.*}}: i32)
-// CHECK-NEXT:     tensor.from_elements %{{.*}} : tensor<1xi32>
-// CHECK-NEXT:     tensor.collapse_shape %{{.*}} [] : tensor<1xi32> into tensor<i32>
+// CHECK-NEXT:     tensor.from_elements %{{.*}} : tensor<i32>
 // CHECK-NEXT:     return %{{.*}}
 // CHECK-NEXT:   }
 
@@ -139,17 +133,15 @@ func @main() -> (tensor<i32>) attributes {} {
 
 func @main() -> (tensor<i32>) attributes {} {
   %c0 = arith.constant 0 : i32
-  %0 = tensor.from_elements %c0 : tensor<1xi32>
-  %reshaped0 = tensor.collapse_shape %0 [] : tensor<1xi32> into tensor<i32>
+  %0 = tensor.from_elements %c0 : tensor<i32>
   %c10 = arith.constant 10 : i32
-  %1 = tensor.from_elements %c10 : tensor<1xi32>
-  %reshaped1 = tensor.collapse_shape %1 [] : tensor<1xi32> into tensor<i32>
-  br ^bb1(%reshaped0 : tensor<i32>)
+  %1 = tensor.from_elements %c10 : tensor<i32>
+  br ^bb1(%0 : tensor<i32>)
 
 ^bb1(%2: tensor<i32>):  // 2 preds: ^bb0, ^bb2
   %3 = linalg.init_tensor [] : tensor<i1>
   %4 = linalg.generic #attrs
-    ins(%2, %reshaped1 : tensor<i32>, tensor<i32>)
+    ins(%2, %1 : tensor<i32>, tensor<i32>)
     outs(%3 : tensor<i1>) {
     ^bb0(%arg0: i32, %arg1: i32, %arg2: i1):  // no predecessors
       %8 = arith.cmpi slt, %arg0, %arg1 : i32
@@ -163,11 +155,10 @@ func @main() -> (tensor<i32>) attributes {} {
   cond_br %5, ^bb2(%2 : tensor<i32>), ^bb2(%2 : tensor<i32>)
 
 ^bb2(%6: tensor<i32>):  // pred: ^bb1
-  %12 = tensor.from_elements %c10 : tensor<1xi32>
-  %reshaped12 = tensor.collapse_shape %12 [] : tensor<1xi32> into tensor<i32>
+  %12 = tensor.from_elements %c10 : tensor<i32>
   %7 = linalg.init_tensor [] : tensor<i32>
   %8 = linalg.generic #attrs
-    ins(%6, %reshaped12 : tensor<i32>, tensor<i32>)
+    ins(%6, %12 : tensor<i32>, tensor<i32>)
     outs(%7 : tensor<i32>) {
     ^bb0(%arg0: i32, %arg1: i32, %arg2: i32):  // no predecessors
       %9 = arith.addi %arg0, %arg1 : i32
@@ -190,7 +181,6 @@ func @main() -> (tensor<i32>) attributes {} {
 // CHECK-NEXT:     arith.addi %{{.*}}, %{{.*}}
 // CHECK-NEXT:     br ^[[bb3:.*]](%{{.*}} : i32)
 // CHECK-NEXT:   ^[[bb3]](%{{.*}}: i32)
-// CHECK-NEXT:     tensor.from_elements %{{.*}} : tensor<1xi32>
-// CHECK-NEXT:     tensor.collapse_shape %{{.*}} [] : tensor<1xi32> into tensor<i32>
+// CHECK-NEXT:     tensor.from_elements %{{.*}} : tensor<i32>
 // CHECK-NEXT:     return %{{.*}}
 // CHECK-NEXT:   }

diff  --git a/mlir/test/Dialect/Linalg/detensorize_trivial.mlir b/mlir/test/Dialect/Linalg/detensorize_trivial.mlir
index 5862327ebe6c7..76b99d916acb1 100644
--- a/mlir/test/Dialect/Linalg/detensorize_trivial.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_trivial.mlir
@@ -11,11 +11,10 @@
 
 func @main(%farg0 : tensor<i32>) -> (tensor<i1>) attributes {} {
   %c10 = arith.constant 10 : i32
-  %1 = tensor.from_elements %c10 : tensor<1xi32>
-  %reshaped1 = tensor.collapse_shape %1 [] : tensor<1xi32> into tensor<i32>
+  %1 = tensor.from_elements %c10 : tensor<i32>
   %3 = linalg.init_tensor [] : tensor<i1>
   %4 = linalg.generic #attrs
-    ins(%farg0, %reshaped1 : tensor<i32>, tensor<i32>)
+    ins(%farg0, %1 : tensor<i32>, tensor<i32>)
     outs(%3 : tensor<i1>) {
     ^bb0(%arg0: i32, %arg1: i32, %arg2: i1):
       %8 = arith.cmpi slt, %arg0, %arg1 : i32
@@ -30,7 +29,6 @@ func @main(%farg0 : tensor<i32>) -> (tensor<i1>) attributes {} {
 // DET-ALL-NEXT:    tensor.extract %{{.*}}[]
 // DET-ALL-NEXT:    arith.cmpi slt, %{{.*}}, %{{.*}}
 // DET-ALL-NEXT:    tensor.from_elements %{{.*}}
-// DET-ALL-NEXT:    tensor.collapse_shape %{{.*}}
 // DET-ALL-NEXT:    return %{{.*}} : tensor<i1>
 // DET-ALL-NEXT:  }
 

diff  --git a/mlir/test/Dialect/Linalg/detensorize_while.mlir b/mlir/test/Dialect/Linalg/detensorize_while.mlir
index 9ece0029737c9..6ae4c1ddef2dd 100644
--- a/mlir/test/Dialect/Linalg/detensorize_while.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_while.mlir
@@ -52,7 +52,6 @@ func @main(%farg0: tensor<i32>, %farg1: tensor<i32>) -> tensor<i32> attributes {
 // DET-ALL:         br ^[[bb1]](%{{.*}} : i32)
 // DET-ALL:       ^[[bb3]](%{{.*}}: i32)
 // DET-ALL:         tensor.from_elements {{.*}}
-// DET-ALL:         tensor.collapse_shape {{.*}}
 // DET-ALL:         return %{{.*}} : tensor<i32>
 
 // Test detensoring only ops involed in control-flow.
@@ -68,6 +67,5 @@ func @main(%farg0: tensor<i32>, %farg1: tensor<i32>) -> tensor<i32> attributes {
 // DET-CF:         arith.addi {{.*}}
 // DET-CF:         br ^[[bb1]](%{{.*}} : i32)
 // DET-CF:       ^[[bb3]](%{{.*}}: i32)
-// DET-CF:         tensor.from_elements %{{.*}} : tensor<1xi32>
-// DET-CF:         tensor.collapse_shape %{{.*}} [] : tensor<1xi32> into tensor<i32>
+// DET-CF:         tensor.from_elements %{{.*}} : tensor<i32>
 // DET-CF:         return %{{.*}} : tensor<i32>

diff  --git a/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir b/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir
index 765692fa2d3d4..a464fb1a90e86 100644
--- a/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir
@@ -76,8 +76,7 @@ func @main(%farg0: tensor<10xi32>, %farg1: tensor<i32>) -> tensor<i32> attribute
 // DET-ALL:         cmpi slt, %{{.*}}, %{{.*}} : i32
 // DET-ALL:         cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32)
 // DET-ALL:       ^[[bb2]](%{{.*}}: i32)
-// DET-ALL:         tensor.from_elements %{{.*}} : tensor<1xi32>
-// DET-ALL:         tensor.collapse_shape %{{.*}} [] : tensor<1xi32> into tensor<i32>
+// DET-ALL:         tensor.from_elements %{{.*}} : tensor<i32>
 // DET-ALL:         linalg.init_tensor [10] : tensor<10xi32>
 // DET-ALL:         linalg.generic {{{.*}}} ins(%{{.*}} : tensor<i32>) outs(%{{.*}} : tensor<10xi32>) {
 // DET-ALL:         ^bb0(%{{.*}}: i32, %{{.*}}: i32):
@@ -85,8 +84,7 @@ func @main(%farg0: tensor<10xi32>, %farg1: tensor<i32>) -> tensor<i32> attribute
 // DET-ALL:         } -> tensor<10xi32>
 // DET-ALL:         br ^[[bb1]](%{{.*}} : tensor<10xi32>)
 // DET-ALL:       ^[[bb3]](%{{.*}}: i32)
-// DET-ALL:         tensor.from_elements %{{.*}} : tensor<1xi32>
-// DET-ALL:         tensor.collapse_shape %{{.*}} [] : tensor<1xi32> into tensor<i32>
+// DET-ALL:         tensor.from_elements %{{.*}} : tensor<i32>
 // DET-ALL:         return %{{.*}} : tensor<i32>
 // DET-ALL:       }
 


        


More information about the Mlir-commits mailing list