[Mlir-commits] [mlir] 93b9f50 - [MLIR][Presburger] IntegerRelation: implement partial rollback support

Arjun P llvmlistbot at llvm.org
Wed Mar 23 17:27:11 PDT 2022


Author: Arjun P
Date: 2022-03-24T00:27:21Z
New Revision: 93b9f50b4c6e84626f976df95602af3ecbb98ce4

URL: https://github.com/llvm/llvm-project/commit/93b9f50b4c6e84626f976df95602af3ecbb98ce4
DIFF: https://github.com/llvm/llvm-project/commit/93b9f50b4c6e84626f976df95602af3ecbb98ce4.diff

LOG: [MLIR][Presburger] IntegerRelation: implement partial rollback support

It is often necessary to "rollback" IntegerRelations to an earlier state. Although providing full rollback support is non-trivial, we really only need to support the case where the only changes made are to append ids or append constraints, and then rollback these additions. This patch adds support to rollback in such situations by recording the number of ids and constraints of each kind and providing support to truncate the IntegerRelation to those counts by removing appended ids and constraints. This already simplifies subtraction a little bit and will also be useful in the implementation of symbolic integer lexmin.

Reviewed By: Groverkss

Differential Revision: https://reviews.llvm.org/D122178

Added: 
    

Modified: 
    mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
    mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h
    mlir/lib/Analysis/Presburger/IntegerRelation.cpp
    mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
    mlir/lib/Analysis/Presburger/PresburgerSpace.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
index ea2c6aea15d00..41c0500e52367 100644
--- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
+++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
@@ -148,6 +148,30 @@ class IntegerRelation : public PresburgerLocalSpace {
     return inequalities.getRow(idx);
   }
 
+  /// The struct CountsSnapshot stores the count of each IdKind, and also of
+  /// each constraint type. getCounts() returns a CountsSnapshot object
+  /// describing the current state of the IntegerRelation. truncate() truncates
+  /// all ids of each IdKind and all constraints of both kinds beyond the counts
+  /// in the specified CountsSnapshot object. This can be used to achieve
+  /// rudimentary rollback support. As long as none of the existing constraints
+  /// or ids are disturbed, and only additional ids or constraints are added,
+  /// this addition can be rolled back using truncate.
+  struct CountsSnapshot {
+  public:
+    CountsSnapshot(const PresburgerLocalSpace &space, unsigned numIneqs,
+                   unsigned numEqs)
+        : space(space), numIneqs(numIneqs), numEqs(numEqs) {}
+    const PresburgerLocalSpace &getSpace() const { return space; };
+    unsigned getNumIneqs() const { return numIneqs; }
+    unsigned getNumEqs() const { return numEqs; }
+
+  private:
+    PresburgerLocalSpace space;
+    unsigned numIneqs, numEqs;
+  };
+  CountsSnapshot getCounts() const;
+  void truncate(const CountsSnapshot &counts);
+
   /// Insert `num` identifiers of the specified kind at position `pos`.
   /// Positions are relative to the kind of identifier. The coefficient columns
   /// corresponding to the added identifiers are initialized to zero. Return the
