[Mlir-commits] [mlir] [mlir][Vector] Move `vector.mask` canonicalization to folder (PR #140324)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri May 16 17:59:46 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-gpu

Author: Diego Caballero (dcaballe)

<details>
<summary>Changes</summary>

This MR moves the canonicalization that elides empty `vector.mask` ops to folders.

---
Full diff: https://github.com/llvm/llvm-project/pull/140324.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+1-1) 
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+33-35) 
- (modified) mlir/test/Conversion/GPUCommon/lower-vector.mlir (+4-2) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 3aefcea8de994..2e0c9a6de11ae 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2554,7 +2554,7 @@ def Vector_MaskOp : Vector_Op<"mask", [
                                  Location loc);
   }];
 
-  let hasCanonicalizer = 1;
+  let hasCanonicalizer = 0;
   let hasFolder = 1;
   let hasCustomAssemblyFormat = 1;
   let hasVerifier = 1;
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 79bf87ccd34af..104459850d508 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6631,13 +6631,42 @@ LogicalResult MaskOp::verify() {
   return success();
 }
 
-/// Folds vector.mask ops with an all-true mask.
+/// Folds empty `vector.mask` with no passthru operand and with or without
+/// return values. For example:
+///
+///   %0 = vector.mask %mask { vector.yield %a : vector<8xf32> } :
+///     vector<8xi1> -> vector<8xf32>
+///   %1 = user_op %0 : vector<8xf32>
+///
+/// becomes:
+///
+///   %0 = user_op %a : vector<8xf32>
+///
+/// `vector.mask` with a passthru is handled by the canonicalizer.
+///
+static LogicalResult foldEmptyMaskOp(MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
+                                     SmallVectorImpl<OpFoldResult> &results) {
+  if (!maskOp.isEmpty() || maskOp.hasPassthru())
+    return failure();
+
+  Block *block = maskOp.getMaskBlock();
+  auto terminator = cast<vector::YieldOp>(block->front());
+  if (terminator.getNumOperands() == 0) {
+    // `vector.mask` has no results, just remove the `vector.mask`.
+    return success();
+  }
+
+  // `vector.mask` has results, propagate the results.
+  llvm::append_range(results, terminator.getOperands());
+  return success();
+}
+
 LogicalResult MaskOp::fold(FoldAdaptor adaptor,
                            SmallVectorImpl<OpFoldResult> &results) {
-  MaskFormat maskFormat = getMaskFormat(getMask());
-  if (isEmpty())
-    return failure();
+  if (succeeded(foldEmptyMaskOp(*this, adaptor, results)))
+    return success();
 
+  MaskFormat maskFormat = getMaskFormat(getMask());
   if (maskFormat != MaskFormat::AllTrue)
     return failure();
 
@@ -6650,37 +6679,6 @@ LogicalResult MaskOp::fold(FoldAdaptor adaptor,
   return success();
 }
 
-// Elides empty vector.mask operations with or without return values. Propagates
-// the yielded values by the vector.yield terminator, if any, or erases the op,
-// otherwise.
-class ElideEmptyMaskOp : public OpRewritePattern<MaskOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(MaskOp maskOp,
-                                PatternRewriter &rewriter) const override {
-    auto maskingOp = cast<MaskingOpInterface>(maskOp.getOperation());
-    if (maskingOp.getMaskableOp())
-      return failure();
-
-    if (!maskOp.isEmpty())
-      return failure();
-
-    Block *block = maskOp.getMaskBlock();
-    auto terminator = cast<vector::YieldOp>(block->front());
-    if (terminator.getNumOperands() == 0)
-      rewriter.eraseOp(maskOp);
-    else
-      rewriter.replaceOp(maskOp, terminator.getOperands());
-
-    return success();
-  }
-};
-
-void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
-                                         MLIRContext *context) {
-  results.add<ElideEmptyMaskOp>(context);
-}
-
 // MaskingOpInterface definitions.
 
 /// Returns the operation masked by this 'vector.mask'.
diff --git a/mlir/test/Conversion/GPUCommon/lower-vector.mlir b/mlir/test/Conversion/GPUCommon/lower-vector.mlir
index 532a2383cea9e..b4e3da9d0dbfe 100644
--- a/mlir/test/Conversion/GPUCommon/lower-vector.mlir
+++ b/mlir/test/Conversion/GPUCommon/lower-vector.mlir
@@ -1,10 +1,12 @@
 // RUN: mlir-opt %s --gpu-to-llvm | FileCheck %s
 
 module {
+  // CHECK-LABEL: func @func
+  // CHECK-SAME: %[[IN:.*]]: vector<11xf32>
   func.func @func(%arg: vector<11xf32>) -> vector<11xf32> {
     %cst_41 = arith.constant dense<true> : vector<11xi1>
-    // CHECK: vector.mask
-    // CHECK-SAME: vector.yield %arg0
+    // CHECK-NOT: vector.mask
+    // CHECK: return %[[IN]] : vector<11xf32>
     %127 = vector.mask %cst_41 { vector.yield %arg : vector<11xf32> } : vector<11xi1> -> vector<11xf32>
     return %127 : vector<11xf32>
   }

``````````

</details>


https://github.com/llvm/llvm-project/pull/140324


More information about the Mlir-commits mailing list