[Mlir-commits] [mlir] [mlir][presburger] Optimize the compilation time for calculating bounds of an Integer Relation (PR #164199)
donald chen
llvmlistbot at llvm.org
Sun Oct 19 23:35:00 PDT 2025
https://github.com/cxy-1993 updated https://github.com/llvm/llvm-project/pull/164199
>From ee2806fcca3e05594a73546ad9e4002230e009f3 Mon Sep 17 00:00:00 2001
From: donald chen <chenxunyu1993 at gmail.com>
Date: Tue, 14 Oct 2025 13:50:31 +0000
Subject: [PATCH] [mlir][presburger] Optimize the compilation time for
calculating bounds of an Integer Relation
IntegerRelation uses Fourier-Motzkin elimination and Gaussian elimination to
simplify constraints. These methods may repeatedly perform calculations and
elimination on irrelevant variables. Preemptively eliminating irrelevant
variables and their associated constraints can speed up up the calculation process.
---
.../Analysis/Presburger/IntegerRelation.h | 3 +
.../Analysis/Presburger/IntegerRelation.cpp | 56 +++++++++++++++++++
2 files changed, 59 insertions(+)
diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
index f86535740fec9..026d84529edfb 100644
--- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
+++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
@@ -511,6 +511,9 @@ class IntegerRelation {
void projectOut(unsigned pos, unsigned num);
inline void projectOut(unsigned pos) { return projectOut(pos, 1); }
+ /// Prune constraints that are irrelevant to the target variable.
+ void pruneConstraints(unsigned pos);
+
/// Tries to fold the specified variable to a constant using a trivial
/// equality detection; if successful, the constant is substituted for the
/// variable everywhere in the constraint system and then removed from the
diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
index 0dcdd5bb97bc8..838cf329c02df 100644
--- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
@@ -21,6 +21,7 @@
#include "mlir/Analysis/Presburger/Simplex.h"
#include "mlir/Analysis/Presburger/Utils.h"
#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallBitVector.h"
@@ -1723,12 +1724,67 @@ std::optional<DynamicAPInt> IntegerRelation::getConstantBoundOnDimSize(
return minDiff;
}
+void IntegerRelation::pruneConstraints(unsigned pos) {
+ llvm::DenseSet<unsigned> relatedCols({pos}), relatedRows;
+
+ llvm::SmallVector<unsigned> rowStack, colStack({pos});
+ unsigned numConstraints = getNumConstraints();
+ if (numConstraints == 0)
+ return;
+ while (!rowStack.empty() || !colStack.empty()) {
+ if (!rowStack.empty()) {
+ unsigned currentRow = rowStack.pop_back_val();
+ for (uint64_t colIndex = 0; colIndex < getNumVars(); ++colIndex) {
+ if (currentRow < getNumInequalities()) {
+ if (atIneq(currentRow, colIndex) != 0 &&
+ relatedCols.insert(colIndex).second) {
+ colStack.push_back(colIndex);
+ }
+ } else {
+ if (atEq(currentRow - getNumInequalities(), colIndex) != 0 &&
+ relatedCols.insert(colIndex).second) {
+ colStack.push_back(colIndex);
+ }
+ }
+ }
+ } else {
+ unsigned currentCol = colStack.pop_back_val();
+ for (uint64_t rowIndex = 0; rowIndex < numConstraints; ++rowIndex) {
+ if (rowIndex < getNumInequalities()) {
+ if (atIneq(rowIndex, currentCol) != 0 &&
+ relatedRows.insert(rowIndex).second) {
+ rowStack.push_back(rowIndex);
+ }
+ } else {
+ if (atEq(rowIndex - getNumInequalities(), currentCol) != 0 &&
+ relatedRows.insert(rowIndex).second) {
+ rowStack.push_back(rowIndex);
+ }
+ }
+ }
+ }
+ }
+
+ for (int64_t constraintId = numConstraints - 1; constraintId >= 0;
+ --constraintId) {
+ if (!relatedRows.contains(constraintId)) {
+ if (constraintId >= getNumInequalities()) {
+ removeEquality(constraintId - getNumInequalities());
+ } else {
+ removeInequality(constraintId);
+ }
+ }
+ }
+}
+
template <bool isLower>
std::optional<DynamicAPInt>
IntegerRelation::computeConstantLowerOrUpperBound(unsigned pos) {
assert(pos < getNumVars() && "invalid position");
// Project to 'pos'.
+ pruneConstraints(pos);
projectOut(0, pos);
+ pruneConstraints(0);
projectOut(1, getNumVars() - 1);
// Check if there's an equality equating the '0'^th variable to a constant.
int eqRowIdx = findEqualityToConstant(/*pos=*/0, /*symbolic=*/false);
More information about the Mlir-commits
mailing list