[Mlir-commits] [mlir] 385c9f5 - [MLIR] Cleanup `constantTripCount()` (NFC) (#159307)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Sep 17 04:18:01 PDT 2025


Author: Mehdi Amini
Date: 2025-09-17T13:17:57+02:00
New Revision: 385c9f5b0244b6e9ff33856852aafc04e6b370a3

URL: https://github.com/llvm/llvm-project/commit/385c9f5b0244b6e9ff33856852aafc04e6b370a3
DIFF: https://github.com/llvm/llvm-project/commit/385c9f5b0244b6e9ff33856852aafc04e6b370a3.diff

LOG: [MLIR] Cleanup `constantTripCount()` (NFC) (#159307)

Add post-merge review comments on #158679

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/Utils/Utils.cpp
    mlir/lib/Dialect/Utils/StaticValueUtils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 18f139c1bd54a..e7bce98c607df 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -483,7 +483,7 @@ LogicalResult mlir::loopUnrollFull(scf::ForOp forOp) {
   std::optional<APInt> mayBeConstantTripCount = forOp.getStaticTripCount();
   if (!mayBeConstantTripCount.has_value())
     return failure();
-  APInt &tripCount = *mayBeConstantTripCount;
+  const APInt &tripCount = *mayBeConstantTripCount;
   if (tripCount.isZero())
     return success();
   if (tripCount.getSExtValue() == 1)

diff  --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 5048b19b2891f..8d3944f883963 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/Attributes.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/APSInt.h"
@@ -280,27 +281,28 @@ std::optional<APInt> constantTripCount(
         computeUbMinusLb) {
   // This is the bitwidth used to return 0 when loop does not execute.
   // We infer it from the type of the bound if it isn't an index type.
-  bool isIndex = true;
-  auto getBitwidth = [&](OpFoldResult ofr) -> int {
-    if (auto attr = dyn_cast<Attribute>(ofr)) {
-      if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
-        if (auto intType = dyn_cast<IntegerType>(intAttr.getType())) {
-          isIndex = intType.isIndex();
-          return intType.getWidth();
-        }
-      }
+  auto getBitwidth = [&](OpFoldResult ofr) -> std::tuple<int, bool> {
+    if (auto intAttr =
+            dyn_cast_or_null<IntegerAttr>(dyn_cast<Attribute>(ofr))) {
+      if (auto intType = dyn_cast<IntegerType>(intAttr.getType()))
+        return std::make_tuple(intType.getWidth(), intType.isIndex());
     } else {
       auto val = cast<Value>(ofr);
-      if (auto intType = dyn_cast<IntegerType>(val.getType())) {
-        isIndex = intType.isIndex();
-        return intType.getWidth();
-      }
+      if (auto intType = dyn_cast<IntegerType>(val.getType()))
+        return std::make_tuple(intType.getWidth(), intType.isIndex());
     }
-    return IndexType::kInternalStorageBitWidth;
+    return std::make_tuple(IndexType::kInternalStorageBitWidth, true);
   };
-  int bitwidth = getBitwidth(lb);
-  assert(bitwidth == getBitwidth(ub) &&
-         "lb and ub must have the same bitwidth");
+  auto [bitwidth, isIndex] = getBitwidth(lb);
+  // This would better be an assert, but unfortunately it breaks scf.for_all
+  // which is missing attributes and SSA value optionally for its bounds, and
+  // uses Index type for the dynamic bounds but i64 for the static bounds. This
+  // is broken...
+  if (std::tie(bitwidth, isIndex) != getBitwidth(ub)) {
+    LDBG() << "mismatch between lb and ub bitwidth/type: " << ub << " vs "
+           << lb;
+    return std::nullopt;
+  }
   if (lb == ub)
     return APInt(bitwidth, 0);
 


        


More information about the Mlir-commits mailing list