[Mlircommits] [mlir] 1ba6043  [MLIR][Presburger] Refactor subtraction in preparation for making it iterative
Arjun P
llvmlistbot at llvm.org
Wed Apr 6 08:35:27 PDT 2022
Author: Arjun P
Date: 20220406T16:35:28+01:00
New Revision: 1ba6043332a3e11365fe617b5550880b5d6ade8a
URL: https://github.com/llvm/llvmproject/commit/1ba6043332a3e11365fe617b5550880b5d6ade8a
DIFF: https://github.com/llvm/llvmproject/commit/1ba6043332a3e11365fe617b5550880b5d6ade8a.diff
LOG: [MLIR][Presburger] Refactor subtraction in preparation for making it iterative
Refactor the operation of subtraction by
 removing the usage of SimplexRollbackScopeExit since this
can't be used in the iterative version
 reducing the number of stack variables to make the
iterative version easier to follow
Reviewed By: Groverkss
Differential Revision: https://reviews.llvm.org/D123156
Added:
Modified:
mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
Removed:
################################################################################
diff git a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
index 12e0792dc9dd3..1515897db68f4 100644
 a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
@@ 100,6 +100,35 @@ PresburgerRelation::intersect(const PresburgerRelation &set) const {
return result;
}
+/// Return the coefficients of the ineq in `rel` specified by `idx`.
+/// `idx` can refer not only to an actual inequality of `rel`, but also
+/// to either of the inequalities that make up an equality in `rel`.
+///
+/// When 0 <= idx < rel.getNumInequalities(), this returns the coeffs of the
+/// idxth inequality of `rel`.
+///
+/// Otherwise, it is then considered to index into the ineqs corresponding to
+/// eqs of `rel`, and it must hold that
+///
+/// 0 <= idx  rel.getNumInequalities() < 2*getNumEqualities().
+///
+/// For every eq `coeffs == 0` there are two possible ineqs to index into.
+/// The first is coeffs >= 0 and the second is coeffs <= 0.
+static SmallVector<int64_t, 8> getIneqCoeffsFromIdx(const IntegerRelation &rel,
+ unsigned idx) {
+ assert(idx < rel.getNumInequalities() + 2 * rel.getNumEqualities() &&
+ "idx out of bounds!");
+ if (idx < rel.getNumInequalities())
+ return llvm::to_vector<8>(rel.getInequality(idx));
+
+ idx = rel.getNumInequalities();
+ ArrayRef<int64_t> eqCoeffs = rel.getEquality(idx / 2);
+
+ if (idx % 2 == 0)
+ return llvm::to_vector<8>(eqCoeffs);
+ return getNegatedCoeffs(eqCoeffs);
+}
+
/// Return the set
diff erence b \ s and accumulate the result into `result`.
/// `simplex` must correspond to b.
///
@@ 133,15 +162,13 @@ PresburgerRelation::intersect(const PresburgerRelation &set) const {
/// that some constraints are redundant. These redundant constraints are
/// ignored.
///
/// b and simplex are callee saved, i.e., their values on return are
/// semantically equivalent to their values when the function is called.
///
/// b should not have duplicate divs because this might lead to existing
/// divs disappearing in the call to mergeLocalIds below, which cannot be
/// handled.
static void subtractRecursively(IntegerRelation &b, Simplex &simplex,
const PresburgerRelation &s, unsigned i,
PresburgerRelation &result) {
+
if (i == s.getNumDisjuncts()) {
result.unionInPlace(b);
return;
@@ 156,17 +183,9 @@ static void subtractRecursively(IntegerRelation &b, Simplex &simplex,
// rollback b to its initial state before returning, which we will do by
// removing all constraints beyond the original number of inequalities
// and equalities, so we store these counts first.
 const IntegerRelation::CountsSnapshot bCounts = b.getCounts();
+ IntegerRelation::CountsSnapshot initBCounts = b.getCounts();
// Similarly, we also want to rollback simplex to its original state.
 const unsigned initialSnapshot = simplex.getSnapshot();

 auto restoreState = [&]() {
 b.truncate(bCounts);
 simplex.rollback(initialSnapshot);
 };

 // Automatically restore the original state when we return.
 auto stateRestorer = llvm::make_scope_exit(restoreState);
+ unsigned initialSnapshot = simplex.getSnapshot();
// Find out which inequalities of sI correspond to division inequalities for
// the local variables of sI.
@@ 176,31 +195,41 @@ static void subtractRecursively(IntegerRelation &b, Simplex &simplex,
// Add sI's locals to b, after b's locals. Also add b's locals to sI, before
// sI's locals.
b.mergeLocalIds(sI);
+ unsigned numLocalsAdded =
+ b.getNumLocalIds()  initBCounts.getSpace().getNumLocalIds();
+ // Update simplex to also include the new locals in `b` from merging.
+ simplex.appendVariable(numLocalsAdded);
 // Mark which inequalities of sI are division inequalities and add all such
 // inequalities to b.
 llvm::SmallBitVector isDivInequality(sI.getNumInequalities());
+ // Equalities are processed by considering them as a pair of inequalities.
+ // The first sI.getNumInequalities() elements are for sI's inequalities;
+ // then a pair of inequalities occurs for each of sI's equalities.
+ // If the equality is expr == 0, the first element in the pair
+ // corresponds to expr >= 0, and the second to expr <= 0.
+ llvm::SmallBitVector canIgnoreIneq(sI.getNumInequalities() +
+ 2 * sI.getNumEqualities());
+
+ // Add all division inequalities to `b`.
for (MaybeLocalRepr &maybeInequality : repr) {
assert(maybeInequality.kind == ReprKind::Inequality &&
"Subtraction is not supported when a representation of the local "
"variables of the subtrahend cannot be found!");
 auto lb = maybeInequality.repr.inequalityPair.lowerBoundIdx;
 auto ub = maybeInequality.repr.inequalityPair.upperBoundIdx;
+ unsigned lb = maybeInequality.repr.inequalityPair.lowerBoundIdx;
+ unsigned ub = maybeInequality.repr.inequalityPair.upperBoundIdx;
b.addInequality(sI.getInequality(lb));
b.addInequality(sI.getInequality(ub));
assert(lb != ub &&
"Upper and lower bounds must be
diff erent inequalities!");
 isDivInequality[lb] = true;
 isDivInequality[ub] = true;
+
+ // We just added these inequalities to `b`, so there is no point considering
+ // the parts where these inequalities occur complemented  such parts are
+ // empty. Therefore, we mark that these can be ignored.
+ canIgnoreIneq[lb] = true;
+ canIgnoreIneq[ub] = true;
}
unsigned offset = simplex.getNumConstraints();
 unsigned numLocalsAdded =
 b.getNumLocalIds()  bCounts.getSpace().getNumLocalIds();
 simplex.appendVariable(numLocalsAdded);

unsigned snapshotBeforeIntersect = simplex.getSnapshot();
simplex.intersectIntegerRelation(sI);
@@ 208,73 +237,55 @@ static void subtractRecursively(IntegerRelation &b, Simplex &simplex,
// b ^ s_i is empty, so b \ s_i = b. We move directly to i + 1.
// We are ignoring level i completely, so we restore the state
// *before* going to level i + 1.
 restoreState();
+ b.truncate(initBCounts);
+ simplex.rollback(initialSnapshot);
subtractRecursively(b, simplex, s, i + 1, result);

 // We already restored the state above and the recursive call should have
 // restored to the same state before returning, so we don't need to restore
 // the state again.
 stateRestorer.release();
return;
}
simplex.detectRedundant();
 // Equalities are added to simplex as a pair of inequalities.
unsigned totalNewSimplexInequalities =
2 * sI.getNumEqualities() + sI.getNumInequalities();
 llvm::SmallBitVector isMarkedRedundant(totalNewSimplexInequalities);
+ // Redundant inequalities can be safely ignored. This is not required for
+ // correctness but improves performance and results in a more compact
+ // representation of the set
diff erence.
for (unsigned j = 0; j < totalNewSimplexInequalities; j++)
 isMarkedRedundant[j] = simplex.isMarkedRedundant(offset + j);

+ canIgnoreIneq[j] = simplex.isMarkedRedundant(offset + j);
simplex.rollback(snapshotBeforeIntersect);
+ SmallVector<unsigned, 8> ineqsToProcess(totalNewSimplexInequalities);
+ for (unsigned i = 0; i < totalNewSimplexInequalities; ++i)
+ if (!canIgnoreIneq[i])
+ ineqsToProcess.push_back(i);
+
// Recurse with the part b ^ ~ineq. Note that b is modified throughout
// subtractRecursively. At the time this function is called, the current b is
// actually equal to b ^ s_i1 ^ s_i2 ^ ... ^ s_ij, and ineq is the next
// inequality, s_{i,j+1}. This function recurses into the next level i + 1
// with the part b ^ s_i1 ^ s_i2 ^ ... ^ s_ij ^ ~s_{i,j+1}.
auto recurseWithInequality = [&, i](ArrayRef<int64_t> ineq) {
 SimplexRollbackScopeExit scopeExit(simplex);
b.addInequality(ineq);
simplex.addInequality(ineq);
subtractRecursively(b, simplex, s, i + 1, result);
 b.removeInequality(b.getNumInequalities()  1);
};
// For each inequality ineq, we first recurse with the part where ineq
// is not satisfied, and then add the ineq to b and simplex because
// ineq must be satisfied by all later parts.
auto processInequality = [&](ArrayRef<int64_t> ineq) {
+ unsigned snapshot = simplex.getSnapshot();
+ IntegerRelation::CountsSnapshot bCounts = b.getCounts();
recurseWithInequality(getComplementIneq(ineq));
+ simplex.rollback(snapshot);
+ b.truncate(bCounts);
+
b.addInequality(ineq);
simplex.addInequality(ineq);
};
 // Process all the inequalities, ignoring redundant inequalities and division
 // inequalities. The result is correct whether or not we ignore these, but
 // ignoring them makes the result simpler.
 for (unsigned j = 0, e = sI.getNumInequalities(); j < e; j++) {
 if (isMarkedRedundant[j])
 continue;
 if (isDivInequality[j])
 continue;
 processInequality(sI.getInequality(j));
 }

 offset = sI.getNumInequalities();
 for (unsigned j = 0, e = sI.getNumEqualities(); j < e; ++j) {
 ArrayRef<int64_t> coeffs = sI.getEquality(j);
 // For each equality, process the positive and negative inequalities that
 // make up this equality. If Simplex found an inequality to be redundant, we
 // skip it as above to make the result simpler. Divisions are always
 // represented in terms of inequalities and not equalities, so we do not
 // check for division inequalities here.
 if (!isMarkedRedundant[offset + 2 * j])
 processInequality(coeffs);
 if (!isMarkedRedundant[offset + 2 * j + 1])
 processInequality(getNegatedCoeffs(coeffs));
 }
+ for (unsigned idx : ineqsToProcess)
+ processInequality(getIneqCoeffsFromIdx(sI, idx));
}
/// Return the set
diff erence disjunct \ set.
More information about the Mlircommits
mailing list