[Mlir-commits] [mlir] [MLIR] Vector: turn the ExtractStridedSlice rewrite pattern from #111541 into a canonicalization (PR #111614)

Benoit Jacob llvmlistbot at llvm.org
Wed Oct 9 04:04:53 PDT 2024


https://github.com/bjacob updated https://github.com/llvm/llvm-project/pull/111614

>From 80fc72b859610db2ca01ed45eba7a896a16d27e2 Mon Sep 17 00:00:00 2001
From: Benoit Jacob <jacob.benoit.1 at gmail.com>
Date: Tue, 8 Oct 2024 21:45:04 -0400
Subject: [PATCH 1/2] canonicalize

---
 .../Vector/Transforms/VectorRewritePatterns.h |  5 --
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 89 ++++++++++++++++++-
 ...sertExtractStridedSliceRewritePatterns.cpp | 69 --------------
 mlir/test/Dialect/Vector/canonicalize.mlir    | 49 ++++++++++
 ...uous-extract-strided-slice-to-extract.mlir | 24 -----
 .../Dialect/Vector/TestVectorTransforms.cpp   | 23 -----
 6 files changed, 137 insertions(+), 122 deletions(-)
 delete mode 100644 mlir/test/Dialect/Vector/vector-contiguous-extract-strided-slice-to-extract.mlir

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index ec1de7fa66aa07..a59f06f3c1ef1b 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -235,11 +235,6 @@ void populateVectorExtractStridedSliceToExtractInsertChainPatterns(
     std::function<bool(ExtractStridedSliceOp)> controlFn = nullptr,
     PatternBenefit benefit = 1);
 
-/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
-/// slice is contiguous, into extract and shape_cast.
-void populateVectorContiguousExtractStridedSliceToExtractPatterns(
-    RewritePatternSet &patterns, PatternBenefit benefit = 1);
-
 /// Populate `patterns` with a pattern to break down 1-D vector.bitcast ops
 /// based on the destination vector shape. Bitcasts from a lower bitwidth
 /// element type to a higher bitwidth one are extracted from the lower bitwidth
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index dc92bea09dc160..cda31706474e27 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3772,6 +3772,92 @@ class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
   }
 };
 
+/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
+/// slice is contiguous, into extract and shape_cast.
+///
+/// Example:
+///     Before:
+///         %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0],
+///         sizes = [1, 1, 1, 1, 8], strides = [1, 1, 1, 1, 1]} :
+///         vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
+///     After:
+///         %0 = vector.extract %arg0[0, 0, 0, 0] : vector<8xi8> from
+///         vector<8x1x1x2x8xi8> %1 = vector.shape_cast %0 : vector<8xi8> to
+///         vector<1x1x1x1x8xi8>
+///
+class ContiguousExtractStridedSliceToExtract final
+    : public OpRewritePattern<ExtractStridedSliceOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
+                                PatternRewriter &rewriter) const override {
+    if (op.hasNonUnitStrides()) {
+      return failure();
+    }
+    Value source = op.getOperand();
+    auto sourceType = cast<VectorType>(source.getType());
+    if (sourceType.isScalable() || sourceType.getRank() == 0) {
+      return failure();
+    }
+
+    // Compute the number of offsets to pass to ExtractOp::build. That is the
+    // difference between the source rank and the desired slice rank. We walk
+    // the dimensions from innermost out, and stop when the next slice dimension
+    // is not full-size.
+    SmallVector<int64_t> sizes = getI64SubArray(op.getSizes());
+    int numOffsets;
+    for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) {
+      if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1)) {
+        break;
+      }
+    }
+
+    // If the created extract op would have no offsets, then this whole
+    // extract_strided_slice is the identity and should have been handled by
+    // other canonicalizations.
+    if (numOffsets == 0) {
+      return failure();
+    }
+
+    // If not even the inner-most dimension is full-size, this op can't be
+    // rewritten as an ExtractOp.
+    if (numOffsets == sourceType.getRank() &&
+        static_cast<int>(sizes.size()) == sourceType.getRank()) {
+      return failure();
+    }
+
+    // The outer dimensions must have unit size.
+    for (int i = 0; i < numOffsets; ++i) {
+      if (sizes[i] != 1) {
+        return failure();
+      }
+    }
+
+    // Avoid generating slices that have leading unit dimensions. The shape_cast
+    // op that we create below would take bad generic fallback patterns
+    // (ShapeCastOpRewritePattern).
+    while (sizes[numOffsets] == 1 &&
+           numOffsets < static_cast<int>(sizes.size()) - 1) {
+      ++numOffsets;
+    }
+    // After exhausting the list of slice sizes, we keep checking for unit
+    // dimensions in the source shape, to remove corner cases where the result
+    // would have a leading unit dimension.
+    while (sourceType.getDimSize(numOffsets) == 1 &&
+           numOffsets < sourceType.getRank() - 1) {
+      ++numOffsets;
+    }
+
+    SmallVector<int64_t> offsets = getI64SubArray(op.getOffsets());
+    auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
+    Value extract = rewriter.create<vector::ExtractOp>(op->getLoc(), source,
+                                                       extractOffsets);
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(), extract);
+    return success();
+  }
+};
+
 } // namespace
 
 void ExtractStridedSliceOp::getCanonicalizationPatterns(
@@ -3780,7 +3866,8 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
   // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
   results.add<StridedSliceConstantMaskFolder, StridedSliceSplatConstantFolder,
               StridedSliceNonSplatConstantFolder, StridedSliceBroadcast,
-              StridedSliceSplat>(context);
+              StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
+      context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index ad845608f18d10..ec2ef3fc7501c2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -329,81 +329,12 @@ class DecomposeNDExtractStridedSlice
   }
 };
 
