[Mlir-commits] [mlir] 56863ad - [MLIR][Presburger] Implement findSymbolicIntegerLexMax for IntegerRelation

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jul 26 07:17:15 PDT 2023


Author: iambrj
Date: 2023-07-26T19:36:29+05:30
New Revision: 56863adf8eecebb16b4e3a901ba5ddc8b7074f01

URL: https://github.com/llvm/llvm-project/commit/56863adf8eecebb16b4e3a901ba5ddc8b7074f01
DIFF: https://github.com/llvm/llvm-project/commit/56863adf8eecebb16b4e3a901ba5ddc8b7074f01.diff

LOG: [MLIR][Presburger] Implement findSymbolicIntegerLexMax for IntegerRelation

This patch implements findSymbolicIntegerLexMax for IntegerRelation

Reviewed By: Groverkss

Differential Revision: https://reviews.llvm.org/D156023

Added: 
    

Modified: 
    mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
    mlir/include/mlir/Analysis/Presburger/Simplex.h
    mlir/lib/Analysis/Presburger/IntegerRelation.cpp
    mlir/lib/Analysis/Presburger/Simplex.cpp
    mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
    mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
index 9646894736de06..369b31511afdc0 100644
--- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
+++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
@@ -29,7 +29,7 @@ class IntegerRelation;
 class IntegerPolyhedron;
 class PresburgerSet;
 class PresburgerRelation;
-struct SymbolicLexMin;
+struct SymbolicLexOpt;
 
 /// The type of bound: equal, lower bound or upper bound.
 enum class BoundType { EQ, LB, UB };
