[Mlir-commits] [mlir] a18f843 - [MLIR][Presburger] Support lexicographic max/min union of two PWMAFunction

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jul 6 08:09:23 PDT 2022


Author: Groverkss
Date: 2022-07-06T16:08:20+01:00
New Revision: a18f843f075f3fbe461d0f114a28e6e383d7c736

URL: https://github.com/llvm/llvm-project/commit/a18f843f075f3fbe461d0f114a28e6e383d7c736
DIFF: https://github.com/llvm/llvm-project/commit/a18f843f075f3fbe461d0f114a28e6e383d7c736.diff

LOG: [MLIR][Presburger] Support lexicographic max/min union of two PWMAFunction

This patch implements a lexicographic max/min union of two PWMAFunctions.

The lexmax/lexmin union of two functions is defined as a function defined on
the union of the input domains of both functions, such that when only one of the
functions are defined, it outputs the same as that function, and if both are
defined, it outputs the lexmax/lexmin of the two outputs. On points where
neither function is defined, the union is not defined either.

Reviewed By: arjunp

Differential Revision: https://reviews.llvm.org/D128829

Added: 
    

Modified: 
    mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
    mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
    mlir/lib/Analysis/Presburger/IntegerRelation.cpp
    mlir/lib/Analysis/Presburger/PWMAFunction.cpp
    mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
