[Mlir-commits] [mlir] [mlir][tensor] Fold unpadding collapse_shape into extract_slice (PR #93554)

Adam Siemieniuk llvmlistbot at llvm.org
Fri May 31 01:21:02 PDT 2024


https://github.com/adam-smnk updated https://github.com/llvm/llvm-project/pull/93554

>From e4bfb741c6b224c2c5310e897d1171b9337ad81e Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Tue, 28 May 2024 16:22:55 +0200
Subject: [PATCH 1/2] [mlir][tensor] Fold unpadding collapse_shape into
 extract_slice

---
 .../Tensor/Transforms/ReshapePatterns.cpp     | 46 +++++++++++--
 .../Tensor/fold-reassociative-reshapes.mlir   | 69 +++++++++++++++++++
 2 files changed, 109 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
index 6cf0f845f59db..d7c608a773bb7 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
@@ -48,6 +48,39 @@ struct FoldExpandOfRankReducingExtract
   }
 };
 
+/// Fold collapse_shape which only removes static dimensions of size `1`
+/// into extract_slice.
+struct FoldUnPaddingCollapseIntoExtract
+    : public OpRewritePattern<tensor::CollapseShapeOp> {
+  using OpRewritePattern<tensor::CollapseShapeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseShapeOp,
+                                PatternRewriter &rewriter) const override {
+    auto extractSliceOp =
+        collapseShapeOp.getSrc().getDefiningOp<tensor::ExtractSliceOp>();
+    // Collapse cannot be folded away with multiple users of the extract slice
+    // and it is not necessarily beneficial to only convert the collapse into
+    // another extract slice.
+    if (!extractSliceOp || !extractSliceOp.getResult().hasOneUse())
+      return failure();
+
+    // Only fold away simple collapse where all removed dimensions have static
+    // size `1`.
+    SliceVerificationResult res = isRankReducedType(
+        collapseShapeOp.getSrcType(), collapseShapeOp.getResultType());
+    if (res != SliceVerificationResult::Success)
+      return rewriter.notifyMatchFailure(collapseShapeOp,
+                                         "expected unpadding collapse");
+
+    Value unPaddedExtractSlice = rewriter.create<tensor::ExtractSliceOp>(
+        extractSliceOp.getLoc(), collapseShapeOp.getResultType(),
+        extractSliceOp.getSource(), extractSliceOp.getMixedOffsets(),
+        extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides());
+    rewriter.replaceOp(collapseShapeOp, unPaddedExtractSlice);
+    return success();
+  }
+};
+
 /// Fold insert_slice(collapse_shape) ops that cancel itself out.
 template <typename OpTy>
 struct FoldInsertOfRankReducingInsert : public OpRewritePattern<OpTy> {
@@ -111,10 +144,11 @@ struct FoldPaddingExpandIntoInsert : public OpRewritePattern<OpTy> {
 
 void mlir::tensor::populateReassociativeReshapeFoldingPatterns(
     RewritePatternSet &patterns) {
-  patterns.add<FoldExpandOfRankReducingExtract,
-               FoldInsertOfRankReducingInsert<tensor::InsertSliceOp>,
-               FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>,
-               FoldPaddingExpandIntoInsert<tensor::InsertSliceOp>,
-               FoldPaddingExpandIntoInsert<tensor::ParallelInsertSliceOp>>(
-      patterns.getContext());
+  patterns
+      .add<FoldExpandOfRankReducingExtract, FoldUnPaddingCollapseIntoExtract,
+           FoldInsertOfRankReducingInsert<tensor::InsertSliceOp>,
+           FoldInsertOfRankReducingInsert<tensor::ParallelInsertSliceOp>,
+           FoldPaddingExpandIntoInsert<tensor::InsertSliceOp>,
+           FoldPaddingExpandIntoInsert<tensor::ParallelInsertSliceOp>>(
+          patterns.getContext());
 }
diff --git a/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir b/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir
index 644d9a918f6ca..c2368c4bf2c91 100644
--- a/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir
+++ b/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir
@@ -22,6 +22,75 @@ func.func @expand_shape_of_rank_reducing_extract(
 
 // -----
 
+// CHECK-LABEL: func @unpadding_collapse_of_extract_slice(
+//  CHECK-SAME:     %[[t:.*]]: tensor<?x?x?x?xf32>
+//  CHECK-SAME:     %[[x:[a-zA-Z0-9_]+]]: index
+//  CHECK-SAME:     %[[y:[a-zA-Z0-9_]+]]: index
+//       CHECK:   %[[extract:.*]] = tensor.extract_slice %[[t]][%[[x]], %[[y]], 0, 0] [1, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<?x?xf32>
+//       CHECK:   return %[[extract]]
+func.func @unpadding_collapse_of_extract_slice(
+    %t: tensor<?x?x?x?xf32>, %x: index, %y: index)
+  -> tensor<?x?xf32> {
+  %c1 = arith.constant 1 : index
+  %c3 = arith.constant 3 : index
+  %sz0 = tensor.dim %t, %c1 : tensor<?x?x?x?xf32>
+  %sz1 = tensor.dim %t, %c3 : tensor<?x?x?x?xf32>
+  %0 = tensor.extract_slice %t[%x, %y, 0, 0] [1, %sz0, 1, %sz1] [1, 1, 1, 1]
+      : tensor<?x?x?x?xf32> to tensor<1x?x1x?xf32>
+  %1 = tensor.collapse_shape %0 [[0, 1], [2, 3]]
+      : tensor<1x?x1x?xf32> into tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @non_unpadding_collapse_of_extract_slice(
+//  CHECK-SAME:     %[[t:.*]]: tensor<?x?x?x?xf32>
+//  CHECK-SAME:     %[[x:[a-zA-Z0-9_]+]]: index
+//  CHECK-SAME:     %[[y:[a-zA-Z0-9_]+]]: index
+//  CHECK-SAME:     %[[sz:[a-zA-Z0-9_]+]]: index
+//       CHECK:   %[[extract:.*]] = tensor.extract_slice %[[t]][%[[x]], %[[y]], 0, 0] [%{{.*}}, %{{.*}}, %[[sz]], 1] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<?x?x?xf32>
+//       CHECK:   %[[collapse:.*]] = tensor.collapse_shape %[[extract]] {{\[}}[0], [1, 2]] : tensor<?x?x?xf32> into tensor<?x?xf32>
+//       CHECK:   return %[[collapse]]
+func.func @non_unpadding_collapse_of_extract_slice(
+    %t: tensor<?x?x?x?xf32>, %x: index, %y: index, %sz: index)
+  -> tensor<?x?xf32> {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %sz0 = tensor.dim %t, %c0 : tensor<?x?x?x?xf32>
+  %sz1 = tensor.dim %t, %c1 : tensor<?x?x?x?xf32>
+  %0 = tensor.extract_slice %t[%x, %y, 0, 0] [%sz0, %sz1, %sz, 1] [1, 1, 1, 1]
+      : tensor<?x?x?x?xf32> to tensor<?x?x?xf32>
+  %1 = tensor.collapse_shape %0 [[0], [1, 2]]
+      : tensor<?x?x?xf32> into tensor<?x?xf32>
+  return %1 : tensor<?x?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @unpadding_collapse_of_extract_slice_with_multiple_users(
+//  CHECK-SAME:     %[[t:.*]]: tensor<?x?x?x?xf32>
+//  CHECK-SAME:     %[[x:[a-zA-Z0-9_]+]]: index
+//  CHECK-SAME:     %[[y:[a-zA-Z0-9_]+]]: index
+//       CHECK:   %[[extract:.*]] = tensor.extract_slice %[[t]][%[[x]], %[[y]], 0, 0] [1, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<1x?x1x?xf32>
+//       CHECK:   %[[collapse:.*]] = tensor.collapse_shape %[[extract]] {{\[}}[0, 1], [2, 3]] : tensor<1x?x1x?xf32> into tensor<?x?xf32>
+//       CHECK:   return %[[extract]], %[[collapse]]
+func.func @unpadding_collapse_of_extract_slice_with_multiple_users(
+    %t: tensor<?x?x?x?xf32>, %x: index, %y: index)
+  -> (tensor<1x?x1x?xf32>, tensor<?x?xf32>) {
+  %c1 = arith.constant 1 : index
+  %c3 = arith.constant 3 : index
+  %sz0 = tensor.dim %t, %c1 : tensor<?x?x?x?xf32>
+  %sz1 = tensor.dim %t, %c3 : tensor<?x?x?x?xf32>
+  %0 = tensor.extract_slice %t[%x, %y, 0, 0] [1, %sz0, 1, %sz1] [1, 1, 1, 1]
+      : tensor<?x?x?x?xf32> to tensor<1x?x1x?xf32>
+  %1 = tensor.collapse_shape %0 [[0, 1], [2, 3]]
+      : tensor<1x?x1x?xf32> into tensor<?x?xf32>
+  return %0, %1 : tensor<1x?x1x?xf32>, tensor<?x?xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @rank_reducing_insert_of_collapse_shape(
 //  CHECK-SAME:     %[[t:.*]]: tensor<?x1x1x5xf32>
 //       CHECK:   %[[insert:.*]] = tensor.insert_slice %[[t]] into %{{.*}}[0, 0, 0, 0] [%{{.*}}, 1, 1, 5] [1, 1, 1, 1] : tensor<?x1x1x5xf32> into tensor<?x?x?x?xf32>

>From 51e9a641b66557483c5c5d9636dbd166735e358d Mon Sep 17 00:00:00 2001
From: Adam Siemieniuk <adam.siemieniuk at intel.com>
Date: Fri, 31 May 2024 10:20:45 +0200
Subject: [PATCH 2/2] Address comments

---
 .../Tensor/Transforms/ReshapePatterns.cpp     |  2 +-
 .../Tensor/fold-reassociative-reshapes.mlir   | 39 ++++++++++++-------
 2 files changed, 27 insertions(+), 14 deletions(-)

diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
index d7c608a773bb7..be0d71866a095 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
@@ -61,7 +61,7 @@ struct FoldUnPaddingCollapseIntoExtract
     // Collapse cannot be folded away with multiple users of the extract slice
     // and it is not necessarily beneficial to only convert the collapse into
     // another extract slice.
-    if (!extractSliceOp || !extractSliceOp.getResult().hasOneUse())
+    if (!extractSliceOp || !extractSliceOp->hasOneUse())
       return failure();
 
     // Only fold away simple collapse where all removed dimensions have static
diff --git a/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir b/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir
index c2368c4bf2c91..594d540dfca0a 100644
--- a/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir
+++ b/mlir/test/Dialect/Tensor/fold-reassociative-reshapes.mlir
@@ -2,8 +2,10 @@
 
 // CHECK-LABEL: func @expand_shape_of_rank_reducing_extract(
 //  CHECK-SAME:     %[[t:.*]]: tensor<?x?x?x?xf32>
-//   CHECK-DAG:   %[[extract1:.*]] = tensor.extract_slice %{{.*}}[0, 0, 0, 0] [%{{.*}}, 1, 1, 5] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<?x1x1x5xf32>
-//   CHECK-DAG:   %[[extract2:.*]] = tensor.extract_slice %{{.*}}[0, 0, 0, 0] [%{{.*}}, 1, 1, 5] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<?x1x1x5xf32>
+//   CHECK-DAG:   %[[extract1:.*]] = tensor.extract_slice %{{.*}}[0, 0, 0, 0]
+//   CHECK-SAME:    [%{{.*}}, 1, 1, 5] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<?x1x1x5xf32>
+//   CHECK-DAG:   %[[extract2:.*]] = tensor.extract_slice %{{.*}}[0, 0, 0, 0]
+//   CHECK-SAME:    [%{{.*}}, 1, 1, 5] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<?x1x1x5xf32>
 //       CHECK:   return %[[extract1]], %[[extract2]]
 func.func @expand_shape_of_rank_reducing_extract(
     %t: tensor<?x?x?x?xf32>, %idx: index)
@@ -26,7 +28,8 @@ func.func @expand_shape_of_rank_reducing_extract(
 //  CHECK-SAME:     %[[t:.*]]: tensor<?x?x?x?xf32>
 //  CHECK-SAME:     %[[x:[a-zA-Z0-9_]+]]: index
 //  CHECK-SAME:     %[[y:[a-zA-Z0-9_]+]]: index
-//       CHECK:   %[[extract:.*]] = tensor.extract_slice %[[t]][%[[x]], %[[y]], 0, 0] [1, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<?x?xf32>
+//       CHECK:   %[[extract:.*]] = tensor.extract_slice %[[t]][%[[x]], %[[y]], 0, 0]
+//  CHECK-SAME:     [1, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<?x?xf32>
 //       CHECK:   return %[[extract]]
 func.func @unpadding_collapse_of_extract_slice(
     %t: tensor<?x?x?x?xf32>, %x: index, %y: index)
@@ -49,7 +52,8 @@ func.func @unpadding_collapse_of_extract_slice(
 //  CHECK-SAME:     %[[x:[a-zA-Z0-9_]+]]: index
 //  CHECK-SAME:     %[[y:[a-zA-Z0-9_]+]]: index
 //  CHECK-SAME:     %[[sz:[a-zA-Z0-9_]+]]: index
-//       CHECK:   %[[extract:.*]] = tensor.extract_slice %[[t]][%[[x]], %[[y]], 0, 0] [%{{.*}}, %{{.*}}, %[[sz]], 1] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<?x?x?xf32>
+//       CHECK:   %[[extract:.*]] = tensor.extract_slice %[[t]][%[[x]], %[[y]], 0, 0]
+//  CHECK-SAME:     [%{{.*}}, %{{.*}}, %[[sz]], 1] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<?x?x?xf32>
 //       CHECK:   %[[collapse:.*]] = tensor.collapse_shape %[[extract]] {{\[}}[0], [1, 2]] : tensor<?x?x?xf32> into tensor<?x?xf32>
 //       CHECK:   return %[[collapse]]
 func.func @non_unpadding_collapse_of_extract_slice(
@@ -72,7 +76,8 @@ func.func @non_unpadding_collapse_of_extract_slice(
 //  CHECK-SAME:     %[[t:.*]]: tensor<?x?x?x?xf32>
 //  CHECK-SAME:     %[[x:[a-zA-Z0-9_]+]]: index
 //  CHECK-SAME:     %[[y:[a-zA-Z0-9_]+]]: index
-//       CHECK:   %[[extract:.*]] = tensor.extract_slice %[[t]][%[[x]], %[[y]], 0, 0] [1, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<1x?x1x?xf32>
+//       CHECK:   %[[extract:.*]] = tensor.extract_slice %[[t]][%[[x]], %[[y]], 0, 0]
+//  CHECK-SAME:     [1, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<1x?x1x?xf32>
 //       CHECK:   %[[collapse:.*]] = tensor.collapse_shape %[[extract]] {{\[}}[0, 1], [2, 3]] : tensor<1x?x1x?xf32> into tensor<?x?xf32>
 //       CHECK:   return %[[extract]], %[[collapse]]
 func.func @unpadding_collapse_of_extract_slice_with_multiple_users(
@@ -93,7 +98,8 @@ func.func @unpadding_collapse_of_extract_slice_with_multiple_users(
 
 // CHECK-LABEL: func @rank_reducing_insert_of_collapse_shape(
 //  CHECK-SAME:     %[[t:.*]]: tensor<?x1x1x5xf32>
-//       CHECK:   %[[insert:.*]] = tensor.insert_slice %[[t]] into %{{.*}}[0, 0, 0, 0] [%{{.*}}, 1, 1, 5] [1, 1, 1, 1] : tensor<?x1x1x5xf32> into tensor<?x?x?x?xf32>
+//       CHECK:   %[[insert:.*]] = tensor.insert_slice %[[t]] into %{{.*}}[0, 0, 0, 0]
+//  CHECK-SAME:     [%{{.*}}, 1, 1, 5] [1, 1, 1, 1] : tensor<?x1x1x5xf32> into tensor<?x?x?x?xf32>
 //       CHECK:   return %[[insert]]
 func.func @rank_reducing_insert_of_collapse_shape(
     %t: tensor<?x1x1x5xf32>, %d: tensor<?x?x?x?xf32>, %sz: index)
@@ -109,7 +115,8 @@ func.func @rank_reducing_insert_of_collapse_shape(
 
 // CHECK-LABEL: func @rank_reducing_parallel_insert_of_collapse_shape(
 //  CHECK-SAME:     %[[t:.*]]: tensor<?x1x1x5xf32>
-//       CHECK:   tensor.parallel_insert_slice %[[t]] into %{{.*}}[0, 0, 0, 0] [%{{.*}}, 1, 1, 5] [1, 1, 1, 1] : tensor<?x1x1x5xf32> into tensor<?x?x?x?xf32>
+//       CHECK:   tensor.parallel_insert_slice %[[t]] into %{{.*}}[0, 0, 0, 0]
+//  CHECK-SAME:     [%{{.*}}, 1, 1, 5] [1, 1, 1, 1] : tensor<?x1x1x5xf32> into tensor<?x?x?x?xf32>
 func.func @rank_reducing_parallel_insert_of_collapse_shape(
     %t: tensor<?x1x1x5xf32>, %d: tensor<?x?x?x?xf32>, %sz: index, %thr: index)
   -> tensor<?x?x?x?xf32> {
@@ -131,7 +138,8 @@ func.func @rank_reducing_parallel_insert_of_collapse_shape(
 //  CHECK-SAME:     %[[d:.*]]: tensor<?x?x?x?xf32>
 //  CHECK-SAME:     %[[x:[a-zA-Z0-9_]+]]: index
 //  CHECK-SAME:     %[[y:[a-zA-Z0-9_]+]]: index
-//       CHECK:   %[[insert:.*]] = tensor.insert_slice %[[t]] into %[[d]][%[[x]], %[[y]], 0, 0] [1, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?x?xf32>
+//       CHECK:   %[[insert:.*]] = tensor.insert_slice %[[t]] into %[[d]][%[[x]], %[[y]], 0, 0]
+//  CHECK-SAME:     [1, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?x?xf32>
 //       CHECK:   return %[[insert]]
 func.func @insert_of_padding_expand_shape(
     %t: tensor<?x?xf32>, %d: tensor<?x?x?x?xf32>, %x: index, %y: index)
@@ -155,8 +163,10 @@ func.func @insert_of_padding_expand_shape(
 //  CHECK-SAME:     %[[x:[a-zA-Z0-9_]+]]: index
 //  CHECK-SAME:     %[[y:[a-zA-Z0-9_]+]]: index
 //  CHECK-SAME:     %[[sz:[a-zA-Z0-9_]+]]: index
-//       CHECK:   %[[expand:.*]] = tensor.expand_shape %[[t]] {{\[}}[0, 1], [2]] output_shape [%[[sz]], %{{.*}}, %{{.*}}] : tensor<?x?xf32> into tensor<?x?x?xf32>
-//       CHECK:   %[[insert:.*]] = tensor.insert_slice %[[expand]] into %[[d]][%[[x]], %[[y]], 0, 0] [%[[sz]], 1, %{{.*}}, %{{.*}}] [1, 1, 1, 1] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
+//       CHECK:   %[[expand:.*]] = tensor.expand_shape %[[t]] {{\[}}[0, 1], [2]]
+//  CHECK-SAME:     output_shape [%[[sz]], %{{.*}}, %{{.*}}] : tensor<?x?xf32> into tensor<?x?x?xf32>
+//       CHECK:   %[[insert:.*]] = tensor.insert_slice %[[expand]] into %[[d]][%[[x]], %[[y]], 0, 0]
+//  CHECK-SAME:     [%[[sz]], 1, %{{.*}}, %{{.*}}] [1, 1, 1, 1] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
 //       CHECK:   return %[[insert]]
 func.func @insert_of_non_padding_expand_shape(
     %t: tensor<?x?xf32>, %d: tensor<?x?x?x?xf32>, %x: index, %y: index, %sz: index)
@@ -179,7 +189,8 @@ func.func @insert_of_non_padding_expand_shape(
 //  CHECK-SAME:     %[[d:.*]]: tensor<?x?x?x?xf32>
 //  CHECK-SAME:     %[[x:[a-zA-Z0-9_]+]]: index
 //  CHECK-SAME:     %[[y:[a-zA-Z0-9_]+]]: index
-//       CHECK:   tensor.parallel_insert_slice %[[t]] into %{{.*}}[%{{.*}}, %{{.*}}, 0, 0] [1, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?x?xf32>
+//       CHECK:   tensor.parallel_insert_slice %[[t]] into %{{.*}}[%{{.*}}, %{{.*}}, 0, 0]
+//  CHECK-SAME:     [1, %{{.*}}, 1, %{{.*}}] [1, 1, 1, 1] : tensor<?x?xf32> into tensor<?x?x?x?xf32>
 func.func @parallel_insert_of_padding_expand_shape(
     %t: tensor<?x?xf32>, %d: tensor<?x?x?x?xf32>, %x: index, %y: index)
   -> tensor<?x?x?x?xf32> {
@@ -206,8 +217,10 @@ func.func @parallel_insert_of_padding_expand_shape(
 //  CHECK-SAME:     %[[x:[a-zA-Z0-9_]+]]: index
 //  CHECK-SAME:     %[[y:[a-zA-Z0-9_]+]]: index
 //  CHECK-SAME:     %[[sz:[a-zA-Z0-9_]+]]: index
-//       CHECK:   %[[expand:.*]] = tensor.expand_shape %[[t]] {{\[}}[0, 1], [2]] output_shape [%[[sz]], %{{.*}}, %{{.*}}] : tensor<?x?xf32> into tensor<?x?x?xf32>
-//       CHECK:   tensor.parallel_insert_slice %[[expand]] into %{{.*}}[%{{.*}}, %{{.*}}, 0, 0] [%[[sz]], 1, %{{.*}}, %{{.*}}] [1, 1, 1, 1] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
+//       CHECK:   %[[expand:.*]] = tensor.expand_shape %[[t]] {{\[}}[0, 1], [2]]
+//  CHECK-SAME:     output_shape [%[[sz]], %{{.*}}, %{{.*}}] : tensor<?x?xf32> into tensor<?x?x?xf32>
+//       CHECK:   tensor.parallel_insert_slice %[[expand]] into %{{.*}}[%{{.*}}, %{{.*}}, 0, 0]
+//  CHECK-SAME:     [%[[sz]], 1, %{{.*}}, %{{.*}}] [1, 1, 1, 1] : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
 func.func @parallel_insert_of_non_padding_expand_shape(
     %t: tensor<?x?xf32>, %d: tensor<?x?x?x?xf32>, %x: index, %y: index, %sz: index)
   -> tensor<?x?x?x?xf32> {



More information about the Mlir-commits mailing list