[Mlir-commits] [mlir] 00b293e - [MLIR][Presburger] refactor subtraction to be non-recursive

Arjun P llvmlistbot at llvm.org
Thu Apr 7 07:20:22 PDT 2022


Author: Arjun P
Date: 2022-04-07T15:20:19+01:00
New Revision: 00b293e83f6bb84f970eea972f022d578923d832

URL: https://github.com/llvm/llvm-project/commit/00b293e83f6bb84f970eea972f022d578923d832
DIFF: https://github.com/llvm/llvm-project/commit/00b293e83f6bb84f970eea972f022d578923d832.diff

LOG: [MLIR][Presburger] refactor subtraction to be non-recursive

Subtraction was previously implemented recursively. This refactors it to be
non-recursive to avoid issues with potential stack overflows.

Reviewed By: Groverkss

Differential Revision: https://reviews.llvm.org/D123248

Added: 
    

Modified: 
    mlir/lib/Analysis/Presburger/PresburgerRelation.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
index 1515897db68f4..1bc77ed402028 100644
--- a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
@@ -129,18 +129,17 @@ static SmallVector<int64_t, 8> getIneqCoeffsFromIdx(const IntegerRelation &rel,
   return getNegatedCoeffs(eqCoeffs);
 }
 
-/// Return the set 
diff erence b \ s and accumulate the result into `result`.
-/// `simplex` must correspond to b.
+/// Return the set 
diff erence b \ s.
 ///
-/// In the following, U denotes union, ^ denotes intersection, \ denotes set
+/// In the following, U denotes union, /\ denotes intersection, \ denotes set
 /// 
diff erence and ~ denotes complement.
-/// Let b be the IntegerRelation and s = (U_i s_i) be the set. We want
-/// b \ (U_i s_i).
 ///
-/// Let s_i = ^_j s_ij, where each s_ij is a single inequality. To compute
-/// b \ s_i = b ^ ~s_i, we partition s_i based on the first violated inequality:
-/// ~s_i = (~s_i1) U (s_i1 ^ ~s_i2) U (s_i1 ^ s_i2 ^ ~s_i3) U ...
-/// And the required result is (b ^ ~s_i1) U (b ^ s_i1 ^ ~s_i2) U ...
+/// Let s = (U_i s_i). We want  b \ (U_i s_i).
+///
+/// Let s_i = /\_j s_ij, where each s_ij is a single inequality. To compute
+/// b \ s_i = b /\ ~s_i, we partition s_i based on the first violated
+/// inequality: ~s_i = (~s_i1) U (s_i1 /\ ~s_i2) U (s_i1 /\ s_i2 /\ ~s_i3) U ...
+/// And the required result is (b /\ ~s_i1) U (b /\ s_i1 /\ ~s_i2) U ...
 /// We recurse by subtracting U_{j > i} S_j from each of these parts and
 /// returning the union of the results. Each equality is handled as a
 /// conjunction of two inequalities.