-/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
-/// slice is contiguous, into extract and shape_cast.
-///
-/// Example:
-///     Before:
-///         %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0],
-///         sizes = [1, 1, 1, 1, 8], strides = [1, 1, 1, 1, 1]} :
-///         vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
-///     After:
-///         %0 = vector.extract %arg0[0, 0, 0, 0] : vector<8xi8> from
-///         vector<8x1x1x2x8xi8> %1 = vector.shape_cast %0 : vector<8xi8> to
-///         vector<1x1x1x1x8xi8>
-///
-class ContiguousExtractStridedSliceToExtract final
-    : public OpRewritePattern<ExtractStridedSliceOp> {
-public:
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
-                                PatternRewriter &rewriter) const override {
-    if (op.hasNonUnitStrides()) {
-      return failure();
-    }
-    Value source = op.getOperand();
-    auto sourceType = cast<VectorType>(source.getType());
-    if (sourceType.isScalable()) {
-      return failure();
-    }
-
-    // Compute the number of offsets to pass to ExtractOp::build. That is the
-    // difference between the source rank and the desired slice rank. We walk
-    // the dimensions from innermost out, and stop when the next slice dimension
-    // is not full-size.
-    SmallVector<int64_t> sizes = getI64SubArray(op.getSizes());
-    int numOffsets;
-    for (numOffsets = sourceType.getRank(); numOffsets > 0; --numOffsets) {
-      if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1)) {
-        break;
-      }
-    }
-
-    // If not even the inner-most dimension is full-size, this op can't be
-    // rewritten as an ExtractOp.
-    if (numOffsets == sourceType.getRank()) {
-      return failure();
-    }
-
-    // Avoid generating slices that have unit outer dimensions. The shape_cast
-    // op that we create below would take bad generic fallback patterns
-    // (ShapeCastOpRewritePattern).
-    while (sizes[numOffsets] == 1 && numOffsets < sourceType.getRank() - 1) {
-      ++numOffsets;
-    }
-
-    SmallVector<int64_t> offsets = getI64SubArray(op.getOffsets());
-    auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
-    Value extract = rewriter.create<vector::ExtractOp>(op->getLoc(), source,
-                                                       extractOffsets);
-    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(), extract);
-    return success();
-  }
-};
-
 void vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns.add<DecomposeDifferentRankInsertStridedSlice,
                DecomposeNDExtractStridedSlice>(patterns.getContext(), benefit);
 }
 
