[Mlir-commits] [mlir] c2fcaf8 - [MLIR][Presburger] Simplex: refactor (symbolic)lex to support specifying multiple varKinds as symbols
Arjun P
llvmlistbot at llvm.org
Fri Jul 1 09:47:39 PDT 2022
Author: Arjun P
Date: 2022-07-01T17:47:39+01:00
New Revision: c2fcaf84e5a33a18385620453fdbb73fee58cbbc
URL: https://github.com/llvm/llvm-project/commit/c2fcaf84e5a33a18385620453fdbb73fee58cbbc
DIFF: https://github.com/llvm/llvm-project/commit/c2fcaf84e5a33a18385620453fdbb73fee58cbbc.diff
LOG: [MLIR][Presburger] Simplex: refactor (symbolic)lex to support specifying multiple varKinds as symbols
This is also required to support lexmin for relations.
Reviewed By: Groverkss
Differential Revision: https://reviews.llvm.org/D128931
Added:
Modified:
mlir/include/mlir/Analysis/Presburger/Simplex.h
mlir/include/mlir/Analysis/Presburger/Utils.h
mlir/lib/Analysis/Presburger/Simplex.cpp
mlir/lib/Analysis/Presburger/Utils.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Analysis/Presburger/Simplex.h b/mlir/include/mlir/Analysis/Presburger/Simplex.h
index b265897673661..2c3aedea8452f 100644
--- a/mlir/include/mlir/Analysis/Presburger/Simplex.h
+++ b/mlir/include/mlir/Analysis/Presburger/Simplex.h
@@ -23,6 +23,7 @@
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/StringSaver.h"
#include "llvm/Support/raw_ostream.h"
@@ -210,14 +211,18 @@ class SimplexBase {
protected:
/// Construct a SimplexBase with the specified number of variables and fixed
- /// columns.
+ /// columns. The first overload should be used when there are nosymbols.
+ /// With the second overload, the specified range of vars will be marked
+ /// as symbols. With the third overload, `isSymbol` is a bitmask denoting
+ /// which vars are symbols. The size of `isSymbol` must be `nVar`.
///
/// For example, Simplex uses two fixed columns: the denominator and the
/// constant term, whereas LexSimplex has an extra fixed column for the
/// so-called big M parameter. For more information see the documentation for
/// LexSimplex.
- SimplexBase(unsigned nVar, bool mustUseBigM, unsigned symbolOffset,
- unsigned nSymbol);
+ SimplexBase(unsigned nVar, bool mustUseBigM);
+ SimplexBase(unsigned nVar, bool mustUseBigM,
+ const llvm::SmallBitVector &isSymbol);
enum class Orientation { Row, Column };
@@ -422,12 +427,16 @@ class LexSimplexBase : public SimplexBase {
unsigned getSnapshot() { return SimplexBase::getSnapshotBasis(); }
protected:
- LexSimplexBase(unsigned nVar, unsigned symbolOffset, unsigned nSymbol)
- : SimplexBase(nVar, /*mustUseBigM=*/true, symbolOffset, nSymbol) {}
+ LexSimplexBase(unsigned nVar) : SimplexBase(nVar, /*mustUseBigM=*/true) {}
+ LexSimplexBase(unsigned nVar, const llvm::SmallBitVector &isSymbol)
+ : SimplexBase(nVar, /*mustUseBigM=*/true, isSymbol) {}
explicit LexSimplexBase(const IntegerRelation &constraints)
- : LexSimplexBase(constraints.getNumVars(),
- constraints.getVarKindOffset(VarKind::Symbol),
- constraints.getNumSymbolVars()) {
+ : LexSimplexBase(constraints.getNumVars()) {
+ intersectIntegerRelation(constraints);
+ }
+ explicit LexSimplexBase(const IntegerRelation &constraints,
+ const llvm::SmallBitVector &isSymbol)
+ : LexSimplexBase(constraints.getNumVars(), isSymbol) {
intersectIntegerRelation(constraints);
}
@@ -470,13 +479,12 @@ class LexSimplexBase : public SimplexBase {
/// provides support for integer-exact redundancy and separateness checks.
class LexSimplex : public LexSimplexBase {
public:
- explicit LexSimplex(unsigned nVar)
- : LexSimplexBase(nVar, /*symbolOffset=*/0, /*nSymbol=*/0) {}
+ explicit LexSimplex(unsigned nVar) : LexSimplexBase(nVar) {}
+ // Note that LexSimplex does NOT support symbolic lexmin;
+ // use SymbolicLexSimplex if that is required. LexSimplex ignores the VarKinds
+ // of the passed IntegerRelation. Symbols will be treated as ordinary vars.
explicit LexSimplex(const IntegerRelation &constraints)
- : LexSimplexBase(constraints) {
- assert(constraints.getNumSymbolVars() == 0 &&
- "LexSimplex does not support symbols!");
- }
+ : LexSimplexBase(constraints) {}
/// Return the lexicographically minimum rational solution to the constraints.
MaybeOptimum<SmallVector<Fraction, 8>> findRationalLexMin();
@@ -521,10 +529,9 @@ class LexSimplex : public LexSimplexBase {
/// Represents the result of a symbolic lexicographic minimization computation.
struct SymbolicLexMin {
- SymbolicLexMin(unsigned nSymbols, unsigned nNonSymbols)
- : lexmin(PresburgerSpace::getSetSpace(nSymbols), nNonSymbols),
- unboundedDomain(
- PresburgerSet::getEmpty(PresburgerSpace::getSetSpace(nSymbols))) {}
+ SymbolicLexMin(const PresburgerSpace &domainSpace, unsigned numOutputs)
+ : lexmin(domainSpace, numOutputs),
+ unboundedDomain(PresburgerSet::getEmpty(domainSpace)) {}
/// This maps assignments of symbols to the corresponding lexmin.
/// Takes no value when no integer sample exists for the assignment or if the
@@ -569,30 +576,40 @@ class SymbolicLexSimplex : public LexSimplexBase {
/// `constraints` is the set for which the symbolic lexmin will be computed.
/// `symbolDomain` is the set of values of the symbols for which the lexmin
/// will be computed. `symbolDomain` should have a dim var for every symbol in
- /// `constraints`, and no other vars.
+ /// `constraints`, and no other vars. `isSymbol` specifies which vars of
+ /// `constraints` should be considered as symbols.
+ ///
+ /// The resulting SymbolicLexMin's space will be compatible with that of
+ /// symbolDomain.
SymbolicLexSimplex(const IntegerRelation &constraints,
- const IntegerPolyhedron &symbolDomain)
- : SymbolicLexSimplex(constraints,
- constraints.getVarKindOffset(VarKind::Symbol),
- symbolDomain) {
- assert(constraints.getNumSymbolVars() == symbolDomain.getNumVars());
+ const IntegerPolyhedron &symbolDomain,
+ const llvm::SmallBitVector &isSymbol)
+ : LexSimplexBase(constraints, isSymbol), domainPoly(symbolDomain),
+ domainSimplex(symbolDomain) {
+ // TODO consider supporting this case. It amounts
+ // to just returning the input constraints.
+ assert(domainPoly.getNumVars() > 0 &&
+ "there must be some non-symbols to optimize!");
}
- /// An overload to select some other subrange of ids as symbols for lexmin.
+ /// An overload to select some subrange of ids as symbols for lexmin.
/// The symbol ids are the range of ids with absolute index
/// [symbolOffset, symbolOffset + symbolDomain.getNumVars())
- /// symbolDomain should only have dim ids.
SymbolicLexSimplex(const IntegerRelation &constraints, unsigned symbolOffset,
const IntegerPolyhedron &symbolDomain)
- : LexSimplexBase(/*nVar=*/constraints.getNumVars(), symbolOffset,
- symbolDomain.getNumVars()),
- domainPoly(symbolDomain), domainSimplex(symbolDomain) {
- // TODO consider supporting this case. It amounts
- // to just returning the input constraints.
- assert(domainPoly.getNumVars() > 0 &&
- "there must be some non-symbols to optimize!");
- assert(domainPoly.getNumVars() == domainPoly.getNumDimVars());
- intersectIntegerRelation(constraints);
+ : SymbolicLexSimplex(constraints, symbolDomain,
+ getSubrangeBitVector(constraints.getNumVars(),
+ symbolOffset,
+ symbolDomain.getNumVars())) {}
+
+ /// An overload to select the symbols of `constraints` as symbols for lexmin.
+ SymbolicLexSimplex(const IntegerRelation &constraints,
+ const IntegerPolyhedron &symbolDomain)
+ : SymbolicLexSimplex(constraints,
+ constraints.getVarKindOffset(VarKind::Symbol),
+ symbolDomain) {
+ assert(constraints.getNumSymbolVars() == symbolDomain.getNumVars() &&
+ "symbolDomain must have as many vars as constraints has symbols!");
}
/// The lexmin will be stored as a function `lexmin` from symbols to
@@ -678,9 +695,7 @@ class Simplex : public SimplexBase {
enum class Direction { Up, Down };
Simplex() = delete;
- explicit Simplex(unsigned nVar)
- : SimplexBase(nVar, /*mustUseBigM=*/false, /*symbolOffset=*/0,
- /*nSymbol=*/0) {}
+ explicit Simplex(unsigned nVar) : SimplexBase(nVar, /*mustUseBigM=*/false) {}
explicit Simplex(const IntegerRelation &constraints)
: Simplex(constraints.getNumVars()) {
intersectIntegerRelation(constraints);
diff --git a/mlir/include/mlir/Analysis/Presburger/Utils.h b/mlir/include/mlir/Analysis/Presburger/Utils.h
index c735ddd037c80..52321c5413def 100644
--- a/mlir/include/mlir/Analysis/Presburger/Utils.h
+++ b/mlir/include/mlir/Analysis/Presburger/Utils.h
@@ -15,6 +15,7 @@
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallBitVector.h"
namespace mlir {
namespace presburger {
@@ -120,6 +121,9 @@ SmallVector<int64_t, 8> getDivUpperBound(ArrayRef<int64_t> dividend,
SmallVector<int64_t, 8> getDivLowerBound(ArrayRef<int64_t> dividend,
int64_t divisor, unsigned localVarIdx);
+llvm::SmallBitVector getSubrangeBitVector(unsigned len, unsigned setOffset,
+ unsigned numSet);
+
/// Check if the pos^th variable can be expressed as a floordiv of an affine
/// function of other variables (where the divisor is a positive constant).
/// `foundRepr` contains a boolean for each variable indicating if the
diff --git a/mlir/lib/Analysis/Presburger/Simplex.cpp b/mlir/lib/Analysis/Presburger/Simplex.cpp
index 0574f155cd433..6f541363f3867 100644
--- a/mlir/lib/Analysis/Presburger/Simplex.cpp
+++ b/mlir/lib/Analysis/Presburger/Simplex.cpp
@@ -31,23 +31,28 @@ scaleAndAddForAssert(ArrayRef<int64_t> a, int64_t scale, ArrayRef<int64_t> b) {
return res;
}
-SimplexBase::SimplexBase(unsigned nVar, bool mustUseBigM, unsigned symbolOffset,
- unsigned nSymbol)
- : usingBigM(mustUseBigM), nRedundant(0), nSymbol(nSymbol),
+SimplexBase::SimplexBase(unsigned nVar, bool mustUseBigM)
+ : usingBigM(mustUseBigM), nRedundant(0), nSymbol(0),
tableau(0, getNumFixedCols() + nVar), empty(false) {
- assert(symbolOffset + nSymbol <= nVar);
-
colUnknown.insert(colUnknown.begin(), getNumFixedCols(), nullIndex);
for (unsigned i = 0; i < nVar; ++i) {
var.emplace_back(Orientation::Column, /*restricted=*/false,
/*pos=*/getNumFixedCols() + i);
colUnknown.push_back(i);
}
+}
- // Move the symbols to be in columns [3, 3 + nSymbol).
- for (unsigned i = 0; i < nSymbol; ++i) {
- var[symbolOffset + i].isSymbol = true;
- swapColumns(var[symbolOffset + i].pos, getNumFixedCols() + i);
+SimplexBase::SimplexBase(unsigned nVar, bool mustUseBigM,
+ const llvm::SmallBitVector &isSymbol)
+ : SimplexBase(nVar, mustUseBigM) {
+ assert(isSymbol.size() == nVar && "invalid bitmask!");
+ // Invariant: nSymbol is the number of symbols that have been marked
+ // already and these occupy the columns
+ // [getNumFixedCols(), getNumFixedCols() + nSymbol).
+ for (unsigned symbolIdx : isSymbol.set_bits()) {
+ var[symbolIdx].isSymbol = true;
+ swapColumns(var[symbolIdx].pos, getNumFixedCols() + nSymbol);
+ ++nSymbol;
}
}
@@ -502,7 +507,7 @@ LogicalResult SymbolicLexSimplex::doNonBranchingPivots() {
}
SymbolicLexMin SymbolicLexSimplex::computeSymbolicIntegerLexMin() {
- SymbolicLexMin result(nSymbol, var.size() - nSymbol);
+ SymbolicLexMin result(domainPoly.getSpace(), var.size() - nSymbol);
/// The algorithm is more naturally expressed recursively, but we implement
/// it iteratively here to avoid potential issues with stack overflows in the
diff --git a/mlir/lib/Analysis/Presburger/Utils.cpp b/mlir/lib/Analysis/Presburger/Utils.cpp
index 4da07c0f84a75..9a6fe16fe2e36 100644
--- a/mlir/lib/Analysis/Presburger/Utils.cpp
+++ b/mlir/lib/Analysis/Presburger/Utils.cpp
@@ -253,6 +253,14 @@ MaybeLocalRepr presburger::computeSingleVarRepr(
return repr;
}
+llvm::SmallBitVector presburger::getSubrangeBitVector(unsigned len,
+ unsigned setOffset,
+ unsigned numSet) {
+ llvm::SmallBitVector vec(len, false);
+ vec.set(setOffset, setOffset + numSet);
+ return vec;
+}
+
void presburger::removeDuplicateDivs(
std::vector<SmallVector<int64_t, 8>> &divs,
SmallVectorImpl<unsigned> &denoms, unsigned localOffset,
More information about the Mlir-commits
mailing list