[Mlir-commits] [mlir] a9ebdbb - [MLIR] Vector: turn the ExtractStridedSlice rewrite pattern from #111541 into a canonicalization (#111614)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Oct 9 06:24:27 PDT 2024
Author: Benoit Jacob
Date: 2024-10-09T09:24:23-04:00
New Revision: a9ebdbb5ac7de7a028f6060b789196a43aea7580
URL: https://github.com/llvm/llvm-project/commit/a9ebdbb5ac7de7a028f6060b789196a43aea7580
DIFF: https://github.com/llvm/llvm-project/commit/a9ebdbb5ac7de7a028f6060b789196a43aea7580.diff
LOG: [MLIR] Vector: turn the ExtractStridedSlice rewrite pattern from #111541 into a canonicalization (#111614)
This is a reasonable canonicalization because `extract` is more
constrained than `extract_strided_slices`, so there is no loss of
semantics here, just lifting an op to a special-case higher/constrained
op. And the additional `shape_cast` is merely adding leading unit dims
to match the original result type.
Context: discussion on #111541. I wasn't sure how this would turn out,
but in the process of writing this PR, I discovered at least 2 bugs in
the pattern introduced in #111541, which shows the value of shared
canonicalization patterns which are exercised on a high number of
testcases.
---------
Signed-off-by: Benoit Jacob <jacob.benoit.1 at gmail.com>
Added:
Modified:
mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Removed:
mlir/test/Dialect/Vector/vector-contiguous-extract-strided-slice-to-extract.mlir
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index ec1de7fa66aa07..a59f06f3c1ef1b 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -235,11 +235,6 @@ void populateVectorExtractStridedSliceToExtractInsertChainPatterns(
std::function<bool(ExtractStridedSliceOp)> controlFn = nullptr,
PatternBenefit benefit = 1);
-/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
-/// slice is contiguous, into extract and shape_cast.
-void populateVectorContiguousExtractStridedSliceToExtractPatterns(
- RewritePatternSet &patterns, PatternBenefit benefit = 1);
-
/// Populate `patterns` with a pattern to break down 1-D vector.bitcast ops
/// based on the destination vector shape. Bitcasts from a lower bitwidth
/// element type to a higher bitwidth one are extracted from the lower bitwidth
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 1718530b4aa167..a2abe1619454f2 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3772,6 +3772,82 @@ class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
}
};
+/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
+/// slice is contiguous, into extract and shape_cast.
+///
+/// Example:
+/// Before:
+/// %1 = vector.extract_strided_slice %arg0 {
+/// offsets = [0, 0, 0, 0, 0],
+/// sizes = [1, 1, 1, 1, 8],
+/// strides = [1, 1, 1, 1, 1]
+/// } : vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
+/// After:
+/// %0 = vector.extract %arg0[0, 0, 0, 0]
+/// : vector<8xi8> from vector<8x1x1x2x8xi8>
+/// %1 = vector.shape_cast %0
+/// : vector<8xi8> to vector<1x1x1x1x8xi8>
+///
+class ContiguousExtractStridedSliceToExtract final
+ : public OpRewritePattern<ExtractStridedSliceOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
+ PatternRewriter &rewriter) const override {
+ if (op.hasNonUnitStrides())
+ return failure();
+ Value source = op.getOperand();
+ auto sourceType = cast<VectorType>(source.getType());
+ if (sourceType.isScalable() || sourceType.getRank() == 0)
+ return failure();
+
+ // Compute the number of offsets to pass to ExtractOp::build. That is the
+ //
diff erence between the source rank and the desired slice rank. We walk
+ // the dimensions from innermost out, and stop when the next slice dimension
+ // is not full-size.
+ SmallVector<int64_t> sizes = getI64SubArray(op.getSizes());
+ int numOffsets;
+ for (numOffsets = sizes.size(); numOffsets > 0; --numOffsets) {
+ if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1))
+ break;
+ }
+
+ // If the created extract op would have no offsets, then this whole
+ // extract_strided_slice is the identity and should have been handled by
+ // other canonicalizations.
+ if (numOffsets == 0)
+ return failure();
+
+ // If not even the inner-most dimension is full-size, this op can't be
+ // rewritten as an ExtractOp.
+ if (numOffsets == sourceType.getRank() &&
+ static_cast<int>(sizes.size()) == sourceType.getRank())
+ return failure();
+
+ // The outer dimensions must have unit size.
+ for (int i = 0; i < numOffsets; ++i) {
+ if (sizes[i] != 1)
+ return failure();
+ }
+
+ // Avoid generating slices that have leading unit dimensions. The shape_cast
+ // op that we create below would take bad generic fallback patterns
+ // (ShapeCastOpRewritePattern).
+ while (sizes[numOffsets] == 1 &&
+ numOffsets < static_cast<int>(sizes.size()) - 1) {
+ ++numOffsets;
+ }
+
+ SmallVector<int64_t> offsets = getI64SubArray(op.getOffsets());
+ auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
+ Value extract = rewriter.create<vector::ExtractOp>(op->getLoc(), source,
+ extractOffsets);
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(), extract);
+ return success();
+ }
+};
+
} // namespace
void ExtractStridedSliceOp::getCanonicalizationPatterns(
@@ -3780,7 +3856,8 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
// ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
results.add<StridedSliceConstantMaskFolder, StridedSliceSplatConstantFolder,
StridedSliceNonSplatConstantFolder, StridedSliceBroadcast,
- StridedSliceSplat>(context);
+ StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
+ context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index ad845608f18d10..ec2ef3fc7501c2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -329,81 +329,12 @@ class DecomposeNDExtractStridedSlice
}
};
-/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
-/// slice is contiguous, into extract and shape_cast.
-///
-/// Example:
-/// Before:
-/// %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0],
-/// sizes = [1, 1, 1, 1, 8], strides = [1, 1, 1, 1, 1]} :
-/// vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
-/// After:
-/// %0 = vector.extract %arg0[0, 0, 0, 0] : vector<8xi8> from
-/// vector<8x1x1x2x8xi8> %1 = vector.shape_cast %0 : vector<8xi8> to
-/// vector<1x1x1x1x8xi8>
-///
-class ContiguousExtractStridedSliceToExtract final
- : public OpRewritePattern<ExtractStridedSliceOp> {
-public:
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
- PatternRewriter &rewriter) const override {
- if (op.hasNonUnitStrides()) {
- return failure();
- }
- Value source = op.getOperand();
- auto sourceType = cast<VectorType>(source.getType());
- if (sourceType.isScalable()) {
- return failure();
- }
-
- // Compute the number of offsets to pass to ExtractOp::build. That is the
- //
diff erence between the source rank and the desired slice rank. We walk
- // the dimensions from innermost out, and stop when the next slice dimension
- // is not full-size.
- SmallVector<int64_t> sizes = getI64SubArray(op.getSizes());
- int numOffsets;
- for (numOffsets = sourceType.getRank(); numOffsets > 0; --numOffsets) {
- if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1)) {
- break;
- }
- }
-
- // If not even the inner-most dimension is full-size, this op can't be
- // rewritten as an ExtractOp.
- if (numOffsets == sourceType.getRank()) {
- return failure();
- }
-
- // Avoid generating slices that have unit outer dimensions. The shape_cast
- // op that we create below would take bad generic fallback patterns
- // (ShapeCastOpRewritePattern).
- while (sizes[numOffsets] == 1 && numOffsets < sourceType.getRank() - 1) {
- ++numOffsets;
- }
-
- SmallVector<int64_t> offsets = getI64SubArray(op.getOffsets());
- auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
- Value extract = rewriter.create<vector::ExtractOp>(op->getLoc(), source,
- extractOffsets);
- rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(), extract);
- return success();
- }
-};
-
void vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<DecomposeDifferentRankInsertStridedSlice,
DecomposeNDExtractStridedSlice>(patterns.getContext(), benefit);
}
-void vector::populateVectorContiguousExtractStridedSliceToExtractPatterns(
- RewritePatternSet &patterns, PatternBenefit benefit) {
- patterns.add<ContiguousExtractStridedSliceToExtract>(patterns.getContext(),
- benefit);
-}
-
void vector::populateVectorExtractStridedSliceToExtractInsertChainPatterns(
RewritePatternSet &patterns,
std::function<bool(ExtractStridedSliceOp)> controlFn,
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index b7c78de4b5bd89..6d6bc199e601c0 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2742,3 +2742,52 @@ func.func @vector_insert_const_regression(%arg0: i8) -> vector<4xi8> {
%1 = vector.insert %arg0, %0 [0] : i8 into vector<4xi8>
return %1 : vector<4xi8>
}
+
+// -----
+
+// CHECK-LABEL: @contiguous_extract_strided_slices_to_extract
+// CHECK: %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0, 0] : vector<4xi32> from vector<8x1x2x1x1x4xi32>
+// CHECK-NEXT: return %[[EXTRACT]] : vector<4xi32>
+func.func @contiguous_extract_strided_slices_to_extract(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<4xi32> {
+ %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 1, 4], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x1x1x1x4xi32>
+ %2 = vector.shape_cast %1 : vector<1x1x1x1x1x4xi32> to vector<4xi32>
+ return %2 : vector<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @contiguous_extract_strided_slices_to_extract_shorter_size_list
+// CHECK: %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0] : vector<1x4xi32> from vector<8x1x2x1x1x4xi32>
+// CHECK-NEXT: return %[[EXTRACT]] : vector<1x4xi32>
+func.func @contiguous_extract_strided_slices_to_extract_shorter_size_list(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<1x4xi32> {
+ %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 1], strides = [1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x1x1x1x4xi32>
+ %2 = vector.shape_cast %1 : vector<1x1x1x1x1x4xi32> to vector<1x4xi32>
+ return %2 : vector<1x4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @contiguous_extract_strided_slices_to_extract_failure_non_unit_outer_size
+// CHECK-NEXT: vector.extract_strided_slice
+func.func @contiguous_extract_strided_slices_to_extract_failure_non_unit_outer_size(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<8x1x1x1x1x4xi32> {
+ %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [8, 1, 1, 1, 1, 4], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<8x1x1x1x1x4xi32>
+ return %1 : vector<8x1x1x1x1x4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @contiguous_extract_strided_slices_to_extract_failure_non_full_size
+// CHECK-NEXT: vector.extract_strided_slice
+func.func @contiguous_extract_strided_slices_to_extract_failure_non_full_size(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<1x1x1x1x1x2xi32> {
+ %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 1, 2], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x1x1x1x2xi32>
+ return %1 : vector<1x1x1x1x1x2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @contiguous_extract_strided_slices_to_extract_failure_non_full_inner_size
+// CHECK-NEXT: vector.extract_strided_slice
+func.func @contiguous_extract_strided_slices_to_extract_failure_non_full_inner_size(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<1x1x2x1x1x1xi32> {
+ %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 2, 1, 1, 1], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x2x1x1x1xi32>
+ return %1 : vector<1x1x2x1x1x1xi32>
+}
diff --git a/mlir/test/Dialect/Vector/vector-contiguous-extract-strided-slice-to-extract.mlir b/mlir/test/Dialect/Vector/vector-contiguous-extract-strided-slice-to-extract.mlir
deleted file mode 100644
index d1401ad7853fc9..00000000000000
--- a/mlir/test/Dialect/Vector/vector-contiguous-extract-strided-slice-to-extract.mlir
+++ /dev/null
@@ -1,24 +0,0 @@
-// RUN: mlir-opt --test-vector-contiguous-extract-strided-slice-to-extract %s | FileCheck %s
-
-// CHECK-LABEL: @contiguous
-// CHECK: %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0, 0] : vector<4xi32> from vector<8x1x2x1x1x4xi32>
-// CHECK-NEXT: return %[[EXTRACT]] : vector<4xi32>
-func.func @contiguous(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<4xi32> {
- %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 1, 4], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x1x1x1x4xi32>
- %2 = vector.shape_cast %1 : vector<1x1x1x1x1x4xi32> to vector<4xi32>
- return %2 : vector<4xi32>
-}
-
-// CHECK-LABEL: @non_full_size
-// CHECK-NEXT: vector.extract_strided_slice
-func.func @non_full_size(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<1x1x1x1x1x2xi32> {
- %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 1, 2], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x1x1x1x2xi32>
- return %1 : vector<1x1x1x1x1x2xi32>
-}
-
-// CHECK-LABEL: @non_full_inner_size
-// CHECK-NEXT: vector.extract_strided_slice
-func.func @non_full_inner_size(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<1x1x2x1x1x1xi32> {
- %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 2, 1, 1, 1], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x2x1x1x1xi32>
- return %1 : vector<1x1x2x1x1x1xi32>
-}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index d91e955b70641e..72aaa7dc4f8973 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -709,27 +709,6 @@ struct TestVectorExtractStridedSliceLowering
}
};
-struct TestVectorContiguousExtractStridedSliceToExtract
- : public PassWrapper<TestVectorContiguousExtractStridedSliceToExtract,
- OperationPass<func::FuncOp>> {
- MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
- TestVectorExtractStridedSliceLowering)
-
- StringRef getArgument() const final {
- return "test-vector-contiguous-extract-strided-slice-to-extract";
- }
- StringRef getDescription() const final {
- return "Test lowering patterns that rewrite simple cases of N-D "
- "extract_strided_slice, where the slice is contiguous, into extract "
- "and shape_cast";
- }
- void runOnOperation() override {
- RewritePatternSet patterns(&getContext());
- populateVectorContiguousExtractStridedSliceToExtractPatterns(patterns);
- (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
- }
-};
-
struct TestVectorBreakDownBitCast
: public PassWrapper<TestVectorBreakDownBitCast,
OperationPass<func::FuncOp>> {
@@ -956,8 +935,6 @@ void registerTestVectorLowerings() {
PassRegistration<TestVectorExtractStridedSliceLowering>();
- PassRegistration<TestVectorContiguousExtractStridedSliceToExtract>();
-
PassRegistration<TestVectorBreakDownBitCast>();
PassRegistration<TestCreateVectorBroadcast>();
More information about the Mlir-commits
mailing list