@@ -491,6 +515,11 @@ class IntegerRelation : public PresburgerLocalSpace {
   /// arrays as needed.
   void removeIdRange(unsigned idStart, unsigned idLimit);
 
+  using PresburgerSpace::truncateIdKind;
+  /// Truncate the ids to the number in the space of the specified
+  /// CountsSnapshot.
+  void truncateIdKind(IdKind kind, const CountsSnapshot &counts);
+
   /// A parameter that controls detection of an unrealistic number of
   /// constraints. If the number of constraints is this many times the number of
   /// variables, we consider such a system out of line with the intended use

diff  --git a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h
index f9b03e43a2f8e..a832f00c29348 100644
--- a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h
+++ b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h
@@ -109,6 +109,10 @@ class PresburgerSpace {
   /// idLimit). The range is relative to the kind of identifier.
   virtual void removeIdRange(IdKind kind, unsigned idStart, unsigned idLimit);
 
+  /// Truncate the ids of the specified kind to the specified number by dropping
+  /// some ids at the end. `num` must be less than the current number.
+  void truncateIdKind(IdKind kind, unsigned num);
+
   /// Returns true if both the spaces are equal i.e. if both spaces have the
   /// same number of identifiers of each kind (excluding Local Identifiers).
   bool isEqual(const PresburgerSpace &other) const;

diff  --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
index 611615bd347d6..4b41c23c0475c 100644
--- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
@@ -126,6 +126,26 @@ void removeConstraintsInvolvingIdRange(IntegerRelation &poly, unsigned begin,
     if (!rangeIsZero(poly.getInequality(i - 1).slice(begin, count)))
       poly.removeInequality(i - 1);
 }
+
+IntegerRelation::CountsSnapshot IntegerRelation::getCounts() const {
+  return {PresburgerLocalSpace(*this), getNumInequalities(),
+          getNumEqualities()};
+}
+
+void IntegerRelation::truncateIdKind(IdKind kind,
+                                     const CountsSnapshot &counts) {
+  truncateIdKind(kind, counts.getSpace().getNumIdKind(kind));
+}
+
+void IntegerRelation::truncate(const CountsSnapshot &counts) {
+  truncateIdKind(IdKind::Domain, counts);
+  truncateIdKind(IdKind::Range, counts);
+  truncateIdKind(IdKind::Symbol, counts);
+  truncateIdKind(IdKind::Local, counts);
+  removeInequalityRange(counts.getNumIneqs(), getNumInequalities());
+  removeInequalityRange(counts.getNumEqs(), getNumEqualities());
+}
+
 unsigned IntegerRelation::insertId(IdKind kind, unsigned pos, unsigned num) {
   assert(pos <= getNumIdKind(kind));
 

diff  --git a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
index 37a3b78adb269..934c80f0b6967 100644
--- a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
@@ -153,16 +153,12 @@ static void subtractRecursively(IntegerRelation &b, Simplex &simplex,
   // rollback b to its initial state before returning, which we will do by
   // removing all constraints beyond the original number of inequalities
   // and equalities, so we store these counts first.
-  const unsigned bInitNumIneqs = b.getNumInequalities();
-  const unsigned bInitNumEqs = b.getNumEqualities();
-  const unsigned bInitNumLocals = b.getNumLocalIds();
+  const IntegerRelation::CountsSnapshot bCounts = b.getCounts();
   // Similarly, we also want to rollback simplex to its original state.
   const unsigned initialSnapshot = simplex.getSnapshot();
 
   auto restoreState = [&]() {
-    b.removeIdRange(IdKind::Local, bInitNumLocals, b.getNumLocalIds());
-    b.removeInequalityRange(bInitNumIneqs, b.getNumInequalities());
-    b.removeEqualityRange(bInitNumEqs, b.getNumEqualities());
+    b.truncate(bCounts);
     simplex.rollback(initialSnapshot);
   };
 
@@ -198,7 +194,8 @@ static void subtractRecursively(IntegerRelation &b, Simplex &simplex,
   }
 
   unsigned offset = simplex.getNumConstraints();
-  unsigned numLocalsAdded = b.getNumLocalIds() - bInitNumLocals;
+  unsigned numLocalsAdded =
+      b.getNumLocalIds() - bCounts.getSpace().getNumLocalIds();
   simplex.appendVariable(numLocalsAdded);
 
   unsigned snapshotBeforeIntersect = simplex.getSnapshot();

diff  --git a/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp b/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp
index cbeba2416d3a7..c4e06aa0d6f9e 100644
--- a/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp
+++ b/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp
@@ -91,6 +91,12 @@ void PresburgerSpace::removeIdRange(IdKind kind, unsigned idStart,
     llvm_unreachable("PresburgerSpace does not support local identifiers!");
 }
 
+void PresburgerSpace::truncateIdKind(IdKind kind, unsigned num) {
+  unsigned curNum = getNumIdKind(kind);
+  assert(num <= curNum && "Can't truncate to more ids!");
+  removeIdRange(kind, num, curNum);
+}
+
 unsigned PresburgerLocalSpace::insertId(IdKind kind, unsigned pos,
                                         unsigned num) {
   if (kind == IdKind::Local) {


        


More information about the Mlir-commits mailing list