[Mlir-commits] [mlir] fd26d86 - [MLIR][Presburger] subtract: fix support for divs defined by equalities

Arjun P llvmlistbot at llvm.org
Tue Jun 28 12:24:48 PDT 2022


Author: Arjun P
Date: 2022-06-28T20:24:51+01:00
New Revision: fd26d86f5f662f71337e4ce266f122564e25466d

URL: https://github.com/llvm/llvm-project/commit/fd26d86f5f662f71337e4ce266f122564e25466d
DIFF: https://github.com/llvm/llvm-project/commit/fd26d86f5f662f71337e4ce266f122564e25466d.diff

LOG: [MLIR][Presburger] subtract: fix support for divs defined by equalities

Also added test cases to test this. Both IntegerRelation::addLocalFloorDiv and the fixed implementation of subtraction need to compute division inequalities from dividend and divisor, so this also adds helper util functions to avoid duplicating this logic.

Reviewed By: Groverkss

Differential Revision: https://reviews.llvm.org/D128736

Added: 
    

Modified: 
    mlir/include/mlir/Analysis/Presburger/Utils.h
    mlir/lib/Analysis/Presburger/IntegerRelation.cpp
    mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
    mlir/lib/Analysis/Presburger/Utils.cpp
    mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/Presburger/Utils.h b/mlir/include/mlir/Analysis/Presburger/Utils.h
index 5e887fbf0cab..c735ddd037c8 100644
--- a/mlir/include/mlir/Analysis/Presburger/Utils.h
+++ b/mlir/include/mlir/Analysis/Presburger/Utils.h
@@ -101,6 +101,25 @@ struct MaybeLocalRepr {
   } repr;
 };
 
+/// If `q` is defined to be equal to `expr floordiv d`, this equivalent to
+/// saying that `q` is an integer and `q` is subject to the inequalities
+/// `0 <= expr - d*q <= c - 1` (quotient remainder theorem).
+///
+/// Rearranging, we get the bounds on `q`: d*q <= expr <= d*q + d - 1.
+///
+/// `getDivUpperBound` returns `d*q <= expr`, and
+/// `getDivLowerBound` returns `expr <= d*q + d - 1`.
+///
+/// The parameter `dividend` corresponds to `expr` above, `divisor` to `d`, and
+/// `localVarIdx` to the position of `q` in the coefficient list.
+///
+/// The coefficient of `q` in `dividend` must be zero, as it is not allowed for
+/// local variable to be a floor division of an expression involving itself.
+SmallVector<int64_t, 8> getDivUpperBound(ArrayRef<int64_t> dividend,
+                                         int64_t divisor, unsigned localVarIdx);
+SmallVector<int64_t, 8> getDivLowerBound(ArrayRef<int64_t> dividend,
+                                         int64_t divisor, unsigned localVarIdx);
+
 /// Check if the pos^th variable can be expressed as a floordiv of an affine
 /// function of other variables (where the divisor is a positive constant).
 /// `foundRepr` contains a boolean for each variable indicating if the

diff  --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
index d9c3c86105b2..ef23acbab8f9 100644
--- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
@@ -1326,21 +1326,12 @@ void IntegerRelation::addLocalFloorDiv(ArrayRef<int64_t> dividend,
 
   appendVar(VarKind::Local);
 
-  // Add two constraints for this new variable 'q'.
-  SmallVector<int64_t, 8> bound(dividend.size() + 1);
-
-  // dividend - q * divisor >= 0
-  std::copy(dividend.begin(), dividend.begin() + dividend.size() - 1,
-            bound.begin());
-  bound.back() = dividend.back();
-  bound[getNumVars() - 1] = -divisor;
-  addInequality(bound);
-
-  // -dividend +qdivisor * q + divisor - 1 >= 0
-  std::transform(bound.begin(), bound.end(), bound.begin(),
-                 std::negate<int64_t>());
-  bound[bound.size() - 1] += divisor - 1;
-  addInequality(bound);
+  SmallVector<int64_t, 8> dividendCopy(dividend.begin(), dividend.end());
+  dividendCopy.insert(dividendCopy.end() - 1, 0);
+  addInequality(
+      getDivLowerBound(dividendCopy, divisor, dividendCopy.size() - 2));
+  addInequality(
+      getDivUpperBound(dividendCopy, divisor, dividendCopy.size() - 2));
 }
 
 /// Finds an equality that equates the specified variable to a constant.
@@ -2281,4 +2272,4 @@ unsigned IntegerPolyhedron::insertVar(VarKind kind, unsigned pos,
   assert((kind != VarKind::Domain || num == 0) &&
          "Domain has to be zero in a set");
   return IntegerRelation::insertVar(kind, pos, num);
-}
\ No newline at end of file
+}

