[Mlir-commits] [mlir] 385c9f5 - [MLIR] Cleanup `constantTripCount()` (NFC) (#159307)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Sep 17 04:18:01 PDT 2025
Author: Mehdi Amini
Date: 2025-09-17T13:17:57+02:00
New Revision: 385c9f5b0244b6e9ff33856852aafc04e6b370a3
URL: https://github.com/llvm/llvm-project/commit/385c9f5b0244b6e9ff33856852aafc04e6b370a3
DIFF: https://github.com/llvm/llvm-project/commit/385c9f5b0244b6e9ff33856852aafc04e6b370a3.diff
LOG: [MLIR] Cleanup `constantTripCount()` (NFC) (#159307)
Add post-merge review comments on #158679
Added:
Modified:
mlir/lib/Dialect/SCF/Utils/Utils.cpp
mlir/lib/Dialect/Utils/StaticValueUtils.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 18f139c1bd54a..e7bce98c607df 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -483,7 +483,7 @@ LogicalResult mlir::loopUnrollFull(scf::ForOp forOp) {
std::optional<APInt> mayBeConstantTripCount = forOp.getStaticTripCount();
if (!mayBeConstantTripCount.has_value())
return failure();
- APInt &tripCount = *mayBeConstantTripCount;
+ const APInt &tripCount = *mayBeConstantTripCount;
if (tripCount.isZero())
return success();
if (tripCount.getSExtValue() == 1)
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index 5048b19b2891f..8d3944f883963 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/Attributes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/APSInt.h"
@@ -280,27 +281,28 @@ std::optional<APInt> constantTripCount(
computeUbMinusLb) {
// This is the bitwidth used to return 0 when loop does not execute.
// We infer it from the type of the bound if it isn't an index type.
- bool isIndex = true;
- auto getBitwidth = [&](OpFoldResult ofr) -> int {
- if (auto attr = dyn_cast<Attribute>(ofr)) {
- if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
- if (auto intType = dyn_cast<IntegerType>(intAttr.getType())) {
- isIndex = intType.isIndex();
- return intType.getWidth();
- }
- }
+ auto getBitwidth = [&](OpFoldResult ofr) -> std::tuple<int, bool> {
+ if (auto intAttr =
+ dyn_cast_or_null<IntegerAttr>(dyn_cast<Attribute>(ofr))) {
+ if (auto intType = dyn_cast<IntegerType>(intAttr.getType()))
+ return std::make_tuple(intType.getWidth(), intType.isIndex());
} else {
auto val = cast<Value>(ofr);
- if (auto intType = dyn_cast<IntegerType>(val.getType())) {
- isIndex = intType.isIndex();
- return intType.getWidth();
- }
+ if (auto intType = dyn_cast<IntegerType>(val.getType()))
+ return std::make_tuple(intType.getWidth(), intType.isIndex());
}
- return IndexType::kInternalStorageBitWidth;
+ return std::make_tuple(IndexType::kInternalStorageBitWidth, true);
};
- int bitwidth = getBitwidth(lb);
- assert(bitwidth == getBitwidth(ub) &&
- "lb and ub must have the same bitwidth");
+ auto [bitwidth, isIndex] = getBitwidth(lb);
+ // This would better be an assert, but unfortunately it breaks scf.for_all
+ // which is missing attributes and SSA value optionally for its bounds, and
+ // uses Index type for the dynamic bounds but i64 for the static bounds. This
+ // is broken...
+ if (std::tie(bitwidth, isIndex) != getBitwidth(ub)) {
+ LDBG() << "mismatch between lb and ub bitwidth/type: " << ub << " vs "
+ << lb;
+ return std::nullopt;
+ }
if (lb == ub)
return APInt(bitwidth, 0);
More information about the Mlir-commits
mailing list