[Mlir-commits] [mlir] 1022feb - [MLIR][Presburger] Generating functions and quasi-polynomials for Barvinok's algorithm (#75702)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Dec 26 11:29:30 PST 2023


Author: Abhinav271828
Date: 2023-12-26T21:29:26+02:00
New Revision: 1022febd9df30abbd5c490b94290c4422ca15b01

URL: https://github.com/llvm/llvm-project/commit/1022febd9df30abbd5c490b94290c4422ca15b01
DIFF: https://github.com/llvm/llvm-project/commit/1022febd9df30abbd5c490b94290c4422ca15b01.diff

LOG: [MLIR][Presburger] Generating functions and quasi-polynomials for Barvinok's algorithm (#75702)

Define basic types and classes for Barvinok's algorithm, including
polyhedra, generating functions and quasi-polynomials.
The class definitions include methods for arithmetic manipulation,
printing, logical relations, etc.

Added: 
    mlir/include/mlir/Analysis/Presburger/QuasiPolynomial.h
    mlir/lib/Analysis/Presburger/GeneratingFunction.h
    mlir/lib/Analysis/Presburger/QuasiPolynomial.cpp
    mlir/unittests/Analysis/Presburger/QuasiPolynomialTest.cpp

Modified: 
    mlir/lib/Analysis/Presburger/CMakeLists.txt
    mlir/unittests/Analysis/Presburger/CMakeLists.txt
    mlir/unittests/Analysis/Presburger/Utils.h

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/Presburger/QuasiPolynomial.h b/mlir/include/mlir/Analysis/Presburger/QuasiPolynomial.h
new file mode 100644
index 00000000000000..f8ce8524e41b21
--- /dev/null
+++ b/mlir/include/mlir/Analysis/Presburger/QuasiPolynomial.h
@@ -0,0 +1,71 @@
+//===- QuasiPolynomial.h - QuasiPolynomial Class ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Definition of the QuasiPolynomial class for Barvinok's algorithm,
+// which represents a single-valued function on a set of parameters.
+// It is an expression of the form
+// f(x) = \sum_i c_i * \prod_j ⌊g_{ij}(x)⌋
+// where c_i \in Q and
+// g_{ij} : Q^d -> Q are affine functionals over d parameters.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ANALYSIS_PRESBURGER_QUASIPOLYNOMIAL_H
+#define MLIR_ANALYSIS_PRESBURGER_QUASIPOLYNOMIAL_H
+
+#include "mlir/Analysis/Presburger/Fraction.h"
+#include "mlir/Analysis/Presburger/PresburgerSpace.h"
+
+namespace mlir {
+namespace presburger {
+
+// A class to describe quasi-polynomials.
+// A quasipolynomial consists of a set of terms.
+// The ith term is a constant `coefficients[i]`, multiplied
+// by the product of a set of affine functions on n parameters.
+// Represents functions f : Q^n -> Q of the form
+//
+// f(x) = \sum_i c_i * \prod_j ⌊g_{ij}(x)⌋
+//
+// where c_i \in Q and
+// g_{ij} : Q^n -> Q are affine functionals.
+class QuasiPolynomial : public PresburgerSpace {
+public:
+  QuasiPolynomial(unsigned numVars, SmallVector<Fraction> coeffs = {},
+                  std::vector<std::vector<SmallVector<Fraction>>> aff = {});
+
+  // Find the number of inputs (numDomain) to the polynomial.
+  // numSymbols is set to zero.
+  unsigned getNumInputs() const {
+    return getNumDomainVars() + getNumSymbolVars();
+  }
+
+  const SmallVector<Fraction> &getCoefficients() const { return coefficients; }
+
+  const std::vector<std::vector<SmallVector<Fraction>>> &getAffine() const {
+    return affine;
+  }
+
+  // Arithmetic operations.
+  QuasiPolynomial operator+(const QuasiPolynomial &x) const;
+  QuasiPolynomial operator-(const QuasiPolynomial &x) const;
+  QuasiPolynomial operator*(const QuasiPolynomial &x) const;
+  QuasiPolynomial operator/(const Fraction x) const;
+
+  // Removes terms which evaluate to zero from the expression.
+  QuasiPolynomial simplify();
+
+private:
+  SmallVector<Fraction> coefficients;
+  std::vector<std::vector<SmallVector<Fraction>>> affine;
+};
+
+} // namespace presburger
+} // namespace mlir
+
+#endif // MLIR_ANALYSIS_PRESBURGER_QUASIPOLYNOMIAL_H
\ No newline at end of file

