[Mlir-commits] [mlir] 204eb70 - [mlir][Vector] Canonicalize empty `vector.mask` into `arith.select` (#140976)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri May 23 08:30:01 PDT 2025


Author: Diego Caballero
Date: 2025-05-23T08:29:57-07:00
New Revision: 204eb70af894770fb4b9107fbcf3003cb3f9cb72

URL: https://github.com/llvm/llvm-project/commit/204eb70af894770fb4b9107fbcf3003cb3f9cb72
DIFF: https://github.com/llvm/llvm-project/commit/204eb70af894770fb4b9107fbcf3003cb3f9cb72.diff

LOG: [mlir][Vector] Canonicalize empty `vector.mask` into `arith.select` (#140976)

This PR adds a missing canonicalization for empty `vector.mask` ops with
a passthru value.

```
   %0 = vector.mask %mask, %passthru { vector.yield %a : vector<8xf32> } :
     vector<8xi1> -> vector<8xf32>

 becomes:

   %0 = arith.select %mask, %a, %passthru : vector<8xf32>
```

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 3f5564541554e..5e8421ed67d66 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2559,6 +2559,7 @@ def Vector_MaskOp : Vector_Op<"mask", [
                                  Location loc);
   }];
 
+  let hasCanonicalizer = 1;
   let hasFolder = 1;
   let hasCustomAssemblyFormat = 1;
   let hasVerifier = 1;

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 41777347975da..890a5e9e5c9b4 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6661,6 +6661,9 @@ LogicalResult MaskOp::verify() {
 ///
 ///   %0 = user_op %a : vector<8xf32>
 ///
+/// Empty `vector.mask` with passthru operand are handled by the canonicalizer
+/// as it requires creating new operations.
+
 static LogicalResult foldEmptyMaskOp(MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
                                      SmallVectorImpl<OpFoldResult> &results) {
   if (!maskOp.isEmpty() || maskOp.hasPassthru())
@@ -6696,6 +6699,47 @@ LogicalResult MaskOp::fold(FoldAdaptor adaptor,
   return success();
 }
 
+/// Canonialize empty `vector.mask` operations that can't be handled in
+/// `VectorMask::fold` as they require creating new operations.
+///
+/// Example 1: Empty `vector.mask` with passthru operand.
+///
+///   %0 = vector.mask %mask, %passthru { vector.yield %a : vector<8xf32> } :
+///     vector<8xi1> -> vector<8xf32>
+///
+/// becomes:
+///
+///   %0 = arith.select %mask, %a, %passthru : vector<8xf32>
+///
+class CanonializeEmptyMaskOp : public OpRewritePattern<MaskOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(MaskOp maskOp,
+                                PatternRewriter &rewriter) const override {
+    if (!maskOp.isEmpty())
+      return failure();
+
+    if (!maskOp.hasPassthru())
+      return failure();
+
+    Block *block = maskOp.getMaskBlock();
+    auto terminator = cast<vector::YieldOp>(block->front());
+    assert(terminator.getNumOperands() == 1 &&
+           "expected one result when passthru is provided");
+
+    rewriter.replaceOpWithNewOp<arith::SelectOp>(
+        maskOp, maskOp.getResultTypes(), maskOp.getMask(),
+        terminator.getOperand(0), maskOp.getPassthru());
+
+    return success();
+  }
+};
+
+void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                         MLIRContext *context) {
+  results.add<CanonializeEmptyMaskOp>(context);
+}
+
 // MaskingOpInterface definitions.
 
 /// Returns the operation masked by this 'vector.mask'.

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 974f4506a2ef0..a6543aafd1c77 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -719,7 +719,7 @@ func.func @fold_extract_transpose(
 // CHECK-LABEL: fold_extract_broadcast_same_input_output_scalar
 //  CHECK-SAME:   %[[A:.*]]: f32
 //       CHECK:   return %[[A]] : f32
-func.func @fold_extract_broadcast_same_input_output_scalar(%a : f32, 
+func.func @fold_extract_broadcast_same_input_output_scalar(%a : f32,
   %idx0 : index, %idx1 : index, %idx2 : index) -> f32 {
   %b = vector.broadcast %a : f32 to vector<1x2x4xf32>
   %r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
@@ -731,7 +731,7 @@ func.func @fold_extract_broadcast_same_input_output_scalar(%a : f32,
 // CHECK-LABEL: fold_extract_broadcast_same_input_output_vec
 //  CHECK-SAME:   %[[A:.*]]: vector<4xf32>
 //       CHECK:   return %[[A]] : vector<4xf32>
-func.func @fold_extract_broadcast_same_input_output_vec(%a : vector<4xf32>, 
+func.func @fold_extract_broadcast_same_input_output_vec(%a : vector<4xf32>,
   %idx0 : index, %idx1 : index) -> vector<4xf32> {
   %b = vector.broadcast %a : vector<4xf32> to vector<1x2x4xf32>
   %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
@@ -744,7 +744,7 @@ func.func @fold_extract_broadcast_same_input_output_vec(%a : vector<4xf32>,
 //  CHECK-SAME:   %[[A:.*]]: vector<f32>
 //       CHECK:   %[[B:.+]] = vector.extract %[[A]][] : f32 from vector<f32>
 //       CHECK:   return %[[B]] : f32
-func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>, 
+func.func @fold_extract_broadcast_0dvec_input_scalar_output(%a : vector<f32>,
   %idx0 : index, %idx1 : index, %idx2: index) -> f32 {
   %b = vector.broadcast %a : vector<f32> to vector<1x2x4xf32>
   %r = vector.extract %b[%idx0, %idx1, %idx2] : f32 from vector<1x2x4xf32>
@@ -780,7 +780,7 @@ func.func @fold_extract_splat(%a : f32, %idx0 : index, %idx1 : index, %idx2 : in
 //  CHECK-SAME:   %[[IDX:.*]]: index, %[[IDX1:.*]]: index, %[[IDX2:.*]]: index
 //       CHECK:   %[[R:.*]] = vector.extract %[[A]][%[[IDX1]], 0] : f32 from vector<2x1xf32>
 //       CHECK:   return %[[R]] : f32
-func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>, 
+func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>,
   %idx : index, %idx1 : index, %idx2 : index) -> f32 {
   %b = vector.broadcast %a : vector<2x1xf32> to vector<1x2x4xf32>
   %r = vector.extract %b[%idx, %idx1, %idx2] : f32 from vector<1x2x4xf32>
@@ -795,7 +795,7 @@ func.func @fold_extract_broadcast_dim1_broadcasting(%a : vector<2x1xf32>,
 //       CHECK:   %[[B:.+]] = vector.extract %[[A]][%[[IDX1]]] : vector<4xf32> from vector<2x4xf32>
 //       CHECK:   return %[[B]] : vector<4xf32>
 // rank(extract_output) < rank(broadcast_input)
-func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>, 
+func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>,
   %idx0 : index, %idx1 : index) -> vector<4xf32> {
   %b = vector.broadcast %a : vector<2x4xf32> to vector<1x2x4xf32>
   %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
@@ -808,7 +808,7 @@ func.func @fold_extract_broadcast_to_lower_rank(%a : vector<2x4xf32>,
 //       CHECK:   %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32>
 //       CHECK:   return %[[B]] : vector<4xf32>
 // rank(extract_output) > rank(broadcast_input)
-func.func @fold_extract_broadcast_to_higher_rank(%a : f32, %idx0 : index, %idx1 : index) 
+func.func @fold_extract_broadcast_to_higher_rank(%a : f32, %idx0 : index, %idx1 : index)
   -> vector<4xf32> {
   %b = vector.broadcast %a : f32 to vector<1x2x4xf32>
   %r = vector.extract %b[%idx0, %idx1] : vector<4xf32> from vector<1x2x4xf32>
@@ -822,7 +822,7 @@ func.func @fold_extract_broadcast_to_higher_rank(%a : f32, %idx0 : index, %idx1
 //       CHECK:   %[[R:.*]] = vector.broadcast %[[A]] : vector<1xf32> to vector<8xf32>
 //       CHECK:   return %[[R]] : vector<8xf32>
 // rank(extract_output) == rank(broadcast_input)
-func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, %idx0 : index) 
+func.func @fold_extract_broadcast_to_equal_rank(%a : vector<1xf32>, %idx0 : index)
   -> vector<8xf32> {
   %b = vector.broadcast %a : vector<1xf32> to vector<1x8xf32>
   %r = vector.extract %b[%idx0] : vector<8xf32> from vector<1x8xf32>
@@ -1169,7 +1169,7 @@ func.func @broadcast_poison() -> vector<4x6xi8> {
   return %broadcast : vector<4x6xi8>
 }
 
-// ----- 
+// -----
 
 // CHECK-LABEL:  broadcast_splat_constant
 //       CHECK:  %[[CONST:.*]] = arith.constant dense<1> : vector<4x6xi8>
@@ -2756,6 +2756,19 @@ func.func @empty_vector_mask_with_return(%a : vector<8xf32>, %mask : vector<8xi1
 
 // -----
 
+// CHECK-LABEL: func @empty_vector_mask_with_passthru
+//  CHECK-SAME:     %[[IN:.*]]: vector<8xf32>, %[[MASK:.*]]: vector<8xi1>, %[[PASSTHRU:.*]]: vector<8xf32>
+func.func @empty_vector_mask_with_passthru(%a : vector<8xf32>, %mask : vector<8xi1>,
+                                           %passthru : vector<8xf32>) -> vector<8xf32> {
+//   CHECK-NOT:   vector.mask
+//       CHECK:   %[[SEL:.*]] = arith.select %[[MASK]], %[[IN]], %[[PASSTHRU]] : vector<8xi1>, vector<8xf32>
+//       CHECK:   return %[[SEL]] : vector<8xf32>
+  %0 = vector.mask %mask, %passthru { vector.yield %a : vector<8xf32> } : vector<8xi1> -> vector<8xf32>
+  return %0 : vector<8xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @all_true_vector_mask
 //  CHECK-SAME:     %[[IN:.*]]: tensor<3x4xf32>
 func.func @all_true_vector_mask(%ta : tensor<3x4xf32>) -> vector<3x4xf32> {


        


More information about the Mlir-commits mailing list