[Mlir-commits] [mlir] [MLIR][Presburger] Implement IntegerRelation::mergeAndAlignSymbols (PR #76736)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jan 2 09:01:33 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-presburger

Author: Bharathi Ramana Joshi (iambrj)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/76736.diff


3 Files Affected:

- (modified) mlir/include/mlir/Analysis/Presburger/IntegerRelation.h (+5) 
- (modified) mlir/lib/Analysis/Presburger/IntegerRelation.cpp (+31) 
- (modified) mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp (+83) 


``````````diff
diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
index 4c6b810f92e95a..cd957280eb740d 100644
--- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
+++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
@@ -577,6 +577,11 @@ class IntegerRelation {
     convertVarKind(kind, varStart, varLimit, VarKind::Local);
   }
 
+  /// Merge and align symbol variables of `this` and `other` with respect to
+  /// identifiers. After this operation the symbol variables of both relations
+  /// have the same identifiers in the same order.
+  void mergeAndAlignSymbols(IntegerRelation &other);
+
   /// Adds additional local vars to the sets such that they both have the union
   /// of the local vars in each set, without changing the set of points that
   /// lie in `this` and `other`.
diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
index 0109384f1689dd..af16321e69a4cc 100644
--- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
@@ -1288,6 +1288,37 @@ void IntegerRelation::eliminateRedundantLocalVar(unsigned posA, unsigned posB) {
   removeVar(posB);
 }
 
+void IntegerRelation::mergeAndAlignSymbols(IntegerRelation &other) {
+  assert(space.isUsingIds() && other.space.isUsingIds() &&
+         "Both relations need to have identifers to merge & align");
+
+  // First merge & align identifiers into `other` from `this`.
+  unsigned i = 0;
+  for (const Identifier identifier : space.getIds(VarKind::Symbol)) {
+    // If the identifier exists in `other`, then align it; otherwise insert it
+    // assuming it is a new identifier. Search in `other` starting at position
+    // `i` since the left of `i` is aligned.
+    auto *findBegin = other.space.getIds(VarKind::Symbol).begin() + i;
+    auto *findEnd = other.space.getIds(VarKind::Symbol).end();
+    auto *itr = std::find(findBegin, findEnd, identifier);
+    if (itr != findEnd) {
+      other.swapVar(other.getVarKindOffset(VarKind::Symbol) + i,
+                    other.getVarKindOffset(VarKind::Symbol) + i +
+                        std::distance(findBegin, itr));
+    } else {
+      other.insertVar(VarKind::Symbol, i);
+      other.space.getId(VarKind::Symbol, i) = identifier;
+    }
+    ++i;
+  }
+
+  // Finally add identifiers that are in `other`, but not in `this` to `this`.
+  for (unsigned e = other.getNumVarKind(VarKind::Symbol); i < e; ++i) {
+    insertVar(VarKind::Symbol, i);
+    space.getId(VarKind::Symbol, i) = other.space.getId(VarKind::Symbol, i);
+  }
+}
+
 /// Adds additional local ids to the sets such that they both have the union
 /// of the local ids in each set, without changing the set of points that
 /// lie in `this` and `other`.
diff --git a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
index f390296da648d2..998218416d7e57 100644
--- a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
@@ -207,3 +207,86 @@ TEST(IntegerRelationTest, swapVar) {
   EXPECT_TRUE(swappedSpace.getId(VarKind::Symbol, 1)
                   .isEqual(space.getId(VarKind::Domain, 1)));
 }
+
+TEST(IntegerRelationTest, mergeAndAlignSymbols) {
+  IntegerRelation rel =
+      parseRelationFromSet("(x, y, z, a, b, c)[N, Q] : (a - x - y == 0, "
+                           "x >= 0, N - b >= 0, y >= 0, Q - y >= 0)",
+                           3);
+  IntegerRelation otherRel = parseRelationFromSet(
+      "(x, y, z, a, b)[N, M, P] : (z - x - y == 0, x >= 0, N - x "
+      ">= 0, y >= 0, M - y >= 0, 2 * P - 3 * a + 2 * b == 0)",
+      3);
+  PresburgerSpace space = PresburgerSpace::getRelationSpace(3, 3, 2, 0);
+  space.resetIds();
+
+  PresburgerSpace otherSpace = PresburgerSpace::getRelationSpace(3, 2, 3, 0);
+  otherSpace.resetIds();
+
+  // Attach identifiers.
+  int identifiers[7] = {0, 1, 2, 3, 4, 5, 6};
+  int otherIdentifiers[8] = {10, 11, 12, 13, 14, 15, 16, 17};
+
+  space.getId(VarKind::Domain, 0) = Identifier(&identifiers[0]);
+  space.getId(VarKind::Domain, 1) = Identifier(&identifiers[1]);
+  // Note the common identifier.
+  space.getId(VarKind::Domain, 2) = Identifier(&otherIdentifiers[2]);
+  space.getId(VarKind::Range, 0) = Identifier(&identifiers[2]);
+  space.getId(VarKind::Range, 1) = Identifier(&identifiers[3]);
+  space.getId(VarKind::Range, 2) = Identifier(&identifiers[4]);
+  space.getId(VarKind::Symbol, 0) = Identifier(&identifiers[5]);
+  space.getId(VarKind::Symbol, 1) = Identifier(&identifiers[6]);
+
+  otherSpace.getId(VarKind::Domain, 0) = Identifier(&otherIdentifiers[0]);
+  otherSpace.getId(VarKind::Domain, 1) = Identifier(&otherIdentifiers[1]);
+  otherSpace.getId(VarKind::Domain, 2) = Identifier(&otherIdentifiers[2]);
+  otherSpace.getId(VarKind::Range, 0) = Identifier(&otherIdentifiers[3]);
+  otherSpace.getId(VarKind::Range, 1) = Identifier(&otherIdentifiers[4]);
+  // Note the common identifier.
+  otherSpace.getId(VarKind::Symbol, 0) = Identifier(&identifiers[6]);
+  otherSpace.getId(VarKind::Symbol, 1) = Identifier(&otherIdentifiers[5]);
+  otherSpace.getId(VarKind::Symbol, 2) = Identifier(&otherIdentifiers[7]);
+
+  rel.setSpace(space);
+  otherRel.setSpace(otherSpace);
+  rel.mergeAndAlignSymbols(otherRel);
+
+  space = rel.getSpace();
+  otherSpace = otherRel.getSpace();
+
+  // Check if merge & align is successful.
+  // Check symbol var identifiers.
+  EXPECT_EQ(4u, space.getNumSymbolVars());
+  EXPECT_EQ(4u, otherSpace.getNumSymbolVars());
+  EXPECT_EQ(space.getId(VarKind::Symbol, 0), Identifier(&identifiers[5]));
+  EXPECT_EQ(space.getId(VarKind::Symbol, 1), Identifier(&identifiers[6]));
+  EXPECT_EQ(space.getId(VarKind::Symbol, 2), Identifier(&otherIdentifiers[5]));
+  EXPECT_EQ(space.getId(VarKind::Symbol, 3), Identifier(&otherIdentifiers[7]));
+  EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 0), Identifier(&identifiers[5]));
+  EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 1), Identifier(&identifiers[6]));
+  EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 2),
+            Identifier(&otherIdentifiers[5]));
+  EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 3),
+            Identifier(&otherIdentifiers[7]));
+  // Check that domain and range var identifiers are not affected.
+  EXPECT_EQ(3u, space.getNumDomainVars());
+  EXPECT_EQ(3u, space.getNumRangeVars());
+  EXPECT_EQ(space.getId(VarKind::Domain, 0), Identifier(&identifiers[0]));
+  EXPECT_EQ(space.getId(VarKind::Domain, 1), Identifier(&identifiers[1]));
+  EXPECT_EQ(space.getId(VarKind::Domain, 2), Identifier(&otherIdentifiers[2]));
+  EXPECT_EQ(space.getId(VarKind::Range, 0), Identifier(&identifiers[2]));
+  EXPECT_EQ(space.getId(VarKind::Range, 1), Identifier(&identifiers[3]));
+  EXPECT_EQ(space.getId(VarKind::Range, 2), Identifier(&identifiers[4]));
+  EXPECT_EQ(3u, otherSpace.getNumDomainVars());
+  EXPECT_EQ(2u, otherSpace.getNumRangeVars());
+  EXPECT_EQ(otherSpace.getId(VarKind::Domain, 0),
+            Identifier(&otherIdentifiers[0]));
+  EXPECT_EQ(otherSpace.getId(VarKind::Domain, 1),
+            Identifier(&otherIdentifiers[1]));
+  EXPECT_EQ(otherSpace.getId(VarKind::Domain, 2),
+            Identifier(&otherIdentifiers[2]));
+  EXPECT_EQ(otherSpace.getId(VarKind::Range, 0),
+            Identifier(&otherIdentifiers[3]));
+  EXPECT_EQ(otherSpace.getId(VarKind::Range, 1),
+            Identifier(&otherIdentifiers[4]));
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/76736


More information about the Mlir-commits mailing list