[Mlir-commits] [mlir] 479c4f6 - [MLIR][Presburger] Refactor division representation to DivisionRepr

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jul 7 07:06:43 PDT 2022


Author: Groverkss
Date: 2022-07-07T15:05:28+01:00
New Revision: 479c4f648a021f1efdc30312bab804a71447e15f

URL: https://github.com/llvm/llvm-project/commit/479c4f648a021f1efdc30312bab804a71447e15f
DIFF: https://github.com/llvm/llvm-project/commit/479c4f648a021f1efdc30312bab804a71447e15f.diff

LOG: [MLIR][Presburger] Refactor division representation to DivisionRepr

This patch refactors existing implementations of division representation storage
into a new class, DivisionRepr. This refactoring is done so that the common
division utilities can be shared in an upcoming patch.

Reviewed By: arjunp

Differential Revision: https://reviews.llvm.org/D129146

Added: 
    

Modified: 
    mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
    mlir/include/mlir/Analysis/Presburger/Matrix.h
    mlir/include/mlir/Analysis/Presburger/Utils.h
    mlir/lib/Analysis/Presburger/IntegerRelation.cpp
    mlir/lib/Analysis/Presburger/Matrix.cpp
    mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
    mlir/lib/Analysis/Presburger/Utils.cpp
    mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp
    mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
