[Mlir-commits] [mlir] [mlir][vector] Add mask elimination transform (PR #99314)

Andrzej Warzyński llvmlistbot at llvm.org
Wed Aug 7 08:15:22 PDT 2024


================
@@ -5862,73 +5887,46 @@ class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
 
   LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
                                 PatternRewriter &rewriter) const override {
-    VectorType retTy = createMaskOp.getResult().getType();
-    bool isScalable = retTy.isScalable();
-
-    // Check every mask operand
-    for (auto [opIdx, operand] : llvm::enumerate(createMaskOp.getOperands())) {
-      if (auto cst = getConstantIntValue(operand)) {
-        // Most basic case - this operand is a constant value. Note that for
-        // scalable dimensions, CreateMaskOp can be folded only if the
-        // corresponding operand is negative or zero.
-        if (retTy.getScalableDims()[opIdx] && *cst > 0)
-          return failure();
-
-        continue;
-      }
-
-      // Non-constant operands are not allowed for non-scalable vectors.
-      if (!isScalable)
-        return failure();
-
-      // For scalable vectors, "arith.muli %vscale, %dimSize" means an "all
-      // true" mask, so can also be treated as constant.
-      auto mul = operand.getDefiningOp<arith::MulIOp>();
-      if (!mul)
-        return failure();
-      auto mulLHS = mul.getRhs();
-      auto mulRHS = mul.getLhs();
-      bool isOneOpVscale =
-          (isa<vector::VectorScaleOp>(mulLHS.getDefiningOp()) ||
-           isa<vector::VectorScaleOp>(mulRHS.getDefiningOp()));
-
-      auto isConstantValMatchingDim =
-          [=, dim = retTy.getShape()[opIdx]](Value operand) {
-            auto constantVal = getConstantIntValue(operand);
-            return (constantVal.has_value() && constantVal.value() == dim);
-          };
-
-      bool isOneOpConstantMatchingDim =
-          isConstantValMatchingDim(mulLHS) || isConstantValMatchingDim(mulRHS);
-
-      if (!isOneOpVscale || !isOneOpConstantMatchingDim)
-        return failure();
+    VectorType maskType = createMaskOp.getVectorType();
+    ArrayRef<int64_t> maskTypeDimSizes = maskType.getShape();
+    ArrayRef<bool> maskTypeDimScalableFlags = maskType.getScalableDims();
+
+    // Special case: Rank zero shape.
+    constexpr std::array<int64_t, 1> rankZeroShape{1};
+    constexpr std::array<bool, 1> rankZeroScalableDims{false};
+    if (maskType.getRank() == 0) {
+      maskTypeDimSizes = rankZeroShape;
+      maskTypeDimScalableFlags = rankZeroScalableDims;
     }
 
-    // Gather constant mask dimension sizes.
-    SmallVector<int64_t, 4> maskDimSizes;
-    maskDimSizes.reserve(createMaskOp->getNumOperands());
-    for (auto [operand, maxDimSize] : llvm::zip_equal(
-             createMaskOp.getOperands(), createMaskOp.getType().getShape())) {
-      std::optional dimSize = getConstantIntValue(operand);
-      if (!dimSize) {
-        // Although not a constant, it is safe to assume that `operand` is
-        // "vscale * maxDimSize".
-        maskDimSizes.push_back(maxDimSize);
-        continue;
-      }
-      int64_t dimSizeVal = std::min(dimSize.value(), maxDimSize);
-      // If one of dim sizes is zero, set all dims to zero.
-      if (dimSize <= 0) {
-        maskDimSizes.assign(createMaskOp.getType().getRank(), 0);
-        break;
+    SmallVector<int64_t, 4> constantDims;
+    for (auto [i, dimSize] : llvm::enumerate(createMaskOp.getOperands())) {
+      if (auto intSize = getConstantIntValue(dimSize)) {
+        // Non scalable dims can have any value. Scalable dims can only be zero.
+        if (intSize >= 0 && maskTypeDimScalableFlags[i])
+          return failure();
+        constantDims.push_back(*intSize);
+      } else if (auto vscaleMultiplier = getConstantVscaleMultiplier(dimSize)) {
+        // Scalable dims must be all-true.
+        if (vscaleMultiplier < maskTypeDimSizes[i])
+          return failure();
+        constantDims.push_back(*vscaleMultiplier);
+      } else {
+        return failure();
----------------
banach-space wrote:

Please document what the purpose of this loop is (IIUC, identify whether the folding would make sense at all + collect constant dims to use for `vector.constant_mask`). 

In particular, let's make it clear what the difference between this loop and what's in [VectorMaskElimination.cpp](https://github.com/llvm/llvm-project/pull/99314/files#diff-bd4772744dd105a11599527d87340a3b6e5b354d5fbe6af65d86437cca59c493) actually is 🙏🏻 

https://github.com/llvm/llvm-project/pull/99314


More information about the Mlir-commits mailing list