[Mlir-commits] [mlir] bb2226a - [MLIR][Presburger] Refactor MultiAffineFunction to be defined over universe

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Sep 10 17:12:38 PDT 2022


Author: Groverkss
Date: 2022-09-11T01:12:09+01:00
New Revision: bb2226ac53aa255d7520146ab9e0048cfc43a225

URL: https://github.com/llvm/llvm-project/commit/bb2226ac53aa255d7520146ab9e0048cfc43a225
DIFF: https://github.com/llvm/llvm-project/commit/bb2226ac53aa255d7520146ab9e0048cfc43a225.diff

LOG: [MLIR][Presburger] Refactor MultiAffineFunction to be defined over universe

This patch refactors MAF to be defined over the universe in a given space
instead of being defined over a restricted domain.

The reasoning for this refactor is to store division representation for local
variables explicitly for the function outputs. This change is required for
unionLexMax/Min to support local variables which will be upstreamed after this
patch. Another reason for this refactor is to have a flattened form of
AffineMap as MultiAffineFunction.

Reviewed By: arjunp

Differential Revision: https://reviews.llvm.org/D131864

Added: 
    

Modified: 
    mlir/include/mlir/Analysis/Presburger/Matrix.h
    mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
    mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h
    mlir/include/mlir/Analysis/Presburger/Simplex.h
    mlir/include/mlir/Analysis/Presburger/Utils.h
    mlir/lib/Analysis/Presburger/IntegerRelation.cpp
    mlir/lib/Analysis/Presburger/Matrix.cpp
    mlir/lib/Analysis/Presburger/PWMAFunction.cpp
    mlir/lib/Analysis/Presburger/PresburgerSpace.cpp
    mlir/lib/Analysis/Presburger/Simplex.cpp
    mlir/lib/Analysis/Presburger/Utils.cpp
    mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
    mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
    mlir/unittests/Analysis/Presburger/Utils.h

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/Presburger/Matrix.h b/mlir/include/mlir/Analysis/Presburger/Matrix.h
index bf32aafc6019d..e9251ba5d031c 100644
--- a/mlir/include/mlir/Analysis/Presburger/Matrix.h
+++ b/mlir/include/mlir/Analysis/Presburger/Matrix.h
@@ -128,6 +128,8 @@ class Matrix {
 
   /// Add `scale` multiples of the source row to the target row.
   void addToRow(unsigned sourceRow, unsigned targetRow, int64_t scale);
+  /// Add `scale` multiples of the rowVec row to the specified row.
+  void addToRow(unsigned row, ArrayRef<int64_t> rowVec, int64_t scale);
 
   /// Add `scale` multiples of the source column to the target column.
   void addToColumn(unsigned sourceColumn, unsigned targetColumn, int64_t scale);

diff  --git a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
index c4626a2945f01..63f3ecfca968f 100644
--- a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
+++ b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
@@ -22,94 +22,93 @@
 namespace mlir {
 namespace presburger {
 
-/// This class represents a multi-affine function whose domain is given by an
-/// IntegerPolyhedron. This can be thought of as an IntegerPolyhedron with a
-/// tuple of integer values attached to every point in the polyhedron, with the
-/// value of each element of the tuple given by an affine expression in the vars
-/// of the polyhedron. For example we could have the domain
-///
-/// (x, y) : (x >= 5, y >= x)
-///
-/// and a tuple of three integers defined at every point in the polyhedron:
+/// This class represents a multi-affine function with the domain as Z^d, where
+/// `d` is the number of domain variables of the function. For example:
 ///
 /// (x, y) -> (x + 2, 2*x - 3y + 5, 2*x + y).
 ///
-/// In this way every point in the polyhedron has a tuple of integers associated
-/// with it. If the integer polyhedron has local vars, then the output
-/// expressions can use them as well. The output expressions are represented as
-/// a matrix with one row for every element in the output vector one column for
-/// each var, and an extra column at the end for the constant term.
+/// The output expressions are represented as a matrix with one row for every
+/// output, one column for each var including division variables, and an extra
+/// column at the end for the constant term.
 ///
 /// Checking equality of two such functions is supported, as well as finding the
 /// value of the function at a specified point.
 class MultiAffineFunction {
 public:
-  MultiAffineFunction(const IntegerPolyhedron &domain, const Matrix &output)
-      : domainSet(domain), output(output) {}
-  MultiAffineFunction(const Matrix &output, const PresburgerSpace &space)
-      : domainSet(space), output(output) {}
-
-  unsigned getNumInputs() const { return domainSet.getNumDimAndSymbolVars(); }
-  unsigned getNumOutputs() const { return output.getNumRows(); }
-  bool isConsistent() const {
-    return output.getNumColumns() == domainSet.getNumVars() + 1;
+  MultiAffineFunction(const PresburgerSpace &space, const Matrix &output)
+      : space(space), output(output),
+        divs(space.getNumVars() - space.getNumRangeVars()) {
+    assertIsConsistent();
+  }
+
+  MultiAffineFunction(const PresburgerSpace &space, const Matrix &output,
+                      const DivisionRepr &divs)
+      : space(space), output(output), divs(divs) {
+    assertIsConsistent();
   }
 
-  /// Get the space of the input domain of this function.
-  const PresburgerSpace &getDomainSpace() const { return domainSet.getSpace(); }
-  /// Get the input domain of this function.
-  const IntegerPolyhedron &getDomain() const { return domainSet; }
+  unsigned getNumDomainVars() const { return space.getNumDomainVars(); }
+  unsigned getNumSymbolVars() const { return space.getNumSymbolVars(); }
+  unsigned getNumOutputs() const { return space.getNumRangeVars(); }
+  unsigned getNumDivs() const { return space.getNumLocalVars(); }
+
+  /// Get the space of this function.
+  const PresburgerSpace &getSpace() const { return space; }
+  /// Get the domain/output space of the function. The returned space is a set
+  /// space.
+  PresburgerSpace getDomainSpace() const { return space.getDomainSpace(); }
+  PresburgerSpace getOutputSpace() const { return space.getRangeSpace(); }
+
   /// Get a matrix with each row representing row^th output expression.
   const Matrix &getOutputMatrix() const { return output; }
   /// Get the `i^th` output expression.
   ArrayRef<int64_t> getOutputExpr(unsigned i) const { return output.getRow(i); }
 
-  /// Insert `num` variables of the specified kind at position `pos`.
-  /// Positions are relative to the kind of variable. The coefficient columns
-  /// corresponding to the added variables are initialized to zero. Return the
-  /// absolute column position (i.e., not relative to the kind of variable)
-  /// of the first added variable.
-  unsigned insertVar(VarKind kind, unsigned pos, unsigned num = 1);
-
-  /// Remove the specified range of vars.
-  void removeVarRange(VarKind kind, unsigned varStart, unsigned varLimit);
-
-  /// Given a MAF `other`, merges local variables such that both funcitons
-  /// have union of local vars, without changing the set of points in domain or
-  /// the output.
-  void mergeLocalVars(MultiAffineFunction &other);
-
-  /// Return whether the outputs of `this` and `other` agree wherever both
-  /// functions are defined, i.e., the outputs should be equal for all points in
-  /// the intersection of the domains.
-  bool isEqualWhereDomainsOverlap(MultiAffineFunction other) const;
-
-  /// Return whether the `this` and `other` are equal. This is the case if
-  /// they lie in the same space, i.e. have the same dimensions, and their
-  /// domains are identical and their outputs are equal on their domain.
+  // Remove the specified range of outputs.
+  void removeOutputs(unsigned start, unsigned end);
+
+  /// Given a MAF `other`, merges division variables such that both functions
+  /// have the union of the division vars that exist in the functions.
+  void mergeDivs(MultiAffineFunction &other);
+
+  /// Return the output of the function at the given point.
+  SmallVector<int64_t, 8> valueAt(ArrayRef<int64_t> point) const;
+
+  /// Return whether the `this` and `other` are equal when the domain is
+  /// restricted to `domain`. This is the case if they lie in the same space,
+  /// and their outputs are equal for every point in `domain`.
   bool isEqual(const MultiAffineFunction &other) const;
+  bool isEqual(const MultiAffineFunction &other,
+               const IntegerPolyhedron &domain) const;
+  bool isEqual(const MultiAffineFunction &other,
+               const PresburgerSet &domain) const;
 
-  /// Get the value of the function at the specified point. If the point lies
-  /// outside the domain, an empty optional is returned.
-  Optional<SmallVector<int64_t, 8>> valueAt(ArrayRef<int64_t> point) const;
+  void subtract(const MultiAffineFunction &other);
 
-  /// Truncate the output dimensions to the first `count` dimensions.
-  ///
-  /// TODO: refactor so that this can be accomplished through removeVarRange.
-  void truncateOutput(unsigned count);
+  /// Get this function as a relation.
+  IntegerRelation getAsRelation() const;
 
   void print(raw_ostream &os) const;
   void dump() const;
 
 private:
-  /// The IntegerPolyhedron representing the domain over which the function is
-  /// defined.
-  IntegerPolyhedron domainSet;
+  /// Assert that the MAF is consistent.
+  void assertIsConsistent() const;
+
+  /// The space of this function. The domain variables are considered as the
+  /// input variables of the function. The range variables are considered as
+  /// the outputs. The symbols parametrize the function and locals are used to
+  /// represent divisions. Each local variable has a corressponding division
+  /// representation stored in `divs`.
+  PresburgerSpace space;
 
   /// The function's output is a tuple of integers, with the ith element of the
   /// tuple defined by the affine expression given by the ith row of this output
   /// matrix.
   Matrix output;
+
+  /// Storage for division representation for each local variable in space.
+  DivisionRepr divs;
 };
 
 /// This class represents a piece-wise MultiAffineFunction. This can be thought
@@ -132,33 +131,47 @@ class MultiAffineFunction {
 /// finding the value of the function at a point.
 class PWMAFunction {
 public:
-  PWMAFunction(const PresburgerSpace &space, unsigned numOutputs)
-      : space(space), numOutputs(numOutputs) {
-    assert(space.getNumDomainVars() == 0 &&
-           "Set type space should have zero domain vars.");
+  struct Piece {
+    PresburgerSet domain;
+    MultiAffineFunction output;
+
+    bool isConsistent() const {
+      return domain.getSpace().isCompatible(output.getDomainSpace());
+    }
+  };
+
+  PWMAFunction(const PresburgerSpace &space) : space(space) {
     assert(space.getNumLocalVars() == 0 &&
            "PWMAFunction cannot have local vars.");
-    assert(numOutputs >= 1 && "The function must output something!");
   }
 
+  // Get the space of this function.
   const PresburgerSpace &getSpace() const { return space; }
 
-  void addPiece(const MultiAffineFunction &piece);
-  void addPiece(const IntegerPolyhedron &domain, const Matrix &output);
-  void addPiece(const PresburgerSet &domain, const Matrix &output);
+  // Add a piece ([domain, output] pair) to this function.
+  void addPiece(const Piece &piece);
 
-  const MultiAffineFunction &getPiece(unsigned i) const { return pieces[i]; }
   unsigned getNumPieces() const { return pieces.size(); }
-  unsigned getNumOutputs() const { return numOutputs; }
-  unsigned getNumInputs() const { return space.getNumVars(); }
-  MultiAffineFunction &getPiece(unsigned i) { return pieces[i]; }
+  unsigned getNumVarKind(VarKind kind) const {
+    return space.getNumVarKind(kind);
+  }
+  unsigned getNumDomainVars() const { return space.getNumDomainVars(); }
+  unsigned getNumOutputs() const { return space.getNumRangeVars(); }
+  unsigned getNumSymbolVars() const { return space.getNumSymbolVars(); }
+
+  /// Remove the specified range of outputs.
+  void removeOutputs(unsigned start, unsigned end);
+
+  /// Get the domain/output space of the function. The returned space is a set
+  /// space.
+  PresburgerSpace getDomainSpace() const { return space.getDomainSpace(); }
+  PresburgerSpace getOutputSpace() const { return space.getDomainSpace(); }
 
   /// Return the domain of this piece-wise MultiAffineFunction. This is the
   /// union of the domains of all the pieces.
   PresburgerSet getDomain() const;
 
-  /// Return the value at the specified point and an empty optional if the
-  /// point does not lie in the domain.
+  /// Return the output of the function at the given point.
   Optional<SmallVector<int64_t, 8>> valueAt(ArrayRef<int64_t> point) const;
 
   /// Return whether `this` and `other` are equal as PWMAFunctions, i.e. whether
@@ -166,11 +179,6 @@ class PWMAFunction {
   /// value at every point in the domain.
   bool isEqual(const PWMAFunction &other) const;
 
-  /// Truncate the output dimensions to the first `count` dimensions.
-  ///
-  /// TODO: refactor so that this can be accomplished through removeVarRange.
-  void truncateOutput(unsigned count);
-
   /// Return a function defined on the union of the domains of this and func,
   /// such that when only one of the functions is defined, it outputs the same
   /// as that function, and if both are defined, it outputs the lexmax/lexmin of
@@ -178,8 +186,8 @@ class PWMAFunction {
   /// function is not defined either.
   ///
   /// Currently this does not support PWMAFunctions which have pieces containing
-  /// local variables.
-  /// TODO: Support local variables in peices.
+  /// divisions.
+  /// TODO: Support division in pieces.
   PWMAFunction unionLexMin(const PWMAFunction &func);
   PWMAFunction unionLexMax(const PWMAFunction &func);
 
@@ -200,19 +208,17 @@ class PWMAFunction {
   ///
   /// The PresburgerSet returned by `tiebreak` should be disjoint.
   /// TODO: Remove this constraint of returning disjoint set.
-  PWMAFunction
-  unionFunction(const PWMAFunction &func,
-                llvm::function_ref<PresburgerSet(MultiAffineFunction mafA,
-                                                 MultiAffineFunction mafB)>
-                    tiebreak) const;
+  PWMAFunction unionFunction(
+      const PWMAFunction &func,
+      llvm::function_ref<PresburgerSet(Piece mafA, Piece mafB)> tiebreak) const;
 
+  /// The space of this function. The domain variables are considered as the
+  /// input variables of the function. The range variables are considered as
+  /// the outputs. The symbols paramterize the function.
   PresburgerSpace space;
 
-  /// The list of pieces in this piece-wise MultiAffineFunction.
-  SmallVector<MultiAffineFunction, 4> pieces;
-
-  /// The number of output vars.
-  unsigned numOutputs;
+  // The pieces of the PWMAFunction.
+  SmallVector<Piece, 4> pieces;
 };
 
 } // namespace presburger

diff  --git a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h
index ff2c2ad85edd9..03a5dfb0631e3 100644
--- a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h
+++ b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h
@@ -90,6 +90,11 @@ class PresburgerSpace {
                            numLocals);
   }
 
+  // Get the domain/range space of this space. The returned space is a set
+  // space.
+  PresburgerSpace getDomainSpace() const;
+  PresburgerSpace getRangeSpace() const;
+
   unsigned getNumDomainVars() const { return numDomain; }
   unsigned getNumRangeVars() const { return numRange; }
   unsigned getNumSetDimVars() const { return numRange; }

diff  --git a/mlir/include/mlir/Analysis/Presburger/Simplex.h b/mlir/include/mlir/Analysis/Presburger/Simplex.h
index 2c3aedea8452f..485a064c6ccef 100644
--- a/mlir/include/mlir/Analysis/Presburger/Simplex.h
+++ b/mlir/include/mlir/Analysis/Presburger/Simplex.h
@@ -529,9 +529,9 @@ class LexSimplex : public LexSimplexBase {
 
 /// Represents the result of a symbolic lexicographic minimization computation.
 struct SymbolicLexMin {
-  SymbolicLexMin(const PresburgerSpace &domainSpace, unsigned numOutputs)
-      : lexmin(domainSpace, numOutputs),
-        unboundedDomain(PresburgerSet::getEmpty(domainSpace)) {}
+  SymbolicLexMin(const PresburgerSpace &space)
+      : lexmin(space),
+        unboundedDomain(PresburgerSet::getEmpty(space.getDomainSpace())) {}
 
   /// This maps assignments of symbols to the corresponding lexmin.
   /// Takes no value when no integer sample exists for the assignment or if the

diff  --git a/mlir/include/mlir/Analysis/Presburger/Utils.h b/mlir/include/mlir/Analysis/Presburger/Utils.h
index 6d81c65b0faea..3801cb63af5ea 100644
--- a/mlir/include/mlir/Analysis/Presburger/Utils.h
+++ b/mlir/include/mlir/Analysis/Presburger/Utils.h
@@ -118,7 +118,7 @@ class DivisionRepr {
   DivisionRepr(unsigned numVars, unsigned numDivs)
       : dividends(numDivs, numVars + 1), denoms(numDivs, 0) {}
 
-  DivisionRepr(unsigned numVars) : dividends(numVars + 1, 0) {}
+  DivisionRepr(unsigned numVars) : dividends(0, numVars + 1) {}
 
   unsigned getNumVars() const { return dividends.getNumColumns() - 1; }
   unsigned getNumDivs() const { return dividends.getNumRows(); }
@@ -142,16 +142,25 @@ class DivisionRepr {
     return dividends.getRow(i);
   }
 
+  // For a given point containing values for each variable other than the
+  // division variables, try to find the values for each division variable from
+  // their division representation.
+  SmallVector<Optional<int64_t>, 4> divValuesAt(ArrayRef<int64_t> point) const;
+
   // Get the `i^th` denominator.
   unsigned &getDenom(unsigned i) { return denoms[i]; }
   unsigned getDenom(unsigned i) const { return denoms[i]; }
 
   ArrayRef<unsigned> getDenoms() const { return denoms; }
 
-  void setDividend(unsigned i, ArrayRef<int64_t> dividend) {
+  void setDiv(unsigned i, ArrayRef<int64_t> dividend, unsigned divisor) {
     dividends.setRow(i, dividend);
+    denoms[i] = divisor;
   }
 
+  void insertDiv(unsigned pos, ArrayRef<int64_t> dividend, unsigned divisor);
+  void insertDiv(unsigned pos, unsigned num = 1);
+
   /// Removes duplicate divisions. On every possible duplicate division found,
   /// `merge(i, j)`, where `i`, `j` are current index of the duplicate
   /// divisions, is called and division at index `j` is merged into division at

diff  --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
index 18b48b5cdc8e2..03252ce0f4c8b 100644
--- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
@@ -238,6 +238,7 @@ SymbolicLexMin IntegerRelation::findSymbolicIntegerLexMin() const {
                getVarKindEnd(VarKind::Domain));
   // Compute the symbolic lexmin of the dims and locals, with the symbols being
   // the actual symbols of this set.
+  // The resultant space of lexmin is the space of the relation itself.
   SymbolicLexMin result =
       SymbolicLexSimplex(*this,
                          IntegerPolyhedron(PresburgerSpace::getSetSpace(
@@ -248,8 +249,8 @@ SymbolicLexMin IntegerRelation::findSymbolicIntegerLexMin() const {
 
   // We want to return only the lexmin over the dims, so strip the locals from
   // the computed lexmin.
-  result.lexmin.truncateOutput(result.lexmin.getNumOutputs() -
-                               getNumLocalVars());
+  result.lexmin.removeOutputs(result.lexmin.getNumOutputs() - getNumLocalVars(),
+                              result.lexmin.getNumOutputs());
   return result;
 }
 

diff  --git a/mlir/lib/Analysis/Presburger/Matrix.cpp b/mlir/lib/Analysis/Presburger/Matrix.cpp
index c9767ae3cee2b..c51aa3c922eac 100644
--- a/mlir/lib/Analysis/Presburger/Matrix.cpp
+++ b/mlir/lib/Analysis/Presburger/Matrix.cpp
@@ -192,10 +192,14 @@ void Matrix::fillRow(unsigned row, int64_t value) {
 }
 
 void Matrix::addToRow(unsigned sourceRow, unsigned targetRow, int64_t scale) {
+  addToRow(targetRow, getRow(sourceRow), scale);
+}
+
+void Matrix::addToRow(unsigned row, ArrayRef<int64_t> rowVec, int64_t scale) {
   if (scale == 0)
     return;
   for (unsigned col = 0; col < nColumns; ++col)
-    at(targetRow, col) += scale * at(sourceRow, col);
+    at(row, col) += scale * rowVec[col];
 }
 
 void Matrix::addToColumn(unsigned sourceColumn, unsigned targetColumn,

diff  --git a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp
index 18b5d0e7c68dc..d1d3925c59462 100644
--- a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp
+++ b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp
@@ -12,11 +12,25 @@
 using namespace mlir;
 using namespace presburger;
 
+void MultiAffineFunction::assertIsConsistent() const {
+  assert(space.getNumVars() - space.getNumRangeVars() + 1 ==
+             output.getNumColumns() &&
+         "Inconsistent number of output columns");
+  assert(space.getNumDomainVars() + space.getNumSymbolVars() ==
+             divs.getNumNonDivs() &&
+         "Inconsistent number of non-division variables in divs");
+  assert(space.getNumRangeVars() == output.getNumRows() &&
+         "Inconsistent number of output rows");
+  assert(space.getNumLocalVars() == divs.getNumDivs() &&
+         "Inconsistent number of divisions.");
+  assert(divs.hasAllReprs() && "All divisions should have a representation");
+}
+
 // Return the result of subtracting the two given vectors pointwise.
 // The vectors must be of the same size.
 // e.g., [3, 4, 6] - [2, 5, 1] = [1, -1, 5].
-static SmallVector<int64_t, 8> subtract(ArrayRef<int64_t> vecA,
-                                        ArrayRef<int64_t> vecB) {
+static SmallVector<int64_t, 8> subtractExprs(ArrayRef<int64_t> vecA,
+                                             ArrayRef<int64_t> vecB) {
   assert(vecA.size() == vecB.size() &&
          "Cannot subtract vectors of 
diff ering lengths!");
   SmallVector<int64_t, 8> result;
@@ -27,152 +41,135 @@ static SmallVector<int64_t, 8> subtract(ArrayRef<int64_t> vecA,
 }
 
 PresburgerSet PWMAFunction::getDomain() const {
-  PresburgerSet domain = PresburgerSet::getEmpty(getSpace());
-  for (const MultiAffineFunction &piece : pieces)
-    domain.unionInPlace(piece.getDomain());
+  PresburgerSet domain = PresburgerSet::getEmpty(getDomainSpace());
+  for (const Piece &piece : pieces)
+    domain.unionInPlace(piece.domain);
   return domain;
 }
 
-Optional<SmallVector<int64_t, 8>>
+void MultiAffineFunction::print(raw_ostream &os) const {
+  space.print(os);
+  os << "Division Representation:\n";
+  divs.print(os);
+  os << "Output:\n";
+  output.print(os);
+}
+
+SmallVector<int64_t, 8>
 MultiAffineFunction::valueAt(ArrayRef<int64_t> point) const {
-  assert(point.size() == domainSet.getNumDimAndSymbolVars() &&
+  assert(point.size() == getNumDomainVars() + getNumSymbolVars() &&
          "Point has incorrect dimensionality!");
 
-  Optional<SmallVector<int64_t, 8>> maybeLocalValues =
-      getDomain().containsPointNoLocal(point);
-  if (!maybeLocalValues)
-    return {};
-
-  // The point lies in the domain, so we need to compute the output value.
   SmallVector<int64_t, 8> pointHomogenous{llvm::to_vector(point)};
-  // The given point didn't include the values of locals which the output is a
-  // function of; we have computed one possible set of values and use them
-  // here. The function is not allowed to have local vars that take more than
-  // one possible value.
-  pointHomogenous.append(*maybeLocalValues);
+  // Get the division values at this point.
+  SmallVector<Optional<int64_t>, 8> divValues = divs.divValuesAt(point);
+  // The given point didn't include the values of the divs which the output is a
+  // function of; we have computed one possible set of values and use them here.
+  pointHomogenous.reserve(pointHomogenous.size() + divValues.size());
+  for (const Optional<int64_t> &divVal : divValues)
+    pointHomogenous.push_back(*divVal);
   // The matrix `output` has an affine expression in the ith row, corresponding
   // to the expression for the ith value in the output vector. The last column
   // of the matrix contains the constant term. Let v be the input point with
   // a 1 appended at the end. We can see that output * v gives the desired
   // output vector.
-  pointHomogenous.emplace_back(1);
+  pointHomogenous.push_back(1);
   SmallVector<int64_t, 8> result =
       output.postMultiplyWithColumn(pointHomogenous);
   assert(result.size() == getNumOutputs());
   return result;
 }
 
-Optional<SmallVector<int64_t, 8>>
-PWMAFunction::valueAt(ArrayRef<int64_t> point) const {
-  assert(point.size() == getNumInputs() &&
-         "Point has incorrect dimensionality!");
-  for (const MultiAffineFunction &piece : pieces)
-    if (Optional<SmallVector<int64_t, 8>> output = piece.valueAt(point))
-      return output;
-  return {};
+bool MultiAffineFunction::isEqual(const MultiAffineFunction &other) const {
+  assert(space.isCompatible(other.space) &&
+         "Spaces should be compatible for equality check.");
+  return getAsRelation().isEqual(other.getAsRelation());
 }
 
-void MultiAffineFunction::print(raw_ostream &os) const {
-  os << "Domain:";
-  domainSet.print(os);
-  os << "Output:\n";
-  output.print(os);
-  os << "\n";
-}
+bool MultiAffineFunction::isEqual(const MultiAffineFunction &other,
+                                  const IntegerPolyhedron &domain) const {
+  assert(space.isCompatible(other.space) &&
+         "Spaces should be compatible for equality check.");
+  IntegerRelation restrictedThis = getAsRelation();
+  restrictedThis.intersectDomain(domain);
 
-void MultiAffineFunction::dump() const { print(llvm::errs()); }
+  IntegerRelation restrictedOther = other.getAsRelation();
+  restrictedOther.intersectDomain(domain);
 
-bool MultiAffineFunction::isEqual(const MultiAffineFunction &other) const {
-  return getDomainSpace().isCompatible(other.getDomainSpace()) &&
-         getDomain().isEqual(other.getDomain()) &&
-         isEqualWhereDomainsOverlap(other);
+  return restrictedThis.isEqual(restrictedOther);
 }
 
-unsigned MultiAffineFunction::insertVar(VarKind kind, unsigned pos,
-                                        unsigned num) {
-  assert(kind != VarKind::Domain && "Domain has to be zero in a set");
-  unsigned absolutePos = domainSet.getVarKindOffset(kind) + pos;
-  output.insertColumns(absolutePos, num);
-  return domainSet.insertVar(kind, pos, num);
+bool MultiAffineFunction::isEqual(const MultiAffineFunction &other,
+                                  const PresburgerSet &domain) const {
+  assert(space.isCompatible(other.space) &&
+         "Spaces should be compatible for equality check.");
+  return llvm::all_of(domain.getAllDisjuncts(),
+                      [&](const IntegerRelation &disjunct) {
+                        return isEqual(other, IntegerPolyhedron(disjunct));
+                      });
 }
 
-void MultiAffineFunction::removeVarRange(VarKind kind, unsigned varStart,
-                                         unsigned varLimit) {
-  output.removeColumns(varStart + domainSet.getVarKindOffset(kind),
-                       varLimit - varStart);
-  domainSet.removeVarRange(kind, varStart, varLimit);
-}
+void MultiAffineFunction::removeOutputs(unsigned start, unsigned end) {
+  assert(end <= getNumOutputs() && "Invalid range");
 
-void MultiAffineFunction::truncateOutput(unsigned count) {
-  assert(count <= output.getNumRows());
-  output.resizeVertically(count);
-}
+  if (start >= end)
+    return;
 
-void PWMAFunction::truncateOutput(unsigned count) {
-  assert(count <= numOutputs);
-  for (MultiAffineFunction &piece : pieces)
-    piece.truncateOutput(count);
-  numOutputs = count;
+  space.removeVarRange(VarKind::Range, start, end);
+  output.removeRows(start, end - start);
 }
 
-void MultiAffineFunction::mergeLocalVars(MultiAffineFunction &other) {
-  // Merge output local vars of both functions without using division
-  // information i.e. append local vars of `other` to `this` and insert
-  // local vars of `this` to `other` at the start of it's local vars.
-  output.insertColumns(domainSet.getVarKindEnd(VarKind::Local),
-                       other.domainSet.getNumLocalVars());
-  other.output.insertColumns(other.domainSet.getVarKindOffset(VarKind::Local),
-                             domainSet.getNumLocalVars());
+void MultiAffineFunction::mergeDivs(MultiAffineFunction &other) {
+  assert(space.isCompatible(other.space) && "Functions should be compatible");
 
-  auto merge = [this, &other](unsigned i, unsigned j) -> bool {
-    // Merge local at position j into local at position i in function domain.
-    domainSet.eliminateRedundantLocalVar(i, j);
-    other.domainSet.eliminateRedundantLocalVar(i, j);
+  unsigned nDivs = getNumDivs();
+  unsigned divOffset = divs.getDivOffset();
 
-    unsigned localOffset = domainSet.getVarKindOffset(VarKind::Local);
+  other.divs.insertDiv(0, nDivs);
 
-    // Merge local at position j into local at position i in output domain.
-    output.addToColumn(localOffset + j, localOffset + i, 1);
-    output.removeColumn(localOffset + j);
-    other.output.addToColumn(localOffset + j, localOffset + i, 1);
-    other.output.removeColumn(localOffset + j);
+  SmallVector<int64_t, 8> div(other.divs.getNumVars() + 1);
+  for (unsigned i = 0; i < nDivs; ++i) {
+    // Zero fill.
+    std::fill(div.begin(), div.end(), 0);
+    // Fill div with dividend from `divs`. Do not fill the constant.
+    std::copy(divs.getDividend(i).begin(), divs.getDividend(i).end() - 1,
+              div.begin());
+    // Fill constant.
+    div.back() = divs.getDividend(i).back();
+    other.divs.setDiv(i, div, divs.getDenom(i));
+  }
+
+  other.space.insertVar(VarKind::Local, 0, nDivs);
+  other.output.insertColumns(divOffset, nDivs);
+
+  auto merge = [&](unsigned i, unsigned j) {
+    // We only merge from local at pos j to local at pos i, where j > i.
+    if (i >= j)
+      return false;
 
+    // If i < nDivs, we are trying to merge duplicate divs in `this`. Since we
+    // do not want to merge duplicates in `this`, we ignore this call.
+    if (j < nDivs)
+      return false;
+
+    // Merge things in space and output.
+    other.space.removeVarRange(VarKind::Local, j, j + 1);
+    other.output.addToColumn(divOffset + i, divOffset + j, 1);
+    other.output.removeColumn(divOffset + j);
     return true;
   };
 
-  presburger::mergeLocalVars(domainSet, other.domainSet, merge);
-}
+  other.divs.removeDuplicateDivs(merge);
 
-bool MultiAffineFunction::isEqualWhereDomainsOverlap(
-    MultiAffineFunction other) const {
-  if (!getDomainSpace().isCompatible(other.getDomainSpace()))
-    return false;
+  unsigned newDivs = other.divs.getNumDivs() - nDivs;
 
-  // `commonFunc` has the same output as `this`.
-  MultiAffineFunction commonFunc = *this;
-  // After this merge, `commonFunc` and `other` have the same local vars; they
-  // are merged.
-  commonFunc.mergeLocalVars(other);
-  // After this, the domain of `commonFunc` will be the intersection of the
-  // domains of `this` and `other`.
-  commonFunc.domainSet.append(other.domainSet);
-
-  // `commonDomainMatching` contains the subset of the common domain
-  // where the outputs of `this` and `other` match.
-  //
-  // We want to add constraints equating the outputs of `this` and `other`.
-  // However, `this` may have 
diff erence local vars from `other`, whereas we
-  // need both to have the same locals. Accordingly, we use `commonFunc.output`
-  // in place of `this->output`, since `commonFunc` has the same output but also
-  // has its locals merged.
-  IntegerPolyhedron commonDomainMatching = commonFunc.getDomain();
-  for (unsigned row = 0, e = getNumOutputs(); row < e; ++row)
-    commonDomainMatching.addEquality(
-        subtract(commonFunc.output.getRow(row), other.output.getRow(row)));
-
-  // If the whole common domain is a subset of commonDomainMatching, then they
-  // are equal and the two functions match on the whole common domain.
-  return commonFunc.getDomain().isSubsetOf(commonDomainMatching);
+  space.insertVar(VarKind::Local, nDivs, newDivs);
+  output.insertColumns(divOffset + nDivs, newDivs);
+  divs = other.divs;
+
+  // Check consistency.
+  assertIsConsistent();
+  other.assertIsConsistent();
 }
 
 /// Two PWMAFunctions are equal if they have the same dimensionalities,
@@ -188,89 +185,79 @@ bool PWMAFunction::isEqual(const PWMAFunction &other) const {
   // overlap, they take the same output value. If `this` and `other` have the
   // same domain (checked above), then this check passes iff the two functions
   // have the same output at every point in the domain.
-  for (const MultiAffineFunction &aPiece : this->pieces)
-    for (const MultiAffineFunction &bPiece : other.pieces)
-      if (!aPiece.isEqualWhereDomainsOverlap(bPiece))
-        return false;
-  return true;
+  return llvm::all_of(this->pieces, [&other](const Piece &pieceA) {
+    return llvm::all_of(other.pieces, [&pieceA](const Piece &pieceB) {
+      PresburgerSet commonDomain = pieceA.domain.intersect(pieceB.domain);
+      return pieceA.output.isEqual(pieceB.output, commonDomain);
+    });
+  });
 }
 
-void PWMAFunction::addPiece(const MultiAffineFunction &piece) {
-  assert(space.isCompatible(piece.getDomainSpace()) &&
-         "Piece to be added is not compatible with this PWMAFunction!");
-  assert(piece.isConsistent() && "Piece is internally inconsistent!");
-  assert(this->getDomain()
-             .intersect(PresburgerSet(piece.getDomain()))
-             .isIntegerEmpty() &&
-         "New piece's domain overlaps with that of existing pieces!");
+void PWMAFunction::addPiece(const Piece &piece) {
+  assert(piece.isConsistent() && "Piece should be consistent");
   pieces.push_back(piece);
 }
 
-void PWMAFunction::addPiece(const IntegerPolyhedron &domain,
-                            const Matrix &output) {
-  addPiece(MultiAffineFunction(domain, output));
-}
-
-void PWMAFunction::addPiece(const PresburgerSet &domain, const Matrix &output) {
-  for (const IntegerRelation &newDom : domain.getAllDisjuncts())
-    addPiece(IntegerPolyhedron(newDom), output);
-}
-
 void PWMAFunction::print(raw_ostream &os) const {
-  os << pieces.size() << " pieces:\n";
-  for (const MultiAffineFunction &piece : pieces)
-    piece.print(os);
+  space.print(os);
+  os << getNumPieces() << " pieces:\n";
+  for (const Piece &piece : pieces) {
+    os << "Domain of piece:\n";
+    piece.domain.print(os);
+    os << "Output of piece\n";
+    piece.output.print(os);
+  }
 }
 
 void PWMAFunction::dump() const { print(llvm::errs()); }
 
 PWMAFunction PWMAFunction::unionFunction(
     const PWMAFunction &func,
-    llvm::function_ref<PresburgerSet(MultiAffineFunction maf1,
-                                     MultiAffineFunction maf2)>
-        tiebreak) const {
+    llvm::function_ref<PresburgerSet(Piece maf1, Piece maf2)> tiebreak) const {
   assert(getNumOutputs() == func.getNumOutputs() &&
-         "Number of outputs of functions should be same.");
+         "Ranges of functions should be same.");
   assert(getSpace().isCompatible(func.getSpace()) &&
          "Space is not compatible.");
 
   // The algorithm used here is as follows:
-  // - Add the output of funcB for the part of the domain where both funcA and
-  //   funcB are defined, and `tiebreak` chooses the output of funcB.
-  // - Add the output of funcA, where funcB is not defined or `tiebreak` chooses
-  //   funcA over funcB.
-  // - Add the output of funcB, where funcA is not defined.
-
-  // Add parts of the common domain where funcB's output is used. Also
-  // add all the parts where funcA's output is used, both common and non-common.
-  PWMAFunction result(getSpace(), getNumOutputs());
-  for (const MultiAffineFunction &funcA : pieces) {
-    PresburgerSet dom(funcA.getDomain());
-    for (const MultiAffineFunction &funcB : func.pieces) {
-      PresburgerSet better = tiebreak(funcB, funcA);
-      // Add the output of funcB, where it is better than output of funcA.
+  // - Add the output of pieceB for the part of the domain where both pieceA and
+  //   pieceB are defined, and `tiebreak` chooses the output of pieceB.
+  // - Add the output of pieceA, where pieceB is not defined or `tiebreak`
+  // chooses
+  //   pieceA over pieceB.
+  // - Add the output of pieceB, where pieceA is not defined.
+
+  // Add parts of the common domain where pieceB's output is used. Also
+  // add all the parts where pieceA's output is used, both common and
+  // non-common.
+  PWMAFunction result(getSpace());
+  for (const Piece &pieceA : pieces) {
+    PresburgerSet dom(pieceA.domain);
+    for (const Piece &pieceB : func.pieces) {
+      PresburgerSet better = tiebreak(pieceB, pieceA);
+      // Add the output of pieceB, where it is better than output of pieceA.
       // The disjuncts in "better" will be disjoint as tiebreak should gurantee
       // that.
-      result.addPiece(better, funcB.getOutputMatrix());
+      result.addPiece({better, pieceB.output});
       dom = dom.subtract(better);
     }
-    // Add output of funcA, where it is better than funcB, or funcB is not
+    // Add output of pieceA, where it is better than pieceB, or pieceB is not
     // defined.
     //
     // `dom` here is guranteed to be disjoint from already added pieces
     // because because the pieces added before are either:
     // - Subsets of the domain of other MAFs in `this`, which are guranteed
     //   to be disjoint from `dom`, or
-    // - They are one of the pieces added for `funcB`, and we have been
+    // - They are one of the pieces added for `pieceB`, and we have been
     //   subtracting all such pieces from `dom`, so `dom` is disjoint from those
     //   pieces as well.
-    result.addPiece(dom, funcA.getOutputMatrix());
+    result.addPiece({dom, pieceA.output});
   }
 
-  // Add parts of funcB which are not shared with funcA.
+  // Add parts of pieceB which are not shared with pieceA.
   PresburgerSet dom = getDomain();
-  for (const MultiAffineFunction &funcB : func.pieces)
-    result.addPiece(funcB.getDomain().subtract(dom), funcB.getOutputMatrix());
+  for (const Piece &pieceB : func.pieces)
+    result.addPiece({pieceB.domain.subtract(dom), pieceB.output});
 
   return result;
 }
@@ -280,21 +267,19 @@ PWMAFunction PWMAFunction::unionFunction(
 /// taking the lexicographically smaller output and otherwise, by taking the
 /// lexicographically larger output.
 template <bool lexMin>
-static PresburgerSet tiebreakLex(const MultiAffineFunction &mafA,
-                                 const MultiAffineFunction &mafB) {
+static PresburgerSet tiebreakLex(const PWMAFunction::Piece &pieceA,
+                                 const PWMAFunction::Piece &pieceB) {
   // TODO: Support local variables here.
-  assert(mafA.getDomainSpace().isCompatible(mafB.getDomainSpace()) &&
-         "Domain spaces should be compatible.");
-  assert(mafA.getNumOutputs() == mafB.getNumOutputs() &&
-         "Number of outputs of both functions should be same.");
-  assert(mafA.getDomain().getNumLocalVars() == 0 &&
+  assert(pieceA.output.getSpace().isCompatible(pieceB.output.getSpace()) &&
+         "Pieces should be compatible");
+  assert(pieceA.domain.getSpace().getNumLocalVars() == 0 &&
          "Local variables are not supported yet.");
 
-  PresburgerSpace compatibleSpace = mafA.getDomain().getSpaceWithoutLocals();
-  const PresburgerSpace &space = mafA.getDomain().getSpace();
+  PresburgerSpace compatibleSpace = pieceA.domain.getSpace();
+  const PresburgerSpace &space = pieceA.domain.getSpace();
 
   // We first create the set `result`, corresponding to the set where output
-  // of mafA is lexicographically larger/smaller than mafB. This is done by
+  // of pieceA is lexicographically larger/smaller than pieceB. This is done by
   // creating a PresburgerSet with the following constraints:
   //
   //    (outA[0] > outB[0]) U
@@ -312,14 +297,15 @@ static PresburgerSet tiebreakLex(const MultiAffineFunction &mafA,
   //    ...
   //    (outA[0] = outB[0], ..., outA[n-2] = outB[n-2], outA[n-1] < outB[n-1])
   PresburgerSet result = PresburgerSet::getEmpty(compatibleSpace);
-  IntegerPolyhedron levelSet(/*numReservedInequalities=*/1,
-                             /*numReservedEqualities=*/mafA.getNumOutputs(),
-                             /*numReservedCols=*/space.getNumVars() + 1, space);
-  for (unsigned level = 0; level < mafA.getNumOutputs(); ++level) {
+  IntegerPolyhedron levelSet(
+      /*numReservedInequalities=*/1,
+      /*numReservedEqualities=*/pieceA.output.getNumOutputs(),
+      /*numReservedCols=*/space.getNumVars() + 1, space);
+  for (unsigned level = 0; level < pieceA.output.getNumOutputs(); ++level) {
 
     // Create the expression `outA - outB` for this level.
-    SmallVector<int64_t, 8> subExpr =
-        subtract(mafA.getOutputExpr(level), mafB.getOutputExpr(level));
+    SmallVector<int64_t, 8> subExpr = subtractExprs(
+        pieceA.output.getOutputExpr(level), pieceB.output.getOutputExpr(level));
 
     if (lexMin) {
       // For lexMin, we add an upper bound of -1:
@@ -343,10 +329,9 @@ static PresburgerSet tiebreakLex(const MultiAffineFunction &mafA,
     levelSet.addEquality(subExpr);
   }
 
-  // We then intersect `result` with the domain of mafA and mafB, to only
+  // We then intersect `result` with the domain of pieceA and pieceB, to only
   // tiebreak on the domain where both are defined.
-  result = result.intersect(PresburgerSet(mafA.getDomain()))
-               .intersect(PresburgerSet(mafB.getDomain()));
+  result = result.intersect(pieceA.domain).intersect(pieceB.domain);
 
   return result;
 }
@@ -358,3 +343,93 @@ PWMAFunction PWMAFunction::unionLexMin(const PWMAFunction &func) {
 PWMAFunction PWMAFunction::unionLexMax(const PWMAFunction &func) {
   return unionFunction(func, tiebreakLex</*lexMin=*/false>);
 }
+
+void MultiAffineFunction::subtract(const MultiAffineFunction &other) {
+  assert(space.isCompatible(other.space) &&
+         "Spaces should be compatible for subtraction.");
+
+  MultiAffineFunction copyOther = other;
+  mergeDivs(copyOther);
+  for (unsigned i = 0, e = getNumOutputs(); i < e; ++i)
+    output.addToRow(i, copyOther.getOutputExpr(i), -1);
+
+  // Check consistency.
+  assertIsConsistent();
+}
+
+/// Adds division constraints corresponding to local variables, given a
+/// relation and division representations of the local variables in the
+/// relation.
+static void addDivisionConstraints(IntegerRelation &rel,
+                                   const DivisionRepr &divs) {
+  assert(divs.hasAllReprs() &&
+         "All divisions in divs should have a representation");
+  assert(rel.getNumVars() == divs.getNumVars() &&
+         "Relation and divs should have the same number of vars");
+  assert(rel.getNumLocalVars() == divs.getNumDivs() &&
+         "Relation and divs should have the same number of local vars");
+
+  for (unsigned i = 0, e = divs.getNumDivs(); i < e; ++i) {
+    rel.addInequality(getDivUpperBound(divs.getDividend(i), divs.getDenom(i),
+                                       divs.getDivOffset() + i));
+    rel.addInequality(getDivLowerBound(divs.getDividend(i), divs.getDenom(i),
+                                       divs.getDivOffset() + i));
+  }
+}
+
+IntegerRelation MultiAffineFunction::getAsRelation() const {
+  // Create a relation corressponding to the input space plus the divisions
+  // used in outputs.
+  IntegerRelation result(PresburgerSpace::getRelationSpace(
+      space.getNumDomainVars(), 0, space.getNumSymbolVars(),
+      space.getNumLocalVars()));
+  // Add division constraints corresponding to divisions used in outputs.
+  addDivisionConstraints(result, divs);
+  // The outputs are represented as range variables in the relation. We add
+  // range variables for the outputs.
+  result.insertVar(VarKind::Range, 0, getNumOutputs());
+
+  // Add equalities such that the i^th range variable is equal to the i^th
+  // output expression.
+  SmallVector<int64_t, 8> eq(result.getNumCols());
+  for (unsigned i = 0, e = getNumOutputs(); i < e; ++i) {
+    // TODO: Add functions to get VarKind offsets in output in MAF and use them
+    // here.
+    // The output expression does not contain range variables, while the
+    // equality does. So, we need to copy all variables and mark all range
+    // variables as 0 in the equality.
+    ArrayRef<int64_t> expr = getOutputExpr(i);
+    // Copy domain variables in `expr` to domain variables in `eq`.
+    std::copy(expr.begin(), expr.begin() + getNumDomainVars(), eq.begin());
+    // Fill the range variables in `eq` as zero.
+    std::fill(eq.begin() + result.getVarKindOffset(VarKind::Range),
+              eq.begin() + result.getVarKindEnd(VarKind::Range), 0);
+    // Copy remaining variables in `expr` to the remaining variables in `eq`.
+    std::copy(expr.begin() + getNumDomainVars(), expr.end(),
+              eq.begin() + result.getVarKindEnd(VarKind::Range));
+
+    // Set the i^th range var to -1 in `eq` to equate the output expression to
+    // this range var.
+    eq[result.getVarKindOffset(VarKind::Range) + i] = -1;
+    // Add the equality `rangeVar_i = output[i]`.
+    result.addEquality(eq);
+  }
+
+  return result;
+}
+
+void PWMAFunction::removeOutputs(unsigned start, unsigned end) {
+  space.removeVarRange(VarKind::Range, start, end);
+  for (Piece &piece : pieces)
+    piece.output.removeOutputs(start, end);
+}
+
+Optional<SmallVector<int64_t, 8>>
+PWMAFunction::valueAt(ArrayRef<int64_t> point) const {
+  assert(point.size() == getNumDomainVars() + getNumSymbolVars());
+
+  for (const Piece &piece : pieces)
+    if (piece.domain.containsPoint(point))
+      return piece.output.valueAt(point);
+  return None;
+}

diff  --git a/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp b/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp
index abb8796848063..a79b229318a1f 100644
--- a/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp
+++ b/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp
@@ -13,6 +13,15 @@
 using namespace mlir;
 using namespace presburger;
 
+PresburgerSpace PresburgerSpace::getDomainSpace() const {
+  // TODO: Preserve identifiers here.
+  return PresburgerSpace::getSetSpace(numDomain, numSymbols, numLocals);
+}
+
+PresburgerSpace PresburgerSpace::getRangeSpace() const {
+  return PresburgerSpace::getSetSpace(numRange, numSymbols, numLocals);
+}
+
 unsigned PresburgerSpace::getNumVarKind(VarKind kind) const {
   if (kind == VarKind::Domain)
     return getNumDomainVars();

diff  --git a/mlir/lib/Analysis/Presburger/Simplex.cpp b/mlir/lib/Analysis/Presburger/Simplex.cpp
index b856c8931114b..d3c66027b98e3 100644
--- a/mlir/lib/Analysis/Presburger/Simplex.cpp
+++ b/mlir/lib/Analysis/Presburger/Simplex.cpp
@@ -466,7 +466,14 @@ void SymbolicLexSimplex::recordOutput(SymbolicLexMin &result) const {
     }
     output.appendExtraRow(sample);
   }
-  result.lexmin.addPiece(domainPoly, output);
+
+  // Store the output in a MultiAffineFunction and add it the result.
+  PresburgerSpace funcSpace = result.lexmin.getSpace();
+  funcSpace.insertVar(VarKind::Local, 0, domainPoly.getNumLocalVars());
+
+  result.lexmin.addPiece(
+      {PresburgerSet(domainPoly),
+       MultiAffineFunction(funcSpace, output, domainPoly.getLocalReprs())});
 }
 
 Optional<unsigned> SymbolicLexSimplex::maybeGetAlwaysViolatedRow() {
@@ -508,7 +515,10 @@ LogicalResult SymbolicLexSimplex::doNonBranchingPivots() {
 }
 
 SymbolicLexMin SymbolicLexSimplex::computeSymbolicIntegerLexMin() {
-  SymbolicLexMin result(domainPoly.getSpace(), var.size() - nSymbol);
+  SymbolicLexMin result(PresburgerSpace::getRelationSpace(
+      /*numDomain=*/domainPoly.getNumDimVars(),
+      /*numRange=*/var.size() - nSymbol,
+      /*numSymbols=*/domainPoly.getNumSymbolVars()));
 
   /// The algorithm is more naturally expressed recursively, but we implement
   /// it iteratively here to avoid potential issues with stack overflows in the

diff  --git a/mlir/lib/Analysis/Presburger/Utils.cpp b/mlir/lib/Analysis/Presburger/Utils.cpp
index c3ac8232c6ec5..e22fa02cfce97 100644
--- a/mlir/lib/Analysis/Presburger/Utils.cpp
+++ b/mlir/lib/Analysis/Presburger/Utils.cpp
@@ -16,6 +16,8 @@
 #include "mlir/Support/MathExtras.h"
 #include <numeric>
 
+#include <numeric>
+
 using namespace mlir;
 using namespace presburger;
 
@@ -280,10 +282,8 @@ void presburger::mergeLocalVars(
   DivisionRepr divsA = relA.getLocalReprs();
   DivisionRepr divsB = relB.getLocalReprs();
 
-  for (unsigned i = initLocals, e = divsB.getNumDivs(); i < e; ++i) {
-    divsA.setDividend(i, divsB.getDividend(i));
-    divsA.getDenom(i) = divsB.getDenom(i);
-  }
+  for (unsigned i = initLocals, e = divsB.getNumDivs(); i < e; ++i)
+    divsA.setDiv(i, divsB.getDividend(i), divsB.getDenom(i));
 
   // Remove duplicate divisions from divsA. The removing duplicate divisions
   // call, calls `merge` to effectively merge divisions in relA and relB.
@@ -357,6 +357,55 @@ SmallVector<int64_t, 8> presburger::getComplementIneq(ArrayRef<int64_t> ineq) {
   return coeffs;
 }
 
+SmallVector<Optional<int64_t>, 4>
+DivisionRepr::divValuesAt(ArrayRef<int64_t> point) const {
+  assert(point.size() == getNumNonDivs() && "Incorrect point size");
+
+  SmallVector<Optional<int64_t>, 4> divValues(getNumDivs(), None);
+  bool changed = true;
+  while (changed) {
+    changed = false;
+    for (unsigned i = 0, e = getNumDivs(); i < e; ++i) {
+      // If division value is found, continue;
+      if (divValues[i])
+        continue;
+
+      ArrayRef<int64_t> dividend = getDividend(i);
+      int64_t divVal = 0;
+
+      // Check if we have all the division values required for this division.
+      unsigned j, f;
+      for (j = 0, f = getNumDivs(); j < f; ++j) {
+        if (dividend[getDivOffset() + j] == 0)
+          continue;
+        // Division value required, but not found yet.
+        if (!divValues[j])
+          break;
+        divVal += dividend[getDivOffset() + j] * divValues[j].value();
+      }
+
+      // We have some division values that are still not found, but are required
+      // to find the value of this division.
+      if (j < f)
+        continue;
+
+      // Fill remaining values.
+      divVal = std::inner_product(point.begin(), point.end(), dividend.begin(),
+                                  divVal);
+      // Add constant.
+      divVal += dividend.back();
+      // Take floor division with denominator.
+      divVal = floorDiv(divVal, denoms[i]);
+
+      // Set div value and continue.
+      divValues[i] = divVal;
+      changed = true;
+    }
+  }
+
+  return divValues;
+}
+
 void DivisionRepr::removeDuplicateDivs(
     llvm::function_ref<bool(unsigned i, unsigned j)> merge) {
 
@@ -402,6 +451,23 @@ void DivisionRepr::removeDuplicateDivs(
   }
 }
 
+void DivisionRepr::insertDiv(unsigned pos, ArrayRef<int64_t> dividend,
+                             unsigned divisor) {
+  assert(pos <= getNumDivs() && "Invalid insertion position");
+  assert(dividend.size() == getNumVars() + 1 && "Incorrect dividend size");
+
+  dividends.appendExtraRow(dividend);
+  denoms.insert(denoms.begin() + pos, divisor);
+  dividends.insertColumn(getDivOffset() + pos);
+}
+
+void DivisionRepr::insertDiv(unsigned pos, unsigned num) {
+  assert(pos <= getNumDivs() && "Invalid insertion position");
+  dividends.insertColumns(getDivOffset() + pos, num);
+  dividends.insertRows(pos, num);
+  denoms.insert(denoms.begin() + pos, num, 0);
+}
+
 void DivisionRepr::print(raw_ostream &os) const {
   os << "Dividends:\n";
   dividends.print(os);

diff  --git a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
index 2ed4cd62b4a09..de7413973e1b6 100644
--- a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
@@ -1171,7 +1171,7 @@ void expectSymbolicIntegerLexMin(
   ASSERT_NE(poly.getNumSymbolVars(), 0u);
 
   PWMAFunction expectedLexmin =
-      parsePWMAF(/*numInputs=*/poly.getNumSymbolVars(),
+      parsePWMAF(/*numInputs=*/0,
                  /*numOutputs=*/poly.getNumDimVars(), expectedLexminRepr,
                  /*numSymbols=*/poly.getNumSymbolVars());
 

diff  --git a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
index c241d5d2f0e81..18efe55c2479a 100644
--- a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
@@ -130,7 +130,7 @@ TEST(IntegerRelationTest, symbolicLexmin) {
           .findSymbolicIntegerLexMin();
 
   PWMAFunction expectedLexmin =
-      parsePWMAF(/*numInputs=*/2,
+      parsePWMAF(/*numInputs=*/1,
                  /*numOutputs=*/1,
                  {
                      {"(a)[b] : (a - b >= 0)", {{1, 0, 0}}},     // a

diff  --git a/mlir/unittests/Analysis/Presburger/Utils.h b/mlir/unittests/Analysis/Presburger/Utils.h
index 3b4a479e97ded..417cce15e8428 100644
--- a/mlir/unittests/Analysis/Presburger/Utils.h
+++ b/mlir/unittests/Analysis/Presburger/Utils.h
@@ -73,14 +73,20 @@ inline PWMAFunction parsePWMAF(
     unsigned numSymbols = 0) {
   static MLIRContext context;
 
-  PWMAFunction result(PresburgerSpace::getSetSpace(
-                          /*numDims=*/numInputs - numSymbols, numSymbols),
-                      numOutputs);
+  PWMAFunction result(
+      PresburgerSpace::getRelationSpace(numInputs, numOutputs, numSymbols));
   for (const auto &pair : data) {
     IntegerPolyhedron domain = parsePoly(pair.first);
 
+    PresburgerSpace funcSpace = result.getSpace();
+    funcSpace.insertVar(VarKind::Local, 0, domain.getNumLocalVars());
+
     result.addPiece(
-        domain, makeMatrix(numOutputs, domain.getNumVars() + 1, pair.second));
+        {PresburgerSet(domain),
+         MultiAffineFunction(
+             funcSpace,
+             makeMatrix(numOutputs, domain.getNumVars() + 1, pair.second),
+             domain.getLocalReprs())});
   }
   return result;
 }


        


More information about the Mlir-commits mailing list