[Mlir-commits] [mlir] 4a57f5d - [MLIR] PresburgerSet: support divisions in operations

Arjun P llvmlistbot at llvm.org
Fri Sep 24 03:06:51 PDT 2021


Author: Arjun P
Date: 2021-09-24T15:36:47+05:30
New Revision: 4a57f5d1e1c5eff98fd03932f9a0f8efa13c3a77

URL: https://github.com/llvm/llvm-project/commit/4a57f5d1e1c5eff98fd03932f9a0f8efa13c3a77
DIFF: https://github.com/llvm/llvm-project/commit/4a57f5d1e1c5eff98fd03932f9a0f8efa13c3a77.diff

LOG: [MLIR] PresburgerSet: support divisions in operations

Add support for intersecting, subtracting, complementing and checking equality of sets having divisions.

Reviewed By: bondhugula

Differential Revision: https://reviews.llvm.org/D110138

Added: 
    

Modified: 
    mlir/include/mlir/Analysis/PresburgerSet.h
    mlir/lib/Analysis/PresburgerSet.cpp
    mlir/unittests/Analysis/PresburgerSetTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/PresburgerSet.h b/mlir/include/mlir/Analysis/PresburgerSet.h
index 4863c876745aa..41dfc0018efab 100644
--- a/mlir/include/mlir/Analysis/PresburgerSet.h
+++ b/mlir/include/mlir/Analysis/PresburgerSet.h
@@ -67,17 +67,18 @@ class PresburgerSet {
   void print(raw_ostream &os) const;
   void dump() const;
 
-  /// Return the complement of this set. Computing the complement of a set
-  /// containing divisions is not yet supported.
+  /// Return the complement of this set. All local variables in the set must
+  /// correspond to floor divisions.
   PresburgerSet complement() const;
 
   /// Return the set 
diff erence of this set and the given set, i.e.,
-  /// return `this \ set`. Subtracting when either set contains divisions is not
-  /// yet supported.
+  /// return `this \ set`. All local variables in `set` must correspond
+  /// to floor divisions, but local variables in `this` need not correspond to
+  /// divisions.
   PresburgerSet subtract(const PresburgerSet &set) const;
 
   /// Return true if this set is equal to the given set, and false otherwise.
-  /// Checking equality when either set contains divisions is not yet supported.
+  /// All local variables in both sets must correspond to floor divisions.
   bool isEqual(const PresburgerSet &set) const;
 
   /// Return a universe set of the specified type that contains all points.

diff  --git a/mlir/lib/Analysis/PresburgerSet.cpp b/mlir/lib/Analysis/PresburgerSet.cpp
index 81fc8cc1eafd2..7beeb99d340bb 100644
--- a/mlir/lib/Analysis/PresburgerSet.cpp
+++ b/mlir/lib/Analysis/PresburgerSet.cpp
@@ -106,16 +106,20 @@ PresburgerSet PresburgerSet::getEmptySet(unsigned nDim, unsigned nSym) {
 //
 // We directly compute (S_1 or S_2 ...) and (T_1 or T_2 ...)
 // as (S_1 and T_1) or (S_1 and T_2) or ...
+//
+// If S_i or T_j have local variables, then S_i and T_j contains the local
+// variables of both.
 PresburgerSet PresburgerSet::intersect(const PresburgerSet &set) const {
   assertDimensionsCompatible(set, *this);
 
   PresburgerSet result(nDim, nSym);
   for (const FlatAffineConstraints &csA : flatAffineConstraints) {
     for (const FlatAffineConstraints &csB : set.flatAffineConstraints) {
-      FlatAffineConstraints intersection(csA);
-      intersection.append(csB);
-      if (!intersection.isEmpty())
-        result.unionFACInPlace(std::move(intersection));
+      FlatAffineConstraints csACopy = csA, csBCopy = csB;
+      csACopy.mergeLocalIds(csBCopy);
+      csACopy.append(std::move(csBCopy));
+      if (!csACopy.isEmpty())
+        result.unionFACInPlace(std::move(csACopy));
     }
   }
   return result;
@@ -160,6 +164,17 @@ static SmallVector<int64_t, 8> getComplementIneq(ArrayRef<int64_t> ineq) {
 /// returning the union of the results. Each equality is handled as a
 /// conjunction of two inequalities.
 ///
+/// Note that the same approach works even if an inequality involves a floor
+/// division. For example, the complement of x <= 7*floor(x/7) is still
+/// x > 7*floor(x/7). Since b \ s_i contains the inequalities of both b and s_i
+/// (or the complements of those inequalities), b \ s_i may contain the
+/// divisions present in both b and s_i. Therefore, we need to add the local
+/// division variables of both b and s_i to each part in the result. This means
+/// adding the local variables of both b and s_i, as well as the corresponding
+/// division inequalities to each part. Since the division inequalities are
+/// added to each part, we can skip the parts where the complement of any
+/// division inequality is added, as these parts will become empty anyway.
+///
 /// As a heuristic, we try adding all the constraints and check if simplex
 /// says that the intersection is empty. If it is, then subtracting this FAC is
 /// a no-op and we just skip it. Also, in the process we find out that some
@@ -174,27 +189,63 @@ static void subtractRecursively(FlatAffineConstraints &b, Simplex &simplex,
     result.unionFACInPlace(b);
     return;
   }
-  const FlatAffineConstraints &sI = s.getFlatAffineConstraints(i);
-  assert(sI.getNumLocalIds() == 0 &&
-         "Subtracting sets with divisions is not yet supported!");
+  FlatAffineConstraints sI = s.getFlatAffineConstraints(i);
+  unsigned bInitNumLocals = b.getNumLocalIds();
+
+  // Find out which inequalities of sI correspond to division inequalities for
+  // the local variables of sI.
+  std::vector<llvm::Optional<std::pair<unsigned, unsigned>>> repr(
+      sI.getNumLocalIds());
+  sI.getLocalReprLbUbPairs(repr);
+
+  // Add sI's locals to b, after b's locals. Also add b's locals to sI, before
+  // sI's locals.
+  b.mergeLocalIds(sI);
+
+  // Mark which inequalities of sI are division inequalities and add all such
+  // inequalities to b.
+  llvm::SmallBitVector isDivInequality(sI.getNumInequalities());
+  for (Optional<std::pair<unsigned, unsigned>> &maybePair : repr) {
+    assert(maybePair &&
+           "Subtraction is not supported when a representation of the local "
+           "variables of the subtrahend cannot be found!");
+
+    b.addInequality(sI.getInequality(maybePair->first));
+    b.addInequality(sI.getInequality(maybePair->second));
+
+    assert(maybePair->first != maybePair->second &&
+           "Upper and lower bounds must be 
diff erent inequalities!");
+    isDivInequality[maybePair->first] = true;
+    isDivInequality[maybePair->second] = true;
+  }
+
   unsigned initialSnapshot = simplex.getSnapshot();
   unsigned offset = simplex.getNumConstraints();
+  unsigned numLocalsAdded = b.getNumLocalIds() - bInitNumLocals;
+  simplex.appendVariable(numLocalsAdded);
+
+  unsigned snapshotBeforeIntersect = simplex.getSnapshot();
   simplex.intersectFlatAffineConstraints(sI);
 
   if (simplex.isEmpty()) {
     /// b ^ s_i is empty, so b \ s_i = b. We move directly to i + 1.
     simplex.rollback(initialSnapshot);
+    b.removeIdRange(FlatAffineConstraints::IdKind::Local, bInitNumLocals,
+                    b.getNumLocalIds());
     subtractRecursively(b, simplex, s, i + 1, result);
     return;
   }
 
   simplex.detectRedundant();
-  llvm::SmallBitVector isMarkedRedundant;
-  for (unsigned j = 0; j < 2 * sI.getNumEqualities() + sI.getNumInequalities();
-       j++)
-    isMarkedRedundant.push_back(simplex.isMarkedRedundant(offset + j));
 
-  simplex.rollback(initialSnapshot);
+  // Equalities are added to simplex as a pair of inequalities.
+  unsigned totalNewSimplexInequalities =
+      2 * sI.getNumEqualities() + sI.getNumInequalities();
+  llvm::SmallBitVector isMarkedRedundant(totalNewSimplexInequalities);
+  for (unsigned j = 0; j < totalNewSimplexInequalities; j++)
+    isMarkedRedundant[j] = simplex.isMarkedRedundant(offset + j);
+
+  simplex.rollback(snapshotBeforeIntersect);
 
   // Recurse with the part b ^ ~ineq. Note that b is modified throughout
   // subtractRecursively. At the time this function is called, the current b is
@@ -223,20 +274,28 @@ static void subtractRecursively(FlatAffineConstraints &b, Simplex &simplex,
   // rollback b to its initial state before returning, which we will do by
   // removing all constraints beyond the original number of inequalities
   // and equalities, so we store these counts first.
-  unsigned originalNumIneqs = b.getNumInequalities();
-  unsigned originalNumEqs = b.getNumEqualities();
+  unsigned bInitNumIneqs = b.getNumInequalities();
+  unsigned bInitNumEqs = b.getNumEqualities();
 
+  // Process all the inequalities, ignoring redundant inequalities and division
+  // inequalities. The result is correct whether or not we ignore these, but
+  // ignoring them makes the result simpler.
   for (unsigned j = 0, e = sI.getNumInequalities(); j < e; j++) {
     if (isMarkedRedundant[j])
       continue;
+    if (isDivInequality[j])
+      continue;
     processInequality(sI.getInequality(j));
   }
 
   offset = sI.getNumInequalities();
   for (unsigned j = 0, e = sI.getNumEqualities(); j < e; ++j) {
-    const ArrayRef<int64_t> &coeffs = sI.getEquality(j);
-    // Same as the above loop for inequalities, done once each for the positive
-    // and negative inequalities that make up this equality.
+    ArrayRef<int64_t> coeffs = sI.getEquality(j);
+    // For each equality, process the positive and negative inequalities that
+    // make up this equality. If Simplex found an inequality to be redundant, we
+    // skip it as above to make the result simpler. Divisions are always
+    // represented in terms of inequalities and not equalities, so we do not
+    // check for division inequalities here.
     if (!isMarkedRedundant[offset + 2 * j])
       processInequality(coeffs);
     if (!isMarkedRedundant[offset + 2 * j + 1])
@@ -244,11 +303,10 @@ static void subtractRecursively(FlatAffineConstraints &b, Simplex &simplex,
   }
 
   // Rollback b and simplex to their initial states.
-  for (unsigned i = b.getNumInequalities(); i > originalNumIneqs; --i)
-    b.removeInequality(i - 1);
-
-  for (unsigned i = b.getNumEqualities(); i > originalNumEqs; --i)
-    b.removeEquality(i - 1);
+  b.removeIdRange(FlatAffineConstraints::IdKind::Local, bInitNumLocals,
+                  b.getNumLocalIds());
+  b.removeInequalityRange(bInitNumIneqs, b.getNumInequalities());
+  b.removeEqualityRange(bInitNumEqs, b.getNumEqualities());
 
   simplex.rollback(initialSnapshot);
 }
@@ -261,8 +319,6 @@ static void subtractRecursively(FlatAffineConstraints &b, Simplex &simplex,
 PresburgerSet PresburgerSet::getSetDifference(FlatAffineConstraints fac,
                                               const PresburgerSet &set) {
   assertDimensionsCompatible(fac, set);
-  assert(fac.getNumLocalIds() == 0 &&
-         "Subtracting sets with divisions is not yet supported!");
   if (fac.isEmptyByGCDTest())
     return PresburgerSet::getEmptySet(fac.getNumDimIds(),
                                       fac.getNumSymbolIds());

diff  --git a/mlir/unittests/Analysis/PresburgerSetTest.cpp b/mlir/unittests/Analysis/PresburgerSetTest.cpp
index ac065fa74879c..4ae76d7f0c329 100644
--- a/mlir/unittests/Analysis/PresburgerSetTest.cpp
+++ b/mlir/unittests/Analysis/PresburgerSetTest.cpp
@@ -80,12 +80,17 @@ static void testComplementAtPoints(PresburgerSet s,
 }
 
 /// Construct a FlatAffineConstraints from a set of inequality and
-/// equality constraints.
+/// equality constraints. `numIds` is the total number of ids, of which
+/// `numLocals` is the number of local ids.
 static FlatAffineConstraints
-makeFACFromConstraints(unsigned dims, ArrayRef<SmallVector<int64_t, 4>> ineqs,
-                       ArrayRef<SmallVector<int64_t, 4>> eqs) {
-  FlatAffineConstraints fac(ineqs.size(), eqs.size(), dims + 1, dims,
-                            /*numSymbols=*/0, /*numLocals=*/0);
+makeFACFromConstraints(unsigned numIds, ArrayRef<SmallVector<int64_t, 4>> ineqs,
+                       ArrayRef<SmallVector<int64_t, 4>> eqs,
+                       unsigned numLocals = 0) {
+  FlatAffineConstraints fac(/*numReservedInequalities=*/ineqs.size(),
+                            /*numReservedEqualities=*/eqs.size(),
+                            /*numReservedCols=*/numIds + 1,
+                            /*numDims=*/numIds - numLocals,
+                            /*numSymbols=*/0, numLocals);
   for (const SmallVector<int64_t, 4> &eq : eqs)
     fac.addEquality(eq);
   for (const SmallVector<int64_t, 4> &ineq : ineqs)
@@ -93,14 +98,22 @@ makeFACFromConstraints(unsigned dims, ArrayRef<SmallVector<int64_t, 4>> ineqs,
   return fac;
 }
 
+/// Construct a FlatAffineConstraints having `numDims` dimensions from the given
+/// set of inequality constraints. This is a convenience function to be used
+/// when the FAC to be constructed does not have any local ids and does not have
+/// equalties.
 static FlatAffineConstraints
-makeFACFromIneqs(unsigned dims, ArrayRef<SmallVector<int64_t, 4>> ineqs) {
-  return makeFACFromConstraints(dims, ineqs, {});
+makeFACFromIneqs(unsigned numDims, ArrayRef<SmallVector<int64_t, 4>> ineqs) {
+  return makeFACFromConstraints(numDims, ineqs, /*eqs=*/{});
 }
 
-static PresburgerSet makeSetFromFACs(unsigned dims,
+/// Construct a PresburgerSet having `numDims` dimensions and no symbols from
+/// the given list of FlatAffineConstraints. Each FAC in `facs` should also have
+/// `numDims` dimensions and no symbols, although it can have any number of
+/// local ids.
+static PresburgerSet makeSetFromFACs(unsigned numDims,
                                      ArrayRef<FlatAffineConstraints> facs) {
-  PresburgerSet set = PresburgerSet::getEmptySet(dims);
+  PresburgerSet set = PresburgerSet::getEmptySet(numDims);
   for (const FlatAffineConstraints &fac : facs)
     set.unionFACInPlace(fac);
   return set;
@@ -592,4 +605,37 @@ TEST(SetTest, isEqual) {
   EXPECT_FALSE(rect.complement().isEqual(square.complement()));
 }
 
+void expectEqual(PresburgerSet s, PresburgerSet t) {
+  EXPECT_TRUE(s.isEqual(t));
+}
+
+void expectEmpty(PresburgerSet s) { EXPECT_TRUE(s.isIntegerEmpty()); }
+
+TEST(SetTest, divisions) {
+  // Note: we currently need to add the equalities as inequalities to the FAC
+  // since detecting divisions based on equalities is not yet supported.
+
+  // evens = {x : exists q, x = 2q}.
+  PresburgerSet evens{
+      makeFACFromConstraints(2, {{1, -2, 0}, {-1, 2, 1}}, {{1, -2, 0}}, 1)};
+  // odds = {x : exists q, x = 2q + 1}.
+  PresburgerSet odds{
+      makeFACFromConstraints(2, {{1, -2, 0}, {-1, 2, 1}}, {{1, -2, -1}}, 1)};
+  // multiples6 = {x : exists q, x = 6q}.
+  PresburgerSet multiples3{
+      makeFACFromConstraints(2, {{1, -3, 0}, {-1, 3, 2}}, {{1, -3, 0}}, 1)};
+  // multiples6 = {x : exists q, x = 6q}.
+  PresburgerSet multiples6{
+      makeFACFromConstraints(2, {{1, -6, 0}, {-1, 6, 5}}, {{1, -6, 0}}, 1)};
+
+  // evens /\ odds = empty.
+  expectEmpty(PresburgerSet(evens).intersect(PresburgerSet(odds)));
+  // evens U odds = universe.
+  expectEqual(evens.unionSet(odds), PresburgerSet::getUniverse(1));
+  expectEqual(evens.complement(), odds);
+  expectEqual(odds.complement(), evens);
+  // even multiples of 3 = multiples of 6.
+  expectEqual(multiples3.intersect(evens), multiples6);
+}
+
 } // namespace mlir


        


More information about the Mlir-commits mailing list