[Mlir-commits] [mlir] d6f394e - [mlir][Vector] Move `vector.mask` canonicalization to folder (#140324)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed May 21 17:25:04 PDT 2025
Author: Diego Caballero
Date: 2025-05-21T17:25:01-07:00
New Revision: d6f394e141bd8e218356fefc4cb1beabf5c7bc6d
URL: https://github.com/llvm/llvm-project/commit/d6f394e141bd8e218356fefc4cb1beabf5c7bc6d
DIFF: https://github.com/llvm/llvm-project/commit/d6f394e141bd8e218356fefc4cb1beabf5c7bc6d.diff
LOG: [mlir][Vector] Move `vector.mask` canonicalization to folder (#140324)
This MR moves the canonicalization that elides empty `vector.mask` ops
to folders.
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Conversion/GPUCommon/lower-vector.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 5e8421ed67d66..3f5564541554e 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2559,7 +2559,6 @@ def Vector_MaskOp : Vector_Op<"mask", [
Location loc);
}];
- let hasCanonicalizer = 1;
let hasFolder = 1;
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index cf2df1f24f91f..41777347975da 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6650,13 +6650,40 @@ LogicalResult MaskOp::verify() {
return success();
}
-/// Folds vector.mask ops with an all-true mask.
+/// Folds empty `vector.mask` with no passthru operand and with or without
+/// return values. For example:
+///
+/// %0 = vector.mask %mask { vector.yield %a : vector<8xf32> } :
+/// vector<8xi1> -> vector<8xf32>
+/// %1 = user_op %0 : vector<8xf32>
+///
+/// becomes:
+///
+/// %0 = user_op %a : vector<8xf32>
+///
+static LogicalResult foldEmptyMaskOp(MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
+ SmallVectorImpl<OpFoldResult> &results) {
+ if (!maskOp.isEmpty() || maskOp.hasPassthru())
+ return failure();
+
+ Block *block = maskOp.getMaskBlock();
+ auto terminator = cast<vector::YieldOp>(block->front());
+ if (terminator.getNumOperands() == 0) {
+ // `vector.mask` has no results, just remove the `vector.mask`.
+ return success();
+ }
+
+ // `vector.mask` has results, propagate the results.
+ llvm::append_range(results, terminator.getOperands());
+ return success();
+}
+
LogicalResult MaskOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
- MaskFormat maskFormat = getMaskFormat(getMask());
- if (isEmpty())
- return failure();
+ if (succeeded(foldEmptyMaskOp(*this, adaptor, results)))
+ return success();
+ MaskFormat maskFormat = getMaskFormat(getMask());
if (maskFormat != MaskFormat::AllTrue)
return failure();
@@ -6669,37 +6696,6 @@ LogicalResult MaskOp::fold(FoldAdaptor adaptor,
return success();
}
-// Elides empty vector.mask operations with or without return values. Propagates
-// the yielded values by the vector.yield terminator, if any, or erases the op,
-// otherwise.
-class ElideEmptyMaskOp : public OpRewritePattern<MaskOp> {
- using OpRewritePattern::OpRewritePattern;
-
- LogicalResult matchAndRewrite(MaskOp maskOp,
- PatternRewriter &rewriter) const override {
- auto maskingOp = cast<MaskingOpInterface>(maskOp.getOperation());
- if (maskingOp.getMaskableOp())
- return failure();
-
- if (!maskOp.isEmpty())
- return failure();
-
- Block *block = maskOp.getMaskBlock();
- auto terminator = cast<vector::YieldOp>(block->front());
- if (terminator.getNumOperands() == 0)
- rewriter.eraseOp(maskOp);
- else
- rewriter.replaceOp(maskOp, terminator.getOperands());
-
- return success();
- }
-};
-
-void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
- MLIRContext *context) {
- results.add<ElideEmptyMaskOp>(context);
-}
-
// MaskingOpInterface definitions.
/// Returns the operation masked by this 'vector.mask'.
diff --git a/mlir/test/Conversion/GPUCommon/lower-vector.mlir b/mlir/test/Conversion/GPUCommon/lower-vector.mlir
index 532a2383cea9e..b4e3da9d0dbfe 100644
--- a/mlir/test/Conversion/GPUCommon/lower-vector.mlir
+++ b/mlir/test/Conversion/GPUCommon/lower-vector.mlir
@@ -1,10 +1,12 @@
// RUN: mlir-opt %s --gpu-to-llvm | FileCheck %s
module {
+ // CHECK-LABEL: func @func
+ // CHECK-SAME: %[[IN:.*]]: vector<11xf32>
func.func @func(%arg: vector<11xf32>) -> vector<11xf32> {
%cst_41 = arith.constant dense<true> : vector<11xi1>
- // CHECK: vector.mask
- // CHECK-SAME: vector.yield %arg0
+ // CHECK-NOT: vector.mask
+ // CHECK: return %[[IN]] : vector<11xf32>
%127 = vector.mask %cst_41 { vector.yield %arg : vector<11xf32> } : vector<11xi1> -> vector<11xf32>
return %127 : vector<11xf32>
}
More information about the Mlir-commits
mailing list