[Mlir-commits] [mlir] d6f394e - [mlir][Vector] Move `vector.mask` canonicalization to folder (#140324)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed May 21 17:25:04 PDT 2025


Author: Diego Caballero
Date: 2025-05-21T17:25:01-07:00
New Revision: d6f394e141bd8e218356fefc4cb1beabf5c7bc6d

URL: https://github.com/llvm/llvm-project/commit/d6f394e141bd8e218356fefc4cb1beabf5c7bc6d
DIFF: https://github.com/llvm/llvm-project/commit/d6f394e141bd8e218356fefc4cb1beabf5c7bc6d.diff

LOG: [mlir][Vector] Move `vector.mask` canonicalization to folder (#140324)

This MR moves the canonicalization that elides empty `vector.mask` ops
to folders.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Conversion/GPUCommon/lower-vector.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 5e8421ed67d66..3f5564541554e 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2559,7 +2559,6 @@ def Vector_MaskOp : Vector_Op<"mask", [
                                  Location loc);
   }];
 
-  let hasCanonicalizer = 1;
   let hasFolder = 1;
   let hasCustomAssemblyFormat = 1;
   let hasVerifier = 1;

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index cf2df1f24f91f..41777347975da 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6650,13 +6650,40 @@ LogicalResult MaskOp::verify() {
   return success();
 }
 
-/// Folds vector.mask ops with an all-true mask.
+/// Folds empty `vector.mask` with no passthru operand and with or without
+/// return values. For example:
+///
+///   %0 = vector.mask %mask { vector.yield %a : vector<8xf32> } :
+///     vector<8xi1> -> vector<8xf32>
+///   %1 = user_op %0 : vector<8xf32>
+///
+/// becomes:
+///
+///   %0 = user_op %a : vector<8xf32>
+///
+static LogicalResult foldEmptyMaskOp(MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
+                                     SmallVectorImpl<OpFoldResult> &results) {
+  if (!maskOp.isEmpty() || maskOp.hasPassthru())
+    return failure();
+
+  Block *block = maskOp.getMaskBlock();
+  auto terminator = cast<vector::YieldOp>(block->front());
+  if (terminator.getNumOperands() == 0) {
+    // `vector.mask` has no results, just remove the `vector.mask`.
+    return success();
+  }
+
+  // `vector.mask` has results, propagate the results.
+  llvm::append_range(results, terminator.getOperands());
+  return success();
+}
+
 LogicalResult MaskOp::fold(FoldAdaptor adaptor,
                            SmallVectorImpl<OpFoldResult> &results) {
-  MaskFormat maskFormat = getMaskFormat(getMask());
-  if (isEmpty())
-    return failure();
+  if (succeeded(foldEmptyMaskOp(*this, adaptor, results)))
+    return success();
 
+  MaskFormat maskFormat = getMaskFormat(getMask());
   if (maskFormat != MaskFormat::AllTrue)
     return failure();
 
@@ -6669,37 +6696,6 @@ LogicalResult MaskOp::fold(FoldAdaptor adaptor,
   return success();
 }
 
-// Elides empty vector.mask operations with or without return values. Propagates
-// the yielded values by the vector.yield terminator, if any, or erases the op,
-// otherwise.
-class ElideEmptyMaskOp : public OpRewritePattern<MaskOp> {
-  using OpRewritePattern::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(MaskOp maskOp,
-                                PatternRewriter &rewriter) const override {
-    auto maskingOp = cast<MaskingOpInterface>(maskOp.getOperation());
-    if (maskingOp.getMaskableOp())
-      return failure();
-
-    if (!maskOp.isEmpty())
-      return failure();
-
-    Block *block = maskOp.getMaskBlock();
-    auto terminator = cast<vector::YieldOp>(block->front());
-    if (terminator.getNumOperands() == 0)
-      rewriter.eraseOp(maskOp);
-    else
-      rewriter.replaceOp(maskOp, terminator.getOperands());
-
-    return success();
-  }
-};
-
-void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
-                                         MLIRContext *context) {
-  results.add<ElideEmptyMaskOp>(context);
-}
-
 // MaskingOpInterface definitions.
 
 /// Returns the operation masked by this 'vector.mask'.

diff  --git a/mlir/test/Conversion/GPUCommon/lower-vector.mlir b/mlir/test/Conversion/GPUCommon/lower-vector.mlir
index 532a2383cea9e..b4e3da9d0dbfe 100644
--- a/mlir/test/Conversion/GPUCommon/lower-vector.mlir
+++ b/mlir/test/Conversion/GPUCommon/lower-vector.mlir
@@ -1,10 +1,12 @@
 // RUN: mlir-opt %s --gpu-to-llvm | FileCheck %s
 
 module {
+  // CHECK-LABEL: func @func
+  // CHECK-SAME: %[[IN:.*]]: vector<11xf32>
   func.func @func(%arg: vector<11xf32>) -> vector<11xf32> {
     %cst_41 = arith.constant dense<true> : vector<11xi1>
-    // CHECK: vector.mask
-    // CHECK-SAME: vector.yield %arg0
+    // CHECK-NOT: vector.mask
+    // CHECK: return %[[IN]] : vector<11xf32>
     %127 = vector.mask %cst_41 { vector.yield %arg : vector<11xf32> } : vector<11xi1> -> vector<11xf32>
     return %127 : vector<11xf32>
   }


        


More information about the Mlir-commits mailing list