[Mlir-commits] [mlir] 10054ba - [mlir][vector] Add pattern to rewrite contiguous ExtractStridedSlice into Extract (#111541)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Oct 8 08:51:05 PDT 2024


Author: Benoit Jacob
Date: 2024-10-08T11:51:01-04:00
New Revision: 10054ba4acbc5378d2e2aa869a5bccd88aa4b59e

URL: https://github.com/llvm/llvm-project/commit/10054ba4acbc5378d2e2aa869a5bccd88aa4b59e
DIFF: https://github.com/llvm/llvm-project/commit/10054ba4acbc5378d2e2aa869a5bccd88aa4b59e.diff

LOG: [mlir][vector] Add pattern to rewrite contiguous ExtractStridedSlice into Extract (#111541)

Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>

Added: 
    mlir/test/Dialect/Vector/vector-contiguous-extract-strided-slice-to-extract.mlir

Modified: 
    mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
    mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
    mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index a59f06f3c1ef1b..ec1de7fa66aa07 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -235,6 +235,11 @@ void populateVectorExtractStridedSliceToExtractInsertChainPatterns(
     std::function<bool(ExtractStridedSliceOp)> controlFn = nullptr,
     PatternBenefit benefit = 1);
 
+/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
+/// slice is contiguous, into extract and shape_cast.
+void populateVectorContiguousExtractStridedSliceToExtractPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit = 1);
+
 /// Populate `patterns` with a pattern to break down 1-D vector.bitcast ops
 /// based on the destination vector shape. Bitcasts from a lower bitwidth
 /// element type to a higher bitwidth one are extracted from the lower bitwidth

diff  --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index ec2ef3fc7501c2..c2da9347aadc87 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -329,12 +329,70 @@ class DecomposeNDExtractStridedSlice
   }
 };
 
+/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
+/// slice is contiguous, into extract and shape_cast.
+class ContiguousExtractStridedSliceToExtract final
+    : public OpRewritePattern<ExtractStridedSliceOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
+                                PatternRewriter &rewriter) const override {
+    if (op.hasNonUnitStrides()) {
+      return failure();
+    }
+    Value source = op.getOperand();
+    auto sourceType = cast<VectorType>(source.getType());
+    if (sourceType.isScalable()) {
+      return failure();
+    }
+
+    // Compute the number of offsets to pass to ExtractOp::build. That is the
+    // 
diff erence between the source rank and the desired slice rank. We walk
+    // the dimensions from innermost out, and stop when the next slice dimension
+    // is not full-size.
+    SmallVector<int64_t> sizes = getI64SubArray(op.getSizes());
+    int numOffsets;
+    for (numOffsets = sourceType.getRank(); numOffsets > 0; --numOffsets) {
+      if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1)) {
+        break;
+      }
+    }
+
+    // If not even the inner-most dimension is full-size, this op can't be
+    // rewritten as an ExtractOp.
+    if (numOffsets == sourceType.getRank()) {
+      return failure();
+    }
+
+    // Avoid generating slices that have unit outer dimensions. The shape_cast
+    // op that we create below would take bad generic fallback patterns
+    // (ShapeCastOpRewritePattern).
+    while (sizes[numOffsets] == 1 && numOffsets < sourceType.getRank() - 1) {
+      ++numOffsets;
+    }
+
+    SmallVector<int64_t> offsets = getI64SubArray(op.getOffsets());
+    auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
+    Value extract = rewriter.create<vector::ExtractOp>(op->getLoc(), source,
+                                                       extractOffsets);
+    rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(), extract);
+    return success();
+  }
+};
+
 void vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
   patterns.add<DecomposeDifferentRankInsertStridedSlice,
                DecomposeNDExtractStridedSlice>(patterns.getContext(), benefit);
 }
 