@@ -162,151 +161,192 @@ static SmallVector<int64_t, 8> getIneqCoeffsFromIdx(const IntegerRelation &rel,
 /// that some constraints are redundant. These redundant constraints are
 /// ignored.
 ///
-/// b should not have duplicate divs because this might lead to existing
-/// divs disappearing in the call to mergeLocalIds below, which cannot be
-/// handled.
-static void subtractRecursively(IntegerRelation &b, Simplex &simplex,
-                                const PresburgerRelation &s, unsigned i,
-                                PresburgerRelation &result) {
-
-  if (i == s.getNumDisjuncts()) {
-    result.unionInPlace(b);
-    return;
-  }
+static PresburgerRelation getSetDifference(IntegerRelation b,
+                                           const PresburgerRelation &s) {
+  assert(b.isSpaceCompatible(s) && "Spaces should match");
+  if (b.isEmptyByGCDTest())
+    return PresburgerRelation::getEmpty(b.getSpaceWithoutLocals());
 
-  IntegerRelation sI = s.getDisjunct(i);
-  // Remove the duplicate divs up front to avoid them possibly disappearing
-  // in the call to mergeLocalIds below.
-  sI.removeDuplicateDivs();
-
-  // Below, we append some additional constraints and ids to b. We want to
-  // rollback b to its initial state before returning, which we will do by
-  // removing all constraints beyond the original number of inequalities
-  // and equalities, so we store these counts first.
-  IntegerRelation::CountsSnapshot initBCounts = b.getCounts();
-  // Similarly, we also want to rollback simplex to its original state.
-  unsigned initialSnapshot = simplex.getSnapshot();
-
-  // Find out which inequalities of sI correspond to division inequalities for
-  // the local variables of sI.
-  std::vector<MaybeLocalRepr> repr(sI.getNumLocalIds());
-  sI.getLocalReprs(repr);
-
-  // Add sI's locals to b, after b's locals. Also add b's locals to sI, before
-  // sI's locals.
-  b.mergeLocalIds(sI);
-  unsigned numLocalsAdded =
-      b.getNumLocalIds() - initBCounts.getSpace().getNumLocalIds();
-  // Update simplex to also include the new locals in `b` from merging.
-  simplex.appendVariable(numLocalsAdded);
-
-  // Equalities are processed by considering them as a pair of inequalities.
-  // The first sI.getNumInequalities() elements are for sI's inequalities;
-  // then a pair of inequalities occurs for each of sI's equalities.
-  // If the equality is expr == 0, the first element in the pair
-  // corresponds to expr >= 0, and the second to expr <= 0.
-  llvm::SmallBitVector canIgnoreIneq(sI.getNumInequalities() +
-                                     2 * sI.getNumEqualities());
-
-  // Add all division inequalities to `b`.
-  for (MaybeLocalRepr &maybeInequality : repr) {
-    assert(maybeInequality.kind == ReprKind::Inequality &&
-           "Subtraction is not supported when a representation of the local "
-           "variables of the subtrahend cannot be found!");
-    unsigned lb = maybeInequality.repr.inequalityPair.lowerBoundIdx;
-    unsigned ub = maybeInequality.repr.inequalityPair.upperBoundIdx;
-
-    b.addInequality(sI.getInequality(lb));
-    b.addInequality(sI.getInequality(ub));
-
-    assert(lb != ub &&
-           "Upper and lower bounds must be 
diff erent inequalities!");
-
-    // We just added these inequalities to `b`, so there is no point considering
-    // the parts where these inequalities occur complemented -- such parts are
-    // empty. Therefore, we mark that these can be ignored.
-    canIgnoreIneq[lb] = true;
-    canIgnoreIneq[ub] = true;
-  }
-
-  unsigned offset = simplex.getNumConstraints();
-  unsigned snapshotBeforeIntersect = simplex.getSnapshot();
-  simplex.intersectIntegerRelation(sI);
-
-  if (simplex.isEmpty()) {
-    // b ^ s_i is empty, so b \ s_i = b. We move directly to i + 1.
-    // We are ignoring level i completely, so we restore the state
-    // *before* going to level i + 1.
-    b.truncate(initBCounts);
-    simplex.rollback(initialSnapshot);
-    subtractRecursively(b, simplex, s, i + 1, result);
-    return;
-  }
+  // Remove duplicate divs up front here to avoid existing
+  // divs disappearing in the call to mergeLocalIds below.
+  b.removeDuplicateDivs();
 
-  simplex.detectRedundant();
-
-  unsigned totalNewSimplexInequalities =
-      2 * sI.getNumEqualities() + sI.getNumInequalities();
-  // Redundant inequalities can be safely ignored. This is not required for
-  // correctness but improves performance and results in a more compact
-  // representation of the set 
diff erence.
-  for (unsigned j = 0; j < totalNewSimplexInequalities; j++)
-    canIgnoreIneq[j] = simplex.isMarkedRedundant(offset + j);
-  simplex.rollback(snapshotBeforeIntersect);
-
-  SmallVector<unsigned, 8> ineqsToProcess(totalNewSimplexInequalities);
-  for (unsigned i = 0; i < totalNewSimplexInequalities; ++i)
-    if (!canIgnoreIneq[i])
-      ineqsToProcess.push_back(i);
-
-  // Recurse with the part b ^ ~ineq. Note that b is modified throughout
-  // subtractRecursively. At the time this function is called, the current b is
-  // actually equal to b ^ s_i1 ^ s_i2 ^ ... ^ s_ij, and ineq is the next
-  // inequality, s_{i,j+1}. This function recurses into the next level i + 1
-  // with the part b ^ s_i1 ^ s_i2 ^ ... ^ s_ij ^ ~s_{i,j+1}.
-  auto recurseWithInequality = [&, i](ArrayRef<int64_t> ineq) {
-    b.addInequality(ineq);
-    simplex.addInequality(ineq);
-    subtractRecursively(b, simplex, s, i + 1, result);
+  PresburgerRelation result =
+      PresburgerRelation::getEmpty(b.getSpaceWithoutLocals());
+  Simplex simplex(b);
+
+  // This algorithm is more naturally expressed recursively, but we implement
+  // it iteratively here to avoid issues with stack sizes.
+  //
+  // Each level of the recursion has five stack variables.
+  struct Frame {
+    // A snapshot of the simplex state to rollback to.
+    unsigned simplexSnapshot;
+    // A CountsSnapshot of `b` to rollback to.
+    IntegerRelation::CountsSnapshot bCounts;
+    // The IntegerRelation currently being operated on.
+    IntegerRelation sI;
+    // A list of indexes (see getIneqCoeffsFromIdx) of inequalities to be
+    // processed.
+    SmallVector<unsigned, 8> ineqsToProcess;
+    // The index of the last inequality that was processed at this level.
+    // This is empty when we are coming to this level for the first time.
+    Optional<unsigned> lastIneqProcessed;
   };
+  SmallVector<Frame, 2> frames;
+
+  // When we "recurse", we ensure the current frame is stored in `frames` and
+  // increment `level`. When we "tail recurse", we just increment `level`,
+  // without storing any frame. Accordingly, when we return, we return to the
+  // last level that has a frame associated with it.
+  unsigned level = 1;
+  while (level > 0) {
+    if (level - 1 >= s.getNumDisjuncts()) {
+      // No more parts to subtract; add to the result and return.
+      result.unionInPlace(b);
+      level = frames.size();
+      continue;
+    }
 
-  // For each inequality ineq, we first recurse with the part where ineq
-  // is not satisfied, and then add the ineq to b and simplex because
-  // ineq must be satisfied by all later parts.
-  auto processInequality = [&](ArrayRef<int64_t> ineq) {
-    unsigned snapshot = simplex.getSnapshot();
-    IntegerRelation::CountsSnapshot bCounts = b.getCounts();
-    recurseWithInequality(getComplementIneq(ineq));
-    simplex.rollback(snapshot);
-    b.truncate(bCounts);
-
-    b.addInequality(ineq);
-    simplex.addInequality(ineq);
-  };
+    if (level > frames.size()) {
+      // No frame for this level yet, so we have just recursed into this level.
+      IntegerRelation sI = s.getDisjunct(level - 1);
+      // Remove the duplicate divs up front to avoid them possibly disappearing
+      // in the call to mergeLocalIds below.
+      sI.removeDuplicateDivs();
+
+      // Below, we append some additional constraints and ids to b. We want to
+      // rollback b to its initial state before returning, which we will do by
+      // removing all constraints beyond the original number of inequalities
+      // and equalities, so we store these counts first.
+      IntegerRelation::CountsSnapshot initBCounts = b.getCounts();
+      // Similarly, we also want to rollback simplex to its original state.
+      unsigned initialSnapshot = simplex.getSnapshot();
+
+      // Find out which inequalities of sI correspond to division inequalities
+      // for the local variables of sI.
+      std::vector<MaybeLocalRepr> repr(sI.getNumLocalIds());
+      sI.getLocalReprs(repr);
+
+      // Add sI's locals to b, after b's locals. Only those locals of sI which
+      // do not already exist in b will be added. (i.e., duplicate divisions
+      // will not be added.) Also add b's locals to sI, in such a way that both
+      // have the same locals in the same order in the end.
+      b.mergeLocalIds(sI);
+
+      // Mark which inequalities of sI are division inequalities and add all
+      // such inequalities to b.
+      llvm::SmallBitVector canIgnoreIneq(sI.getNumInequalities() +
+                                         2 * sI.getNumEqualities());
+      for (MaybeLocalRepr &maybeInequality : repr) {
+        assert(
+            maybeInequality.kind == ReprKind::Inequality &&
+            "Subtraction is not supported when a representation of the local "
+            "variables of the subtrahend cannot be found!");
+        unsigned lb = maybeInequality.repr.inequalityPair.lowerBoundIdx;
+        unsigned ub = maybeInequality.repr.inequalityPair.upperBoundIdx;
+
+        b.addInequality(sI.getInequality(lb));
+        b.addInequality(sI.getInequality(ub));
+
+        assert(lb != ub &&
+               "Upper and lower bounds must be 
diff erent inequalities!");
+        canIgnoreIneq[lb] = true;
+        canIgnoreIneq[ub] = true;
+      }
 
-  for (unsigned idx : ineqsToProcess)
-    processInequality(getIneqCoeffsFromIdx(sI, idx));
-}
+      unsigned offset = simplex.getNumConstraints();
+      unsigned numLocalsAdded =
+          b.getNumLocalIds() - initBCounts.getSpace().getNumLocalIds();
+      simplex.appendVariable(numLocalsAdded);
+
+      unsigned snapshotBeforeIntersect = simplex.getSnapshot();
+      simplex.intersectIntegerRelation(sI);
+
+      if (simplex.isEmpty()) {
+        // b /\ s_i is empty, so b \ s_i = b. We move directly to i + 1.
+        // We are ignoring level i completely, so we restore the state
+        // *before* going to the next level. We are "tail recursing", so
+        // we don't add a frame before going to the next level.
+        b.truncate(initBCounts);
+        simplex.rollback(initialSnapshot);
+        ++level;
+        continue;
+      }
 
-/// Return the set 
diff erence disjunct \ set.
-///
-/// The disjunct here is modified in subtractRecursively, so it cannot be a
-/// const reference even though it is restored to its original state before
-/// returning from that function.
-static PresburgerRelation getSetDifference(IntegerRelation disjunct,
-                                           const PresburgerRelation &set) {
-  assert(disjunct.isSpaceCompatible(set) && "Spaces should match");
-  if (disjunct.isEmptyByGCDTest())
-    return PresburgerRelation::getEmpty(disjunct.getSpaceWithoutLocals());
-
-  // Remove duplicate divs up front here as subtractRecursively does not support
-  // this set having duplicate divs.
-  disjunct.removeDuplicateDivs();
+      simplex.detectRedundant();
+
+      // Equalities are added to simplex as a pair of inequalities.
+      unsigned totalNewSimplexInequalities =
+          2 * sI.getNumEqualities() + sI.getNumInequalities();
+      for (unsigned j = 0; j < totalNewSimplexInequalities; j++)
+        canIgnoreIneq[j] = simplex.isMarkedRedundant(offset + j);
+      simplex.rollback(snapshotBeforeIntersect);
+
+      SmallVector<unsigned, 8> ineqsToProcess(totalNewSimplexInequalities);
+      for (unsigned i = 0; i < totalNewSimplexInequalities; ++i)
+        if (!canIgnoreIneq[i])
+          ineqsToProcess.push_back(i);
+
+      if (ineqsToProcess.empty()) {
+        // Nothing to process; return. (we have no frame to pop.)
+        level = frames.size();
+        continue;
+      }
+
+      unsigned simplexSnapshot = simplex.getSnapshot();
+      IntegerRelation::CountsSnapshot bCounts = b.getCounts();
+      frames.push_back(Frame{simplexSnapshot, bCounts, sI, ineqsToProcess,
+                             /*lastIneqProcessed=*/llvm::None});
+      // We have completed the initial setup for this level.
+      // Fallthrough to the main recursive part below.
+    }
+
+    // For each inequality ineq, we first recurse with the part where ineq
+    // is not satisfied, and then add ineq to b and simplex because
+    // ineq must be satisfied by all later parts.
+    if (level == frames.size()) {
+      Frame &frame = frames.back();
+      if (frame.lastIneqProcessed) {
+        // Let the current value of b be b' and
+        // let the initial value of b when we first came to this level be b.
+        //
+        // b' is equal to b /\ s_i1 /\ s_i2 /\ ... /\ s_i{j-1} /\ ~s_ij.
+        // We had previously recursed with the part where s_ij was not
+        // satisfied; all further parts satisfy s_ij, so we rollback to the
+        // state before adding this complement constraint, and add s_ij to b.
+        simplex.rollback(frame.simplexSnapshot);
+        b.truncate(frame.bCounts);
+        SmallVector<int64_t, 8> ineq =
+            getIneqCoeffsFromIdx(frame.sI, *frame.lastIneqProcessed);
+        b.addInequality(ineq);
+        simplex.addInequality(ineq);
+      }
+
+      if (frame.ineqsToProcess.empty()) {
+        // No ineqs left to process; pop this level's frame and return.
+        frames.pop_back();
+        level = frames.size();
+        continue;
+      }
+
+      // "Recurse" with the part where the ineq is not satisfied.
+      frame.bCounts = b.getCounts();
+      frame.simplexSnapshot = simplex.getSnapshot();
+
+      unsigned idx = frame.ineqsToProcess.back();
+      SmallVector<int64_t, 8> ineq =
+          getComplementIneq(getIneqCoeffsFromIdx(frame.sI, idx));
+      b.addInequality(ineq);
+      simplex.addInequality(ineq);
+
+      frame.ineqsToProcess.pop_back();
+      frame.lastIneqProcessed = idx;
+      ++level;
+      continue;
+    }
+  }
 
-  PresburgerRelation result =
-      PresburgerRelation::getEmpty(disjunct.getSpaceWithoutLocals());
-  Simplex simplex(disjunct);
-  subtractRecursively(disjunct, simplex, set, 0, result);
   return result;
 }
 


        


More information about the Mlir-commits mailing list