[Mlir-commits] [mlir] 1ba6043 - [MLIR][Presburger] Refactor subtraction in preparation for making it iterative
Arjun P
llvmlistbot at llvm.org
Wed Apr 6 08:35:27 PDT 2022
Author: Arjun P
Date: 2022-04-06T16:35:28+01:00
New Revision: 1ba6043332a3e11365fe617b5550880b5d6ade8a
URL: https://github.com/llvm/llvm-project/commit/1ba6043332a3e11365fe617b5550880b5d6ade8a
DIFF: https://github.com/llvm/llvm-project/commit/1ba6043332a3e11365fe617b5550880b5d6ade8a.diff
LOG: [MLIR][Presburger] Refactor subtraction in preparation for making it iterative
Refactor the operation of subtraction by
- removing the usage of SimplexRollbackScopeExit since this
can't be used in the iterative version
- reducing the number of stack variables to make the
iterative version easier to follow
Reviewed By: Groverkss
Differential Revision: https://reviews.llvm.org/D123156
Added:
Modified:
mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
index 12e0792dc9dd3..1515897db68f4 100644
--- a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
@@ -100,6 +100,35 @@ PresburgerRelation::intersect(const PresburgerRelation &set) const {
return result;
}
+/// Return the coefficients of the ineq in `rel` specified by `idx`.
+/// `idx` can refer not only to an actual inequality of `rel`, but also
+/// to either of the inequalities that make up an equality in `rel`.
+///
+/// When 0 <= idx < rel.getNumInequalities(), this returns the coeffs of the
+/// idx-th inequality of `rel`.
+///
+/// Otherwise, it is then considered to index into the ineqs corresponding to
+/// eqs of `rel`, and it must hold that
+///
+/// 0 <= idx - rel.getNumInequalities() < 2*getNumEqualities().
+///
+/// For every eq `coeffs == 0` there are two possible ineqs to index into.
+/// The first is coeffs >= 0 and the second is coeffs <= 0.
+static SmallVector<int64_t, 8> getIneqCoeffsFromIdx(const IntegerRelation &rel,
+ unsigned idx) {
+ assert(idx < rel.getNumInequalities() + 2 * rel.getNumEqualities() &&
+ "idx out of bounds!");
+ if (idx < rel.getNumInequalities())
+ return llvm::to_vector<8>(rel.getInequality(idx));
+
+ idx -= rel.getNumInequalities();
+ ArrayRef<int64_t> eqCoeffs = rel.getEquality(idx / 2);
+
+ if (idx % 2 == 0)
+ return llvm::to_vector<8>(eqCoeffs);
+ return getNegatedCoeffs(eqCoeffs);
+}
+
/// Return the set
diff erence b \ s and accumulate the result into `result`.
/// `simplex` must correspond to b.
///
@@ -133,15 +162,13 @@ PresburgerRelation::intersect(const PresburgerRelation &set) const {
/// that some constraints are redundant. These redundant constraints are
/// ignored.
///
-/// b and simplex are callee saved, i.e., their values on return are
-/// semantically equivalent to their values when the function is called.
-///
/// b should not have duplicate divs because this might lead to existing
/// divs disappearing in the call to mergeLocalIds below, which cannot be
/// handled.
static void subtractRecursively(IntegerRelation &b, Simplex &simplex,
const PresburgerRelation &s, unsigned i,
PresburgerRelation &result) {
+
if (i == s.getNumDisjuncts()) {
result.unionInPlace(b);
return;
@@ -156,17 +183,9 @@ static void subtractRecursively(IntegerRelation &b, Simplex &simplex,
// rollback b to its initial state before returning, which we will do by
// removing all constraints beyond the original number of inequalities
// and equalities, so we store these counts first.
- const IntegerRelation::CountsSnapshot bCounts = b.getCounts();
+ IntegerRelation::CountsSnapshot initBCounts = b.getCounts();
// Similarly, we also want to rollback simplex to its original state.
- const unsigned initialSnapshot = simplex.getSnapshot();
-
- auto restoreState = [&]() {
- b.truncate(bCounts);
- simplex.rollback(initialSnapshot);
- };
-
- // Automatically restore the original state when we return.
- auto stateRestorer = llvm::make_scope_exit(restoreState);
+ unsigned initialSnapshot = simplex.getSnapshot();
// Find out which inequalities of sI correspond to division inequalities for
// the local variables of sI.
@@ -176,31 +195,41 @@ static void subtractRecursively(IntegerRelation &b, Simplex &simplex,
// Add sI's locals to b, after b's locals. Also add b's locals to sI, before
// sI's locals.
b.mergeLocalIds(sI);
+ unsigned numLocalsAdded =
+ b.getNumLocalIds() - initBCounts.getSpace().getNumLocalIds();
+ // Update simplex to also include the new locals in `b` from merging.
+ simplex.appendVariable(numLocalsAdded);
- // Mark which inequalities of sI are division inequalities and add all such
- // inequalities to b.
- llvm::SmallBitVector isDivInequality(sI.getNumInequalities());
+ // Equalities are processed by considering them as a pair of inequalities.
+ // The first sI.getNumInequalities() elements are for sI's inequalities;
+ // then a pair of inequalities occurs for each of sI's equalities.
+ // If the equality is expr == 0, the first element in the pair
+ // corresponds to expr >= 0, and the second to expr <= 0.
+ llvm::SmallBitVector canIgnoreIneq(sI.getNumInequalities() +
+ 2 * sI.getNumEqualities());
+
+ // Add all division inequalities to `b`.
for (MaybeLocalRepr &maybeInequality : repr) {
assert(maybeInequality.kind == ReprKind::Inequality &&
"Subtraction is not supported when a representation of the local "
"variables of the subtrahend cannot be found!");
- auto lb = maybeInequality.repr.inequalityPair.lowerBoundIdx;
- auto ub = maybeInequality.repr.inequalityPair.upperBoundIdx;
+ unsigned lb = maybeInequality.repr.inequalityPair.lowerBoundIdx;
+ unsigned ub = maybeInequality.repr.inequalityPair.upperBoundIdx;
b.addInequality(sI.getInequality(lb));
b.addInequality(sI.getInequality(ub));
assert(lb != ub &&
"Upper and lower bounds must be
diff erent inequalities!");
- isDivInequality[lb] = true;
- isDivInequality[ub] = true;
+
+ // We just added these inequalities to `b`, so there is no point considering
+ // the parts where these inequalities occur complemented -- such parts are
+ // empty. Therefore, we mark that these can be ignored.
+ canIgnoreIneq[lb] = true;
+ canIgnoreIneq[ub] = true;
}
unsigned offset = simplex.getNumConstraints();
- unsigned numLocalsAdded =
- b.getNumLocalIds() - bCounts.getSpace().getNumLocalIds();
- simplex.appendVariable(numLocalsAdded);
-
unsigned snapshotBeforeIntersect = simplex.getSnapshot();
simplex.intersectIntegerRelation(sI);
@@ -208,73 +237,55 @@ static void subtractRecursively(IntegerRelation &b, Simplex &simplex,
// b ^ s_i is empty, so b \ s_i = b. We move directly to i + 1.
// We are ignoring level i completely, so we restore the state
// *before* going to level i + 1.
- restoreState();
+ b.truncate(initBCounts);
+ simplex.rollback(initialSnapshot);
subtractRecursively(b, simplex, s, i + 1, result);
-
- // We already restored the state above and the recursive call should have
- // restored to the same state before returning, so we don't need to restore
- // the state again.
- stateRestorer.release();
return;
}
simplex.detectRedundant();
- // Equalities are added to simplex as a pair of inequalities.
unsigned totalNewSimplexInequalities =
2 * sI.getNumEqualities() + sI.getNumInequalities();
- llvm::SmallBitVector isMarkedRedundant(totalNewSimplexInequalities);
+ // Redundant inequalities can be safely ignored. This is not required for
+ // correctness but improves performance and results in a more compact
+ // representation of the set
diff erence.
for (unsigned j = 0; j < totalNewSimplexInequalities; j++)
- isMarkedRedundant[j] = simplex.isMarkedRedundant(offset + j);
-
+ canIgnoreIneq[j] = simplex.isMarkedRedundant(offset + j);
simplex.rollback(snapshotBeforeIntersect);
+ SmallVector<unsigned, 8> ineqsToProcess(totalNewSimplexInequalities);
+ for (unsigned i = 0; i < totalNewSimplexInequalities; ++i)
+ if (!canIgnoreIneq[i])
+ ineqsToProcess.push_back(i);
+
// Recurse with the part b ^ ~ineq. Note that b is modified throughout
// subtractRecursively. At the time this function is called, the current b is
// actually equal to b ^ s_i1 ^ s_i2 ^ ... ^ s_ij, and ineq is the next
// inequality, s_{i,j+1}. This function recurses into the next level i + 1
// with the part b ^ s_i1 ^ s_i2 ^ ... ^ s_ij ^ ~s_{i,j+1}.
auto recurseWithInequality = [&, i](ArrayRef<int64_t> ineq) {
- SimplexRollbackScopeExit scopeExit(simplex);
b.addInequality(ineq);
simplex.addInequality(ineq);
subtractRecursively(b, simplex, s, i + 1, result);
- b.removeInequality(b.getNumInequalities() - 1);
};
// For each inequality ineq, we first recurse with the part where ineq
// is not satisfied, and then add the ineq to b and simplex because
// ineq must be satisfied by all later parts.
auto processInequality = [&](ArrayRef<int64_t> ineq) {
+ unsigned snapshot = simplex.getSnapshot();
+ IntegerRelation::CountsSnapshot bCounts = b.getCounts();
recurseWithInequality(getComplementIneq(ineq));
+ simplex.rollback(snapshot);
+ b.truncate(bCounts);
+
b.addInequality(ineq);
simplex.addInequality(ineq);
};
- // Process all the inequalities, ignoring redundant inequalities and division
- // inequalities. The result is correct whether or not we ignore these, but
- // ignoring them makes the result simpler.
- for (unsigned j = 0, e = sI.getNumInequalities(); j < e; j++) {
- if (isMarkedRedundant[j])
- continue;
- if (isDivInequality[j])
- continue;
- processInequality(sI.getInequality(j));
- }
-
- offset = sI.getNumInequalities();
- for (unsigned j = 0, e = sI.getNumEqualities(); j < e; ++j) {
- ArrayRef<int64_t> coeffs = sI.getEquality(j);
- // For each equality, process the positive and negative inequalities that
- // make up this equality. If Simplex found an inequality to be redundant, we
- // skip it as above to make the result simpler. Divisions are always
- // represented in terms of inequalities and not equalities, so we do not
- // check for division inequalities here.
- if (!isMarkedRedundant[offset + 2 * j])
- processInequality(coeffs);
- if (!isMarkedRedundant[offset + 2 * j + 1])
- processInequality(getNegatedCoeffs(coeffs));
- }
+ for (unsigned idx : ineqsToProcess)
+ processInequality(getIneqCoeffsFromIdx(sI, idx));
}
/// Return the set
diff erence disjunct \ set.
More information about the Mlir-commits
mailing list