[Mlir-commits] [mlir] ca21398 - [MLIR][Presburger] Implement findSymbolicIntegerLexMin/Max for PresburgerRelation
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Aug 7 23:23:55 PDT 2023
Author: iambrj
Date: 2023-08-08T11:46:27+05:30
New Revision: ca213983ccab7ea3571b0b50727457e1287b3f12
URL: https://github.com/llvm/llvm-project/commit/ca213983ccab7ea3571b0b50727457e1287b3f12
DIFF: https://github.com/llvm/llvm-project/commit/ca213983ccab7ea3571b0b50727457e1287b3f12.diff
LOG: [MLIR][Presburger] Implement findSymbolicIntegerLexMin/Max for PresburgerRelation
This patch implements findSymbolicIntegerLexMin/Max for PresburgerRelation
Reviewed By: Groverkss
Differential Revision: https://reviews.llvm.org/D156623
Added:
Modified:
mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h
mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h b/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h
index 8c7e0746c84610..ae94d2d96162cd 100644
--- a/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h
+++ b/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h
@@ -118,6 +118,16 @@ class PresburgerRelation {
/// Same as compose, provided for uniformity with applyDomain.
void applyRange(const PresburgerRelation &rel);
+ /// Compute the symbolic integer lexmin of the relation, i.e. for every
+ /// assignment of the symbols and domain the lexicographically minimum value
+ /// attained by the range.
+ SymbolicLexOpt findSymbolicIntegerLexMin() const;
+
+ /// Compute the symbolic integer lexmax of the relation, i.e. for every
+ /// assignment of the symbols and domain the lexicographically maximum value
+ /// attained by the range.
+ SymbolicLexOpt findSymbolicIntegerLexMax() const;
+
/// Return true if the set contains the given point, and false otherwise.
bool containsPoint(ArrayRef<MPInt> point) const;
bool containsPoint(ArrayRef<int64_t> point) const {
diff --git a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
index c692069618ad60..df801b0e41d0f5 100644
--- a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
@@ -8,6 +8,7 @@
#include "mlir/Analysis/Presburger/PresburgerRelation.h"
#include "mlir/Analysis/Presburger/IntegerRelation.h"
+#include "mlir/Analysis/Presburger/PWMAFunction.h"
#include "mlir/Analysis/Presburger/Simplex.h"
#include "mlir/Analysis/Presburger/Utils.h"
#include "llvm/ADT/STLExtras.h"
@@ -203,6 +204,33 @@ void PresburgerRelation::applyRange(const PresburgerRelation &rel) {
compose(rel);
}
+static SymbolicLexOpt findSymbolicIntegerLexOpt(const PresburgerRelation &rel,
+ bool isMin) {
+ SymbolicLexOpt result(rel.getSpace());
+ PWMAFunction &lexopt = result.lexopt;
+ PresburgerSet &unboundedDomain = result.unboundedDomain;
+ for (const IntegerRelation &cs : rel.getAllDisjuncts()) {
+ SymbolicLexOpt s(rel.getSpace());
+ if (isMin) {
+ s = cs.findSymbolicIntegerLexMin();
+ lexopt = lexopt.unionLexMin(s.lexopt);
+ } else {
+ s = cs.findSymbolicIntegerLexMax();
+ lexopt = lexopt.unionLexMax(s.lexopt);
+ }
+ unboundedDomain = unboundedDomain.intersect(s.unboundedDomain);
+ }
+ return result;
+}
+
+SymbolicLexOpt PresburgerRelation::findSymbolicIntegerLexMin() const {
+ return findSymbolicIntegerLexOpt(*this, true);
+}
+
+SymbolicLexOpt PresburgerRelation::findSymbolicIntegerLexMax() const {
+ return findSymbolicIntegerLexOpt(*this, false);
+}
+
/// Return the coefficients of the ineq in `rel` specified by `idx`.
/// `idx` can refer not only to an actual inequality of `rel`, but also
/// to either of the inequalities that make up an equality in `rel`.
diff --git a/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp
index 7054ed0f360577..012a9bd5d084c1 100644
--- a/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/Presburger/PresburgerRelation.h"
#include "Parser.h"
+#include "mlir/Analysis/Presburger/Simplex.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
@@ -189,3 +190,109 @@ TEST(PresburgerRelationTest, inverse) {
EXPECT_TRUE(rel.isEqual(inverseRel));
}
}
+
+TEST(IntegerRelationTest, symbolicLexOpt) {
+ PresburgerRelation rel1 = parsePresburgerRelationFromPresburgerSet(
+ {"(x, y)[N, M] : (x >= 0, y >= 0, N - 1 >= 0, M >= 0, M - 2 * N - 1>= 0, "
+ "2 * N - x >= 0, 2 * N - y >= 0)",
+ "(x, y)[N, M] : (x >= 0, y >= 0, N - 1 >= 0, M >= 0, M - 2 * N - 1>= 0, "
+ "x - N >= 0, M - x >= 0, y - 2 * N >= 0, M - y >= 0)"},
+ 1);
+
+ SymbolicLexOpt lexmin1 = rel1.findSymbolicIntegerLexMin();
+
+ PWMAFunction expectedLexMin1 = parsePWMAF({
+ {"(x)[N, M] : (x >= 0, N - 1 >= 0, M >= 0, M - 2 * N - 1 >= 0, "
+ "2 * N - x >= 0)",
+ "(x)[N, M] -> (0)"},
+ {"(x)[N, M] : (x >= 0, N - 1 >= 0, M >= 0, M - 2 * N - 1 >= 0, "
+ "x - 2 * N- 1 >= 0, M - x >= 0)",
+ "(x)[N, M] -> (2 * N)"},
+ });
+
+ SymbolicLexOpt lexmax1 = rel1.findSymbolicIntegerLexMax();
+
+ PWMAFunction expectedLexMax1 = parsePWMAF({
+ {"(x)[N, M] : (x >= 0, N - 1 >= 0, M >= 0, M - 2 * N - 1 >= 0, "
+ "N - 1 - x >= 0)",
+ "(x)[N, M] -> (2 * N)"},
+ {"(x)[N, M] : (x >= 0, N - 1 >= 0, M >= 0, M - 2 * N - 1 >= 0, "
+ "x - N >= 0, M - x >= 0)",
+ "(x)[N, M] -> (M)"},
+ });
+
+ EXPECT_TRUE(lexmin1.unboundedDomain.isIntegerEmpty());
+ EXPECT_TRUE(lexmin1.lexopt.isEqual(expectedLexMin1));
+ EXPECT_TRUE(lexmax1.unboundedDomain.isIntegerEmpty());
+ EXPECT_TRUE(lexmax1.lexopt.isEqual(expectedLexMax1));
+
+ PresburgerRelation rel2 = parsePresburgerRelationFromPresburgerSet(
+ // x or y or z
+ // lexmin = (x, 0, 1 - x)
+ // lexmax = (x, 1, 1)
+ {"(x, y, z) : (x >= 0, y >= 0, z >= 0, 1 - x >= 0, 1 - y >= 0, "
+ "1 - z >= 0, x + y + z - 1 >= 0)",
+ // (x or y) and (y or z) and (z or x)
+ // lexmin = (x, 1 - x, 1)
+ // lexmax = (x, 1, 1)
+ "(x, y, z) : (x >= 0, y >= 0, z >= 0, 1 - x >= 0, 1 - y >= 0, "
+ "1 - z >= 0, x + y - 1 >= 0, y + z - 1 >= 0, z + x - 1 >= 0)",
+ // x => (not y) or (not z)
+ // lexmin = (x, 0, 0)
+ // lexmax = (x, 1, 1 - x)
+ "(x, y, z) : (x >= 0, y >= 0, z >= 0, 1 - x >= 0, 1 - y >= 0, "
+ "1 - z >= 0, 2 - x - y - z >= 0)"},
+ 1);
+
+ SymbolicLexOpt lexmin2 = rel2.findSymbolicIntegerLexMin();
+
+ PWMAFunction expectedLexMin2 =
+ parsePWMAF({{"(x) : (x >= 0, 1 - x >= 0)", "(x) -> (0, 0)"}});
+
+ SymbolicLexOpt lexmax2 = rel2.findSymbolicIntegerLexMax();
+
+ PWMAFunction expectedLexMax2 =
+ parsePWMAF({{"(x) : (x >= 0, 1 - x >= 0)", "(x) -> (1, 1)"}});
+
+ EXPECT_TRUE(lexmin2.unboundedDomain.isIntegerEmpty());
+ EXPECT_TRUE(lexmin2.lexopt.isEqual(expectedLexMin2));
+ EXPECT_TRUE(lexmax2.unboundedDomain.isIntegerEmpty());
+ EXPECT_TRUE(lexmax2.lexopt.isEqual(expectedLexMax2));
+
+ PresburgerRelation rel3 = parsePresburgerRelationFromPresburgerSet(
+ // (x => u or v or w) and (x or v) and (x or (not w))
+ // lexmin = (x, 0, 0, 1 - x)
+ // lexmax = (x, 1, 1 - x, x)
+ {"(x, u, v, w) : (x >= 0, u >= 0, v >= 0, w >= 0, 1 - x >= 0, "
+ "1 - u >= 0, 1 - v >= 0, 1 - w >= 0, -x + u + v + w >= 0, "
+ "x + v - 1 >= 0, x - w >= 0)",
+ // x => (u => (v => w)) and (x or (not v)) and (x or (not w))
+ // lexmin = (x, 0, 0, x)
+ // lexmax = (x, 1, x, x)
+ "(x, u, v, w) : (x >= 0, u >= 0, v >= 0, w >= 0, 1 - x >= 0, "
+ "1 - u >= 0, 1 - v >= 0, 1 - w >= 0, -x - u - v + w + 2 >= 0, "
+ "x - v >= 0, x - w >= 0)",
+ // (x or (u or (not v))) and ((not x) or ((not u) or w))
+ // and (x or (not v)) and (x or (not w))
+ // lexmin = (x, 0, 0, x)
+ // lexmax = (x, 1, x, x)
+ "(x, u, v, w) : (x >= 0, u >= 0, v >= 0, w >= 0, 1 - x >= 0, "
+ "1 - u >= 0, 1 - v >= 0, 1 - w >= 0, x + u - v >= 0, x - u + w >= 0, "
+ "x - v >= 0, x - w >= 0)"},
+ 1);
+
+ SymbolicLexOpt lexmin3 = rel3.findSymbolicIntegerLexMin();
+
+ PWMAFunction expectedLexMin3 =
+ parsePWMAF({{"(x) : (x >= 0, 1 - x >= 0)", "(x) -> (0, 0, 0)"}});
+
+ SymbolicLexOpt lexmax3 = rel3.findSymbolicIntegerLexMax();
+
+ PWMAFunction expectedLexMax3 =
+ parsePWMAF({{"(x) : (x >= 0, 1 - x >= 0)", "(x) -> (1, 1, x)"}});
+
+ EXPECT_TRUE(lexmin3.unboundedDomain.isIntegerEmpty());
+ EXPECT_TRUE(lexmin3.lexopt.isEqual(expectedLexMin3));
+ EXPECT_TRUE(lexmax3.unboundedDomain.isIntegerEmpty());
+ EXPECT_TRUE(lexmax3.lexopt.isEqual(expectedLexMax3));
+}
More information about the Mlir-commits
mailing list