[Mlir-commits] [mlir] 1ba6043 - [MLIR][Presburger] Refactor subtraction in preparation for making it iterative

Arjun P llvmlistbot at llvm.org
Wed Apr 6 08:35:27 PDT 2022


Author: Arjun P
Date: 2022-04-06T16:35:28+01:00
New Revision: 1ba6043332a3e11365fe617b5550880b5d6ade8a

URL: https://github.com/llvm/llvm-project/commit/1ba6043332a3e11365fe617b5550880b5d6ade8a
DIFF: https://github.com/llvm/llvm-project/commit/1ba6043332a3e11365fe617b5550880b5d6ade8a.diff

LOG: [MLIR][Presburger] Refactor subtraction in preparation for making it iterative

Refactor the operation of subtraction by
- removing the usage of SimplexRollbackScopeExit since this
  can't be used in the iterative version
- reducing the number of stack variables to make the
  iterative version easier to follow

Reviewed By: Groverkss

Differential Revision: https://reviews.llvm.org/D123156

Added: 
    

Modified: 
    mlir/lib/Analysis/Presburger/PresburgerRelation.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
index 12e0792dc9dd3..1515897db68f4 100644
--- a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
@@ -100,6 +100,35 @@ PresburgerRelation::intersect(const PresburgerRelation &set) const {
   return result;
 }
 
+/// Return the coefficients of the ineq in `rel` specified by  `idx`.
+/// `idx` can refer not only to an actual inequality of `rel`, but also
+/// to either of the inequalities that make up an equality in `rel`.
+///
+/// When 0 <= idx < rel.getNumInequalities(), this returns the coeffs of the
+/// idx-th inequality of `rel`.
+///
+/// Otherwise, it is then considered to index into the ineqs corresponding to
+/// eqs of `rel`, and it must hold that
+///
+/// 0 <= idx - rel.getNumInequalities() < 2*getNumEqualities().
+///
+/// For every eq `coeffs == 0` there are two possible ineqs to index into.
+/// The first is coeffs >= 0 and the second is coeffs <= 0.
+static SmallVector<int64_t, 8> getIneqCoeffsFromIdx(const IntegerRelation &rel,
+                                                    unsigned idx) {
+  assert(idx < rel.getNumInequalities() + 2 * rel.getNumEqualities() &&
+         "idx out of bounds!");
+  if (idx < rel.getNumInequalities())
+    return llvm::to_vector<8>(rel.getInequality(idx));
+
+  idx -= rel.getNumInequalities();
+  ArrayRef<int64_t> eqCoeffs = rel.getEquality(idx / 2);
+
+  if (idx % 2 == 0)
+    return llvm::to_vector<8>(eqCoeffs);
+  return getNegatedCoeffs(eqCoeffs);
+}
+
 /// Return the set 
diff erence b \ s and accumulate the result into `result`.
 /// `simplex` must correspond to b.
 ///
