[Mlir-commits] [mlir] [MLIR][Presburger] Implement PresburgerRelation::convertVarKind and a… (PR #66104)

Bharathi Ramana Joshi llvmlistbot at llvm.org
Tue Sep 12 09:22:33 PDT 2023


https://github.com/iambrj created https://github.com/llvm/llvm-project/pull/66104:

…dd unit test for IntegerRelation::convertVarKind

>From a15e7bdd56d85121691eb9b54177539c81de5984 Mon Sep 17 00:00:00 2001
From: iambrj <joshibharathiramana at gmail.com>
Date: Tue, 12 Sep 2023 13:23:37 +0530
Subject: [PATCH] [MLIR][Presburger] Implement
 PresburgerRelation::convertVarKind and add unit test for
 IntegerRelation::convertVarKind

---
 .../Analysis/Presburger/PresburgerRelation.h  |  8 +++
 .../Presburger/PresburgerRelation.cpp         | 15 +++++
 .../Presburger/IntegerRelationTest.cpp        | 25 +++++---
 mlir/unittests/Analysis/Presburger/Parser.h   | 21 +++++++
 .../Presburger/PresburgerRelationTest.cpp     | 57 +++++++++++++------
 5 files changed, 102 insertions(+), 24 deletions(-)

diff --git a/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h b/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h
index f54878272d86d09..0c9c5cf67b4c3c1 100644
--- a/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h
+++ b/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h
@@ -66,6 +66,14 @@ class PresburgerRelation {
 
   void insertVarInPlace(VarKind kind, unsigned pos, unsigned num = 1);
 
+  /// Converts variables of the specified kind in the column range [srcPos,
+  /// srcPos + num) to variables of the specified kind at position dstPos. The
+  /// ranges are relative to the kind of variable.
+  ///
+  /// srcKind and dstKind must be different.
+  void convertVarKind(VarKind srcKind, unsigned srcPos, unsigned num,
+                      VarKind dstKind, unsigned dstPos);
+
   /// Return a reference to the list of disjuncts.
   ArrayRef<IntegerRelation> getAllDisjuncts() const;
 
diff --git a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
index 09e7563d58b4898..997afafa596193f 100644
--- a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
@@ -38,6 +38,21 @@ void PresburgerRelation::insertVarInPlace(VarKind kind, unsigned pos,
   space.insertVar(kind, pos, num);
 }
 
+void PresburgerRelation::convertVarKind(VarKind srcKind, unsigned srcPos,
+                                        unsigned num, VarKind dstKind,
+                                        unsigned dstPos) {
+  assert(srcKind != dstKind && "cannot convert variables to the same kind");
+  assert(srcPos + num <= space.getNumVarKind(srcKind) &&
+         "invalid range for source variables");
+  assert(dstPos <= space.getNumVarKind(dstKind) &&
+         "invalid position for destination variables");
+
+  space.convertVarKind(srcKind, srcPos, num, dstKind, dstPos);
+
+  for (IntegerRelation &disjunct : disjuncts)
+    disjunct.convertVarKind(srcKind, srcPos, srcPos + num, dstKind, dstPos);
+}
+
 unsigned PresburgerRelation::getNumDisjuncts() const {
   return disjuncts.size();
 }
diff --git a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
index dd20e058e358ddc..7bca5d9b5be8f4e 100644
--- a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
@@ -16,14 +16,6 @@
 using namespace mlir;
 using namespace presburger;
 
-static IntegerRelation parseRelationFromSet(StringRef set, unsigned numDomain) {
-  IntegerRelation rel = parseIntegerPolyhedron(set);
-
-  rel.convertVarKind(VarKind::SetDim, 0, numDomain, VarKind::Domain);
-
-  return rel;
-}
-
 TEST(IntegerRelationTest, getDomainAndRangeSet) {
   IntegerRelation rel = parseRelationFromSet(
       "(x, xr)[N] : (xr - x - 10 == 0, xr >= 0, N - xr >= 0)", 1);
@@ -175,3 +167,20 @@ TEST(IntegerRelationTest, symbolicLexmax) {
   EXPECT_TRUE(lexmax3.unboundedDomain.isIntegerEmpty());
   EXPECT_TRUE(lexmax3.lexopt.isEqual(expectedLexmax3));
 }
+
+TEST(IntegerRelationTest, convertVarKind) {
+  IntegerRelation rel = parseRelationFromSet(
+      "(x, y, a)[U, V, W] : (x - U == 0, y + a == W, U - V >= 0, y - a >= 0)",
+      2);
+
+  // Make a few kind conversions.
+  rel.convertVarKind(VarKind::Domain, 0, 1, VarKind::Range, 0);
+  rel.convertVarKind(VarKind::Symbol, 1, 3, VarKind::Domain, 1);
+  rel.convertVarKind(VarKind::Symbol, 0, 1, VarKind::Range, 1);
+
+  IntegerRelation expectedRel = parseRelationFromSet(
+      "(y, V, W, x, U, a)[] : (x - U == 0, y + a == W, U - V >= 0, y - a >= 0)",
+      3);
+
+  EXPECT_TRUE(rel.isEqual(expectedRel));
+}
diff --git a/mlir/unittests/Analysis/Presburger/Parser.h b/mlir/unittests/Analysis/Presburger/Parser.h
index c2c63730056e7fe..f64a72f302dde2b 100644
--- a/mlir/unittests/Analysis/Presburger/Parser.h
+++ b/mlir/unittests/Analysis/Presburger/Parser.h
@@ -80,6 +80,27 @@ parsePWMAF(ArrayRef<std::pair<StringRef, StringRef>> pieces) {
   return func;
 }
 
+inline IntegerRelation parseRelationFromSet(StringRef set, unsigned numDomain) {
+  IntegerRelation rel = parseIntegerPolyhedron(set);
+
+  rel.convertVarKind(VarKind::SetDim, 0, numDomain, VarKind::Domain);
+
+  return rel;
+}
+
+inline PresburgerRelation
+parsePresburgerRelationFromPresburgerSet(ArrayRef<StringRef> strs,
+                                         unsigned numDomain) {
+  assert(!strs.empty() && "strs should not be empty");
+
+  IntegerRelation rel = parseIntegerPolyhedron(strs[0]);
+  PresburgerRelation result(rel);
+  for (unsigned i = 1, e = strs.size(); i < e; ++i)
+    result.unionInPlace(parseIntegerPolyhedron(strs[i]));
+  result.convertVarKind(VarKind::Range, 0, numDomain, VarKind::Domain, 0);
+  return result;
+}
+
 } // namespace presburger
 } // namespace mlir
 
diff --git a/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp
index c882a516bc29dfe..52bcd2de0928a8e 100644
--- a/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 #include "mlir/Analysis/Presburger/PresburgerRelation.h"
 #include "Parser.h"
+#include "mlir/Analysis/Presburger/IntegerRelation.h"
 #include "mlir/Analysis/Presburger/Simplex.h"
 
 #include <gmock/gmock.h>
@@ -16,22 +17,6 @@
 using namespace mlir;
 using namespace presburger;
 
-static PresburgerRelation
-parsePresburgerRelationFromPresburgerSet(ArrayRef<StringRef> strs,
-                                         unsigned numDomain) {
-  assert(!strs.empty() && "strs should not be empty");
-
-  IntegerRelation rel = parseIntegerPolyhedron(strs[0]);
-  rel.convertVarKind(VarKind::SetDim, 0, numDomain, VarKind::Domain);
-  PresburgerRelation result(rel);
-  for (unsigned i = 1, e = strs.size(); i < e; ++i) {
-    rel = parseIntegerPolyhedron(strs[i]);
-    rel.convertVarKind(VarKind::SetDim, 0, numDomain, VarKind::Domain);
-    result.unionInPlace(rel);
-  }
-  return result;
-}
-
 TEST(PresburgerRelationTest, intersectDomainAndRange) {
   {
     PresburgerRelation rel = parsePresburgerRelationFromPresburgerSet(
@@ -291,3 +276,43 @@ TEST(PresburgerRelationTest, getDomainAndRangeSet) {
 
   EXPECT_TRUE(rangeSet.isEqual(expectedRangeSet));
 }
+
+TEST(PresburgerRelationTest, convertVarKind) {
+  PresburgerSpace space = PresburgerSpace::getRelationSpace(2, 1, 3, 0);
+
+  IntegerRelation disj1 = parseRelationFromSet(
+                      "(x, y, a)[U, V, W] : (x - U == 0, y + a == W, U - V >= "
+                      "0, y - a >= 0)",
+                      2),
+                  disj2 = parseRelationFromSet(
+                      "(x, y, a)[U, V, W] : (x - U == 0, y + a == W, U - V >= "
+                      "0, y - a >= 0)",
+                      2);
+
+  PresburgerRelation rel(disj1);
+  rel.unionInPlace(disj2);
+
+  // Make a few kind conversions.
+  rel.convertVarKind(VarKind::Domain, 0, 1, VarKind::Range, 0);
+  rel.convertVarKind(VarKind::Symbol, 1, 2, VarKind::Domain, 1);
+  rel.convertVarKind(VarKind::Symbol, 0, 1, VarKind::Range, 1);
+
+  // Expected rel.
+  disj1.convertVarKind(VarKind::Domain, 0, 1, VarKind::Range, 0);
+  disj1.convertVarKind(VarKind::Symbol, 1, 3, VarKind::Domain, 1);
+  disj1.convertVarKind(VarKind::Symbol, 0, 1, VarKind::Range, 1);
+  disj2.convertVarKind(VarKind::Domain, 0, 1, VarKind::Range, 0);
+  disj2.convertVarKind(VarKind::Symbol, 1, 3, VarKind::Domain, 1);
+  disj2.convertVarKind(VarKind::Symbol, 0, 1, VarKind::Range, 1);
+
+  PresburgerRelation expectedRel(disj1);
+  expectedRel.unionInPlace(disj2);
+
+  // Check if var counts are correct.
+  EXPECT_EQ(rel.getNumDomainVars(), 3u);
+  EXPECT_EQ(rel.getNumRangeVars(), 3u);
+  EXPECT_EQ(rel.getNumSymbolVars(), 0u);
+
+  // Check if identifiers are transferred correctly.
+  EXPECT_TRUE(expectedRel.isEqual(rel));
+}



More information about the Mlir-commits mailing list