diff  --git a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
index 9131aaeed3ac..8b14df655262 100644
--- a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
@@ -228,30 +228,37 @@ static PresburgerRelation getSetDifference(IntegerRelation b,
       // Similarly, we also want to rollback simplex to its original state.
       unsigned initialSnapshot = simplex.getSnapshot();
 
-      // Find out which inequalities of sI correspond to division inequalities
-      // for the local variables of sI.
-      std::vector<MaybeLocalRepr> repr(sI.getNumLocalVars());
-      sI.getLocalReprs(repr);
-
       // Add sI's locals to b, after b's locals. Only those locals of sI which
       // do not already exist in b will be added. (i.e., duplicate divisions
       // will not be added.) Also add b's locals to sI, in such a way that both
       // have the same locals in the same order in the end.
       b.mergeLocalVars(sI);
 
+      // Find out which inequalities of sI correspond to division inequalities
+      // for the local variables of sI.
+      //
+      // Careful! This has to be done after the merge above; otherwise, the
+      // dividends won't contain the new ids inserted during the merge.
+      std::vector<MaybeLocalRepr> repr;
+      std::vector<SmallVector<int64_t, 8>> dividends;
+      SmallVector<unsigned, 4> divisors;
+      sI.getLocalReprs(dividends, divisors, repr);
+
       // Mark which inequalities of sI are division inequalities and add all
       // such inequalities to b.
       llvm::SmallBitVector canIgnoreIneq(sI.getNumInequalities() +
                                          2 * sI.getNumEqualities());
-      for (MaybeLocalRepr &maybeRepr : repr) {
+      for (unsigned i = initBCounts.getSpace().getNumLocalVars(),
+                    e = sI.getNumLocalVars();
+           i < e; ++i) {
         assert(
-            maybeRepr &&
+            repr[i] &&
             "Subtraction is not supported when a representation of the local "
             "variables of the subtrahend cannot be found!");
 
-        if (maybeRepr.kind == ReprKind::Inequality) {
-          unsigned lb = maybeRepr.repr.inequalityPair.lowerBoundIdx;
-          unsigned ub = maybeRepr.repr.inequalityPair.upperBoundIdx;
+        if (repr[i].kind == ReprKind::Inequality) {
+          unsigned lb = repr[i].repr.inequalityPair.lowerBoundIdx;
+          unsigned ub = repr[i].repr.inequalityPair.upperBoundIdx;
 
           b.addInequality(sI.getInequality(lb));
           b.addInequality(sI.getInequality(ub));
@@ -261,14 +268,30 @@ static PresburgerRelation getSetDifference(IntegerRelation b,
           canIgnoreIneq[lb] = true;
           canIgnoreIneq[ub] = true;
         } else {
-          assert(maybeRepr.kind == ReprKind::Equality &&
+          assert(repr[i].kind == ReprKind::Equality &&
                  "ReprKind isn't inequality so should be equality");
-          unsigned idx = maybeRepr.repr.equalityIdx;
-          b.addEquality(sI.getEquality(idx));
-          // We can ignore both inequalities corresponding to this equality.
-          unsigned offset = sI.getNumInequalities();
-          canIgnoreIneq[offset + 2 * idx] = true;
-          canIgnoreIneq[offset + 2 * idx + 1] = true;
+
+          // Consider the case (x) : (x = 3e + 1), where e is a local.
+          // Its complement is (x) : (x = 3e) or (x = 3e + 2).
+          //
+          // This can be computed by considering the set to be
+          // (x) : (x = 3*(x floordiv 3) + 1).
+          //
+          // Now there are no equalities defining divisions; the division is
+          // defined by the standard division equalities for e = x floordiv 3,
+          // i.e., 0 <= x - 3*e <= 2.
+          // So now as before, we add these division inequalities to b. The
+          // equality is now just an ordinary constraint that must be considered
+          // in the remainder of the algorithm. The division inequalities must
+          // need not be considered, same as above, and they automatically will
+          // not be because they were never a part of sI; we just infer them
+          // from the equality and add them only to b.
+          b.addInequality(
+              getDivLowerBound(dividends[i], divisors[i],
+                               sI.getVarKindOffset(VarKind::Local) + i));
+          b.addInequality(
+              getDivUpperBound(dividends[i], divisors[i],
+                               sI.getVarKindOffset(VarKind::Local) + i));
         }
       }
 

diff  --git a/mlir/lib/Analysis/Presburger/Utils.cpp b/mlir/lib/Analysis/Presburger/Utils.cpp
index e985c821fb9f..199261789f70 100644
--- a/mlir/lib/Analysis/Presburger/Utils.cpp
+++ b/mlir/lib/Analysis/Presburger/Utils.cpp
@@ -338,6 +338,29 @@ void presburger::mergeLocalVars(
   presburger::removeDuplicateDivs(divsA, denomsA, localOffset, merge);
 }
 
+SmallVector<int64_t, 8> presburger::getDivUpperBound(ArrayRef<int64_t> dividend,
+                                                     int64_t divisor,
+                                                     unsigned localVarIdx) {
+  assert(dividend[localVarIdx] == 0 &&
+         "Local to be set to division must have zero coeff!");
+  SmallVector<int64_t, 8> ineq(dividend.begin(), dividend.end());
+  ineq[localVarIdx] = -divisor;
+  return ineq;
+}
+
+SmallVector<int64_t, 8> presburger::getDivLowerBound(ArrayRef<int64_t> dividend,
+                                                     int64_t divisor,
+                                                     unsigned localVarIdx) {
+  assert(dividend[localVarIdx] == 0 &&
+         "Local to be set to division must have zero coeff!");
+  SmallVector<int64_t, 8> ineq(dividend.size());
+  std::transform(dividend.begin(), dividend.end(), ineq.begin(),
+                 std::negate<int64_t>());
+  ineq[localVarIdx] = divisor;
+  ineq.back() += divisor - 1;
+  return ineq;
+}
+
 int64_t presburger::gcdRange(ArrayRef<int64_t> range) {
   int64_t gcd = 0;
   for (int64_t elem : range) {

diff  --git a/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp b/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp
index 07d4565d116c..02f801ae98b7 100644
--- a/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp
@@ -459,11 +459,50 @@ TEST(SetTest, divisions) {
   PresburgerSet setA{parsePoly("(x) : (-x >= 0)")};
   PresburgerSet setB{parsePoly("(x) : (x floordiv 2 - 4 >= 0)")};
   EXPECT_TRUE(setA.subtract(setB).isEqual(setA));
+}
+
+void convertSuffixDimsToLocals(IntegerPolyhedron &poly, unsigned numLocals) {
+  poly.convertVarKind(VarKind::SetDim, poly.getNumDimVars() - numLocals,
+                      poly.getNumDimVars(), VarKind::Local);
+}
+
+inline IntegerPolyhedron parsePolyAndMakeLocals(StringRef str,
+                                                unsigned numLocals) {
+  IntegerPolyhedron poly = parsePoly(str);
+  convertSuffixDimsToLocals(poly, numLocals);
+  return poly;
+}
+
+TEST(SetTest, divisionsDefByEq) {
+  // evens = {x : exists q, x = 2q}.
+  PresburgerSet evens{
+      parsePolyAndMakeLocals("(x, y) : (x - 2 * y == 0)", /*numLocals=*/1)};
+
+  //  odds = {x : exists q, x = 2q + 1}.
+  PresburgerSet odds{
+      parsePolyAndMakeLocals("(x, y) : (x - 2 * y - 1 == 0)", /*numLocals=*/1)};
+
+  // multiples3 = {x : exists q, x = 3q}.
+  PresburgerSet multiples3{
+      parsePolyAndMakeLocals("(x, y) : (x - 3 * y == 0)", /*numLocals=*/1)};
+
+  // multiples6 = {x : exists q, x = 6q}.
+  PresburgerSet multiples6{
+      parsePolyAndMakeLocals("(x, y) : (x - 6 * y == 0)", /*numLocals=*/1)};
+
+  // evens /\ odds = empty.
+  expectEmpty(PresburgerSet(evens).intersect(PresburgerSet(odds)));
+  // evens U odds = universe.
+  expectEqual(evens.unionSet(odds),
+              PresburgerSet::getUniverse(PresburgerSpace::getSetSpace((1))));
+  expectEqual(evens.complement(), odds);
+  expectEqual(odds.complement(), evens);
+  // even multiples of 3 = multiples of 6.
+  expectEqual(multiples3.intersect(evens), multiples6);
 
-  IntegerPolyhedron evensDefByEquality(PresburgerSpace::getSetSpace(
-      /*numDims=*/1, /*numSymbols=*/0, /*numLocals=*/1));
-  evensDefByEquality.addEquality({1, -2, 0});
-  expectEqual(evens, PresburgerSet(evensDefByEquality));
+  PresburgerSet evensDefByIneq{
+      parsePoly("(x) : (x - 2 * (x floordiv 2) == 0)")};
+  expectEqual(evens, PresburgerSet(evensDefByIneq));
 }
 
 TEST(SetTest, subtractDuplicateDivsRegression) {


        


More information about the Mlir-commits mailing list