index 720309a7c1e98..a49d50e081a13 100644
--- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
+++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
@@ -118,7 +118,7 @@ class IntegerRelation {
   /// intersection with no simplification of any sort attempted.
   void append(const IntegerRelation &other);
 
-  /// Return the intersection of the two sets.
+  /// Return the intersection of the two relations.
   /// If there are locals, they will be merged.
   IntegerRelation intersect(IntegerRelation other) const;
 
@@ -608,6 +608,10 @@ class IntegerRelation {
   /// `PresburgerSet`, `unboundedDomain`.
   SymbolicLexMin findSymbolicIntegerLexMin() const;
 
+  /// Return the set 
diff erence of this set and the given set, i.e.,
+  /// return `this \ set`.
+  PresburgerRelation subtract(const PresburgerRelation &set) const;
+
   void print(raw_ostream &os) const;
   void dump() const;
 
@@ -790,6 +794,14 @@ class IntegerPolyhedron : public IntegerRelation {
   /// column position (i.e., not relative to the kind of variable) of the
   /// first added variable.
   unsigned insertVar(VarKind kind, unsigned pos, unsigned num = 1) override;
+
+  /// Return the intersection of the two relations.
+  /// If there are locals, they will be merged.
+  IntegerPolyhedron intersect(const IntegerPolyhedron &other) const;
+
+  /// Return the set 
diff erence of this set and the given set, i.e.,
+  /// return `this \ set`.
+  PresburgerSet subtract(const PresburgerSet &other) const;
 };
 
 } // namespace presburger

diff  --git a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
index 00de5b5ac96a6..c4626a2945f01 100644
--- a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
+++ b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
@@ -54,8 +54,15 @@ class MultiAffineFunction {
   bool isConsistent() const {
     return output.getNumColumns() == domainSet.getNumVars() + 1;
   }
-  const IntegerPolyhedron &getDomain() const { return domainSet; }
+
+  /// Get the space of the input domain of this function.
   const PresburgerSpace &getDomainSpace() const { return domainSet.getSpace(); }
+  /// Get the input domain of this function.
+  const IntegerPolyhedron &getDomain() const { return domainSet; }
+  /// Get a matrix with each row representing row^th output expression.
+  const Matrix &getOutputMatrix() const { return output; }
+  /// Get the `i^th` output expression.
+  ArrayRef<int64_t> getOutputExpr(unsigned i) const { return output.getRow(i); }
 
   /// Insert `num` variables of the specified kind at position `pos`.
   /// Positions are relative to the kind of variable. The coefficient columns
@@ -138,6 +145,7 @@ class PWMAFunction {
 
   void addPiece(const MultiAffineFunction &piece);
   void addPiece(const IntegerPolyhedron &domain, const Matrix &output);
+  void addPiece(const PresburgerSet &domain, const Matrix &output);
 
   const MultiAffineFunction &getPiece(unsigned i) const { return pieces[i]; }
   unsigned getNumPieces() const { return pieces.size(); }
@@ -163,10 +171,41 @@ class PWMAFunction {
   /// TODO: refactor so that this can be accomplished through removeVarRange.
   void truncateOutput(unsigned count);
 
+  /// Return a function defined on the union of the domains of this and func,
+  /// such that when only one of the functions is defined, it outputs the same
+  /// as that function, and if both are defined, it outputs the lexmax/lexmin of
+  /// the two outputs. On points where neither function is defined, the returned
+  /// function is not defined either.
+  ///
+  /// Currently this does not support PWMAFunctions which have pieces containing
+  /// local variables.
+  /// TODO: Support local variables in peices.
+  PWMAFunction unionLexMin(const PWMAFunction &func);
+  PWMAFunction unionLexMax(const PWMAFunction &func);
+
   void print(raw_ostream &os) const;
   void dump() const;
 
 private:
+  /// Return a function defined on the union of the domains of `this` and
+  /// `func`, such that when only one of the functions is defined, it outputs
+  /// the same as that function, and if neither is defined, the returned
+  /// function is not defined either.
+  ///
+  /// The provided `tiebreak` function determines which of the two functions'
+  /// output should be used on inputs where both the functions are defined. More
+  /// precisely, given two `MultiAffineFunction`s `mafA` and `mafB`, `tiebreak`
+  /// returns the subset of the intersection of the two functions' domains where
+  /// the output of `mafA` should be used.
+  ///
+  /// The PresburgerSet returned by `tiebreak` should be disjoint.
+  /// TODO: Remove this constraint of returning disjoint set.
+  PWMAFunction
+  unionFunction(const PWMAFunction &func,
+                llvm::function_ref<PresburgerSet(MultiAffineFunction mafA,
+                                                 MultiAffineFunction mafB)>
+                    tiebreak) const;
+
   PresburgerSpace space;
 
   /// The list of pieces in this piece-wise MultiAffineFunction.

diff  --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
index fd4c038c6d82e..b455f4818acf6 100644
--- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
@@ -252,6 +252,11 @@ SymbolicLexMin IntegerRelation::findSymbolicIntegerLexMin() const {
   return result;
 }
 
+PresburgerRelation
+IntegerRelation::subtract(const PresburgerRelation &set) const {
+  return PresburgerRelation(*this).subtract(set);
+}
+
 unsigned IntegerRelation::insertVar(VarKind kind, unsigned pos, unsigned num) {
   assert(pos <= getNumVarKind(kind));
 
@@ -2284,3 +2289,11 @@ unsigned IntegerPolyhedron::insertVar(VarKind kind, unsigned pos,
          "Domain has to be zero in a set");
   return IntegerRelation::insertVar(kind, pos, num);
 }
+IntegerPolyhedron
+IntegerPolyhedron::intersect(const IntegerPolyhedron &other) const {
+  return IntegerPolyhedron(IntegerRelation::intersect(other));
+}
+
+PresburgerSet IntegerPolyhedron::subtract(const PresburgerSet &other) const {
+  return PresburgerSet(IntegerRelation::subtract(other));
+}

diff  --git a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp
index 39c828a20c609..18b5d0e7c68dc 100644
--- a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp
+++ b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp
@@ -211,6 +211,11 @@ void PWMAFunction::addPiece(const IntegerPolyhedron &domain,
   addPiece(MultiAffineFunction(domain, output));
 }
 
+void PWMAFunction::addPiece(const PresburgerSet &domain, const Matrix &output) {
+  for (const IntegerRelation &newDom : domain.getAllDisjuncts())
+    addPiece(IntegerPolyhedron(newDom), output);
+}
+
 void PWMAFunction::print(raw_ostream &os) const {
   os << pieces.size() << " pieces:\n";
   for (const MultiAffineFunction &piece : pieces)
@@ -218,3 +223,138 @@ void PWMAFunction::print(raw_ostream &os) const {
 }
 
 void PWMAFunction::dump() const { print(llvm::errs()); }
+
+PWMAFunction PWMAFunction::unionFunction(
+    const PWMAFunction &func,
+    llvm::function_ref<PresburgerSet(MultiAffineFunction maf1,
+                                     MultiAffineFunction maf2)>
+        tiebreak) const {
+  assert(getNumOutputs() == func.getNumOutputs() &&
+         "Number of outputs of functions should be same.");
+  assert(getSpace().isCompatible(func.getSpace()) &&
+         "Space is not compatible.");
+
+  // The algorithm used here is as follows:
+  // - Add the output of funcB for the part of the domain where both funcA and
+  //   funcB are defined, and `tiebreak` chooses the output of funcB.
+  // - Add the output of funcA, where funcB is not defined or `tiebreak` chooses
+  //   funcA over funcB.
+  // - Add the output of funcB, where funcA is not defined.
+
+  // Add parts of the common domain where funcB's output is used. Also
+  // add all the parts where funcA's output is used, both common and non-common.
+  PWMAFunction result(getSpace(), getNumOutputs());
+  for (const MultiAffineFunction &funcA : pieces) {
+    PresburgerSet dom(funcA.getDomain());
+    for (const MultiAffineFunction &funcB : func.pieces) {
+      PresburgerSet better = tiebreak(funcB, funcA);
+      // Add the output of funcB, where it is better than output of funcA.
+      // The disjuncts in "better" will be disjoint as tiebreak should gurantee
+      // that.
+      result.addPiece(better, funcB.getOutputMatrix());
+      dom = dom.subtract(better);
+    }
+    // Add output of funcA, where it is better than funcB, or funcB is not
+    // defined.
+    //
+    // `dom` here is guranteed to be disjoint from already added pieces
+    // because because the pieces added before are either:
+    // - Subsets of the domain of other MAFs in `this`, which are guranteed
+    //   to be disjoint from `dom`, or
+    // - They are one of the pieces added for `funcB`, and we have been
+    //   subtracting all such pieces from `dom`, so `dom` is disjoint from those
+    //   pieces as well.
+    result.addPiece(dom, funcA.getOutputMatrix());
+  }
+
+  // Add parts of funcB which are not shared with funcA.
+  PresburgerSet dom = getDomain();
+  for (const MultiAffineFunction &funcB : func.pieces)
+    result.addPiece(funcB.getDomain().subtract(dom), funcB.getOutputMatrix());
+
+  return result;
+}
+
+/// A tiebreak function which breaks ties by comparing the outputs
+/// lexicographically. If `lexMin` is true, then the ties are broken by
+/// taking the lexicographically smaller output and otherwise, by taking the
+/// lexicographically larger output.
+template <bool lexMin>
+static PresburgerSet tiebreakLex(const MultiAffineFunction &mafA,
+                                 const MultiAffineFunction &mafB) {
+  // TODO: Support local variables here.
+  assert(mafA.getDomainSpace().isCompatible(mafB.getDomainSpace()) &&
+         "Domain spaces should be compatible.");
+  assert(mafA.getNumOutputs() == mafB.getNumOutputs() &&
+         "Number of outputs of both functions should be same.");
+  assert(mafA.getDomain().getNumLocalVars() == 0 &&
+         "Local variables are not supported yet.");
+
+  PresburgerSpace compatibleSpace = mafA.getDomain().getSpaceWithoutLocals();
+  const PresburgerSpace &space = mafA.getDomain().getSpace();
+
+  // We first create the set `result`, corresponding to the set where output
+  // of mafA is lexicographically larger/smaller than mafB. This is done by
+  // creating a PresburgerSet with the following constraints:
+  //
+  //    (outA[0] > outB[0]) U
+  //    (outA[0] = outB[0], outA[1] > outA[1]) U
+  //    (outA[0] = outB[0], outA[1] = outA[1], outA[2] > outA[2]) U
+  //    ...
+  //    (outA[0] = outB[0], ..., outA[n-2] = outB[n-2], outA[n-1] > outB[n-1])
+  //
+  // where `n` is the number of outputs.
+  // If `lexMin` is set, the complement inequality is used:
+  //
+  //    (outA[0] < outB[0]) U
+  //    (outA[0] = outB[0], outA[1] < outA[1]) U
+  //    (outA[0] = outB[0], outA[1] = outA[1], outA[2] < outA[2]) U
+  //    ...
+  //    (outA[0] = outB[0], ..., outA[n-2] = outB[n-2], outA[n-1] < outB[n-1])
+  PresburgerSet result = PresburgerSet::getEmpty(compatibleSpace);
+  IntegerPolyhedron levelSet(/*numReservedInequalities=*/1,
+                             /*numReservedEqualities=*/mafA.getNumOutputs(),
+                             /*numReservedCols=*/space.getNumVars() + 1, space);
+  for (unsigned level = 0; level < mafA.getNumOutputs(); ++level) {
+
+    // Create the expression `outA - outB` for this level.
+    SmallVector<int64_t, 8> subExpr =
+        subtract(mafA.getOutputExpr(level), mafB.getOutputExpr(level));
+
+    if (lexMin) {
+      // For lexMin, we add an upper bound of -1:
+      //        outA - outB <= -1
+      //        outA <= outB - 1
+      //        outA < outB
+      levelSet.addBound(IntegerPolyhedron::BoundType::UB, subExpr, -1);
+    } else {
+      // For lexMax, we add a lower bound of 1:
+      //        outA - outB >= 1
+      //        outA > outB + 1
+      //        outA > outB
+      levelSet.addBound(IntegerPolyhedron::BoundType::LB, subExpr, 1);
+    }
+
+    // Union the set with the result.
+    result.unionInPlace(levelSet);
+    // There is only 1 inequality in `levelSet`, so the index is always 0.
+    levelSet.removeInequality(0);
+    // Add equality `outA - outB == 0` for this level for next iteration.
+    levelSet.addEquality(subExpr);
+  }
+
+  // We then intersect `result` with the domain of mafA and mafB, to only
+  // tiebreak on the domain where both are defined.
+  result = result.intersect(PresburgerSet(mafA.getDomain()))
+               .intersect(PresburgerSet(mafB.getDomain()));
+
+  return result;
+}
+
+PWMAFunction PWMAFunction::unionLexMin(const PWMAFunction &func) {
+  return unionFunction(func, tiebreakLex</*lexMin=*/true>);
+}
+
+PWMAFunction PWMAFunction::unionLexMax(const PWMAFunction &func) {
+  return unionFunction(func, tiebreakLex</*lexMin=*/false>);
+}

diff  --git a/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp b/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp
index 27d02524ec15b..c99139607bdd5 100644
--- a/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp
@@ -189,3 +189,331 @@ TEST(PWMAFunction, eliminateRedundantLocalIdRegressionTest) {
       });
   EXPECT_TRUE(pwmafA.isEqual(pwmafB));
 }
+
+TEST(PWMAFunction, unionLexMaxSimple) {
+  // func2 is better than func1, but func2's domain is empty.
+  {
+    PWMAFunction func1 = parsePWMAF(
+        /*numInputs=*/1, /*numOutputs=*/1,
+        {
+            {"(x) : ()", {{0, 1}}},
+        });
+
+    PWMAFunction func2 = parsePWMAF(
+        /*numInputs=*/1, /*numOutputs=*/1,
+        {
+            {"(x) : (1 == 0)", {{0, 2}}},
+        });
+
+    EXPECT_TRUE(func1.unionLexMax(func2).isEqual(func1));
+    EXPECT_TRUE(func2.unionLexMax(func1).isEqual(func1));
+  }
+
+  // func2 is better than func1 on a subset of func1.
+  {
+    PWMAFunction func1 = parsePWMAF(
+        /*numInputs=*/1, /*numOutputs=*/1,
+        {
+            {"(x) : ()", {{0, 1}}},
+        });
+
+    PWMAFunction func2 = parsePWMAF(
+        /*numInputs=*/1, /*numOutputs=*/1,
+        {
+            {"(x) : (x >= 0, 10 - x >= 0)", {{0, 2}}},
+        });
+
+    PWMAFunction result = parsePWMAF(
+        /*numInputs=*/1, /*numOutputs=*/1,
+        {
+            {"(x) : (-1 - x >= 0)", {{0, 1}}},
+            {"(x) : (x >= 0, 10 - x >= 0)", {{0, 2}}},
+            {"(x) : (x - 11 >= 0)", {{0, 1}}},
+        });
+
+    EXPECT_TRUE(func1.unionLexMax(func2).isEqual(result));
+    EXPECT_TRUE(func2.unionLexMax(func1).isEqual(result));
+  }
+
+  // func1 and func2 are defined over the whole domain with 
diff erent outputs.
+  {
+    PWMAFunction func1 = parsePWMAF(
+        /*numInputs=*/1, /*numOutputs=*/1,
+        {
+            {"(x) : ()", {{1, 0}}},
+        });
+
+    PWMAFunction func2 = parsePWMAF(
+        /*numInputs=*/1, /*numOutputs=*/1,
+        {
+            {"(x) : ()", {{-1, 0}}},
+        });
+
+    PWMAFunction result = parsePWMAF(
+        /*numInputs=*/1, /*numOutputs=*/1,
+        {
+            {"(x) : (x >= 0)", {{1, 0}}},
+            {"(x) : (-1 - x >= 0)", {{-1, 0}}},
+        });
+
+    EXPECT_TRUE(func1.unionLexMax(func2).isEqual(result));
+    EXPECT_TRUE(func2.unionLexMax(func1).isEqual(result));
+  }
+
+  // func1 and func2 have disjoint domains.
+  {
+    PWMAFunction func1 = parsePWMAF(
+        /*numInputs=*/1, /*numOutputs=*/1,
+        {
+            {"(x) : (x >= 0, 10 - x >= 0)", {{0, 1}}},
+            {"(x) : (x - 71 >= 0, 80 - x >= 0)", {{0, 1}}},
+        });
+
+    PWMAFunction func2 = parsePWMAF(
+        /*numInputs=*/1, /*numOutputs=*/1,
+        {
+            {"(x) : (x - 20 >= 0, 41 - x >= 0)", {{0, 2}}},
+            {"(x) : (x - 101 >= 0, 120 - x >= 0)", {{0, 2}}},
+        });
+
+    PWMAFunction result = parsePWMAF(
+        /*numInputs=*/1, /*numOutputs=*/1,
+        {
+            {"(x) : (x >= 0, 10 - x >= 0)", {{0, 1}}},
+            {"(x) : (x - 71 >= 0, 80 - x >= 0)", {{0, 1}}},
+            {"(x) : (x - 20 >= 0, 41 - x >= 0)", {{0, 2}}},
+            {"(x) : (x - 101 >= 0, 120 - x >= 0)", {{0, 2}}},
+        });
+
+    EXPECT_TRUE(func1.unionLexMin(func2).isEqual(result));
+    EXPECT_TRUE(func2.unionLexMin(func1).isEqual(result));
+  }
+}
+
+TEST(PWMAFunction, unionLexMinSimple) {
+  // func2 is better than func1, but func2's domain is empty.
+  {
+    PWMAFunction func1 = parsePWMAF(
+        /*numInputs=*/1, /*numOutputs=*/1,
+        {
+            {"(x) : ()", {{0, -1}}},
+        });
+
+    PWMAFunction func2 = parsePWMAF(
+        /*numInputs=*/1, /*numOutputs=*/1,
+        {
+            {"(x) : (1 == 0)", {{0, -2}}},
+        });
+
+    EXPECT_TRUE(func1.unionLexMin(func2).isEqual(func1));
+    EXPECT_TRUE(func2.unionLexMin(func1).isEqual(func1));
+  }
+
+  // func2 is better than func1 on a subset of func1.
+  {
+    PWMAFunction func1 = parsePWMAF(
+        /*numInputs=*/1, /*numOutputs=*/1,
+        {
+            {"(x) : ()", {{0, -1}}},
+        });
+
+    PWMAFunction func2 = parsePWMAF(
+        /*numInputs=*/1, /*numOutputs=*/1,
+        {
+            {"(x) : (x >= 0, 10 - x >= 0)", {{0, -2}}},
+        });
+
+    PWMAFunction result = parsePWMAF(
+        /*numInputs=*/1, /*numOutputs=*/1,
+        {
+            {"(x) : (-1 - x >= 0)", {{0, -1}}},
+            {"(x) : (x >= 0, 10 - x >= 0)", {{0, -2}}},
+            {"(x) : (x - 11 >= 0)", {{0, -1}}},
+        });
+
+    EXPECT_TRUE(func1.unionLexMin(func2).isEqual(result));
+    EXPECT_TRUE(func2.unionLexMin(func1).isEqual(result));
+  }
+
+  // func1 and func2 are defined over the whole domain with 
diff erent outputs.
+  {
+    PWMAFunction func1 = parsePWMAF(
+        /*numInputs=*/1, /*numOutputs=*/1,
+        {
+            {"(x) : ()", {{-1, 0}}},
+        });
+
+    PWMAFunction func2 = parsePWMAF(
+        /*numInputs=*/1, /*numOutputs=*/1,
+        {
+            {"(x) : ()", {{1, 0}}},
+        });
+
+    PWMAFunction result = parsePWMAF(
+        /*numInputs=*/1, /*numOutputs=*/1,
+        {
+            {"(x) : (x >= 0)", {{-1, 0}}},
+            {"(x) : (-1 - x >= 0)", {{1, 0}}},
+        });
+
+    EXPECT_TRUE(func1.unionLexMin(func2).isEqual(result));
+    EXPECT_TRUE(func2.unionLexMin(func1).isEqual(result));
+  }
+}
+
+TEST(PWMAFunction, unionLexMaxComplex) {
+  // Union of function containing 4 
diff erent pieces of output.
+  //
+  // x >= 21               --> func1 (func2 not defined)
+  // x <= 0                --> func2 (func1 not defined)
+  // 10 <= x <= 20, y >  0 --> func1 (x + y  > x - y for y >  0)
+  // 10 <= x <= 20, y <= 0 --> func2 (x + y <= x - y for y <= 0)
+  {
+    PWMAFunction func1 = parsePWMAF(
+        /*numInputs=*/2, /*numOutputs=*/1,
+        {
+            {"(x, y) : (x >= 10)", {{1, 1, 0}}},
+        });
+
+    PWMAFunction func2 = parsePWMAF(
+        /*numInputs=*/2, /*numOutputs=*/1,
+        {
+            {"(x, y) : (x <= 20)", {{1, -1, 0}}},
+        });
+
+    PWMAFunction result = parsePWMAF(/*numInputs=*/2, /*numOutputs=*/1,
+                                     {{"(x, y) : (x >= 10, x <= 20, y >= 1)",
+                                       {
+                                           {1, 1, 0},
+                                       }},
+                                      {"(x, y) : (x >= 21)",
+                                       {
+                                           {1, 1, 0},
+                                       }},
+                                      {"(x, y) : (x <= 9)",
+                                       {
+                                           {1, -1, 0},
+                                       }},
+                                      {"(x, y) : (x >= 10, x <= 20, y <= 0)",
+                                       {
+                                           {1, -1, 0},
+                                       }}});
+
+    EXPECT_TRUE(func1.unionLexMax(func2).isEqual(result));
+  }
+
+  // Functions with more than one output, with contribution from both functions.
+  //
+  // If y >= 1, func1 is better because in the first output,
+  // x + y (func1) > x (func2), when y >= 1
+  //
+  // If y == 0, the first output is same for both functions, so we look at the
+  // second output. -2x + 4 (func1) > 2x - 2 (func2) when 0 <= x <= 1, so we
+  // take func1 for this domain and func2 for the remaining.
+  {
+    PWMAFunction func1 = parsePWMAF(
+        /*numInputs=*/2, /*numOutputs=*/2,
+        {
+            {"(x, y) : (x >= 0, y >= 0)", {{1, 1, 0}, {-2, 0, 4}}},
+        });
+
+    PWMAFunction func2 = parsePWMAF(
+        /*numInputs=*/2, /*numOutputs=*/2,
+        {
+            {"(x, y) : (x >= 0, y >= 0)", {{1, 0, 0}, {2, 0, -2}}},
+        });
+
+    PWMAFunction result = parsePWMAF(/*numInputs=*/2, /*numOutputs=*/2,
+                                     {{"(x, y) : (x >= 0, y >= 1)",
+                                       {
+                                           {1, 1, 0},
+                                           {-2, 0, 4},
+                                       }},
+                                      {"(x, y) : (x >= 0, x <= 1, y == 0)",
+                                       {
+                                           {1, 1, 0},
+                                           {-2, 0, 4},
+                                       }},
+                                      {"(x, y) : (x >= 2, y == 0)",
+                                       {
+                                           {1, 0, 0},
+                                           {2, 0, -2},
+                                       }}});
+
+    EXPECT_TRUE(func1.unionLexMax(func2).isEqual(result));
+    EXPECT_TRUE(func2.unionLexMax(func1).isEqual(result));
+  }
+
+  // Function with three boolean variables `a, b, c` used to control which
+  // output will be taken lexicographically.
+  //
+  // a == 1                 --> Take func2
+  // a == 0, b == 1         --> Take func1
+  // a == 0, b == 0, c == 1 --> Take func2
+  {
+    PWMAFunction func1 = parsePWMAF(
+        /*numInputs=*/3, /*numOutputs=*/3,
+        {
+            {"(a, b, c) : (a >= 0, 1 - a >= 0, b >= 0, 1 - b >= 0, c "
+             ">= 0, 1 - c >= 0)",
+             {{0, 0, 0, 0}, {0, 1, 0, 0}, {0, 0, 0, 0}}},
+        });
+
+    PWMAFunction func2 = parsePWMAF(
+        /*numInputs=*/3, /*numOutputs=*/3,
+        {
+            {"(a, b, c) : (a >= 0, 1 - a >= 0, b >= 0, 1 - b >= 0, c >= 0, 1 - "
+             "c >= 0)",
+             {{1, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 1, 0}}},
+        });
+
+    PWMAFunction result = parsePWMAF(
+        /*numInputs=*/3, /*numOutputs=*/3,
+        {
+            {"(a, b, c) : (a - 1 == 0, b >= 0, 1 - b >= 0, c >= 0, 1 - c >= 0)",
+             {{1, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 1, 0}}},
+            {"(a, b, c) : (a == 0, b - 1 == 0, c >= 0, 1 - c >= 0)",
+             {{0, 0, 0, 0}, {0, 1, 0, 0}, {0, 0, 0, 0}}},
+            {"(a, b, c) : (a == 0, b == 0, c >= 0, 1 - c >= 0)",
+             {{1, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 1, 0}}},
+        });
+
+    EXPECT_TRUE(func1.unionLexMax(func2).isEqual(result));
+    EXPECT_TRUE(func2.unionLexMax(func1).isEqual(result));
+  }
+}
+
+TEST(PWMAFunction, unionLexMinComplex) {
+  // Regression test checking if lexicographic tiebreak produces disjoint
+  // domains.
+  //
+  // If x == 1, func1 is better since in the first output,
+  // -x (func1) is < 0 (func2) when x == 1.
+  //
+  // If x == 0, func1 and func2 both have the same first output. So we take a
+  // look at the second output. func2 is better since in the second output,
+  // y - 1 (func2) is < y (func1).
+  PWMAFunction func1 = parsePWMAF(
+      /*numInputs=*/2, /*numOutputs=*/2,
+      {
+          {"(x, y) : (x >= 0, x <= 1, y >= 0, y <= 1)",
+           {{-1, 0, 0}, {0, 1, 0}}},
+      });
+
+  PWMAFunction func2 = parsePWMAF(
+      /*numInputs=*/2, /*numOutputs=*/2,
+      {
+          {"(x, y) : (x >= 0, x <= 1, y >= 0, y <= 1)",
+           {{0, 0, 0}, {0, 1, -1}}},
+      });
+
+  PWMAFunction result = parsePWMAF(
+      /*numInputs=*/2, /*numOutputs=*/2,
+      {
+          {"(x, y) : (x == 1, y >= 0, y <= 1)", {{-1, 0, 0}, {0, 1, 0}}},
+          {"(x, y) : (x == 0, y >= 0, y <= 1)", {{0, 0, 0}, {0, 1, -1}}},
+      });
+
+  EXPECT_TRUE(func1.unionLexMin(func2).isEqual(result));
+  EXPECT_TRUE(func2.unionLexMin(func1).isEqual(result));
+}


        


More information about the Mlir-commits mailing list