[Mlir-commits] [mlir] [MLIR][Presburger] Generating functions and quasi-polynomials for Barvinok's algorithm (PR #75702)

Arjun P llvmlistbot at llvm.org
Mon Dec 18 01:40:02 PST 2023

@@ -0,0 +1,237 @@
+//===- Barvinok.h - Barvinok's Algorithm ------------------------*- C++ -*-===//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+// Functions and classes for Barvinok's algorithm in MLIR.
+#include "mlir/Analysis/Presburger/Fraction.h"
+#include "mlir/Analysis/Presburger/IntegerRelation.h"
+#include "mlir/Analysis/Presburger/Matrix.h"
+#include "mlir/Analysis/Presburger/PresburgerSpace.h"
+#include "mlir/Analysis/Presburger/Utils.h"
+#include "mlir/Support/LogicalResult.h"
+#include <optional>
+namespace mlir {
+namespace presburger {
+// The H (inequality) representation of both general
+// polyhedra and cones specifically is an integer relation.
+using PolyhedronH = IntegerRelation;
+using ConeH = PolyhedronH;
+// The V (generator) representation of both general
+// polyhedra and cones specifically is simply a matrix
+// whose rows are the generators.
+using PolyhedronV = Matrix<MPInt>;
+using ConeV = PolyhedronV;
+// A parametric point is a vector, each of whose elements
+// is an affine function of n parameters. Each row
+// in the matrix represents the affine function and
+// has n+1 elements.
+using ParamPoint = Matrix<Fraction>;
+// A point is simply a vector.
+using Point = SmallVector<Fraction>;
+// A class to describe the type of generating function
+// used to enumerate the integer points in a polytope.
+// Consists of a set of terms, where the ith term has
+// * a sign, ±1, stored in `signs[i]`
+// * a numerator, of the form x^{n},
+//      where n, stored in `numerators[i]`,
+//      is a parametric point (a vertex).
+// * a denominator, of the form (1 - x^{d1})...(1 - x^{dn}),
+//      where each dj, stored in `denominators[i][j]`,
+//      is a vector (a generator).
+// Represents functions f : Q^n -> Q of the form
+// f(x) = \sum_i s_i * (x^n_i(p)) / (\prod_j (1 - x^d_{ij})
+// where s_i is ±1,
+// n_i \in (Q^d -> Q)^n is an n-vector of affine functions on d parameters, and
+// g_{ij} \in Q^n are vectors.
+class GeneratingFunction {
+  GeneratingFunction(unsigned numParam, SmallVector<int, 8> signs,
+                     std::vector<ParamPoint> nums,
+                     std::vector<std::vector<Point>> dens)
+      : numParam(numParam), signs(signs), numerators(nums), denominators(dens) {
+    for (const ParamPoint &term : numerators)
+      assert(term.getNumColumns() - 1 == numParam &&
+             "dimensionality of numerator exponents does not match number of "
+             "parameters!");
+  }
+  // Find the number of parameters involved in the function
+  // from the dimensionality of the affine functions.
+  unsigned getNumParams() { return numParam; }
+  GeneratingFunction operator+(const GeneratingFunction &gf) {
+    assert(numParam == gf.getNumParams() &&
+           "two generating functions with different numbers of parameters "
+           "cannot be added!");
+    signs.append(gf.signs);
+    numerators.insert(numerators.end(), gf.numerators.begin(),
+                      gf.numerators.end());
+    denominators.insert(denominators.end(), gf.denominators.begin(),
+                        gf.denominators.end());
+    return *this;
+  }
+  llvm::raw_ostream &print(llvm::raw_ostream &os) const {
+    for (unsigned i = 0, e = signs.size(); i < e; i++) {
+      if (signs[i] == 1)
+        os << " + ";
+      else
+        os << " - ";
+      os << "x^[";
+      for (unsigned j = 0, e = numerators[i].size(); j < e - 1; j++)
+        os << numerators[i][j] << ",";
+      os << numerators[i].back() << "]/";
+      for (Point den : denominators[i]) {
+        os << "(x^[";
+        for (unsigned j = 0, e = den.size(); j < e - 1; j++)
+          os << den[j] << ",";
+        os << den[den.size() - 1] << "])";
+      }
+    }
+    return os;
+  }
+  SmallVector<int, 8> signs;
+  std::vector<ParamPoint> numerators;
+  std::vector<std::vector<Point>> denominators;
Superty wrote:

Constant value would just be a function of the class right?


More information about the Mlir-commits mailing list