[Mlir-commits] [mlir] [mlir][vector] Switch to using `getNumScalableDims` (nfc) (PR #100806)

Andrzej WarzyƄski llvmlistbot at llvm.org
Fri Jul 26 12:57:05 PDT 2024


https://github.com/banach-space created https://github.com/llvm/llvm-project/pull/100806

Updates the codebase to use the recently introduced VectorType helper:
  * `getNumScalableDims` (introduced in #100331).


>From ebaf3ae9e91f499ee7edf6ee9a7c1de0608c674b Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Fri, 26 Jul 2024 20:49:39 +0100
Subject: [PATCH] [mlir][vector] Switch to using `getNumScalableDims` (nfc)

Updates the codebase to use the recently introduced VectorType helper:
  * `getNumScalableDims` (introduced in #100331).
---
 mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp | 2 +-
 mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp | 2 +-
 mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp             | 3 +--
 3 files changed, 3 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
index 39292c4533d69..1e711678dc9ab 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp
@@ -479,7 +479,7 @@ struct SwapVectorExtractOfArithExtend
       return rewriter.notifyMatchFailure(extractOp,
                                          "extracted type is not a vector type");
 
-    auto numScalableDims = llvm::count(resultType.getScalableDims(), true);
+    auto numScalableDims = resultType.getNumScalableDims();
     if (numScalableDims != 1)
       return rewriter.notifyMatchFailure(
           extractOp, "extracted type is not a 1-D scalable vector type");
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 96dad6518fec8..81fa195c4bf62 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -549,7 +549,7 @@ struct FoldExtractFromVectorOfSMELikeCreateMasks
       return rewriter.notifyMatchFailure(extractOp,
                                          "extracted type is not a vector type");
 
-    auto numScalable = llvm::count(extractedMaskType.getScalableDims(), true);
+    auto numScalable = extractedMaskType.getNumScalableDims();
     if (numScalable != 2)
       return rewriter.notifyMatchFailure(
           extractOp, "expected extracted type to be an SME-like mask");
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 4ed5a8bac20d1..e590d8c43c44b 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -323,8 +323,7 @@ SmallVector<OpFoldResult> vector::getMixedSizesXfer(bool hasTensorSemantics,
 }
 
 bool vector::isLinearizableVector(VectorType type) {
-  auto numScalableDims = llvm::count(type.getScalableDims(), true);
-  return (type.getRank() > 1) && (numScalableDims <= 1);
+  return (type.getRank() > 1) && (type.getNumScalableDims() <= 1);
 }
 
 Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,



More information about the Mlir-commits mailing list