[Mlir-commits] [mlir] [MLIR][Presburger] Implement IntegerRelation::mergeAndAlignSymbols (PR #76736)
Bharathi Ramana Joshi
llvmlistbot at llvm.org
Wed Jan 3 07:56:06 PST 2024
https://github.com/iambrj updated https://github.com/llvm/llvm-project/pull/76736
>From 8d338e25b9c5a709b75448c1de2a1023217987db Mon Sep 17 00:00:00 2001
From: iambrj <joshibharathiramana at gmail.com>
Date: Wed, 3 Jan 2024 21:21:46 +0530
Subject: [PATCH] [MLIR][Presburger] Implement
IntegerRelation::mergeAndAlignSymbols
---
.../Analysis/Presburger/IntegerRelation.h | 5 +
.../Analysis/Presburger/IntegerRelation.cpp | 31 +++
.../Presburger/IntegerRelationTest.cpp | 240 ++++++++++++++++++
3 files changed, 276 insertions(+)
diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
index 4c6b810f92e95a..cd957280eb740d 100644
--- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
+++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
@@ -577,6 +577,11 @@ class IntegerRelation {
convertVarKind(kind, varStart, varLimit, VarKind::Local);
}
+ /// Merge and align symbol variables of `this` and `other` with respect to
+ /// identifiers. After this operation the symbol variables of both relations
+ /// have the same identifiers in the same order.
+ void mergeAndAlignSymbols(IntegerRelation &other);
+
/// Adds additional local vars to the sets such that they both have the union
/// of the local vars in each set, without changing the set of points that
/// lie in `this` and `other`.
diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
index 0109384f1689dd..af16321e69a4cc 100644
--- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
@@ -1288,6 +1288,37 @@ void IntegerRelation::eliminateRedundantLocalVar(unsigned posA, unsigned posB) {
removeVar(posB);
}
+void IntegerRelation::mergeAndAlignSymbols(IntegerRelation &other) {
+ assert(space.isUsingIds() && other.space.isUsingIds() &&
+ "Both relations need to have identifers to merge & align");
+
+ // First merge & align identifiers into `other` from `this`.
+ unsigned i = 0;
+ for (const Identifier identifier : space.getIds(VarKind::Symbol)) {
+ // If the identifier exists in `other`, then align it; otherwise insert it
+ // assuming it is a new identifier. Search in `other` starting at position
+ // `i` since the left of `i` is aligned.
+ auto *findBegin = other.space.getIds(VarKind::Symbol).begin() + i;
+ auto *findEnd = other.space.getIds(VarKind::Symbol).end();
+ auto *itr = std::find(findBegin, findEnd, identifier);
+ if (itr != findEnd) {
+ other.swapVar(other.getVarKindOffset(VarKind::Symbol) + i,
+ other.getVarKindOffset(VarKind::Symbol) + i +
+ std::distance(findBegin, itr));
+ } else {
+ other.insertVar(VarKind::Symbol, i);
+ other.space.getId(VarKind::Symbol, i) = identifier;
+ }
+ ++i;
+ }
+
+ // Finally add identifiers that are in `other`, but not in `this` to `this`.
+ for (unsigned e = other.getNumVarKind(VarKind::Symbol); i < e; ++i) {
+ insertVar(VarKind::Symbol, i);
+ space.getId(VarKind::Symbol, i) = other.space.getId(VarKind::Symbol, i);
+ }
+}
+
/// Adds additional local ids to the sets such that they both have the union
/// of the local ids in each set, without changing the set of points that
/// lie in `this` and `other`.
diff --git a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
index f390296da648d2..63b7548ea92dab 100644
--- a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
@@ -207,3 +207,243 @@ TEST(IntegerRelationTest, swapVar) {
EXPECT_TRUE(swappedSpace.getId(VarKind::Symbol, 1)
.isEqual(space.getId(VarKind::Domain, 1)));
}
+
+TEST(IntegerRelationTest, mergeAndAlignSymbols) {
+ IntegerRelation rel =
+ parseRelationFromSet("(x, y, z, a, b, c)[N, Q] : (a - x - y == 0, "
+ "x >= 0, N - b >= 0, y >= 0, Q - y >= 0)",
+ 3);
+ IntegerRelation otherRel = parseRelationFromSet(
+ "(x, y, z, a, b)[N, M, P] : (z - x - y == 0, x >= 0, N - x "
+ ">= 0, y >= 0, M - y >= 0, 2 * P - 3 * a + 2 * b == 0)",
+ 3);
+ PresburgerSpace space = PresburgerSpace::getRelationSpace(3, 3, 2, 0);
+ space.resetIds();
+
+ PresburgerSpace otherSpace = PresburgerSpace::getRelationSpace(3, 2, 3, 0);
+ otherSpace.resetIds();
+
+ // Attach identifiers.
+ int identifiers[7] = {0, 1, 2, 3, 4, 5, 6};
+ int otherIdentifiers[8] = {10, 11, 12, 13, 14, 15, 16, 17};
+
+ space.getId(VarKind::Domain, 0) = Identifier(&identifiers[0]);
+ space.getId(VarKind::Domain, 1) = Identifier(&identifiers[1]);
+ // Note the common identifier.
+ space.getId(VarKind::Domain, 2) = Identifier(&otherIdentifiers[2]);
+ space.getId(VarKind::Range, 0) = Identifier(&identifiers[2]);
+ space.getId(VarKind::Range, 1) = Identifier(&identifiers[3]);
+ space.getId(VarKind::Range, 2) = Identifier(&identifiers[4]);
+ space.getId(VarKind::Symbol, 0) = Identifier(&identifiers[5]);
+ space.getId(VarKind::Symbol, 1) = Identifier(&identifiers[6]);
+
+ otherSpace.getId(VarKind::Domain, 0) = Identifier(&otherIdentifiers[0]);
+ otherSpace.getId(VarKind::Domain, 1) = Identifier(&otherIdentifiers[1]);
+ otherSpace.getId(VarKind::Domain, 2) = Identifier(&otherIdentifiers[2]);
+ otherSpace.getId(VarKind::Range, 0) = Identifier(&otherIdentifiers[3]);
+ otherSpace.getId(VarKind::Range, 1) = Identifier(&otherIdentifiers[4]);
+ // Note the common identifier.
+ otherSpace.getId(VarKind::Symbol, 0) = Identifier(&identifiers[6]);
+ otherSpace.getId(VarKind::Symbol, 1) = Identifier(&otherIdentifiers[5]);
+ otherSpace.getId(VarKind::Symbol, 2) = Identifier(&otherIdentifiers[7]);
+
+ rel.setSpace(space);
+ otherRel.setSpace(otherSpace);
+ rel.mergeAndAlignSymbols(otherRel);
+
+ space = rel.getSpace();
+ otherSpace = otherRel.getSpace();
+
+ // Check if merge & align is successful.
+ // Check symbol var identifiers.
+ EXPECT_EQ(4u, space.getNumSymbolVars());
+ EXPECT_EQ(4u, otherSpace.getNumSymbolVars());
+ EXPECT_EQ(space.getId(VarKind::Symbol, 0), Identifier(&identifiers[5]));
+ EXPECT_EQ(space.getId(VarKind::Symbol, 1), Identifier(&identifiers[6]));
+ EXPECT_EQ(space.getId(VarKind::Symbol, 2), Identifier(&otherIdentifiers[5]));
+ EXPECT_EQ(space.getId(VarKind::Symbol, 3), Identifier(&otherIdentifiers[7]));
+ EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 0), Identifier(&identifiers[5]));
+ EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 1), Identifier(&identifiers[6]));
+ EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 2),
+ Identifier(&otherIdentifiers[5]));
+ EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 3),
+ Identifier(&otherIdentifiers[7]));
+ // Check that domain and range var identifiers are not affected.
+ EXPECT_EQ(3u, space.getNumDomainVars());
+ EXPECT_EQ(3u, space.getNumRangeVars());
+ EXPECT_EQ(space.getId(VarKind::Domain, 0), Identifier(&identifiers[0]));
+ EXPECT_EQ(space.getId(VarKind::Domain, 1), Identifier(&identifiers[1]));
+ EXPECT_EQ(space.getId(VarKind::Domain, 2), Identifier(&otherIdentifiers[2]));
+ EXPECT_EQ(space.getId(VarKind::Range, 0), Identifier(&identifiers[2]));
+ EXPECT_EQ(space.getId(VarKind::Range, 1), Identifier(&identifiers[3]));
+ EXPECT_EQ(space.getId(VarKind::Range, 2), Identifier(&identifiers[4]));
+ EXPECT_EQ(3u, otherSpace.getNumDomainVars());
+ EXPECT_EQ(2u, otherSpace.getNumRangeVars());
+ EXPECT_EQ(otherSpace.getId(VarKind::Domain, 0),
+ Identifier(&otherIdentifiers[0]));
+ EXPECT_EQ(otherSpace.getId(VarKind::Domain, 1),
+ Identifier(&otherIdentifiers[1]));
+ EXPECT_EQ(otherSpace.getId(VarKind::Domain, 2),
+ Identifier(&otherIdentifiers[2]));
+ EXPECT_EQ(otherSpace.getId(VarKind::Range, 0),
+ Identifier(&otherIdentifiers[3]));
+ EXPECT_EQ(otherSpace.getId(VarKind::Range, 1),
+ Identifier(&otherIdentifiers[4]));
+}
+
+TEST(IntegerRelationTest, mergeAndAlignSymbols2) {
+ IntegerRelation rel = parseRelationFromSet(
+ "(x, y, z)[A, B, C, D] : (x + A - C - y + D - z >= 0)", 2);
+ IntegerRelation otherRel = parseRelationFromSet(
+ "(u, v, a, b)[E, F, G, H] : (E - u + v == 0, v - G - H >= 0)", 2);
+ PresburgerSpace space = PresburgerSpace::getRelationSpace(2, 1, 4, 0);
+ space.resetIds();
+
+ PresburgerSpace otherSpace = PresburgerSpace::getRelationSpace(2, 2, 4, 0);
+ otherSpace.resetIds();
+
+ // Attach identifiers.
+ int identifiers[7] = {'x', 'y', 'z', 'A', 'B', 'C', 'D'};
+ int otherIdentifiers[8] = {'u', 'v', 'a', 'b', 'E', 'F', 'G', 'H'};
+
+ space.getId(VarKind::Domain, 0) = Identifier(&identifiers[0]);
+ space.getId(VarKind::Domain, 1) = Identifier(&identifiers[1]);
+ space.getId(VarKind::Range, 0) = Identifier(&identifiers[2]);
+ space.getId(VarKind::Symbol, 0) = Identifier(&identifiers[3]);
+ space.getId(VarKind::Symbol, 1) = Identifier(&identifiers[4]);
+ space.getId(VarKind::Symbol, 2) = Identifier(&identifiers[5]);
+ space.getId(VarKind::Symbol, 3) = Identifier(&identifiers[6]);
+
+ otherSpace.getId(VarKind::Domain, 0) = Identifier(&otherIdentifiers[0]);
+ otherSpace.getId(VarKind::Domain, 1) = Identifier(&otherIdentifiers[1]);
+ otherSpace.getId(VarKind::Range, 0) = Identifier(&otherIdentifiers[2]);
+ otherSpace.getId(VarKind::Range, 1) = Identifier(&otherIdentifiers[3]);
+ otherSpace.getId(VarKind::Symbol, 0) = Identifier(&otherIdentifiers[4]);
+ otherSpace.getId(VarKind::Symbol, 1) = Identifier(&otherIdentifiers[5]);
+ otherSpace.getId(VarKind::Symbol, 2) = Identifier(&otherIdentifiers[6]);
+ otherSpace.getId(VarKind::Symbol, 3) = Identifier(&otherIdentifiers[7]);
+
+ rel.setSpace(space);
+ otherRel.setSpace(otherSpace);
+ rel.mergeAndAlignSymbols(otherRel);
+
+ space = rel.getSpace();
+ otherSpace = otherRel.getSpace();
+
+ // Check if merge & align is successful.
+ // Check symbol var identifiers.
+ EXPECT_EQ(8u, space.getNumSymbolVars());
+ EXPECT_EQ(8u, otherSpace.getNumSymbolVars());
+ EXPECT_EQ(space.getId(VarKind::Symbol, 0), Identifier(&identifiers[3]));
+ EXPECT_EQ(space.getId(VarKind::Symbol, 1), Identifier(&identifiers[4]));
+ EXPECT_EQ(space.getId(VarKind::Symbol, 2), Identifier(&identifiers[5]));
+ EXPECT_EQ(space.getId(VarKind::Symbol, 3), Identifier(&identifiers[6]));
+ EXPECT_EQ(space.getId(VarKind::Symbol, 4), Identifier(&otherIdentifiers[4]));
+ EXPECT_EQ(space.getId(VarKind::Symbol, 5), Identifier(&otherIdentifiers[5]));
+ EXPECT_EQ(space.getId(VarKind::Symbol, 6), Identifier(&otherIdentifiers[6]));
+ EXPECT_EQ(space.getId(VarKind::Symbol, 7), Identifier(&otherIdentifiers[7]));
+ EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 0), Identifier(&identifiers[3]));
+ EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 1), Identifier(&identifiers[4]));
+ EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 2), Identifier(&identifiers[5]));
+ EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 3), Identifier(&identifiers[6]));
+ EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 4),
+ Identifier(&otherIdentifiers[4]));
+ EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 5),
+ Identifier(&otherIdentifiers[5]));
+ EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 6),
+ Identifier(&otherIdentifiers[6]));
+ EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 7),
+ Identifier(&otherIdentifiers[7]));
+ // Check that domain and range var identifiers are not affected.
+ EXPECT_EQ(2u, space.getNumDomainVars());
+ EXPECT_EQ(1u, space.getNumRangeVars());
+ EXPECT_EQ(space.getId(VarKind::Domain, 0), Identifier(&identifiers[0]));
+ EXPECT_EQ(space.getId(VarKind::Domain, 1), Identifier(&identifiers[1]));
+ EXPECT_EQ(space.getId(VarKind::Range, 0), Identifier(&identifiers[2]));
+ EXPECT_EQ(2u, otherSpace.getNumDomainVars());
+ EXPECT_EQ(2u, otherSpace.getNumRangeVars());
+ EXPECT_EQ(otherSpace.getId(VarKind::Domain, 0),
+ Identifier(&otherIdentifiers[0]));
+ EXPECT_EQ(otherSpace.getId(VarKind::Domain, 1),
+ Identifier(&otherIdentifiers[1]));
+ EXPECT_EQ(otherSpace.getId(VarKind::Range, 0),
+ Identifier(&otherIdentifiers[2]));
+ EXPECT_EQ(otherSpace.getId(VarKind::Range, 1),
+ Identifier(&otherIdentifiers[3]));
+}
+
+TEST(IntegerRelationTest, mergeAndAlignSymbols3) {
+ IntegerRelation rel = parseRelationFromSet(
+ "(x, y, z)[A, B, C, D] : (x + A - C - y + D - z >= 0)", 2);
+ IntegerRelation otherRel = parseRelationFromSet(
+ "(u, v, a, b)[E, F, C, D] : (E - u + v == 0, v - C - D >= 0)", 2);
+ PresburgerSpace space = PresburgerSpace::getRelationSpace(2, 1, 4, 0);
+ space.resetIds();
+
+ PresburgerSpace otherSpace = PresburgerSpace::getRelationSpace(2, 2, 4, 0);
+ otherSpace.resetIds();
+
+ // Attach identifiers.
+ int identifiers[7] = {'x', 'y', 'z', 'A', 'B', 'C', 'D'};
+ int otherIdentifiers[6] = {'u', 'v', 'a', 'b', 'E', 'F'};
+
+ space.getId(VarKind::Domain, 0) = Identifier(&identifiers[0]);
+ space.getId(VarKind::Domain, 1) = Identifier(&identifiers[1]);
+ space.getId(VarKind::Range, 0) = Identifier(&identifiers[2]);
+ space.getId(VarKind::Symbol, 0) = Identifier(&identifiers[3]);
+ space.getId(VarKind::Symbol, 1) = Identifier(&identifiers[4]);
+ space.getId(VarKind::Symbol, 2) = Identifier(&identifiers[5]);
+ space.getId(VarKind::Symbol, 3) = Identifier(&identifiers[6]);
+
+ otherSpace.getId(VarKind::Domain, 0) = Identifier(&otherIdentifiers[0]);
+ otherSpace.getId(VarKind::Domain, 1) = Identifier(&otherIdentifiers[1]);
+ otherSpace.getId(VarKind::Range, 0) = Identifier(&otherIdentifiers[2]);
+ otherSpace.getId(VarKind::Range, 1) = Identifier(&otherIdentifiers[3]);
+ otherSpace.getId(VarKind::Symbol, 0) = Identifier(&otherIdentifiers[4]);
+ otherSpace.getId(VarKind::Symbol, 1) = Identifier(&otherIdentifiers[5]);
+ // Note common identifiers
+ otherSpace.getId(VarKind::Symbol, 2) = Identifier(&identifiers[5]);
+ otherSpace.getId(VarKind::Symbol, 3) = Identifier(&identifiers[6]);
+
+ rel.setSpace(space);
+ otherRel.setSpace(otherSpace);
+ rel.mergeAndAlignSymbols(otherRel);
+
+ space = rel.getSpace();
+ otherSpace = otherRel.getSpace();
+
+ // Check if merge & align is successful.
+ // Check symbol var identifiers.
+ EXPECT_EQ(6u, space.getNumSymbolVars());
+ EXPECT_EQ(6u, otherSpace.getNumSymbolVars());
+ EXPECT_EQ(space.getId(VarKind::Symbol, 0), Identifier(&identifiers[3]));
+ EXPECT_EQ(space.getId(VarKind::Symbol, 1), Identifier(&identifiers[4]));
+ EXPECT_EQ(space.getId(VarKind::Symbol, 2), Identifier(&identifiers[5]));
+ EXPECT_EQ(space.getId(VarKind::Symbol, 3), Identifier(&identifiers[6]));
+ EXPECT_EQ(space.getId(VarKind::Symbol, 4), Identifier(&otherIdentifiers[4]));
+ EXPECT_EQ(space.getId(VarKind::Symbol, 5), Identifier(&otherIdentifiers[5]));
+ EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 0), Identifier(&identifiers[3]));
+ EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 1), Identifier(&identifiers[4]));
+ EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 2), Identifier(&identifiers[5]));
+ EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 3), Identifier(&identifiers[6]));
+ EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 4),
+ Identifier(&otherIdentifiers[4]));
+ EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 5),
+ Identifier(&otherIdentifiers[5]));
+ // Check that domain and range var identifiers are not affected.
+ EXPECT_EQ(2u, space.getNumDomainVars());
+ EXPECT_EQ(1u, space.getNumRangeVars());
+ EXPECT_EQ(space.getId(VarKind::Domain, 0), Identifier(&identifiers[0]));
+ EXPECT_EQ(space.getId(VarKind::Domain, 1), Identifier(&identifiers[1]));
+ EXPECT_EQ(space.getId(VarKind::Range, 0), Identifier(&identifiers[2]));
+ EXPECT_EQ(2u, otherSpace.getNumDomainVars());
+ EXPECT_EQ(2u, otherSpace.getNumRangeVars());
+ EXPECT_EQ(otherSpace.getId(VarKind::Domain, 0),
+ Identifier(&otherIdentifiers[0]));
+ EXPECT_EQ(otherSpace.getId(VarKind::Domain, 1),
+ Identifier(&otherIdentifiers[1]));
+ EXPECT_EQ(otherSpace.getId(VarKind::Range, 0),
+ Identifier(&otherIdentifiers[2]));
+ EXPECT_EQ(otherSpace.getId(VarKind::Range, 1),
+ Identifier(&otherIdentifiers[3]));
+}
More information about the Mlir-commits
mailing list