[Mlir-commits] [mlir] 5630143 - [MLIR][Presburger] LexSimplex::addEquality: add equalities as fixed columns
Arjun P
llvmlistbot at llvm.org
Wed Mar 23 17:41:07 PDT 2022
Author: Arjun P
Date: 2022-03-24T00:41:17Z
New Revision: 5630143af33f7e6e0dabdf38982cc9800140bb75
URL: https://github.com/llvm/llvm-project/commit/5630143af33f7e6e0dabdf38982cc9800140bb75
DIFF: https://github.com/llvm/llvm-project/commit/5630143af33f7e6e0dabdf38982cc9800140bb75.diff
LOG: [MLIR][Presburger] LexSimplex::addEquality: add equalities as fixed columns
In LexSimplex, instead of adding equalities as a pair of inequalities,
add them as a single row, move them into the basis, and keep them there.
There will always be a valid basis involving all non-redundant equalities. Such
equalities will then be ignored in some other operations, such as when looking
for pivot columns. This speeds them up a little bit.
More importantly, this is an important precursor patch to adding support for
symbolic integer lexmin, as this heuristic can sometimes make a big difference there.
Reviewed By: Groverkss
Differential Revision: https://reviews.llvm.org/D122165
Added:
Modified:
mlir/include/mlir/Analysis/Presburger/Simplex.h
mlir/lib/Analysis/Presburger/Simplex.cpp
mlir/unittests/Analysis/Presburger/SimplexTest.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Analysis/Presburger/Simplex.h b/mlir/include/mlir/Analysis/Presburger/Simplex.h
index 90f092ff39ac2..235b794c6bd74 100644
--- a/mlir/include/mlir/Analysis/Presburger/Simplex.h
+++ b/mlir/include/mlir/Analysis/Presburger/Simplex.h
@@ -166,21 +166,21 @@ class SimplexBase {
/// false otherwise.
bool isEmpty() const;
- /// Add an inequality to the tableau. If coeffs is c_0, c_1, ... c_n, where n
- /// is the current number of variables, then the corresponding inequality is
- /// c_n + c_0*x_0 + c_1*x_1 + ... + c_{n-1}*x_{n-1} >= 0.
- virtual void addInequality(ArrayRef<int64_t> coeffs) = 0;
-
/// Returns the number of variables in the tableau.
unsigned getNumVariables() const;
/// Returns the number of constraints in the tableau.
unsigned getNumConstraints() const;
+ /// Add an inequality to the tableau. If coeffs is c_0, c_1, ... c_n, where n
+ /// is the current number of variables, then the corresponding inequality is
+ /// c_n + c_0*x_0 + c_1*x_1 + ... + c_{n-1}*x_{n-1} >= 0.
+ virtual void addInequality(ArrayRef<int64_t> coeffs) = 0;
+
/// Add an equality to the tableau. If coeffs is c_0, c_1, ... c_n, where n
/// is the current number of variables, then the corresponding equality is
/// c_n + c_0*x_0 + c_1*x_1 + ... + c_{n-1}*x_{n-1} == 0.
- void addEquality(ArrayRef<int64_t> coeffs);
+ virtual void addEquality(ArrayRef<int64_t> coeffs) = 0;
/// Add new variables to the end of the list of variables.
void appendVariable(unsigned count = 1);
@@ -249,6 +249,14 @@ class SimplexBase {
/// coefficient for it.
Optional<unsigned> findAnyPivotRow(unsigned col);
+ /// Return any column that this row can be pivoted with, ignoring tableau
+ /// consistency. Equality rows are not considered.
+ ///
+ /// Returns an empty optional if no pivot is possible, which happens only when
+ /// the column unknown is a variable and no constraint has a non-zero
+ /// coefficient for it.
+ Optional<unsigned> findAnyPivotCol(unsigned row);
+
/// Swap the row with the column in the tableau's data structures but not the
/// tableau itself. This is used by pivot.
void swapRowWithCol(unsigned row, unsigned col);
@@ -295,6 +303,7 @@ class SimplexBase {
RemoveLastVariable,
UnmarkEmpty,
UnmarkLastRedundant,
+ UnmarkLastEquality,
RestoreBasis
};
@@ -308,13 +317,14 @@ class SimplexBase {
/// Undo the operation represented by the log entry.
void undo(UndoLogEntry entry);
- /// Return the number of fixed columns, as described in the constructor above,
- /// this is the number of columns beyond those for the variables in var.
- unsigned getNumFixedCols() const { return usingBigM ? 3u : 2u; }
+ unsigned getNumFixedCols() const { return numFixedCols; }
/// Stores whether or not a big M column is present in the tableau.
bool usingBigM;
+ /// denom + const + maybe M + equality columns
+ unsigned numFixedCols;
+
/// The number of rows in the tableau.
unsigned nRow;
@@ -435,9 +445,12 @@ class LexSimplex : public SimplexBase {
///
/// This just adds the inequality to the tableau and does not try to create a
/// consistent tableau configuration.
- void addInequality(ArrayRef<int64_t> coeffs) final {
- addRow(coeffs, /*makeRestricted=*/true);
- }
+ void addInequality(ArrayRef<int64_t> coeffs) final;
+
+ /// Add an equality to the tableau. If coeffs is c_0, c_1, ... c_n, where n
+ /// is the current number of variables, then the corresponding equality is
+ /// c_n + c_0*x_0 + c_1*x_1 + ... + c_{n-1}*x_{n-1} == 0.
+ void addEquality(ArrayRef<int64_t> coeffs) final;
/// Get a snapshot of the current state. This is used for rolling back.
unsigned getSnapshot() { return SimplexBase::getSnapshotBasis(); }
@@ -533,6 +546,11 @@ class Simplex : public SimplexBase {
/// state and marks the Simplex empty if this is not possible.
void addInequality(ArrayRef<int64_t> coeffs) final;
+ /// Add an equality to the tableau. If coeffs is c_0, c_1, ... c_n, where n
+ /// is the current number of variables, then the corresponding equality is
+ /// c_n + c_0*x_0 + c_1*x_1 + ... + c_{n-1}*x_{n-1} == 0.
+ void addEquality(ArrayRef<int64_t> coeffs) final;
+
/// Compute the maximum or minimum value of the given row, depending on
/// direction. The specified row is never pivoted. On return, the row may
/// have a negative sample value if the direction is down.
diff --git a/mlir/lib/Analysis/Presburger/Simplex.cpp b/mlir/lib/Analysis/Presburger/Simplex.cpp
index 341bcbc235dc7..d736a05f9636a 100644
--- a/mlir/lib/Analysis/Presburger/Simplex.cpp
+++ b/mlir/lib/Analysis/Presburger/Simplex.cpp
@@ -19,12 +19,12 @@ using Direction = Simplex::Direction;
const int nullIndex = std::numeric_limits<int>::max();
SimplexBase::SimplexBase(unsigned nVar, bool mustUseBigM)
- : usingBigM(mustUseBigM), nRow(0), nCol(getNumFixedCols() + nVar),
- nRedundant(0), tableau(0, nCol), empty(false) {
- colUnknown.insert(colUnknown.begin(), getNumFixedCols(), nullIndex);
+ : usingBigM(mustUseBigM), numFixedCols(mustUseBigM ? 3 : 2), nRow(0),
+ nCol(numFixedCols + nVar), nRedundant(0), tableau(0, nCol), empty(false) {
+ colUnknown.insert(colUnknown.begin(), numFixedCols, nullIndex);
for (unsigned i = 0; i < nVar; ++i) {
var.emplace_back(Orientation::Column, /*restricted=*/false,
- /*pos=*/getNumFixedCols() + i);
+ /*pos=*/numFixedCols + i);
colUnknown.push_back(i);
}
}
@@ -309,7 +309,7 @@ void LexSimplex::restoreRationalConsistency() {
// minimizes the change in sample value.
LogicalResult LexSimplex::moveRowUnknownToColumn(unsigned row) {
Optional<unsigned> maybeColumn;
- for (unsigned col = 3; col < nCol; ++col) {
+ for (unsigned col = getNumFixedCols(); col < nCol; ++col) {
if (tableau(row, col) <= 0)
continue;
maybeColumn =
@@ -648,7 +648,7 @@ void Simplex::addInequality(ArrayRef<int64_t> coeffs) {
///
/// We simply add two opposing inequalities, which force the expression to
/// be zero.
-void SimplexBase::addEquality(ArrayRef<int64_t> coeffs) {
+void Simplex::addEquality(ArrayRef<int64_t> coeffs) {
addInequality(coeffs);
SmallVector<int64_t, 8> negatedCoeffs;
for (int64_t coeff : coeffs)
@@ -705,6 +705,15 @@ Optional<unsigned> SimplexBase::findAnyPivotRow(unsigned col) {
return {};
}
+// This doesn't find a pivot column only if the row has zero coefficients for
+// every column not marked as an equality.
+Optional<unsigned> SimplexBase::findAnyPivotCol(unsigned row) {
+ for (unsigned col = getNumFixedCols(); col < nCol; ++col)
+ if (tableau(row, col) != 0)
+ return col;
+ return {};
+}
+
// It's not valid to remove the constraint by deleting the column since this
// would result in an invalid basis.
void Simplex::undoLastConstraint() {
@@ -780,6 +789,10 @@ void SimplexBase::undo(UndoLogEntry entry) {
empty = false;
} else if (entry == UndoLogEntry::UnmarkLastRedundant) {
nRedundant--;
+ } else if (entry == UndoLogEntry::UnmarkLastEquality) {
+ numFixedCols--;
+ assert(getNumFixedCols() >= 2 + usingBigM &&
+ "The denominator, constant, big M and symbols are always fixed!");
} else if (entry == UndoLogEntry::RestoreBasis) {
assert(!savedBases.empty() && "No bases saved!");
@@ -1110,6 +1123,26 @@ Optional<SmallVector<Fraction, 8>> Simplex::getRationalSample() const {
return sample;
}
+void LexSimplex::addInequality(ArrayRef<int64_t> coeffs) {
+ addRow(coeffs, /*makeRestricted=*/true);
+}
+
+/// Try to make the equality a fixed column by finding any pivot and performing
+/// it. The only time this is not possible is when the given equality's
+/// direction is already in the span of the existing fixed column equalities. In
+/// that case, we just leave it in row position.
+void LexSimplex::addEquality(ArrayRef<int64_t> coeffs) {
+ const Unknown &u = con[addRow(coeffs, /*makeRestricted=*/true)];
+ Optional<unsigned> pivotCol = findAnyPivotCol(u.pos);
+ if (!pivotCol)
+ return;
+
+ pivot(u.pos, *pivotCol);
+ swapColumns(*pivotCol, getNumFixedCols());
+ numFixedCols++;
+ undoLog.push_back(UndoLogEntry::UnmarkLastEquality);
+}
+
MaybeOptimum<SmallVector<Fraction, 8>> LexSimplex::getRationalSample() const {
if (empty)
return OptimumKind::Empty;
diff --git a/mlir/unittests/Analysis/Presburger/SimplexTest.cpp b/mlir/unittests/Analysis/Presburger/SimplexTest.cpp
index e605c89b46922..ce5dad77819c2 100644
--- a/mlir/unittests/Analysis/Presburger/SimplexTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/SimplexTest.cpp
@@ -548,3 +548,10 @@ TEST(SimplexTest, addDivisionVariable) {
ASSERT_TRUE(sample.hasValue());
EXPECT_EQ((*sample)[0] / 2, (*sample)[1]);
}
+
+TEST(LexSimplexTest, addEquality) {
+ IntegerRelation rel(/*numDomain=*/0, /*numRange=*/1);
+ rel.addEquality({1, 0});
+ LexSimplex simplex(rel);
+ EXPECT_EQ(simplex.getNumConstraints(), 1u);
+}
More information about the Mlir-commits
mailing list