index a49d50e081a13..333669f69c2e1 100644
--- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
+++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
@@ -352,22 +352,16 @@ class IntegerRelation {
   Optional<SmallVector<int64_t, 8>>
   containsPointNoLocal(ArrayRef<int64_t> point) const;
 
-  /// Find equality and pairs of inequality constraints identified by their
-  /// position indices, using which an explicit representation for each local
-  /// variable can be computed. The indices of the constraints are stored in
-  /// `MaybeLocalRepr` struct. If no such pair can be found, the kind attribute
-  /// in `MaybeLocalRepr` is set to None.
+  /// Returns a `DivisonRepr` representing the division representation of local
+  /// variables in the constraint system.
   ///
-  /// The dividends of the explicit representations are stored in `dividends`
-  /// and the denominators in `denominators`. If no explicit representation
-  /// could be found for the `i^th` local variable, `denominators[i]` is set
-  /// to 0.
-  void getLocalReprs(std::vector<SmallVector<int64_t, 8>> &dividends,
-                     SmallVector<unsigned, 4> &denominators,
-                     std::vector<MaybeLocalRepr> &repr) const;
-  void getLocalReprs(std::vector<MaybeLocalRepr> &repr) const;
-  void getLocalReprs(std::vector<SmallVector<int64_t, 8>> &dividends,
-                     SmallVector<unsigned, 4> &denominators) const;
+  /// If `repr` is not `nullptr`, the equality and pairs of inequality
+  /// constraints identified by their position indices using which an explicit
+  /// representation for each local variable can be computed are set in `repr`
+  /// in the form of a `MaybeLocalRepr` struct. If no such inequality
+  /// pair/equality can be found, the kind attribute in `MaybeLocalRepr` is set
+  /// to None.
+  DivisionRepr getLocalReprs(std::vector<MaybeLocalRepr> *repr = nullptr) const;
 
   /// The type of bound: equal, lower bound or upper bound.
   enum BoundType { EQ, LB, UB };

diff  --git a/mlir/include/mlir/Analysis/Presburger/Matrix.h b/mlir/include/mlir/Analysis/Presburger/Matrix.h
index 5b53cc5f045f9..bf32aafc6019d 100644
--- a/mlir/include/mlir/Analysis/Presburger/Matrix.h
+++ b/mlir/include/mlir/Analysis/Presburger/Matrix.h
@@ -89,6 +89,9 @@ class Matrix {
   MutableArrayRef<int64_t> getRow(unsigned row);
   ArrayRef<int64_t> getRow(unsigned row) const;
 
+  /// Set the specified row to `elems`.
+  void setRow(unsigned row, ArrayRef<int64_t> elems);
+
   /// Insert columns having positions pos, pos + 1, ... pos + count - 1.
   /// Columns that were at positions 0 to pos - 1 will stay where they are;
   /// columns that were at positions pos to nColumns - 1 will be pushed to the

diff  --git a/mlir/include/mlir/Analysis/Presburger/Utils.h b/mlir/include/mlir/Analysis/Presburger/Utils.h
index 52321c5413def..e322cb9189396 100644
--- a/mlir/include/mlir/Analysis/Presburger/Utils.h
+++ b/mlir/include/mlir/Analysis/Presburger/Utils.h
@@ -17,6 +17,8 @@
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallBitVector.h"
 
+#include "mlir/Analysis/Presburger/Matrix.h"
+
 namespace mlir {
 namespace presburger {
 
@@ -102,6 +104,80 @@ struct MaybeLocalRepr {
   } repr;
 };
 
+/// Class storing division representation of local variables of a constraint
+/// system. The coefficients of the dividends are stored in order:
+/// [nonLocalVars, localVars, constant]. Each local variable may or may not have
+/// a representation. If the local does not have a representation, the dividend
+/// of the division has no meaning and the denominator is zero.
+///
+/// The i^th division here, represents the division representation of the
+/// variable at position `divOffset + i` in the constraint system.
+class DivisionRepr {
+public:
+  DivisionRepr(unsigned numVars, unsigned numDivs)
+      : dividends(numDivs, numVars + 1), denoms(numDivs, 0) {}
+
+  DivisionRepr(unsigned numVars) : dividends(numVars + 1, 0) {}
+
+  unsigned getNumVars() const { return dividends.getNumColumns() - 1; }
+  unsigned getNumDivs() const { return dividends.getNumRows(); }
+  unsigned getNumNonDivs() const { return getNumVars() - getNumDivs(); }
+  // Get the offset from where division variables start.
+  unsigned getDivOffset() const { return getNumVars() - getNumDivs(); }
+
+  // Check whether the `i^th` division has a division representation or not.
+  bool hasRepr(unsigned i) const { return denoms[i] != 0; }
+  // Check whether all the divisions have a division representation or not.
+  bool hasAllReprs() const {
+    return all_of(denoms, [](unsigned denom) { return denom != 0; });
+  }
+
+  // Clear the division representation of the i^th local variable.
+  void clearRepr(unsigned i) { denoms[i] = 0; }
+
+  // Get the dividend of the `i^th` division.
+  MutableArrayRef<int64_t> getDividend(unsigned i) {
+    return dividends.getRow(i);
+  }
+  ArrayRef<int64_t> getDividend(unsigned i) const {
+    return dividends.getRow(i);
+  }
+
+  // Get the `i^th` denominator.
+  unsigned &getDenom(unsigned i) { return denoms[i]; }
+  unsigned getDenom(unsigned i) const { return denoms[i]; }
+
+  ArrayRef<unsigned> getDenoms() const { return denoms; }
+
+  void setDividend(unsigned i, ArrayRef<int64_t> dividend) {
+    dividends.setRow(i, dividend);
+  }
+
+  /// Removes duplicate divisions. On every possible duplicate division found,
+  /// `merge(i, j)`, where `i`, `j` are current index of the duplicate
+  /// divisions, is called and division at index `j` is merged into division at
+  /// index `i`. If `merge(i, j)` returns `true`, the divisions are merged i.e.
+  /// `j^th` division gets eliminated and it's each instance is replaced by
+  /// `i^th` division. If it returns `false`, the divisions are not merged.
+  /// `merge` can also do side effects, For example it can merge the local
+  /// variables in IntegerRelation.
+  void
+  removeDuplicateDivs(llvm::function_ref<bool(unsigned i, unsigned j)> merge);
+
+  void print(raw_ostream &os) const;
+  void dump() const;
+
+private:
+  /// Each row of the Matrix represents a single division dividend. The
+  /// `i^th` row represents the dividend of the variable at `divOffset + i`
+  /// in the constraint system (and the `i^th` local variable).
+  Matrix dividends;
+
+  /// Denominators of each division. If a denominator of a division is `0`, the
+  /// division variable is considered to not have a division representation.
+  SmallVector<unsigned, 4> denoms;
+};
+
 /// If `q` is defined to be equal to `expr floordiv d`, this equivalent to
 /// saying that `q` is an integer and `q` is subject to the inequalities
 /// `0 <= expr - d*q <= c - 1` (quotient remainder theorem).
@@ -135,25 +211,9 @@ llvm::SmallBitVector getSubrangeBitVector(unsigned len, unsigned setOffset,
 /// `MaybeLocalRepr` is set to None.
 MaybeLocalRepr computeSingleVarRepr(const IntegerRelation &cst,
                                     ArrayRef<bool> foundRepr, unsigned pos,
-                                    SmallVector<int64_t, 8> &dividend,
+                                    MutableArrayRef<int64_t> dividend,
                                     unsigned &divisor);
 
-/// Given dividends of divisions `divs` and denominators `denoms`, detects and
-/// removes duplicate divisions. `localOffset` is the offset in dividend of a
-/// division from where local variables start.
-///
-/// On every possible duplicate division found, `merge(i, j)`, where `i`, `j`
-/// are current index of the duplicate divisions, is called and division at
-/// index `j` is merged into division at index `i`. If `merge(i, j)` returns
-/// `true`, the divisions are merged i.e. `j^th` division gets eliminated and
-/// it's each instance is replaced by `i^th` division. If it returns `false`,
-/// the divisions are not merged. `merge` can also do side effects, For example
-/// it can merge the local variables in IntegerRelation.
-void removeDuplicateDivs(
-    std::vector<SmallVector<int64_t, 8>> &divs,
-    SmallVectorImpl<unsigned> &denoms, unsigned localOffset,
-    llvm::function_ref<bool(unsigned i, unsigned j)> merge);
-
 /// Given two relations, A and B, add additional local vars to the sets such
 /// that both have the union of the local vars in each set, without changing
 /// the set of points that lie in A and B.

diff  --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
index b455f4818acf6..d2427eb39b734 100644
--- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
@@ -175,8 +175,8 @@ PresburgerRelation IntegerRelation::computeReprWithOnlyDivLocals() const {
   //
   // Take a copy so we can perform mutations.
   IntegerRelation copy = *this;
-  std::vector<MaybeLocalRepr> reprs;
-  copy.getLocalReprs(reprs);
+  std::vector<MaybeLocalRepr> reprs(getNumLocalVars());
+  copy.getLocalReprs(&reprs);
 
   // Iterate through all the locals. The last `numNonDivLocals` are the locals
   // that have been scanned already and do not have division representations.
@@ -912,56 +912,39 @@ IntegerRelation::containsPointNoLocal(ArrayRef<int64_t> point) const {
   return copy.findIntegerSample();
 }
 
-void IntegerRelation::getLocalReprs(std::vector<MaybeLocalRepr> &repr) const {
-  std::vector<SmallVector<int64_t, 8>> dividends(getNumLocalVars());
-  SmallVector<unsigned, 4> denominators(getNumLocalVars());
-  getLocalReprs(dividends, denominators, repr);
-}
-
-void IntegerRelation::getLocalReprs(
-    std::vector<SmallVector<int64_t, 8>> &dividends,
-    SmallVector<unsigned, 4> &denominators) const {
-  std::vector<MaybeLocalRepr> repr(getNumLocalVars());
-  getLocalReprs(dividends, denominators, repr);
-}
-
-void IntegerRelation::getLocalReprs(
-    std::vector<SmallVector<int64_t, 8>> &dividends,
-    SmallVector<unsigned, 4> &denominators,
-    std::vector<MaybeLocalRepr> &repr) const {
-
-  repr.resize(getNumLocalVars());
-  dividends.resize(getNumLocalVars());
-  denominators.resize(getNumLocalVars());
-
+DivisionRepr
+IntegerRelation::getLocalReprs(std::vector<MaybeLocalRepr> *repr) const {
   SmallVector<bool, 8> foundRepr(getNumVars(), false);
   for (unsigned i = 0, e = getNumDimAndSymbolVars(); i < e; ++i)
     foundRepr[i] = true;
 
-  unsigned divOffset = getNumDimAndSymbolVars();
+  unsigned localOffset = getVarKindOffset(VarKind::Local);
+  DivisionRepr divs(getNumVars(), getNumLocalVars());
   bool changed;
   do {
     // Each time changed is true, at end of this iteration, one or more local
     // vars have been detected as floor divs.
     changed = false;
     for (unsigned i = 0, e = getNumLocalVars(); i < e; ++i) {
-      if (!foundRepr[i + divOffset]) {
-        MaybeLocalRepr res = computeSingleVarRepr(
-            *this, foundRepr, divOffset + i, dividends[i], denominators[i]);
-        if (!res)
+      if (!foundRepr[i + localOffset]) {
+        MaybeLocalRepr res =
+            computeSingleVarRepr(*this, foundRepr, localOffset + i,
+                                 divs.getDividend(i), divs.getDenom(i));
+        if (!res) {
+          // No representation was found, so clear the representation and
+          // continue.
+          divs.clearRepr(i);
           continue;
-        foundRepr[i + divOffset] = true;
-        repr[i] = res;
+        }
+        foundRepr[localOffset + i] = true;
+        if (repr)
+          (*repr)[i] = res;
         changed = true;
       }
     }
   } while (changed);
 
-  // Set 0 denominator for variables for which no division representation
-  // could be found.
-  for (unsigned i = 0, e = repr.size(); i < e; ++i)
-    if (!repr[i])
-      denominators[i] = 0;
+  return divs;
 }
 
 /// Tightens inequalities given that we are dealing with integer spaces. This is
@@ -1211,23 +1194,16 @@ unsigned IntegerRelation::mergeLocalVars(IntegerRelation &other) {
 }
 
 bool IntegerRelation::hasOnlyDivLocals() const {
-  std::vector<MaybeLocalRepr> reprs;
-  getLocalReprs(reprs);
-  return llvm::all_of(reprs,
-                      [](const MaybeLocalRepr &repr) { return bool(repr); });
+  return getLocalReprs().hasAllReprs();
 }
 
 void IntegerRelation::removeDuplicateDivs() {
-  std::vector<SmallVector<int64_t, 8>> divs;
-  SmallVector<unsigned, 4> denoms;
-
-  getLocalReprs(divs, denoms);
+  DivisionRepr divs = getLocalReprs();
   auto merge = [this](unsigned i, unsigned j) -> bool {
     eliminateRedundantLocalVar(i, j);
     return true;
   };
-  presburger::removeDuplicateDivs(divs, denoms,
-                                  getVarKindOffset(VarKind::Local), merge);
+  divs.removeDuplicateDivs(merge);
 }
 
 /// Removes local variables using equalities. Each equality is checked if it

diff  --git a/mlir/lib/Analysis/Presburger/Matrix.cpp b/mlir/lib/Analysis/Presburger/Matrix.cpp
index 17bbb8185ed4e..c9767ae3cee2b 100644
--- a/mlir/lib/Analysis/Presburger/Matrix.cpp
+++ b/mlir/lib/Analysis/Presburger/Matrix.cpp
@@ -92,6 +92,13 @@ ArrayRef<int64_t> Matrix::getRow(unsigned row) const {
   return {&data[row * nReservedColumns], nColumns};
 }
 
+void Matrix::setRow(unsigned row, ArrayRef<int64_t> elems) {
+  assert(elems.size() == getNumColumns() &&
+         "elems size must match row length!");
+  for (unsigned i = 0, e = getNumColumns(); i < e; ++i)
+    at(row, i) = elems[i];
+}
+
 void Matrix::insertColumn(unsigned pos) { insertColumns(pos, 1); }
 void Matrix::insertColumns(unsigned pos, unsigned count) {
   if (count == 0)

diff  --git a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
index a2e524ea9fd7c..c6a88f285c1c0 100644
--- a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
@@ -253,10 +253,8 @@ static PresburgerRelation getSetDifference(IntegerRelation b,
       //
       // Careful! This has to be done after the merge above; otherwise, the
       // dividends won't contain the new ids inserted during the merge.
-      std::vector<MaybeLocalRepr> repr;
-      std::vector<SmallVector<int64_t, 8>> dividends;
-      SmallVector<unsigned, 4> divisors;
-      sI.getLocalReprs(dividends, divisors, repr);
+      std::vector<MaybeLocalRepr> repr(sI.getNumLocalVars());
+      DivisionRepr divs = sI.getLocalReprs(&repr);
 
       // Mark which inequalities of sI are division inequalities and add all
       // such inequalities to b.
@@ -301,10 +299,10 @@ static PresburgerRelation getSetDifference(IntegerRelation b,
           // not be because they were never a part of sI; we just infer them
           // from the equality and add them only to b.
           b.addInequality(
-              getDivLowerBound(dividends[i], divisors[i],
+              getDivLowerBound(divs.getDividend(i), divs.getDenom(i),
                                sI.getVarKindOffset(VarKind::Local) + i));
           b.addInequality(
-              getDivUpperBound(dividends[i], divisors[i],
+              getDivUpperBound(divs.getDividend(i), divs.getDenom(i),
                                sI.getVarKindOffset(VarKind::Local) + i));
         }
       }

diff  --git a/mlir/lib/Analysis/Presburger/Utils.cpp b/mlir/lib/Analysis/Presburger/Utils.cpp
index 9a6fe16fe2e36..978ced1a967f8 100644
--- a/mlir/lib/Analysis/Presburger/Utils.cpp
+++ b/mlir/lib/Analysis/Presburger/Utils.cpp
@@ -21,7 +21,7 @@ using namespace presburger;
 /// Normalize a division's `dividend` and the `divisor` by their GCD. For
 /// example: if the dividend and divisor are [2,0,4] and 4 respectively,
 /// they get normalized to [1,0,2] and 2.
-static void normalizeDivisionByGCD(SmallVectorImpl<int64_t> &dividend,
+static void normalizeDivisionByGCD(MutableArrayRef<int64_t> dividend,
                                    unsigned &divisor) {
   if (divisor == 0 || dividend.empty())
     return;
@@ -89,7 +89,7 @@ static void normalizeDivisionByGCD(SmallVectorImpl<int64_t> &dividend,
 /// normalized by GCD.
 static LogicalResult getDivRepr(const IntegerRelation &cst, unsigned pos,
                                 unsigned ubIneq, unsigned lbIneq,
-                                SmallVector<int64_t, 8> &expr,
+                                MutableArrayRef<int64_t> expr,
                                 unsigned &divisor) {
 
   assert(pos <= cst.getNumVars() && "Invalid variable position");
@@ -97,6 +97,7 @@ static LogicalResult getDivRepr(const IntegerRelation &cst, unsigned pos,
          "Invalid upper bound inequality position");
   assert(lbIneq <= cst.getNumInequalities() &&
          "Invalid upper bound inequality position");
+  assert(expr.size() == cst.getNumCols() && "Invalid expression size");
 
   // Extract divisor from the lower bound.
   divisor = cst.atIneq(lbIneq, pos);
@@ -126,7 +127,6 @@ static LogicalResult getDivRepr(const IntegerRelation &cst, unsigned pos,
   // The inequality pair can be used to extract the division.
   // Set `expr` to the dividend of the division except the constant term, which
   // is set below.
-  expr.resize(cst.getNumCols(), 0);
   for (i = 0, e = cst.getNumVars(); i < e; ++i)
     if (i != pos)
       expr[i] = cst.atIneq(ubIneq, i);
@@ -152,11 +152,12 @@ static LogicalResult getDivRepr(const IntegerRelation &cst, unsigned pos,
 /// set to the denominator of the division. The final division expression is
 /// normalized by GCD.
 static LogicalResult getDivRepr(const IntegerRelation &cst, unsigned pos,
-                                unsigned eqInd, SmallVector<int64_t, 8> &expr,
+                                unsigned eqInd, MutableArrayRef<int64_t> expr,
                                 unsigned &divisor) {
 
   assert(pos <= cst.getNumVars() && "Invalid variable position");
   assert(eqInd <= cst.getNumEqualities() && "Invalid equality position");
+  assert(expr.size() == cst.getNumCols() && "Invalid expression size");
 
   // Extract divisor, the divisor can be negative and hence its sign information
   // is stored in `signDiv` to reverse the sign of dividend's coefficients.
@@ -169,7 +170,6 @@ static LogicalResult getDivRepr(const IntegerRelation &cst, unsigned pos,
   // The divisor is always a positive integer.
   divisor = tempDiv * signDiv;
 
-  expr.resize(cst.getNumCols(), 0);
   for (unsigned i = 0, e = cst.getNumVars(); i < e; ++i)
     if (i != pos)
       expr[i] = -signDiv * cst.atEq(eqInd, i);
@@ -215,10 +215,11 @@ static bool checkExplicitRepresentation(const IntegerRelation &cst,
 /// `MaybeLocalRepr` is set to None.
 MaybeLocalRepr presburger::computeSingleVarRepr(
     const IntegerRelation &cst, ArrayRef<bool> foundRepr, unsigned pos,
-    SmallVector<int64_t, 8> &dividend, unsigned &divisor) {
+    MutableArrayRef<int64_t> dividend, unsigned &divisor) {
   assert(pos < cst.getNumVars() && "invalid position");
   assert(foundRepr.size() == cst.getNumVars() &&
          "Size of foundRepr does not match total number of variables");
+  assert(dividend.size() == cst.getNumCols() && "Invalid dividend size");
 
   SmallVector<unsigned, 4> lbIndices, ubIndices, eqIndices;
   cst.getLowerAndUpperBoundIndices(pos, &lbIndices, &ubIndices, &eqIndices);
@@ -261,57 +262,6 @@ llvm::SmallBitVector presburger::getSubrangeBitVector(unsigned len,
   return vec;
 }
 
-void presburger::removeDuplicateDivs(
-    std::vector<SmallVector<int64_t, 8>> &divs,
-    SmallVectorImpl<unsigned> &denoms, unsigned localOffset,
-    llvm::function_ref<bool(unsigned i, unsigned j)> merge) {
-
-  // Find and merge duplicate divisions.
-  // TODO: Add division normalization to support divisions that 
diff er by
-  // a constant.
-  // TODO: Add division ordering such that a division representation for local
-  // variable at position `i` only depends on local variables at position <
-  // `i`. This would make sure that all divisions depending on other local
-  // variables that can be merged, are merged.
-  for (unsigned i = 0; i < divs.size(); ++i) {
-    // Check if a division representation exists for the `i^th` local var.
-    if (denoms[i] == 0)
-      continue;
-    // Check if a division exists which is a duplicate of the division at `i`.
-    for (unsigned j = i + 1; j < divs.size(); ++j) {
-      // Check if a division representation exists for the `j^th` local var.
-      if (denoms[j] == 0)
-        continue;
-      // Check if the denominators match.
-      if (denoms[i] != denoms[j])
-        continue;
-      // Check if the representations are equal.
-      if (divs[i] != divs[j])
-        continue;
-
-      // Merge divisions at position `j` into division at position `i`. If
-      // merge fails, do not merge these divs.
-      bool mergeResult = merge(i, j);
-      if (!mergeResult)
-        continue;
-
-      // Update division information to reflect merging.
-      for (unsigned k = 0, g = divs.size(); k < g; ++k) {
-        SmallVector<int64_t, 8> &div = divs[k];
-        if (denoms[k] != 0) {
-          div[localOffset + i] += div[localOffset + j];
-          div.erase(div.begin() + localOffset + j);
-        }
-      }
-
-      divs.erase(divs.begin() + j);
-      denoms.erase(denoms.begin() + j);
-      // Since `j` can never be zero, we do not need to worry about overflows.
-      --j;
-    }
-  }
-}
-
 void presburger::mergeLocalVars(
     IntegerRelation &relA, IntegerRelation &relB,
     llvm::function_ref<bool(unsigned i, unsigned j)> merge) {
@@ -327,23 +277,17 @@ void presburger::mergeLocalVars(
   relB.insertVar(VarKind::Local, 0, initLocals);
 
   // Get division representations from each rel.
-  std::vector<SmallVector<int64_t, 8>> divsA, divsB;
-  SmallVector<unsigned, 4> denomsA, denomsB;
-  relA.getLocalReprs(divsA, denomsA);
-  relB.getLocalReprs(divsB, denomsB);
-
-  // Copy division information for relB into `divsA` and `denomsA`, so that
-  // these have the combined division information of both rels. Since newly
-  // added local variables in relA and relB have no constraints, they will not
-  // have any division representation.
-  std::copy(divsB.begin() + initLocals, divsB.end(),
-            divsA.begin() + initLocals);
-  std::copy(denomsB.begin() + initLocals, denomsB.end(),
-            denomsA.begin() + initLocals);
-
-  // Merge all divisions by removing duplicate divisions.
-  unsigned localOffset = relA.getVarKindOffset(VarKind::Local);
-  presburger::removeDuplicateDivs(divsA, denomsA, localOffset, merge);
+  DivisionRepr divsA = relA.getLocalReprs();
+  DivisionRepr divsB = relB.getLocalReprs();
+
+  for (unsigned i = initLocals, e = divsB.getNumDivs(); i < e; ++i) {
+    divsA.setDividend(i, divsB.getDividend(i));
+    divsA.getDenom(i) = divsB.getDenom(i);
+  }
+
+  // Remove duplicate divisions from divsA. The removing duplicate divisions
+  // call, calls `merge` to effectively merge divisions in relA and relB.
+  divsA.removeDuplicateDivs(merge);
 }
 
 SmallVector<int64_t, 8> presburger::getDivUpperBound(ArrayRef<int64_t> dividend,
@@ -412,3 +356,59 @@ SmallVector<int64_t, 8> presburger::getComplementIneq(ArrayRef<int64_t> ineq) {
   --coeffs.back();
   return coeffs;
 }
+
+void DivisionRepr::removeDuplicateDivs(
+    llvm::function_ref<bool(unsigned i, unsigned j)> merge) {
+
+  // Find and merge duplicate divisions.
+  // TODO: Add division normalization to support divisions that 
diff er by
+  // a constant.
+  // TODO: Add division ordering such that a division representation for local
+  // variable at position `i` only depends on local variables at position <
+  // `i`. This would make sure that all divisions depending on other local
+  // variables that can be merged, are merged.
+  for (unsigned i = 0; i < getNumDivs(); ++i) {
+    // Check if a division representation exists for the `i^th` local var.
+    if (denoms[i] == 0)
+      continue;
+    // Check if a division exists which is a duplicate of the division at `i`.
+    for (unsigned j = i + 1; j < getNumDivs(); ++j) {
+      // Check if a division representation exists for the `j^th` local var.
+      if (denoms[j] == 0)
+        continue;
+      // Check if the denominators match.
+      if (denoms[i] != denoms[j])
+        continue;
+      // Check if the representations are equal.
+      if (dividends.getRow(i) != dividends.getRow(j))
+        continue;
+
+      // Merge divisions at position `j` into division at position `i`. If
+      // merge fails, do not merge these divs.
+      bool mergeResult = merge(i, j);
+      if (!mergeResult)
+        continue;
+
+      // Update division information to reflect merging.
+      unsigned divOffset = getDivOffset();
+      dividends.addToColumn(divOffset + j, divOffset + i, /*scale=*/1);
+      dividends.removeColumn(divOffset + j);
+      dividends.removeRow(j);
+      denoms.erase(denoms.begin() + j);
+
+      // Since `j` can never be zero, we do not need to worry about overflows.
+      --j;
+    }
+  }
+}
+
+void DivisionRepr::print(raw_ostream &os) const {
+  os << "Dividends:\n";
+  dividends.print(os);
+  os << "Denominators\n";
+  for (unsigned i = 0, e = denoms.size(); i < e; ++i)
+    os << denoms[i] << " ";
+  os << "\n";
+}
+
+void DivisionRepr::dump() const { print(llvm::errs()); }

diff  --git a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp
index ffe92cd1bcff7..593d818127384 100644
--- a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp
@@ -865,7 +865,7 @@ static bool detectAsFloorDiv(const FlatAffineValueConstraints &cst,
     if (exprs[i])
       foundRepr[i] = true;
 
-  SmallVector<int64_t, 8> dividend;
+  SmallVector<int64_t, 8> dividend(cst.getNumCols());
   unsigned divisor;
   auto ulPair = computeSingleVarRepr(cst, foundRepr, pos, dividend, divisor);
 

diff  --git a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
index 581558cf9bbfb..a6c721d453461 100644
--- a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
@@ -608,23 +608,19 @@ TEST(IntegerPolyhedronTest, addConstantLowerBound) {
 static void checkDivisionRepresentation(
     IntegerPolyhedron &poly,
     const std::vector<SmallVector<int64_t, 8>> &expectedDividends,
-    const SmallVectorImpl<unsigned> &expectedDenominators) {
-  std::vector<SmallVector<int64_t, 8>> dividends;
-  SmallVector<unsigned, 4> denominators;
-
-  poly.getLocalReprs(dividends, denominators);
+    ArrayRef<unsigned> expectedDenominators) {
+  DivisionRepr divs = poly.getLocalReprs();
 
   // Check that the `denominators` and `expectedDenominators` match.
-  EXPECT_TRUE(expectedDenominators == denominators);
+  EXPECT_TRUE(expectedDenominators == divs.getDenoms());
 
   // Check that the `dividends` and `expectedDividends` match. If the
   // denominator for a division is zero, we ignore its dividend.
-  EXPECT_TRUE(dividends.size() == expectedDividends.size());
-  for (unsigned i = 0, e = dividends.size(); i < e; ++i) {
-    if (denominators[i] != 0) {
-      EXPECT_TRUE(expectedDividends[i] == dividends[i]);
-    }
-  }
+  EXPECT_TRUE(divs.getNumDivs() == expectedDividends.size());
+  for (unsigned i = 0, e = divs.getNumDivs(); i < e; ++i)
+    if (divs.hasRepr(i))
+      for (unsigned j = 0, f = divs.getNumVars() + 1; j < f; ++j)
+        EXPECT_TRUE(expectedDividends[i][j] == divs.getDividend(i)[j]);
 }
 
 TEST(IntegerPolyhedronTest, computeLocalReprSimple) {


        


More information about the Mlir-commits mailing list