[Mlir-commits] [mlir] 0c71a6e - [mlir][Vector] Add folding for vector.mask with all-true masks
Diego Caballero
llvmlistbot at llvm.org
Thu May 18 12:07:20 PDT 2023
Author: Diego Caballero
Date: 2023-05-18T19:07:07Z
New Revision: 0c71a6e7c8e9a22bc44f25a90e8bdd995b8d9261
URL: https://github.com/llvm/llvm-project/commit/0c71a6e7c8e9a22bc44f25a90e8bdd995b8d9261
DIFF: https://github.com/llvm/llvm-project/commit/0c71a6e7c8e9a22bc44f25a90e8bdd995b8d9261.diff
LOG: [mlir][Vector] Add folding for vector.mask with all-true masks
This patch removes `vector.mask` operations with all-true masks (i.e.,
all lanes enabled).
Reviewed By: hanchung
Differential Revision: https://reviews.llvm.org/D150743
Added:
Modified:
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index b9bcd2bb5f126..09c8d6a2b2831 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2375,7 +2375,7 @@ def Vector_MaskOp : Vector_Op<"mask", [
```
}];
- // TODO: Support multiple results and passthru values.
+ // TODO: Support multiple passthru values.
let arguments = (ins VectorOf<[I1]>:$mask,
Optional<AnyType>:$passthru);
let results = (outs Variadic<AnyType>:$results);
@@ -2394,10 +2394,21 @@ def Vector_MaskOp : Vector_Op<"mask", [
let extraClassDeclaration = [{
Block *getMaskBlock() { return &getMaskRegion().front(); }
- static void ensureTerminator(Region ®ion, Builder &builder, Location loc);
+
+ /// Returns true if mask op is not masking any operation.
+ bool isEmpty() {
+ Block *block = getMaskBlock();
+ if (block->getOperations().size() > 1)
+ return false;
+ return true;
+ }
+
+ static void ensureTerminator(Region ®ion, Builder &builder,
+ Location loc);
}];
let hasCanonicalizer = 1;
+ let hasFolder = 1;
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 1549237f8c9fd..64b64c6ae71de 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5448,6 +5448,25 @@ LogicalResult MaskOp::verify() {
return success();
}
+/// Folds vector.mask ops with an all-true mask.
+LogicalResult MaskOp::fold(FoldAdaptor adaptor,
+ SmallVectorImpl<OpFoldResult> &results) {
+ MaskFormat maskFormat = getMaskFormat(getMask());
+ if (isEmpty())
+ return failure();
+
+ if (maskFormat != MaskFormat::AllTrue)
+ return failure();
+
+ // Move maskable operation outside of the `vector.mask` region.
+ Operation *maskableOp = getMaskableOp();
+ maskableOp->dropAllUses();
+ maskableOp->moveBefore(getOperation());
+
+ results.push_back(maskableOp->getResult(0));
+ return success();
+}
+
// Elides empty vector.mask operations with or without return values. Propagates
// the yielded values by the vector.yield terminator, if any, or erases the op,
// otherwise.
@@ -5460,10 +5479,10 @@ class ElideEmptyMaskOp : public OpRewritePattern<MaskOp> {
if (maskingOp.getMaskableOp())
return failure();
- Block *block = maskOp.getMaskBlock();
- if (block->getOperations().size() > 1)
+ if (!maskOp.isEmpty())
return failure();
+ Block *block = maskOp.getMaskBlock();
auto terminator = cast<vector::YieldOp>(block->front());
if (terminator.getNumOperands() == 0)
rewriter.eraseOp(maskOp);
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 4ce4350f0e4f3..739ab00fa43f9 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2147,4 +2147,16 @@ func.func @empty_vector_mask_with_return(%a : vector<8xf32>, %mask : vector<8xi1
return %0 : vector<8xf32>
}
+// -----
+
+// CHECK-LABEL: func @all_true_vector_mask
+// CHECK-SAME: %[[IN:.*]]: vector<3x4xf32>
+func.func @all_true_vector_mask(%a : vector<3x4xf32>) -> vector<3x4xf32> {
+// CHECK-NOT: vector.mask
+// CHECK: %[[ADD:.*]] = arith.addf %[[IN]], %[[IN]] : vector<3x4xf32>
+// CHECK: return %[[ADD]] : vector<3x4xf32>
+ %all_true = vector.constant_mask [3, 4] : vector<3x4xi1>
+ %0 = vector.mask %all_true { arith.addf %a, %a : vector<3x4xf32> } : vector<3x4xi1> -> vector<3x4xf32>
+ return %0 : vector<3x4xf32>
+}
More information about the Mlir-commits
mailing list