[Mlir-commits] [mlir] [MLIR][Presburger] Fix Gaussian elimination (PR #164437)
Yue Huang
llvmlistbot at llvm.org
Sat Nov 22 02:38:50 PST 2025
https://github.com/AdUhTkJm updated https://github.com/llvm/llvm-project/pull/164437
>From 7eb6a6ea841605f0809378d4809b6f8060e9f91d Mon Sep 17 00:00:00 2001
From: Yue Huang <yh548 at cam.ac.uk>
Date: Tue, 21 Oct 2025 17:03:02 +0100
Subject: [PATCH] [MLIR][Presburger] Fix Gaussian elimination
---
mlir/lib/Analysis/Presburger/Barvinok.cpp | 14 ++++++-------
.../Analysis/Presburger/IntegerRelation.cpp | 21 +++++++++++++++++--
.../Analysis/Presburger/BarvinokTest.cpp | 7 +++++++
.../Presburger/IntegerRelationTest.cpp | 14 +++++++++++++
4 files changed, 47 insertions(+), 9 deletions(-)
diff --git a/mlir/lib/Analysis/Presburger/Barvinok.cpp b/mlir/lib/Analysis/Presburger/Barvinok.cpp
index 75d592e976edf..c31b27794f01e 100644
--- a/mlir/lib/Analysis/Presburger/Barvinok.cpp
+++ b/mlir/lib/Analysis/Presburger/Barvinok.cpp
@@ -178,13 +178,13 @@ mlir::presburger::detail::solveParametricEquations(FracMatrix equations) {
for (unsigned i = 0; i < d; ++i) {
// First ensure that the diagonal element is nonzero, by swapping
// it with a row that is non-zero at column i.
- if (equations(i, i) != 0)
- continue;
- for (unsigned j = i + 1; j < d; ++j) {
- if (equations(j, i) == 0)
- continue;
- equations.swapRows(j, i);
- break;
+ if (equations(i, i) == 0) {
+ for (unsigned j = i + 1; j < d; ++j) {
+ if (equations(j, i) == 0)
+ continue;
+ equations.swapRows(j, i);
+ break;
+ }
}
Fraction diagElement = equations(i, i);
diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
index 949fc2db79809..3249381a6a7e4 100644
--- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
@@ -1112,15 +1112,28 @@ unsigned IntegerRelation::gaussianEliminateVars(unsigned posStart,
return posLimit - posStart;
}
+static std::optional<unsigned>
+findEqualityWithNonZeroAfterRow(IntegerRelation &rel, unsigned fromRow,
+ unsigned colIdx) {
+ assert(fromRow < rel.getNumVars() && colIdx < rel.getNumCols() &&
+ "position out of bounds");
+ for (unsigned rowIdx = fromRow; rowIdx < rel.getNumEqualities(); ++rowIdx) {
+ if (rel.atEq(rowIdx, colIdx) != 0)
+ return rowIdx;
+ }
+ return std::nullopt;
+}
+
bool IntegerRelation::gaussianEliminate() {
gcdTightenInequalities();
unsigned firstVar = 0, vars = getNumVars();
unsigned nowDone, eqs;
std::optional<unsigned> pivotRow;
for (nowDone = 0, eqs = getNumEqualities(); nowDone < eqs; ++nowDone) {
- // Finds the first non-empty column.
+ // Finds the first non-empty column that we haven't dealt with.
for (; firstVar < vars; ++firstVar) {
- if ((pivotRow = findConstraintWithNonZeroAt(firstVar, /*isEq=*/true)))
+ if ((pivotRow =
+ findEqualityWithNonZeroAfterRow(*this, nowDone, firstVar)))
break;
}
// The matrix has been normalized to row echelon form.
@@ -1143,6 +1156,10 @@ bool IntegerRelation::gaussianEliminate() {
inequalities.normalizeRow(i);
}
gcdTightenInequalities();
+
+ // The column is finished. Tell the next iteration to start at the next
+ // column.
+ firstVar++;
}
// No redundant rows.
diff --git a/mlir/unittests/Analysis/Presburger/BarvinokTest.cpp b/mlir/unittests/Analysis/Presburger/BarvinokTest.cpp
index eaf04379cb529..d687a0072a158 100644
--- a/mlir/unittests/Analysis/Presburger/BarvinokTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/BarvinokTest.cpp
@@ -301,3 +301,10 @@ TEST(BarvinokTest, computeNumTermsPolytope) {
gf = count[0].second;
EXPECT_EQ(gf.getNumerators().size(), 24u);
}
+
+TEST(BarvinokTest, solveParametricEquations) {
+ FracMatrix equations = makeFracMatrix(2, 3, {{2, 3, -4}, {2, 6, -7}});
+ FracMatrix solution = *solveParametricEquations(equations);
+ EXPECT_EQ(solution.at(0, 0), Fraction(1, 2));
+ EXPECT_EQ(solution.at(1, 0), 1);
+}
diff --git a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
index 9ae90a4841f3c..c1795b1b48925 100644
--- a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
@@ -725,3 +725,17 @@ TEST(IntegerRelationTest, addLocalModulo) {
EXPECT_TRUE(rel.containsPointNoLocal({x, x % 32}));
}
}
+
+TEST(IntegerRelationTest, simplify) {
+ IntegerRelation rel =
+ parseRelationFromSet("(x, y, z): (2*x + y - 4*z - 3 == 0, "
+ "3*x - y - 3*z + 2 == 0, x + 3*y - 5*z - 8 == 0,"
+ "x - y + z >= 0)",
+ 2);
+ IntegerRelation copy = rel;
+ rel.simplify();
+
+ EXPECT_TRUE(rel.isEqual(copy));
+ // The third equality is redundant and should be removed.
+ EXPECT_TRUE(rel.getNumEqualities() == 2);
+}
More information about the Mlir-commits
mailing list