[Mlir-commits] [mlir] 5630143 - [MLIR][Presburger] LexSimplex::addEquality: add equalities as fixed columns

Arjun P llvmlistbot at llvm.org
Wed Mar 23 17:41:07 PDT 2022


Author: Arjun P
Date: 2022-03-24T00:41:17Z
New Revision: 5630143af33f7e6e0dabdf38982cc9800140bb75

URL: https://github.com/llvm/llvm-project/commit/5630143af33f7e6e0dabdf38982cc9800140bb75
DIFF: https://github.com/llvm/llvm-project/commit/5630143af33f7e6e0dabdf38982cc9800140bb75.diff

LOG: [MLIR][Presburger] LexSimplex::addEquality: add equalities as fixed columns

In LexSimplex, instead of adding equalities as a pair of inequalities,
add them as a single row, move them into the basis, and keep them there.

There will always be a valid basis involving all non-redundant equalities. Such
equalities will then be ignored in some other operations, such as when looking
for pivot columns. This speeds them up a little bit.

More importantly, this is an important precursor patch to adding support for
symbolic integer lexmin, as this heuristic can sometimes make a big difference there.

Reviewed By: Groverkss

Differential Revision: https://reviews.llvm.org/D122165

Added: 
    

Modified: 
    mlir/include/mlir/Analysis/Presburger/Simplex.h
    mlir/lib/Analysis/Presburger/Simplex.cpp
    mlir/unittests/Analysis/Presburger/SimplexTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/Presburger/Simplex.h b/mlir/include/mlir/Analysis/Presburger/Simplex.h
