[Mlir-commits] [mlir] [mlir][vector] Add pattern to rewrite contiguous ExtractStridedSlice into Extract (PR #111541)
Benoit Jacob
llvmlistbot at llvm.org
Tue Oct 8 08:37:48 PDT 2024
https://github.com/bjacob updated https://github.com/llvm/llvm-project/pull/111541
>From 97202f6bcd579653e5b504d044311d05e30f58a9 Mon Sep 17 00:00:00 2001
From: Benoit Jacob <jacob.benoit.1 at gmail.com>
Date: Tue, 8 Oct 2024 09:38:35 -0500
Subject: [PATCH 1/3] contiguous-extract-strided-slice
---
.../Vector/Transforms/VectorRewritePatterns.h | 6 ++
...sertExtractStridedSliceRewritePatterns.cpp | 64 +++++++++++++++++++
...uous-extract-strided-slice-to-extract.mlir | 40 ++++++++++++
.../Dialect/Vector/TestVectorTransforms.cpp | 23 +++++++
4 files changed, 133 insertions(+)
create mode 100644 mlir/test/Dialect/Vector/vector-contiguous-extract-strided-slice-to-extract.mlir
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index a59f06f3c1ef1b..9ad78cc282b674 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -235,6 +235,12 @@ void populateVectorExtractStridedSliceToExtractInsertChainPatterns(
std::function<bool(ExtractStridedSliceOp)> controlFn = nullptr,
PatternBenefit benefit = 1);
+/// Populate `patterns` with a pattern to rewrite simple cases of N-D
+/// extract_strided_slice, where the slice is contiguous, into extract and
+/// shape_cast.
+void populateVectorContiguousExtractStridedSliceToExtractPatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit = 1);
+
/// Populate `patterns` with a pattern to break down 1-D vector.bitcast ops
/// based on the destination vector shape. Bitcasts from a lower bitwidth
/// element type to a higher bitwidth one are extracted from the lower bitwidth
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index ec2ef3fc7501c2..324c7b84ebfa0d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -329,12 +329,76 @@ class DecomposeNDExtractStridedSlice
}
};
+static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
+ SmallVectorImpl<int64_t> &results) {
+ for (auto attr : arrayAttr)
+ results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
+}
+
+/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
+/// slice is contiguous, into extract and shape_cast.
+class ContiguousExtractStridedSliceToExtract final
+ : public OpRewritePattern<ExtractStridedSliceOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
+ PatternRewriter &rewriter) const override {
+ if (op.hasNonUnitStrides()) {
+ return failure();
+ }
+ SmallVector<int64_t> sizes;
+ populateFromInt64AttrArray(op.getSizes(), sizes);
+ Value source = op.getOperand();
+ ShapedType sourceType = cast<ShapedType>(source.getType());
+
+ // Compute the number of offsets to pass to ExtractOp::build. That is the
+ // difference between the source rank and the desired slice rank. We walk
+ // the dimensions from innermost out, and stop when the next slice dimension
+ // is not full-size.
+ int numOffsets;
+ for (numOffsets = sourceType.getRank(); numOffsets > 0; --numOffsets) {
+ if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1)) {
+ break;
+ }
+ }
+
+ // If not even the inner-most dimension is full-size, this op can't be
+ // rewritten as an ExtractOp.
+ if (numOffsets == sourceType.getRank()) {
+ return failure();
+ }
+
+ // Avoid generating slices that have unit outer dimensions. The shape_cast
+ // op that we create below would take bad generic fallback patterns
+ // (ShapeCastOpRewritePattern).
+ while (sizes[numOffsets] == 1 && numOffsets < sourceType.getRank() - 1) {
+ ++numOffsets;
+ }
+
+ SmallVector<int64_t> offsets;
+ populateFromInt64AttrArray(op.getOffsets(), offsets);
+ auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
+ Value extract = rewriter.create<vector::ExtractOp>(op->getLoc(), source,
+ extractOffsets);
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
+ op, op->getResultTypes()[0], extract);
+ return success();
+ }
+};
+
void vector::populateVectorInsertExtractStridedSliceDecompositionPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<DecomposeDifferentRankInsertStridedSlice,
DecomposeNDExtractStridedSlice>(patterns.getContext(), benefit);
}
+void vector::populateVectorContiguousExtractStridedSliceToExtractPatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<ContiguousExtractStridedSliceToExtract>(patterns.getContext(),
+ benefit);
+}
+
void vector::populateVectorExtractStridedSliceToExtractInsertChainPatterns(
RewritePatternSet &patterns,
std::function<bool(ExtractStridedSliceOp)> controlFn,
diff --git a/mlir/test/Dialect/Vector/vector-contiguous-extract-strided-slice-to-extract.mlir b/mlir/test/Dialect/Vector/vector-contiguous-extract-strided-slice-to-extract.mlir
new file mode 100644
index 00000000000000..da8ff492431629
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-contiguous-extract-strided-slice-to-extract.mlir
@@ -0,0 +1,40 @@
+// RUN: mlir-opt -split-input-file -test-vector-contiguous-extract-strided-slice-to-extract %s | FileCheck %s
+
+// CHECK-LABEL: @extract_strided_slice_to_extract_i8
+// CHECK: vector.extract {{.*}}[0, 0, 0, 0] : vector<8xi8> from vector<8x1x1x2x8xi8>
+
+func.func @extract_strided_slice_to_extract_i8(%arg0 : vector<8x1x1x2x8xi8>) -> vector<8xi8> {
+ %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 8], strides = [1, 1, 1, 1, 1]} : vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
+ %2 = vector.shape_cast %1 : vector<1x1x1x1x8xi8> to vector<8xi8>
+ return %2 : vector<8xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @extract_strided_slice_to_extract_i32
+// CHECK: vector.extract {{.*}}[0, 0, 0, 0, 0] : vector<4xi32> from vector<8x1x2x1x1x4xi32>
+func.func @extract_strided_slice_to_extract_i32(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<4xi32> {
+ %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 1, 4], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x1x1x1x4xi32>
+ %2 = vector.shape_cast %1 : vector<1x1x1x1x1x4xi32> to vector<4xi32>
+ return %2 : vector<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @extract_strided_slice_to_extract_i32_non_contiguous_1
+// CHECK: vector.extract_strided_slice
+func.func @extract_strided_slice_to_extract_i32_non_contiguous_1(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<2xi32> {
+ %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 1, 2], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x1x1x1x2xi32>
+ %2 = vector.shape_cast %1 : vector<1x1x1x1x1x2xi32> to vector<2xi32>
+ return %2 : vector<2xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @extract_strided_slice_to_extract_i32_non_contiguous_2
+// CHECK: vector.extract_strided_slice
+func.func @extract_strided_slice_to_extract_i32_non_contiguous_2(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<2xi32> {
+ %1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 2, 1, 1, 1], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x2x1x1x1xi32>
+ %2 = vector.shape_cast %1 : vector<1x1x2x1x1x1xi32> to vector<2xi32>
+ return %2 : vector<2xi32>
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 72aaa7dc4f8973..d91e955b70641e 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -709,6 +709,27 @@ struct TestVectorExtractStridedSliceLowering
}
};
+struct TestVectorContiguousExtractStridedSliceToExtract
+ : public PassWrapper<TestVectorContiguousExtractStridedSliceToExtract,
+ OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+ TestVectorExtractStridedSliceLowering)
+
+ StringRef getArgument() const final {
+ return "test-vector-contiguous-extract-strided-slice-to-extract";
+ }
+ StringRef getDescription() const final {
+ return "Test lowering patterns that rewrite simple cases of N-D "
+ "extract_strided_slice, where the slice is contiguous, into extract "
+ "and shape_cast";
+ }
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ populateVectorContiguousExtractStridedSliceToExtractPatterns(patterns);
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ }
+};
+
struct TestVectorBreakDownBitCast
: public PassWrapper<TestVectorBreakDownBitCast,
OperationPass<func::FuncOp>> {
@@ -935,6 +956,8 @@ void registerTestVectorLowerings() {
PassRegistration<TestVectorExtractStridedSliceLowering>();
+ PassRegistration<TestVectorContiguousExtractStridedSliceToExtract>();
+
PassRegistration<TestVectorBreakDownBitCast>();
PassRegistration<TestCreateVectorBroadcast>();
>From 4b21279aeff356789fd60eddc7846b0a09c6a7d6 Mon Sep 17 00:00:00 2001
From: Benoit Jacob <jacob.benoit.1 at gmail.com>
Date: Tue, 8 Oct 2024 10:27:29 -0500
Subject: [PATCH 2/3] review comments
---
.../mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h | 5 ++---
.../VectorInsertExtractStridedSliceRewritePatterns.cpp | 7 +++++--
2 files changed, 7 insertions(+), 5 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 9ad78cc282b674..ec1de7fa66aa07 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -235,9 +235,8 @@ void populateVectorExtractStridedSliceToExtractInsertChainPatterns(
std::function<bool(ExtractStridedSliceOp)> controlFn = nullptr,
PatternBenefit benefit = 1);
-/// Populate `patterns` with a pattern to rewrite simple cases of N-D
-/// extract_strided_slice, where the slice is contiguous, into extract and
-/// shape_cast.
+/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
+/// slice is contiguous, into extract and shape_cast.
void populateVectorContiguousExtractStridedSliceToExtractPatterns(
RewritePatternSet &patterns, PatternBenefit benefit = 1);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index 324c7b84ebfa0d..dda135d099bfb8 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -347,10 +347,13 @@ class ContiguousExtractStridedSliceToExtract final
if (op.hasNonUnitStrides()) {
return failure();
}
+ Value source = op.getOperand();
+ VectorType sourceType = cast<VectorType>(source.getType());
+ if (sourceType.isScalable()) {
+ return failure();
+ }
SmallVector<int64_t> sizes;
populateFromInt64AttrArray(op.getSizes(), sizes);
- Value source = op.getOperand();
- ShapedType sourceType = cast<ShapedType>(source.getType());
// Compute the number of offsets to pass to ExtractOp::build. That is the
// difference between the source rank and the desired slice rank. We walk
>From 14f1e4244e8ac1f19a25ccb42f74a4a864e7a3d2 Mon Sep 17 00:00:00 2001
From: Benoit Jacob <jacob.benoit.1 at gmail.com>
Date: Tue, 8 Oct 2024 10:31:56 -0500
Subject: [PATCH 3/3] more review comments
---
...InsertExtractStridedSliceRewritePatterns.cpp | 17 ++++-------------
...iguous-extract-strided-slice-to-extract.mlir | 13 ++++---------
2 files changed, 8 insertions(+), 22 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index dda135d099bfb8..c2da9347aadc87 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -329,12 +329,6 @@ class DecomposeNDExtractStridedSlice
}
};
-static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
- SmallVectorImpl<int64_t> &results) {
- for (auto attr : arrayAttr)
- results.push_back(llvm::cast<IntegerAttr>(attr).getInt());
-}
-
/// Pattern to rewrite simple cases of N-D extract_strided_slice, where the
/// slice is contiguous, into extract and shape_cast.
class ContiguousExtractStridedSliceToExtract final
@@ -348,17 +342,16 @@ class ContiguousExtractStridedSliceToExtract final
return failure();
}
Value source = op.getOperand();
- VectorType sourceType = cast<VectorType>(source.getType());
+ auto sourceType = cast<VectorType>(source.getType());
if (sourceType.isScalable()) {
return failure();
}
- SmallVector<int64_t> sizes;
- populateFromInt64AttrArray(op.getSizes(), sizes);
// Compute the number of offsets to pass to ExtractOp::build. That is the
// difference between the source rank and the desired slice rank. We walk
// the dimensions from innermost out, and stop when the next slice dimension
// is not full-size.
+ SmallVector<int64_t> sizes = getI64SubArray(op.getSizes());
int numOffsets;
for (numOffsets = sourceType.getRank(); numOffsets > 0; --numOffsets) {
if (sizes[numOffsets - 1] != sourceType.getDimSize(numOffsets - 1)) {
@@ -379,13 +372,11 @@ class ContiguousExtractStridedSliceToExtract final
++numOffsets;
}
- SmallVector<int64_t> offsets;
- populateFromInt64AttrArray(op.getOffsets(), offsets);
+ SmallVector<int64_t> offsets = getI64SubArray(op.getOffsets());
auto extractOffsets = ArrayRef(offsets).take_front(numOffsets);
Value extract = rewriter.create<vector::ExtractOp>(op->getLoc(), source,
extractOffsets);
- rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
- op, op->getResultTypes()[0], extract);
+ rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(), extract);
return success();
}
};
diff --git a/mlir/test/Dialect/Vector/vector-contiguous-extract-strided-slice-to-extract.mlir b/mlir/test/Dialect/Vector/vector-contiguous-extract-strided-slice-to-extract.mlir
index da8ff492431629..cdec7c28e1914f 100644
--- a/mlir/test/Dialect/Vector/vector-contiguous-extract-strided-slice-to-extract.mlir
+++ b/mlir/test/Dialect/Vector/vector-contiguous-extract-strided-slice-to-extract.mlir
@@ -1,26 +1,23 @@
// RUN: mlir-opt -split-input-file -test-vector-contiguous-extract-strided-slice-to-extract %s | FileCheck %s
// CHECK-LABEL: @extract_strided_slice_to_extract_i8
-// CHECK: vector.extract {{.*}}[0, 0, 0, 0] : vector<8xi8> from vector<8x1x1x2x8xi8>
-
+// CHECK: %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0] : vector<8xi8> from vector<8x1x1x2x8xi8>
+// CHECK: return %[[EXTRACT]] : vector<8xi8>
func.func @extract_strided_slice_to_extract_i8(%arg0 : vector<8x1x1x2x8xi8>) -> vector<8xi8> {
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 8], strides = [1, 1, 1, 1, 1]} : vector<8x1x1x2x8xi8> to vector<1x1x1x1x8xi8>
%2 = vector.shape_cast %1 : vector<1x1x1x1x8xi8> to vector<8xi8>
return %2 : vector<8xi8>
}
-// -----
-
// CHECK-LABEL: @extract_strided_slice_to_extract_i32
-// CHECK: vector.extract {{.*}}[0, 0, 0, 0, 0] : vector<4xi32> from vector<8x1x2x1x1x4xi32>
+// CHECK: %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0, 0] : vector<4xi32> from vector<8x1x2x1x1x4xi32>
+// CHECK: return %[[EXTRACT]] : vector<4xi32>
func.func @extract_strided_slice_to_extract_i32(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<4xi32> {
%1 = vector.extract_strided_slice %arg0 {offsets = [0, 0, 0, 0, 0, 0], sizes = [1, 1, 1, 1, 1, 4], strides = [1, 1, 1, 1, 1, 1]} : vector<8x1x2x1x1x4xi32> to vector<1x1x1x1x1x4xi32>
%2 = vector.shape_cast %1 : vector<1x1x1x1x1x4xi32> to vector<4xi32>
return %2 : vector<4xi32>
}
-// -----
-
// CHECK-LABEL: @extract_strided_slice_to_extract_i32_non_contiguous_1
// CHECK: vector.extract_strided_slice
func.func @extract_strided_slice_to_extract_i32_non_contiguous_1(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<2xi32> {
@@ -29,8 +26,6 @@ func.func @extract_strided_slice_to_extract_i32_non_contiguous_1(%arg0 : vector<
return %2 : vector<2xi32>
}
-// -----
-
// CHECK-LABEL: @extract_strided_slice_to_extract_i32_non_contiguous_2
// CHECK: vector.extract_strided_slice
func.func @extract_strided_slice_to_extract_i32_non_contiguous_2(%arg0 : vector<8x1x2x1x1x4xi32>) -> vector<2xi32> {
More information about the Mlir-commits
mailing list