[Mlir-commits] [mlir] d7ec4d0 - [MLIR] PresburgerSet subtraction: fix bug where the set `b` was not restored properly on return

Arjun P llvmlistbot at llvm.org
Sun Dec 12 03:44:09 PST 2021


Author: Arjun P
Date: 2021-12-12T17:14:04+05:30
New Revision: d7ec4d0be34f4d99cedff1a06e12f0a664d039d9

URL: https://github.com/llvm/llvm-project/commit/d7ec4d0be34f4d99cedff1a06e12f0a664d039d9
DIFF: https://github.com/llvm/llvm-project/commit/d7ec4d0be34f4d99cedff1a06e12f0a664d039d9.diff

LOG: [MLIR] PresburgerSet subtraction: fix bug where the set `b` was not restored properly on return

When subtracting `b \ c`, when there are divisions in `c`, these division
constraints get added to `b`. `b` must be restored to its original state
when returning, but these added divisions constraints were not removed in
one of the return paths. This patch fixes this and deduplicates the
restoration logic by encapuslating it in a lambda `restoreState`. The patch
also includes a regression test for the bug fix.

Reviewed By: Groverkss

Differential Revision: https://reviews.llvm.org/D115577

Added: 
    

Modified: 
    mlir/lib/Analysis/PresburgerSet.cpp
    mlir/unittests/Analysis/PresburgerSetTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Analysis/PresburgerSet.cpp b/mlir/lib/Analysis/PresburgerSet.cpp
index 50976e2fd76e9..aa1b8e70c3ddd 100644
--- a/mlir/lib/Analysis/PresburgerSet.cpp
+++ b/mlir/lib/Analysis/PresburgerSet.cpp
@@ -190,7 +190,25 @@ static void subtractRecursively(FlatAffineConstraints &b, Simplex &simplex,
     return;
   }
   FlatAffineConstraints sI = s.getFlatAffineConstraints(i);
-  unsigned bInitNumLocals = b.getNumLocalIds();
+
+  // Below, we append some additional constraints and ids to b. We want to
+  // rollback b to its initial state before returning, which we will do by
+  // removing all constraints beyond the original number of inequalities
+  // and equalities, so we store these counts first.
+  const unsigned bInitNumIneqs = b.getNumInequalities();
+  const unsigned bInitNumEqs = b.getNumEqualities();
+  const unsigned bInitNumLocals = b.getNumLocalIds();
+  // Similarly, we also want to rollback simplex to its original state.
+  const unsigned initialSnapshot = simplex.getSnapshot();
+
+  // Automatically restore the original state when we return.
+  auto restoreState = [&]() {
+    b.removeIdRange(FlatAffineConstraints::IdKind::Local, bInitNumLocals,
+                    b.getNumLocalIds());
+    b.removeInequalityRange(bInitNumIneqs, b.getNumInequalities());
+    b.removeEqualityRange(bInitNumEqs, b.getNumEqualities());
+    simplex.rollback(initialSnapshot);
+  };
 
   // Find out which inequalities of sI correspond to division inequalities for
   // the local variables of sI.
@@ -219,7 +237,6 @@ static void subtractRecursively(FlatAffineConstraints &b, Simplex &simplex,
     isDivInequality[maybePair->second] = true;
   }
 
-  unsigned initialSnapshot = simplex.getSnapshot();
   unsigned offset = simplex.getNumConstraints();
   unsigned numLocalsAdded = b.getNumLocalIds() - bInitNumLocals;
   simplex.appendVariable(numLocalsAdded);