@@ -133,15 +162,13 @@ PresburgerRelation::intersect(const PresburgerRelation &set) const {
 /// that some constraints are redundant. These redundant constraints are
 /// ignored.
 ///
-/// b and simplex are callee saved, i.e., their values on return are
-/// semantically equivalent to their values when the function is called.
-///
 /// b should not have duplicate divs because this might lead to existing
 /// divs disappearing in the call to mergeLocalIds below, which cannot be
 /// handled.
 static void subtractRecursively(IntegerRelation &b, Simplex &simplex,
                                 const PresburgerRelation &s, unsigned i,
                                 PresburgerRelation &result) {
+
   if (i == s.getNumDisjuncts()) {
     result.unionInPlace(b);
     return;
@@ -156,17 +183,9 @@ static void subtractRecursively(IntegerRelation &b, Simplex &simplex,
   // rollback b to its initial state before returning, which we will do by
   // removing all constraints beyond the original number of inequalities
   // and equalities, so we store these counts first.
-  const IntegerRelation::CountsSnapshot bCounts = b.getCounts();
+  IntegerRelation::CountsSnapshot initBCounts = b.getCounts();
   // Similarly, we also want to rollback simplex to its original state.
-  const unsigned initialSnapshot = simplex.getSnapshot();
-
-  auto restoreState = [&]() {
-    b.truncate(bCounts);
-    simplex.rollback(initialSnapshot);
-  };
-
-  // Automatically restore the original state when we return.
-  auto stateRestorer = llvm::make_scope_exit(restoreState);
+  unsigned initialSnapshot = simplex.getSnapshot();
 
   // Find out which inequalities of sI correspond to division inequalities for
   // the local variables of sI.
@@ -176,31 +195,41 @@ static void subtractRecursively(IntegerRelation &b, Simplex &simplex,
   // Add sI's locals to b, after b's locals. Also add b's locals to sI, before
   // sI's locals.
   b.mergeLocalIds(sI);
+  unsigned numLocalsAdded =
+      b.getNumLocalIds() - initBCounts.getSpace().getNumLocalIds();
+  // Update simplex to also include the new locals in `b` from merging.
+  simplex.appendVariable(numLocalsAdded);
 
-  // Mark which inequalities of sI are division inequalities and add all such
-  // inequalities to b.
-  llvm::SmallBitVector isDivInequality(sI.getNumInequalities());
+  // Equalities are processed by considering them as a pair of inequalities.
+  // The first sI.getNumInequalities() elements are for sI's inequalities;
+  // then a pair of inequalities occurs for each of sI's equalities.
+  // If the equality is expr == 0, the first element in the pair
+  // corresponds to expr >= 0, and the second to expr <= 0.
+  llvm::SmallBitVector canIgnoreIneq(sI.getNumInequalities() +
+                                     2 * sI.getNumEqualities());
+
+  // Add all division inequalities to `b`.
   for (MaybeLocalRepr &maybeInequality : repr) {
     assert(maybeInequality.kind == ReprKind::Inequality &&
            "Subtraction is not supported when a representation of the local "
            "variables of the subtrahend cannot be found!");
-    auto lb = maybeInequality.repr.inequalityPair.lowerBoundIdx;
-    auto ub = maybeInequality.repr.inequalityPair.upperBoundIdx;
+    unsigned lb = maybeInequality.repr.inequalityPair.lowerBoundIdx;
+    unsigned ub = maybeInequality.repr.inequalityPair.upperBoundIdx;
 
     b.addInequality(sI.getInequality(lb));
     b.addInequality(sI.getInequality(ub));
 
     assert(lb != ub &&
            "Upper and lower bounds must be 
diff erent inequalities!");
-    isDivInequality[lb] = true;
-    isDivInequality[ub] = true;
+
+    // We just added these inequalities to `b`, so there is no point considering
+    // the parts where these inequalities occur complemented -- such parts are
+    // empty. Therefore, we mark that these can be ignored.
+    canIgnoreIneq[lb] = true;
+    canIgnoreIneq[ub] = true;
   }
 
   unsigned offset = simplex.getNumConstraints();
-  unsigned numLocalsAdded =
-      b.getNumLocalIds() - bCounts.getSpace().getNumLocalIds();
-  simplex.appendVariable(numLocalsAdded);
-
   unsigned snapshotBeforeIntersect = simplex.getSnapshot();
   simplex.intersectIntegerRelation(sI);
 
@@ -208,73 +237,55 @@ static void subtractRecursively(IntegerRelation &b, Simplex &simplex,
     // b ^ s_i is empty, so b \ s_i = b. We move directly to i + 1.
     // We are ignoring level i completely, so we restore the state
     // *before* going to level i + 1.
-    restoreState();
+    b.truncate(initBCounts);
+    simplex.rollback(initialSnapshot);
     subtractRecursively(b, simplex, s, i + 1, result);
-
-    // We already restored the state above and the recursive call should have
-    // restored to the same state before returning, so we don't need to restore
-    // the state again.
-    stateRestorer.release();
     return;
   }
 
   simplex.detectRedundant();
 
-  // Equalities are added to simplex as a pair of inequalities.
   unsigned totalNewSimplexInequalities =
       2 * sI.getNumEqualities() + sI.getNumInequalities();
-  llvm::SmallBitVector isMarkedRedundant(totalNewSimplexInequalities);
+  // Redundant inequalities can be safely ignored. This is not required for
+  // correctness but improves performance and results in a more compact
+  // representation of the set 
diff erence.
   for (unsigned j = 0; j < totalNewSimplexInequalities; j++)
-    isMarkedRedundant[j] = simplex.isMarkedRedundant(offset + j);
-
+    canIgnoreIneq[j] = simplex.isMarkedRedundant(offset + j);
   simplex.rollback(snapshotBeforeIntersect);
 
+  SmallVector<unsigned, 8> ineqsToProcess(totalNewSimplexInequalities);
+  for (unsigned i = 0; i < totalNewSimplexInequalities; ++i)
+    if (!canIgnoreIneq[i])
+      ineqsToProcess.push_back(i);
+
   // Recurse with the part b ^ ~ineq. Note that b is modified throughout
   // subtractRecursively. At the time this function is called, the current b is
   // actually equal to b ^ s_i1 ^ s_i2 ^ ... ^ s_ij, and ineq is the next
   // inequality, s_{i,j+1}. This function recurses into the next level i + 1
   // with the part b ^ s_i1 ^ s_i2 ^ ... ^ s_ij ^ ~s_{i,j+1}.
   auto recurseWithInequality = [&, i](ArrayRef<int64_t> ineq) {
-    SimplexRollbackScopeExit scopeExit(simplex);
     b.addInequality(ineq);
     simplex.addInequality(ineq);
     subtractRecursively(b, simplex, s, i + 1, result);
-    b.removeInequality(b.getNumInequalities() - 1);
   };
 
   // For each inequality ineq, we first recurse with the part where ineq
   // is not satisfied, and then add the ineq to b and simplex because
   // ineq must be satisfied by all later parts.
   auto processInequality = [&](ArrayRef<int64_t> ineq) {
+    unsigned snapshot = simplex.getSnapshot();
+    IntegerRelation::CountsSnapshot bCounts = b.getCounts();
     recurseWithInequality(getComplementIneq(ineq));
+    simplex.rollback(snapshot);
+    b.truncate(bCounts);
+
     b.addInequality(ineq);
     simplex.addInequality(ineq);
   };
 
-  // Process all the inequalities, ignoring redundant inequalities and division
-  // inequalities. The result is correct whether or not we ignore these, but
-  // ignoring them makes the result simpler.
-  for (unsigned j = 0, e = sI.getNumInequalities(); j < e; j++) {
-    if (isMarkedRedundant[j])
-      continue;
-    if (isDivInequality[j])
-      continue;
-    processInequality(sI.getInequality(j));
-  }
-
-  offset = sI.getNumInequalities();
-  for (unsigned j = 0, e = sI.getNumEqualities(); j < e; ++j) {
-    ArrayRef<int64_t> coeffs = sI.getEquality(j);
-    // For each equality, process the positive and negative inequalities that
-    // make up this equality. If Simplex found an inequality to be redundant, we
-    // skip it as above to make the result simpler. Divisions are always
-    // represented in terms of inequalities and not equalities, so we do not
-    // check for division inequalities here.
-    if (!isMarkedRedundant[offset + 2 * j])
-      processInequality(coeffs);
-    if (!isMarkedRedundant[offset + 2 * j + 1])
-      processInequality(getNegatedCoeffs(coeffs));
-  }
+  for (unsigned idx : ineqsToProcess)
+    processInequality(getIneqCoeffsFromIdx(sI, idx));
 }
 
 /// Return the set 
diff erence disjunct \ set.


        


More information about the Mlir-commits mailing list