-void vector::populateVectorContiguousExtractStridedSliceToExtractPatterns(
-    RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns.add<ContiguousExtractStridedSliceToExtract>(patterns.getContext(),
-                                                       benefit);
-}
-
 void vector::populateVectorExtractStridedSliceToExtractInsertChainPatterns(
     RewritePatternSet &patterns,
     std::function<bool(ExtractStridedSliceOp)> controlFn,
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index b7c78de4b5bd89..15b77c91439cfe 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2742,3 +2742,52 @@ func.func @vector_insert_const_regression(%arg0: i8) -> vector<4xi8> {
   %1 = vector.insert %arg0, %0 [0] : i8 into vector<4xi8>
   return %1 : vector<4xi8>
 }
+
+// -----
+
+// CHECK-LABEL: @contiguous_extract_strided_slices_to_extract
+// CHECK:        %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0, 0] : vector<4xi32> from vector<8x1x2x1x1x4xi32>
+// CHECK-NEXT:   return %[[EXTRACT]] :  vector<4xi32>
+func.func @contiguous_extract_strided_slices_to_extract(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<4xi32> {
+  %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 1, 4], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x1x1x1x4xi32>
+  %2 = vector.shape_cast %1 : vector<1x1x1x1x1x4xi32> to vector<4xi32>
+  return %2 : vector<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @contiguous_extract_strided_slices_to_extract_shorter_size_list
+// CHECK:        %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0, 0] : vector<4xi32> from vector<8x1x2x1x1x4xi32>
+// CHECK-NEXT:   return %[[EXTRACT]] :  vector<4xi32>
+func.func @contiguous_extract_strided_slices_to_extract_shorter_size_list(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<4xi32> {
+  %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 1], strides = [1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x1x1x1x4xi32>
+  %2 = vector.shape_cast %1 : vector<1x1x1x1x1x4xi32> to vector<4xi32>
+  return %2 : vector<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @contiguous_extract_strided_slices_to_extract_failure_non_unit_outer_size
+// CHECK-NEXT:   vector.extract_strided_slice
+func.func @contiguous_extract_strided_slices_to_extract_failure_non_unit_outer_size(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<8x1x1x1x1x4xi32> {
+  %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [8, 1, 1, 1, 1, 4], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<8x1x1x1x1x4xi32>
+  return %1 : vector<8x1x1x1x1x4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @contiguous_extract_strided_slices_to_extract_failure_non_full_size
+// CHECK-NEXT:   vector.extract_strided_slice
+func.func @contiguous_extract_strided_slices_to_extract_failure_non_full_size(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<1x1x1x1x1x2xi32> {
+  %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 1, 2], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x1x1x1x2xi32>
+  return %1 : vector<1x1x1x1x1x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @contiguous_extract_strided_slices_to_extract_failure_non_full_inner_size
+// CHECK-NEXT:    vector.extract_strided_slice
+func.func @contiguous_extract_strided_slices_to_extract_failure_non_full_inner_size(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<1x1x2x1x1x1xi32> {
+  %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 2, 1, 1, 1], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x2x1x1x1xi32>
+  return %1 : vector<1x1x2x1x1x1xi32>
+}
diff --git a/mlir/test/Dialect/Vector/vector-contiguous-extract-strided-slice-to-extract.mlir b/mlir/test/Dialect/Vector/vector-contiguous-extract-strided-slice-to-extract.mlir
deleted file mode 100644
index d1401ad7853fc9..00000000000000
--- a/mlir/test/Dialect/Vector/vector-contiguous-extract-strided-slice-to-extract.mlir
+++ /dev/null
@@ -1,24 +0,0 @@
-// RUN: mlir-opt --test-vector-contiguous-extract-strided-slice-to-extract %s | FileCheck %s
-
-// CHECK-LABEL: @contiguous
-// CHECK:        %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0, 0] : vector<4xi32> from vector<8x1x2x1x1x4xi32>
-// CHECK-NEXT:   return %[[EXTRACT]] :  vector<4xi32>
-func.func @contiguous(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<4xi32> {
-  %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 1, 4], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x1x1x1x4xi32>
-  %2 = vector.shape_cast %1 : vector<1x1x1x1x1x4xi32> to vector<4xi32>
-  return %2 : vector<4xi32>
-}
-
-// CHECK-LABEL: @non_full_size
-// CHECK-NEXT:   vector.extract_strided_slice
-func.func @non_full_size(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<1x1x1x1x1x2xi32> {
-  %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 1, 2], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x1x1x1x2xi32>
-  return %1 : vector<1x1x1x1x1x2xi32>
-}
-
-// CHECK-LABEL: @non_full_inner_size
-// CHECK-NEXT:    vector.extract_strided_slice
-func.func @non_full_inner_size(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<1x1x2x1x1x1xi32> {
-  %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 2, 1, 1, 1], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x2x1x1x1xi32>
-  return %1 : vector<1x1x2x1x1x1xi32>
-}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index d91e955b70641e..72aaa7dc4f8973 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -709,27 +709,6 @@ struct TestVectorExtractStridedSliceLowering
   }
 };
 
-struct TestVectorContiguousExtractStridedSliceToExtract
-    : public PassWrapper<TestVectorContiguousExtractStridedSliceToExtract,
-                         OperationPass<func::FuncOp>> {
-  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
-      TestVectorExtractStridedSliceLowering)
-
-  StringRef getArgument() const final {
-    return "test-vector-contiguous-extract-strided-slice-to-extract";
-  }
-  StringRef getDescription() const final {
-    return "Test lowering patterns that rewrite simple cases of N-D "
-           "extract_strided_slice, where the slice is contiguous, into extract "
-           "and shape_cast";
-  }
-  void runOnOperation() override {
-    RewritePatternSet patterns(&getContext());
-    populateVectorContiguousExtractStridedSliceToExtractPatterns(patterns);
-    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
-  }
-};
-
 struct TestVectorBreakDownBitCast
     : public PassWrapper<TestVectorBreakDownBitCast,
                          OperationPass<func::FuncOp>> {
@@ -956,8 +935,6 @@ void registerTestVectorLowerings() {
 
   PassRegistration<TestVectorExtractStridedSliceLowering>();
 
-  PassRegistration<TestVectorContiguousExtractStridedSliceToExtract>();
-
   PassRegistration<TestVectorBreakDownBitCast>();
 
   PassRegistration<TestCreateVectorBroadcast>();

>From 17d9d8eaf808a3f785b467e43c8ae7a4fa7d4a07 Mon Sep 17 00:00:00 2001
From: Benoit Jacob <jacob.benoit.1 at gmail.com>
Date: Wed, 9 Oct 2024 07:00:40 -0400
Subject: [PATCH 2/2] review comments and fix

Signed-off-by: Benoit Jacob <jacob.benoit.1 at gmail.com>
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   | 40 ++++++++--------------
 mlir/test/Dialect/Vector/canonicalize.mlir | 10 +++---
 2 files changed, 20 insertions(+), 30 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index cda31706474e27..decf43749d6639 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3777,13 +3777,16 @@ class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
 ///
 /// Example:
 ///     Before:
-///         %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0],
-///         sizes = [1, 1, 1, 1, 8], strides = [1, 1, 1, 1, 1]} :
-///         vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
+///         %1 = vector.extract_strided_slice %arg0 {
+///                offsets = [0, 0, 0, 0, 0],
+///                sizes = [1, 1, 1, 1, 8],
+///                strides = [1, 1, 1, 1, 1]
+///              } : vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
 ///     After:
-///         %0 = vector.extract %arg0[0, 0, 0, 0] : vector<8xi8> from
-///         vector<8x1x1x2x8xi8> %1 = vector.shape_cast %0 : vector<8xi8> to
-///         vector<1x1x1x1x8xi8>
+///         %0 = vector.extract %arg0[0, 0, 0, 0]
+///                : vector<8xi8> from vector<8x1x1x2x8xi8>
+///         %1 = vector.shape_cast %0
+///                : vector<8xi8> to vector<1x1x1x1x8xi8>
 ///
 class ContiguousExtractStridedSliceToExtract final
     : public OpRewritePattern<ExtractStridedSliceOp> {
@@ -3792,14 +3795,12 @@ class ContiguousExtractStridedSliceToExtract final
 
   LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
                                 PatternRewriter &rewriter) const override {
-    if (op.hasNonUnitStrides()) {
+    if (op.hasNonUnitStrides())
       return failure();
-    }
     Value source = op.getOperand();
     auto sourceType = cast<VectorType>(source.getType());
-    if (sourceType.isScalable() || sourceType.getRank() == 0) {
+    if (sourceType.isScalable() || sourceType.getRank() == 0)
       return failure();
-    }
 
     // Compute the number of offsets to pass to ExtractOp::build. That is the
     // difference between the source rank and the desired slice rank. We walk
@@ -3808,30 +3809,26 @@ class ContiguousExtractStridedSliceToExtract final
     SmallVector<int64_t> sizes = getI64SubArray(op.getSizes());
     int numOffsets;
     for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) {
-      if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1)) {
+      if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1))
         break;
-      }
     }
 
     // If the created extract op would have no offsets, then this whole
     // extract_strided_slice is the identity and should have been handled by
     // other canonicalizations.
-    if (numOffsets == 0) {
+    if (numOffsets == 0)
       return failure();
-    }
 
     // If not even the inner-most dimension is full-size, this op can't be
     // rewritten as an ExtractOp.
     if (numOffsets == sourceType.getRank() &&
-        static_cast<int>(sizes.size()) == sourceType.getRank()) {
+        static_cast<int>(sizes.size()) == sourceType.getRank())
       return failure();
-    }
 
     // The outer dimensions must have unit size.
     for (int i = 0; i < numOffsets; ++i) {
-      if (sizes[i] != 1) {
+      if (sizes[i] != 1)
         return failure();
-      }
     }
 
     // Avoid generating slices that have leading unit dimensions. The shape_cast
@@ -3841,13 +3838,6 @@ class ContiguousExtractStridedSliceToExtract final
            numOffsets < static_cast<int>(sizes.size()) - 1) {
       ++numOffsets;
     }