index 90f092ff39ac2..235b794c6bd74 100644
--- a/mlir/include/mlir/Analysis/Presburger/Simplex.h
+++ b/mlir/include/mlir/Analysis/Presburger/Simplex.h
@@ -166,21 +166,21 @@ class SimplexBase {
   /// false otherwise.
   bool isEmpty() const;
 
-  /// Add an inequality to the tableau. If coeffs is c_0, c_1, ... c_n, where n
-  /// is the current number of variables, then the corresponding inequality is
-  /// c_n + c_0*x_0 + c_1*x_1 + ... + c_{n-1}*x_{n-1} >= 0.
-  virtual void addInequality(ArrayRef<int64_t> coeffs) = 0;
-
   /// Returns the number of variables in the tableau.
   unsigned getNumVariables() const;
 
   /// Returns the number of constraints in the tableau.
   unsigned getNumConstraints() const;
 
+  /// Add an inequality to the tableau. If coeffs is c_0, c_1, ... c_n, where n
+  /// is the current number of variables, then the corresponding inequality is
+  /// c_n + c_0*x_0 + c_1*x_1 + ... + c_{n-1}*x_{n-1} >= 0.
+  virtual void addInequality(ArrayRef<int64_t> coeffs) = 0;
+
   /// Add an equality to the tableau. If coeffs is c_0, c_1, ... c_n, where n
   /// is the current number of variables, then the corresponding equality is
   /// c_n + c_0*x_0 + c_1*x_1 + ... + c_{n-1}*x_{n-1} == 0.
-  void addEquality(ArrayRef<int64_t> coeffs);
+  virtual void addEquality(ArrayRef<int64_t> coeffs) = 0;
 
   /// Add new variables to the end of the list of variables.
   void appendVariable(unsigned count = 1);
@@ -249,6 +249,14 @@ class SimplexBase {
   /// coefficient for it.
   Optional<unsigned> findAnyPivotRow(unsigned col);
 
+  /// Return any column that this row can be pivoted with, ignoring tableau
+  /// consistency. Equality rows are not considered.
+  ///
+  /// Returns an empty optional if no pivot is possible, which happens only when
+  /// the column unknown is a variable and no constraint has a non-zero
+  /// coefficient for it.
+  Optional<unsigned> findAnyPivotCol(unsigned row);
+
   /// Swap the row with the column in the tableau's data structures but not the
   /// tableau itself. This is used by pivot.
   void swapRowWithCol(unsigned row, unsigned col);
@@ -295,6 +303,7 @@ class SimplexBase {
     RemoveLastVariable,
     UnmarkEmpty,
     UnmarkLastRedundant,
+    UnmarkLastEquality,
     RestoreBasis
   };
 
@@ -308,13 +317,14 @@ class SimplexBase {
   /// Undo the operation represented by the log entry.
   void undo(UndoLogEntry entry);
 
-  /// Return the number of fixed columns, as described in the constructor above,
-  /// this is the number of columns beyond those for the variables in var.
-  unsigned getNumFixedCols() const { return usingBigM ? 3u : 2u; }
+  unsigned getNumFixedCols() const { return numFixedCols; }
 
   /// Stores whether or not a big M column is present in the tableau.
   bool usingBigM;
 
+  /// denom + const + maybe M + equality columns
+  unsigned numFixedCols;
+
   /// The number of rows in the tableau.
   unsigned nRow;
 
@@ -435,9 +445,12 @@ class LexSimplex : public SimplexBase {
   ///
   /// This just adds the inequality to the tableau and does not try to create a
   /// consistent tableau configuration.
-  void addInequality(ArrayRef<int64_t> coeffs) final {
-    addRow(coeffs, /*makeRestricted=*/true);
-  }
+  void addInequality(ArrayRef<int64_t> coeffs) final;
+
+  /// Add an equality to the tableau. If coeffs is c_0, c_1, ... c_n, where n
+  /// is the current number of variables, then the corresponding equality is
+  /// c_n + c_0*x_0 + c_1*x_1 + ... + c_{n-1}*x_{n-1} == 0.
+  void addEquality(ArrayRef<int64_t> coeffs) final;
 
   /// Get a snapshot of the current state. This is used for rolling back.
   unsigned getSnapshot() { return SimplexBase::getSnapshotBasis(); }
@@ -533,6 +546,11 @@ class Simplex : public SimplexBase {
   /// state and marks the Simplex empty if this is not possible.
   void addInequality(ArrayRef<int64_t> coeffs) final;
 
+  /// Add an equality to the tableau. If coeffs is c_0, c_1, ... c_n, where n
+  /// is the current number of variables, then the corresponding equality is
+  /// c_n + c_0*x_0 + c_1*x_1 + ... + c_{n-1}*x_{n-1} == 0.
+  void addEquality(ArrayRef<int64_t> coeffs) final;
+
   /// Compute the maximum or minimum value of the given row, depending on
   /// direction. The specified row is never pivoted. On return, the row may
   /// have a negative sample value if the direction is down.

diff  --git a/mlir/lib/Analysis/Presburger/Simplex.cpp b/mlir/lib/Analysis/Presburger/Simplex.cpp
index 341bcbc235dc7..d736a05f9636a 100644
--- a/mlir/lib/Analysis/Presburger/Simplex.cpp
+++ b/mlir/lib/Analysis/Presburger/Simplex.cpp
@@ -19,12 +19,12 @@ using Direction = Simplex::Direction;
 const int nullIndex = std::numeric_limits<int>::max();
 
 SimplexBase::SimplexBase(unsigned nVar, bool mustUseBigM)
-    : usingBigM(mustUseBigM), nRow(0), nCol(getNumFixedCols() + nVar),
-      nRedundant(0), tableau(0, nCol), empty(false) {
-  colUnknown.insert(colUnknown.begin(), getNumFixedCols(), nullIndex);
+    : usingBigM(mustUseBigM), numFixedCols(mustUseBigM ? 3 : 2), nRow(0),
+      nCol(numFixedCols + nVar), nRedundant(0), tableau(0, nCol), empty(false) {
+  colUnknown.insert(colUnknown.begin(), numFixedCols, nullIndex);
   for (unsigned i = 0; i < nVar; ++i) {
     var.emplace_back(Orientation::Column, /*restricted=*/false,
-                     /*pos=*/getNumFixedCols() + i);
+                     /*pos=*/numFixedCols + i);
     colUnknown.push_back(i);
   }
 }
@@ -309,7 +309,7 @@ void LexSimplex::restoreRationalConsistency() {
 // minimizes the change in sample value.
 LogicalResult LexSimplex::moveRowUnknownToColumn(unsigned row) {
   Optional<unsigned> maybeColumn;
-  for (unsigned col = 3; col < nCol; ++col) {
+  for (unsigned col = getNumFixedCols(); col < nCol; ++col) {
     if (tableau(row, col) <= 0)
       continue;
     maybeColumn =
@@ -648,7 +648,7 @@ void Simplex::addInequality(ArrayRef<int64_t> coeffs) {
 ///
 /// We simply add two opposing inequalities, which force the expression to
 /// be zero.
-void SimplexBase::addEquality(ArrayRef<int64_t> coeffs) {
+void Simplex::addEquality(ArrayRef<int64_t> coeffs) {
   addInequality(coeffs);
   SmallVector<int64_t, 8> negatedCoeffs;
   for (int64_t coeff : coeffs)
@@ -705,6 +705,15 @@ Optional<unsigned> SimplexBase::findAnyPivotRow(unsigned col) {
   return {};
 }
 
+// This doesn't find a pivot column only if the row has zero coefficients for
+// every column not marked as an equality.
+Optional<unsigned> SimplexBase::findAnyPivotCol(unsigned row) {
+  for (unsigned col = getNumFixedCols(); col < nCol; ++col)
+    if (tableau(row, col) != 0)
+      return col;
+  return {};
+}
+
 // It's not valid to remove the constraint by deleting the column since this
 // would result in an invalid basis.
 void Simplex::undoLastConstraint() {
@@ -780,6 +789,10 @@ void SimplexBase::undo(UndoLogEntry entry) {
     empty = false;
   } else if (entry == UndoLogEntry::UnmarkLastRedundant) {
     nRedundant--;
+  } else if (entry == UndoLogEntry::UnmarkLastEquality) {
+    numFixedCols--;
+    assert(getNumFixedCols() >= 2 + usingBigM &&
+           "The denominator, constant, big M and symbols are always fixed!");
   } else if (entry == UndoLogEntry::RestoreBasis) {
     assert(!savedBases.empty() && "No bases saved!");
 
@@ -1110,6 +1123,26 @@ Optional<SmallVector<Fraction, 8>> Simplex::getRationalSample() const {
   return sample;
 }
 
+void LexSimplex::addInequality(ArrayRef<int64_t> coeffs) {
+  addRow(coeffs, /*makeRestricted=*/true);
+}
+
+/// Try to make the equality a fixed column by finding any pivot and performing
+/// it. The only time this is not possible is when the given equality's
+/// direction is already in the span of the existing fixed column equalities. In
+/// that case, we just leave it in row position.
+void LexSimplex::addEquality(ArrayRef<int64_t> coeffs) {
+  const Unknown &u = con[addRow(coeffs, /*makeRestricted=*/true)];
+  Optional<unsigned> pivotCol = findAnyPivotCol(u.pos);
+  if (!pivotCol)
+    return;
+
+  pivot(u.pos, *pivotCol);
+  swapColumns(*pivotCol, getNumFixedCols());
+  numFixedCols++;
+  undoLog.push_back(UndoLogEntry::UnmarkLastEquality);
+}
+
 MaybeOptimum<SmallVector<Fraction, 8>> LexSimplex::getRationalSample() const {
   if (empty)
     return OptimumKind::Empty;

diff  --git a/mlir/unittests/Analysis/Presburger/SimplexTest.cpp b/mlir/unittests/Analysis/Presburger/SimplexTest.cpp
index e605c89b46922..ce5dad77819c2 100644
--- a/mlir/unittests/Analysis/Presburger/SimplexTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/SimplexTest.cpp
@@ -548,3 +548,10 @@ TEST(SimplexTest, addDivisionVariable) {
   ASSERT_TRUE(sample.hasValue());
   EXPECT_EQ((*sample)[0] / 2, (*sample)[1]);
 }
+
+TEST(LexSimplexTest, addEquality) {
+  IntegerRelation rel(/*numDomain=*/0, /*numRange=*/1);
+  rel.addEquality({1, 0});
+  LexSimplex simplex(rel);
+  EXPECT_EQ(simplex.getNumConstraints(), 1u);
+}


        


More information about the Mlir-commits mailing list