@@ -659,15 +659,18 @@ class IntegerRelation {
   /// x = a if b <= a, a <= c
   /// x = b if a <  b, b <= c
   ///
-  /// This function is stored in the `lexmin` function in the result.
+  /// This function is stored in the `lexopt` function in the result.
   /// Some assignments to the symbols might make the set empty.
   /// Such points are not part of the function's domain.
   /// In the above example, this happens when max(a, b) > c.
   ///
   /// For some values of the symbols, the lexmin may be unbounded.
-  /// `SymbolicLexMin` stores these parts of the symbolic domain in a separate
+  /// `SymbolicLexOpt` stores these parts of the symbolic domain in a separate
   /// `PresburgerSet`, `unboundedDomain`.
-  SymbolicLexMin findSymbolicIntegerLexMin() const;
+  SymbolicLexOpt findSymbolicIntegerLexMin() const;
+
+  /// Same as findSymbolicIntegerLexMin but produces lexmax instead of lexmin
+  SymbolicLexOpt findSymbolicIntegerLexMax() const;
 
   /// Return the set 
diff erence of this set and the given set, i.e.,
   /// return `this \ set`.

diff  --git a/mlir/include/mlir/Analysis/Presburger/Simplex.h b/mlir/include/mlir/Analysis/Presburger/Simplex.h
index 470d483cbb5648..79a42d6c38d411 100644
--- a/mlir/include/mlir/Analysis/Presburger/Simplex.h
+++ b/mlir/include/mlir/Analysis/Presburger/Simplex.h
@@ -43,7 +43,7 @@ class GBRSimplex;
 /// these constraints that are redundant, i.e. a subset of constraints that
 /// doesn't constrain the affine set further after adding the non-redundant
 /// constraints. The LexSimplex class provides support for computing the
-/// lexicographic minimum of an IntegerRelation. The SymbolicLexMin class
+/// lexicographic minimum of an IntegerRelation. The SymbolicLexOpt class
 /// provides support for computing symbolic lexicographic minimums. All of these
 /// classes can be constructed from an IntegerRelation, and all inherit common
 /// functionality from SimplexBase.
@@ -529,18 +529,18 @@ class LexSimplex : public LexSimplexBase {
   std::optional<unsigned> maybeGetNonIntegralVarRow() const;
 };
 
-/// Represents the result of a symbolic lexicographic minimization computation.
-struct SymbolicLexMin {
-  SymbolicLexMin(const PresburgerSpace &space)
-      : lexmin(space),
+/// Represents the result of a symbolic lexicographic optimization computation.
+struct SymbolicLexOpt {
+  SymbolicLexOpt(const PresburgerSpace &space)
+      : lexopt(space),
         unboundedDomain(PresburgerSet::getEmpty(space.getDomainSpace())) {}
 
-  /// This maps assignments of symbols to the corresponding lexmin.
+  /// This maps assignments of symbols to the corresponding lexopt.
   /// Takes no value when no integer sample exists for the assignment or if the
-  /// lexmin is unbounded.
-  PWMAFunction lexmin;
-  /// Contains all assignments to the symbols that made the lexmin unbounded.
-  /// Note that the symbols of the input set to the symbolic lexmin are dims
+  /// lexopt is unbounded.
+  PWMAFunction lexopt;
+  /// Contains all assignments to the symbols that made the lexopt unbounded.
+  /// Note that the symbols of the input set to the symbolic lexopt are dims
   /// of this PrebsurgerSet.
   PresburgerSet unboundedDomain;
 };
@@ -575,13 +575,13 @@ struct SymbolicLexMin {
 /// where it is.
 class SymbolicLexSimplex : public LexSimplexBase {
 public:
-  /// `constraints` is the set for which the symbolic lexmin will be computed.
-  /// `symbolDomain` is the set of values of the symbols for which the lexmin
+  /// `constraints` is the set for which the symbolic lexopt will be computed.
+  /// `symbolDomain` is the set of values of the symbols for which the lexopt
   /// will be computed. `symbolDomain` should have a dim var for every symbol in
   /// `constraints`, and no other vars. `isSymbol` specifies which vars of
   /// `constraints` should be considered as symbols.
   ///
-  /// The resulting SymbolicLexMin's space will be compatible with that of
+  /// The resulting SymbolicLexOpt's space will be compatible with that of
   /// symbolDomain.
   SymbolicLexSimplex(const IntegerRelation &constraints,
                      const IntegerPolyhedron &symbolDomain,
@@ -594,7 +594,7 @@ class SymbolicLexSimplex : public LexSimplexBase {
            "there must be some non-symbols to optimize!");
   }
 
-  /// An overload to select some subrange of ids as symbols for lexmin.
+  /// An overload to select some subrange of ids as symbols for lexopt.
   /// The symbol ids are the range of ids with absolute index
   /// [symbolOffset, symbolOffset + symbolDomain.getNumVars())
   SymbolicLexSimplex(const IntegerRelation &constraints, unsigned symbolOffset,
@@ -604,7 +604,7 @@ class SymbolicLexSimplex : public LexSimplexBase {
                                                 symbolOffset,
                                                 symbolDomain.getNumVars())) {}
 
-  /// An overload to select the symbols of `constraints` as symbols for lexmin.
+  /// An overload to select the symbols of `constraints` as symbols for lexopt.
   SymbolicLexSimplex(const IntegerRelation &constraints,
                      const IntegerPolyhedron &symbolDomain)
       : SymbolicLexSimplex(constraints,
@@ -614,7 +614,7 @@ class SymbolicLexSimplex : public LexSimplexBase {
            "symbolDomain must have as many vars as constraints has symbols!");
   }
 
-  /// The lexmin will be stored as a function `lexmin` from symbols to
+  /// The lexmin will be stored as a function `lexopt` from symbols to
   /// non-symbols in the result.
   ///
   /// For some values of the symbols, the lexmin may be unbounded.
@@ -622,7 +622,7 @@ class SymbolicLexSimplex : public LexSimplexBase {
   ///
   /// The spaces of the sets in the result are compatible with the symbolDomain
   /// passed in the SymbolicLexSimplex constructor.
-  SymbolicLexMin computeSymbolicIntegerLexMin();
+  SymbolicLexOpt computeSymbolicIntegerLexMin();
 
 private:
   /// Perform all pivots that do not require branching.
@@ -670,7 +670,7 @@ class SymbolicLexSimplex : public LexSimplexBase {
 
   /// Record a lexmin. The tableau must be consistent with all variables
   /// having symbolic samples with integer coefficients.
-  void recordOutput(SymbolicLexMin &result) const;
+  void recordOutput(SymbolicLexOpt &result) const;
 
   /// The symbol domain.
   IntegerPolyhedron domainPoly;

diff  --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
index 75c6adbf6bbc2b..b2359aea461ae7 100644
--- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
@@ -172,7 +172,7 @@ PresburgerRelation IntegerRelation::computeReprWithOnlyDivLocals() const {
     return PresburgerRelation(*this);
 
   // Move all the non-div locals to the end, as the current API to
-  // SymbolicLexMin requires these to form a contiguous range.
+  // SymbolicLexOpt requires these to form a contiguous range.
   //
   // Take a copy so we can perform mutations.
   IntegerRelation copy = *this;
@@ -211,13 +211,13 @@ PresburgerRelation IntegerRelation::computeReprWithOnlyDivLocals() const {
   // exists, which is the union of the domain of the returned lexmin function
   // and the returned set of assignments to the "symbols" that makes the lexmin
   // unbounded.
-  SymbolicLexMin lexminResult =
+  SymbolicLexOpt lexminResult =
       SymbolicLexSimplex(copy, /*symbolOffset*/ 0,
                          IntegerPolyhedron(PresburgerSpace::getSetSpace(
                              /*numDims=*/copy.getNumVars() - numNonDivLocals)))
           .computeSymbolicIntegerLexMin();
   PresburgerRelation result =
-      lexminResult.lexmin.getDomain().unionSet(lexminResult.unboundedDomain);
+      lexminResult.lexopt.getDomain().unionSet(lexminResult.unboundedDomain);
 
   // The result set might lie in the wrong space -- all its ids are dims.
   // Set it to the desired space and return.
@@ -227,7 +227,7 @@ PresburgerRelation IntegerRelation::computeReprWithOnlyDivLocals() const {
   return result;
 }
 
-SymbolicLexMin IntegerRelation::findSymbolicIntegerLexMin() const {
+SymbolicLexOpt IntegerRelation::findSymbolicIntegerLexMin() const {
   // Symbol and Domain vars will be used as symbols for symbolic lexmin.
   // In other words, for every value of the symbols and domain, return the
   // lexmin value of the (range, locals).
@@ -239,7 +239,7 @@ SymbolicLexMin IntegerRelation::findSymbolicIntegerLexMin() const {
   // Compute the symbolic lexmin of the dims and locals, with the symbols being
   // the actual symbols of this set.
   // The resultant space of lexmin is the space of the relation itself.
-  SymbolicLexMin result =
+  SymbolicLexOpt result =
       SymbolicLexSimplex(*this,
                          IntegerPolyhedron(PresburgerSpace::getSetSpace(
                              /*numDims=*/getNumDomainVars(),
@@ -249,11 +249,49 @@ SymbolicLexMin IntegerRelation::findSymbolicIntegerLexMin() const {
 
   // We want to return only the lexmin over the dims, so strip the locals from
   // the computed lexmin.
-  result.lexmin.removeOutputs(result.lexmin.getNumOutputs() - getNumLocalVars(),
-                              result.lexmin.getNumOutputs());
+  result.lexopt.removeOutputs(result.lexopt.getNumOutputs() - getNumLocalVars(),
+                              result.lexopt.getNumOutputs());
   return result;
 }
 
+/// findSymbolicIntegerLexMax is implemented using findSymbolicIntegerLexMin as
+/// follows:
+/// 1. A new relation is created which is `this` relation with the sign of
+/// each dimension variable in range flipped;
+/// 2. findSymbolicIntegerLexMin is called on the range negated relation to
+/// compute the negated lexmax of `this` relation;
+/// 3. The sign of the negated lexmax is flipped and returned.
+SymbolicLexOpt IntegerRelation::findSymbolicIntegerLexMax() const {
+  IntegerRelation flippedRel = *this;
+  // Flip range sign by flipping the sign of range variables in all constraints.
+  for (unsigned j = getNumDomainVars(),
+                b = getNumDomainVars() + getNumRangeVars();
+       j < b; j++) {
+    for (unsigned i = 0, a = getNumEqualities(); i < a; i++)
+      flippedRel.atEq(i, j) = -1 * atEq(i, j);
+    for (unsigned i = 0, a = getNumInequalities(); i < a; i++)
+      flippedRel.atIneq(i, j) = -1 * atIneq(i, j);
+  }
+  // Compute negated lexmax by computing lexmin.
+  SymbolicLexOpt flippedSymbolicIntegerLexMax =
+                     flippedRel.findSymbolicIntegerLexMin(),
+                 symbolicIntegerLexMax(
+                     flippedSymbolicIntegerLexMax.lexopt.getSpace());
+  // Get lexmax by flipping range sign in the PWMA constraints.
+  for (auto &flippedPiece :
+       flippedSymbolicIntegerLexMax.lexopt.getAllPieces()) {
+    Matrix mat = flippedPiece.output.getOutputMatrix();
+    for (unsigned i = 0, e = mat.getNumRows(); i < e; i++)
+      mat.negateRow(i);
+    MultiAffineFunction maf(flippedPiece.output.getSpace(), mat);
+    PWMAFunction::Piece piece = {flippedPiece.domain, maf};
+    symbolicIntegerLexMax.lexopt.addPiece(piece);
+  }
+  symbolicIntegerLexMax.unboundedDomain =
+      flippedSymbolicIntegerLexMax.unboundedDomain;
+  return symbolicIntegerLexMax;
+}
+
 PresburgerRelation
 IntegerRelation::subtract(const PresburgerRelation &set) const {
   return PresburgerRelation(*this).subtract(set);

diff  --git a/mlir/lib/Analysis/Presburger/Simplex.cpp b/mlir/lib/Analysis/Presburger/Simplex.cpp
index 61c39bd315f187..eff312b69e1de1 100644
--- a/mlir/lib/Analysis/Presburger/Simplex.cpp
+++ b/mlir/lib/Analysis/Presburger/Simplex.cpp
@@ -435,9 +435,9 @@ LogicalResult SymbolicLexSimplex::addSymbolicCut(unsigned row) {
   return moveRowUnknownToColumn(cutRow);
 }
 
-void SymbolicLexSimplex::recordOutput(SymbolicLexMin &result) const {
+void SymbolicLexSimplex::recordOutput(SymbolicLexOpt &result) const {
   Matrix output(0, domainPoly.getNumVars() + 1);
-  output.reserveRows(result.lexmin.getNumOutputs());
+  output.reserveRows(result.lexopt.getNumOutputs());
   for (const Unknown &u : var) {
     if (u.isSymbol)
       continue;
@@ -469,10 +469,10 @@ void SymbolicLexSimplex::recordOutput(SymbolicLexMin &result) const {
   }
 
   // Store the output in a MultiAffineFunction and add it the result.
-  PresburgerSpace funcSpace = result.lexmin.getSpace();
+  PresburgerSpace funcSpace = result.lexopt.getSpace();
   funcSpace.insertVar(VarKind::Local, 0, domainPoly.getNumLocalVars());
 
-  result.lexmin.addPiece(
+  result.lexopt.addPiece(
       {PresburgerSet(domainPoly),
        MultiAffineFunction(funcSpace, output, domainPoly.getLocalReprs())});
 }
@@ -515,8 +515,8 @@ LogicalResult SymbolicLexSimplex::doNonBranchingPivots() {
   return success();
 }
 
-SymbolicLexMin SymbolicLexSimplex::computeSymbolicIntegerLexMin() {
-  SymbolicLexMin result(PresburgerSpace::getRelationSpace(
+SymbolicLexOpt SymbolicLexSimplex::computeSymbolicIntegerLexMin() {
+  SymbolicLexOpt result(PresburgerSpace::getRelationSpace(
       /*numDomain=*/domainPoly.getNumDimVars(),
       /*numRange=*/var.size() - nSymbol,
       /*numSymbols=*/domainPoly.getNumSymbolVars()));

diff  --git a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
index 6beb9384c8bf22..ba035e84ff1fd7 100644
--- a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
@@ -1198,13 +1198,13 @@ void expectSymbolicIntegerLexMin(
   ASSERT_NE(poly.getNumDimVars(), 0u);
   ASSERT_NE(poly.getNumSymbolVars(), 0u);
 
-  SymbolicLexMin result = poly.findSymbolicIntegerLexMin();
+  SymbolicLexOpt result = poly.findSymbolicIntegerLexMin();
 
   if (expectedLexminRepr.empty()) {
-    EXPECT_TRUE(result.lexmin.getDomain().isIntegerEmpty());
+    EXPECT_TRUE(result.lexopt.getDomain().isIntegerEmpty());
   } else {
     PWMAFunction expectedLexmin = parsePWMAF(expectedLexminRepr);
-    EXPECT_TRUE(result.lexmin.isEqual(expectedLexmin));
+    EXPECT_TRUE(result.lexopt.isEqual(expectedLexmin));
   }
 
   if (expectedUnboundedDomainRepr.empty()) {

diff  --git a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
index 1a9241bfe3ffcd..dd20e058e358dd 100644
--- a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
@@ -126,7 +126,7 @@ TEST(IntegerRelationTest, applyDomainAndRange) {
 }
 
 TEST(IntegerRelationTest, symbolicLexmin) {
-  SymbolicLexMin lexmin =
+  SymbolicLexOpt lexmin =
       parseRelationFromSet("(a, x)[b] : (x - a >= 0, x - b >= 0)", 1)
           .findSymbolicIntegerLexMin();
 
@@ -135,5 +135,43 @@ TEST(IntegerRelationTest, symbolicLexmin) {
       {"(a)[b] : (b - a - 1 >= 0)", "(a)[b] -> (b)"}, // b
   });
   EXPECT_TRUE(lexmin.unboundedDomain.isIntegerEmpty());
-  EXPECT_TRUE(lexmin.lexmin.isEqual(expectedLexmin));
+  EXPECT_TRUE(lexmin.lexopt.isEqual(expectedLexmin));
+}
+
+TEST(IntegerRelationTest, symbolicLexmax) {
+  SymbolicLexOpt lexmax1 =
+      parseRelationFromSet("(a, x)[b] : (a - x >= 0, b - x >= 0)", 1)
+          .findSymbolicIntegerLexMax();
+
+  PWMAFunction expectedLexmax1 = parsePWMAF({
+      {"(a)[b] : (a - b >= 0)", "(a)[b] -> (b)"},
+      {"(a)[b] : (b - a - 1 >= 0)", "(a)[b] -> (a)"},
+  });
+
+  SymbolicLexOpt lexmax2 =
+      parseRelationFromSet("(i, j)[N] : (i >= 0, j >= 0, N - i - j >= 0)", 1)
+          .findSymbolicIntegerLexMax();
+
+  PWMAFunction expectedLexmax2 = parsePWMAF({
+      {"(i)[N] : (i >= 0, N - i >= 0)", "(i)[N] -> (N - i)"},
+  });
+
+  SymbolicLexOpt lexmax3 =
+      parseRelationFromSet("(x, y)[N] : (x >= 0, 2 * N - x >= 0, y >= 0, x - y "
+                           "+ 2 * N >= 0, 4 * N - x - y >= 0)",
+                           1)
+          .findSymbolicIntegerLexMax();
+
+  PWMAFunction expectedLexmax3 =
+      parsePWMAF({{"(x)[N] : (x >= 0, 2 * N - x >= 0, x - N - 1 >= 0)",
+                   "(x)[N] -> (4 * N - x)"},
+                  {"(x)[N] : (x >= 0, 2 * N - x >= 0, -x + N >= 0)",
+                   "(x)[N] -> (x + 2 * N)"}});
+
+  EXPECT_TRUE(lexmax1.unboundedDomain.isIntegerEmpty());
+  EXPECT_TRUE(lexmax1.lexopt.isEqual(expectedLexmax1));
+  EXPECT_TRUE(lexmax2.unboundedDomain.isIntegerEmpty());
+  EXPECT_TRUE(lexmax2.lexopt.isEqual(expectedLexmax2));
+  EXPECT_TRUE(lexmax3.unboundedDomain.isIntegerEmpty());
+  EXPECT_TRUE(lexmax3.lexopt.isEqual(expectedLexmax3));
 }


        


More information about the Mlir-commits mailing list