[Mlir-commits] [mlir] c2fcaf8 - [MLIR][Presburger] Simplex: refactor (symbolic)lex to support specifying multiple varKinds as symbols

Arjun P llvmlistbot at llvm.org
Fri Jul 1 09:47:39 PDT 2022


Author: Arjun P
Date: 2022-07-01T17:47:39+01:00
New Revision: c2fcaf84e5a33a18385620453fdbb73fee58cbbc

URL: https://github.com/llvm/llvm-project/commit/c2fcaf84e5a33a18385620453fdbb73fee58cbbc
DIFF: https://github.com/llvm/llvm-project/commit/c2fcaf84e5a33a18385620453fdbb73fee58cbbc.diff

LOG: [MLIR][Presburger] Simplex: refactor (symbolic)lex to support specifying multiple varKinds as symbols

This is also required to support lexmin for relations.

Reviewed By: Groverkss

Differential Revision: https://reviews.llvm.org/D128931

Added: 
    

Modified: 
    mlir/include/mlir/Analysis/Presburger/Simplex.h
    mlir/include/mlir/Analysis/Presburger/Utils.h
    mlir/lib/Analysis/Presburger/Simplex.cpp
    mlir/lib/Analysis/Presburger/Utils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/Presburger/Simplex.h b/mlir/include/mlir/Analysis/Presburger/Simplex.h
index b265897673661..2c3aedea8452f 100644
--- a/mlir/include/mlir/Analysis/Presburger/Simplex.h
+++ b/mlir/include/mlir/Analysis/Presburger/Simplex.h
@@ -23,6 +23,7 @@
 #include "mlir/Support/LogicalResult.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/Optional.h"
+#include "llvm/ADT/SmallBitVector.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Support/StringSaver.h"
 #include "llvm/Support/raw_ostream.h"
