[Mlir-commits] [mlir] 8a7ead6 - [MLIR][Presburger] Support computing a representation of a set that only has locals that are divs

Arjun P llvmlistbot at llvm.org
Sat Jun 25 06:23:34 PDT 2022


Author: Arjun P
Date: 2022-06-25T14:23:32+01:00
New Revision: 8a7ead691bad29b86017d9e42fa63a57c8c0d629

URL: https://github.com/llvm/llvm-project/commit/8a7ead691bad29b86017d9e42fa63a57c8c0d629
DIFF: https://github.com/llvm/llvm-project/commit/8a7ead691bad29b86017d9e42fa63a57c8c0d629.diff

LOG: [MLIR][Presburger] Support computing a representation of a set that only has locals that are divs

This paves the way for integer-exact projection, and for supporting
non-division locals in subtraction, complement, and equality checks.

Reviewed By: Groverkss

Differential Revision: https://reviews.llvm.org/D127463

Added: 
    

Modified: 
    mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
    mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h
    mlir/include/mlir/Analysis/Presburger/Simplex.h
    mlir/lib/Analysis/Presburger/IntegerRelation.cpp
    mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
    mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
index 4a866c17dd3b1..935307d4bb88e 100644
--- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
+++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
@@ -26,6 +26,7 @@ namespace presburger {
 
 class IntegerRelation;
 class IntegerPolyhedron;
+class PresburgerSet;
 
 /// An IntegerRelation represents the set of points from a PresburgerSpace that
 /// satisfy a list of affine constraints. Affine constraints can be inequalities
@@ -93,6 +94,17 @@ class IntegerRelation {
   /// Returns a reference to the underlying space.
   const PresburgerSpace &getSpace() const { return space; }
 
+  /// Set the space to `oSpace`, which should have the same number of ids as
+  /// the current space.
+  void setSpace(const PresburgerSpace &oSpace);
+
+  /// Set the space to `oSpace`, which should not have any local ids.
+  /// `oSpace` can have fewer ids than the current space; in that case, the
+  /// the extra ids in `this` that are not accounted for by `oSpace` will be
+  /// considered as local ids. `oSpace` should not have more ids than the
+  /// current space; this will result in an assert failure.
+  void setSpaceExceptLocals(const PresburgerSpace &oSpace);
+
   /// Returns a copy of the space without locals.
   PresburgerSpace getSpaceWithoutLocals() const {
     return PresburgerSpace::getRelationSpace(space.getNumDomainIds(),
@@ -497,6 +509,9 @@ class IntegerRelation {
   /// locals that have been added to `this`.
   unsigned mergeLocalIds(IntegerRelation &other);
 
+  /// Check whether all local ids have a division representation.
+  bool hasOnlyDivLocals() const;
+
   /// Changes the partition between dimensions and symbols. Depending on the new
   /// symbol count, either a chunk of dimensional identifiers immediately before
   /// the split become symbols, or some of the symbols immediately after the
@@ -739,6 +754,12 @@ class IntegerPolyhedron : public IntegerRelation {
   /// first added identifier.
   unsigned insertId(IdKind kind, unsigned pos, unsigned num = 1) override;
 
+  /// Compute an equivalent representation of the same set, such that all local
+  /// ids have division representations. This representation may involve
+  /// local ids that correspond to divisions, and may also be a union of convex
+  /// disjuncts.
+  PresburgerSet computeReprWithOnlyDivLocals() const;
+
   /// Compute the symbolic integer lexmin of the polyhedron.
   /// This finds, for every assignment to the symbols, the lexicographically
   /// minimum value attained by the dimensions. For example, the symbolic lexmin

diff  --git a/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h b/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h
index e4aa36599537b..89a3deb30e689 100644
--- a/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h
+++ b/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h
@@ -55,6 +55,14 @@ class PresburgerRelation {
 
   const PresburgerSpace &getSpace() const { return space; }
 
+  /// Set the space to `oSpace`. `oSpace` should not contain any local ids.
+  /// `oSpace` need not have the same number of ids as the current space;
+  /// it could have more or less. If it has less, the extra ids become
+  /// locals of the disjuncts. It can also have more, in which case the
+  /// disjuncts will have fewer locals. If its total number of ids
+  /// exceeds that of some disjunct, an assert failure will occur.
+  void setSpace(const PresburgerSpace &oSpace);
+
   /// Return a reference to the list of disjuncts.
   ArrayRef<IntegerRelation> getAllDisjuncts() const;
 
@@ -117,6 +125,9 @@ class PresburgerRelation {
   /// disjuncts in the union.
   PresburgerRelation coalesce() const;
 
+  /// Check whether all local ids in all disjuncts have a div representation.
+  bool hasOnlyDivLocals() const;
+
   /// Print the set's internal state.
   void print(raw_ostream &os) const;
   void dump() const;

diff  --git a/mlir/include/mlir/Analysis/Presburger/Simplex.h b/mlir/include/mlir/Analysis/Presburger/Simplex.h
index 4caaf78bd4716..f583c59ef24be 100644
--- a/mlir/include/mlir/Analysis/Presburger/Simplex.h
+++ b/mlir/include/mlir/Analysis/Presburger/Simplex.h
@@ -572,10 +572,28 @@ class SymbolicLexSimplex : public LexSimplexBase {
   /// `constraints`, and no other ids.
   SymbolicLexSimplex(const IntegerPolyhedron &constraints,
                      const IntegerPolyhedron &symbolDomain)
-      : LexSimplexBase(constraints), domainPoly(symbolDomain),
-        domainSimplex(symbolDomain) {
-    assert(domainPoly.getNumIds() == constraints.getNumSymbolIds());
-    assert(domainPoly.getNumDimIds() == constraints.getNumSymbolIds());
+      : SymbolicLexSimplex(constraints,
+                           constraints.getIdKindOffset(IdKind::Symbol),
+                           symbolDomain) {
+    assert(constraints.getNumSymbolIds() == symbolDomain.getNumIds());
+  }
+
+  /// An overload to select some other subrange of ids as symbols for lexmin.
+  /// The symbol ids are the range of ids with absolute index
+  /// [symbolOffset, symbolOffset + symbolDomain.getNumIds())
+  /// symbolDomain should only have dim ids.
+  SymbolicLexSimplex(const IntegerPolyhedron &constraints,
+                     unsigned symbolOffset,
+                     const IntegerPolyhedron &symbolDomain)
+      : LexSimplexBase(/*nVar=*/constraints.getNumIds(), symbolOffset,
+                       symbolDomain.getNumIds()),
+        domainPoly(symbolDomain), domainSimplex(symbolDomain) {
+    // TODO consider supporting this case. It amounts
+    // to just returning the input constraints.
+    assert(domainPoly.getNumIds() > 0 &&
+           "there must be some non-symbols to optimize!");
+    assert(domainPoly.getNumIds() == domainPoly.getNumDimIds());
+    intersectIntegerRelation(constraints);
   }
 
   /// The lexmin will be stored as a function `lexmin` from symbols to
@@ -583,6 +601,9 @@ class SymbolicLexSimplex : public LexSimplexBase {
   ///
   /// For some values of the symbols, the lexmin may be unbounded.
   /// These parts of the symbol domain will be stored in `unboundedDomain`.
+  ///
+  /// The spaces of the sets in the result are compatible with the symbolDomain
+  /// passed in the SymbolicLexSimplex constructor.
   SymbolicLexMin computeSymbolicIntegerLexMin();
 
 private:

diff  --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
index 25d89f93d93d1..7376747663e62 100644
--- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
@@ -38,6 +38,19 @@ std::unique_ptr<IntegerPolyhedron> IntegerPolyhedron::clone() const {
   return std::make_unique<IntegerPolyhedron>(*this);
 }
 
+void IntegerRelation::setSpace(const PresburgerSpace &oSpace) {
+  assert(space.getNumIds() == oSpace.getNumIds() && "invalid space!");
+  space = oSpace;
+}
+
+void IntegerRelation::setSpaceExceptLocals(const PresburgerSpace &oSpace) {
+  assert(oSpace.getNumLocalIds() == 0 && "no locals should be present!");
+  assert(oSpace.getNumIds() <= getNumIds() && "invalid space!");
+  unsigned newNumLocals = getNumIds() - oSpace.getNumIds();
+  space = oSpace;
+  space.insertId(IdKind::Local, 0, newNumLocals);
+}
+
 void IntegerRelation::append(const IntegerRelation &other) {
   assert(space.isEqual(other.getSpace()) && "Spaces must be equal.");
 
@@ -152,6 +165,67 @@ void IntegerRelation::truncate(const CountsSnapshot &counts) {
   removeEqualityRange(counts.getNumEqs(), getNumEqualities());
 }
 
+PresburgerSet IntegerPolyhedron::computeReprWithOnlyDivLocals() const {
+  // If there are no locals, we're done.
+  if (getNumLocalIds() == 0)
+    return PresburgerSet(*this);
+
+  // Move all the non-div locals to the end, as the current API to
+  // SymbolicLexMin requires these to form a contiguous range.
+  //
+  // Take a copy so we can perform mutations.
+  IntegerPolyhedron copy = *this;
+  std::vector<MaybeLocalRepr> reprs;
+  copy.getLocalReprs(reprs);
+
+  // Iterate through all the locals. The last `numNonDivLocals` are the locals
+  // that have been scanned already and do not have division representations.
+  unsigned numNonDivLocals = 0;
+  unsigned offset = copy.getIdKindOffset(IdKind::Local);
+  for (unsigned i = 0, e = copy.getNumLocalIds(); i < e - numNonDivLocals;) {
+    if (!reprs[i]) {
+      // Whenever we come across a local that does not have a division
+      // representation, we swap it to the `numNonDivLocals`-th last position
+      // and increment `numNonDivLocal`s. `reprs` also needs to be swapped.
+      copy.swapId(offset + i, offset + e - numNonDivLocals - 1);
+      std::swap(reprs[i], reprs[e - numNonDivLocals - 1]);
+      ++numNonDivLocals;
+      continue;
+    }
+    ++i;
+  }
+
+  // If there are no non-div locals, we're done.
+  if (numNonDivLocals == 0)
+    return PresburgerSet(*this);
+
+  // We computeSymbolicIntegerLexMin by considering the non-div locals as
+  // "non-symbols" and considering everything else as "symbols". This will
+  // compute a function mapping assignments to "symbols" to the
+  // lexicographically minimal valid assignment of "non-symbols", when a
+  // satisfying assignment exists. It separately returns the set of assignments
+  // to the "symbols" such that a satisfying assignment to the "non-symbols"
+  // exists but the lexmin is unbounded. We basically want to find the set of
+  // values of the "symbols" such that an assignment to the "non-symbols"
+  // exists, which is the union of the domain of the returned lexmin function
+  // and the returned set of assignments to the "symbols" that makes the lexmin
+  // unbounded.
+  SymbolicLexMin lexminResult =
+      SymbolicLexSimplex(copy, /*symbolOffset*/ 0,
+                         IntegerPolyhedron(PresburgerSpace::getSetSpace(
+                             /*numDims=*/copy.getNumIds() - numNonDivLocals)))
+          .computeSymbolicIntegerLexMin();
+  PresburgerSet result =
+      lexminResult.lexmin.getDomain().unionSet(lexminResult.unboundedDomain);
+
+  // The result set might lie in the wrong space -- all its ids are dims.
+  // Set it to the desired space and return.
+  PresburgerSpace space = getSpace();
+  space.removeIdRange(IdKind::Local, 0, getNumLocalIds());
+  result.setSpace(space);
+  return result;
+}
+
 SymbolicLexMin IntegerPolyhedron::findSymbolicIntegerLexMin() const {
   // Compute the symbolic lexmin of the dims and locals, with the symbols being
   // the actual symbols of this set.
@@ -1120,6 +1194,13 @@ unsigned IntegerRelation::mergeLocalIds(IntegerRelation &other) {
   return relA.getNumLocalIds() - oldALocals;
 }
 
+bool IntegerRelation::hasOnlyDivLocals() const {
+  std::vector<MaybeLocalRepr> reprs;
+  getLocalReprs(reprs);
+  return llvm::all_of(reprs,
+                      [](const MaybeLocalRepr &repr) { return bool(repr); });
+}
+
 void IntegerRelation::removeDuplicateDivs() {
   std::vector<SmallVector<int64_t, 8>> divs;
   SmallVector<unsigned, 4> denoms;

diff  --git a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
index 9ce59d769d43c..1b5d48b9cf896 100644
--- a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
@@ -21,6 +21,13 @@ PresburgerRelation::PresburgerRelation(const IntegerRelation &disjunct)
   unionInPlace(disjunct);
 }
 
+void PresburgerRelation::setSpace(const PresburgerSpace &oSpace) {
+  assert(space.getNumLocalIds() == 0 && "no locals should be present");
+  space = oSpace;
+  for (IntegerRelation &disjunct : disjuncts)
+    disjunct.setSpaceExceptLocals(space);
+}
+
 unsigned PresburgerRelation::getNumDisjuncts() const {
   return disjuncts.size();
 }
@@ -770,6 +777,12 @@ PresburgerRelation PresburgerRelation::coalesce() const {
   return SetCoalescer(*this).coalesce();
 }
 
+bool PresburgerRelation::hasOnlyDivLocals() const {
+  return llvm::all_of(disjuncts, [](const IntegerRelation &rel) {
+    return rel.hasOnlyDivLocals();
+  });
+}
+
 void PresburgerRelation::print(raw_ostream &os) const {
   os << "Number of Disjuncts: " << getNumDisjuncts() << "\n";
   for (const IntegerRelation &disjunct : disjuncts) {

diff  --git a/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp b/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp
index ba3a0024f9732..0c98f488ea074 100644
--- a/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp
@@ -751,6 +751,54 @@ TEST(SetTest, computeVolume) {
                                         /*resultBound=*/{});
 }
 
+// The last `numToProject` dims will be projected out, i.e., converted to
+// locals.
+void testComputeReprAtPoints(IntegerPolyhedron poly,
+                             ArrayRef<SmallVector<int64_t, 4>> points,
+                             unsigned numToProject) {
+  poly.convertIdKind(IdKind::SetDim, poly.getNumDimIds() - numToProject,
+                     poly.getNumDimIds(), IdKind::Local);
+  PresburgerSet repr = poly.computeReprWithOnlyDivLocals();
+  EXPECT_TRUE(repr.hasOnlyDivLocals());
+  EXPECT_TRUE(repr.getSpace().isCompatible(poly.getSpace()));
+  for (const SmallVector<int64_t, 4> &point : points) {
+    EXPECT_EQ(poly.containsPointNoLocal(point).hasValue(),
+              repr.containsPoint(point));
+  }
+}
+
+void testComputeRepr(IntegerPolyhedron poly, const PresburgerSet &expected,
+                     unsigned numToProject) {
+  poly.convertIdKind(IdKind::SetDim, poly.getNumDimIds() - numToProject,
+                     poly.getNumDimIds(), IdKind::Local);
+  PresburgerSet repr = poly.computeReprWithOnlyDivLocals();
+  EXPECT_TRUE(repr.hasOnlyDivLocals());
+  EXPECT_TRUE(repr.getSpace().isCompatible(poly.getSpace()));
+  EXPECT_TRUE(repr.isEqual(expected));
+}
+
+TEST(SetTest, computeReprWithOnlyDivLocals) {
+  testComputeReprAtPoints(parsePoly("(x, y) : (x - 2*y == 0)"),
+                          {{1, 0}, {2, 1}, {3, 0}, {4, 2}, {5, 3}},
+                          /*numToProject=*/0);
+  testComputeReprAtPoints(parsePoly("(x, e) : (x - 2*e == 0)"),
+                          {{1}, {2}, {3}, {4}, {5}}, /*numToProject=*/1);
+
+  // Tests to check that the space is preserved.
+  testComputeReprAtPoints(parsePoly("(x, y)[z, w] : ()"), {},
+                          /*numToProject=*/1);
+  testComputeReprAtPoints(parsePoly("(x, y)[z, w] : (z - (w floordiv 2) == 0)"),
+                          {},
+                          /*numToProject=*/1);
+
+  // Bezout's lemma: if a, b are constants,
+  // the set of values that ax + by can take is all multiples of gcd(a, b).
+  testComputeRepr(
+      parsePoly("(x, e, f) : (x - 15*e - 21*f == 0)"),
+      PresburgerSet(parsePoly({"(x) : (x - 3*(x floordiv 3) == 0)"})),
+      /*numToProject=*/2);
+}
+
 TEST(SetTest, subtractOutputSizeRegression) {
   PresburgerSet set1 =
       parsePresburgerSetFromPolyStrings(1, {"(i) : (i >= 0, 10 - i >= 0)"});


        


More information about the Mlir-commits mailing list