[Mlircommits] [mlir] 00b293e  [MLIR][Presburger] refactor subtraction to be nonrecursive
Arjun P
llvmlistbot at llvm.org
Thu Apr 7 07:20:22 PDT 2022
Author: Arjun P
Date: 20220407T15:20:19+01:00
New Revision: 00b293e83f6bb84f970eea972f022d578923d832
URL: https://github.com/llvm/llvmproject/commit/00b293e83f6bb84f970eea972f022d578923d832
DIFF: https://github.com/llvm/llvmproject/commit/00b293e83f6bb84f970eea972f022d578923d832.diff
LOG: [MLIR][Presburger] refactor subtraction to be nonrecursive
Subtraction was previously implemented recursively. This refactors it to be
nonrecursive to avoid issues with potential stack overflows.
Reviewed By: Groverkss
Differential Revision: https://reviews.llvm.org/D123248
Added:
Modified:
mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
Removed:
################################################################################
diff git a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
index 1515897db68f4..1bc77ed402028 100644
 a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
@@ 129,18 +129,17 @@ static SmallVector<int64_t, 8> getIneqCoeffsFromIdx(const IntegerRelation &rel,
return getNegatedCoeffs(eqCoeffs);
}
/// Return the set
diff erence b \ s and accumulate the result into `result`.
/// `simplex` must correspond to b.
+/// Return the set
diff erence b \ s.
///
/// In the following, U denotes union, ^ denotes intersection, \ denotes set
+/// In the following, U denotes union, /\ denotes intersection, \ denotes set
///
diff erence and ~ denotes complement.
/// Let b be the IntegerRelation and s = (U_i s_i) be the set. We want
/// b \ (U_i s_i).
///
/// Let s_i = ^_j s_ij, where each s_ij is a single inequality. To compute
/// b \ s_i = b ^ ~s_i, we partition s_i based on the first violated inequality:
/// ~s_i = (~s_i1) U (s_i1 ^ ~s_i2) U (s_i1 ^ s_i2 ^ ~s_i3) U ...
/// And the required result is (b ^ ~s_i1) U (b ^ s_i1 ^ ~s_i2) U ...
+/// Let s = (U_i s_i). We want b \ (U_i s_i).
+///
+/// Let s_i = /\_j s_ij, where each s_ij is a single inequality. To compute
+/// b \ s_i = b /\ ~s_i, we partition s_i based on the first violated
+/// inequality: ~s_i = (~s_i1) U (s_i1 /\ ~s_i2) U (s_i1 /\ s_i2 /\ ~s_i3) U ...
+/// And the required result is (b /\ ~s_i1) U (b /\ s_i1 /\ ~s_i2) U ...
/// We recurse by subtracting U_{j > i} S_j from each of these parts and
/// returning the union of the results. Each equality is handled as a
/// conjunction of two inequalities.
@@ 162,151 +161,192 @@ static SmallVector<int64_t, 8> getIneqCoeffsFromIdx(const IntegerRelation &rel,
/// that some constraints are redundant. These redundant constraints are
/// ignored.
///
/// b should not have duplicate divs because this might lead to existing
/// divs disappearing in the call to mergeLocalIds below, which cannot be
/// handled.
static void subtractRecursively(IntegerRelation &b, Simplex &simplex,
 const PresburgerRelation &s, unsigned i,
 PresburgerRelation &result) {

 if (i == s.getNumDisjuncts()) {
 result.unionInPlace(b);
 return;
 }
+static PresburgerRelation getSetDifference(IntegerRelation b,
+ const PresburgerRelation &s) {
+ assert(b.isSpaceCompatible(s) && "Spaces should match");
+ if (b.isEmptyByGCDTest())
+ return PresburgerRelation::getEmpty(b.getSpaceWithoutLocals());
 IntegerRelation sI = s.getDisjunct(i);
 // Remove the duplicate divs up front to avoid them possibly disappearing
 // in the call to mergeLocalIds below.
 sI.removeDuplicateDivs();

 // Below, we append some additional constraints and ids to b. We want to
 // rollback b to its initial state before returning, which we will do by
 // removing all constraints beyond the original number of inequalities
 // and equalities, so we store these counts first.
 IntegerRelation::CountsSnapshot initBCounts = b.getCounts();
 // Similarly, we also want to rollback simplex to its original state.
 unsigned initialSnapshot = simplex.getSnapshot();

 // Find out which inequalities of sI correspond to division inequalities for
 // the local variables of sI.
 std::vector<MaybeLocalRepr> repr(sI.getNumLocalIds());
 sI.getLocalReprs(repr);

 // Add sI's locals to b, after b's locals. Also add b's locals to sI, before
 // sI's locals.
 b.mergeLocalIds(sI);
 unsigned numLocalsAdded =
 b.getNumLocalIds()  initBCounts.getSpace().getNumLocalIds();
 // Update simplex to also include the new locals in `b` from merging.
 simplex.appendVariable(numLocalsAdded);

 // Equalities are processed by considering them as a pair of inequalities.
 // The first sI.getNumInequalities() elements are for sI's inequalities;
 // then a pair of inequalities occurs for each of sI's equalities.
 // If the equality is expr == 0, the first element in the pair
 // corresponds to expr >= 0, and the second to expr <= 0.
 llvm::SmallBitVector canIgnoreIneq(sI.getNumInequalities() +
 2 * sI.getNumEqualities());

 // Add all division inequalities to `b`.
 for (MaybeLocalRepr &maybeInequality : repr) {
 assert(maybeInequality.kind == ReprKind::Inequality &&
 "Subtraction is not supported when a representation of the local "
 "variables of the subtrahend cannot be found!");
 unsigned lb = maybeInequality.repr.inequalityPair.lowerBoundIdx;
 unsigned ub = maybeInequality.repr.inequalityPair.upperBoundIdx;

 b.addInequality(sI.getInequality(lb));
 b.addInequality(sI.getInequality(ub));

 assert(lb != ub &&
 "Upper and lower bounds must be
diff erent inequalities!");

 // We just added these inequalities to `b`, so there is no point considering
 // the parts where these inequalities occur complemented  such parts are
 // empty. Therefore, we mark that these can be ignored.
 canIgnoreIneq[lb] = true;
 canIgnoreIneq[ub] = true;
 }

 unsigned offset = simplex.getNumConstraints();
 unsigned snapshotBeforeIntersect = simplex.getSnapshot();
 simplex.intersectIntegerRelation(sI);

 if (simplex.isEmpty()) {
 // b ^ s_i is empty, so b \ s_i = b. We move directly to i + 1.
 // We are ignoring level i completely, so we restore the state
 // *before* going to level i + 1.
 b.truncate(initBCounts);
 simplex.rollback(initialSnapshot);
 subtractRecursively(b, simplex, s, i + 1, result);
 return;
 }
+ // Remove duplicate divs up front here to avoid existing
+ // divs disappearing in the call to mergeLocalIds below.
+ b.removeDuplicateDivs();
 simplex.detectRedundant();

 unsigned totalNewSimplexInequalities =
 2 * sI.getNumEqualities() + sI.getNumInequalities();
 // Redundant inequalities can be safely ignored. This is not required for
 // correctness but improves performance and results in a more compact
 // representation of the set
diff erence.
 for (unsigned j = 0; j < totalNewSimplexInequalities; j++)
 canIgnoreIneq[j] = simplex.isMarkedRedundant(offset + j);
 simplex.rollback(snapshotBeforeIntersect);

 SmallVector<unsigned, 8> ineqsToProcess(totalNewSimplexInequalities);
 for (unsigned i = 0; i < totalNewSimplexInequalities; ++i)
 if (!canIgnoreIneq[i])
 ineqsToProcess.push_back(i);

 // Recurse with the part b ^ ~ineq. Note that b is modified throughout
 // subtractRecursively. At the time this function is called, the current b is
 // actually equal to b ^ s_i1 ^ s_i2 ^ ... ^ s_ij, and ineq is the next
 // inequality, s_{i,j+1}. This function recurses into the next level i + 1
 // with the part b ^ s_i1 ^ s_i2 ^ ... ^ s_ij ^ ~s_{i,j+1}.
 auto recurseWithInequality = [&, i](ArrayRef<int64_t> ineq) {
 b.addInequality(ineq);
 simplex.addInequality(ineq);
 subtractRecursively(b, simplex, s, i + 1, result);
+ PresburgerRelation result =
+ PresburgerRelation::getEmpty(b.getSpaceWithoutLocals());
+ Simplex simplex(b);
+
+ // This algorithm is more naturally expressed recursively, but we implement
+ // it iteratively here to avoid issues with stack sizes.
+ //
+ // Each level of the recursion has five stack variables.
+ struct Frame {
+ // A snapshot of the simplex state to rollback to.
+ unsigned simplexSnapshot;
+ // A CountsSnapshot of `b` to rollback to.
+ IntegerRelation::CountsSnapshot bCounts;
+ // The IntegerRelation currently being operated on.
+ IntegerRelation sI;
+ // A list of indexes (see getIneqCoeffsFromIdx) of inequalities to be
+ // processed.
+ SmallVector<unsigned, 8> ineqsToProcess;
+ // The index of the last inequality that was processed at this level.
+ // This is empty when we are coming to this level for the first time.
+ Optional<unsigned> lastIneqProcessed;
};
+ SmallVector<Frame, 2> frames;
+
+ // When we "recurse", we ensure the current frame is stored in `frames` and
+ // increment `level`. When we "tail recurse", we just increment `level`,
+ // without storing any frame. Accordingly, when we return, we return to the
+ // last level that has a frame associated with it.
+ unsigned level = 1;
+ while (level > 0) {
+ if (level  1 >= s.getNumDisjuncts()) {
+ // No more parts to subtract; add to the result and return.
+ result.unionInPlace(b);
+ level = frames.size();
+ continue;
+ }
 // For each inequality ineq, we first recurse with the part where ineq
 // is not satisfied, and then add the ineq to b and simplex because
 // ineq must be satisfied by all later parts.
 auto processInequality = [&](ArrayRef<int64_t> ineq) {
 unsigned snapshot = simplex.getSnapshot();
 IntegerRelation::CountsSnapshot bCounts = b.getCounts();
 recurseWithInequality(getComplementIneq(ineq));
 simplex.rollback(snapshot);
 b.truncate(bCounts);

 b.addInequality(ineq);
 simplex.addInequality(ineq);
 };
+ if (level > frames.size()) {
+ // No frame for this level yet, so we have just recursed into this level.
+ IntegerRelation sI = s.getDisjunct(level  1);
+ // Remove the duplicate divs up front to avoid them possibly disappearing
+ // in the call to mergeLocalIds below.
+ sI.removeDuplicateDivs();
+
+ // Below, we append some additional constraints and ids to b. We want to
+ // rollback b to its initial state before returning, which we will do by
+ // removing all constraints beyond the original number of inequalities
+ // and equalities, so we store these counts first.
+ IntegerRelation::CountsSnapshot initBCounts = b.getCounts();
+ // Similarly, we also want to rollback simplex to its original state.
+ unsigned initialSnapshot = simplex.getSnapshot();
+
+ // Find out which inequalities of sI correspond to division inequalities
+ // for the local variables of sI.
+ std::vector<MaybeLocalRepr> repr(sI.getNumLocalIds());
+ sI.getLocalReprs(repr);
+
+ // Add sI's locals to b, after b's locals. Only those locals of sI which
+ // do not already exist in b will be added. (i.e., duplicate divisions
+ // will not be added.) Also add b's locals to sI, in such a way that both
+ // have the same locals in the same order in the end.
+ b.mergeLocalIds(sI);
+
+ // Mark which inequalities of sI are division inequalities and add all
+ // such inequalities to b.
+ llvm::SmallBitVector canIgnoreIneq(sI.getNumInequalities() +
+ 2 * sI.getNumEqualities());
+ for (MaybeLocalRepr &maybeInequality : repr) {
+ assert(
+ maybeInequality.kind == ReprKind::Inequality &&
+ "Subtraction is not supported when a representation of the local "
+ "variables of the subtrahend cannot be found!");
+ unsigned lb = maybeInequality.repr.inequalityPair.lowerBoundIdx;
+ unsigned ub = maybeInequality.repr.inequalityPair.upperBoundIdx;
+
+ b.addInequality(sI.getInequality(lb));
+ b.addInequality(sI.getInequality(ub));
+
+ assert(lb != ub &&
+ "Upper and lower bounds must be
diff erent inequalities!");
+ canIgnoreIneq[lb] = true;
+ canIgnoreIneq[ub] = true;
+ }
 for (unsigned idx : ineqsToProcess)
 processInequality(getIneqCoeffsFromIdx(sI, idx));
}
+ unsigned offset = simplex.getNumConstraints();
+ unsigned numLocalsAdded =
+ b.getNumLocalIds()  initBCounts.getSpace().getNumLocalIds();
+ simplex.appendVariable(numLocalsAdded);
+
+ unsigned snapshotBeforeIntersect = simplex.getSnapshot();
+ simplex.intersectIntegerRelation(sI);
+
+ if (simplex.isEmpty()) {
+ // b /\ s_i is empty, so b \ s_i = b. We move directly to i + 1.
+ // We are ignoring level i completely, so we restore the state
+ // *before* going to the next level. We are "tail recursing", so
+ // we don't add a frame before going to the next level.
+ b.truncate(initBCounts);
+ simplex.rollback(initialSnapshot);
+ ++level;
+ continue;
+ }
/// Return the set
diff erence disjunct \ set.
///
/// The disjunct here is modified in subtractRecursively, so it cannot be a
/// const reference even though it is restored to its original state before
/// returning from that function.
static PresburgerRelation getSetDifference(IntegerRelation disjunct,
 const PresburgerRelation &set) {
 assert(disjunct.isSpaceCompatible(set) && "Spaces should match");
 if (disjunct.isEmptyByGCDTest())
 return PresburgerRelation::getEmpty(disjunct.getSpaceWithoutLocals());

 // Remove duplicate divs up front here as subtractRecursively does not support
 // this set having duplicate divs.
 disjunct.removeDuplicateDivs();
+ simplex.detectRedundant();
+
+ // Equalities are added to simplex as a pair of inequalities.
+ unsigned totalNewSimplexInequalities =
+ 2 * sI.getNumEqualities() + sI.getNumInequalities();
+ for (unsigned j = 0; j < totalNewSimplexInequalities; j++)
+ canIgnoreIneq[j] = simplex.isMarkedRedundant(offset + j);
+ simplex.rollback(snapshotBeforeIntersect);
+
+ SmallVector<unsigned, 8> ineqsToProcess(totalNewSimplexInequalities);
+ for (unsigned i = 0; i < totalNewSimplexInequalities; ++i)
+ if (!canIgnoreIneq[i])
+ ineqsToProcess.push_back(i);
+
+ if (ineqsToProcess.empty()) {
+ // Nothing to process; return. (we have no frame to pop.)
+ level = frames.size();
+ continue;
+ }
+
+ unsigned simplexSnapshot = simplex.getSnapshot();
+ IntegerRelation::CountsSnapshot bCounts = b.getCounts();
+ frames.push_back(Frame{simplexSnapshot, bCounts, sI, ineqsToProcess,
+ /*lastIneqProcessed=*/llvm::None});
+ // We have completed the initial setup for this level.
+ // Fallthrough to the main recursive part below.
+ }
+
+ // For each inequality ineq, we first recurse with the part where ineq
+ // is not satisfied, and then add ineq to b and simplex because
+ // ineq must be satisfied by all later parts.
+ if (level == frames.size()) {
+ Frame &frame = frames.back();
+ if (frame.lastIneqProcessed) {
+ // Let the current value of b be b' and
+ // let the initial value of b when we first came to this level be b.
+ //
+ // b' is equal to b /\ s_i1 /\ s_i2 /\ ... /\ s_i{j1} /\ ~s_ij.
+ // We had previously recursed with the part where s_ij was not
+ // satisfied; all further parts satisfy s_ij, so we rollback to the
+ // state before adding this complement constraint, and add s_ij to b.
+ simplex.rollback(frame.simplexSnapshot);
+ b.truncate(frame.bCounts);
+ SmallVector<int64_t, 8> ineq =
+ getIneqCoeffsFromIdx(frame.sI, *frame.lastIneqProcessed);
+ b.addInequality(ineq);
+ simplex.addInequality(ineq);
+ }
+
+ if (frame.ineqsToProcess.empty()) {
+ // No ineqs left to process; pop this level's frame and return.
+ frames.pop_back();
+ level = frames.size();
+ continue;
+ }
+
+ // "Recurse" with the part where the ineq is not satisfied.
+ frame.bCounts = b.getCounts();
+ frame.simplexSnapshot = simplex.getSnapshot();
+
+ unsigned idx = frame.ineqsToProcess.back();
+ SmallVector<int64_t, 8> ineq =
+ getComplementIneq(getIneqCoeffsFromIdx(frame.sI, idx));
+ b.addInequality(ineq);
+ simplex.addInequality(ineq);
+
+ frame.ineqsToProcess.pop_back();
+ frame.lastIneqProcessed = idx;
+ ++level;
+ continue;
+ }
+ }
 PresburgerRelation result =
 PresburgerRelation::getEmpty(disjunct.getSpaceWithoutLocals());
 Simplex simplex(disjunct);
 subtractRecursively(disjunct, simplex, set, 0, result);
return result;
}
More information about the Mlircommits
mailing list