[Mlir-commits] [mlir] 00b293e - [MLIR][Presburger] refactor subtraction to be non-recursive
Arjun P
llvmlistbot at llvm.org
Thu Apr 7 07:20:22 PDT 2022
Author: Arjun P
Date: 2022-04-07T15:20:19+01:00
New Revision: 00b293e83f6bb84f970eea972f022d578923d832
URL: https://github.com/llvm/llvm-project/commit/00b293e83f6bb84f970eea972f022d578923d832
DIFF: https://github.com/llvm/llvm-project/commit/00b293e83f6bb84f970eea972f022d578923d832.diff
LOG: [MLIR][Presburger] refactor subtraction to be non-recursive
Subtraction was previously implemented recursively. This refactors it to be
non-recursive to avoid issues with potential stack overflows.
Reviewed By: Groverkss
Differential Revision: https://reviews.llvm.org/D123248
Added:
Modified:
mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
index 1515897db68f4..1bc77ed402028 100644
--- a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
@@ -129,18 +129,17 @@ static SmallVector<int64_t, 8> getIneqCoeffsFromIdx(const IntegerRelation &rel,
return getNegatedCoeffs(eqCoeffs);
}
-/// Return the set
diff erence b \ s and accumulate the result into `result`.
-/// `simplex` must correspond to b.
+/// Return the set
diff erence b \ s.
///
-/// In the following, U denotes union, ^ denotes intersection, \ denotes set
+/// In the following, U denotes union, /\ denotes intersection, \ denotes set
///
diff erence and ~ denotes complement.
-/// Let b be the IntegerRelation and s = (U_i s_i) be the set. We want
-/// b \ (U_i s_i).
///
-/// Let s_i = ^_j s_ij, where each s_ij is a single inequality. To compute
-/// b \ s_i = b ^ ~s_i, we partition s_i based on the first violated inequality:
-/// ~s_i = (~s_i1) U (s_i1 ^ ~s_i2) U (s_i1 ^ s_i2 ^ ~s_i3) U ...
-/// And the required result is (b ^ ~s_i1) U (b ^ s_i1 ^ ~s_i2) U ...
+/// Let s = (U_i s_i). We want b \ (U_i s_i).
+///
+/// Let s_i = /\_j s_ij, where each s_ij is a single inequality. To compute
+/// b \ s_i = b /\ ~s_i, we partition s_i based on the first violated
+/// inequality: ~s_i = (~s_i1) U (s_i1 /\ ~s_i2) U (s_i1 /\ s_i2 /\ ~s_i3) U ...
+/// And the required result is (b /\ ~s_i1) U (b /\ s_i1 /\ ~s_i2) U ...
/// We recurse by subtracting U_{j > i} S_j from each of these parts and
/// returning the union of the results. Each equality is handled as a
/// conjunction of two inequalities.
@@ -162,151 +161,192 @@ static SmallVector<int64_t, 8> getIneqCoeffsFromIdx(const IntegerRelation &rel,
/// that some constraints are redundant. These redundant constraints are
/// ignored.
///
-/// b should not have duplicate divs because this might lead to existing
-/// divs disappearing in the call to mergeLocalIds below, which cannot be
-/// handled.
-static void subtractRecursively(IntegerRelation &b, Simplex &simplex,
- const PresburgerRelation &s, unsigned i,
- PresburgerRelation &result) {
-
- if (i == s.getNumDisjuncts()) {
- result.unionInPlace(b);
- return;
- }
+static PresburgerRelation getSetDifference(IntegerRelation b,
+ const PresburgerRelation &s) {
+ assert(b.isSpaceCompatible(s) && "Spaces should match");
+ if (b.isEmptyByGCDTest())
+ return PresburgerRelation::getEmpty(b.getSpaceWithoutLocals());
- IntegerRelation sI = s.getDisjunct(i);
- // Remove the duplicate divs up front to avoid them possibly disappearing
- // in the call to mergeLocalIds below.
- sI.removeDuplicateDivs();
-
- // Below, we append some additional constraints and ids to b. We want to
- // rollback b to its initial state before returning, which we will do by
- // removing all constraints beyond the original number of inequalities
- // and equalities, so we store these counts first.
- IntegerRelation::CountsSnapshot initBCounts = b.getCounts();
- // Similarly, we also want to rollback simplex to its original state.
- unsigned initialSnapshot = simplex.getSnapshot();
-
- // Find out which inequalities of sI correspond to division inequalities for
- // the local variables of sI.
- std::vector<MaybeLocalRepr> repr(sI.getNumLocalIds());
- sI.getLocalReprs(repr);
-
- // Add sI's locals to b, after b's locals. Also add b's locals to sI, before
- // sI's locals.
- b.mergeLocalIds(sI);
- unsigned numLocalsAdded =
- b.getNumLocalIds() - initBCounts.getSpace().getNumLocalIds();
- // Update simplex to also include the new locals in `b` from merging.
- simplex.appendVariable(numLocalsAdded);
-
- // Equalities are processed by considering them as a pair of inequalities.
- // The first sI.getNumInequalities() elements are for sI's inequalities;
- // then a pair of inequalities occurs for each of sI's equalities.
- // If the equality is expr == 0, the first element in the pair
- // corresponds to expr >= 0, and the second to expr <= 0.
- llvm::SmallBitVector canIgnoreIneq(sI.getNumInequalities() +
- 2 * sI.getNumEqualities());
-
- // Add all division inequalities to `b`.
- for (MaybeLocalRepr &maybeInequality : repr) {
- assert(maybeInequality.kind == ReprKind::Inequality &&
- "Subtraction is not supported when a representation of the local "
- "variables of the subtrahend cannot be found!");
- unsigned lb = maybeInequality.repr.inequalityPair.lowerBoundIdx;
- unsigned ub = maybeInequality.repr.inequalityPair.upperBoundIdx;
-
- b.addInequality(sI.getInequality(lb));
- b.addInequality(sI.getInequality(ub));
-
- assert(lb != ub &&
- "Upper and lower bounds must be
diff erent inequalities!");
-
- // We just added these inequalities to `b`, so there is no point considering
- // the parts where these inequalities occur complemented -- such parts are
- // empty. Therefore, we mark that these can be ignored.
- canIgnoreIneq[lb] = true;
- canIgnoreIneq[ub] = true;
- }
-
- unsigned offset = simplex.getNumConstraints();
- unsigned snapshotBeforeIntersect = simplex.getSnapshot();
- simplex.intersectIntegerRelation(sI);
-
- if (simplex.isEmpty()) {
- // b ^ s_i is empty, so b \ s_i = b. We move directly to i + 1.
- // We are ignoring level i completely, so we restore the state
- // *before* going to level i + 1.
- b.truncate(initBCounts);
- simplex.rollback(initialSnapshot);
- subtractRecursively(b, simplex, s, i + 1, result);
- return;
- }
+ // Remove duplicate divs up front here to avoid existing
+ // divs disappearing in the call to mergeLocalIds below.
+ b.removeDuplicateDivs();
- simplex.detectRedundant();
-
- unsigned totalNewSimplexInequalities =
- 2 * sI.getNumEqualities() + sI.getNumInequalities();
- // Redundant inequalities can be safely ignored. This is not required for
- // correctness but improves performance and results in a more compact
- // representation of the set
diff erence.
- for (unsigned j = 0; j < totalNewSimplexInequalities; j++)
- canIgnoreIneq[j] = simplex.isMarkedRedundant(offset + j);
- simplex.rollback(snapshotBeforeIntersect);
-
- SmallVector<unsigned, 8> ineqsToProcess(totalNewSimplexInequalities);
- for (unsigned i = 0; i < totalNewSimplexInequalities; ++i)
- if (!canIgnoreIneq[i])
- ineqsToProcess.push_back(i);
-
- // Recurse with the part b ^ ~ineq. Note that b is modified throughout
- // subtractRecursively. At the time this function is called, the current b is
- // actually equal to b ^ s_i1 ^ s_i2 ^ ... ^ s_ij, and ineq is the next
- // inequality, s_{i,j+1}. This function recurses into the next level i + 1
- // with the part b ^ s_i1 ^ s_i2 ^ ... ^ s_ij ^ ~s_{i,j+1}.
- auto recurseWithInequality = [&, i](ArrayRef<int64_t> ineq) {
- b.addInequality(ineq);
- simplex.addInequality(ineq);
- subtractRecursively(b, simplex, s, i + 1, result);
+ PresburgerRelation result =
+ PresburgerRelation::getEmpty(b.getSpaceWithoutLocals());
+ Simplex simplex(b);
+
+ // This algorithm is more naturally expressed recursively, but we implement
+ // it iteratively here to avoid issues with stack sizes.
+ //
+ // Each level of the recursion has five stack variables.
+ struct Frame {
+ // A snapshot of the simplex state to rollback to.
+ unsigned simplexSnapshot;
+ // A CountsSnapshot of `b` to rollback to.
+ IntegerRelation::CountsSnapshot bCounts;
+ // The IntegerRelation currently being operated on.
+ IntegerRelation sI;
+ // A list of indexes (see getIneqCoeffsFromIdx) of inequalities to be
+ // processed.
+ SmallVector<unsigned, 8> ineqsToProcess;
+ // The index of the last inequality that was processed at this level.
+ // This is empty when we are coming to this level for the first time.
+ Optional<unsigned> lastIneqProcessed;
};
+ SmallVector<Frame, 2> frames;
+
+ // When we "recurse", we ensure the current frame is stored in `frames` and
+ // increment `level`. When we "tail recurse", we just increment `level`,
+ // without storing any frame. Accordingly, when we return, we return to the
+ // last level that has a frame associated with it.
+ unsigned level = 1;
+ while (level > 0) {
+ if (level - 1 >= s.getNumDisjuncts()) {
+ // No more parts to subtract; add to the result and return.
+ result.unionInPlace(b);
+ level = frames.size();
+ continue;
+ }
- // For each inequality ineq, we first recurse with the part where ineq
- // is not satisfied, and then add the ineq to b and simplex because
- // ineq must be satisfied by all later parts.
- auto processInequality = [&](ArrayRef<int64_t> ineq) {
- unsigned snapshot = simplex.getSnapshot();
- IntegerRelation::CountsSnapshot bCounts = b.getCounts();
- recurseWithInequality(getComplementIneq(ineq));
- simplex.rollback(snapshot);
- b.truncate(bCounts);
-
- b.addInequality(ineq);
- simplex.addInequality(ineq);
- };
+ if (level > frames.size()) {
+ // No frame for this level yet, so we have just recursed into this level.
+ IntegerRelation sI = s.getDisjunct(level - 1);
+ // Remove the duplicate divs up front to avoid them possibly disappearing
+ // in the call to mergeLocalIds below.
+ sI.removeDuplicateDivs();
+
+ // Below, we append some additional constraints and ids to b. We want to
+ // rollback b to its initial state before returning, which we will do by
+ // removing all constraints beyond the original number of inequalities
+ // and equalities, so we store these counts first.
+ IntegerRelation::CountsSnapshot initBCounts = b.getCounts();
+ // Similarly, we also want to rollback simplex to its original state.
+ unsigned initialSnapshot = simplex.getSnapshot();
+
+ // Find out which inequalities of sI correspond to division inequalities
+ // for the local variables of sI.
+ std::vector<MaybeLocalRepr> repr(sI.getNumLocalIds());
+ sI.getLocalReprs(repr);
+
+ // Add sI's locals to b, after b's locals. Only those locals of sI which
+ // do not already exist in b will be added. (i.e., duplicate divisions
+ // will not be added.) Also add b's locals to sI, in such a way that both
+ // have the same locals in the same order in the end.
+ b.mergeLocalIds(sI);
+
+ // Mark which inequalities of sI are division inequalities and add all
+ // such inequalities to b.
+ llvm::SmallBitVector canIgnoreIneq(sI.getNumInequalities() +
+ 2 * sI.getNumEqualities());
+ for (MaybeLocalRepr &maybeInequality : repr) {
+ assert(
+ maybeInequality.kind == ReprKind::Inequality &&
+ "Subtraction is not supported when a representation of the local "
+ "variables of the subtrahend cannot be found!");
+ unsigned lb = maybeInequality.repr.inequalityPair.lowerBoundIdx;
+ unsigned ub = maybeInequality.repr.inequalityPair.upperBoundIdx;
+
+ b.addInequality(sI.getInequality(lb));
+ b.addInequality(sI.getInequality(ub));
+
+ assert(lb != ub &&
+ "Upper and lower bounds must be
diff erent inequalities!");
+ canIgnoreIneq[lb] = true;
+ canIgnoreIneq[ub] = true;
+ }
- for (unsigned idx : ineqsToProcess)
- processInequality(getIneqCoeffsFromIdx(sI, idx));
-}
+ unsigned offset = simplex.getNumConstraints();
+ unsigned numLocalsAdded =
+ b.getNumLocalIds() - initBCounts.getSpace().getNumLocalIds();
+ simplex.appendVariable(numLocalsAdded);
+
+ unsigned snapshotBeforeIntersect = simplex.getSnapshot();
+ simplex.intersectIntegerRelation(sI);
+
+ if (simplex.isEmpty()) {
+ // b /\ s_i is empty, so b \ s_i = b. We move directly to i + 1.
+ // We are ignoring level i completely, so we restore the state
+ // *before* going to the next level. We are "tail recursing", so
+ // we don't add a frame before going to the next level.
+ b.truncate(initBCounts);
+ simplex.rollback(initialSnapshot);
+ ++level;
+ continue;
+ }
-/// Return the set
diff erence disjunct \ set.
-///
-/// The disjunct here is modified in subtractRecursively, so it cannot be a
-/// const reference even though it is restored to its original state before
-/// returning from that function.
-static PresburgerRelation getSetDifference(IntegerRelation disjunct,
- const PresburgerRelation &set) {
- assert(disjunct.isSpaceCompatible(set) && "Spaces should match");
- if (disjunct.isEmptyByGCDTest())
- return PresburgerRelation::getEmpty(disjunct.getSpaceWithoutLocals());
-
- // Remove duplicate divs up front here as subtractRecursively does not support
- // this set having duplicate divs.
- disjunct.removeDuplicateDivs();
+ simplex.detectRedundant();
+
+ // Equalities are added to simplex as a pair of inequalities.
+ unsigned totalNewSimplexInequalities =
+ 2 * sI.getNumEqualities() + sI.getNumInequalities();
+ for (unsigned j = 0; j < totalNewSimplexInequalities; j++)
+ canIgnoreIneq[j] = simplex.isMarkedRedundant(offset + j);
+ simplex.rollback(snapshotBeforeIntersect);
+
+ SmallVector<unsigned, 8> ineqsToProcess(totalNewSimplexInequalities);
+ for (unsigned i = 0; i < totalNewSimplexInequalities; ++i)
+ if (!canIgnoreIneq[i])
+ ineqsToProcess.push_back(i);
+
+ if (ineqsToProcess.empty()) {
+ // Nothing to process; return. (we have no frame to pop.)
+ level = frames.size();
+ continue;
+ }
+
+ unsigned simplexSnapshot = simplex.getSnapshot();
+ IntegerRelation::CountsSnapshot bCounts = b.getCounts();
+ frames.push_back(Frame{simplexSnapshot, bCounts, sI, ineqsToProcess,
+ /*lastIneqProcessed=*/llvm::None});
+ // We have completed the initial setup for this level.
+ // Fallthrough to the main recursive part below.
+ }
+
+ // For each inequality ineq, we first recurse with the part where ineq
+ // is not satisfied, and then add ineq to b and simplex because
+ // ineq must be satisfied by all later parts.
+ if (level == frames.size()) {
+ Frame &frame = frames.back();
+ if (frame.lastIneqProcessed) {
+ // Let the current value of b be b' and
+ // let the initial value of b when we first came to this level be b.
+ //
+ // b' is equal to b /\ s_i1 /\ s_i2 /\ ... /\ s_i{j-1} /\ ~s_ij.
+ // We had previously recursed with the part where s_ij was not
+ // satisfied; all further parts satisfy s_ij, so we rollback to the
+ // state before adding this complement constraint, and add s_ij to b.
+ simplex.rollback(frame.simplexSnapshot);
+ b.truncate(frame.bCounts);
+ SmallVector<int64_t, 8> ineq =
+ getIneqCoeffsFromIdx(frame.sI, *frame.lastIneqProcessed);
+ b.addInequality(ineq);
+ simplex.addInequality(ineq);
+ }
+
+ if (frame.ineqsToProcess.empty()) {
+ // No ineqs left to process; pop this level's frame and return.
+ frames.pop_back();
+ level = frames.size();
+ continue;
+ }
+
+ // "Recurse" with the part where the ineq is not satisfied.
+ frame.bCounts = b.getCounts();
+ frame.simplexSnapshot = simplex.getSnapshot();
+
+ unsigned idx = frame.ineqsToProcess.back();
+ SmallVector<int64_t, 8> ineq =
+ getComplementIneq(getIneqCoeffsFromIdx(frame.sI, idx));
+ b.addInequality(ineq);
+ simplex.addInequality(ineq);
+
+ frame.ineqsToProcess.pop_back();
+ frame.lastIneqProcessed = idx;
+ ++level;
+ continue;
+ }
+ }
- PresburgerRelation result =
- PresburgerRelation::getEmpty(disjunct.getSpaceWithoutLocals());
- Simplex simplex(disjunct);
- subtractRecursively(disjunct, simplex, set, 0, result);
return result;
}
More information about the Mlir-commits
mailing list