[Mlir-commits] [mlir] [mlir][Vector] Add canonicalization for extract_strided_slice(create_mask) (PR #146745)

Kunwar Grover llvmlistbot at llvm.org
Fri Jul 4 03:53:04 PDT 2025


https://github.com/Groverkss updated https://github.com/llvm/llvm-project/pull/146745

>From 48b888330a93cf52abb35939db3cea5811334b29 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Wed, 2 Jul 2025 17:14:18 +0100
Subject: [PATCH 1/2] [mlir][Vector] Add canonicalization for
 extract_strided_slice(create_mask)

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   | 62 ++++++++++++++++++++--
 mlir/test/Dialect/Vector/canonicalize.mlir | 17 ++++++
 2 files changed, 76 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 1fb8c7a928e06..66c2fe50529d7 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4081,6 +4081,62 @@ void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
 
 namespace {
 
+class StridedSliceCreateMaskFolder final
+    : public OpRewritePattern<ExtractStridedSliceOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+public:
+  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
+                                PatternRewriter &rewriter) const override {
+    Location loc = extractStridedSliceOp.getLoc();
+    // Return if 'extractStridedSliceOp' operand is not defined by a
+    // CreateMaskOp.
+    auto *defOp = extractStridedSliceOp.getVector().getDefiningOp();
+    auto createMaskOp = dyn_cast_or_null<CreateMaskOp>(defOp);
+    if (!createMaskOp)
+      return failure();
+    // Return if 'extractStridedSliceOp' has non-unit strides.
+    if (extractStridedSliceOp.hasNonUnitStrides())
+      return failure();
+    // Gather constant mask dimension sizes.
+    SmallVector<Value> maskDimSizes(createMaskOp.getOperands());
+    // Gather strided slice offsets and sizes.
+    SmallVector<int64_t, 4> sliceOffsets;
+    populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
+                               sliceOffsets);
+    SmallVector<int64_t, 4> sliceSizes;
+    populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
+
+    // Compute slice of vector mask region.
+    SmallVector<Value> sliceMaskDimSizes;
+    sliceMaskDimSizes.reserve(maskDimSizes.size());
+    for (auto [maskDimSize, sliceOffset, sliceSize] :
+         llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
+      // No need to clamp on min/max values, because create_mask has clamping
+      // semantics, i.e. the sliceMaskDimSize is allowed to be negative or
+      // greater than the vector dim size.
+      IntegerAttr offsetAttr =
+          rewriter.getIntegerAttr(maskDimSize.getType(), sliceOffset);
+      Value offset = rewriter.create<arith::ConstantOp>(loc, offsetAttr);
+      Value sliceMaskDimSize =
+          rewriter.create<arith::SubIOp>(loc, maskDimSize, offset);
+      sliceMaskDimSizes.push_back(sliceMaskDimSize);
+    }
+    // Add unchanged dimensions.
+    if (sliceMaskDimSizes.size() < maskDimSizes.size()) {
+      for (size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i) {
+        sliceMaskDimSizes.push_back(maskDimSizes[i]);
+      }
+    }
+    // Replace 'extractStridedSliceOp' with CreateMaskOp with sliced mask
+    // region.
+    rewriter.replaceOpWithNewOp<CreateMaskOp>(
+        extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
+        sliceMaskDimSizes);
+    return success();
+  }
+};
+
 // Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
 // ConstantMaskOp.
 class StridedSliceConstantMaskFolder final
