[Mlir-commits] [mlir] [mlir][Vector] Add canonicalization for extract_strided_slice(create_mask) (PR #146745)
Kunwar Grover
llvmlistbot at llvm.org
Fri Jul 4 03:53:04 PDT 2025
https://github.com/Groverkss updated https://github.com/llvm/llvm-project/pull/146745
>From 48b888330a93cf52abb35939db3cea5811334b29 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Wed, 2 Jul 2025 17:14:18 +0100
Subject: [PATCH 1/2] [mlir][Vector] Add canonicalization for
extract_strided_slice(create_mask)
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 62 ++++++++++++++++++++--
mlir/test/Dialect/Vector/canonicalize.mlir | 17 ++++++
2 files changed, 76 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 1fb8c7a928e06..66c2fe50529d7 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4081,6 +4081,62 @@ void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
namespace {
+class StridedSliceCreateMaskFolder final
+ : public OpRewritePattern<ExtractStridedSliceOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+public:
+ LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
+ PatternRewriter &rewriter) const override {
+ Location loc = extractStridedSliceOp.getLoc();
+ // Return if 'extractStridedSliceOp' operand is not defined by a
+ // CreateMaskOp.
+ auto *defOp = extractStridedSliceOp.getVector().getDefiningOp();
+ auto createMaskOp = dyn_cast_or_null<CreateMaskOp>(defOp);
+ if (!createMaskOp)
+ return failure();
+ // Return if 'extractStridedSliceOp' has non-unit strides.
+ if (extractStridedSliceOp.hasNonUnitStrides())
+ return failure();
+ // Gather constant mask dimension sizes.
+ SmallVector<Value> maskDimSizes(createMaskOp.getOperands());
+ // Gather strided slice offsets and sizes.
+ SmallVector<int64_t, 4> sliceOffsets;
+ populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
+ sliceOffsets);
+ SmallVector<int64_t, 4> sliceSizes;
+ populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
+
+ // Compute slice of vector mask region.
+ SmallVector<Value> sliceMaskDimSizes;
+ sliceMaskDimSizes.reserve(maskDimSizes.size());
+ for (auto [maskDimSize, sliceOffset, sliceSize] :
+ llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
+ // No need to clamp on min/max values, because create_mask has clamping
+ // semantics, i.e. the sliceMaskDimSize is allowed to be negative or
+ // greater than the vector dim size.
+ IntegerAttr offsetAttr =
+ rewriter.getIntegerAttr(maskDimSize.getType(), sliceOffset);
+ Value offset = rewriter.create<arith::ConstantOp>(loc, offsetAttr);
+ Value sliceMaskDimSize =
+ rewriter.create<arith::SubIOp>(loc, maskDimSize, offset);
+ sliceMaskDimSizes.push_back(sliceMaskDimSize);
+ }
+ // Add unchanged dimensions.
+ if (sliceMaskDimSizes.size() < maskDimSizes.size()) {
+ for (size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i) {
+ sliceMaskDimSizes.push_back(maskDimSizes[i]);
+ }
+ }
+ // Replace 'extractStridedSliceOp' with CreateMaskOp with sliced mask
+ // region.
+ rewriter.replaceOpWithNewOp<CreateMaskOp>(
+ extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
+ sliceMaskDimSizes);
+ return success();
+ }
+};
+
// Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
// ConstantMaskOp.
class StridedSliceConstantMaskFolder final
@@ -4279,9 +4335,9 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
// Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
// ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
- results.add<StridedSliceConstantMaskFolder, StridedSliceBroadcast,
- StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
- context);
+ results.add<StridedSliceCreateMaskFolder, StridedSliceConstantMaskFolder,
+ StridedSliceBroadcast, StridedSliceSplat,
+ ContiguousExtractStridedSliceToExtract>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 0282e9cac5e02..c09dab8232900 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -361,6 +361,23 @@ func.func @extract_strided_slice_of_constant_mask() -> (vector<2x1xi1>) {
// -----
+// CHECK-LABEL: func.func @extract_strided_slice_of_create_mask
+// CHECK-SAME: (%[[DIM0:.+]]: index, %[[DIM1:.+]]: index)
+func.func @extract_strided_slice_of_create_mask(%dim0: index, %dim1: index) -> (vector<2x2xi1>) {
+ %0 = vector.create_mask %dim0, %dim1 : vector<4x3xi1>
+ %1 = vector.extract_strided_slice %0
+ {offsets = [2, 1], sizes = [2, 2], strides = [1, 1]}
+ : vector<4x3xi1> to vector<2x2xi1>
+ // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+ // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+ // CHECK-DAG: %[[A:.+]] = arith.subi %[[DIM0]], %[[C2]]
+ // CHECK-DAG: %[[B:.+]] = arith.subi %[[DIM1]], %[[C1]]
+ // CHECK: vector.create_mask %[[A]], %[[B]] : vector<2x2xi1>
+ return %1 : vector<2x2xi1>
+}
+
+// -----
+
// CHECK-LABEL: extract_strided_fold
// CHECK-SAME: (%[[ARG:.*]]: vector<4x3xi1>)
// CHECK-NEXT: return %[[ARG]] : vector<4x3xi1>
>From 8dc4dda3c0b671589951424e94b215921ab028e5 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Fri, 4 Jul 2025 11:52:38 +0100
Subject: [PATCH 2/2] address comments
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 25 ++++++++++++++++------
mlir/test/Dialect/Vector/canonicalize.mlir | 18 ++++++++++++++++
2 files changed, 37 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 66c2fe50529d7..6f166bc26811d 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4081,6 +4081,18 @@ void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
namespace {
+// Pattern to rewrite an ExtractStridedSliceOp(CreateMaskOp) to
+// CreateMaskOp.
+//
+// Example:
+//
+// %mask = vector.create_mask %ub : vector<16xi1>
+// %slice = vector.extract_strided_slice [%offset] [8] [1]
+//
+// to
+//
+// %new_ub = arith.subi %ub, %offset
+// %mask = vector.create_mask %new_ub : vector<8xi1>
class StridedSliceCreateMaskFolder final
: public OpRewritePattern<ExtractStridedSliceOp> {
using OpRewritePattern::OpRewritePattern;
@@ -4101,10 +4113,10 @@ class StridedSliceCreateMaskFolder final
// Gather constant mask dimension sizes.
SmallVector<Value> maskDimSizes(createMaskOp.getOperands());
// Gather strided slice offsets and sizes.
- SmallVector<int64_t, 4> sliceOffsets;
+ SmallVector<int64_t> sliceOffsets;
populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
sliceOffsets);
- SmallVector<int64_t, 4> sliceSizes;
+ SmallVector<int64_t> sliceSizes;
populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
// Compute slice of vector mask region.
@@ -4124,7 +4136,8 @@ class StridedSliceCreateMaskFolder final
}
// Add unchanged dimensions.
if (sliceMaskDimSizes.size() < maskDimSizes.size()) {
- for (size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i) {
+ for (size_t i = sliceMaskDimSizes.size(), e = maskDimSizes.size(); i < e;
+ ++i) {
sliceMaskDimSizes.push_back(maskDimSizes[i]);
}
}
@@ -4158,14 +4171,14 @@ class StridedSliceConstantMaskFolder final
// Gather constant mask dimension sizes.
ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
// Gather strided slice offsets and sizes.
- SmallVector<int64_t, 4> sliceOffsets;
+ SmallVector<int64_t> sliceOffsets;
populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
sliceOffsets);
- SmallVector<int64_t, 4> sliceSizes;
+ SmallVector<int64_t> sliceSizes;
populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
// Compute slice of vector mask region.
- SmallVector<int64_t, 4> sliceMaskDimSizes;
+ SmallVector<int64_t> sliceMaskDimSizes;
sliceMaskDimSizes.reserve(maskDimSizes.size());
for (auto [maskDimSize, sliceOffset, sliceSize] :
llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index c09dab8232900..e05eb4b0ee5bb 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -378,6 +378,24 @@ func.func @extract_strided_slice_of_create_mask(%dim0: index, %dim1: index) -> (
// -----
+// CHECK-LABEL: func.func @extract_strided_slice_partial_of_create_mask
+// CHECK-SAME: (%[[DIM0:.+]]: index, %[[DIM1:.+]]: index, %[[DIM2:.+]]: index)
+func.func @extract_strided_slice_partial_of_create_mask(
+ %dim0: index, %dim1: index, %dim2 : index) -> (vector<2x2x8xi1>) {
+ %0 = vector.create_mask %dim0, %dim1, %dim2 : vector<4x3x8xi1>
+ %1 = vector.extract_strided_slice %0
+ {offsets = [2, 1], sizes = [2, 2], strides = [1, 1]}
+ : vector<4x3x8xi1> to vector<2x2x8xi1>
+ // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+ // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+ // CHECK-DAG: %[[A:.+]] = arith.subi %[[DIM0]], %[[C2]]
+ // CHECK-DAG: %[[B:.+]] = arith.subi %[[DIM1]], %[[C1]]
+ // CHECK: vector.create_mask %[[A]], %[[B]], %[[DIM2]] : vector<2x2x8xi1>
+ return %1 : vector<2x2x8xi1>
+}
+
+// -----
+
// CHECK-LABEL: extract_strided_fold
// CHECK-SAME: (%[[ARG:.*]]: vector<4x3xi1>)
// CHECK-NEXT: return %[[ARG]] : vector<4x3xi1>
More information about the Mlir-commits
mailing list