-    // After exhausting the list of slice sizes, we keep checking for unit
-    // dimensions in the source shape, to remove corner cases where the result
-    // would have a leading unit dimension.
-    while (sourceType.getDimSize(numOffsets) == 1 &&
-           numOffsets < sourceType.getRank() - 1) {
-      ++numOffsets;
-    }
 
     SmallVector<int64_t> offsets = getI64SubArray(op.getOffsets());
     auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 15b77c91439cfe..6d6bc199e601c0 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2757,12 +2757,12 @@ func.func @contiguous_extract_strided_slices_to_extract(%arg0 : vector<8x1x2x1x1
 // -----
 
 // CHECK-LABEL: @contiguous_extract_strided_slices_to_extract_shorter_size_list
-// CHECK:        %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0, 0] : vector<4xi32> from vector<8x1x2x1x1x4xi32>
-// CHECK-NEXT:   return %[[EXTRACT]] :  vector<4xi32>
-func.func @contiguous_extract_strided_slices_to_extract_shorter_size_list(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<4xi32> {
+// CHECK:        %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0] : vector<1x4xi32> from vector<8x1x2x1x1x4xi32>
+// CHECK-NEXT:   return %[[EXTRACT]] :  vector<1x4xi32>
+func.func @contiguous_extract_strided_slices_to_extract_shorter_size_list(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<1x4xi32> {
   %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 1], strides = [1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x1x1x1x4xi32>
-  %2 = vector.shape_cast %1 : vector<1x1x1x1x1x4xi32> to vector<4xi32>
-  return %2 : vector<4xi32>
+  %2 = vector.shape_cast %1 : vector<1x1x1x1x1x4xi32> to vector<1x4xi32>
+  return %2 : vector<1x4xi32>
 }
 
 // -----



More information about the Mlir-commits mailing list