[Mlir-commits] [mlir] [mlir][vector] Extend `CreateMaskFolder` (PR #75842)

Jakub Kuderski llvmlistbot at llvm.org
Mon Dec 18 11:19:42 PST 2023


================
@@ -5657,46 +5657,102 @@ LogicalResult CreateMaskOp::verify() {
 
 namespace {
 
-// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
+/// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
+///
+/// Ex 1:
+///   %c2 = arith.constant 2 : index
+///   %c3 = arith.constant 3 : index
+///   %0 = vector.create_mask %c3, %c2 : vector<4x3xi1>
+/// Becomes:
+///    vector.constant_mask [3, 2] : vector<4x3xi1>
+///
+/// Ex 2:
+///   %c_neg_1 = arith.constant -1 : index
+///   %0 = vector.create_mask %c_neg_1 : vector<[8]xi1>
+/// becomes:
+///   vector.constant_mask [0] : vector<[8]xi1>
+///
+/// Ex 3:
+///   %c8 = arith.constant 8 : index
+///   %c16 = arith.constant 16 : index
+///   %0 = vector.vscale
+///   %1 = arith.muli %0, %c16 : index
+///   %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
+/// becomes:
+///   %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1>
 class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
 public:
   using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
                                 PatternRewriter &rewriter) const override {
-    // Return if any of 'createMaskOp' operands are not defined by a constant.
-    auto isNotDefByConstant = [](Value operand) {
-      return !getConstantIntValue(operand).has_value();
-    };
-    if (llvm::any_of(createMaskOp.getOperands(), isNotDefByConstant))
-      return failure();
+    VectorType retTy = createMaskOp.getResult().getType();
+    bool isScalable = retTy.isScalable();
+
+    // Check every mask operand
+    for (auto [opIdx, operand] : llvm::enumerate(createMaskOp.getOperands())) {
+      // Most basic case - this operand is a constant value. Note that for
+      // scalable dimensions, CreateMaskOp can be folded only if the
+      // corresponding operand is negative or zero.
+      if (auto op = getConstantIntValue(operand)) {
----------------
kuhar wrote:

It's not clear to me what the type is based on the RHS only. Can you write the full type here?
Also, is `op` used inside the if body? If not, maybe we can drop it all together.

https://github.com/llvm/llvm-project/pull/75842


More information about the Mlir-commits mailing list