[Mlir-commits] [mlir] ca21398 - [MLIR][Presburger] Implement findSymbolicIntegerLexMin/Max for PresburgerRelation

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Aug 7 23:23:55 PDT 2023


Author: iambrj
Date: 2023-08-08T11:46:27+05:30
New Revision: ca213983ccab7ea3571b0b50727457e1287b3f12

URL: https://github.com/llvm/llvm-project/commit/ca213983ccab7ea3571b0b50727457e1287b3f12
DIFF: https://github.com/llvm/llvm-project/commit/ca213983ccab7ea3571b0b50727457e1287b3f12.diff

LOG: [MLIR][Presburger] Implement findSymbolicIntegerLexMin/Max for PresburgerRelation

This patch implements findSymbolicIntegerLexMin/Max for PresburgerRelation

Reviewed By: Groverkss

Differential Revision: https://reviews.llvm.org/D156623

Added: 
    

Modified: 
    mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h
    mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
    mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h b/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h
index 8c7e0746c84610..ae94d2d96162cd 100644
--- a/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h
+++ b/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h
@@ -118,6 +118,16 @@ class PresburgerRelation {
   /// Same as compose, provided for uniformity with applyDomain.
   void applyRange(const PresburgerRelation &rel);
 
+  /// Compute the symbolic integer lexmin of the relation, i.e. for every
+  /// assignment of the symbols and domain the lexicographically minimum value
+  /// attained by the range.
+  SymbolicLexOpt findSymbolicIntegerLexMin() const;
+
+  /// Compute the symbolic integer lexmax of the relation, i.e. for every
+  /// assignment of the symbols and domain the lexicographically maximum value
+  /// attained by the range.
+  SymbolicLexOpt findSymbolicIntegerLexMax() const;
+
   /// Return true if the set contains the given point, and false otherwise.
   bool containsPoint(ArrayRef<MPInt> point) const;
   bool containsPoint(ArrayRef<int64_t> point) const {

diff  --git a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
index c692069618ad60..df801b0e41d0f5 100644
--- a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Analysis/Presburger/PresburgerRelation.h"
 #include "mlir/Analysis/Presburger/IntegerRelation.h"
+#include "mlir/Analysis/Presburger/PWMAFunction.h"
 #include "mlir/Analysis/Presburger/Simplex.h"
 #include "mlir/Analysis/Presburger/Utils.h"
 #include "llvm/ADT/STLExtras.h"
@@ -203,6 +204,33 @@ void PresburgerRelation::applyRange(const PresburgerRelation &rel) {
   compose(rel);
 }
 
+static SymbolicLexOpt findSymbolicIntegerLexOpt(const PresburgerRelation &rel,
+                                                bool isMin) {
+  SymbolicLexOpt result(rel.getSpace());
+  PWMAFunction &lexopt = result.lexopt;
+  PresburgerSet &unboundedDomain = result.unboundedDomain;
+  for (const IntegerRelation &cs : rel.getAllDisjuncts()) {
+    SymbolicLexOpt s(rel.getSpace());
+    if (isMin) {
+      s = cs.findSymbolicIntegerLexMin();
+      lexopt = lexopt.unionLexMin(s.lexopt);
+    } else {
+      s = cs.findSymbolicIntegerLexMax();
+      lexopt = lexopt.unionLexMax(s.lexopt);
+    }
+    unboundedDomain = unboundedDomain.intersect(s.unboundedDomain);
+  }
+  return result;
+}
+
+SymbolicLexOpt PresburgerRelation::findSymbolicIntegerLexMin() const {
+  return findSymbolicIntegerLexOpt(*this, true);
+}
+
+SymbolicLexOpt PresburgerRelation::findSymbolicIntegerLexMax() const {
+  return findSymbolicIntegerLexOpt(*this, false);
+}
+
 /// Return the coefficients of the ineq in `rel` specified by  `idx`.
 /// `idx` can refer not only to an actual inequality of `rel`, but also
 /// to either of the inequalities that make up an equality in `rel`.

diff  --git a/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp
index 7054ed0f360577..012a9bd5d084c1 100644
--- a/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 #include "mlir/Analysis/Presburger/PresburgerRelation.h"
 #include "Parser.h"
+#include "mlir/Analysis/Presburger/Simplex.h"
 
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
@@ -189,3 +190,109 @@ TEST(PresburgerRelationTest, inverse) {
     EXPECT_TRUE(rel.isEqual(inverseRel));
   }
 }
+
+TEST(IntegerRelationTest, symbolicLexOpt) {
+  PresburgerRelation rel1 = parsePresburgerRelationFromPresburgerSet(
+      {"(x, y)[N, M] : (x >= 0, y >= 0, N - 1 >= 0, M >= 0, M - 2 * N - 1>= 0, "
+       "2 * N - x >= 0, 2 * N - y >= 0)",
+       "(x, y)[N, M] : (x >= 0, y >= 0, N - 1 >= 0, M >= 0, M - 2 * N - 1>= 0, "
+       "x - N >= 0, M - x >= 0, y - 2 * N >= 0, M - y >= 0)"},
+      1);
+
+  SymbolicLexOpt lexmin1 = rel1.findSymbolicIntegerLexMin();
+
+  PWMAFunction expectedLexMin1 = parsePWMAF({
+      {"(x)[N, M] : (x >= 0, N - 1 >= 0, M >= 0, M - 2 * N - 1 >= 0, "
+       "2 * N - x >= 0)",
+       "(x)[N, M] -> (0)"},
+      {"(x)[N, M] : (x >= 0, N - 1 >= 0, M >= 0, M - 2 * N - 1 >= 0, "
+       "x - 2 * N- 1 >= 0, M - x >= 0)",
+       "(x)[N, M] -> (2 * N)"},
+  });
+
+  SymbolicLexOpt lexmax1 = rel1.findSymbolicIntegerLexMax();
+
+  PWMAFunction expectedLexMax1 = parsePWMAF({
+      {"(x)[N, M] : (x >= 0, N - 1 >= 0, M >= 0, M - 2 * N - 1 >= 0, "
+       "N - 1 - x  >= 0)",
+       "(x)[N, M] -> (2 * N)"},
+      {"(x)[N, M] : (x >= 0, N - 1 >= 0, M >= 0, M - 2 * N - 1 >= 0, "
+       "x - N >= 0, M - x >= 0)",
+       "(x)[N, M] -> (M)"},
+  });
+
+  EXPECT_TRUE(lexmin1.unboundedDomain.isIntegerEmpty());
+  EXPECT_TRUE(lexmin1.lexopt.isEqual(expectedLexMin1));
+  EXPECT_TRUE(lexmax1.unboundedDomain.isIntegerEmpty());
+  EXPECT_TRUE(lexmax1.lexopt.isEqual(expectedLexMax1));
+
+  PresburgerRelation rel2 = parsePresburgerRelationFromPresburgerSet(
+      // x or y or z
+      // lexmin = (x, 0, 1 - x)
+      // lexmax = (x, 1, 1)
+      {"(x, y, z) : (x >= 0, y >= 0, z >= 0, 1 - x >= 0, 1 - y >= 0, "
+       "1 - z >= 0, x + y + z - 1 >= 0)",
+       // (x or y) and (y or z) and (z or x)
+       // lexmin = (x, 1 - x, 1)
+       // lexmax = (x, 1, 1)
+       "(x, y, z) : (x >= 0, y >= 0, z >= 0, 1 - x >= 0, 1 - y >= 0, "
+       "1 - z >= 0, x + y - 1 >= 0, y + z - 1 >= 0, z + x - 1 >= 0)",
+       // x => (not y) or (not z)
+       // lexmin = (x, 0, 0)
+       // lexmax = (x, 1, 1 - x)
+       "(x, y, z) : (x >= 0, y >= 0, z >= 0, 1 - x >= 0, 1 - y >= 0, "
+       "1 - z >= 0, 2 - x - y - z >= 0)"},
+      1);
+
+  SymbolicLexOpt lexmin2 = rel2.findSymbolicIntegerLexMin();
+
+  PWMAFunction expectedLexMin2 =
+      parsePWMAF({{"(x) : (x >= 0, 1 - x >= 0)", "(x) -> (0, 0)"}});
+
+  SymbolicLexOpt lexmax2 = rel2.findSymbolicIntegerLexMax();
+
+  PWMAFunction expectedLexMax2 =
+      parsePWMAF({{"(x) : (x >= 0, 1 - x >= 0)", "(x) -> (1, 1)"}});
+
+  EXPECT_TRUE(lexmin2.unboundedDomain.isIntegerEmpty());
+  EXPECT_TRUE(lexmin2.lexopt.isEqual(expectedLexMin2));
+  EXPECT_TRUE(lexmax2.unboundedDomain.isIntegerEmpty());
+  EXPECT_TRUE(lexmax2.lexopt.isEqual(expectedLexMax2));
+
+  PresburgerRelation rel3 = parsePresburgerRelationFromPresburgerSet(
+      // (x => u or v or w) and (x or v) and (x or (not w))
+      // lexmin = (x, 0, 0, 1 - x)
+      // lexmax = (x, 1, 1 - x, x)
+      {"(x, u, v, w) : (x >= 0, u >= 0, v >= 0, w >= 0, 1 - x >= 0, "
+       "1 - u >= 0, 1 - v >= 0, 1 - w >= 0, -x + u + v + w >= 0, "
+       "x + v - 1 >= 0, x - w >= 0)",
+       // x => (u => (v => w)) and (x or (not v)) and (x or (not w))
+       // lexmin = (x, 0, 0, x)
+       // lexmax = (x, 1, x, x)
+       "(x, u, v, w) : (x >= 0, u >= 0, v >= 0, w >= 0, 1 - x >= 0, "
+       "1 - u >= 0, 1 - v >= 0, 1 - w >= 0, -x - u - v + w + 2 >= 0, "
+       "x - v >= 0, x - w >= 0)",
+       // (x or (u or (not v))) and ((not x) or ((not u) or w))
+       // and (x or (not v)) and (x or (not w))
+       // lexmin = (x, 0, 0, x)
+       // lexmax = (x, 1, x, x)
+       "(x, u, v, w) : (x >= 0, u >= 0, v >= 0, w >= 0, 1 - x >= 0, "
+       "1 - u >= 0, 1 - v >= 0, 1 - w >= 0, x + u - v >= 0, x - u + w >= 0, "
+       "x - v >= 0, x - w >= 0)"},
+      1);
+
+  SymbolicLexOpt lexmin3 = rel3.findSymbolicIntegerLexMin();
+
+  PWMAFunction expectedLexMin3 =
+      parsePWMAF({{"(x) : (x >= 0, 1 - x >= 0)", "(x) -> (0, 0, 0)"}});
+
+  SymbolicLexOpt lexmax3 = rel3.findSymbolicIntegerLexMax();
+
+  PWMAFunction expectedLexMax3 =
+      parsePWMAF({{"(x) : (x >= 0, 1 - x >= 0)", "(x) -> (1, 1, x)"}});
+
+  EXPECT_TRUE(lexmin3.unboundedDomain.isIntegerEmpty());
+  EXPECT_TRUE(lexmin3.lexopt.isEqual(expectedLexMin3));
+  EXPECT_TRUE(lexmax3.unboundedDomain.isIntegerEmpty());
+  EXPECT_TRUE(lexmax3.lexopt.isEqual(expectedLexMax3));
+}


        


More information about the Mlir-commits mailing list