[Mlir-commits] [mlir] c4abef2 - [MLIR][Presburger] support symbolicLexMin for IntegerRelation
Arjun P
llvmlistbot at llvm.org
Fri Jul 1 10:00:08 PDT 2022
Author: Arjun P
Date: 2022-07-01T18:00:11+01:00
New Revision: c4abef28a3bd6d37acc80b5659fe458fc0b3fc18
URL: https://github.com/llvm/llvm-project/commit/c4abef28a3bd6d37acc80b5659fe458fc0b3fc18
DIFF: https://github.com/llvm/llvm-project/commit/c4abef28a3bd6d37acc80b5659fe458fc0b3fc18.diff
LOG: [MLIR][Presburger] support symbolicLexMin for IntegerRelation
This also changes the space of the returned lexmin for IntegerPolyhedrons;
the symbols in the poly now correspond to symbols in the result rather than dims.
Reviewed By: Groverkss
Differential Revision: https://reviews.llvm.org/D128933
Added:
Modified:
mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
mlir/lib/Analysis/Presburger/IntegerRelation.cpp
mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
mlir/unittests/Analysis/Presburger/Utils.h
Removed:
################################################################################
diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
index 58ac2a5c440b2..720309a7c1e98 100644
--- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
+++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
@@ -28,6 +28,7 @@ class IntegerRelation;
class IntegerPolyhedron;
class PresburgerSet;
class PresburgerRelation;
+struct SymbolicLexMin;
/// An IntegerRelation represents the set of points from a PresburgerSpace that
/// satisfy a list of affine constraints. Affine constraints can be inequalities
@@ -583,6 +584,30 @@ class IntegerRelation {
/// union of convex disjuncts.
PresburgerRelation computeReprWithOnlyDivLocals() const;
+ /// Compute the symbolic integer lexmin of the relation.
+ ///
+ /// This finds, for every assignment to the symbols and domain,
+ /// the lexicographically minimum value attained by the range.
+ ///
+ /// For example, the symbolic lexmin of the set
+ ///
+ /// (x, y)[a, b, c] : (a <= x, b <= x, x <= c)
+ ///
+ /// can be written as
+ ///
+ /// x = a if b <= a, a <= c
+ /// x = b if a < b, b <= c
+ ///
+ /// This function is stored in the `lexmin` function in the result.
+ /// Some assignments to the symbols might make the set empty.
+ /// Such points are not part of the function's domain.
+ /// In the above example, this happens when max(a, b) > c.
+ ///
+ /// For some values of the symbols, the lexmin may be unbounded.
+ /// `SymbolicLexMin` stores these parts of the symbolic domain in a separate
+ /// `PresburgerSet`, `unboundedDomain`.
+ SymbolicLexMin findSymbolicIntegerLexMin() const;
+
void print(raw_ostream &os) const;
void dump() const;
@@ -692,8 +717,6 @@ class IntegerRelation {
Matrix inequalities;
};
-struct SymbolicLexMin;
-
/// An IntegerPolyhedron represents the set of points from a PresburgerSpace
/// that satisfy a list of affine constraints. Affine constraints can be
/// inequalities or equalities in the form:
@@ -767,28 +790,6 @@ class IntegerPolyhedron : public IntegerRelation {
/// column position (i.e., not relative to the kind of variable) of the
/// first added variable.
unsigned insertVar(VarKind kind, unsigned pos, unsigned num = 1) override;
-
- /// Compute the symbolic integer lexmin of the polyhedron.
- /// This finds, for every assignment to the symbols, the lexicographically
- /// minimum value attained by the dimensions. For example, the symbolic lexmin
- /// of the set
- ///
- /// (x, y)[a, b, c] : (a <= x, b <= x, x <= c)
- ///
- /// can be written as
- ///
- /// x = a if b <= a, a <= c
- /// x = b if a < b, b <= c
- ///
- /// This function is stored in the `lexmin` function in the result.
- /// Some assignments to the symbols might make the set empty.
- /// Such points are not part of the function's domain.
- /// In the above example, this happens when max(a, b) > c.
- ///
- /// For some values of the symbols, the lexmin may be unbounded.
- /// `SymbolicLexMin` stores these parts of the symbolic domain in a separate
- /// `PresburgerSet`, `unboundedDomain`.
- SymbolicLexMin findSymbolicIntegerLexMin() const;
};
} // namespace presburger
diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
index d7c70a4353bc3..fd4c038c6d82e 100644
--- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
@@ -226,12 +226,23 @@ PresburgerRelation IntegerRelation::computeReprWithOnlyDivLocals() const {
return result;
}
-SymbolicLexMin IntegerPolyhedron::findSymbolicIntegerLexMin() const {
+SymbolicLexMin IntegerRelation::findSymbolicIntegerLexMin() const {
+ // Symbol and Domain vars will be used as symbols for symbolic lexmin.
+ // In other words, for every value of the symbols and domain, return the
+ // lexmin value of the (range, locals).
+ llvm::SmallBitVector isSymbol(getNumVars(), false);
+ isSymbol.set(getVarKindOffset(VarKind::Symbol),
+ getVarKindEnd(VarKind::Symbol));
+ isSymbol.set(getVarKindOffset(VarKind::Domain),
+ getVarKindEnd(VarKind::Domain));
// Compute the symbolic lexmin of the dims and locals, with the symbols being
// the actual symbols of this set.
SymbolicLexMin result =
- SymbolicLexSimplex(*this, IntegerPolyhedron(PresburgerSpace::getSetSpace(
- /*numDims=*/getNumSymbolVars())))
+ SymbolicLexSimplex(*this,
+ IntegerPolyhedron(PresburgerSpace::getSetSpace(
+ /*numDims=*/getNumDomainVars(),
+ /*numSymbols=*/getNumSymbolVars())),
+ isSymbol)
.computeSymbolicIntegerLexMin();
// We want to return only the lexmin over the dims, so strip the locals from
diff --git a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
index 098b58b6df6ca..581558cf9bbfb 100644
--- a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
@@ -1170,10 +1170,11 @@ void expectSymbolicIntegerLexMin(
PWMAFunction expectedLexmin =
parsePWMAF(/*numInputs=*/poly.getNumSymbolVars(),
- /*numOutputs=*/poly.getNumDimVars(), expectedLexminRepr);
+ /*numOutputs=*/poly.getNumDimVars(), expectedLexminRepr,
+ /*numSymbols=*/poly.getNumSymbolVars());
PresburgerSet expectedUnboundedDomain = parsePresburgerSetFromPolyStrings(
- poly.getNumSymbolVars(), expectedUnboundedDomainRepr);
+ /*numDims=*/0, expectedUnboundedDomainRepr, poly.getNumSymbolVars());
SymbolicLexMin result = poly.findSymbolicIntegerLexMin();
@@ -1200,114 +1201,116 @@ void expectSymbolicIntegerLexMin(
TEST(IntegerPolyhedronTest, findSymbolicIntegerLexMin) {
expectSymbolicIntegerLexMin("(x)[a] : (x - a >= 0)",
{
- {"(a) : ()", {{1, 0}}}, // a
+ {"()[a] : ()", {{1, 0}}}, // a
});
expectSymbolicIntegerLexMin(
"(x)[a, b] : (x - a >= 0, x - b >= 0)",
{
- {"(a, b) : (a - b >= 0)", {{1, 0, 0}}}, // a
- {"(a, b) : (b - a - 1 >= 0)", {{0, 1, 0}}}, // b
+ {"()[a, b] : (a - b >= 0)", {{1, 0, 0}}}, // a
+ {"()[a, b] : (b - a - 1 >= 0)", {{0, 1, 0}}}, // b
});
expectSymbolicIntegerLexMin(
"(x)[a, b, c] : (x -a >= 0, x - b >= 0, x - c >= 0)",
{
- {"(a, b, c) : (a - b >= 0, a - c >= 0)", {{1, 0, 0, 0}}}, // a
- {"(a, b, c) : (b - a - 1 >= 0, b - c >= 0)", {{0, 1, 0, 0}}}, // b
- {"(a, b, c) : (c - a - 1 >= 0, c - b - 1 >= 0)", {{0, 0, 1, 0}}}, // c
+ {"()[a, b, c] : (a - b >= 0, a - c >= 0)", {{1, 0, 0, 0}}}, // a
+ {"()[a, b, c] : (b - a - 1 >= 0, b - c >= 0)", {{0, 1, 0, 0}}}, // b
+ {"()[a, b, c] : (c - a - 1 >= 0, c - b - 1 >= 0)",
+ {{0, 0, 1, 0}}}, // c
});
expectSymbolicIntegerLexMin("(x, y)[a] : (x - a >= 0, x + y >= 0)",
{
- {"(a) : ()", {{1, 0}, {-1, 0}}}, // (a, -a)
+ {"()[a] : ()", {{1, 0}, {-1, 0}}}, // (a, -a)
});
expectSymbolicIntegerLexMin(
"(x, y)[a] : (x - a >= 0, x + y >= 0, y >= 0)",
{
- {"(a) : (a >= 0)", {{1, 0}, {0, 0}}}, // (a, 0)
- {"(a) : (-a - 1 >= 0)", {{1, 0}, {-1, 0}}}, // (a, -a)
+ {"()[a] : (a >= 0)", {{1, 0}, {0, 0}}}, // (a, 0)
+ {"()[a] : (-a - 1 >= 0)", {{1, 0}, {-1, 0}}}, // (a, -a)
});
expectSymbolicIntegerLexMin(
"(x, y)[a, b, c] : (x - a >= 0, y - b >= 0, c - x - y >= 0)",
{
- {"(a, b, c) : (c - a - b >= 0)",
+ {"()[a, b, c] : (c - a - b >= 0)",
{{1, 0, 0, 0}, {0, 1, 0, 0}}}, // (a, b)
});
expectSymbolicIntegerLexMin(
"(x, y, z)[a, b, c] : (c - z >= 0, b - y >= 0, x + y + z - a == 0)",
{
- {"(a, b, c) : ()",
+ {"()[a, b, c] : ()",
{{1, -1, -1, 0}, {0, 1, 0, 0}, {0, 0, 1, 0}}}, // (a - b - c, b, c)
});
expectSymbolicIntegerLexMin(
"(x)[a, b] : (a >= 0, b >= 0, x >= 0, a + b + x - 1 >= 0)",
{
- {"(a, b) : (a >= 0, b >= 0, a + b - 1 >= 0)", {{0, 0, 0}}}, // 0
- {"(a, b) : (a == 0, b == 0)", {{0, 0, 1}}}, // 1
+ {"()[a, b] : (a >= 0, b >= 0, a + b - 1 >= 0)", {{0, 0, 0}}}, // 0
+ {"()[a, b] : (a == 0, b == 0)", {{0, 0, 1}}}, // 1
});
expectSymbolicIntegerLexMin(
"(x)[a, b] : (1 - a >= 0, a >= 0, 1 - b >= 0, b >= 0, 1 - x >= 0, x >= "
"0, a + b + x - 1 >= 0)",
{
- {"(a, b) : (1 - a >= 0, a >= 0, 1 - b >= 0, b >= 0, a + b - 1 >= 0)",
- {{0, 0, 0}}}, // 0
- {"(a, b) : (a == 0, b == 0)", {{0, 0, 1}}}, // 1
+ {"()[a, b] : (1 - a >= 0, a >= 0, 1 - b >= 0, b >= 0, a + b - 1 >= "
+ "0)",
+ {{0, 0, 0}}}, // 0
+ {"()[a, b] : (a == 0, b == 0)", {{0, 0, 1}}}, // 1
});
expectSymbolicIntegerLexMin(
"(x, y, z)[a, b] : (x - a == 0, y - b == 0, x >= 0, y >= 0, z >= 0, x + "
"y + z - 1 >= 0)",
{
- {"(a, b) : (a >= 0, b >= 0, 1 - a - b >= 0)",
+ {"()[a, b] : (a >= 0, b >= 0, 1 - a - b >= 0)",
{{1, 0, 0}, {0, 1, 0}, {-1, -1, 1}}}, // (a, b, 1 - a - b)
- {"(a, b) : (a >= 0, b >= 0, a + b - 2 >= 0)",
+ {"()[a, b] : (a >= 0, b >= 0, a + b - 2 >= 0)",
{{1, 0, 0}, {0, 1, 0}, {0, 0, 0}}}, // (a, b, 0)
});
expectSymbolicIntegerLexMin("(x)[a, b] : (x - a == 0, x - b >= 0)",
{
- {"(a, b) : (a - b >= 0)", {{1, 0, 0}}}, // a
+ {"()[a, b] : (a - b >= 0)", {{1, 0, 0}}}, // a
});
expectSymbolicIntegerLexMin(
"(q)[a] : (a - 1 - 3*q == 0, q >= 0)",
{
- {"(a) : (a - 1 - 3*(a floordiv 3) == 0, a >= 0)",
+ {"()[a] : (a - 1 - 3*(a floordiv 3) == 0, a >= 0)",
{{0, 1, 0}}}, // a floordiv 3
});
expectSymbolicIntegerLexMin(
"(r, q)[a] : (a - r - 3*q == 0, q >= 0, 1 - r >= 0, r >= 0)",
{
- {"(a) : (a - 0 - 3*(a floordiv 3) == 0, a >= 0)",
+ {"()[a] : (a - 0 - 3*(a floordiv 3) == 0, a >= 0)",
{{0, 0, 0}, {0, 1, 0}}}, // (0, a floordiv 3)
- {"(a) : (a - 1 - 3*(a floordiv 3) == 0, a >= 0)",
+ {"()[a] : (a - 1 - 3*(a floordiv 3) == 0, a >= 0)",
{{0, 0, 1}, {0, 1, 0}}}, // (1 a floordiv 3)
});
expectSymbolicIntegerLexMin(
"(r, q)[a] : (a - r - 3*q == 0, q >= 0, 2 - r >= 0, r - 1 >= 0)",
{
- {"(a) : (a - 1 - 3*(a floordiv 3) == 0, a >= 0)",
+ {"()[a] : (a - 1 - 3*(a floordiv 3) == 0, a >= 0)",
{{0, 0, 1}, {0, 1, 0}}}, // (1, a floordiv 3)
- {"(a) : (a - 2 - 3*(a floordiv 3) == 0, a >= 0)",
+ {"()[a] : (a - 2 - 3*(a floordiv 3) == 0, a >= 0)",
{{0, 0, 2}, {0, 1, 0}}}, // (2, a floordiv 3)
});
expectSymbolicIntegerLexMin(
"(r, q)[a] : (a - r - 3*q == 0, q >= 0, r >= 0)",
{
- {"(a) : (a - 3*(a floordiv 3) == 0, a >= 0)",
+ {"()[a] : (a - 3*(a floordiv 3) == 0, a >= 0)",
{{0, 0, 0}, {0, 1, 0}}}, // (0, a floordiv 3)
- {"(a) : (a - 1 - 3*(a floordiv 3) == 0, a >= 0)",
+ {"()[a] : (a - 1 - 3*(a floordiv 3) == 0, a >= 0)",
{{0, 0, 1}, {0, 1, 0}}}, // (1, a floordiv 3)
- {"(a) : (a - 2 - 3*(a floordiv 3) == 0, a >= 0)",
+ {"()[a] : (a - 2 - 3*(a floordiv 3) == 0, a >= 0)",
{{0, 0, 2}, {0, 1, 0}}}, // (2, a floordiv 3)
});
@@ -1323,11 +1326,11 @@ TEST(IntegerPolyhedronTest, findSymbolicIntegerLexMin) {
// What's the lexmin solution using exactly g true vars?
"g - x - y - z - w == 0)",
{
- {"(g) : (g - 1 == 0)",
+ {"()[g] : (g - 1 == 0)",
{{0, 0}, {0, 1}, {0, 0}, {0, 0}}}, // (0, 1, 0, 0)
- {"(g) : (g - 2 == 0)",
+ {"()[g] : (g - 2 == 0)",
{{0, 0}, {0, 0}, {0, 1}, {0, 1}}}, // (0, 0, 1, 1)
- {"(g) : (g - 3 == 0)",
+ {"()[g] : (g - 3 == 0)",
{{0, 0}, {0, 1}, {0, 1}, {0, 1}}}, // (0, 1, 1, 1)
});
@@ -1340,11 +1343,11 @@ TEST(IntegerPolyhedronTest, findSymbolicIntegerLexMin) {
// According to Bezout's lemma, 14x + 35y can take on all multiples
// of 7 and no other values. So the solution exists iff r - a is a
// multiple of 7.
- {"(a, r) : (a >= 0, r - a - 7*((r - a) floordiv 7) == 0)"});
+ {"()[a, r] : (a >= 0, r - a - 7*((r - a) floordiv 7) == 0)"});
// The lexmins are unbounded.
expectSymbolicIntegerLexMin("(x, y)[a] : (9*x - 4*y - 2*a >= 0)", {},
- {"(a) : ()"});
+ {"()[a] : ()"});
// Test cases adapted from isl.
expectSymbolicIntegerLexMin(
@@ -1352,7 +1355,7 @@ TEST(IntegerPolyhedronTest, findSymbolicIntegerLexMin) {
// So b is minimized when c = b.
"(b, c)[a] : (a - 4*b + 2*c == 0, c - b >= 0)",
{
- {"(a) : (a - 2*(a floordiv 2) == 0)",
+ {"()[a] : (a - 2*(a floordiv 2) == 0)",
{{0, 1, 0}, {0, 1, 0}}}, // (a floordiv 2, a floordiv 2)
});
@@ -1362,7 +1365,7 @@ TEST(IntegerPolyhedronTest, findSymbolicIntegerLexMin) {
"(b)[a] : (255 - b >= 0, b >= 0, a - 512*b - 1 >= 0, 512*b -a + 509 >= "
"0, b + 7 - 16*((8 + b) floordiv 16) >= 0)",
{
- {"(a) : (255 - (a floordiv 512) >= 0, a >= 0, a - 512*(a floordiv "
+ {"()[a] : (255 - (a floordiv 512) >= 0, a >= 0, a - 512*(a floordiv "
"512) - 1 >= 0, 512*(a floordiv 512) - a + 509 >= 0, (a floordiv "
"512) + 7 - 16*((8 + (a floordiv 512)) floordiv 16) >= 0)",
{{0, 1, 0, 0}}}, // (a floordiv 2, a floordiv 2)
@@ -1375,7 +1378,8 @@ TEST(IntegerPolyhedronTest, findSymbolicIntegerLexMin) {
"2*N - 3*K + a - b >= 0, 4*N - K + 1 - 3*b >= 0, b - N >= 0, a - x - 1 "
">= 0)",
{{
- "(K, N, x, y) : (x + 6 - 2*N >= 0, 2*N - 5 - x >= 0, x + 1 -3*K + N "
+ "()[K, N, x, y] : (x + 6 - 2*N >= 0, 2*N - 5 - x >= 0, x + 1 -3*K + "
+ "N "
">= 0, N + K - 2 - x >= 0, x - 4 >= 0)",
{{0, 0, 1, 0, 1}, {0, 1, 0, 0, 0}} // (1 + x, N)
}});
diff --git a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
index 51a649205bfce..c241d5d2f0e81 100644
--- a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
@@ -8,6 +8,7 @@
#include "mlir/Analysis/Presburger/IntegerRelation.h"
#include "./Utils.h"
+#include "mlir/Analysis/Presburger/Simplex.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
@@ -122,3 +123,20 @@ TEST(IntegerRelationTest, applyDomainAndRange) {
EXPECT_TRUE(map1.isEqual(map3));
}
}
+
+TEST(IntegerRelationTest, symbolicLexmin) {
+ SymbolicLexMin lexmin =
+ parseRelationFromSet("(a, x)[b] : (x - a >= 0, x - b >= 0)", 1)
+ .findSymbolicIntegerLexMin();
+
+ PWMAFunction expectedLexmin =
+ parsePWMAF(/*numInputs=*/2,
+ /*numOutputs=*/1,
+ {
+ {"(a)[b] : (a - b >= 0)", {{1, 0, 0}}}, // a
+ {"(a)[b] : (b - a - 1 >= 0)", {{0, 1, 0}}}, // b
+ },
+ /*numSymbols=*/1);
+ EXPECT_TRUE(lexmin.unboundedDomain.isIntegerEmpty());
+ EXPECT_TRUE(lexmin.lexmin.isEqual(expectedLexmin));
+}
diff --git a/mlir/unittests/Analysis/Presburger/Utils.h b/mlir/unittests/Analysis/Presburger/Utils.h
index bf9163cf6184c..3b4a479e97ded 100644
--- a/mlir/unittests/Analysis/Presburger/Utils.h
+++ b/mlir/unittests/Analysis/Presburger/Utils.h
@@ -17,6 +17,7 @@
#include "mlir/Analysis/Presburger/IntegerRelation.h"
#include "mlir/Analysis/Presburger/PWMAFunction.h"
#include "mlir/Analysis/Presburger/PresburgerRelation.h"
+#include "mlir/Analysis/Presburger/Simplex.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Support/LLVM.h"
@@ -40,9 +41,10 @@ inline IntegerPolyhedron parsePoly(StringRef str) {
/// are all valid IntegerSet representation and that all of them have the same
/// number of dimensions as is specified by the numDims argument.
inline PresburgerSet
-parsePresburgerSetFromPolyStrings(unsigned numDims, ArrayRef<StringRef> strs) {
- PresburgerSet set =
- PresburgerSet::getEmpty(PresburgerSpace::getSetSpace(numDims));
+parsePresburgerSetFromPolyStrings(unsigned numDims, ArrayRef<StringRef> strs,
+ unsigned numSymbols = 0) {
+ PresburgerSet set = PresburgerSet::getEmpty(
+ PresburgerSpace::getSetSpace(numDims, numSymbols));
for (StringRef str : strs)
set.unionInPlace(parsePoly(str));
return set;
@@ -71,9 +73,9 @@ inline PWMAFunction parsePWMAF(
unsigned numSymbols = 0) {
static MLIRContext context;
- PWMAFunction result(
- PresburgerSpace::getSetSpace(numInputs - numSymbols, numSymbols),
- numOutputs);
+ PWMAFunction result(PresburgerSpace::getSetSpace(
+ /*numDims=*/numInputs - numSymbols, numSymbols),
+ numOutputs);
for (const auto &pair : data) {
IntegerPolyhedron domain = parsePoly(pair.first);
More information about the Mlir-commits
mailing list