@@ -4279,9 +4335,9 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
     RewritePatternSet &results, MLIRContext *context) {
   // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
   // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
-  results.add<StridedSliceConstantMaskFolder, StridedSliceBroadcast,
-              StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
-      context);
+  results.add<StridedSliceCreateMaskFolder, StridedSliceConstantMaskFolder,
+              StridedSliceBroadcast, StridedSliceSplat,
+              ContiguousExtractStridedSliceToExtract>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 0282e9cac5e02..c09dab8232900 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -361,6 +361,23 @@ func.func @extract_strided_slice_of_constant_mask() -> (vector<2x1xi1>) {
 
 // -----
 
+// CHECK-LABEL: func.func @extract_strided_slice_of_create_mask
+// CHECK-SAME: (%[[DIM0:.+]]: index, %[[DIM1:.+]]: index)
+func.func @extract_strided_slice_of_create_mask(%dim0: index, %dim1: index) -> (vector<2x2xi1>) {
+  %0 = vector.create_mask %dim0, %dim1 : vector<4x3xi1>
+  %1 = vector.extract_strided_slice %0
+    {offsets = [2, 1], sizes = [2, 2], strides = [1, 1]}
+      : vector<4x3xi1> to vector<2x2xi1>
+  // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+  // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+  // CHECK-DAG: %[[A:.+]] = arith.subi %[[DIM0]], %[[C2]]
+  // CHECK-DAG: %[[B:.+]] = arith.subi %[[DIM1]], %[[C1]]
+  // CHECK: vector.create_mask %[[A]], %[[B]] : vector<2x2xi1>
+  return %1 : vector<2x2xi1>
+}
+
+// -----
+
 // CHECK-LABEL: extract_strided_fold
 //  CHECK-SAME: (%[[ARG:.*]]: vector<4x3xi1>)
 //  CHECK-NEXT:   return %[[ARG]] : vector<4x3xi1>

>From 8dc4dda3c0b671589951424e94b215921ab028e5 Mon Sep 17 00:00:00 2001
From: Kunwar Grover <groverkss at gmail.com>
Date: Fri, 4 Jul 2025 11:52:38 +0100
Subject: [PATCH 2/2] address comments

---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   | 25 ++++++++++++++++------
 mlir/test/Dialect/Vector/canonicalize.mlir | 18 ++++++++++++++++
 2 files changed, 37 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 66c2fe50529d7..6f166bc26811d 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4081,6 +4081,18 @@ void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
 
 namespace {
 
+// Pattern to rewrite an ExtractStridedSliceOp(CreateMaskOp) to
+// CreateMaskOp.
+//
+// Example:
+//
+// %mask = vector.create_mask %ub : vector<16xi1>
+// %slice = vector.extract_strided_slice [%offset] [8] [1]
+//
+// to
+//
+// %new_ub = arith.subi %ub, %offset
+// %mask = vector.create_mask %new_ub : vector<8xi1>
 class StridedSliceCreateMaskFolder final
     : public OpRewritePattern<ExtractStridedSliceOp> {
   using OpRewritePattern::OpRewritePattern;
@@ -4101,10 +4113,10 @@ class StridedSliceCreateMaskFolder final
     // Gather constant mask dimension sizes.
     SmallVector<Value> maskDimSizes(createMaskOp.getOperands());
     // Gather strided slice offsets and sizes.
-    SmallVector<int64_t, 4> sliceOffsets;
+    SmallVector<int64_t> sliceOffsets;
     populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
                                sliceOffsets);
-    SmallVector<int64_t, 4> sliceSizes;
+    SmallVector<int64_t> sliceSizes;
     populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
 
     // Compute slice of vector mask region.
@@ -4124,7 +4136,8 @@ class StridedSliceCreateMaskFolder final
     }
     // Add unchanged dimensions.
     if (sliceMaskDimSizes.size() < maskDimSizes.size()) {
-      for (size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i) {
+      for (size_t i = sliceMaskDimSizes.size(), e = maskDimSizes.size(); i < e;
+           ++i) {
         sliceMaskDimSizes.push_back(maskDimSizes[i]);
       }
     }
@@ -4158,14 +4171,14 @@ class StridedSliceConstantMaskFolder final
     // Gather constant mask dimension sizes.
     ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
     // Gather strided slice offsets and sizes.
-    SmallVector<int64_t, 4> sliceOffsets;
+    SmallVector<int64_t> sliceOffsets;
     populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
                                sliceOffsets);
-    SmallVector<int64_t, 4> sliceSizes;
+    SmallVector<int64_t> sliceSizes;
     populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
 
     // Compute slice of vector mask region.
-    SmallVector<int64_t, 4> sliceMaskDimSizes;
+    SmallVector<int64_t> sliceMaskDimSizes;
     sliceMaskDimSizes.reserve(maskDimSizes.size());
     for (auto [maskDimSize, sliceOffset, sliceSize] :
          llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index c09dab8232900..e05eb4b0ee5bb 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -378,6 +378,24 @@ func.func @extract_strided_slice_of_create_mask(%dim0: index, %dim1: index) -> (
 
 // -----
 
+// CHECK-LABEL: func.func @extract_strided_slice_partial_of_create_mask
+// CHECK-SAME: (%[[DIM0:.+]]: index, %[[DIM1:.+]]: index, %[[DIM2:.+]]: index)
+func.func @extract_strided_slice_partial_of_create_mask(
+  %dim0: index, %dim1: index, %dim2 : index) -> (vector<2x2x8xi1>) {
+  %0 = vector.create_mask %dim0, %dim1, %dim2 : vector<4x3x8xi1>
+  %1 = vector.extract_strided_slice %0
+    {offsets = [2, 1], sizes = [2, 2], strides = [1, 1]}
+      : vector<4x3x8xi1> to vector<2x2x8xi1>
+  // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+  // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+  // CHECK-DAG: %[[A:.+]] = arith.subi %[[DIM0]], %[[C2]]
+  // CHECK-DAG: %[[B:.+]] = arith.subi %[[DIM1]], %[[C1]]
+  // CHECK: vector.create_mask %[[A]], %[[B]], %[[DIM2]] : vector<2x2x8xi1>
+  return %1 : vector<2x2x8xi1>
+}
+
+// -----
+
 // CHECK-LABEL: extract_strided_fold
 //  CHECK-SAME: (%[[ARG:.*]]: vector<4x3xi1>)
 //  CHECK-NEXT:   return %[[ARG]] : vector<4x3xi1>



More information about the Mlir-commits mailing list