[Mlir-commits] [mlir] da92f92 - [MLIR][Presburger] IntegerPolyhedron: add support for symbolic integer lexmin
Arjun P
llvmlistbot at llvm.org
Mon Apr 4 16:25:00 PDT 2022
Author: Arjun P
Date: 2022-04-05T00:24:57+01:00
New Revision: da92f92621e28a56fe8ad79d82eb60e436bf1d39
URL: https://github.com/llvm/llvm-project/commit/da92f92621e28a56fe8ad79d82eb60e436bf1d39
DIFF: https://github.com/llvm/llvm-project/commit/da92f92621e28a56fe8ad79d82eb60e436bf1d39.diff
LOG: [MLIR][Presburger] IntegerPolyhedron: add support for symbolic integer lexmin
Add support for computing the symbolic integer lexmin of a polyhedron.
This finds, for every assignment to the symbols, the lexicographically
minimum value attained by the dimensions. For example, the symbolic lexmin
of the set
`(x, y)[a, b, c] : (a <= x, b <= x, x <= c)`
can be written as
```
x = a if b <= a, a <= c
x = b if a < b, b <= c
```
This also finds the set of assignments to the symbols that make the lexmin unbounded.
Reviewed By: Groverkss
Differential Revision: https://reviews.llvm.org/D122985
Added:
Modified:
mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
mlir/include/mlir/Analysis/Presburger/Matrix.h
mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
mlir/include/mlir/Analysis/Presburger/Simplex.h
mlir/lib/Analysis/Presburger/IntegerRelation.cpp
mlir/lib/Analysis/Presburger/Matrix.cpp
mlir/lib/Analysis/Presburger/PWMAFunction.cpp
mlir/lib/Analysis/Presburger/Simplex.cpp
mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
index 6add6e8aa9b24..709e4f8438356 100644
--- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
+++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
@@ -536,6 +536,7 @@ class IntegerRelation : public PresburgerSpace {
Matrix inequalities;
};
+struct SymbolicLexMin;
/// An IntegerPolyhedron is a PresburgerSpace subject to affine
/// constraints. Affine constraints can be inequalities or equalities in the
/// form:
@@ -593,6 +594,28 @@ class IntegerPolyhedron : public IntegerRelation {
/// column position (i.e., not relative to the kind of identifier) of the
/// first added identifier.
unsigned insertId(IdKind kind, unsigned pos, unsigned num = 1) override;
+
+ /// Compute the symbolic integer lexmin of the polyhedron.
+ /// This finds, for every assignment to the symbols, the lexicographically
+ /// minimum value attained by the dimensions. For example, the symbolic lexmin
+ /// of the set
+ ///
+ /// (x, y)[a, b, c] : (a <= x, b <= x, x <= c)
+ ///
+ /// can be written as
+ ///
+ /// x = a if b <= a, a <= c
+ /// x = b if a < b, b <= c
+ ///
+ /// This function is stored in the `lexmin` function in the result.
+ /// Some assignments to the symbols might make the set empty.
+ /// Such points are not part of the function's domain.
+ /// In the above example, this happens when max(a, b) > c.
+ ///
+ /// For some values of the symbols, the lexmin may be unbounded.
+ /// `SymbolicLexMin` stores these parts of the symbolic domain in a separate
+ /// `PresburgerSet`, `unboundedDomain`.
+ SymbolicLexMin findSymbolicIntegerLexMin() const;
};
} // namespace presburger
diff --git a/mlir/include/mlir/Analysis/Presburger/Matrix.h b/mlir/include/mlir/Analysis/Presburger/Matrix.h
index 940b88d8148f4..e2ad543070a4b 100644
--- a/mlir/include/mlir/Analysis/Presburger/Matrix.h
+++ b/mlir/include/mlir/Analysis/Presburger/Matrix.h
@@ -151,6 +151,9 @@ class Matrix {
/// Add an extra row at the bottom of the matrix and return its position.
unsigned appendExtraRow();
+ /// Same as above, but copy the given elements into the row. The length of
+ /// `elems` must be equal to the number of columns.
+ unsigned appendExtraRow(ArrayRef<int64_t> elems);
/// Print the matrix.
void print(raw_ostream &os) const;
diff --git a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
index ce0d77da9bc2c..f4bffe5b4e7a4 100644
--- a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
+++ b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
@@ -106,6 +106,11 @@ class MultiAffineFunction : protected IntegerPolyhedron {
/// outside the domain, an empty optional is returned.
Optional<SmallVector<int64_t, 8>> valueAt(ArrayRef<int64_t> point) const;
+ /// Truncate the output dimensions to the first `count` dimensions.
+ ///
+ /// TODO: refactor so that this can be accomplished through removeIdRange.
+ void truncateOutput(unsigned count);
+
void print(raw_ostream &os) const;
void dump() const;
@@ -165,6 +170,11 @@ class PWMAFunction : public PresburgerSpace {
/// value at every point in the domain.
bool isEqual(const PWMAFunction &other) const;
+ /// Truncate the output dimensions to the first `count` dimensions.
+ ///
+ /// TODO: refactor so that this can be accomplished through removeIdRange.
+ void truncateOutput(unsigned count);
+
void print(raw_ostream &os) const;
void dump() const;
diff --git a/mlir/include/mlir/Analysis/Presburger/Simplex.h b/mlir/include/mlir/Analysis/Presburger/Simplex.h
index 66d408dbf8b69..67a4b5f68e202 100644
--- a/mlir/include/mlir/Analysis/Presburger/Simplex.h
+++ b/mlir/include/mlir/Analysis/Presburger/Simplex.h
@@ -18,6 +18,7 @@
#include "mlir/Analysis/Presburger/Fraction.h"
#include "mlir/Analysis/Presburger/IntegerRelation.h"
#include "mlir/Analysis/Presburger/Matrix.h"
+#include "mlir/Analysis/Presburger/PWMAFunction.h"
#include "mlir/Analysis/Presburger/Utils.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/ArrayRef.h"
@@ -41,8 +42,9 @@ class GBRSimplex;
/// these constraints that are redundant, i.e. a subset of constraints that
/// doesn't constrain the affine set further after adding the non-redundant
/// constraints. The LexSimplex class provides support for computing the
-/// lexicographical minimum of an IntegerRelation. Both these classes can be
-/// constructed from an IntegerRelation, and both inherit common
+/// lexicographic minimum of an IntegerRelation. The SymbolicLexMin class
+/// provides support for computing symbolic lexicographic minimums. All of these
+/// classes can be constructed from an IntegerRelation, and all inherit common
/// functionality from SimplexBase.
///
/// The implementations of the Simplex and SimplexBase classes, other than the
@@ -72,19 +74,22 @@ class GBRSimplex;
/// respectively. As described above, the first column is the common
/// denominator. The second column represents the constant term, explained in
/// more detail below. These two are _fixed columns_; they always retain their
-/// position as the first and second columns. Additionally, LexSimplex stores
-/// a so-call big M parameter (explained below) in the third column, so
-/// LexSimplex has three fixed columns.
+/// position as the first and second columns. Additionally, LexSimplexBase
+/// stores a so-call big M parameter (explained below) in the third column, so
+/// LexSimplexBase has three fixed columns. Finally, SymbolicLexSimplex has
+/// `nSymbol` variables designated as symbols. These occupy the next `nSymbol`
+/// columns, viz. the columns [3, 3 + nSymbol). For more information on symbols,
+/// see LexSimplexBase and SymbolicLexSimplex.
///
-/// LexSimplex does not directly support variables which can be negative, so we
-/// introduce the so-called big M parameter, an artificial variable that is
+/// LexSimplexBase does not directly support variables which can be negative, so
+/// we introduce the so-called big M parameter, an artificial variable that is
/// considered to have an arbitrarily large value. We then transform the
/// variables, say x, y, z, ... to M, M + x, M + y, M + z. Since M has been
/// added to these variables, they are now known to have non-negative values.
-/// For more details, see the documentation for LexSimplex. The big M parameter
-/// is not considered a real unknown and is not stored in the `var` data
-/// structure; rather the tableau just has an extra fixed column for it just
-/// like the constant term.
+/// For more details, see the documentation for LexSimplexBase. The big M
+/// parameter is not considered a real unknown and is not stored in the `var`
+/// data structure; rather the tableau just has an extra fixed column for it
+/// just like the constant term.
///
/// The vectors var and con store information about the variables and
/// constraints respectively, namely, whether they are in row or column
@@ -146,8 +151,8 @@ class GBRSimplex;
/// operation from the end until we reach the snapshot's location. SimplexBase
/// also supports taking a snapshot including the exact set of basis unknowns;
/// if this functionality is used, then on rolling back the exact basis will
-/// also be restored. This is used by LexSimplex because its algorithm, unlike
-/// Simplex, is sensitive to the exact basis used at a point.
+/// also be restored. This is used by LexSimplexBase because the lex algorithm,
+/// unlike `Simplex`, is sensitive to the exact basis used at a point.
class SimplexBase {
public:
SimplexBase() = delete;
@@ -211,7 +216,8 @@ class SimplexBase {
/// constant term, whereas LexSimplex has an extra fixed column for the
/// so-called big M parameter. For more information see the documentation for
/// LexSimplex.
- SimplexBase(unsigned nVar, bool mustUseBigM);
+ SimplexBase(unsigned nVar, bool mustUseBigM, unsigned symbolOffset,
+ unsigned nSymbol);
enum class Orientation { Row, Column };
@@ -223,11 +229,14 @@ class SimplexBase {
/// always be non-negative and if it cannot be made non-negative without
/// violating other constraints, the tableau is empty.
struct Unknown {
- Unknown(Orientation oOrientation, bool oRestricted, unsigned oPos)
- : pos(oPos), orientation(oOrientation), restricted(oRestricted) {}
+ Unknown(Orientation oOrientation, bool oRestricted, unsigned oPos,
+ bool oIsSymbol = false)
+ : pos(oPos), orientation(oOrientation), restricted(oRestricted),
+ isSymbol(oIsSymbol) {}
unsigned pos;
Orientation orientation;
bool restricted : 1;
+ bool isSymbol : 1;
void print(raw_ostream &os) const {
os << (orientation == Orientation::Row ? "r" : "c");
@@ -326,6 +335,10 @@ class SimplexBase {
/// nRedundant rows.
unsigned nRedundant;
+ /// The number of parameters. This must be consistent with the number of
+ /// Unknowns in `var` below that have `isSymbol` set to true.
+ unsigned nSymbol;
+
/// The matrix representing the tableau.
Matrix tableau;
@@ -363,62 +376,45 @@ class SimplexBase {
/// introduce an artifical variable M that is considered to have a value of
/// +infinity and instead of the variables x, y, z, we internally use variables
/// M + x, M + y, M + z, which are now guaranteed to be non-negative. See the
-/// documentation for Simplex for more details. The whole algorithm is performed
-/// without having to fix a "big enough" value of the big M parameter; it is
-/// just considered to be infinite throughout and it never appears in the final
-/// outputs. We will deal with sample values throughout that may in general be
-/// some linear expression involving M like pM + q or aM + b. We can compare
-/// these with each other. They have a total order:
-/// aM + b < pM + q iff a < p or (a == p and b < q).
+/// documentation for SimplexBase for more details. M is also considered to be
+/// an integer that is divisible by everything.
+///
+/// The whole algorithm is performed with M treated as a symbol;
+/// it is just considered to be infinite throughout and it never appears in the
+/// final outputs. We will deal with sample values throughout that may in
+/// general be some affine expression involving M, like pM + q or aM + b. We can
+/// compare these with each other. They have a total order:
+///
+/// aM + b < pM + q iff a < p or (a == p and b < q).
/// In particular, aM + b < 0 iff a < 0 or (a == 0 and b < 0).
///
+/// When performing symbolic optimization, sample values will be affine
+/// expressions in M and the symbols. For example, we could have sample values
+/// aM + bS + c and pM + qS + r, where S is a symbol. Now we have
+/// aM + bS + c < pM + qS + r iff (a < p) or (a == p and bS + c < qS + r).
+/// bS + c < qS + r can be always true, always false, or neither,
+/// depending on the set of values S can take. The symbols are always stored
+/// in columns [3, 3 + nSymbols). For more details, see the
+/// documentation for SymbolicLexSimplex.
+///
/// Initially all the constraints to be added are added as rows, with no attempt
/// to keep the tableau consistent. Pivots are only performed when some query
/// is made, such as a call to getRationalLexMin. Care is taken to always
/// maintain a lexicopositive basis transform, explained below.
///
-/// Let the variables be x = (x_1, ... x_n). Let the basis unknowns at a
-/// particular point be y = (y_1, ... y_n). We know that x = A*y + b for some
-/// n x n matrix A and n x 1 column vector b. We want every column in A to be
-/// lexicopositive, i.e., have at least one non-zero element, with the first
-/// such element being positive. This property is preserved throughout the
-/// operation of LexSimplex. Note that on construction, the basis transform A is
-/// the indentity matrix and so every column is lexicopositive. Note that for
-/// LexSimplex, for the tableau to be consistent we must have non-negative
-/// sample values not only for the constraints but also for the variables.
-/// So if the tableau is consistent then x >= 0 and y >= 0, by which we mean
-/// every element in these vectors is non-negative. (note that this is a
-///
diff erent concept from lexicopositivity!)
-///
-/// When we arrive at a basis such the basis transform is lexicopositive and the
-/// tableau is consistent, the sample point is the lexiographically minimum
-/// point in the polytope. We will show that A*y is zero or lexicopositive when
-/// y >= 0. Adding a lexicopositive vector to b will make it lexicographically
-/// bigger, so A*y + b is lexicographically bigger than b for any y >= 0 except
-/// y = 0. This shows that no point lexicographically smaller than x = b can be
-/// obtained. Since we already know that x = b is valid point in the space, this
-/// shows that x = b is the lexicographic minimum.
-///
-/// Proof that A*y is lexicopositive or zero when y > 0. Recall that every
-/// column of A is lexicopositive. Begin by considering A_1, the first row of A.
-/// If this row is all zeros, then (A*y)_1 = (A_1)*y = 0; proceed to the next
-/// row. If we run out of rows, A*y is zero and we are done; otherwise, we
-/// encounter some row A_i that has a non-zero element. Every column is
-/// lexicopositive and so has some positive element before any negative elements
-/// occur, so the element in this row for any column, if non-zero, must be
-/// positive. Consider (A*y)_i = (A_i)*y. All the elements in both vectors are
-/// non-negative, so if this is non-zero then it must be positive. Then the
-/// first non-zero element of A*y is positive so A*y is lexicopositive.
-///
-/// Otherwise, if (A_i)*y is zero, then for every column j that had a non-zero
-/// element in A_i, y_j is zero. Thus these columns have no contribution to A*y
-/// and we can completely ignore these columns of A. We now continue downwards,
-/// looking for rows of A that have a non-zero element other than in the ignored
-/// columns. If we find one, say A_k, once again these elements must be positive
-/// since they are the first non-zero element in each of these columns, so if
-/// (A_k)*y is not zero then we have that A*y is lexicopositive and if not we
-/// ignore more columns; eventually if all these dot products become zero then
-/// A*y is zero and we are done.
+/// Let the variables be x = (x_1, ... x_n).
+/// Let the symbols be s = (s_1, ... s_m). Let the basis unknowns at a
+/// particular point be y = (y_1, ... y_n). We know that x = A*y + T*s + b for
+/// some n x n matrix A, n x m matrix s, and n x 1 column vector b. We want
+/// every column in A to be lexicopositive, i.e., have at least one non-zero
+/// element, with the first such element being positive. This property is
+/// preserved throughout the operation of LexSimplexBase. Note that on
+/// construction, the basis transform A is the identity matrix and so every
+/// column is lexicopositive. Note that for LexSimplexBase, for the tableau to
+/// be consistent we must have non-negative sample values not only for the
+/// constraints but also for the variables. So if the tableau is consistent then
+/// x >= 0 and y >= 0, by which we mean every element in these vectors is
+/// non-negative. (note that this is a
diff erent concept from lexicopositivity!)
class LexSimplexBase : public SimplexBase {
public:
~LexSimplexBase() override = default;
@@ -435,25 +431,37 @@ class LexSimplexBase : public SimplexBase {
unsigned getSnapshot() { return SimplexBase::getSnapshotBasis(); }
protected:
- LexSimplexBase(unsigned nVar) : SimplexBase(nVar, /*mustUseBigM=*/true) {}
+ LexSimplexBase(unsigned nVar, unsigned symbolOffset, unsigned nSymbol)
+ : SimplexBase(nVar, /*mustUseBigM=*/true, symbolOffset, nSymbol) {}
explicit LexSimplexBase(const IntegerRelation &constraints)
- : LexSimplexBase(constraints.getNumIds()) {
+ : LexSimplexBase(constraints.getNumIds(),
+ constraints.getIdKindOffset(IdKind::Symbol),
+ constraints.getNumSymbolIds()) {
intersectIntegerRelation(constraints);
}
+ /// Add new symbolic variables to the end of the list of variables.
+ void appendSymbol();
+
/// Try to move the specified row to column orientation while preserving the
- /// lexicopositivity of the basis transform. If this is not possible, return
- /// failure. This only occurs when the constraints have no solution; the
- /// tableau will be marked empty in such a case.
+ /// lexicopositivity of the basis transform. The row must have a negative
+ /// sample value. If this is not possible, return failure. This only occurs
+ /// when the constraints have no solution; the tableau will be marked empty in
+ /// such a case.
LogicalResult moveRowUnknownToColumn(unsigned row);
- /// Given a row that has a non-integer sample value, add an inequality such
- /// that this fractional sample value is cut away from the polytope. The added
- /// inequality will be such that no integer points are removed.
+ /// Given a row that has a non-integer sample value, add an inequality to cut
+ /// away this fractional sample value from the polytope without removing any
+ /// integer points. The integer lexmin, if one existed, remains the same on
+ /// return.
///
- /// Returns whether the cut constraint could be enforced, i.e. failure if the
- /// cut made the polytope empty, and success if it didn't. Failure status
- /// indicates that the polytope didn't have any integer points.
+ /// This assumes that the symbolic part of the sample is integral,
+ /// i.e., if the symbolic sample is (c + aM + b_1*s_1 + ... b_n*s_n)/d,
+ /// where s_1, ... s_n are symbols, this assumes that
+ /// (b_1*s_1 + ... + b_n*s_n)/s is integral.
+ ///
+ /// Return failure if the tableau became empty, and success if it didn't.
+ /// Failure status indicates that the polytope was integer empty.
LogicalResult addCut(unsigned row);
/// Undo the addition of the last constraint. This is only called while
@@ -461,14 +469,19 @@ class LexSimplexBase : public SimplexBase {
void undoLastConstraint() final;
/// Given two potential pivot columns for a row, return the one that results
- /// in the lexicographically smallest sample vector.
+ /// in the lexicographically smallest sample vector. The row's sample value
+ /// must be negative. If symbols are involved, the sample value must be
+ /// negative for all possible assignments to the symbols.
unsigned getLexMinPivotColumn(unsigned row, unsigned colA,
unsigned colB) const;
};
+/// A class for lexicographic optimization without any symbols. This also
+/// provides support for integer-exact redundancy and separateness checks.
class LexSimplex : public LexSimplexBase {
public:
- explicit LexSimplex(unsigned nVar) : LexSimplexBase(nVar) {}
+ explicit LexSimplex(unsigned nVar)
+ : LexSimplexBase(nVar, /*symbolOffset=*/0, /*nSymbol=*/0) {}
explicit LexSimplex(const IntegerRelation &constraints)
: LexSimplexBase(constraints) {
assert(constraints.getNumSymbolIds() == 0 &&
@@ -502,7 +515,7 @@ class LexSimplex : public LexSimplexBase {
MaybeOptimum<SmallVector<Fraction, 8>> getRationalSample() const;
/// Make the tableau configuration consistent.
- void restoreRationalConsistency();
+ LogicalResult restoreRationalConsistency();
/// Return whether the specified row is violated;
bool rowIsViolated(unsigned row) const;
@@ -514,11 +527,122 @@ class LexSimplex : public LexSimplexBase {
/// Get a row corresponding to a var that has a non-integral sample value, if
/// one exists. Otherwise, return an empty optional.
Optional<unsigned> maybeGetNonIntegralVarRow() const;
+};
- /// Given two potential pivot columns for a row, return the one that results
- /// in the lexicographically smallest sample vector.
- unsigned getLexMinPivotColumn(unsigned row, unsigned colA,
- unsigned colB) const;
+/// Represents the result of a symbolic lexicographic minimization computation.
+struct SymbolicLexMin {
+ SymbolicLexMin(unsigned nSymbols, unsigned nNonSymbols)
+ : lexmin(PresburgerSpace::getSetSpace(nSymbols), nNonSymbols),
+ unboundedDomain(
+ PresburgerSet::getEmpty(PresburgerSpace::getSetSpace(nSymbols))) {}
+
+ /// This maps assignments of symbols to the corresponding lexmin.
+ /// Takes no value when no integer sample exists for the assignment or if the
+ /// lexmin is unbounded.
+ PWMAFunction lexmin;
+ /// Contains all assignments to the symbols that made the lexmin unbounded.
+ /// Note that the symbols of the input set to the symbolic lexmin are dims
+ /// of this PrebsurgerSet.
+ PresburgerSet unboundedDomain;
+};
+
+/// A class to perform symbolic lexicographic optimization,
+/// i.e., to find, for every assignment to the symbols the specified
+/// `symbolDomain`, the lexicographically minimum value integer value attained
+/// by the non-symbol variables.
+///
+/// The input is a set parametrized by some symbols, i.e., the constant terms
+/// of the constraints in the set are affine expressions in the symbols, and
+/// every assignment to the symbols defines a non-symbolic set.
+///
+/// Accordingly, the sample values of the rows in our tableau will be affine
+/// expressions in the symbols, and every assignment to the symbols will define
+/// a non-symbolic LexSimplex. We then run the algorithm of
+/// LexSimplex::findIntegerLexMin simultaneously for every value of the symbols
+/// in the domain.
+///
+/// Often, the pivot to be performed is the same for all values of the symbols,
+/// in which case we just do it. For example, if the symbolic sample of a row is
+/// negative for all values in the symbol domain, the row needs to be pivoted
+/// irrespective of the precise value of the symbols. To answer queries like
+/// "Is this symbolic sample always negative in the symbol domain?", we maintain
+/// a `LexSimplex domainSimplex` correponding to the symbol domain.
+///
+/// In other cases, it may be that the symbolic sample is violated at some
+/// values in the symbol domain and not violated at others. In this case,
+/// the pivot to be performed does depend on the value of the symbols. We
+/// handle this by splitting the symbol domain. We run the algorithm for the
+/// case where the row isn't violated, and then come back and run the case
+/// where it is.
+class SymbolicLexSimplex : public LexSimplexBase {
+public:
+ /// `constraints` is the set for which the symbolic lexmin will be computed.
+ /// `symbolDomain` is the set of values of the symbols for which the lexmin
+ /// will be computed. `symbolDomain` should have a dim id for every symbol in
+ /// `constraints`, and no other ids.
+ SymbolicLexSimplex(const IntegerPolyhedron &constraints,
+ const IntegerPolyhedron &symbolDomain)
+ : LexSimplexBase(constraints), domainPoly(symbolDomain),
+ domainSimplex(symbolDomain) {
+ assert(domainPoly.getNumIds() == constraints.getNumSymbolIds());
+ assert(domainPoly.getNumDimIds() == constraints.getNumSymbolIds());
+ }
+
+ /// The lexmin will be stored as a function `lexmin` from symbols to
+ /// non-symbols in the result.
+ ///
+ /// For some values of the symbols, the lexmin may be unbounded.
+ /// These parts of the symbol domain will be stored in `unboundedDomain`.
+ SymbolicLexMin computeSymbolicIntegerLexMin();
+
+private:
+ /// Perform all pivots that do not require branching.
+ ///
+ /// Return failure if the tableau became empty, indicating that the polytope
+ /// is always integer empty in the current symbol domain.
+ /// Return success otherwise.
+ LogicalResult doNonBranchingPivots();
+
+ /// Get a row that is always violated in the current domain, if one exists.
+ Optional<unsigned> maybeGetAlwaysViolatedRow();
+
+ /// Get a row corresponding to a variable with non-integral sample value, if
+ /// one exists.
+ Optional<unsigned> maybeGetNonIntegralVarRow();
+
+ /// Given a row that has a non-integer sample value, cut away this fractional
+ /// sample value witahout removing any integer points, i.e., the integer
+ /// lexmin, if it exists, remains the same after a call to this function. This
+ /// may add constraints or local variables to the tableau, as well as to the
+ /// domain.
+ ///
+ /// Returns whether the cut constraint could be enforced, i.e. failure if the
+ /// cut made the polytope empty, and success if it didn't. Failure status
+ /// indicates that the polytope is always integer empty in the symbol domain
+ /// at the time of the call. (This function may modify the symbol domain, but
+ /// failure statu indicates that the polytope was empty for all symbol values
+ /// in the initial domain.)
+ LogicalResult addSymbolicCut(unsigned row);
+
+ /// Get the numerator of the symbolic sample of the specific row.
+ /// This is an affine expression in the symbols with integer coefficients.
+ /// The last element is the constant term. This ignores the big M coefficient.
+ SmallVector<int64_t, 8> getSymbolicSampleNumerator(unsigned row) const;
+
+ /// Return whether all the coefficients of the symbolic sample are integers.
+ ///
+ /// This does not consult the domain to check if the specified expression
+ /// is always integral despite coefficients being fractional.
+ bool isSymbolicSampleIntegral(unsigned row) const;
+
+ /// Record a lexmin. The tableau must be consistent with all variables
+ /// having symbolic samples with integer coefficients.
+ void recordOutput(SymbolicLexMin &result) const;
+
+ /// The symbol domain.
+ IntegerPolyhedron domainPoly;
+ /// Simplex corresponding to the symbol domain.
+ LexSimplex domainSimplex;
};
/// The Simplex class uses the Normal pivot rule and supports integer emptiness
@@ -540,7 +664,9 @@ class Simplex : public SimplexBase {
enum class Direction { Up, Down };
Simplex() = delete;
- explicit Simplex(unsigned nVar) : SimplexBase(nVar, /*mustUseBigM=*/false) {}
+ explicit Simplex(unsigned nVar)
+ : SimplexBase(nVar, /*mustUseBigM=*/false, /*symbolOffset=*/0,
+ /*nSymbol=*/0) {}
explicit Simplex(const IntegerRelation &constraints)
: Simplex(constraints.getNumIds()) {
intersectIntegerRelation(constraints);
diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
index bfa9a6539077d..5e527b5467f54 100644
--- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
@@ -14,6 +14,7 @@
#include "mlir/Analysis/Presburger/IntegerRelation.h"
#include "mlir/Analysis/Presburger/LinearTransform.h"
+#include "mlir/Analysis/Presburger/PWMAFunction.h"
#include "mlir/Analysis/Presburger/PresburgerRelation.h"
#include "mlir/Analysis/Presburger/Simplex.h"
#include "mlir/Analysis/Presburger/Utils.h"
@@ -145,6 +146,21 @@ void IntegerRelation::truncate(const CountsSnapshot &counts) {
removeEqualityRange(counts.getNumEqs(), getNumEqualities());
}
+SymbolicLexMin IntegerPolyhedron::findSymbolicIntegerLexMin() const {
+ // Compute the symbolic lexmin of the dims and locals, with the symbols being
+ // the actual symbols of this set.
+ SymbolicLexMin result =
+ SymbolicLexSimplex(
+ *this, PresburgerSpace::getSetSpace(/*numDims=*/getNumSymbolIds()))
+ .computeSymbolicIntegerLexMin();
+
+ // We want to return only the lexmin over the dims, so strip the locals from
+ // the computed lexmin.
+ result.lexmin.truncateOutput(result.lexmin.getNumOutputs() -
+ getNumLocalIds());
+ return result;
+}
+
unsigned IntegerRelation::insertId(IdKind kind, unsigned pos, unsigned num) {
assert(pos <= getNumIdKind(kind));
diff --git a/mlir/lib/Analysis/Presburger/Matrix.cpp b/mlir/lib/Analysis/Presburger/Matrix.cpp
index 219d490e7368a..680e4509b7cc8 100644
--- a/mlir/lib/Analysis/Presburger/Matrix.cpp
+++ b/mlir/lib/Analysis/Presburger/Matrix.cpp
@@ -66,6 +66,14 @@ unsigned Matrix::appendExtraRow() {
return nRows - 1;
}
+unsigned Matrix::appendExtraRow(ArrayRef<int64_t> elems) {
+ assert(elems.size() == nColumns && "elems must match row length!");
+ unsigned row = appendExtraRow();
+ for (unsigned col = 0; col < nColumns; ++col)
+ at(row, col) = elems[col];
+ return row;
+}
+
void Matrix::resizeHorizontally(unsigned newNColumns) {
if (newNColumns < nColumns)
removeColumns(newNColumns, nColumns - newNColumns);
diff --git a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp
index b995bc00a19c8..711e99aab35b4 100644
--- a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp
+++ b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp
@@ -114,6 +114,18 @@ void MultiAffineFunction::eliminateRedundantLocalId(unsigned posA,
IntegerPolyhedron::eliminateRedundantLocalId(posA, posB);
}
+void MultiAffineFunction::truncateOutput(unsigned count) {
+ assert(count <= output.getNumRows());
+ output.resizeVertically(count);
+}
+
+void PWMAFunction::truncateOutput(unsigned count) {
+ assert(count <= numOutputs);
+ for (MultiAffineFunction &piece : pieces)
+ piece.truncateOutput(count);
+ numOutputs = count;
+}
+
bool MultiAffineFunction::isEqualWhereDomainsOverlap(
MultiAffineFunction other) const {
if (!isSpaceCompatible(other))
diff --git a/mlir/lib/Analysis/Presburger/Simplex.cpp b/mlir/lib/Analysis/Presburger/Simplex.cpp
index 57e8f485742d2..f3bf42f40b177 100644
--- a/mlir/lib/Analysis/Presburger/Simplex.cpp
+++ b/mlir/lib/Analysis/Presburger/Simplex.cpp
@@ -18,15 +18,24 @@ using Direction = Simplex::Direction;
const int nullIndex = std::numeric_limits<int>::max();
-SimplexBase::SimplexBase(unsigned nVar, bool mustUseBigM)
+SimplexBase::SimplexBase(unsigned nVar, bool mustUseBigM, unsigned symbolOffset,
+ unsigned nSymbol)
: usingBigM(mustUseBigM), nRow(0), nCol(getNumFixedCols() + nVar),
- nRedundant(0), tableau(0, nCol), empty(false) {
+ nRedundant(0), nSymbol(nSymbol), tableau(0, nCol), empty(false) {
+ assert(symbolOffset + nSymbol <= nVar);
+
colUnknown.insert(colUnknown.begin(), getNumFixedCols(), nullIndex);
for (unsigned i = 0; i < nVar; ++i) {
var.emplace_back(Orientation::Column, /*restricted=*/false,
/*pos=*/getNumFixedCols() + i);
colUnknown.push_back(i);
}
+
+ // Move the symbols to be in columns [3, 3 + nSymbol).
+ for (unsigned i = 0; i < nSymbol; ++i) {
+ var[symbolOffset + i].isSymbol = true;
+ swapColumns(var[symbolOffset + i].pos, getNumFixedCols() + i);
+ }
}
const Simplex::Unknown &SimplexBase::unknownFromIndex(int index) const {
@@ -96,9 +105,13 @@ unsigned SimplexBase::addRow(ArrayRef<int64_t> coeffs, bool makeRestricted) {
// where M is the big M parameter. As such, when the user tries to add
// a row ax + by + cz + d, we express it in terms of our internal variables
// as -(a + b + c)M + a(M + x) + b(M + y) + c(M + z) + d.
+ //
+ // Symbols don't use the big M parameter since they do not get lex
+ // optimized.
int64_t bigMCoeff = 0;
for (unsigned i = 0; i < coeffs.size() - 1; ++i)
- bigMCoeff -= coeffs[i];
+ if (!var[i].isSymbol)
+ bigMCoeff -= coeffs[i];
// The coefficient to the big M parameter is stored in column 2.
tableau(nRow - 1, 2) = bigMCoeff;
}
@@ -164,19 +177,97 @@ Direction flippedDirection(Direction direction) {
}
} // namespace
+/// We simply make the tableau consistent while maintaining a lexicopositive
+/// basis transform, and then return the sample value. If the tableau becomes
+/// empty, we return empty.
+///
+/// Let the variables be x = (x_1, ... x_n).
+/// Let the basis unknowns be y = (y_1, ... y_n).
+/// We have that x = A*y + b for some n x n matrix A and n x 1 column vector b.
+///
+/// As we will show below, A*y is either zero or lexicopositive.
+/// Adding a lexicopositive vector to b will make it lexicographically
+/// greater, so A*y + b is always equal to or lexicographically greater than b.
+/// Thus, since we can attain x = b, that is the lexicographic minimum.
+///
+/// We have that that every column in A is lexicopositive, i.e., has at least
+/// one non-zero element, with the first such element being positive. Since for
+/// the tableau to be consistent we must have non-negative sample values not
+/// only for the constraints but also for the variables, we also have x >= 0 and
+/// y >= 0, by which we mean every element in these vectors is non-negative.
+///
+/// Proof that if every column in A is lexicopositive, and y >= 0, then
+/// A*y is zero or lexicopositive. Begin by considering A_1, the first row of A.
+/// If this row is all zeros, then (A*y)_1 = (A_1)*y = 0; proceed to the next
+/// row. If we run out of rows, A*y is zero and we are done; otherwise, we
+/// encounter some row A_i that has a non-zero element. Every column is
+/// lexicopositive and so has some positive element before any negative elements
+/// occur, so the element in this row for any column, if non-zero, must be
+/// positive. Consider (A*y)_i = (A_i)*y. All the elements in both vectors are
+/// non-negative, so if this is non-zero then it must be positive. Then the
+/// first non-zero element of A*y is positive so A*y is lexicopositive.
+///
+/// Otherwise, if (A_i)*y is zero, then for every column j that had a non-zero
+/// element in A_i, y_j is zero. Thus these columns have no contribution to A*y
+/// and we can completely ignore these columns of A. We now continue downwards,
+/// looking for rows of A that have a non-zero element other than in the ignored
+/// columns. If we find one, say A_k, once again these elements must be positive
+/// since they are the first non-zero element in each of these columns, so if
+/// (A_k)*y is not zero then we have that A*y is lexicopositive and if not we
+/// add these to the set of ignored columns and continue to the next row. If we
+/// run out of rows, then A*y is zero and we are done.
MaybeOptimum<SmallVector<Fraction, 8>> LexSimplex::findRationalLexMin() {
- restoreRationalConsistency();
+ if (restoreRationalConsistency().failed())
+ return OptimumKind::Empty;
return getRationalSample();
}
+/// Given a row that has a non-integer sample value, add an inequality such
+/// that this fractional sample value is cut away from the polytope. The added
+/// inequality will be such that no integer points are removed. i.e., the
+/// integer lexmin, if it exists, is the same with and without this constraint.
+///
+/// Let the row be
+/// (c + coeffM*M + a_1*s_1 + ... + a_m*s_m + b_1*y_1 + ... + b_n*y_n)/d,
+/// where s_1, ... s_m are the symbols and
+/// y_1, ... y_n are the other basis unknowns.
+///
+/// For this to be an integer, we want
+/// coeffM*M + a_1*s_1 + ... + a_m*s_m + b_1*y_1 + ... + b_n*y_n = -c (mod d)
+/// Note that this constraint must always hold, independent of the basis,
+/// becuse the row unknown's value always equals this expression, even if *we*
+/// later compute the sample value from a
diff erent expression based on a
+///
diff erent basis.
+///
+/// Let us assume that M has a factor of d in it. Imposing this constraint on M
+/// does not in any way hinder us from finding a value of M that is big enough.
+/// Moreover, this function is only called when the symbolic part of the sample,
+/// a_1*s_1 + ... + a_m*s_m, is known to be an integer.
+///
+/// Also, we can safely reduce the coefficients modulo d, so we have:
+///
+/// (b_1%d)y_1 + ... + (b_n%d)y_n = (-c%d) + k*d for some integer `k`
+///
+/// Note that all coefficient modulos here are non-negative. Also, all the
+/// unknowns are non-negative here as both constraints and variables are
+/// non-negative in LexSimplexBase. (We used the big M trick to make the
+/// variables non-negative). Therefore, the LHS here is non-negative.
+/// Since 0 <= (-c%d) < d, k is the quotient of dividing the LHS by d and
+/// is therefore non-negative as well.
+///
+/// So we have
+/// ((b_1%d)y_1 + ... + (b_n%d)y_n - (-c%d))/d >= 0.
+///
+/// The constraint is violated when added (it would be useless otherwise)
+/// so we immediately try to move it to a column.
LogicalResult LexSimplexBase::addCut(unsigned row) {
- int64_t denom = tableau(row, 0);
+ int64_t d = tableau(row, 0);
addZeroRow(/*makeRestricted=*/true);
- tableau(nRow - 1, 0) = denom;
- tableau(nRow - 1, 1) = -mod(-tableau(row, 1), denom);
- tableau(nRow - 1, 2) = 0; // M has all factors in it.
- for (unsigned col = 3; col < nCol; ++col)
- tableau(nRow - 1, col) = mod(tableau(row, col), denom);
+ tableau(nRow - 1, 0) = d;
+ tableau(nRow - 1, 1) = -mod(-tableau(row, 1), d); // -c%d.
+ tableau(nRow - 1, 2) = 0;
+ for (unsigned col = 3 + nSymbol; col < nCol; ++col)
+ tableau(nRow - 1, col) = mod(tableau(row, col), d); // b_i%d.
return moveRowUnknownToColumn(nRow - 1);
}
@@ -185,7 +276,7 @@ Optional<unsigned> LexSimplex::maybeGetNonIntegralVarRow() const {
if (u.orientation == Orientation::Column)
continue;
// If the sample value is of the form (a/d)M + b/d, we need b to be
- // divisible by d. We assume M is very large and contains all possible
+ // divisible by d. We assume M contains all possible
// factors and is divisible by everything.
unsigned row = u.pos;
if (tableau(row, 1) % tableau(row, 0) != 0)
@@ -195,28 +286,34 @@ Optional<unsigned> LexSimplex::maybeGetNonIntegralVarRow() const {
}
MaybeOptimum<SmallVector<int64_t, 8>> LexSimplex::findIntegerLexMin() {
- while (!empty) {
- restoreRationalConsistency();
- if (empty)
- return OptimumKind::Empty;
-
- if (Optional<unsigned> maybeRow = maybeGetNonIntegralVarRow()) {
- // Failure occurs when the polytope is integer empty.
- if (failed(addCut(*maybeRow)))
- return OptimumKind::Empty;
- continue;
- }
+ // We first try to make the tableau consistent.
+ if (restoreRationalConsistency().failed())
+ return OptimumKind::Empty;
- MaybeOptimum<SmallVector<Fraction, 8>> sample = getRationalSample();
- assert(!sample.isEmpty() && "If we reached here the sample should exist!");
- if (sample.isUnbounded())
- return OptimumKind::Unbounded;
- return llvm::to_vector<8>(
- llvm::map_range(*sample, std::mem_fn(&Fraction::getAsInteger)));
+ // Then, if the sample value is integral, we are done.
+ while (Optional<unsigned> maybeRow = maybeGetNonIntegralVarRow()) {
+ // Otherwise, for the variable whose row has a non-integral sample value,
+ // we add a cut, a constraint that remove this rational point
+ // while preserving all integer points, thus keeping the lexmin the same.
+ // We then again try to make the tableau with the new constraint
+ // consistent. This continues until the tableau becomes empty, in which
+ // case there is no integer point, or until there are no variables with
+ // non-integral sample values.
+ //
+ // Failure indicates that the tableau became empty, which occurs when the
+ // polytope is integer empty.
+ if (addCut(*maybeRow).failed())
+ return OptimumKind::Empty;
+ if (restoreRationalConsistency().failed())
+ return OptimumKind::Empty;
}
- // Polytope is integer empty.
- return OptimumKind::Empty;
+ MaybeOptimum<SmallVector<Fraction, 8>> sample = getRationalSample();
+ assert(!sample.isEmpty() && "If we reached here the sample should exist!");
+ if (sample.isUnbounded())
+ return OptimumKind::Unbounded;
+ return llvm::to_vector<8>(
+ llvm::map_range(*sample, std::mem_fn(&Fraction::getAsInteger)));
}
bool LexSimplex::isSeparateInequality(ArrayRef<int64_t> coeffs) {
@@ -228,6 +325,319 @@ bool LexSimplex::isSeparateInequality(ArrayRef<int64_t> coeffs) {
bool LexSimplex::isRedundantInequality(ArrayRef<int64_t> coeffs) {
return isSeparateInequality(getComplementIneq(coeffs));
}
+
+SmallVector<int64_t, 8>
+SymbolicLexSimplex::getSymbolicSampleNumerator(unsigned row) const {
+ SmallVector<int64_t, 8> sample;
+ sample.reserve(nSymbol + 1);
+ for (unsigned col = 3; col < 3 + nSymbol; ++col)
+ sample.push_back(tableau(row, col));
+ sample.push_back(tableau(row, 1));
+ return sample;
+}
+
+void LexSimplexBase::appendSymbol() {
+ appendVariable();
+ swapColumns(3 + nSymbol, nCol - 1);
+ var.back().isSymbol = true;
+ nSymbol++;
+}
+
+static bool isRangeDivisibleBy(ArrayRef<int64_t> range, int64_t divisor) {
+ assert(divisor > 0 && "divisor must be positive!");
+ return llvm::all_of(range, [divisor](int64_t x) { return x % divisor == 0; });
+}
+
+bool SymbolicLexSimplex::isSymbolicSampleIntegral(unsigned row) const {
+ int64_t denom = tableau(row, 0);
+ return tableau(row, 1) % denom == 0 &&
+ isRangeDivisibleBy(tableau.getRow(row).slice(3, nSymbol), denom);
+}
+
+/// This proceeds similarly to LexSimplex::addCut(). We are given a row that has
+/// a symbolic sample value with fractional coefficients.
+///
+/// Let the row be
+/// (c + coeffM*M + sum_i a_i*s_i + sum_j b_j*y_j)/d,
+/// where s_1, ... s_m are the symbols and
+/// y_1, ... y_n are the other basis unknowns.
+///
+/// As in LexSimplex::addCut, for this to be an integer, we want
+///
+/// coeffM*M + sum_j b_j*y_j = -c + sum_i (-a_i*s_i) (mod d)
+///
+/// This time, a_1*s_1 + ... + a_m*s_m may not be an integer. We find that
+///
+/// sum_i (b_i%d)y_i = ((-c%d) + sum_i (-a_i%d)s_i)%d + k*d for some integer k
+///
+/// where we take a modulo of the whole symbolic expression on the right to
+/// bring it into the range [0, d - 1]. Therefore, as in LexSimplex::addCut,
+/// k is the quotient on dividing the LHS by d, and since LHS >= 0, we have
+/// k >= 0 as well. We realize the modulo of the symbolic expression by adding a
+/// division variable
+///
+/// q = ((-c%d) + sum_i (-a_i%d)s_i)/d
+///
+/// to the symbol domain, so the equality becomes
+///
+/// sum_i (b_i%d)y_i = (-c%d) + sum_i (-a_i%d)s_i - q*d + k*d for some integer k
+///
+/// So the cut is
+/// (sum_i (b_i%d)y_i - (-c%d) - sum_i (-a_i%d)s_i + q*d)/d >= 0
+/// This constraint is violated when added so we immediately try to move it to a
+/// column.
+LogicalResult SymbolicLexSimplex::addSymbolicCut(unsigned row) {
+ int64_t d = tableau(row, 0);
+
+ // Add the division variable `q` described above to the symbol domain.
+ // q = ((-c%d) + sum_i (-a_i%d)s_i)/d.
+ SmallVector<int64_t, 8> domainDivCoeffs;
+ domainDivCoeffs.reserve(nSymbol + 1);
+ for (unsigned col = 3; col < 3 + nSymbol; ++col)
+ domainDivCoeffs.push_back(mod(-tableau(row, col), d)); // (-a_i%d)s_i
+ domainDivCoeffs.push_back(mod(-tableau(row, 1), d)); // -c%d.
+
+ domainSimplex.addDivisionVariable(domainDivCoeffs, d);
+ domainPoly.addLocalFloorDiv(domainDivCoeffs, d);
+
+ // Update `this` to account for the additional symbol we just added.
+ appendSymbol();
+
+ // Add the cut (sum_i (b_i%d)y_i - (-c%d) + sum_i -(-a_i%d)s_i + q*d)/d >= 0.
+ addZeroRow(/*makeRestricted=*/true);
+ tableau(nRow - 1, 0) = d;
+ tableau(nRow - 1, 2) = 0;
+
+ tableau(nRow - 1, 1) = -mod(-tableau(row, 1), d); // -(-c%d).
+ for (unsigned col = 3; col < 3 + nSymbol - 1; ++col)
+ tableau(nRow - 1, col) = -mod(-tableau(row, col), d); // -(-a_i%d)s_i.
+ tableau(nRow - 1, 3 + nSymbol - 1) = d; // q*d.
+
+ for (unsigned col = 3 + nSymbol; col < nCol; ++col)
+ tableau(nRow - 1, col) = mod(tableau(row, col), d); // (b_i%d)y_i.
+ return moveRowUnknownToColumn(nRow - 1);
+}
+
+void SymbolicLexSimplex::recordOutput(SymbolicLexMin &result) const {
+ Matrix output(0, domainPoly.getNumIds() + 1);
+ output.reserveRows(result.lexmin.getNumOutputs());
+ for (const Unknown &u : var) {
+ if (u.isSymbol)
+ continue;
+
+ if (u.orientation == Orientation::Column) {
+ // M + u has a sample value of zero so u has a sample value of -M, i.e,
+ // unbounded.
+ result.unboundedDomain.unionInPlace(domainPoly);
+ return;
+ }
+
+ int64_t denom = tableau(u.pos, 0);
+ if (tableau(u.pos, 2) < denom) {
+ // M + u has a sample value of fM + something, where f < 1, so
+ // u = (f - 1)M + something, which has a negative coefficient for M,
+ // and so is unbounded.
+ result.unboundedDomain.unionInPlace(domainPoly);
+ return;
+ }
+ assert(tableau(u.pos, 2) == denom &&
+ "Coefficient of M should not be greater than 1!");
+
+ SmallVector<int64_t, 8> sample = getSymbolicSampleNumerator(u.pos);
+ for (int64_t &elem : sample) {
+ assert(elem % denom == 0 && "coefficients must be integral!");
+ elem /= denom;
+ }
+ output.appendExtraRow(sample);
+ }
+ result.lexmin.addPiece(domainPoly, output);
+}
+
+Optional<unsigned> SymbolicLexSimplex::maybeGetAlwaysViolatedRow() {
+ // First look for rows that are clearly violated just from the big M
+ // coefficient, without needing to perform any simplex queries on the domain.
+ for (unsigned row = 0; row < nRow; ++row)
+ if (tableau(row, 2) < 0)
+ return row;
+
+ for (unsigned row = 0; row < nRow; ++row) {
+ if (tableau(row, 2) > 0)
+ continue;
+ if (domainSimplex.isSeparateInequality(getSymbolicSampleNumerator(row))) {
+ // Sample numerator always takes negative values in the symbol domain.
+ return row;
+ }
+ }
+ return {};
+}
+
+Optional<unsigned> SymbolicLexSimplex::maybeGetNonIntegralVarRow() {
+ for (const Unknown &u : var) {
+ if (u.orientation == Orientation::Column)
+ continue;
+ assert(!u.isSymbol && "Symbol should not be in row orientation!");
+ if (!isSymbolicSampleIntegral(u.pos))
+ return u.pos;
+ }
+ return {};
+}
+
+/// The non-branching pivots are just the ones moving the rows
+/// that are always violated in the symbol domain.
+LogicalResult SymbolicLexSimplex::doNonBranchingPivots() {
+ while (Optional<unsigned> row = maybeGetAlwaysViolatedRow())
+ if (moveRowUnknownToColumn(*row).failed())
+ return failure();
+ return success();
+}
+
+SymbolicLexMin SymbolicLexSimplex::computeSymbolicIntegerLexMin() {
+ SymbolicLexMin result(nSymbol, var.size() - nSymbol);
+
+ /// The algorithm is more naturally expressed recursively, but we implement
+ /// it iteratively here to avoid potential issues with stack overflows in the
+ /// compiler. We explicitly maintain the stack frames in a vector.
+ ///
+ /// To "recurse", we store the current "stack frame", i.e., state variables
+ /// that we will need when we "return", into `stack`, increment `level`, and
+ /// `continue`. To "tail recurse", we just `continue`.
+ /// To "return", we decrement `level` and `continue`.
+ ///
+ /// When there is no stack frame for the current `level`, this indicates that
+ /// we have just "recursed" or "tail recursed". When there does exist one,
+ /// this indicates that we have just "returned" from recursing. There is only
+ /// one point at which non-tail calls occur so we always "return" there.
+ unsigned level = 1;
+ struct StackFrame {
+ int splitIndex;
+ unsigned snapshot;
+ unsigned domainSnapshot;
+ IntegerRelation::CountsSnapshot domainPolyCounts;
+ };
+ SmallVector<StackFrame, 8> stack;
+
+ while (level > 0) {
+ assert(level >= stack.size());
+ if (level > stack.size()) {
+ if (empty || domainSimplex.findIntegerLexMin().isEmpty()) {
+ // No integer points; return.
+ --level;
+ continue;
+ }
+
+ if (doNonBranchingPivots().failed()) {
+ // Could not find pivots for violated constraints; return.
+ --level;
+ continue;
+ }
+
+ unsigned splitRow;
+ SmallVector<int64_t, 8> symbolicSample;
+ for (splitRow = 0; splitRow < nRow; ++splitRow) {
+ if (tableau(splitRow, 2) > 0)
+ continue;
+ assert(tableau(splitRow, 2) == 0 &&
+ "Non-branching pivots should have been handled already!");
+
+ symbolicSample = getSymbolicSampleNumerator(splitRow);
+ if (domainSimplex.isRedundantInequality(symbolicSample))
+ continue;
+
+ // It's neither redundant nor separate, so it takes both positive and
+ // negative values, and hence constitutes a row for which we need to
+ // split the domain and separately run each case.
+ assert(!domainSimplex.isSeparateInequality(symbolicSample) &&
+ "Non-branching pivots should have been handled already!");
+ break;
+ }
+
+ if (splitRow < nRow) {
+ unsigned domainSnapshot = domainSimplex.getSnapshot();
+ IntegerRelation::CountsSnapshot domainPolyCounts =
+ domainPoly.getCounts();
+
+ // First, we consider the part of the domain where the row is not
+ // violated. We don't have to do any pivots for the row in this case,
+ // but we record the additional constraint that defines this part of
+ // the domain.
+ domainSimplex.addInequality(symbolicSample);
+ domainPoly.addInequality(symbolicSample);
+
+ // Recurse.
+ //
+ // On return, the basis as a set is preserved but not the internal
+ // ordering within rows or columns. Thus, we take note of the index of
+ // the Unknown that caused the split, which may be in a
diff erent
+ // row when we come back from recursing. We will need this to recurse
+ // on the other part of the split domain, where the row is violated.
+ //
+ // Note that we have to capture the index above and not a reference to
+ // the Unknown itself, since the array it lives in might get
+ // reallocated.
+ int splitIndex = rowUnknown[splitRow];
+ unsigned snapshot = getSnapshot();
+ stack.push_back(
+ {splitIndex, snapshot, domainSnapshot, domainPolyCounts});
+ ++level;
+ continue;
+ }
+
+ // The tableau is rationally consistent for the current domain.
+ // Now we look for non-integral sample values and add cuts for them.
+ if (Optional<unsigned> row = maybeGetNonIntegralVarRow()) {
+ if (addSymbolicCut(*row).failed()) {
+ // No integral points; return.
+ --level;
+ continue;
+ }
+
+ // Rerun this level with the added cut constraint (tail recurse).
+ continue;
+ }
+
+ // Record output and return.
+ recordOutput(result);
+ --level;
+ continue;
+ }
+
+ if (level == stack.size()) {
+ // We have "returned" from "recursing".
+ const StackFrame &frame = stack.back();
+ domainPoly.truncate(frame.domainPolyCounts);
+ domainSimplex.rollback(frame.domainSnapshot);
+ rollback(frame.snapshot);
+ const Unknown &u = unknownFromIndex(frame.splitIndex);
+
+ // Drop the frame. We don't need it anymore.
+ stack.pop_back();
+
+ // Now we consider the part of the domain where the unknown `splitIndex`
+ // was negative.
+ assert(u.orientation == Orientation::Row &&
+ "The split row should have been returned to row orientation!");
+ SmallVector<int64_t, 8> splitIneq =
+ getComplementIneq(getSymbolicSampleNumerator(u.pos));
+ if (moveRowUnknownToColumn(u.pos).failed()) {
+ // The unknown can't be made non-negative; return.
+ --level;
+ continue;
+ }
+
+ // The unknown can be made negative; recurse with the corresponding domain
+ // constraints.
+ domainSimplex.addInequality(splitIneq);
+ domainPoly.addInequality(splitIneq);
+
+ // We are now taking care of the second half of the domain and we don't
+ // need to do anything else here after returning, so it's a tail recurse.
+ continue;
+ }
+ }
+
+ return result;
+}
+
bool LexSimplex::rowIsViolated(unsigned row) const {
if (tableau(row, 2) < 0)
return true;
@@ -243,19 +653,20 @@ Optional<unsigned> LexSimplex::maybeGetViolatedRow() const {
return {};
}
-// We simply look for violated rows and keep trying to move them to column
-// orientation, which always succeeds unless the constraints have no solution
-// in which case we just give up and return.
-void LexSimplex::restoreRationalConsistency() {
- while (Optional<unsigned> maybeViolatedRow = maybeGetViolatedRow()) {
- LogicalResult status = moveRowUnknownToColumn(*maybeViolatedRow);
- if (failed(status))
- return;
- }
+/// We simply look for violated rows and keep trying to move them to column
+/// orientation, which always succeeds unless the constraints have no solution
+/// in which case we just give up and return.
+LogicalResult LexSimplex::restoreRationalConsistency() {
+ if (empty)
+ return failure();
+ while (Optional<unsigned> maybeViolatedRow = maybeGetViolatedRow())
+ if (moveRowUnknownToColumn(*maybeViolatedRow).failed())
+ return failure();
+ return success();
}
// Move the row unknown to column orientation while preserving lexicopositivity
-// of the basis transform.
+// of the basis transform. The sample value of the row must be negative.
//
// We only consider pivots where the pivot element is positive. Suppose no such
// pivot exists, i.e., some violated row has no positive coefficient for any
@@ -318,7 +729,7 @@ void LexSimplex::restoreRationalConsistency() {
// minimizes the change in sample value.
LogicalResult LexSimplexBase::moveRowUnknownToColumn(unsigned row) {
Optional<unsigned> maybeColumn;
- for (unsigned col = 3; col < nCol; ++col) {
+ for (unsigned col = 3 + nSymbol; col < nCol; ++col) {
if (tableau(row, col) <= 0)
continue;
maybeColumn =
@@ -336,6 +747,7 @@ LogicalResult LexSimplexBase::moveRowUnknownToColumn(unsigned row) {
unsigned LexSimplexBase::getLexMinPivotColumn(unsigned row, unsigned colA,
unsigned colB) const {
+ // First, let's consider the non-symbolic case.
// A pivot causes the following change. (in the diagram the matrix elements
// are shown as rationals and there is no common denominator used)
//
@@ -359,7 +771,7 @@ unsigned LexSimplexBase::getLexMinPivotColumn(unsigned row, unsigned colA,
// (-p/a)M + (-b/a), i.e. 0 to -(pM + b)/a. Thus the change in the sample
// value is -s/a.
//
- // If the variable is the pivot row, it sampel value goes from s to 0, for a
+ // If the variable is the pivot row, its sample value goes from s to 0, for a
// change of -s.
//
// If the variable is a non-pivot row, its sample value changes from
@@ -373,8 +785,12 @@ unsigned LexSimplexBase::getLexMinPivotColumn(unsigned row, unsigned colA,
// comparisons involved and can be ignored, since -s is strictly positive.
//
// Thus we take away this common factor and just return 0, 1/a, 1, or c/a as
- // appropriate. This allows us to run the entire algorithm without ever having
- // to fix a value of M.
+ // appropriate. This allows us to run the entire algorithm treating M
+ // symbolically, as the pivot to be performed does not depend on the value
+ // of M, so long as the sample value s is negative. Note that this is not
+ // because of any special feature of M; by the same argument, we ignore the
+ // symbols too. The caller ensure that the sample value s is negative for
+ // all possible values of the symbols.
auto getSampleChangeCoeffForVar = [this, row](unsigned col,
const Unknown &u) -> Fraction {
int64_t a = tableau(row, col);
@@ -489,6 +905,7 @@ void SimplexBase::pivot(Pivot pair) { pivot(pair.row, pair.column); }
/// element.
void SimplexBase::pivot(unsigned pivotRow, unsigned pivotCol) {
assert(pivotCol >= getNumFixedCols() && "Refusing to pivot invalid column");
+ assert(!unknownFromColumn(pivotCol).isSymbol);
swapRowWithCol(pivotRow, pivotCol);
std::swap(tableau(pivotRow, 0), tableau(pivotRow, pivotCol));
@@ -778,6 +1195,9 @@ void SimplexBase::undo(UndoLogEntry entry) {
assert(var.back().orientation == Orientation::Column &&
"Variable to be removed must be in column orientation!");
+ if (var.back().isSymbol)
+ nSymbol--;
+
// Move this variable to the last column and remove the column from the
// tableau.
swapColumns(var.back().pos, nCol - 1);
diff --git a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
index 4149d85d8759f..2cb6ada89397a 100644
--- a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
@@ -8,6 +8,7 @@
#include "./Utils.h"
#include "mlir/Analysis/Presburger/IntegerRelation.h"
+#include "mlir/Analysis/Presburger/PWMAFunction.h"
#include "mlir/Analysis/Presburger/Simplex.h"
#include <gmock/gmock.h>
@@ -1134,6 +1135,229 @@ TEST(IntegerPolyhedronTest, findIntegerLexMin) {
">= 0, -11*z + 5*y - 3*x + 7 >= 0)"));
}
+void expectSymbolicIntegerLexMin(
+ StringRef polyStr,
+ ArrayRef<std::pair<StringRef, SmallVector<SmallVector<int64_t, 8>, 8>>>
+ expectedLexminRepr,
+ ArrayRef<StringRef> expectedUnboundedDomainRepr) {
+ IntegerPolyhedron poly = parsePoly(polyStr);
+
+ ASSERT_NE(poly.getNumDimIds(), 0u);
+ ASSERT_NE(poly.getNumSymbolIds(), 0u);
+
+ PWMAFunction expectedLexmin =
+ parsePWMAF(/*numInputs=*/poly.getNumSymbolIds(),
+ /*numOutputs=*/poly.getNumDimIds(), expectedLexminRepr);
+
+ PresburgerSet expectedUnboundedDomain = parsePresburgerSetFromPolyStrings(
+ poly.getNumSymbolIds(), expectedUnboundedDomainRepr);
+
+ SymbolicLexMin result = poly.findSymbolicIntegerLexMin();
+
+ EXPECT_TRUE(result.lexmin.isEqual(expectedLexmin));
+ if (!result.lexmin.isEqual(expectedLexmin)) {
+ llvm::errs() << "got:\n";
+ result.lexmin.dump();
+ llvm::errs() << "expected:\n";
+ expectedLexmin.dump();
+ }
+
+ EXPECT_TRUE(result.unboundedDomain.isEqual(expectedUnboundedDomain));
+ if (!result.unboundedDomain.isEqual(expectedUnboundedDomain))
+ result.unboundedDomain.dump();
+}
+
+void expectSymbolicIntegerLexMin(
+ StringRef polyStr,
+ ArrayRef<std::pair<StringRef, SmallVector<SmallVector<int64_t, 8>, 8>>>
+ result) {
+ expectSymbolicIntegerLexMin(polyStr, result, {});
+}
+
+TEST(IntegerPolyhedronTest, findSymbolicIntegerLexMin) {
+ expectSymbolicIntegerLexMin("(x)[a] : (x - a >= 0)",
+ {
+ {"(a) : ()", {{1, 0}}}, // a
+ });
+
+ expectSymbolicIntegerLexMin(
+ "(x)[a, b] : (x - a >= 0, x - b >= 0)",
+ {
+ {"(a, b) : (a - b >= 0)", {{1, 0, 0}}}, // a
+ {"(a, b) : (b - a - 1 >= 0)", {{0, 1, 0}}}, // b
+ });
+
+ expectSymbolicIntegerLexMin(
+ "(x)[a, b, c] : (x -a >= 0, x - b >= 0, x - c >= 0)",
+ {
+ {"(a, b, c) : (a - b >= 0, a - c >= 0)", {{1, 0, 0, 0}}}, // a
+ {"(a, b, c) : (b - a - 1 >= 0, b - c >= 0)", {{0, 1, 0, 0}}}, // b
+ {"(a, b, c) : (c - a - 1 >= 0, c - b - 1 >= 0)", {{0, 0, 1, 0}}}, // c
+ });
+
+ expectSymbolicIntegerLexMin("(x, y)[a] : (x - a >= 0, x + y >= 0)",
+ {
+ {"(a) : ()", {{1, 0}, {-1, 0}}}, // (a, -a)
+ });
+
+ expectSymbolicIntegerLexMin(
+ "(x, y)[a] : (x - a >= 0, x + y >= 0, y >= 0)",
+ {
+ {"(a) : (a >= 0)", {{1, 0}, {0, 0}}}, // (a, 0)
+ {"(a) : (-a - 1 >= 0)", {{1, 0}, {-1, 0}}}, // (a, -a)
+ });
+
+ expectSymbolicIntegerLexMin(
+ "(x, y)[a, b, c] : (x - a >= 0, y - b >= 0, c - x - y >= 0)",
+ {
+ {"(a, b, c) : (c - a - b >= 0)",
+ {{1, 0, 0, 0}, {0, 1, 0, 0}}}, // (a, b)
+ });
+
+ expectSymbolicIntegerLexMin(
+ "(x, y, z)[a, b, c] : (c - z >= 0, b - y >= 0, x + y + z - a == 0)",
+ {
+ {"(a, b, c) : ()",
+ {{1, -1, -1, 0}, {0, 1, 0, 0}, {0, 0, 1, 0}}}, // (a - b - c, b, c)
+ });
+
+ expectSymbolicIntegerLexMin(
+ "(x)[a, b] : (a >= 0, b >= 0, x >= 0, a + b + x - 1 >= 0)",
+ {
+ {"(a, b) : (a >= 0, b >= 0, a + b - 1 >= 0)", {{0, 0, 0}}}, // 0
+ {"(a, b) : (a == 0, b == 0)", {{0, 0, 1}}}, // 1
+ });
+
+ expectSymbolicIntegerLexMin(
+ "(x)[a, b] : (1 - a >= 0, a >= 0, 1 - b >= 0, b >= 0, 1 - x >= 0, x >= "
+ "0, a + b + x - 1 >= 0)",
+ {
+ {"(a, b) : (1 - a >= 0, a >= 0, 1 - b >= 0, b >= 0, a + b - 1 >= 0)",
+ {{0, 0, 0}}}, // 0
+ {"(a, b) : (a == 0, b == 0)", {{0, 0, 1}}}, // 1
+ });
+
+ expectSymbolicIntegerLexMin(
+ "(x, y, z)[a, b] : (x - a == 0, y - b == 0, x >= 0, y >= 0, z >= 0, x + "
+ "y + z - 1 >= 0)",
+ {
+ {"(a, b) : (a >= 0, b >= 0, 1 - a - b >= 0)",
+ {{1, 0, 0}, {0, 1, 0}, {-1, -1, 1}}}, // (a, b, 1 - a - b)
+ {"(a, b) : (a >= 0, b >= 0, a + b - 2 >= 0)",
+ {{1, 0, 0}, {0, 1, 0}, {0, 0, 0}}}, // (a, b, 0)
+ });
+
+ expectSymbolicIntegerLexMin("(x)[a, b] : (x - a == 0, x - b >= 0)",
+ {
+ {"(a, b) : (a - b >= 0)", {{1, 0, 0}}}, // a
+ });
+
+ expectSymbolicIntegerLexMin(
+ "(q)[a] : (a - 1 - 3*q == 0, q >= 0)",
+ {
+ {"(a) : (a - 1 - 3*(a floordiv 3) == 0, a >= 0)",
+ {{0, 1, 0}}}, // a floordiv 3
+ });
+
+ expectSymbolicIntegerLexMin(
+ "(r, q)[a] : (a - r - 3*q == 0, q >= 0, 1 - r >= 0, r >= 0)",
+ {
+ {"(a) : (a - 0 - 3*(a floordiv 3) == 0, a >= 0)",
+ {{0, 0, 0}, {0, 1, 0}}}, // (0, a floordiv 3)
+ {"(a) : (a - 1 - 3*(a floordiv 3) == 0, a >= 0)",
+ {{0, 0, 1}, {0, 1, 0}}}, // (1 a floordiv 3)
+ });
+
+ expectSymbolicIntegerLexMin(
+ "(r, q)[a] : (a - r - 3*q == 0, q >= 0, 2 - r >= 0, r - 1 >= 0)",
+ {
+ {"(a) : (a - 1 - 3*(a floordiv 3) == 0, a >= 0)",
+ {{0, 0, 1}, {0, 1, 0}}}, // (1, a floordiv 3)
+ {"(a) : (a - 2 - 3*(a floordiv 3) == 0, a >= 0)",
+ {{0, 0, 2}, {0, 1, 0}}}, // (2, a floordiv 3)
+ });
+
+ expectSymbolicIntegerLexMin(
+ "(r, q)[a] : (a - r - 3*q == 0, q >= 0, r >= 0)",
+ {
+ {"(a) : (a - 3*(a floordiv 3) == 0, a >= 0)",
+ {{0, 0, 0}, {0, 1, 0}}}, // (0, a floordiv 3)
+ {"(a) : (a - 1 - 3*(a floordiv 3) == 0, a >= 0)",
+ {{0, 0, 1}, {0, 1, 0}}}, // (1, a floordiv 3)
+ {"(a) : (a - 2 - 3*(a floordiv 3) == 0, a >= 0)",
+ {{0, 0, 2}, {0, 1, 0}}}, // (2, a floordiv 3)
+ });
+
+ expectSymbolicIntegerLexMin(
+ "(x, y, z, w)[g] : ("
+ // x, y, z, w are boolean variables.
+ "1 - x >= 0, x >= 0, 1 - y >= 0, y >= 0,"
+ "1 - z >= 0, z >= 0, 1 - w >= 0, w >= 0,"
+ // We have some constraints on them:
+ "x + y + z - 1 >= 0," // x or y or z
+ "x + y + w - 1 >= 0," // x or y or w
+ "1 - x + 1 - y + 1 - w - 1 >= 0," // ~x or ~y or ~w
+ // What's the lexmin solution using exactly g true vars?
+ "g - x - y - z - w == 0)",
+ {
+ {"(g) : (g - 1 == 0)",
+ {{0, 0}, {0, 1}, {0, 0}, {0, 0}}}, // (0, 1, 0, 0)
+ {"(g) : (g - 2 == 0)",
+ {{0, 0}, {0, 0}, {0, 1}, {0, 1}}}, // (0, 0, 1, 1)
+ {"(g) : (g - 3 == 0)",
+ {{0, 0}, {0, 1}, {0, 1}, {0, 1}}}, // (0, 1, 1, 1)
+ });
+
+ // Bezout's lemma: if a, b are constants,
+ // the set of values that ax + by can take is all multiples of gcd(a, b).
+ expectSymbolicIntegerLexMin(
+ // If (x, y) is a solution for a given [a, r], then so is (x - 5, y + 2).
+ // So the lexmin is unbounded if it exists.
+ "(x, y)[a, r] : (a >= 0, r - a + 14*x + 35*y == 0)", {},
+ // According to Bezout's lemma, 14x + 35y can take on all multiples
+ // of 7 and no other values. So the solution exists iff r - a is a
+ // multiple of 7.
+ {"(a, r) : (a >= 0, r - a - 7*((r - a) floordiv 7) == 0)"});
+
+ // The lexmins are unbounded.
+ expectSymbolicIntegerLexMin("(x, y)[a] : (9*x - 4*y - 2*a >= 0)", {},
+ {"(a) : ()"});
+
+ // Test cases adapted from isl.
+ expectSymbolicIntegerLexMin(
+ // a = 2b - 2(c - b), c - b >= 0.
+ // So b is minimized when c = b.
+ "(b, c)[a] : (a - 4*b + 2*c == 0, c - b >= 0)",
+ {
+ {"(a) : (a - 2*(a floordiv 2) == 0)",
+ {{0, 1, 0}, {0, 1, 0}}}, // (a floordiv 2, a floordiv 2)
+ });
+
+ expectSymbolicIntegerLexMin(
+ // 0 <= b <= 255, 1 <= a - 512b <= 509,
+ // b + 8 >= 1 + 16*(b + 8 floordiv 16) // i.e. b % 16 != 8
+ "(b)[a] : (255 - b >= 0, b >= 0, a - 512*b - 1 >= 0, 512*b -a + 509 >= "
+ "0, b + 7 - 16*((8 + b) floordiv 16) >= 0)",
+ {
+ {"(a) : (255 - (a floordiv 512) >= 0, a >= 0, a - 512*(a floordiv "
+ "512) - 1 >= 0, 512*(a floordiv 512) - a + 509 >= 0, (a floordiv "
+ "512) + 7 - 16*((8 + (a floordiv 512)) floordiv 16) >= 0)",
+ {{0, 1, 0, 0}}}, // (a floordiv 2, a floordiv 2)
+ });
+
+ expectSymbolicIntegerLexMin(
+ "(a, b)[K, N, x, y] : (N - K - 2 >= 0, K + 4 - N >= 0, x - 4 >= 0, x + 6 "
+ "- 2*N >= 0, K+N - x - 1 >= 0, a - N + 1 >= 0, K+N-1-a >= 0,a + 6 - b - "
+ "N >= 0, 2*N - 4 - a >= 0,"
+ "2*N - 3*K + a - b >= 0, 4*N - K + 1 - 3*b >= 0, b - N >= 0, a - x - 1 "
+ ">= 0)",
+ {{
+ "(K, N, x, y) : (x + 6 - 2*N >= 0, 2*N - 5 - x >= 0, x + 1 -3*K + N "
+ ">= 0, N + K - 2 - x >= 0, x - 4 >= 0)",
+ {{0, 0, 1, 0, 1}, {0, 1, 0, 0, 0}} // (1 + x, N)
+ }});
+}
+
static void
expectComputedVolumeIsValidOverapprox(const IntegerPolyhedron &poly,
Optional<uint64_t> trueVolume,
More information about the Mlir-commits
mailing list