[Mlir-commits] [mlir] [mlir][Vector] Move `vector.mask` canonicalization to folder (PR #140324)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri May 16 17:59:46 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-gpu
Author: Diego Caballero (dcaballe)
<details>
<summary>Changes</summary>
This MR moves the canonicalization that elides empty `vector.mask` ops to folders.
---
Full diff: https://github.com/llvm/llvm-project/pull/140324.diff
3 Files Affected:
- (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+1-1)
- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+33-35)
- (modified) mlir/test/Conversion/GPUCommon/lower-vector.mlir (+4-2)
``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 3aefcea8de994..2e0c9a6de11ae 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2554,7 +2554,7 @@ def Vector_MaskOp : Vector_Op<"mask", [
Location loc);
}];
- let hasCanonicalizer = 1;
+ let hasCanonicalizer = 0;
let hasFolder = 1;
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 79bf87ccd34af..104459850d508 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6631,13 +6631,42 @@ LogicalResult MaskOp::verify() {
return success();
}
-/// Folds vector.mask ops with an all-true mask.
+/// Folds empty `vector.mask` with no passthru operand and with or without
+/// return values. For example:
+///
+/// %0 = vector.mask %mask { vector.yield %a : vector<8xf32> } :
+/// vector<8xi1> -> vector<8xf32>
+/// %1 = user_op %0 : vector<8xf32>
+///
+/// becomes:
+///
+/// %0 = user_op %a : vector<8xf32>
+///
+/// `vector.mask` with a passthru is handled by the canonicalizer.
+///
+static LogicalResult foldEmptyMaskOp(MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
+ SmallVectorImpl<OpFoldResult> &results) {
+ if (!maskOp.isEmpty() || maskOp.hasPassthru())
+ return failure();
+
+ Block *block = maskOp.getMaskBlock();
+ auto terminator = cast<vector::YieldOp>(block->front());
+ if (terminator.getNumOperands() == 0) {
+ // `vector.mask` has no results, just remove the `vector.mask`.
+ return success();
+ }
+
+ // `vector.mask` has results, propagate the results.
+ llvm::append_range(results, terminator.getOperands());
+ return success();
+}
+
LogicalResult MaskOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
- MaskFormat maskFormat = getMaskFormat(getMask());
- if (isEmpty())
- return failure();
+ if (succeeded(foldEmptyMaskOp(*this, adaptor, results)))
+ return success();
+ MaskFormat maskFormat = getMaskFormat(getMask());
if (maskFormat != MaskFormat::AllTrue)
return failure();
@@ -6650,37 +6679,6 @@ LogicalResult MaskOp::fold(FoldAdaptor adaptor,
return success();
}
-// Elides empty vector.mask operations with or without return values. Propagates
-// the yielded values by the vector.yield terminator, if any, or erases the op,
-// otherwise.
-class ElideEmptyMaskOp : public OpRewritePattern<MaskOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(MaskOp maskOp,
- PatternRewriter &rewriter) const override {
- auto maskingOp = cast<MaskingOpInterface>(maskOp.getOperation());
- if (maskingOp.getMaskableOp())
- return failure();
-
- if (!maskOp.isEmpty())
- return failure();
-
- Block *block = maskOp.getMaskBlock();
- auto terminator = cast<vector::YieldOp>(block->front());
- if (terminator.getNumOperands() == 0)
- rewriter.eraseOp(maskOp);
- else
- rewriter.replaceOp(maskOp, terminator.getOperands());
-
- return success();
- }
-};
-
-void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
- MLIRContext *context) {
- results.add<ElideEmptyMaskOp>(context);
-}
-
// MaskingOpInterface definitions.
/// Returns the operation masked by this 'vector.mask'.
diff --git a/mlir/test/Conversion/GPUCommon/lower-vector.mlir b/mlir/test/Conversion/GPUCommon/lower-vector.mlir
index 532a2383cea9e..b4e3da9d0dbfe 100644
--- a/mlir/test/Conversion/GPUCommon/lower-vector.mlir
+++ b/mlir/test/Conversion/GPUCommon/lower-vector.mlir
@@ -1,10 +1,12 @@
// RUN: mlir-opt %s --gpu-to-llvm | FileCheck %s
module {
+ // CHECK-LABEL: func @func
+ // CHECK-SAME: %[[IN:.*]]: vector<11xf32>
func.func @func(%arg: vector<11xf32>) -> vector<11xf32> {
%cst_41 = arith.constant dense<true> : vector<11xi1>
- // CHECK: vector.mask
- // CHECK-SAME: vector.yield %arg0
+ // CHECK-NOT: vector.mask
+ // CHECK: return %[[IN]] : vector<11xf32>
%127 = vector.mask %cst_41 { vector.yield %arg : vector<11xf32> } : vector<11xi1> -> vector<11xf32>
return %127 : vector<11xf32>
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/140324
More information about the Mlir-commits
mailing list