[Mlir-commits] [mlir] mlir/Presburger: reinstate use of LogicalResult (PR #97415)

Ramkumar Ramachandra llvmlistbot at llvm.org
Tue Jul 2 07:09:24 PDT 2024


https://github.com/artagnon updated https://github.com/llvm/llvm-project/pull/97415

>From 170dbd75018d3f6c90c80e00b164fbaa28d82eac Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Tue, 2 Jul 2024 12:51:33 +0100
Subject: [PATCH 1/2] mlir/Presburger: reinstate use of LogicalResult

Follow up on a desire post-landing d0fee98 (mlir/Presburger: strip
dependency on MLIRSupport) to reinstate the use of LogicalResult in
Presburger. Since db791b2 (mlir/LogicalResult: move into llvm),
LogicalResult is in LLVM, and fulfilling this desire is possible while
still maintaining the goal of stripping the Presburger library of mlir
dependencies.
---
 .../Analysis/Presburger/IntegerRelation.h     | 12 ++--
 .../mlir/Analysis/Presburger/Simplex.h        | 12 ++--
 .../Analysis/FlatLinearValueConstraints.cpp   |  4 +-
 .../Analysis/Presburger/IntegerRelation.cpp   | 26 ++++----
 .../Presburger/PresburgerRelation.cpp         | 59 ++++++++++---------
 mlir/lib/Analysis/Presburger/Simplex.cpp      | 57 +++++++++---------
 mlir/lib/Analysis/Presburger/Utils.cpp        | 30 +++++-----
 7 files changed, 106 insertions(+), 94 deletions(-)

diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
index 5e5cd898b7518..a27fc8c37eeda 100644
--- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
+++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
@@ -21,13 +21,17 @@
 #include "mlir/Analysis/Presburger/Utils.h"
 #include "llvm/ADT/DynamicAPInt.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/LogicalResult.h"
 #include <optional>
 
 namespace mlir {
 namespace presburger {
 using llvm::DynamicAPInt;
+using llvm::failure;
 using llvm::int64fromDynamicAPInt;
+using llvm::LogicalResult;
 using llvm::SmallVectorImpl;
+using llvm::success;
 
 class IntegerRelation;
 class IntegerPolyhedron;
@@ -478,7 +482,7 @@ class IntegerRelation {
   /// equality detection; if successful, the constant is substituted for the
   /// variable everywhere in the constraint system and then removed from the
   /// system.
-  bool constantFoldVar(unsigned pos);
+  LogicalResult constantFoldVar(unsigned pos);
 
   /// This method calls `constantFoldVar` for the specified range of variables,
   /// `num` variables starting at position `pos`.
@@ -501,7 +505,7 @@ class IntegerRelation {
   /// 3) this   = {0 <= d0 <= 5, 1 <= d1 <= 9}
   ///    other  = {2 <= d0 <= 6, 5 <= d1 <= 15},
   ///    output = {0 <= d0 <= 6, 1 <= d1 <= 15}
-  bool unionBoundingBox(const IntegerRelation &other);
+  LogicalResult unionBoundingBox(const IntegerRelation &other);
 
   /// Returns the smallest known constant bound for the extent of the specified
   /// variable (pos^th), i.e., the smallest known constant that is greater
@@ -774,8 +778,8 @@ class IntegerRelation {
   /// Eliminates a single variable at `position` from equality and inequality
   /// constraints. Returns `success` if the variable was eliminated, and
   /// `failure` otherwise.
-  inline bool gaussianEliminateVar(unsigned position) {
-    return gaussianEliminateVars(position, position + 1) == 1;
+  inline LogicalResult gaussianEliminateVar(unsigned position) {
+    return success(gaussianEliminateVars(position, position + 1) == 1);
   }
 
   /// Removes local variables using equalities. Each equality is checked if it
diff --git a/mlir/include/mlir/Analysis/Presburger/Simplex.h b/mlir/include/mlir/Analysis/Presburger/Simplex.h
index f413636e06910..4c40c4cdcb655 100644
--- a/mlir/include/mlir/Analysis/Presburger/Simplex.h
+++ b/mlir/include/mlir/Analysis/Presburger/Simplex.h
@@ -445,7 +445,7 @@ class LexSimplexBase : public SimplexBase {
   /// lexicopositivity of the basis transform. The row must have a non-positive
   /// sample value. If this is not possible, return failure. This occurs when
   /// the constraints have no solution or the sample value is zero.
-  bool moveRowUnknownToColumn(unsigned row);
+  LogicalResult moveRowUnknownToColumn(unsigned row);
 
   /// Given a row that has a non-integer sample value, add an inequality to cut
   /// away this fractional sample value from the polytope without removing any
@@ -459,7 +459,7 @@ class LexSimplexBase : public SimplexBase {
   ///
   /// Return failure if the tableau became empty, and success if it didn't.
   /// Failure status indicates that the polytope was integer empty.
-  bool addCut(unsigned row);
+  LogicalResult addCut(unsigned row);
 
   /// Undo the addition of the last constraint. This is only called while
   /// rolling back.
@@ -511,7 +511,7 @@ class LexSimplex : public LexSimplexBase {
   MaybeOptimum<SmallVector<Fraction, 8>> getRationalSample() const;
 
   /// Make the tableau configuration consistent.
-  bool restoreRationalConsistency();
+  LogicalResult restoreRationalConsistency();
 
   /// Return whether the specified row is violated;
   bool rowIsViolated(unsigned row) const;
@@ -626,7 +626,7 @@ class SymbolicLexSimplex : public LexSimplexBase {
   /// Return failure if the tableau became empty, indicating that the polytope
   /// is always integer empty in the current symbol domain.
   /// Return success otherwise.
-  bool doNonBranchingPivots();
+  LogicalResult doNonBranchingPivots();
 
   /// Get a row that is always violated in the current domain, if one exists.
   std::optional<unsigned> maybeGetAlwaysViolatedRow();
@@ -647,7 +647,7 @@ class SymbolicLexSimplex : public LexSimplexBase {
   /// at the time of the call. (This function may modify the symbol domain, but
   /// failure statu indicates that the polytope was empty for all symbol values
   /// in the initial domain.)
-  bool addSymbolicCut(unsigned row);
+  LogicalResult addSymbolicCut(unsigned row);
 
   /// Get the numerator of the symbolic sample of the specific row.
   /// This is an affine expression in the symbols with integer coefficients.
@@ -820,7 +820,7 @@ class Simplex : public SimplexBase {
   ///
   /// Returns success if the unknown was successfully restored to a non-negative
   /// sample value, failure otherwise.
-  bool restoreRow(Unknown &u);
+  LogicalResult restoreRow(Unknown &u);
 
   /// Find a pivot to change the sample value of row in the specified
   /// direction while preserving tableau consistency, except that if the
diff --git a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
index 746cff525beb2..e628fb152b52f 100644
--- a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
+++ b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
@@ -1247,10 +1247,10 @@ LogicalResult FlatLinearValueConstraints::unionBoundingBox(
   if (!areVarsAligned(*this, otherCst)) {
     FlatLinearValueConstraints otherCopy(otherCst);
     mergeAndAlignVars(/*offset=*/getNumDimVars(), this, &otherCopy);
-    return success(IntegerPolyhedron::unionBoundingBox(otherCopy));
+    return IntegerPolyhedron::unionBoundingBox(otherCopy);
   }
 
-  return success(IntegerPolyhedron::unionBoundingBox(otherCst));
+  return IntegerPolyhedron::unionBoundingBox(otherCst);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
index 6b438692ff6f9..d7a3a933b75dd 100644
--- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
@@ -26,6 +26,7 @@
 #include "llvm/ADT/Sequence.h"
 #include "llvm/ADT/SmallBitVector.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/LogicalResult.h"
 #include "llvm/Support/raw_ostream.h"
 #include <algorithm>
 #include <cassert>
@@ -1552,22 +1553,22 @@ static int findEqualityToConstant(const IntegerRelation &cst, unsigned pos,
   return -1;
 }
 
-bool IntegerRelation::constantFoldVar(unsigned pos) {
+LogicalResult IntegerRelation::constantFoldVar(unsigned pos) {
   assert(pos < getNumVars() && "invalid position");
   int rowIdx;
   if ((rowIdx = findEqualityToConstant(*this, pos)) == -1)
-    return false;
+    return failure();
 
   // atEq(rowIdx, pos) is either -1 or 1.
   assert(atEq(rowIdx, pos) * atEq(rowIdx, pos) == 1);
   DynamicAPInt constVal = -atEq(rowIdx, getNumCols() - 1) / atEq(rowIdx, pos);
   setAndEliminate(pos, constVal);
-  return true;
+  return success();
 }
 
 void IntegerRelation::constantFoldVarRange(unsigned pos, unsigned num) {
   for (unsigned s = pos, t = pos, e = pos + num; s < e; s++) {
-    if (!constantFoldVar(t))
+    if (constantFoldVar(t).failed())
       t++;
   }
 }
@@ -1944,9 +1945,9 @@ void IntegerRelation::fourierMotzkinEliminate(unsigned pos, bool darkShadow,
   for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
     if (atEq(r, pos) != 0) {
       // Use Gaussian elimination here (since we have an equality).
-      bool ret = gaussianEliminateVar(pos);
+      LogicalResult ret = gaussianEliminateVar(pos);
       (void)ret;
-      assert(ret && "Gaussian elimination guaranteed to succeed");
+      assert(ret.succeeded() && "Gaussian elimination guaranteed to succeed");
       LLVM_DEBUG(llvm::dbgs() << "FM output (through Gaussian elimination):\n");
       LLVM_DEBUG(dump());
       return;
@@ -2173,7 +2174,8 @@ static void getCommonConstraints(const IntegerRelation &a,
 
 // Computes the bounding box with respect to 'other' by finding the min of the
 // lower bounds and the max of the upper bounds along each of the dimensions.
-bool IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) {
+LogicalResult
+IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) {
   assert(space.isEqual(otherCst.getSpace()) && "Spaces should match.");
   assert(getNumLocalVars() == 0 && "local ids not supported yet here");
 
@@ -2201,13 +2203,13 @@ bool IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) {
     if (!extent.has_value())
       // TODO: symbolic extents when necessary.
       // TODO: handle union if a dimension is unbounded.
-      return false;
+      return failure();
 
     auto otherExtent = otherCst.getConstantBoundOnDimSize(
         d, &otherLb, &otherLbFloorDivisor, &otherUb);
     if (!otherExtent.has_value() || lbFloorDivisor != otherLbFloorDivisor)
       // TODO: symbolic extents when necessary.
-      return false;
+      return success();
 
     assert(lbFloorDivisor > 0 && "divisor always expected to be positive");
 
@@ -2227,7 +2229,7 @@ bool IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) {
       auto constLb = getConstantBound(BoundType::LB, d);
       auto constOtherLb = otherCst.getConstantBound(BoundType::LB, d);
       if (!constLb.has_value() || !constOtherLb.has_value())
-        return false;
+        return failure();
       std::fill(minLb.begin(), minLb.end(), 0);
       minLb.back() = std::min(*constLb, *constOtherLb);
     }
@@ -2243,7 +2245,7 @@ bool IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) {
       auto constUb = getConstantBound(BoundType::UB, d);
       auto constOtherUb = otherCst.getConstantBound(BoundType::UB, d);
       if (!constUb.has_value() || !constOtherUb.has_value())
-        return false;
+        return failure();
       std::fill(maxUb.begin(), maxUb.end(), 0);
       maxUb.back() = std::max(*constUb, *constOtherUb);
     }
@@ -2281,7 +2283,7 @@ bool IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) {
   // union (since the above are just the union along dimensions); we shouldn't
   // be discarding any other constraints on the symbols.
 
-  return true;
+  return success();
 }
 
 bool IntegerRelation::isColZero(unsigned pos) const {
diff --git a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
index 5c4965c919ac3..e284ca82420ba 100644
--- a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
@@ -15,6 +15,7 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallBitVector.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/LogicalResult.h"
 #include "llvm/Support/raw_ostream.h"
 #include <cassert>
 #include <functional>
@@ -753,18 +754,18 @@ class presburger::SetCoalescer {
   ///     \___\|/            \_____/
   ///
   ///
-  bool coalescePairCutCase(unsigned i, unsigned j);
+  LogicalResult coalescePairCutCase(unsigned i, unsigned j);
 
   /// Types the inequality `ineq` according to its `IneqType` for `simp` into
   /// `redundantIneqsB` and `cuttingIneqsB`. Returns success, if no separate
   /// inequalities were encountered. Otherwise, returns failure.
-  bool typeInequality(ArrayRef<DynamicAPInt> ineq, Simplex &simp);
+  LogicalResult typeInequality(ArrayRef<DynamicAPInt> ineq, Simplex &simp);
 
   /// Types the equality `eq`, i.e. for `eq` == 0, types both `eq` >= 0 and
   /// -`eq` >= 0 according to their `IneqType` for `simp` into
   /// `redundantIneqsB` and `cuttingIneqsB`. Returns success, if no separate
   /// inequalities were encountered. Otherwise, returns failure.
-  bool typeEquality(ArrayRef<DynamicAPInt> eq, Simplex &simp);
+  LogicalResult typeEquality(ArrayRef<DynamicAPInt> eq, Simplex &simp);
 
   /// Replaces the element at position `i` with the last element and erases
   /// the last element for both `disjuncts` and `simplices`.
@@ -775,7 +776,7 @@ class presburger::SetCoalescer {
   /// successfully coalesced. The simplices in `simplices` need to be the ones
   /// constructed from `disjuncts`. At this point, there are no empty
   /// disjuncts in `disjuncts` left.
-  bool coalescePair(unsigned i, unsigned j);
+  LogicalResult coalescePair(unsigned i, unsigned j);
 };
 
 /// Constructs a `SetCoalescer` from a `PresburgerRelation`. Only adds non-empty
@@ -818,7 +819,7 @@ PresburgerRelation SetCoalescer::coalesce() {
       cuttingIneqsB.clear();
       if (i == j)
         continue;
-      if (coalescePair(i, j)) {
+      if (coalescePair(i, j).succeeded()) {
         broken = true;
         break;
       }
@@ -902,7 +903,7 @@ void SetCoalescer::addCoalescedDisjunct(unsigned i, unsigned j,
 ///     \___\|/            \_____/
 ///
 ///
-bool SetCoalescer::coalescePairCutCase(unsigned i, unsigned j) {
+LogicalResult SetCoalescer::coalescePairCutCase(unsigned i, unsigned j) {
   /// All inequalities of `b` need to be redundant. We already know that the
   /// redundant ones are, so only the cutting ones remain to be checked.
   Simplex &simp = simplices[i];
@@ -910,7 +911,7 @@ bool SetCoalescer::coalescePairCutCase(unsigned i, unsigned j) {
   if (llvm::any_of(cuttingIneqsA, [this, &simp](ArrayRef<DynamicAPInt> curr) {
         return !isFacetContained(curr, simp);
       }))
-    return false;
+    return failure();
   IntegerRelation newSet(disjunct.getSpace());
 
   for (ArrayRef<DynamicAPInt> curr : redundantIneqsA)
@@ -920,23 +921,25 @@ bool SetCoalescer::coalescePairCutCase(unsigned i, unsigned j) {
     newSet.addInequality(curr);
 
   addCoalescedDisjunct(i, j, newSet);
-  return true;
+  return success();
 }
 
-bool SetCoalescer::typeInequality(ArrayRef<DynamicAPInt> ineq, Simplex &simp) {
+LogicalResult SetCoalescer::typeInequality(ArrayRef<DynamicAPInt> ineq,
+                                           Simplex &simp) {
   Simplex::IneqType type = simp.findIneqType(ineq);
   if (type == Simplex::IneqType::Redundant)
     redundantIneqsB.push_back(ineq);
   else if (type == Simplex::IneqType::Cut)
     cuttingIneqsB.push_back(ineq);
   else
-    return false;
-  return true;
+    return failure();
+  return success();
 }
 
-bool SetCoalescer::typeEquality(ArrayRef<DynamicAPInt> eq, Simplex &simp) {
-  if (!typeInequality(eq, simp))
-    return false;
+LogicalResult SetCoalescer::typeEquality(ArrayRef<DynamicAPInt> eq,
+                                         Simplex &simp) {
+  if (typeInequality(eq, simp).failed())
+    return failure();
   negEqs.push_back(getNegatedCoeffs(eq));
   ArrayRef<DynamicAPInt> inv(negEqs.back());
   return typeInequality(inv, simp);
@@ -951,7 +954,7 @@ void SetCoalescer::eraseDisjunct(unsigned i) {
   simplices.pop_back();
 }
 
-bool SetCoalescer::coalescePair(unsigned i, unsigned j) {
+LogicalResult SetCoalescer::coalescePair(unsigned i, unsigned j) {
 
   IntegerRelation &a = disjuncts[i];
   IntegerRelation &b = disjuncts[j];
@@ -959,7 +962,7 @@ bool SetCoalescer::coalescePair(unsigned i, unsigned j) {
   /// skipped.
   /// TODO: implement local id support.
   if (a.getNumLocalVars() != 0 || b.getNumLocalVars() != 0)
-    return false;
+    return failure();
   Simplex &simpA = simplices[i];
   Simplex &simpB = simplices[j];
 
@@ -969,34 +972,34 @@ bool SetCoalescer::coalescePair(unsigned i, unsigned j) {
   // inequality is encountered during typing, the two IntegerRelations
   // cannot be coalesced.
   for (int k = 0, e = a.getNumInequalities(); k < e; ++k)
-    if (!typeInequality(a.getInequality(k), simpB))
-      return false;
+    if (typeInequality(a.getInequality(k), simpB).failed())
+      return failure();
 
   for (int k = 0, e = a.getNumEqualities(); k < e; ++k)
-    if (!typeEquality(a.getEquality(k), simpB))
-      return false;
+    if (typeEquality(a.getEquality(k), simpB).failed())
+      return failure();
 
   std::swap(redundantIneqsA, redundantIneqsB);
   std::swap(cuttingIneqsA, cuttingIneqsB);
 
   for (int k = 0, e = b.getNumInequalities(); k < e; ++k)
-    if (!typeInequality(b.getInequality(k), simpA))
-      return false;
+    if (typeInequality(b.getInequality(k), simpA).failed())
+      return failure();
 
   for (int k = 0, e = b.getNumEqualities(); k < e; ++k)
-    if (!typeEquality(b.getEquality(k), simpA))
-      return false;
+    if (typeEquality(b.getEquality(k), simpA).failed())
+      return failure();
 
   // If there are no cutting inequalities of `a`, `b` is contained
   // within `a`.
   if (cuttingIneqsA.empty()) {
     eraseDisjunct(j);
-    return true;
+    return success();
   }
 
   // Try to apply the cut case
-  if (coalescePairCutCase(i, j))
-    return true;
+  if (coalescePairCutCase(i, j).succeeded())
+    return success();
 
   // Swap the vectors to compare the pair (j,i) instead of (i,j).
   std::swap(redundantIneqsA, redundantIneqsB);
@@ -1006,7 +1009,7 @@ bool SetCoalescer::coalescePair(unsigned i, unsigned j) {
   // within `a`.
   if (cuttingIneqsA.empty()) {
     eraseDisjunct(i);
-    return true;
+    return success();
   }
 
   // Try to apply the cut case
diff --git a/mlir/lib/Analysis/Presburger/Simplex.cpp b/mlir/lib/Analysis/Presburger/Simplex.cpp
index 4efc7a3755014..bebbf0325f430 100644
--- a/mlir/lib/Analysis/Presburger/Simplex.cpp
+++ b/mlir/lib/Analysis/Presburger/Simplex.cpp
@@ -17,6 +17,7 @@
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Support/Compiler.h"
 #include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/LogicalResult.h"
 #include "llvm/Support/raw_ostream.h"
 #include <cassert>
 #include <functional>
@@ -229,7 +230,7 @@ Direction flippedDirection(Direction direction) {
 /// add these to the set of ignored columns and continue to the next row. If we
 /// run out of rows, then A*y is zero and we are done.
 MaybeOptimum<SmallVector<Fraction, 8>> LexSimplex::findRationalLexMin() {
-  if (!restoreRationalConsistency()) {
+  if (restoreRationalConsistency().failed()) {
     markEmpty();
     return OptimumKind::Empty;
   }
@@ -274,7 +275,7 @@ MaybeOptimum<SmallVector<Fraction, 8>> LexSimplex::findRationalLexMin() {
 ///
 /// The constraint is violated when added (it would be useless otherwise)
 /// so we immediately try to move it to a column.
-bool LexSimplexBase::addCut(unsigned row) {
+LogicalResult LexSimplexBase::addCut(unsigned row) {
   DynamicAPInt d = tableau(row, 0);
   unsigned cutRow = addZeroRow(/*makeRestricted=*/true);
   tableau(cutRow, 0) = d;
@@ -301,7 +302,7 @@ std::optional<unsigned> LexSimplex::maybeGetNonIntegralVarRow() const {
 
 MaybeOptimum<SmallVector<DynamicAPInt, 8>> LexSimplex::findIntegerLexMin() {
   // We first try to make the tableau consistent.
-  if (!restoreRationalConsistency())
+  if (restoreRationalConsistency().failed())
     return OptimumKind::Empty;
 
   // Then, if the sample value is integral, we are done.
@@ -316,9 +317,9 @@ MaybeOptimum<SmallVector<DynamicAPInt, 8>> LexSimplex::findIntegerLexMin() {
     //
     // Failure indicates that the tableau became empty, which occurs when the
     // polytope is integer empty.
-    if (!addCut(*maybeRow))
+    if (addCut(*maybeRow).failed())
       return OptimumKind::Empty;
-    if (!restoreRationalConsistency())
+    if (restoreRationalConsistency().failed())
       return OptimumKind::Empty;
   }
 
@@ -411,7 +412,7 @@ bool SymbolicLexSimplex::isSymbolicSampleIntegral(unsigned row) const {
 /// (sum_i (b_i%d)y_i - (-c%d) - sum_i (-a_i%d)s_i + q*d)/d >= 0
 /// This constraint is violated when added so we immediately try to move it to a
 /// column.
-bool SymbolicLexSimplex::addSymbolicCut(unsigned row) {
+LogicalResult SymbolicLexSimplex::addSymbolicCut(unsigned row) {
   DynamicAPInt d = tableau(row, 0);
   if (isRangeDivisibleBy(tableau.getRow(row).slice(3, nSymbol), d)) {
     // The coefficients of symbols in the symbol numerator are divisible
@@ -523,11 +524,11 @@ std::optional<unsigned> SymbolicLexSimplex::maybeGetNonIntegralVarRow() {
 
 /// The non-branching pivots are just the ones moving the rows
 /// that are always violated in the symbol domain.
-bool SymbolicLexSimplex::doNonBranchingPivots() {
+LogicalResult SymbolicLexSimplex::doNonBranchingPivots() {
   while (std::optional<unsigned> row = maybeGetAlwaysViolatedRow())
-    if (!moveRowUnknownToColumn(*row))
-      return false;
-  return true;
+    if (moveRowUnknownToColumn(*row).failed())
+      return failure();
+  return success();
 }
 
 SymbolicLexOpt SymbolicLexSimplex::computeSymbolicIntegerLexMin() {
@@ -567,7 +568,7 @@ SymbolicLexOpt SymbolicLexSimplex::computeSymbolicIntegerLexMin() {
         continue;
       }
 
-      if (!doNonBranchingPivots()) {
+      if (doNonBranchingPivots().failed()) {
         // Could not find pivots for violated constraints; return.
         --level;
         continue;
@@ -627,7 +628,7 @@ SymbolicLexOpt SymbolicLexSimplex::computeSymbolicIntegerLexMin() {
       // The tableau is rationally consistent for the current domain.
       // Now we look for non-integral sample values and add cuts for them.
       if (std::optional<unsigned> row = maybeGetNonIntegralVarRow()) {
-        if (!addSymbolicCut(*row)) {
+        if (addSymbolicCut(*row).failed()) {
           // No integral points; return.
           --level;
           continue;
@@ -661,7 +662,7 @@ SymbolicLexOpt SymbolicLexSimplex::computeSymbolicIntegerLexMin() {
       SmallVector<DynamicAPInt, 8> splitIneq =
           getComplementIneq(getSymbolicSampleIneq(u.pos));
       normalizeRange(splitIneq);
-      if (!moveRowUnknownToColumn(u.pos)) {
+      if (moveRowUnknownToColumn(u.pos).failed()) {
         // The unknown can't be made non-negative; return.
         --level;
         continue;
@@ -699,13 +700,13 @@ std::optional<unsigned> LexSimplex::maybeGetViolatedRow() const {
 /// We simply look for violated rows and keep trying to move them to column
 /// orientation, which always succeeds unless the constraints have no solution
 /// in which case we just give up and return.
-bool LexSimplex::restoreRationalConsistency() {
+LogicalResult LexSimplex::restoreRationalConsistency() {
   if (empty)
-    return false;
+    return failure();
   while (std::optional<unsigned> maybeViolatedRow = maybeGetViolatedRow())
-    if (!moveRowUnknownToColumn(*maybeViolatedRow))
-      return false;
-  return true;
+    if (moveRowUnknownToColumn(*maybeViolatedRow).failed())
+      return failure();
+  return success();
 }
 
 // Move the row unknown to column orientation while preserving lexicopositivity
@@ -770,7 +771,7 @@ bool LexSimplex::restoreRationalConsistency() {
 // which is in contradiction to the fact that B.col(j) / B(i,j) must be
 // lexicographically smaller than B.col(k) / B(i,k), since it lexicographically
 // minimizes the change in sample value.
-bool LexSimplexBase::moveRowUnknownToColumn(unsigned row) {
+LogicalResult LexSimplexBase::moveRowUnknownToColumn(unsigned row) {
   std::optional<unsigned> maybeColumn;
   for (unsigned col = 3 + nSymbol, e = getNumColumns(); col < e; ++col) {
     if (tableau(row, col) <= 0)
@@ -780,10 +781,10 @@ bool LexSimplexBase::moveRowUnknownToColumn(unsigned row) {
   }
 
   if (!maybeColumn)
-    return false;
+    return failure();
 
   pivot(row, *maybeColumn);
-  return true;
+  return success();
 }
 
 unsigned LexSimplexBase::getLexMinPivotColumn(unsigned row, unsigned colA,
@@ -986,7 +987,7 @@ void SimplexBase::pivot(unsigned pivotRow, unsigned pivotCol) {
 /// Perform pivots until the unknown has a non-negative sample value or until
 /// no more upward pivots can be performed. Return success if we were able to
 /// bring the row to a non-negative sample value, and failure otherwise.
-bool Simplex::restoreRow(Unknown &u) {
+LogicalResult Simplex::restoreRow(Unknown &u) {
   assert(u.orientation == Orientation::Row &&
          "unknown should be in row position");
 
@@ -997,9 +998,9 @@ bool Simplex::restoreRow(Unknown &u) {
 
     pivot(*maybePivot);
     if (u.orientation == Orientation::Column)
-      return true; // the unknown is unbounded above.
+      return success(); // the unknown is unbounded above.
   }
-  return tableau(u.pos, 1) >= 0;
+  return success(tableau(u.pos, 1) >= 0);
 }
 
 /// Find a row that can be used to pivot the column in the specified direction.
@@ -1105,8 +1106,8 @@ void SimplexBase::markEmpty() {
 /// empty and we mark it as such.
 void Simplex::addInequality(ArrayRef<DynamicAPInt> coeffs) {
   unsigned conIndex = addRow(coeffs, /*makeRestricted=*/true);
-  bool result = restoreRow(con[conIndex]);
-  if (!result)
+  LogicalResult result = restoreRow(con[conIndex]);
+  if (result.failed())
     markEmpty();
 }
 
@@ -1384,7 +1385,7 @@ MaybeOptimum<Fraction> Simplex::computeOptimum(Direction direction,
   MaybeOptimum<Fraction> optimum = computeRowOptimum(direction, row);
   if (u.restricted && direction == Direction::Down &&
       (optimum.isUnbounded() || *optimum < Fraction(0, 1))) {
-    if (!restoreRow(u))
+    if (restoreRow(u).failed())
       llvm_unreachable("Could not restore row!");
   }
   return optimum;
@@ -1453,7 +1454,7 @@ void Simplex::detectRedundant(unsigned offset, unsigned count) {
     if (minimum.isUnbounded() || *minimum < Fraction(0, 1)) {
       // Constraint is unbounded below or can attain negative sample values and
       // hence is not redundant.
-      if (!restoreRow(u))
+      if (restoreRow(u).failed())
         llvm_unreachable("Could not restore non-redundant row!");
       continue;
     }
diff --git a/mlir/lib/Analysis/Presburger/Utils.cpp b/mlir/lib/Analysis/Presburger/Utils.cpp
index 65190c6f07d4b..9b32972de2e0a 100644
--- a/mlir/lib/Analysis/Presburger/Utils.cpp
+++ b/mlir/lib/Analysis/Presburger/Utils.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Analysis/Presburger/PresburgerSpace.h"
 #include "llvm/ADT/STLFunctionalExtras.h"
 #include "llvm/ADT/SmallBitVector.h"
+#include "llvm/Support/LogicalResult.h"
 #include "llvm/Support/raw_ostream.h"
 #include <cassert>
 #include <cstdint>
@@ -95,10 +96,10 @@ static void normalizeDivisionByGCD(MutableArrayRef<DynamicAPInt> dividend,
 /// If successful, `expr` is set to dividend of the division and `divisor` is
 /// set to the denominator of the division, which will be positive.
 /// The final division expression is normalized by GCD.
-static bool getDivRepr(const IntegerRelation &cst, unsigned pos,
-                       unsigned ubIneq, unsigned lbIneq,
-                       MutableArrayRef<DynamicAPInt> expr,
-                       DynamicAPInt &divisor) {
+static LogicalResult getDivRepr(const IntegerRelation &cst, unsigned pos,
+                                unsigned ubIneq, unsigned lbIneq,
+                                MutableArrayRef<DynamicAPInt> expr,
+                                DynamicAPInt &divisor) {
 
   assert(pos <= cst.getNumVars() && "Invalid variable position");
   assert(ubIneq <= cst.getNumInequalities() &&
@@ -120,7 +121,7 @@ static bool getDivRepr(const IntegerRelation &cst, unsigned pos,
       break;
 
   if (i < e)
-    return false;
+    return failure();
 
   // Then, check if the constant term is of the proper form.
   // Due to the form of the upper/lower bound inequalities, the sum of their
@@ -132,7 +133,7 @@ static bool getDivRepr(const IntegerRelation &cst, unsigned pos,
   // Check if `c` satisfies the condition `0 <= c <= divisor - 1`.
   // This also implictly checks that `divisor` is positive.
   if (!(0 <= c && c <= divisor - 1)) // NOLINT
-    return false;
+    return failure();
 
   // The inequality pair can be used to extract the division.
   // Set `expr` to the dividend of the division except the constant term, which
@@ -147,7 +148,7 @@ static bool getDivRepr(const IntegerRelation &cst, unsigned pos,
   expr.back() = cst.atIneq(ubIneq, cst.getNumCols() - 1) + c;
   normalizeDivisionByGCD(expr, divisor);
 
-  return true;
+  return success();
 }
 
 /// Check if the pos^th variable can be represented as a division using
@@ -161,9 +162,10 @@ static bool getDivRepr(const IntegerRelation &cst, unsigned pos,
 /// If successful, `expr` is set to dividend of the division and `divisor` is
 /// set to the denominator of the division. The final division expression is
 /// normalized by GCD.
-static bool getDivRepr(const IntegerRelation &cst, unsigned pos, unsigned eqInd,
-                       MutableArrayRef<DynamicAPInt> expr,
-                       DynamicAPInt &divisor) {
+static LogicalResult getDivRepr(const IntegerRelation &cst, unsigned pos,
+                                unsigned eqInd,
+                                MutableArrayRef<DynamicAPInt> expr,
+                                DynamicAPInt &divisor) {
 
   assert(pos <= cst.getNumVars() && "Invalid variable position");
   assert(eqInd <= cst.getNumEqualities() && "Invalid equality position");
@@ -174,7 +176,7 @@ static bool getDivRepr(const IntegerRelation &cst, unsigned pos, unsigned eqInd,
   // Equality must involve the pos-th variable and hence `tempDiv` != 0.
   DynamicAPInt tempDiv = cst.atEq(eqInd, pos);
   if (tempDiv == 0)
-    return false;
+    return failure();
   int signDiv = tempDiv < 0 ? -1 : 1;
 
   // The divisor is always a positive integer.
@@ -187,7 +189,7 @@ static bool getDivRepr(const IntegerRelation &cst, unsigned pos, unsigned eqInd,
   expr.back() = -signDiv * cst.atEq(eqInd, cst.getNumCols() - 1);
   normalizeDivisionByGCD(expr, divisor);
 
-  return true;
+  return success();
 }
 
 // Returns `false` if the constraints depends on a variable for which an
@@ -238,7 +240,7 @@ MaybeLocalRepr presburger::computeSingleVarRepr(
   for (unsigned ubPos : ubIndices) {
     for (unsigned lbPos : lbIndices) {
       // Attempt to get divison representation from ubPos, lbPos.
-      if (!getDivRepr(cst, pos, ubPos, lbPos, dividend, divisor))
+      if (getDivRepr(cst, pos, ubPos, lbPos, dividend, divisor).failed())
         continue;
 
       if (!checkExplicitRepresentation(cst, foundRepr, dividend, pos))
@@ -251,7 +253,7 @@ MaybeLocalRepr presburger::computeSingleVarRepr(
   }
   for (unsigned eqPos : eqIndices) {
     // Attempt to get divison representation from eqPos.
-    if (!getDivRepr(cst, pos, eqPos, dividend, divisor))
+    if (getDivRepr(cst, pos, eqPos, dividend, divisor).failed())
       continue;
 
     if (!checkExplicitRepresentation(cst, foundRepr, dividend, pos))

>From 5e91890ebaf3c4cda566e5755b820e4b6af31d11 Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <ramkumar.ramachandra at codasip.com>
Date: Tue, 2 Jul 2024 15:06:57 +0100
Subject: [PATCH 2/2] Presburger/IntegerRelation: fix serious error

---
 mlir/lib/Analysis/Presburger/IntegerRelation.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
index d7a3a933b75dd..095a7dcb287f3 100644
--- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
@@ -2209,7 +2209,7 @@ IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) {
         d, &otherLb, &otherLbFloorDivisor, &otherUb);
     if (!otherExtent.has_value() || lbFloorDivisor != otherLbFloorDivisor)
       // TODO: symbolic extents when necessary.
-      return success();
+      return failure();
 
     assert(lbFloorDivisor > 0 && "divisor always expected to be positive");
 



More information about the Mlir-commits mailing list