[Mlir-commits] [mlir] [MLIR][Presburger] Implement PresburgerSpace::mergeAndAlignVarKind (PR #76397)

Bharathi Ramana Joshi llvmlistbot at llvm.org
Tue Dec 26 07:45:04 PST 2023


https://github.com/iambrj created https://github.com/llvm/llvm-project/pull/76397

None

>From 18b89bc1796bed3686954008b819845d6c5510b5 Mon Sep 17 00:00:00 2001
From: iambrj <joshibharathiramana at gmail.com>
Date: Fri, 22 Dec 2023 13:35:10 +0530
Subject: [PATCH] [MLIR][Presburger] Implement
 PresburgerSpace::mergeAndAlignVarKind

---
 .../Analysis/Presburger/PresburgerSpace.h     |  5 ++
 .../Analysis/Presburger/PresburgerSpace.cpp   | 38 ++++++++-
 .../Presburger/PresburgerSpaceTest.cpp        | 77 +++++++++++++++++++
 3 files changed, 118 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h
index 9fe2abafd36bad..6a450ddf3ed407 100644
--- a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h
+++ b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h
@@ -290,6 +290,11 @@ class PresburgerSpace {
   /// the symbols in two spaces are aligned.
   bool isAligned(const PresburgerSpace &other, VarKind kind) const;
 
+  /// Merge and align VarKind variables of `this` and `other` with respect to
+  /// identifiers. After this operation the VarKind variables of both spaces
+  /// have the same identifiers in the same order.
+  void mergeAndAlignVarKind(VarKind kind, PresburgerSpace &other);
+
   void print(llvm::raw_ostream &os) const;
   void dump() const;
 
diff --git a/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp b/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp
index cf1b3befbc89f8..3c440cebeee5f7 100644
--- a/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp
+++ b/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp
@@ -18,8 +18,9 @@ using namespace presburger;
 bool Identifier::isEqual(const Identifier &other) const {
   if (value == nullptr || other.value == nullptr)
     return false;
-  assert(value == other.value && idType == other.idType &&
-         "Values of Identifiers are equal but their types do not match.");
+  assert(value != other.value ||
+         (value == other.value && idType == other.idType &&
+          "Values of Identifiers are equal but their types do not match."));
   return value == other.value;
 }
 
@@ -293,6 +294,39 @@ void PresburgerSpace::setVarSymbolSeperation(unsigned newSymbolCount) {
   // `identifiers` remains same.
 }
 
+void PresburgerSpace::mergeAndAlignVarKind(VarKind kind,
+                                           PresburgerSpace &other) {
+  assert(usingIds && other.usingIds &&
+         "Both spaces need to have identifers to merge & align");
+
+  // First merge & align identifiers into `other` from `this`.
+  unsigned kindBeginOffset = other.getVarKindOffset(kind);
+  unsigned i = 0;
+  for (const Identifier *identifier =
+           identifiers.begin() + getVarKindOffset(kind);
+       identifier != identifiers.begin() + getVarKindEnd(kind); identifier++) {
+    // If the identifier exists in `other`, then align it; otherwise insert it
+    // assuming it is a new identifier. Search in `other` starting at position
+    // `i` since the left of `i` is aligned.
+    auto *findEnd = other.identifiers.begin() + other.getVarKindEnd(kind);
+    auto *itr = std::find(other.identifiers.begin() + kindBeginOffset + i,
+                          findEnd, *identifier);
+    if (itr != findEnd) {
+      std::iter_swap(other.identifiers.begin() + kindBeginOffset + i, itr);
+    } else {
+      other.insertVar(kind, i);
+      other.getId(kind, i) = *identifier;
+    }
+    i++;
+  }
+
+  // Finally add identifiers that are in `other`, but not in `this` to `this`.
+  for (unsigned e = other.getNumVarKind(kind); i < e; i++) {
+    insertVar(kind, i);
+    getId(kind, i) = other.getId(kind, i);
+  }
+}
+
 void PresburgerSpace::print(llvm::raw_ostream &os) const {
   os << "Domain: " << getNumDomainVars() << ", "
      << "Range: " << getNumRangeVars() << ", "
diff --git a/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp b/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp
index dd06d462f54bee..b8a578620161a8 100644
--- a/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp
@@ -179,3 +179,80 @@ TEST(PresburgerSpaceTest, convertVarKind2) {
   EXPECT_EQ(space.getId(VarKind::Symbol, 2), Identifier(&identifiers[1]));
   EXPECT_EQ(space.getId(VarKind::Symbol, 3), Identifier(&identifiers[3]));
 }
+
+TEST(PresburgerSpaceTest, mergeSymbols) {
+  PresburgerSpace space = PresburgerSpace::getRelationSpace(3, 3, 2, 0);
+  space.resetIds();
+
+  PresburgerSpace otherSpace = PresburgerSpace::getRelationSpace(3, 2, 3, 0);
+  otherSpace.resetIds();
+
+  // Attach identifiers.
+  int identifiers[7] = {0, 1, 2, 3, 4, 5, 6};
+  int otherIdentifiers[8] = {10, 11, 12, 13, 14, 15, 16, 17};
+
+  space.getId(VarKind::Domain, 0) = Identifier(&identifiers[0]);
+  space.getId(VarKind::Domain, 1) = Identifier(&identifiers[1]);
+  // Note the common identifier
+  space.getId(VarKind::Domain, 2) = Identifier(&otherIdentifiers[2]);
+  space.getId(VarKind::Range, 0) = Identifier(&identifiers[2]);
+  space.getId(VarKind::Range, 1) = Identifier(&identifiers[3]);
+  space.getId(VarKind::Range, 2) = Identifier(&identifiers[4]);
+  space.getId(VarKind::Symbol, 0) = Identifier(&identifiers[5]);
+  space.getId(VarKind::Symbol, 1) = Identifier(&identifiers[6]);
+
+  otherSpace.getId(VarKind::Domain, 0) = Identifier(&otherIdentifiers[0]);
+  otherSpace.getId(VarKind::Domain, 1) = Identifier(&otherIdentifiers[1]);
+  otherSpace.getId(VarKind::Domain, 2) = Identifier(&otherIdentifiers[2]);
+  otherSpace.getId(VarKind::Range, 0) = Identifier(&otherIdentifiers[3]);
+  otherSpace.getId(VarKind::Range, 1) = Identifier(&otherIdentifiers[4]);
+  // Note the common identifier
+  otherSpace.getId(VarKind::Symbol, 0) = Identifier(&identifiers[6]);
+  otherSpace.getId(VarKind::Symbol, 1) = Identifier(&otherIdentifiers[5]);
+  otherSpace.getId(VarKind::Symbol, 2) = Identifier(&otherIdentifiers[7]);
+
+  space.mergeAndAlignVarKind(VarKind::Domain, otherSpace);
+  space.mergeAndAlignVarKind(VarKind::Range, otherSpace);
+  space.mergeAndAlignVarKind(VarKind::Symbol, otherSpace);
+
+  // Check if merge & align is successful
+  // Check domain var identifiers
+  EXPECT_EQ(5u, space.getNumRangeVars());
+  EXPECT_EQ(5u, otherSpace.getNumRangeVars());
+  space.getId(VarKind::Domain, 0) = Identifier(&identifiers[0]);
+  space.getId(VarKind::Domain, 1) = Identifier(&identifiers[1]);
+  space.getId(VarKind::Domain, 2) = Identifier(&otherIdentifiers[2]);
+  space.getId(VarKind::Domain, 3) = Identifier(&otherIdentifiers[0]);
+  space.getId(VarKind::Domain, 4) = Identifier(&otherIdentifiers[1]);
+  otherSpace.getId(VarKind::Domain, 0) = Identifier(&identifiers[0]);
+  otherSpace.getId(VarKind::Domain, 1) = Identifier(&identifiers[1]);
+  otherSpace.getId(VarKind::Domain, 2) = Identifier(&otherIdentifiers[2]);
+  otherSpace.getId(VarKind::Domain, 3) = Identifier(&otherIdentifiers[0]);
+  otherSpace.getId(VarKind::Domain, 4) = Identifier(&otherIdentifiers[1]);
+  // Check range var identifiers
+  EXPECT_EQ(5u, space.getNumRangeVars());
+  EXPECT_EQ(5u, otherSpace.getNumRangeVars());
+  space.getId(VarKind::Range, 0) = Identifier(&identifiers[2]);
+  space.getId(VarKind::Range, 1) = Identifier(&identifiers[3]);
+  space.getId(VarKind::Range, 2) = Identifier(&identifiers[4]);
+  space.getId(VarKind::Range, 3) = Identifier(&otherIdentifiers[3]);
+  space.getId(VarKind::Range, 4) = Identifier(&otherIdentifiers[4]);
+  otherSpace.getId(VarKind::Range, 0) = Identifier(&identifiers[2]);
+  otherSpace.getId(VarKind::Range, 1) = Identifier(&identifiers[3]);
+  otherSpace.getId(VarKind::Range, 2) = Identifier(&identifiers[4]);
+  otherSpace.getId(VarKind::Range, 3) = Identifier(&otherIdentifiers[3]);
+  otherSpace.getId(VarKind::Range, 4) = Identifier(&otherIdentifiers[4]);
+  // Check symbol var identifiers
+  EXPECT_EQ(4u, space.getNumSymbolVars());
+  EXPECT_EQ(4u, otherSpace.getNumSymbolVars());
+  EXPECT_EQ(space.getId(VarKind::Symbol, 0), Identifier(&identifiers[5]));
+  EXPECT_EQ(space.getId(VarKind::Symbol, 1), Identifier(&identifiers[6]));
+  EXPECT_EQ(space.getId(VarKind::Symbol, 2), Identifier(&otherIdentifiers[5]));
+  EXPECT_EQ(space.getId(VarKind::Symbol, 3), Identifier(&otherIdentifiers[7]));
+  EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 0), Identifier(&identifiers[5]));
+  EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 1), Identifier(&identifiers[6]));
+  EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 2),
+            Identifier(&otherIdentifiers[5]));
+  EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 3),
+            Identifier(&otherIdentifiers[7]));
+}



More information about the Mlir-commits mailing list