@@ -210,14 +211,18 @@ class SimplexBase {
 
 protected:
   /// Construct a SimplexBase with the specified number of variables and fixed
-  /// columns.
+  /// columns. The first overload should be used when there are nosymbols.
+  /// With the second overload, the specified range of vars will be marked
+  /// as symbols. With the third overload, `isSymbol` is a bitmask denoting
+  /// which vars are symbols. The size of `isSymbol` must be `nVar`.
   ///
   /// For example, Simplex uses two fixed columns: the denominator and the
   /// constant term, whereas LexSimplex has an extra fixed column for the
   /// so-called big M parameter. For more information see the documentation for
   /// LexSimplex.
-  SimplexBase(unsigned nVar, bool mustUseBigM, unsigned symbolOffset,
-              unsigned nSymbol);
+  SimplexBase(unsigned nVar, bool mustUseBigM);
+  SimplexBase(unsigned nVar, bool mustUseBigM,
+              const llvm::SmallBitVector &isSymbol);
 
   enum class Orientation { Row, Column };
 
@@ -422,12 +427,16 @@ class LexSimplexBase : public SimplexBase {
   unsigned getSnapshot() { return SimplexBase::getSnapshotBasis(); }
 
 protected:
-  LexSimplexBase(unsigned nVar, unsigned symbolOffset, unsigned nSymbol)
-      : SimplexBase(nVar, /*mustUseBigM=*/true, symbolOffset, nSymbol) {}
+  LexSimplexBase(unsigned nVar) : SimplexBase(nVar, /*mustUseBigM=*/true) {}
+  LexSimplexBase(unsigned nVar, const llvm::SmallBitVector &isSymbol)
+      : SimplexBase(nVar, /*mustUseBigM=*/true, isSymbol) {}
   explicit LexSimplexBase(const IntegerRelation &constraints)
-      : LexSimplexBase(constraints.getNumVars(),
-                       constraints.getVarKindOffset(VarKind::Symbol),
-                       constraints.getNumSymbolVars()) {
+      : LexSimplexBase(constraints.getNumVars()) {
+    intersectIntegerRelation(constraints);
+  }
+  explicit LexSimplexBase(const IntegerRelation &constraints,
+                          const llvm::SmallBitVector &isSymbol)
+      : LexSimplexBase(constraints.getNumVars(), isSymbol) {
     intersectIntegerRelation(constraints);
   }
 
@@ -470,13 +479,12 @@ class LexSimplexBase : public SimplexBase {
 /// provides support for integer-exact redundancy and separateness checks.
 class LexSimplex : public LexSimplexBase {
 public:
-  explicit LexSimplex(unsigned nVar)
-      : LexSimplexBase(nVar, /*symbolOffset=*/0, /*nSymbol=*/0) {}
+  explicit LexSimplex(unsigned nVar) : LexSimplexBase(nVar) {}
+  // Note that LexSimplex does NOT support symbolic lexmin;
+  // use SymbolicLexSimplex if that is required. LexSimplex ignores the VarKinds
+  // of the passed IntegerRelation. Symbols will be treated as ordinary vars.
   explicit LexSimplex(const IntegerRelation &constraints)
-      : LexSimplexBase(constraints) {
-    assert(constraints.getNumSymbolVars() == 0 &&
-           "LexSimplex does not support symbols!");
-  }
+      : LexSimplexBase(constraints) {}
 
   /// Return the lexicographically minimum rational solution to the constraints.
   MaybeOptimum<SmallVector<Fraction, 8>> findRationalLexMin();
@@ -521,10 +529,9 @@ class LexSimplex : public LexSimplexBase {
 
 /// Represents the result of a symbolic lexicographic minimization computation.
 struct SymbolicLexMin {
-  SymbolicLexMin(unsigned nSymbols, unsigned nNonSymbols)
-      : lexmin(PresburgerSpace::getSetSpace(nSymbols), nNonSymbols),
-        unboundedDomain(
-            PresburgerSet::getEmpty(PresburgerSpace::getSetSpace(nSymbols))) {}
+  SymbolicLexMin(const PresburgerSpace &domainSpace, unsigned numOutputs)
+      : lexmin(domainSpace, numOutputs),
+        unboundedDomain(PresburgerSet::getEmpty(domainSpace)) {}
 
   /// This maps assignments of symbols to the corresponding lexmin.
   /// Takes no value when no integer sample exists for the assignment or if the
@@ -569,30 +576,40 @@ class SymbolicLexSimplex : public LexSimplexBase {
   /// `constraints` is the set for which the symbolic lexmin will be computed.
   /// `symbolDomain` is the set of values of the symbols for which the lexmin
   /// will be computed. `symbolDomain` should have a dim var for every symbol in
-  /// `constraints`, and no other vars.
+  /// `constraints`, and no other vars. `isSymbol` specifies which vars of
+  /// `constraints` should be considered as symbols.
+  ///
+  /// The resulting SymbolicLexMin's space will be compatible with that of
+  /// symbolDomain.
   SymbolicLexSimplex(const IntegerRelation &constraints,
-                     const IntegerPolyhedron &symbolDomain)
-      : SymbolicLexSimplex(constraints,
-                           constraints.getVarKindOffset(VarKind::Symbol),
-                           symbolDomain) {
-    assert(constraints.getNumSymbolVars() == symbolDomain.getNumVars());
+                     const IntegerPolyhedron &symbolDomain,
+                     const llvm::SmallBitVector &isSymbol)
+      : LexSimplexBase(constraints, isSymbol), domainPoly(symbolDomain),
+        domainSimplex(symbolDomain) {
+    // TODO consider supporting this case. It amounts
+    // to just returning the input constraints.
+    assert(domainPoly.getNumVars() > 0 &&
+           "there must be some non-symbols to optimize!");
   }
 
-  /// An overload to select some other subrange of ids as symbols for lexmin.
+  /// An overload to select some subrange of ids as symbols for lexmin.
   /// The symbol ids are the range of ids with absolute index
   /// [symbolOffset, symbolOffset + symbolDomain.getNumVars())
-  /// symbolDomain should only have dim ids.
   SymbolicLexSimplex(const IntegerRelation &constraints, unsigned symbolOffset,
                      const IntegerPolyhedron &symbolDomain)
-      : LexSimplexBase(/*nVar=*/constraints.getNumVars(), symbolOffset,
-                       symbolDomain.getNumVars()),
-        domainPoly(symbolDomain), domainSimplex(symbolDomain) {
-    // TODO consider supporting this case. It amounts
-    // to just returning the input constraints.
-    assert(domainPoly.getNumVars() > 0 &&
-           "there must be some non-symbols to optimize!");
-    assert(domainPoly.getNumVars() == domainPoly.getNumDimVars());
-    intersectIntegerRelation(constraints);
+      : SymbolicLexSimplex(constraints, symbolDomain,
+                           getSubrangeBitVector(constraints.getNumVars(),
+                                                symbolOffset,
+                                                symbolDomain.getNumVars())) {}
+
+  /// An overload to select the symbols of `constraints` as symbols for lexmin.
+  SymbolicLexSimplex(const IntegerRelation &constraints,
+                     const IntegerPolyhedron &symbolDomain)
+      : SymbolicLexSimplex(constraints,
+                           constraints.getVarKindOffset(VarKind::Symbol),
+                           symbolDomain) {
+    assert(constraints.getNumSymbolVars() == symbolDomain.getNumVars() &&
+           "symbolDomain must have as many vars as constraints has symbols!");
   }
 
   /// The lexmin will be stored as a function `lexmin` from symbols to
@@ -678,9 +695,7 @@ class Simplex : public SimplexBase {
   enum class Direction { Up, Down };
 
   Simplex() = delete;
-  explicit Simplex(unsigned nVar)
-      : SimplexBase(nVar, /*mustUseBigM=*/false, /*symbolOffset=*/0,
-                    /*nSymbol=*/0) {}
+  explicit Simplex(unsigned nVar) : SimplexBase(nVar, /*mustUseBigM=*/false) {}
   explicit Simplex(const IntegerRelation &constraints)
       : Simplex(constraints.getNumVars()) {
     intersectIntegerRelation(constraints);

diff  --git a/mlir/include/mlir/Analysis/Presburger/Utils.h b/mlir/include/mlir/Analysis/Presburger/Utils.h
index c735ddd037c80..52321c5413def 100644
--- a/mlir/include/mlir/Analysis/Presburger/Utils.h
+++ b/mlir/include/mlir/Analysis/Presburger/Utils.h
@@ -15,6 +15,7 @@
 
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallBitVector.h"
 
 namespace mlir {
 namespace presburger {
@@ -120,6 +121,9 @@ SmallVector<int64_t, 8> getDivUpperBound(ArrayRef<int64_t> dividend,
 SmallVector<int64_t, 8> getDivLowerBound(ArrayRef<int64_t> dividend,
                                          int64_t divisor, unsigned localVarIdx);
 
+llvm::SmallBitVector getSubrangeBitVector(unsigned len, unsigned setOffset,
+                                          unsigned numSet);
+
 /// Check if the pos^th variable can be expressed as a floordiv of an affine
 /// function of other variables (where the divisor is a positive constant).
 /// `foundRepr` contains a boolean for each variable indicating if the

diff  --git a/mlir/lib/Analysis/Presburger/Simplex.cpp b/mlir/lib/Analysis/Presburger/Simplex.cpp
index 0574f155cd433..6f541363f3867 100644
--- a/mlir/lib/Analysis/Presburger/Simplex.cpp
+++ b/mlir/lib/Analysis/Presburger/Simplex.cpp
@@ -31,23 +31,28 @@ scaleAndAddForAssert(ArrayRef<int64_t> a, int64_t scale, ArrayRef<int64_t> b) {
   return res;
 }
 
-SimplexBase::SimplexBase(unsigned nVar, bool mustUseBigM, unsigned symbolOffset,
-                         unsigned nSymbol)
-    : usingBigM(mustUseBigM), nRedundant(0), nSymbol(nSymbol),
+SimplexBase::SimplexBase(unsigned nVar, bool mustUseBigM)
+    : usingBigM(mustUseBigM), nRedundant(0), nSymbol(0),
       tableau(0, getNumFixedCols() + nVar), empty(false) {
-  assert(symbolOffset + nSymbol <= nVar);
-
   colUnknown.insert(colUnknown.begin(), getNumFixedCols(), nullIndex);
   for (unsigned i = 0; i < nVar; ++i) {
     var.emplace_back(Orientation::Column, /*restricted=*/false,
                      /*pos=*/getNumFixedCols() + i);
     colUnknown.push_back(i);
   }
+}
 
-  // Move the symbols to be in columns [3, 3 + nSymbol).
-  for (unsigned i = 0; i < nSymbol; ++i) {
-    var[symbolOffset + i].isSymbol = true;
-    swapColumns(var[symbolOffset + i].pos, getNumFixedCols() + i);
+SimplexBase::SimplexBase(unsigned nVar, bool mustUseBigM,
+                         const llvm::SmallBitVector &isSymbol)
+    : SimplexBase(nVar, mustUseBigM) {
+  assert(isSymbol.size() == nVar && "invalid bitmask!");
+  // Invariant: nSymbol is the number of symbols that have been marked
+  // already and these occupy the columns
+  // [getNumFixedCols(), getNumFixedCols() + nSymbol).
+  for (unsigned symbolIdx : isSymbol.set_bits()) {
+    var[symbolIdx].isSymbol = true;
+    swapColumns(var[symbolIdx].pos, getNumFixedCols() + nSymbol);
+    ++nSymbol;
   }
 }
 
@@ -502,7 +507,7 @@ LogicalResult SymbolicLexSimplex::doNonBranchingPivots() {
 }
 
 SymbolicLexMin SymbolicLexSimplex::computeSymbolicIntegerLexMin() {
-  SymbolicLexMin result(nSymbol, var.size() - nSymbol);
+  SymbolicLexMin result(domainPoly.getSpace(), var.size() - nSymbol);
 
   /// The algorithm is more naturally expressed recursively, but we implement
   /// it iteratively here to avoid potential issues with stack overflows in the

diff  --git a/mlir/lib/Analysis/Presburger/Utils.cpp b/mlir/lib/Analysis/Presburger/Utils.cpp
index 4da07c0f84a75..9a6fe16fe2e36 100644
--- a/mlir/lib/Analysis/Presburger/Utils.cpp
+++ b/mlir/lib/Analysis/Presburger/Utils.cpp
@@ -253,6 +253,14 @@ MaybeLocalRepr presburger::computeSingleVarRepr(
   return repr;
 }
 
+llvm::SmallBitVector presburger::getSubrangeBitVector(unsigned len,
+                                                      unsigned setOffset,
+                                                      unsigned numSet) {
+  llvm::SmallBitVector vec(len, false);
+  vec.set(setOffset, setOffset + numSet);
+  return vec;
+}
+
 void presburger::removeDuplicateDivs(
     std::vector<SmallVector<int64_t, 8>> &divs,
     SmallVectorImpl<unsigned> &denoms, unsigned localOffset,


        


More information about the Mlir-commits mailing list