[Mlir-commits] [mlir] 0c71a6e - [mlir][Vector] Add folding for vector.mask with all-true masks

Diego Caballero llvmlistbot at llvm.org
Thu May 18 12:07:20 PDT 2023


Author: Diego Caballero
Date: 2023-05-18T19:07:07Z
New Revision: 0c71a6e7c8e9a22bc44f25a90e8bdd995b8d9261

URL: https://github.com/llvm/llvm-project/commit/0c71a6e7c8e9a22bc44f25a90e8bdd995b8d9261
DIFF: https://github.com/llvm/llvm-project/commit/0c71a6e7c8e9a22bc44f25a90e8bdd995b8d9261.diff

LOG: [mlir][Vector] Add folding for vector.mask with all-true masks

This patch removes `vector.mask` operations with all-true masks (i.e.,
all lanes enabled).

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D150743

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index b9bcd2bb5f126..09c8d6a2b2831 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2375,7 +2375,7 @@ def Vector_MaskOp : Vector_Op<"mask", [
     ```
   }];
 
-  // TODO: Support multiple results and passthru values.
+  // TODO: Support multiple passthru values.
   let arguments = (ins VectorOf<[I1]>:$mask,
                    Optional<AnyType>:$passthru);
   let results = (outs Variadic<AnyType>:$results);
@@ -2394,10 +2394,21 @@ def Vector_MaskOp : Vector_Op<"mask", [
 
   let extraClassDeclaration = [{
     Block *getMaskBlock() { return &getMaskRegion().front(); }
-    static void ensureTerminator(Region &region, Builder &builder, Location loc);
+
+    /// Returns true if mask op is not masking any operation.
+    bool isEmpty() {
+      Block *block = getMaskBlock();
+      if (block->getOperations().size() > 1)
+        return false;
+      return true;
+    }
+
+    static void ensureTerminator(Region &region, Builder &builder,
+                                 Location loc);
   }];
 
   let hasCanonicalizer = 1;
+  let hasFolder = 1;
   let hasCustomAssemblyFormat = 1;
   let hasVerifier = 1;
 }

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 1549237f8c9fd..64b64c6ae71de 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5448,6 +5448,25 @@ LogicalResult MaskOp::verify() {
   return success();
 }
 
+/// Folds vector.mask ops with an all-true mask.
+LogicalResult MaskOp::fold(FoldAdaptor adaptor,
+                           SmallVectorImpl<OpFoldResult> &results) {
+  MaskFormat maskFormat = getMaskFormat(getMask());
+  if (isEmpty())
+    return failure();
+
+  if (maskFormat != MaskFormat::AllTrue)
+    return failure();
+
+  // Move maskable operation outside of the `vector.mask` region.
+  Operation *maskableOp = getMaskableOp();
+  maskableOp->dropAllUses();
+  maskableOp->moveBefore(getOperation());
+
+  results.push_back(maskableOp->getResult(0));
+  return success();
+}
+
 // Elides empty vector.mask operations with or without return values. Propagates
 // the yielded values by the vector.yield terminator, if any, or erases the op,
 // otherwise.
@@ -5460,10 +5479,10 @@ class ElideEmptyMaskOp : public OpRewritePattern<MaskOp> {
     if (maskingOp.getMaskableOp())
       return failure();
 
-    Block *block = maskOp.getMaskBlock();
-    if (block->getOperations().size() > 1)
+    if (!maskOp.isEmpty())
       return failure();
 
+    Block *block = maskOp.getMaskBlock();
     auto terminator = cast<vector::YieldOp>(block->front());
     if (terminator.getNumOperands() == 0)
       rewriter.eraseOp(maskOp);

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 4ce4350f0e4f3..739ab00fa43f9 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2147,4 +2147,16 @@ func.func @empty_vector_mask_with_return(%a : vector<8xf32>, %mask : vector<8xi1
   return %0 : vector<8xf32>
 }
 
+// -----
+
+// CHECK-LABEL: func @all_true_vector_mask
+//  CHECK-SAME:     %[[IN:.*]]: vector<3x4xf32>
+func.func @all_true_vector_mask(%a : vector<3x4xf32>) -> vector<3x4xf32> {
+//   CHECK-NOT:   vector.mask
+//       CHECK:   %[[ADD:.*]] = arith.addf %[[IN]], %[[IN]] : vector<3x4xf32>
+//       CHECK:   return %[[ADD]] : vector<3x4xf32>
+  %all_true = vector.constant_mask [3, 4] : vector<3x4xi1>
+  %0 = vector.mask %all_true { arith.addf %a, %a : vector<3x4xf32> } : vector<3x4xi1> -> vector<3x4xf32>
+  return %0 : vector<3x4xf32>
+}
 


        


More information about the Mlir-commits mailing list