[Mlir-commits] [mlir] 204eb70 - [mlir][Vector] Canonicalize empty `vector.mask` into `arith.select` (#140976)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri May 23 08:30:01 PDT 2025
Author: Diego Caballero
Date: 2025-05-23T08:29:57-07:00
New Revision: 204eb70af894770fb4b9107fbcf3003cb3f9cb72
URL: https://github.com/llvm/llvm-project/commit/204eb70af894770fb4b9107fbcf3003cb3f9cb72
DIFF: https://github.com/llvm/llvm-project/commit/204eb70af894770fb4b9107fbcf3003cb3f9cb72.diff
LOG: [mlir][Vector] Canonicalize empty `vector.mask` into `arith.select` (#140976)
This PR adds a missing canonicalization for empty `vector.mask` ops with
a passthru value.
```
%0 = vector.mask %mask, %passthru { vector.yield %a : vector<8xf32> } :
vector<8xi1> -> vector<8xf32>
becomes:
%0 = arith.select %mask, %a, %passthru : vector<8xf32>
```
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 3f5564541554e..5e8421ed67d66 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2559,6 +2559,7 @@ def Vector_MaskOp : Vector_Op<"mask", [
Location loc);
}];
+ let hasCanonicalizer = 1;
let hasFolder = 1;
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 41777347975da..890a5e9e5c9b4 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6661,6 +6661,9 @@ LogicalResult MaskOp::verify() {
///
/// %0 = user_op %a : vector<8xf32>
///
+/// Empty `vector.mask` with passthru operand are handled by the canonicalizer
+/// as it requires creating new operations.
+
static LogicalResult foldEmptyMaskOp(MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
if (!maskOp.isEmpty() || maskOp.hasPassthru())
@@ -6696,6 +6699,47 @@ LogicalResult MaskOp::fold(FoldAdaptor adaptor,
return success();
}
+/// Canonialize empty `vector.mask` operations that can't be handled in
+/// `VectorMask::fold` as they require creating new operations.
+///
+/// Example 1: Empty `vector.mask` with passthru operand.
+///
+/// %0 = vector.mask %mask, %passthru { vector.yield %a : vector<8xf32> } :
+/// vector<8xi1> -> vector<8xf32>
+///
+/// becomes:
+///
+/// %0 = arith.select %mask, %a, %passthru : vector<8xf32>
+///
+class CanonializeEmptyMaskOp : public OpRewritePattern<MaskOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(MaskOp maskOp,
+ PatternRewriter &rewriter) const override {
+ if (!maskOp.isEmpty())
+ return failure();
+
+ if (!maskOp.hasPassthru())
+ return failure();
+
+ Block *block = maskOp.getMaskBlock();
+ auto terminator = cast<vector::YieldOp>(block->front());
+ assert(terminator.getNumOperands() == 1 &&
+ "expected one result when passthru is provided");
+
+ rewriter.replaceOpWithNewOp<arith::SelectOp>(
+ maskOp, maskOp.getResultTypes(), maskOp.getMask(),
+ terminator.getOperand(0), maskOp.getPassthru());
+
+ return success();
+ }
+};
+
+void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<CanonializeEmptyMaskOp>(context);
+}
+
// MaskingOpInterface definitions.
/// Returns the operation masked by this 'vector.mask'.
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 974f4506a2ef0..a6543aafd1c77 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -719,7 +719,7 @@ func.func @fold_extract_transpose(
// CHECK-LABEL: fold_extract_broadcast_same_input_output_scalar
// CHECK-SAME: %[[A:.*]]: f32
// CHECK: return %[[A]] : f32
-func.func @fold_extract_broadcast_same_input_output_scalar(%a : f32,
+func.func @fold_extract_broadcast_same_input_output_scalar(%a : f32,
%idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
%b = vector.broadcast %a : f32 to vector<1x2x4xf32>
%r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
@@ -731,7 +731,7 @@ func.func @fold_extract_broadcast_same_input_output_scalar(%a : f32,
// CHECK-LABEL: fold_extract_broadcast_same_input_output_vec
// CHECK-SAME: %[[A:.*]]: vector<4xf32>
// CHECK: return %[[A]] : vector<4xf32>
-func.func @fold_extract_broadcast_same_input_output_vec(%a : vector<4xf32>,
+func.func @fold_extract_broadcast_same_input_output_vec(%a : vector<4xf32>,
%idx0 : index, %idx1 : index) -> vector<4xf32> {
%b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
%r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
@@ -744,7 +744,7 @@ func.func @fold_extract_broadcast_same_input_output_vec(%a : vector<4xf32>,
// CHECK-SAME: %[[A:.*]]: vector<f32>
// CHECK: %[[B:.+]] = vector.extract %[[A]][] : f32 from vector<f32>
// CHECK: return %[[B]] : f32
-func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
+func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
%idx0 : index, %idx1 : index, %idx2: index) -> f32 {
%b = vector.broadcast %a : vector<f32> to vector<1x2x4xf32>
%r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
@@ -780,7 +780,7 @@ func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : in
// CHECK-SAME: %[[IDX:.*]]: index, %[[IDX1:.*]]: index, %[[IDX2:.*]]: index
// CHECK: %[[R:.*]] = vector.extract %[[A]][%[[IDX1]], 0] : f32 from vector<2x1xf32>
// CHECK: return %[[R]] : f32
-func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>,
+func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>,
%idx : index, %idx1 : index, %idx2 : index) -> f32 {
%b = vector.broadcast %a : vector<2x1xf32> to vector<1x2x4xf32>
%r = vector.extract %b[%idx, %idx1, %idx2] : f32 from vector<1x2x4xf32>
@@ -795,7 +795,7 @@ func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>,
// CHECK: %[[B:.+]] = vector.extract %[[A]][%[[IDX1]]] : vector<4xf32> from vector<2x4xf32>
// CHECK: return %[[B]] : vector<4xf32>
// rank(extract_output) < rank(broadcast_input)
-func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>,
+func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>,
%idx0 : index, %idx1 : index) -> vector<4xf32> {
%b = vector.broadcast %a : vector<2x4xf32> to vector<1x2x4xf32>
%r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
@@ -808,7 +808,7 @@ func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>,
// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
// CHECK: return %[[B]] : vector<4xf32>
// rank(extract_output) > rank(broadcast_input)
-func.func @fold_extract_broadcast_to_higher_rank(%a : f32, %idx0 : index, %idx1 : index)
+func.func @fold_extract_broadcast_to_higher_rank(%a : f32, %idx0 : index, %idx1 : index)
-> vector<4xf32> {
%b = vector.broadcast %a : f32 to vector<1x2x4xf32>
%r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
@@ -822,7 +822,7 @@ func.func @fold_extract_broadcast_to_higher_rank(%a : f32, %idx0 : index, %idx1
// CHECK: %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<8xf32>
// CHECK: return %[[R]] : vector<8xf32>
// rank(extract_output) == rank(broadcast_input)
-func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, %idx0 : index)
+func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, %idx0 : index)
-> vector<8xf32> {
%b = vector.broadcast %a : vector<1xf32> to vector<1x8xf32>
%r = vector.extract %b[%idx0] : vector<8xf32> from vector<1x8xf32>
@@ -1169,7 +1169,7 @@ func.func @broadcast_poison() -> vector<4x6xi8> {
return %broadcast : vector<4x6xi8>
}
-// -----
+// -----
// CHECK-LABEL: broadcast_splat_constant
// CHECK: %[[CONST:.*]] = arith.constant dense<1> : vector<4x6xi8>
@@ -2756,6 +2756,19 @@ func.func @empty_vector_mask_with_return(%a : vector<8xf32>, %mask : vector<8xi1
// -----
+// CHECK-LABEL: func @empty_vector_mask_with_passthru
+// CHECK-SAME: %[[IN:.*]]: vector<8xf32>, %[[MASK:.*]]: vector<8xi1>, %[[PASSTHRU:.*]]: vector<8xf32>
+func.func @empty_vector_mask_with_passthru(%a : vector<8xf32>, %mask : vector<8xi1>,
+ %passthru : vector<8xf32>) -> vector<8xf32> {
+// CHECK-NOT: vector.mask
+// CHECK: %[[SEL:.*]] = arith.select %[[MASK]], %[[IN]], %[[PASSTHRU]] : vector<8xi1>, vector<8xf32>
+// CHECK: return %[[SEL]] : vector<8xf32>
+ %0 = vector.mask %mask, %passthru { vector.yield %a : vector<8xf32> } : vector<8xi1> -> vector<8xf32>
+ return %0 : vector<8xf32>
+}
+
+// -----
+
// CHECK-LABEL: func @all_true_vector_mask
// CHECK-SAME: %[[IN:.*]]: tensor<3x4xf32>
func.func @all_true_vector_mask(%ta : tensor<3x4xf32>) -> vector<3x4xf32> {
More information about the Mlir-commits
mailing list