[Mlir-commits] [mlir] 94750af - [MLIR][Presburger] Support divisions in union of two PWMAFunction

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jan 19 05:03:23 PST 2023


Author: Groverkss
Date: 2023-01-19T18:33:00+05:30
New Revision: 94750af83640cd702d80c53ab99d5bb303e55796

URL: https://github.com/llvm/llvm-project/commit/94750af83640cd702d80c53ab99d5bb303e55796
DIFF: https://github.com/llvm/llvm-project/commit/94750af83640cd702d80c53ab99d5bb303e55796.diff

LOG: [MLIR][Presburger] Support divisions in union of two PWMAFunction

This patch adds support for divisions in the union of two PWMAFunction. This is
now possible because of previous patches, which made divisions explicitly
stored in MultiAffineFunction (MAF). This patch also refactors the previous
implementation, moving the implementation for obtaining a set of points where a
MAF is lexicographically "better" than the other to MAF.

Reviewed By: arjunp

Differential Revision: https://reviews.llvm.org/D138118

Added: 
    

Modified: 
    mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
    mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h
    mlir/lib/Analysis/Presburger/PWMAFunction.cpp
    mlir/lib/Analysis/Presburger/PresburgerSpace.cpp
    mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
index 4ba0f44cbc82a..ea3456624e72d 100644
--- a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
+++ b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
@@ -23,6 +23,10 @@
 namespace mlir {
 namespace presburger {
 
+/// Enum representing a binary comparison operator: equal, not equal, less than,
+/// less than or equal, greater than, greater than or equal.
+enum class OrderingKind { EQ, NE, LT, LE, GT, GE };
+
 /// This class represents a multi-affine function with the domain as Z^d, where
 /// `d` is the number of domain variables of the function. For example:
 ///
@@ -65,7 +69,10 @@ class MultiAffineFunction {
   /// Get the `i^th` output expression.
   ArrayRef<MPInt> getOutputExpr(unsigned i) const { return output.getRow(i); }
 
-  // Remove the specified range of outputs.
+  /// Get the divisions used in this function.
+  const DivisionRepr &getDivs() const { return divs; }
+
+  /// Remove the specified range of outputs.
   void removeOutputs(unsigned start, unsigned end);
 
   /// Given a MAF `other`, merges division variables such that both functions
@@ -89,6 +96,14 @@ class MultiAffineFunction {
 
   void subtract(const MultiAffineFunction &other);
 
+  /// Return the set of domain points where the output of `this` and `other`
+  /// are ordered lexicographically according to the given ordering.
+  /// For example, if the given comparison is `LT`, then the returned set
+  /// contains all points where the first output of `this` is lexicographically
+  /// less than `other`.
+  PresburgerSet getLexSet(OrderingKind comp,
+                          const MultiAffineFunction &other) const;
+
   /// Get this function as a relation.
   IntegerRelation getAsRelation() const;
 
@@ -181,6 +196,9 @@ class PWMAFunction {
     return valueAt(getMPIntVec(point));
   }
 
+  /// Return all the pieces of this piece-wise function.
+  ArrayRef<Piece> getAllPieces() const { return pieces; }
+
   /// Return whether `this` and `other` are equal as PWMAFunctions, i.e. whether
   /// they have the same dimensions, the same domain and they take the same
   /// value at every point in the domain.

diff  --git a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h
index 03a5dfb0631e3..998a70c677bf5 100644
--- a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h
+++ b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h
@@ -90,11 +90,14 @@ class PresburgerSpace {
                            numLocals);
   }
 
-  // Get the domain/range space of this space. The returned space is a set
-  // space.
+  /// Get the domain/range space of this space. The returned space is a set
+  /// space.
   PresburgerSpace getDomainSpace() const;
   PresburgerSpace getRangeSpace() const;
 
+  /// Get the space without local variables.
+  PresburgerSpace getSpaceWithoutLocals() const;
+
   unsigned getNumDomainVars() const { return numDomain; }
   unsigned getNumRangeVars() const { return numRange; }
   unsigned getNumSetDimVars() const { return numRange; }

diff  --git a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp
index c31d50ad29ff4..64b9ba6bf7a0e 100644
--- a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp
+++ b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp
@@ -172,6 +172,93 @@ void MultiAffineFunction::mergeDivs(MultiAffineFunction &other) {
   other.assertIsConsistent();
 }
 
+PresburgerSet
+MultiAffineFunction::getLexSet(OrderingKind comp,
+                               const MultiAffineFunction &other) const {
+  assert(getSpace().isCompatible(other.getSpace()) &&
+         "Output space of funcs should be compatible");
+
+  // Create copies of functions and merge their local space.
+  MultiAffineFunction funcA = *this;
+  MultiAffineFunction funcB = other;
+  funcA.mergeDivs(funcB);
+
+  // We first create the set `result`, corresponding to the set where output
+  // of funcA is lexicographically larger/smaller than funcB. This is done by
+  // creating a PresburgerSet with the following constraints:
+  //
+  //    (outA[0] > outB[0]) U
+  //    (outA[0] = outB[0], outA[1] > outA[1]) U
+  //    (outA[0] = outB[0], outA[1] = outA[1], outA[2] > outA[2]) U
+  //    ...
+  //    (outA[0] = outB[0], ..., outA[n-2] = outB[n-2], outA[n-1] > outB[n-1])
+  //
+  // where `n` is the number of outputs.
+  // If `lexMin` is set, the complement inequality is used:
+  //
+  //    (outA[0] < outB[0]) U
+  //    (outA[0] = outB[0], outA[1] < outA[1]) U
+  //    (outA[0] = outB[0], outA[1] = outA[1], outA[2] < outA[2]) U
+  //    ...
+  //    (outA[0] = outB[0], ..., outA[n-2] = outB[n-2], outA[n-1] < outB[n-1])
+  PresburgerSpace resultSpace = funcA.getDomainSpace();
+  PresburgerSet result =
+      PresburgerSet::getEmpty(resultSpace.getSpaceWithoutLocals());
+  IntegerPolyhedron levelSet(
+      /*numReservedInequalities=*/1 + 2 * resultSpace.getNumLocalVars(),
+      /*numReservedEqualities=*/funcA.getNumOutputs(),
+      /*numReservedCols=*/resultSpace.getNumVars() + 1, resultSpace);
+
+  // Add division inequalities to `levelSet`.
+  for (unsigned i = 0, e = funcA.getNumDivs(); i < e; ++i) {
+    levelSet.addInequality(getDivUpperBound(funcA.divs.getDividend(i),
+                                            funcA.divs.getDenom(i),
+                                            funcA.divs.getDivOffset() + i));
+    levelSet.addInequality(getDivLowerBound(funcA.divs.getDividend(i),
+                                            funcA.divs.getDenom(i),
+                                            funcA.divs.getDivOffset() + i));
+  }
+
+  for (unsigned level = 0; level < funcA.getNumOutputs(); ++level) {
+    // Create the expression `outA - outB` for this level.
+    SmallVector<MPInt, 8> subExpr =
+        subtractExprs(funcA.getOutputExpr(level), funcB.getOutputExpr(level));
+
+    // TODO: Implement all comparison cases.
+    switch (comp) {
+    case OrderingKind::LT:
+      // For less than, we add an upper bound of -1:
+      //        outA - outB <= -1
+      //        outA <= outB - 1
+      //        outA < outB
+      levelSet.addBound(IntegerPolyhedron::BoundType::UB, subExpr, MPInt(-1));
+      break;
+    case OrderingKind::GT:
+      // For greater than, we add a lower bound of 1:
+      //        outA - outB >= 1
+      //        outA > outB + 1
+      //        outA > outB
+      levelSet.addBound(IntegerPolyhedron::BoundType::LB, subExpr, MPInt(1));
+      break;
+    case OrderingKind::GE:
+    case OrderingKind::LE:
+    case OrderingKind::EQ:
+    case OrderingKind::NE:
+      assert(false && "Not implemented case");
+    }
+
+    // Union the set with the result.
+    result.unionInPlace(levelSet);
+    // The last inequality in `levelSet` is the bound we inserted. We remove
+    // that for next iteration.
+    levelSet.removeInequality(levelSet.getNumInequalities() - 1);
+    // Add equality `outA - outB == 0` for this level for next iteration.
+    levelSet.addEquality(subExpr);
+  }
+
+  return result;
+}
+
 /// Two PWMAFunctions are equal if they have the same dimensionalities,
 /// the same domain, and take the same value at every point in the domain.
 bool PWMAFunction::isEqual(const PWMAFunction &other) const {
@@ -195,6 +282,8 @@ bool PWMAFunction::isEqual(const PWMAFunction &other) const {
 
 void PWMAFunction::addPiece(const Piece &piece) {
   assert(piece.isConsistent() && "Piece should be consistent");
+  assert(piece.domain.intersect(getDomain()).isIntegerEmpty() &&
+         "Piece should be disjoint from the function");
   pieces.push_back(piece);
 }
 
@@ -263,85 +352,23 @@ PWMAFunction PWMAFunction::unionFunction(
 }
 
 /// A tiebreak function which breaks ties by comparing the outputs
-/// lexicographically. If `lexMin` is true, then the ties are broken by
-/// taking the lexicographically smaller output and otherwise, by taking the
-/// lexicographically larger output.
-template <bool lexMin>
+/// lexicographically based on the given comparison operator.
+/// This is templated since it is passed as a lambda.
+template <OrderingKind comp>
 static PresburgerSet tiebreakLex(const PWMAFunction::Piece &pieceA,
                                  const PWMAFunction::Piece &pieceB) {
-  // TODO: Support local variables here.
-  assert(pieceA.output.getSpace().isCompatible(pieceB.output.getSpace()) &&
-         "Pieces should be compatible");
-  assert(pieceA.domain.getSpace().getNumLocalVars() == 0 &&
-         "Local variables are not supported yet.");
-
-  PresburgerSpace compatibleSpace = pieceA.domain.getSpace();
-  const PresburgerSpace &space = pieceA.domain.getSpace();
-
-  // We first create the set `result`, corresponding to the set where output
-  // of pieceA is lexicographically larger/smaller than pieceB. This is done by
-  // creating a PresburgerSet with the following constraints:
-  //
-  //    (outA[0] > outB[0]) U
-  //    (outA[0] = outB[0], outA[1] > outA[1]) U
-  //    (outA[0] = outB[0], outA[1] = outA[1], outA[2] > outA[2]) U
-  //    ...
-  //    (outA[0] = outB[0], ..., outA[n-2] = outB[n-2], outA[n-1] > outB[n-1])
-  //
-  // where `n` is the number of outputs.
-  // If `lexMin` is set, the complement inequality is used:
-  //
-  //    (outA[0] < outB[0]) U
-  //    (outA[0] = outB[0], outA[1] < outA[1]) U
-  //    (outA[0] = outB[0], outA[1] = outA[1], outA[2] < outA[2]) U
-  //    ...
-  //    (outA[0] = outB[0], ..., outA[n-2] = outB[n-2], outA[n-1] < outB[n-1])
-  PresburgerSet result = PresburgerSet::getEmpty(compatibleSpace);
-  IntegerPolyhedron levelSet(
-      /*numReservedInequalities=*/1,
-      /*numReservedEqualities=*/pieceA.output.getNumOutputs(),
-      /*numReservedCols=*/space.getNumVars() + 1, space);
-  for (unsigned level = 0; level < pieceA.output.getNumOutputs(); ++level) {
-
-    // Create the expression `outA - outB` for this level.
-    SmallVector<MPInt, 8> subExpr = subtractExprs(
-        pieceA.output.getOutputExpr(level), pieceB.output.getOutputExpr(level));
-
-    if (lexMin) {
-      // For lexMin, we add an upper bound of -1:
-      //        outA - outB <= -1
-      //        outA <= outB - 1
-      //        outA < outB
-      levelSet.addBound(IntegerPolyhedron::BoundType::UB, subExpr, MPInt(-1));
-    } else {
-      // For lexMax, we add a lower bound of 1:
-      //        outA - outB >= 1
-      //        outA > outB + 1
-      //        outA > outB
-      levelSet.addBound(IntegerPolyhedron::BoundType::LB, subExpr, MPInt(1));
-    }
-
-    // Union the set with the result.
-    result.unionInPlace(levelSet);
-    // There is only 1 inequality in `levelSet`, so the index is always 0.
-    levelSet.removeInequality(0);
-    // Add equality `outA - outB == 0` for this level for next iteration.
-    levelSet.addEquality(subExpr);
-  }
-
-  // We then intersect `result` with the domain of pieceA and pieceB, to only
-  // tiebreak on the domain where both are defined.
+  PresburgerSet result = pieceA.output.getLexSet(comp, pieceB.output);
   result = result.intersect(pieceA.domain).intersect(pieceB.domain);
 
   return result;
 }
 
 PWMAFunction PWMAFunction::unionLexMin(const PWMAFunction &func) {
-  return unionFunction(func, tiebreakLex</*lexMin=*/true>);
+  return unionFunction(func, tiebreakLex</*comp=*/OrderingKind::LT>);
 }
 
 PWMAFunction PWMAFunction::unionLexMax(const PWMAFunction &func) {
-  return unionFunction(func, tiebreakLex</*lexMin=*/false>);
+  return unionFunction(func, tiebreakLex</*comp=*/OrderingKind::GT>);
 }
 
 void MultiAffineFunction::subtract(const MultiAffineFunction &other) {

diff  --git a/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp b/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp
index 648860c7756ef..e15db1edf8cb4 100644
--- a/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp
+++ b/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp
@@ -22,6 +22,12 @@ PresburgerSpace PresburgerSpace::getRangeSpace() const {
   return PresburgerSpace::getSetSpace(numRange, numSymbols, numLocals);
 }
 
+PresburgerSpace PresburgerSpace::getSpaceWithoutLocals() const {
+  PresburgerSpace space = *this;
+  space.removeVarRange(VarKind::Local, 0, numLocals);
+  return space;
+}
+
 unsigned PresburgerSpace::getNumVarKind(VarKind kind) const {
   if (kind == VarKind::Domain)
     return getNumDomainVars();

diff  --git a/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp b/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp
index cebc7fa2ec6e6..ee2931e78185c 100644
--- a/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp
@@ -395,3 +395,44 @@ TEST(PWMAFunction, unionLexMinComplex) {
   EXPECT_TRUE(func1.unionLexMin(func2).isEqual(result));
   EXPECT_TRUE(func2.unionLexMin(func1).isEqual(result));
 }
+
+TEST(PWMAFunction, unionLexMinWithDivs) {
+  {
+    PWMAFunction func1 = parsePWMAF({
+        {"(x, y) : (x mod 5 == 0)", "(x, y) -> (x, 1)"},
+    });
+
+    PWMAFunction func2 = parsePWMAF({
+        {"(x, y) : (x mod 7 == 0)", "(x, y) -> (x + y, 2)"},
+    });
+
+    PWMAFunction result = parsePWMAF({
+        {"(x, y) : (x mod 5 == 0, x mod 7 >= 1)", "(x, y) -> (x, 1)"},
+        {"(x, y) : (x mod 7 == 0, x mod 5 >= 1)", "(x, y) -> (x + y, 2)"},
+        {"(x, y) : (x mod 5 == 0, x mod 7 == 0, y >= 0)", "(x, y) -> (x, 1)"},
+        {"(x, y) : (x mod 7 == 0, x mod 5 == 0, y <= -1)",
+         "(x, y) -> (x + y, 2)"},
+    });
+
+    EXPECT_TRUE(func1.unionLexMin(func2).isEqual(result));
+  }
+
+  {
+    PWMAFunction func1 = parsePWMAF({
+        {"(x) : (x >= 0, x <= 1000)", "(x) -> (x floordiv 16)"},
+    });
+
+    PWMAFunction func2 = parsePWMAF({
+        {"(x) : (x >= 0, x <= 1000)", "(x) -> ((x + 10) floordiv 17)"},
+    });
+
+    PWMAFunction result = parsePWMAF({
+        {"(x) : (x >= 0, x <= 1000, x floordiv 16 <= (x + 10) floordiv 17)",
+         "(x) -> (x floordiv 16)"},
+        {"(x) : (x >= 0, x <= 1000, x floordiv 16 >= (x + 10) floordiv 17 + 1)",
+         "(x) -> ((x + 10) floordiv 17)"},
+    });
+
+    EXPECT_TRUE(func1.unionLexMin(func2).isEqual(result));
+  }
+}


        


More information about the Mlir-commits mailing list