[Mlir-commits] [mlir] [mlir][vector] Extend `CreateMaskFolder` (PR #75842)

Andrzej WarzyƄski llvmlistbot at llvm.org
Tue Dec 19 09:02:27 PST 2023


https://github.com/banach-space updated https://github.com/llvm/llvm-project/pull/75842

>From 87149572f24ef07660328e147873329cb1b3154d Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Mon, 18 Dec 2023 18:43:08 +0000
Subject: [PATCH 1/4] [mlir][vector] Extend `CreateMaskFolder`

Extends `CreateMaskFolder` pattern so that the following:
```mlir
  %c8 = arith.constant 8 : index
  %c16 = arith.constant 16 : index
  %0 = vector.vscale
  %1 = arith.muli %0, %c16 : index
  %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
```

is folded as:

```mlir
  %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1>
```
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   | 96 +++++++++++++++++-----
 mlir/test/Dialect/Vector/canonicalize.mlir | 13 +++
 2 files changed, 89 insertions(+), 20 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 540959b486db9c..3619c1c00f1664 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5657,30 +5657,79 @@ LogicalResult CreateMaskOp::verify() {
 
 namespace {
 
-// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
+/// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
+///
+/// Ex 1:
+///   %c2 = arith.constant 2 : index
+///   %c3 = arith.constant 3 : index
+///   %0 = vector.create_mask %c3, %c2 : vector<4x3xi1>
+/// Becomes:
+///    vector.constant_mask [3, 2] : vector<4x3xi1>
+///
+/// Ex 2:
+///   %c_neg_1 = arith.constant -1 : index
+///   %0 = vector.create_mask %c_neg_1 : vector<[8]xi1>
+/// becomes:
+///   vector.constant_mask [0] : vector<[8]xi1>
+///
+/// Ex 3:
+///   %c8 = arith.constant 8 : index
+///   %c16 = arith.constant 16 : index
+///   %0 = vector.vscale
+///   %1 = arith.muli %0, %c16 : index
+///   %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
+/// becomes:
+///   %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1>
 class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
 public:
   using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
                                 PatternRewriter &rewriter) const override {
-    // Return if any of 'createMaskOp' operands are not defined by a constant.
-    auto isNotDefByConstant = [](Value operand) {
-      return !getConstantIntValue(operand).has_value();
-    };
-    if (llvm::any_of(createMaskOp.getOperands(), isNotDefByConstant))
-      return failure();
+    VectorType retTy = createMaskOp.getResult().getType();
+    bool isScalable = retTy.isScalable();
+
+    // Check every mask operand
+    for (auto [opIdx, operand] : llvm::enumerate(createMaskOp.getOperands())) {
+      // Most basic case - this operand is a constant value. Note that for
+      // scalable dimensions, CreateMaskOp can be folded only if the
+      // corresponding operand is negative or zero.
+      if (auto op = getConstantIntValue(operand)) {
+        APInt intVal;
+        if (isScalable && !(matchPattern(operand, m_ConstantInt(&intVal)) ||
+                            intVal.isStrictlyPositive()))
+          return failure();
 
-    // CreateMaskOp for scalable vectors can be folded only if all dimensions
-    // are negative or zero.
-    if (auto vType = llvm::dyn_cast<VectorType>(createMaskOp.getType())) {
-      if (vType.isScalable())
-        for (auto opDim : createMaskOp.getOperands()) {
-          APInt intVal;
-          if (matchPattern(opDim, m_ConstantInt(&intVal)) &&
-              intVal.isStrictlyPositive())
-            return failure();
-        }
+        continue;
+      }
+
+      // Non-constant operands are not allowed for non-scalable vectors.
+      if (!isScalable)
+        return failure();
+
+      // For scalable vectors, "arith.muli %vscale, %dimSize" means an "all
+      // true" mask, so can also be treated as constant.
+      auto mul = llvm::dyn_cast_or_null<arith::MulIOp>(operand.getDefiningOp());
+      if (!mul)
+        return failure();
+      auto mulLHS = mul.getOperands()[0];
+      auto mulRHS = mul.getOperands()[1];
+      bool isOneOpVscale =
+          (isa<vector::VectorScaleOp>(mulLHS.getDefiningOp()) ||
+           isa<vector::VectorScaleOp>(mulRHS.getDefiningOp()));
+
+      auto isConstantValMatchingDim =
+          [=, dim = createMaskOp.getResult().getType().getShape()[opIdx]](
+              Value operand) {
+            auto constantVal = getConstantIntValue(operand);
+            return (constantVal.has_value() && constantVal.value() == dim);
+          };
+
+      bool isOneOpConstantMatchingDim =
+          isConstantValMatchingDim(mulLHS) || isConstantValMatchingDim(mulRHS);
+
+      if (!isOneOpVscale || !isOneOpConstantMatchingDim)
+        return failure();
     }
 
     // Gather constant mask dimension sizes.
@@ -5688,15 +5737,22 @@ class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
     maskDimSizes.reserve(createMaskOp->getNumOperands());
     for (auto [operand, maxDimSize] : llvm::zip_equal(
              createMaskOp.getOperands(), createMaskOp.getType().getShape())) {
-      int64_t dimSize = getConstantIntValue(operand).value();
-      dimSize = std::min(dimSize, maxDimSize);
+      auto dimSize = getConstantIntValue(operand);
+      if (not dimSize) {
+        // Although not a constant, it is safe to assume that `operand` is
+        // "vscale * maxDimSize".
+        maskDimSizes.push_back(maxDimSize);
+        continue;
+      }
+      int64_t dimSizeVal = std::min(dimSize.value(), maxDimSize);
       // If one of dim sizes is zero, set all dims to zero.
       if (dimSize <= 0) {
         maskDimSizes.assign(createMaskOp.getType().getRank(), 0);
         break;
       }
-      maskDimSizes.push_back(dimSize);
+      maskDimSizes.push_back(dimSizeVal);
     }
+
     // Replace 'createMaskOp' with ConstantMaskOp.
     rewriter.replaceOpWithNewOp<ConstantMaskOp>(
         createMaskOp, createMaskOp.getResult().getType(),
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 1021c73cc57d34..a30016ea857d97 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -53,6 +53,19 @@ func.func @create_vector_mask_to_constant_mask_truncation_zero() -> (vector<4x3x
 
 // -----
 
+// CHECK-LABEL: create_vector_mask_to_constant_mask_scalable_all_true
+func.func @create_vector_mask_to_constant_mask_scalable_all_true() -> (vector<8x[16]xi1>) {
+  %c8 = arith.constant 8 : index
+  %c16 = arith.constant 16 : index
+  %0 = vector.vscale
+  %1 = arith.muli %0, %c16 : index
+  // CHECK: vector.constant_mask [8, 16] : vector<8x[16]xi1>
+  %10 = vector.create_mask %c8, %1 : vector<8x[16]xi1>
+  return %10 : vector<8x[16]xi1>
+}
+
+// -----
+
 // CHECK-LABEL: create_mask_transpose_to_transposed_create_mask
 //  CHECK-SAME: %[[DIM0:.*]]: index, %[[DIM1:.*]]: index, %[[DIM2:.*]]: index
 func.func @create_mask_transpose_to_transposed_create_mask(

>From 60ec419f816d777e8fa10b93ec972b7462d2df6e Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Mon, 18 Dec 2023 20:22:54 +0000
Subject: [PATCH 2/4] fixup! Use StringRef::{starts,ends}_with (NFC)

Address comments from Jakub
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 3619c1c00f1664..5eef2f4c3271b1 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5694,7 +5694,7 @@ class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
       // Most basic case - this operand is a constant value. Note that for
       // scalable dimensions, CreateMaskOp can be folded only if the
       // corresponding operand is negative or zero.
-      if (auto op = getConstantIntValue(operand)) {
+      if (getConstantIntValue(operand)) {
         APInt intVal;
         if (isScalable && !(matchPattern(operand, m_ConstantInt(&intVal)) ||
                             intVal.isStrictlyPositive()))
@@ -5709,11 +5709,11 @@ class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
 
       // For scalable vectors, "arith.muli %vscale, %dimSize" means an "all
       // true" mask, so can also be treated as constant.
-      auto mul = llvm::dyn_cast_or_null<arith::MulIOp>(operand.getDefiningOp());
+      auto mul = operand.getDefiningOp<arith::MulIOp>();
       if (!mul)
         return failure();
-      auto mulLHS = mul.getOperands()[0];
-      auto mulRHS = mul.getOperands()[1];
+      auto mulLHS = mul.getRhs();
+      auto mulRHS = mul.getLhs();
       bool isOneOpVscale =
           (isa<vector::VectorScaleOp>(mulLHS.getDefiningOp()) ||
            isa<vector::VectorScaleOp>(mulRHS.getDefiningOp()));
@@ -5737,8 +5737,8 @@ class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
     maskDimSizes.reserve(createMaskOp->getNumOperands());
     for (auto [operand, maxDimSize] : llvm::zip_equal(
              createMaskOp.getOperands(), createMaskOp.getType().getShape())) {
-      auto dimSize = getConstantIntValue(operand);
-      if (not dimSize) {
+      std::optional dimSize = getConstantIntValue(operand);
+      if (!dimSize) {
         // Although not a constant, it is safe to assume that `operand` is
         // "vscale * maxDimSize".
         maskDimSizes.push_back(maxDimSize);

>From 08cd6617ef412746088424e94da9ff575d4218d9 Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 19 Dec 2023 12:56:51 +0000
Subject: [PATCH 3/4] fixup! [mlir][vector] Extend `CreateMaskFolder`

Address comments from Cullen
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 5eef2f4c3271b1..66da2e6ac4a1fd 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5719,8 +5719,7 @@ class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
            isa<vector::VectorScaleOp>(mulRHS.getDefiningOp()));
 
       auto isConstantValMatchingDim =
-          [=, dim = createMaskOp.getResult().getType().getShape()[opIdx]](
-              Value operand) {
+          [=, dim = retTy.getShape()[opIdx]](Value operand) {
             auto constantVal = getConstantIntValue(operand);
             return (constantVal.has_value() && constantVal.value() == dim);
           };
@@ -5755,7 +5754,7 @@ class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
 
     // Replace 'createMaskOp' with ConstantMaskOp.
     rewriter.replaceOpWithNewOp<ConstantMaskOp>(
-        createMaskOp, createMaskOp.getResult().getType(),
+        createMaskOp, retTy,
         vector::getVectorSubscriptAttr(rewriter, maskDimSizes));
     return success();
   }

>From 1e0d190d884572e3120d9a3dcb6133156db2b9ed Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 19 Dec 2023 17:01:33 +0000
Subject: [PATCH 4/4] fixup! [mlir][vector] Extend `CreateMaskFolder`

Fix how scalable dims are treated
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 11 ++++++-----
 1 file changed, 6 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 66da2e6ac4a1fd..0e584ae2b19a05 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -5691,13 +5691,14 @@ class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
 
     // Check every mask operand
     for (auto [opIdx, operand] : llvm::enumerate(createMaskOp.getOperands())) {
-      // Most basic case - this operand is a constant value. Note that for
-      // scalable dimensions, CreateMaskOp can be folded only if the
-      // corresponding operand is negative or zero.
       if (getConstantIntValue(operand)) {
+        // Most basic case - this operand is a constant value. Note that for
+        // scalable dimensions, CreateMaskOp can be folded only if the
+        // corresponding operand is negative or zero.
         APInt intVal;
-        if (isScalable && !(matchPattern(operand, m_ConstantInt(&intVal)) ||
-                            intVal.isStrictlyPositive()))
+        if (retTy.getScalableDims()[opIdx] &&
+            (!matchPattern(operand, m_ConstantInt(&intVal)) ||
+             intVal.isStrictlyPositive()))
           return failure();
 
         continue;



More information about the Mlir-commits mailing list