@@ -229,9 +246,9 @@ static void subtractRecursively(FlatAffineConstraints &b, Simplex &simplex,
 
   if (simplex.isEmpty()) {
     /// b ^ s_i is empty, so b \ s_i = b. We move directly to i + 1.
-    simplex.rollback(initialSnapshot);
-    b.removeIdRange(FlatAffineConstraints::IdKind::Local, bInitNumLocals,
-                    b.getNumLocalIds());
+    /// We are ignoring level i completely, so we restore the state
+    /// *before* going to level i + 1.
+    restoreState();
     subtractRecursively(b, simplex, s, i + 1, result);
     return;
   }
@@ -270,13 +287,6 @@ static void subtractRecursively(FlatAffineConstraints &b, Simplex &simplex,
     simplex.addInequality(ineq);
   };
 
-  // processInequality appends some additional constraints to b. We want to
-  // rollback b to its initial state before returning, which we will do by
-  // removing all constraints beyond the original number of inequalities
-  // and equalities, so we store these counts first.
-  unsigned bInitNumIneqs = b.getNumInequalities();
-  unsigned bInitNumEqs = b.getNumEqualities();
-
   // Process all the inequalities, ignoring redundant inequalities and division
   // inequalities. The result is correct whether or not we ignore these, but
   // ignoring them makes the result simpler.
@@ -302,13 +312,7 @@ static void subtractRecursively(FlatAffineConstraints &b, Simplex &simplex,
       processInequality(getNegatedCoeffs(coeffs));
   }
 
-  // Rollback b and simplex to their initial states.
-  b.removeIdRange(FlatAffineConstraints::IdKind::Local, bInitNumLocals,
-                  b.getNumLocalIds());
-  b.removeInequalityRange(bInitNumIneqs, b.getNumInequalities());
-  b.removeEqualityRange(bInitNumEqs, b.getNumEqualities());
-
-  simplex.rollback(initialSnapshot);
+  restoreState();
 }
 
 /// Return the set 
diff erence fac \ set.

diff  --git a/mlir/unittests/Analysis/PresburgerSetTest.cpp b/mlir/unittests/Analysis/PresburgerSetTest.cpp
index fe7a2dfbe0cb6..5a44b8db9d8cd 100644
--- a/mlir/unittests/Analysis/PresburgerSetTest.cpp
+++ b/mlir/unittests/Analysis/PresburgerSetTest.cpp
@@ -22,6 +22,17 @@
 
 namespace mlir {
 
+/// Parses a FlatAffineConstraints from a StringRef. It is expected that the
+/// string represents a valid IntegerSet, otherwise it will violate a gtest
+/// assertion.
+static FlatAffineConstraints parseFAC(StringRef str, MLIRContext *context) {
+  FailureOr<FlatAffineConstraints> fac = parseIntegerSetToFAC(str, context);
+
+  EXPECT_TRUE(succeeded(fac));
+
+  return *fac;
+}
+
 /// Compute the union of s and t, and check that each of the given points
 /// belongs to the union iff it belongs to at least one of s and t.
 static void testUnionAtPoints(PresburgerSet s, PresburgerSet t,
@@ -620,6 +631,7 @@ void expectEqual(const PresburgerSet &s, const PresburgerSet &t) {
 void expectEmpty(PresburgerSet s) { EXPECT_TRUE(s.isIntegerEmpty()); }
 
 TEST(SetTest, divisions) {
+  MLIRContext context;
   // Note: we currently need to add the equalities as inequalities to the FAC
   // since detecting divisions based on equalities is not yet supported.
 
@@ -644,17 +656,12 @@ TEST(SetTest, divisions) {
   expectEqual(odds.complement(), evens);
   // even multiples of 3 = multiples of 6.
   expectEqual(multiples3.intersect(evens), multiples6);
-}
-
-/// Parses a FlatAffineConstraints from a StringRef. It is expected that the
-/// string represents a valid IntegerSet, otherwise it will violate a gtest
-/// assertion.
-static FlatAffineConstraints parseFAC(StringRef str, MLIRContext *context) {
-  FailureOr<FlatAffineConstraints> fac = parseIntegerSetToFAC(str, context);
-
-  EXPECT_TRUE(succeeded(fac));
 
-  return *fac;
+  PresburgerSet setA =
+      makeSetFromFACs(1, {parseFAC("(x) : (-x >= 0)", &context)});
+  PresburgerSet setB =
+      makeSetFromFACs(1, {parseFAC("(x) : (x floordiv 2 - 4 >= 0)", &context)});
+  EXPECT_TRUE(setA.subtract(setB).isEqual(setA));
 }
 
 /// Coalesce `set` and check that the `newSet` is equal to `set and that


        


More information about the Mlir-commits mailing list