diff  --git a/mlir/lib/Analysis/Presburger/CMakeLists.txt b/mlir/lib/Analysis/Presburger/CMakeLists.txt
index 22f1a4cac44055..e77e1623dae175 100644
--- a/mlir/lib/Analysis/Presburger/CMakeLists.txt
+++ b/mlir/lib/Analysis/Presburger/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_library(MLIRPresburger
   PresburgerRelation.cpp
   PresburgerSpace.cpp
   PWMAFunction.cpp
+  QuasiPolynomial.cpp
   Simplex.cpp
   SlowMPInt.cpp
   Utils.cpp

diff  --git a/mlir/lib/Analysis/Presburger/GeneratingFunction.h b/mlir/lib/Analysis/Presburger/GeneratingFunction.h
new file mode 100644
index 00000000000000..8676b84c1c4df8
--- /dev/null
+++ b/mlir/lib/Analysis/Presburger/GeneratingFunction.h
@@ -0,0 +1,132 @@
+//===- GeneratingFunction.h - Generating Functions over Q^d -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Definition of the GeneratingFunction class for Barvinok's algorithm,
+// which represents a function over Q^n, parameterized by d parameters.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ANALYSIS_PRESBURGER_GENERATINGFUNCTION_H
+#define MLIR_ANALYSIS_PRESBURGER_GENERATINGFUNCTION_H
+
+#include "mlir/Analysis/Presburger/Fraction.h"
+#include "mlir/Analysis/Presburger/Matrix.h"
+
+namespace mlir {
+namespace presburger {
+
+// A parametric point is a vector, each of whose elements
+// is an affine function of n parameters. Each row
+// in the matrix represents the affine function and
+// has n+1 elements.
+using ParamPoint = FracMatrix;
+
+// A point is simply a vector.
+using Point = SmallVector<Fraction>;
+
+// A class to describe the type of generating function
+// used to enumerate the integer points in a polytope.
+// Consists of a set of terms, where the ith term has
+// * a sign, ±1, stored in `signs[i]`
+// * a numerator, of the form x^{n},
+//      where n, stored in `numerators[i]`,
+//      is a parametric point.
+// * a denominator, of the form (1 - x^{d1})...(1 - x^{dn}),
+//      where each dj, stored in `denominators[i][j]`,
+//      is a vector.
+//
+// Represents functions f_p : Q^n -> Q of the form
+//
+// f_p(x) = \sum_i s_i * (x^n_i(p)) / (\prod_j (1 - x^d_{ij})
+//
+// where s_i is ±1,
+// n_i \in Q^d -> Q^n is an n-vector of affine functions on d parameters, and
+// g_{ij} \in Q^n are vectors.
+class GeneratingFunction {
+public:
+  GeneratingFunction(unsigned numParam, SmallVector<int, 8> signs,
+                     std::vector<ParamPoint> nums,
+                     std::vector<std::vector<Point>> dens)
+      : numParam(numParam), signs(signs), numerators(nums), denominators(dens) {
+    for (const ParamPoint &term : numerators)
+      assert(term.getNumColumns() == numParam + 1 &&
+             "dimensionality of numerator exponents does not match number of "
+             "parameters!");
+  }
+
+  unsigned getNumParams() { return numParam; }
+
+  SmallVector<int> getSigns() { return signs; }
+
+  std::vector<ParamPoint> getNumerators() { return numerators; }
+
+  std::vector<std::vector<Point>> getDenominators() { return denominators; }
+
+  GeneratingFunction operator+(const GeneratingFunction &gf) const {
+    assert(numParam == gf.getNumParams() &&
+           "two generating functions with 
diff erent numbers of parameters "
+           "cannot be added!");
+    SmallVector<int> sumSigns = signs;
+    sumSigns.append(gf.signs);
+
+    std::vector<ParamPoint> sumNumerators = numerators;
+    sumNumerators.insert(sumNumerators.end(), gf.numerators.begin(),
+                         gf.numerators.end());
+
+    std::vector<std::vector<Point>> sumDenominators = denominators;
+    sumDenominators.insert(sumDenominators.end(), gf.denominators.begin(),
+                           gf.denominators.end());
+    return GeneratingFunction(sumSigns, sumNumerators, sumDenominators);
+  }
+
+  llvm::raw_ostream &print(llvm::raw_ostream &os) const {
+    for (unsigned i = 0, e = signs.size(); i < e; i++) {
+      if (i == 0) {
+        if (signs[i] == -1)
+          os << "- ";
+      } else {
+        if (signs[i] == 1)
+          os << " + ";
+        else
+          os << " - ";
+      }
+
+      os << "x^[";
+      unsigned r = numerators[i].getNumRows();
+      for (unsigned j = 0; j < r - 1; j++) {
+        os << "[";
+        for (unsigned k = 0, c = numerators[i].getNumColumns(); k < c - 1; k++)
+          os << numerators[i].at(j, k) << ",";
+        os << numerators[i].getRow(j).back() << "],";
+      }
+      os << "[";
+      for (unsigned k = 0, c = numerators[i].getNumColumns(); k < c - 1; k++)
+        os << numerators[i].at(r - 1, k) << ",";
+      os << numerators[i].getRow(r - 1).back() << "]]/";
+
+      for (const Point &den : denominators[i]) {
+        os << "(x^[";
+        for (unsigned j = 0, e = den.size(); j < e - 1; j++)
+          os << den[j] << ",";
+        os << den.back() << "])";
+      }
+    }
+    return os;
+  }
+
+private:
+  unsigned numParam;
+  SmallVector<int, 8> signs;
+  std::vector<ParamPoint> numerators;
+  std::vector<std::vector<Point>> denominators;
+};
+
+} // namespace presburger
+} // namespace mlir
+
+#endif // MLIR_ANALYSIS_PRESBURGER_GENERATINGFUNCTION_H
\ No newline at end of file

diff  --git a/mlir/lib/Analysis/Presburger/QuasiPolynomial.cpp b/mlir/lib/Analysis/Presburger/QuasiPolynomial.cpp
new file mode 100644
index 00000000000000..902e3ced472f82
--- /dev/null
+++ b/mlir/lib/Analysis/Presburger/QuasiPolynomial.cpp
@@ -0,0 +1,113 @@
+//===- QuasiPolynomial.cpp - Quasipolynomial Class --------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/Presburger/QuasiPolynomial.h"
+#include "mlir/Analysis/Presburger/Fraction.h"
+#include "mlir/Analysis/Presburger/PresburgerSpace.h"
+#include "mlir/Analysis/Presburger/Utils.h"
+
+using namespace mlir;
+using namespace presburger;
+
+QuasiPolynomial::QuasiPolynomial(
+    unsigned numVars, SmallVector<Fraction> coeffs,
+    std::vector<std::vector<SmallVector<Fraction>>> aff)
+    : PresburgerSpace(/*numDomain=*/numVars, /*numRange=*/1, /*numSymbols=*/0,
+                      /*numLocals=*/0),
+      coefficients(coeffs), affine(aff) {
+  // For each term which involves at least one affine function,
+  for (const std::vector<SmallVector<Fraction>> &term : affine) {
+    if (term.size() == 0)
+      continue;
+    // the number of elements in each affine function is
+    // one more than the number of symbols.
+    for (const SmallVector<Fraction> &aff : term) {
+      assert(aff.size() == getNumInputs() + 1 &&
+             "dimensionality of affine functions does not match number of "
+             "symbols!");
+    }
+  }
+}
+
+QuasiPolynomial QuasiPolynomial::operator+(const QuasiPolynomial &x) const {
+  assert(getNumInputs() == x.getNumInputs() &&
+         "two quasi-polynomials with 
diff erent numbers of symbols cannot "
+         "be added!");
+  SmallVector<Fraction> sumCoeffs = coefficients;
+  sumCoeffs.append(x.coefficients);
+  std::vector<std::vector<SmallVector<Fraction>>> sumAff = affine;
+  sumAff.insert(sumAff.end(), x.affine.begin(), x.affine.end());
+  return QuasiPolynomial(getNumInputs(), sumCoeffs, sumAff);
+}
+
+QuasiPolynomial QuasiPolynomial::operator-(const QuasiPolynomial &x) const {
+  assert(getNumInputs() == x.getNumInputs() &&
+         "two quasi-polynomials with 
diff erent numbers of symbols cannot "
+         "be subtracted!");
+  QuasiPolynomial qp(getNumInputs(), x.coefficients, x.affine);
+  for (Fraction &coeff : qp.coefficients)
+    coeff = -coeff;
+  return *this + qp;
+}
+
+QuasiPolynomial QuasiPolynomial::operator*(const QuasiPolynomial &x) const {
+  assert(getNumInputs() == x.getNumInputs() &&
+         "two quasi-polynomials with 
diff erent numbers of "
+         "symbols cannot be multiplied!");
+
+  SmallVector<Fraction> coeffs;
+  coeffs.reserve(coefficients.size() * x.coefficients.size());
+  for (const Fraction &coeff : coefficients)
+    for (const Fraction &xcoeff : x.coefficients)
+      coeffs.push_back(coeff * xcoeff);
+
+  std::vector<SmallVector<Fraction>> product;
+  std::vector<std::vector<SmallVector<Fraction>>> aff;
+  aff.reserve(affine.size() * x.affine.size());
+  for (const std::vector<SmallVector<Fraction>> &term : affine) {
+    for (const std::vector<SmallVector<Fraction>> &xterm : x.affine) {
+      product.clear();
+      product.insert(product.end(), term.begin(), term.end());
+      product.insert(product.end(), xterm.begin(), xterm.end());
+      aff.push_back(product);
+    }
+  }
+
+  return QuasiPolynomial(getNumInputs(), coeffs, aff);
+}
+
+QuasiPolynomial QuasiPolynomial::operator/(const Fraction x) const {
+  assert(x != 0 && "division by zero!");
+  QuasiPolynomial qp(*this);
+  for (Fraction &coeff : qp.coefficients)
+    coeff /= x;
+  return qp;
+}
+
+// Removes terms which evaluate to zero from the expression.
+QuasiPolynomial QuasiPolynomial::simplify() {
+  SmallVector<Fraction> newCoeffs({});
+  std::vector<std::vector<SmallVector<Fraction>>> newAffine({});
+  for (unsigned i = 0, e = coefficients.size(); i < e; i++) {
+    // A term is zero if its coefficient is zero, or
+    if (coefficients[i] == Fraction(0, 1))
+      continue;
+    bool product_is_zero =
+        // if any of the affine functions in the product
+        llvm::any_of(affine[i], [](const SmallVector<Fraction> &affine_ij) {
+          // has all its coefficients as zero.
+          return llvm::all_of(affine_ij,
+                              [](const Fraction &f) { return f == 0; });
+        });
+    if (product_is_zero)
+      continue;
+    newCoeffs.push_back(coefficients[i]);
+    newAffine.push_back(affine[i]);
+  }
+  return QuasiPolynomial(getNumInputs(), newCoeffs, newAffine);
+}
\ No newline at end of file

diff  --git a/mlir/unittests/Analysis/Presburger/CMakeLists.txt b/mlir/unittests/Analysis/Presburger/CMakeLists.txt
index b6ce273e35a0ee..e37133354e53ca 100644
--- a/mlir/unittests/Analysis/Presburger/CMakeLists.txt
+++ b/mlir/unittests/Analysis/Presburger/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_unittest(MLIRPresburgerTests
   PresburgerRelationTest.cpp
   PresburgerSpaceTest.cpp
   PWMAFunctionTest.cpp
+  QuasiPolynomialTest.cpp
   SimplexTest.cpp
   UtilsTest.cpp
 )

diff  --git a/mlir/unittests/Analysis/Presburger/QuasiPolynomialTest.cpp b/mlir/unittests/Analysis/Presburger/QuasiPolynomialTest.cpp
new file mode 100644
index 00000000000000..a84f0234067ab7
--- /dev/null
+++ b/mlir/unittests/Analysis/Presburger/QuasiPolynomialTest.cpp
@@ -0,0 +1,140 @@
+//===- MatrixTest.cpp - Tests for QuasiPolynomial -------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/Presburger/QuasiPolynomial.h"
+#include "./Utils.h"
+#include "mlir/Analysis/Presburger/Fraction.h"
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+using namespace mlir;
+using namespace presburger;
+
+// Test the arithmetic operations on QuasiPolynomials;
+// addition, subtraction, multiplication, and division
+// by a constant.
+// Two QPs of 3 parameters each were generated randomly
+// and their sum, 
diff erence, and product computed by hand.
+TEST(QuasiPolynomialTest, arithmetic) {
+  QuasiPolynomial qp1(
+      3, {Fraction(1, 3), Fraction(1, 1), Fraction(1, 2)},
+      {{{Fraction(1, 1), Fraction(-1, 2), Fraction(4, 5), Fraction(0, 1)},
+        {Fraction(2, 3), Fraction(3, 4), Fraction(-1, 1), Fraction(5, 7)}},
+       {{Fraction(1, 2), Fraction(1, 1), Fraction(4, 5), Fraction(1, 1)}},
+       {{Fraction(-3, 2), Fraction(1, 1), Fraction(5, 6), Fraction(7, 5)},
+        {Fraction(1, 4), Fraction(2, 1), Fraction(6, 5), Fraction(-9, 8)},
+        {Fraction(3, 2), Fraction(2, 5), Fraction(-7, 4), Fraction(0, 1)}}});
+  QuasiPolynomial qp2(
+      3, {Fraction(1, 1), Fraction(2, 1)},
+      {{{Fraction(1, 2), Fraction(0, 1), Fraction(-1, 3), Fraction(5, 3)},
+        {Fraction(2, 1), Fraction(5, 4), Fraction(9, 7), Fraction(-1, 5)}},
+       {{Fraction(1, 3), Fraction(-2, 3), Fraction(1, 1), Fraction(0, 1)}}});
+
+  QuasiPolynomial sum = qp1 + qp2;
+  EXPECT_EQ_REPR_QUASIPOLYNOMIAL(
+      sum,
+      QuasiPolynomial(
+          3,
+          {Fraction(1, 3), Fraction(1, 1), Fraction(1, 2), Fraction(1, 1),
+           Fraction(2, 1)},
+          {{{Fraction(1, 1), Fraction(-1, 2), Fraction(4, 5), Fraction(0, 1)},
+            {Fraction(2, 3), Fraction(3, 4), Fraction(-1, 1), Fraction(5, 7)}},
+           {{Fraction(1, 2), Fraction(1, 1), Fraction(4, 5), Fraction(1, 1)}},
+           {{Fraction(-3, 2), Fraction(1, 1), Fraction(5, 6), Fraction(7, 5)},
+            {Fraction(1, 4), Fraction(2, 1), Fraction(6, 5), Fraction(-9, 8)},
+            {Fraction(3, 2), Fraction(2, 5), Fraction(-7, 4), Fraction(0, 1)}},
+           {{Fraction(1, 2), Fraction(0, 1), Fraction(-1, 3), Fraction(5, 3)},
+            {Fraction(2, 1), Fraction(5, 4), Fraction(9, 7), Fraction(-1, 5)}},
+           {{Fraction(1, 3), Fraction(-2, 3), Fraction(1, 1),
+             Fraction(0, 1)}}}));
+
+  QuasiPolynomial 
diff  = qp1 - qp2;
+  EXPECT_EQ_REPR_QUASIPOLYNOMIAL(
+      
diff ,
+      QuasiPolynomial(
+          3,
+          {Fraction(1, 3), Fraction(1, 1), Fraction(1, 2), Fraction(-1, 1),
+           Fraction(-2, 1)},
+          {{{Fraction(1, 1), Fraction(-1, 2), Fraction(4, 5), Fraction(0, 1)},
+            {Fraction(2, 3), Fraction(3, 4), Fraction(-1, 1), Fraction(5, 7)}},
+           {{Fraction(1, 2), Fraction(1, 1), Fraction(4, 5), Fraction(1, 1)}},
+           {{Fraction(-3, 2), Fraction(1, 1), Fraction(5, 6), Fraction(7, 5)},
+            {Fraction(1, 4), Fraction(2, 1), Fraction(6, 5), Fraction(-9, 8)},
+            {Fraction(3, 2), Fraction(2, 5), Fraction(-7, 4), Fraction(0, 1)}},
+           {{Fraction(1, 2), Fraction(0, 1), Fraction(-1, 3), Fraction(5, 3)},
+            {Fraction(2, 1), Fraction(5, 4), Fraction(9, 7), Fraction(-1, 5)}},
+           {{Fraction(1, 3), Fraction(-2, 3), Fraction(1, 1),
+             Fraction(0, 1)}}}));
+
+  QuasiPolynomial prod = qp1 * qp2;
+  EXPECT_EQ_REPR_QUASIPOLYNOMIAL(
+      prod,
+      QuasiPolynomial(
+          3,
+          {Fraction(1, 3), Fraction(2, 3), Fraction(1, 1), Fraction(2, 1),
+           Fraction(1, 2), Fraction(1, 1)},
+          {{{Fraction(1, 1), Fraction(-1, 2), Fraction(4, 5), Fraction(0, 1)},
+            {Fraction(2, 3), Fraction(3, 4), Fraction(-1, 1), Fraction(5, 7)},
+            {Fraction(1, 2), Fraction(0, 1), Fraction(-1, 3), Fraction(5, 3)},
+            {Fraction(2, 1), Fraction(5, 4), Fraction(9, 7), Fraction(-1, 5)}},
+           {{Fraction(1, 1), Fraction(-1, 2), Fraction(4, 5), Fraction(0, 1)},
+            {Fraction(2, 3), Fraction(3, 4), Fraction(-1, 1), Fraction(5, 7)},
+            {Fraction(1, 3), Fraction(-2, 3), Fraction(1, 1), Fraction(0, 1)}},
+           {{Fraction(1, 2), Fraction(1, 1), Fraction(4, 5), Fraction(1, 1)},
+            {Fraction(1, 2), Fraction(0, 1), Fraction(-1, 3), Fraction(5, 3)},
+            {Fraction(2, 1), Fraction(5, 4), Fraction(9, 7), Fraction(-1, 5)}},
+           {{Fraction(1, 2), Fraction(1, 1), Fraction(4, 5), Fraction(1, 1)},
+            {Fraction(1, 3), Fraction(-2, 3), Fraction(1, 1), Fraction(0, 1)}},
+           {{Fraction(-3, 2), Fraction(1, 1), Fraction(5, 6), Fraction(7, 5)},
+            {Fraction(1, 4), Fraction(2, 1), Fraction(6, 5), Fraction(-9, 8)},
+            {Fraction(3, 2), Fraction(2, 5), Fraction(-7, 4), Fraction(0, 1)},
+            {Fraction(1, 2), Fraction(0, 1), Fraction(-1, 3), Fraction(5, 3)},
+            {Fraction(2, 1), Fraction(5, 4), Fraction(9, 7), Fraction(-1, 5)}},
+           {{Fraction(-3, 2), Fraction(1, 1), Fraction(5, 6), Fraction(7, 5)},
+            {Fraction(1, 4), Fraction(2, 1), Fraction(6, 5), Fraction(-9, 8)},
+            {Fraction(3, 2), Fraction(2, 5), Fraction(-7, 4), Fraction(0, 1)},
+            {Fraction(1, 3), Fraction(-2, 3), Fraction(1, 1),
+             Fraction(0, 1)}}}));
+
+  QuasiPolynomial quot = qp1 / 2;
+  EXPECT_EQ_REPR_QUASIPOLYNOMIAL(
+      quot,
+      QuasiPolynomial(
+          3, {Fraction(1, 6), Fraction(1, 2), Fraction(1, 4)},
+          {{{Fraction(1, 1), Fraction(-1, 2), Fraction(4, 5), Fraction(0, 1)},
+            {Fraction(2, 3), Fraction(3, 4), Fraction(-1, 1), Fraction(5, 7)}},
+           {{Fraction(1, 2), Fraction(1, 1), Fraction(4, 5), Fraction(1, 1)}},
+           {{Fraction(-3, 2), Fraction(1, 1), Fraction(5, 6), Fraction(7, 5)},
+            {Fraction(1, 4), Fraction(2, 1), Fraction(6, 5), Fraction(-9, 8)},
+            {Fraction(3, 2), Fraction(2, 5), Fraction(-7, 4),
+             Fraction(0, 1)}}}));
+}
+
+// Test the simplify() operation on QPs, which removes terms that
+// are identically zero. A random QP was generated and terms were
+// changed to account for each condition in simplify() – 
+// the term coefficient being zero, or all the coefficients in some
+// affine term in the product being zero.
+TEST(QuasiPolynomialTest, simplify) {
+  QuasiPolynomial qp(2,
+                     {Fraction(2, 3), Fraction(0, 1), Fraction(1, 1),
+                      Fraction(1, 2), Fraction(0, 1)},
+                     {{{Fraction(1, 1), Fraction(3, 4), Fraction(5, 3)},
+                       {Fraction(2, 1), Fraction(0, 1), Fraction(0, 1)}},
+                      {{Fraction(1, 3), Fraction(8, 5), Fraction(2, 5)}},
+                      {{Fraction(2, 7), Fraction(9, 5), Fraction(0, 1)},
+                       {Fraction(0, 1), Fraction(0, 1), Fraction(0, 1)}},
+                      {{Fraction(1, 1), Fraction(4, 5), Fraction(6, 5)}},
+                      {{Fraction(1, 3), Fraction(4, 3), Fraction(7, 8)}}});
+  EXPECT_EQ_REPR_QUASIPOLYNOMIAL(
+      qp.simplify(),
+      QuasiPolynomial(2, {Fraction(2, 3), Fraction(1, 2)},
+                      {{{Fraction(1, 1), Fraction(3, 4), Fraction(5, 3)},
+                        {Fraction(2, 1), Fraction(0, 1), Fraction(0, 1)}},
+                       {{Fraction(1, 1), Fraction(4, 5), Fraction(6, 5)}}}));
+}
\ No newline at end of file

diff  --git a/mlir/unittests/Analysis/Presburger/Utils.h b/mlir/unittests/Analysis/Presburger/Utils.h
index 544577375dd1d1..2a9966c7ce2ea5 100644
--- a/mlir/unittests/Analysis/Presburger/Utils.h
+++ b/mlir/unittests/Analysis/Presburger/Utils.h
@@ -17,6 +17,7 @@
 #include "mlir/Analysis/Presburger/Matrix.h"
 #include "mlir/Analysis/Presburger/PWMAFunction.h"
 #include "mlir/Analysis/Presburger/PresburgerRelation.h"
+#include "mlir/Analysis/Presburger/QuasiPolynomial.h"
 #include "mlir/Analysis/Presburger/Simplex.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/Support/LLVM.h"
@@ -71,6 +72,28 @@ inline void EXPECT_EQ_FRAC_MATRIX(FracMatrix a, FracMatrix b) {
       EXPECT_EQ(a(row, col), b(row, col));
 }
 
+// Check the coefficients (in order) of two quasipolynomials.
+// Note that this is not a true equality check.
+inline void EXPECT_EQ_REPR_QUASIPOLYNOMIAL(QuasiPolynomial a, QuasiPolynomial b) {
+  EXPECT_EQ(a.getNumInputs(), b.getNumInputs());
+
+  SmallVector<Fraction> aCoeffs = a.getCoefficients(),
+                        bCoeffs = b.getCoefficients();
+  EXPECT_EQ(aCoeffs.size(), bCoeffs.size());
+  for (unsigned i = 0, e = aCoeffs.size(); i < e; i++)
+    EXPECT_EQ(aCoeffs[i], bCoeffs[i]);
+
+  std::vector<std::vector<SmallVector<Fraction>>> aAff = a.getAffine(),
+                                                  bAff = b.getAffine();
+  EXPECT_EQ(aAff.size(), bAff.size());
+  for (unsigned i = 0, e = aAff.size(); i < e; i++) {
+    EXPECT_EQ(aAff[i].size(), bAff[i].size());
+    for (unsigned j = 0, f = aAff[i].size(); j < f; j++)
+      for (unsigned k = 0, g = a.getNumInputs(); k <= g; k++)
+        EXPECT_EQ(aAff[i][j][k], bAff[i][j][k]);
+  }
+}
+
 /// lhs and rhs represent non-negative integers or positive infinity. The
 /// infinity case corresponds to when the Optional is empty.
 inline bool infinityOrUInt64LE(std::optional<MPInt> lhs,


        


More information about the Mlir-commits mailing list