+void vector::populateVectorContiguousExtractStridedSliceToExtractPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<ContiguousExtractStridedSliceToExtract>(patterns.getContext(),
+                                                       benefit);
+}
+
 void vector::populateVectorExtractStridedSliceToExtractInsertChainPatterns(
     RewritePatternSet &patterns,
     std::function<bool(ExtractStridedSliceOp)> controlFn,

diff  --git a/mlir/test/Dialect/Vector/vector-contiguous-extract-strided-slice-to-extract.mlir b/mlir/test/Dialect/Vector/vector-contiguous-extract-strided-slice-to-extract.mlir
new file mode 100644
index 00000000000000..9147e7bf02581e
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-contiguous-extract-strided-slice-to-extract.mlir
@@ -0,0 +1,35 @@
+// RUN: mlir-opt --test-vector-contiguous-extract-strided-slice-to-extract %s | FileCheck %s
+
+// CHECK-LABEL: @extract_strided_slice_to_extract_i8
+// CHECK:       %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0] : vector<8xi8> from vector<8x1x1x2x8xi8>
+// CHECK:       return %[[EXTRACT]] :  vector<8xi8>
+func.func @extract_strided_slice_to_extract_i8(%arg0 : vector<8x1x1x2x8xi8>) -> vector<8xi8> {
+  %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 8], strides = [1, 1, 1, 1, 1]} : vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
+  %2 = vector.shape_cast %1 : vector<1x1x1x1x8xi8> to vector<8xi8>
+  return %2 : vector<8xi8>
+}
+
+// CHECK-LABEL: @extract_strided_slice_to_extract_i32
+// CHECK:        %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0, 0] : vector<4xi32> from vector<8x1x2x1x1x4xi32>
+// CHECK:       return %[[EXTRACT]] :  vector<4xi32>
+func.func @extract_strided_slice_to_extract_i32(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<4xi32> {
+  %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 1, 4], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x1x1x1x4xi32>
+  %2 = vector.shape_cast %1 : vector<1x1x1x1x1x4xi32> to vector<4xi32>
+  return %2 : vector<4xi32>
+}
+
+// CHECK-LABEL: @extract_strided_slice_to_extract_i32_non_contiguous_1
+// CHECK:        vector.extract_strided_slice
+func.func @extract_strided_slice_to_extract_i32_non_contiguous_1(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<2xi32> {
+  %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 1, 2], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x1x1x1x2xi32>
+  %2 = vector.shape_cast %1 : vector<1x1x1x1x1x2xi32> to vector<2xi32>
+  return %2 : vector<2xi32>
+}
+
+// CHECK-LABEL: @extract_strided_slice_to_extract_i32_non_contiguous_2
+// CHECK:        vector.extract_strided_slice
+func.func @extract_strided_slice_to_extract_i32_non_contiguous_2(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<2xi32> {
+  %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 2, 1, 1, 1], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x2x1x1x1xi32>
+  %2 = vector.shape_cast %1 : vector<1x1x2x1x1x1xi32> to vector<2xi32>
+  return %2 : vector<2xi32>
+}

diff  --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 72aaa7dc4f8973..d91e955b70641e 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -709,6 +709,27 @@ struct TestVectorExtractStridedSliceLowering
   }
 };
 
+struct TestVectorContiguousExtractStridedSliceToExtract
+    : public PassWrapper<TestVectorContiguousExtractStridedSliceToExtract,
+                         OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+      TestVectorExtractStridedSliceLowering)
+
+  StringRef getArgument() const final {
+    return "test-vector-contiguous-extract-strided-slice-to-extract";
+  }
+  StringRef getDescription() const final {
+    return "Test lowering patterns that rewrite simple cases of N-D "
+           "extract_strided_slice, where the slice is contiguous, into extract "
+           "and shape_cast";
+  }
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateVectorContiguousExtractStridedSliceToExtractPatterns(patterns);
+    (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+  }
+};
+
 struct TestVectorBreakDownBitCast
     : public PassWrapper<TestVectorBreakDownBitCast,
                          OperationPass<func::FuncOp>> {
@@ -935,6 +956,8 @@ void registerTestVectorLowerings() {
 
   PassRegistration<TestVectorExtractStridedSliceLowering>();
 
+  PassRegistration<TestVectorContiguousExtractStridedSliceToExtract>();
+
   PassRegistration<TestVectorBreakDownBitCast>();
 
   PassRegistration<TestCreateVectorBroadcast>();


        


More information about the Mlir-commits mailing list