[Mlir-commits] [mlir] [MLIR][Presburger] Implement IntegerRelation::mergeAndAlignSymbols (PR #76736)

Kunwar Grover llvmlistbot at llvm.org
Wed Jan 3 08:24:35 PST 2024


================
@@ -207,3 +207,243 @@ TEST(IntegerRelationTest, swapVar) {
   EXPECT_TRUE(swappedSpace.getId(VarKind::Symbol, 1)
                   .isEqual(space.getId(VarKind::Domain, 1)));
 }
+
+TEST(IntegerRelationTest, mergeAndAlignSymbols) {
+  IntegerRelation rel =
+      parseRelationFromSet("(x, y, z, a, b, c)[N, Q] : (a - x - y == 0, "
+                           "x >= 0, N - b >= 0, y >= 0, Q - y >= 0)",
+                           3);
+  IntegerRelation otherRel = parseRelationFromSet(
+      "(x, y, z, a, b)[N, M, P] : (z - x - y == 0, x >= 0, N - x "
+      ">= 0, y >= 0, M - y >= 0, 2 * P - 3 * a + 2 * b == 0)",
+      3);
+  PresburgerSpace space = PresburgerSpace::getRelationSpace(3, 3, 2, 0);
+  space.resetIds();
+
+  PresburgerSpace otherSpace = PresburgerSpace::getRelationSpace(3, 2, 3, 0);
+  otherSpace.resetIds();
+
+  // Attach identifiers.
+  int identifiers[7] = {0, 1, 2, 3, 4, 5, 6};
+  int otherIdentifiers[8] = {10, 11, 12, 13, 14, 15, 16, 17};
+
+  space.getId(VarKind::Domain, 0) = Identifier(&identifiers[0]);
+  space.getId(VarKind::Domain, 1) = Identifier(&identifiers[1]);
+  // Note the common identifier.
+  space.getId(VarKind::Domain, 2) = Identifier(&otherIdentifiers[2]);
+  space.getId(VarKind::Range, 0) = Identifier(&identifiers[2]);
+  space.getId(VarKind::Range, 1) = Identifier(&identifiers[3]);
+  space.getId(VarKind::Range, 2) = Identifier(&identifiers[4]);
+  space.getId(VarKind::Symbol, 0) = Identifier(&identifiers[5]);
+  space.getId(VarKind::Symbol, 1) = Identifier(&identifiers[6]);
+
+  otherSpace.getId(VarKind::Domain, 0) = Identifier(&otherIdentifiers[0]);
+  otherSpace.getId(VarKind::Domain, 1) = Identifier(&otherIdentifiers[1]);
+  otherSpace.getId(VarKind::Domain, 2) = Identifier(&otherIdentifiers[2]);
+  otherSpace.getId(VarKind::Range, 0) = Identifier(&otherIdentifiers[3]);
+  otherSpace.getId(VarKind::Range, 1) = Identifier(&otherIdentifiers[4]);
+  // Note the common identifier.
+  otherSpace.getId(VarKind::Symbol, 0) = Identifier(&identifiers[6]);
+  otherSpace.getId(VarKind::Symbol, 1) = Identifier(&otherIdentifiers[5]);
+  otherSpace.getId(VarKind::Symbol, 2) = Identifier(&otherIdentifiers[7]);
+
+  rel.setSpace(space);
+  otherRel.setSpace(otherSpace);
+  rel.mergeAndAlignSymbols(otherRel);
+
+  space = rel.getSpace();
+  otherSpace = otherRel.getSpace();
+
+  // Check if merge & align is successful.
+  // Check symbol var identifiers.
+  EXPECT_EQ(4u, space.getNumSymbolVars());
+  EXPECT_EQ(4u, otherSpace.getNumSymbolVars());
+  EXPECT_EQ(space.getId(VarKind::Symbol, 0), Identifier(&identifiers[5]));
+  EXPECT_EQ(space.getId(VarKind::Symbol, 1), Identifier(&identifiers[6]));
+  EXPECT_EQ(space.getId(VarKind::Symbol, 2), Identifier(&otherIdentifiers[5]));
+  EXPECT_EQ(space.getId(VarKind::Symbol, 3), Identifier(&otherIdentifiers[7]));
+  EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 0), Identifier(&identifiers[5]));
+  EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 1), Identifier(&identifiers[6]));
+  EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 2),
+            Identifier(&otherIdentifiers[5]));
+  EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 3),
+            Identifier(&otherIdentifiers[7]));
+  // Check that domain and range var identifiers are not affected.
+  EXPECT_EQ(3u, space.getNumDomainVars());
+  EXPECT_EQ(3u, space.getNumRangeVars());
+  EXPECT_EQ(space.getId(VarKind::Domain, 0), Identifier(&identifiers[0]));
+  EXPECT_EQ(space.getId(VarKind::Domain, 1), Identifier(&identifiers[1]));
+  EXPECT_EQ(space.getId(VarKind::Domain, 2), Identifier(&otherIdentifiers[2]));
+  EXPECT_EQ(space.getId(VarKind::Range, 0), Identifier(&identifiers[2]));
+  EXPECT_EQ(space.getId(VarKind::Range, 1), Identifier(&identifiers[3]));
+  EXPECT_EQ(space.getId(VarKind::Range, 2), Identifier(&identifiers[4]));
+  EXPECT_EQ(3u, otherSpace.getNumDomainVars());
+  EXPECT_EQ(2u, otherSpace.getNumRangeVars());
+  EXPECT_EQ(otherSpace.getId(VarKind::Domain, 0),
+            Identifier(&otherIdentifiers[0]));
+  EXPECT_EQ(otherSpace.getId(VarKind::Domain, 1),
+            Identifier(&otherIdentifiers[1]));
+  EXPECT_EQ(otherSpace.getId(VarKind::Domain, 2),
+            Identifier(&otherIdentifiers[2]));
+  EXPECT_EQ(otherSpace.getId(VarKind::Range, 0),
+            Identifier(&otherIdentifiers[3]));
+  EXPECT_EQ(otherSpace.getId(VarKind::Range, 1),
+            Identifier(&otherIdentifiers[4]));
+}
+
+TEST(IntegerRelationTest, mergeAndAlignSymbols2) {
+  IntegerRelation rel = parseRelationFromSet(
+      "(x, y, z)[A, B, C, D] : (x + A - C - y + D - z >= 0)", 2);
+  IntegerRelation otherRel = parseRelationFromSet(
+      "(u, v, a, b)[E, F, G, H] : (E - u + v == 0, v - G - H >= 0)", 2);
+  PresburgerSpace space = PresburgerSpace::getRelationSpace(2, 1, 4, 0);
+  space.resetIds();
+
+  PresburgerSpace otherSpace = PresburgerSpace::getRelationSpace(2, 2, 4, 0);
+  otherSpace.resetIds();
+
+  // Attach identifiers.
+  int identifiers[7] = {'x', 'y', 'z', 'A', 'B', 'C', 'D'};
+  int otherIdentifiers[8] = {'u', 'v', 'a', 'b', 'E', 'F', 'G', 'H'};
+
+  space.getId(VarKind::Domain, 0) = Identifier(&identifiers[0]);
+  space.getId(VarKind::Domain, 1) = Identifier(&identifiers[1]);
+  space.getId(VarKind::Range, 0) = Identifier(&identifiers[2]);
+  space.getId(VarKind::Symbol, 0) = Identifier(&identifiers[3]);
+  space.getId(VarKind::Symbol, 1) = Identifier(&identifiers[4]);
+  space.getId(VarKind::Symbol, 2) = Identifier(&identifiers[5]);
+  space.getId(VarKind::Symbol, 3) = Identifier(&identifiers[6]);
+
+  otherSpace.getId(VarKind::Domain, 0) = Identifier(&otherIdentifiers[0]);
+  otherSpace.getId(VarKind::Domain, 1) = Identifier(&otherIdentifiers[1]);
+  otherSpace.getId(VarKind::Range, 0) = Identifier(&otherIdentifiers[2]);
+  otherSpace.getId(VarKind::Range, 1) = Identifier(&otherIdentifiers[3]);
+  otherSpace.getId(VarKind::Symbol, 0) = Identifier(&otherIdentifiers[4]);
+  otherSpace.getId(VarKind::Symbol, 1) = Identifier(&otherIdentifiers[5]);
+  otherSpace.getId(VarKind::Symbol, 2) = Identifier(&otherIdentifiers[6]);
+  otherSpace.getId(VarKind::Symbol, 3) = Identifier(&otherIdentifiers[7]);
+
+  rel.setSpace(space);
+  otherRel.setSpace(otherSpace);
+  rel.mergeAndAlignSymbols(otherRel);
+
+  space = rel.getSpace();
+  otherSpace = otherRel.getSpace();
+
+  // Check if merge & align is successful.
+  // Check symbol var identifiers.
+  EXPECT_EQ(8u, space.getNumSymbolVars());
+  EXPECT_EQ(8u, otherSpace.getNumSymbolVars());
+  EXPECT_EQ(space.getId(VarKind::Symbol, 0), Identifier(&identifiers[3]));
+  EXPECT_EQ(space.getId(VarKind::Symbol, 1), Identifier(&identifiers[4]));
+  EXPECT_EQ(space.getId(VarKind::Symbol, 2), Identifier(&identifiers[5]));
+  EXPECT_EQ(space.getId(VarKind::Symbol, 3), Identifier(&identifiers[6]));
+  EXPECT_EQ(space.getId(VarKind::Symbol, 4), Identifier(&otherIdentifiers[4]));
+  EXPECT_EQ(space.getId(VarKind::Symbol, 5), Identifier(&otherIdentifiers[5]));
+  EXPECT_EQ(space.getId(VarKind::Symbol, 6), Identifier(&otherIdentifiers[6]));
+  EXPECT_EQ(space.getId(VarKind::Symbol, 7), Identifier(&otherIdentifiers[7]));
+  EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 0), Identifier(&identifiers[3]));
+  EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 1), Identifier(&identifiers[4]));
+  EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 2), Identifier(&identifiers[5]));
+  EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 3), Identifier(&identifiers[6]));
+  EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 4),
+            Identifier(&otherIdentifiers[4]));
+  EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 5),
+            Identifier(&otherIdentifiers[5]));
+  EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 6),
+            Identifier(&otherIdentifiers[6]));
+  EXPECT_EQ(otherSpace.getId(VarKind::Symbol, 7),
+            Identifier(&otherIdentifiers[7]));
+  // Check that domain and range var identifiers are not affected.
+  EXPECT_EQ(2u, space.getNumDomainVars());
+  EXPECT_EQ(1u, space.getNumRangeVars());
+  EXPECT_EQ(space.getId(VarKind::Domain, 0), Identifier(&identifiers[0]));
+  EXPECT_EQ(space.getId(VarKind::Domain, 1), Identifier(&identifiers[1]));
+  EXPECT_EQ(space.getId(VarKind::Range, 0), Identifier(&identifiers[2]));
+  EXPECT_EQ(2u, otherSpace.getNumDomainVars());
+  EXPECT_EQ(2u, otherSpace.getNumRangeVars());
+  EXPECT_EQ(otherSpace.getId(VarKind::Domain, 0),
+            Identifier(&otherIdentifiers[0]));
+  EXPECT_EQ(otherSpace.getId(VarKind::Domain, 1),
+            Identifier(&otherIdentifiers[1]));
+  EXPECT_EQ(otherSpace.getId(VarKind::Range, 0),
+            Identifier(&otherIdentifiers[2]));
+  EXPECT_EQ(otherSpace.getId(VarKind::Range, 1),
+            Identifier(&otherIdentifiers[3]));
+}
+
+TEST(IntegerRelationTest, mergeAndAlignSymbols3) {
+  IntegerRelation rel = parseRelationFromSet(
+      "(x, y, z)[A, B, C, D] : (x + A - C - y + D - z >= 0)", 2);
+  IntegerRelation otherRel = parseRelationFromSet(
+      "(u, v, a, b)[E, F, C, D] : (E - u + v == 0, v - C - D >= 0)", 2);
+  PresburgerSpace space = PresburgerSpace::getRelationSpace(2, 1, 4, 0);
+  space.resetIds();
+
+  PresburgerSpace otherSpace = PresburgerSpace::getRelationSpace(2, 2, 4, 0);
+  otherSpace.resetIds();
+
+  // Attach identifiers.
+  int identifiers[7] = {'x', 'y', 'z', 'A', 'B', 'C', 'D'};
+  int otherIdentifiers[6] = {'u', 'v', 'a', 'b', 'E', 'F'};
+
+  space.getId(VarKind::Domain, 0) = Identifier(&identifiers[0]);
+  space.getId(VarKind::Domain, 1) = Identifier(&identifiers[1]);
+  space.getId(VarKind::Range, 0) = Identifier(&identifiers[2]);
+  space.getId(VarKind::Symbol, 0) = Identifier(&identifiers[3]);
+  space.getId(VarKind::Symbol, 1) = Identifier(&identifiers[4]);
+  space.getId(VarKind::Symbol, 2) = Identifier(&identifiers[5]);
+  space.getId(VarKind::Symbol, 3) = Identifier(&identifiers[6]);
+
+  otherSpace.getId(VarKind::Domain, 0) = Identifier(&otherIdentifiers[0]);
+  otherSpace.getId(VarKind::Domain, 1) = Identifier(&otherIdentifiers[1]);
+  otherSpace.getId(VarKind::Range, 0) = Identifier(&otherIdentifiers[2]);
+  otherSpace.getId(VarKind::Range, 1) = Identifier(&otherIdentifiers[3]);
+  otherSpace.getId(VarKind::Symbol, 0) = Identifier(&otherIdentifiers[4]);
+  otherSpace.getId(VarKind::Symbol, 1) = Identifier(&otherIdentifiers[5]);
+  // Note common identifiers
+  otherSpace.getId(VarKind::Symbol, 2) = Identifier(&identifiers[5]);
+  otherSpace.getId(VarKind::Symbol, 3) = Identifier(&identifiers[6]);
+
+  rel.setSpace(space);
+  otherRel.setSpace(otherSpace);
+  rel.mergeAndAlignSymbols(otherRel);
+
+  space = rel.getSpace();
+  otherSpace = otherRel.getSpace();
+
+  // Check if merge & align is successful.
----------------
Groverkss wrote:

nit: & -> and

https://github.com/llvm/llvm-project/pull/76736


More information about the Mlir-commits mailing list