[Mlir-commits] [mlir] 4db6e14 - [MLIR][Presburger] Implement composition for PresburgerRelation

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jul 7 05:13:50 PDT 2023


Author: iambrj
Date: 2023-07-07T17:43:41+05:30
New Revision: 4db6e149624cd9a0fab2fba9c49f9a20c1068cee

URL: https://github.com/llvm/llvm-project/commit/4db6e149624cd9a0fab2fba9c49f9a20c1068cee
DIFF: https://github.com/llvm/llvm-project/commit/4db6e149624cd9a0fab2fba9c49f9a20c1068cee.diff

LOG: [MLIR][Presburger] Implement composition for PresburgerRelation

This patch implements range and domain composition for PresburgerRelations

Reviewed By: Groverkss

Differential Revision: https://reviews.llvm.org/D154444

Added: 
    mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp

Modified: 
    mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h
    mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
    mlir/unittests/Analysis/Presburger/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h b/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h
index 57ffbd38bd5c48..278d70a12c22bc 100644
--- a/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h
+++ b/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h
@@ -83,6 +83,27 @@ class PresburgerRelation {
   /// Return the intersection of this set and the given set.
   PresburgerRelation intersect(const PresburgerRelation &set) const;
 
+  /// Invert the relation, i.e. swap its domain and range.
+  ///
+  /// Formally, if `this`: A -> B then `inverse` updates `this` in-place to
+  /// `this`: B -> A.
+  void inverse();
+
+  /// Compose `this` relation with the given relation `rel` in-place.
+  ///
+  /// Formally, if `this`: A -> B, and `rel`: B -> C, then this function updates
+  /// `this` to `result`: A -> C where a point (a, c) belongs to `result`
+  /// iff there exists b such that (a, b) is in `this` and, (b, c) is in rel.
+  void compose(const PresburgerRelation &rel);
+
+  /// Apply the domain of given relation `rel` to `this` relation.
+  ///
+  /// Formally, R1.applyDomain(R2) = R2.inverse().compose(R1).
+  void applyDomain(const PresburgerRelation &rel);
+
+  /// Same as compose, provided for uniformity with applyDomain.
+  void applyRange(const PresburgerRelation &rel);
+
   /// Return true if the set contains the given point, and false otherwise.
   bool containsPoint(ArrayRef<MPInt> point) const;
   bool containsPoint(ArrayRef<int64_t> point) const {

diff  --git a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
index 9c956709d8c36b..440de3c12faf3b 100644
--- a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Analysis/Presburger/PresburgerRelation.h"
+#include "mlir/Analysis/Presburger/IntegerRelation.h"
 #include "mlir/Analysis/Presburger/Simplex.h"
 #include "mlir/Analysis/Presburger/Utils.h"
 #include "llvm/ADT/STLExtras.h"
@@ -108,6 +109,47 @@ PresburgerRelation::intersect(const PresburgerRelation &set) const {
   return result;
 }
 
+void PresburgerRelation::inverse() {
+  for (IntegerRelation &cs : disjuncts)
+    cs.inverse();
+
+  if (getNumDisjuncts())
+    setSpace(getDisjunct(0).getSpaceWithoutLocals());
+}
+
+void PresburgerRelation::compose(const PresburgerRelation &rel) {
+  assert(getSpace().getRangeSpace().isCompatible(
+             rel.getSpace().getDomainSpace()) &&
+         "Range of `this` should be compatible with domain of `rel`");
+
+  PresburgerRelation result =
+      PresburgerRelation::getEmpty(PresburgerSpace::getRelationSpace(
+          getNumDomainVars(), rel.getNumRangeVars(), getNumSymbolVars()));
+  for (const IntegerRelation &csA : disjuncts) {
+    for (const IntegerRelation &csB : rel.disjuncts) {
+      IntegerRelation composition = csA;
+      composition.compose(csB);
+      if (!composition.isEmpty())
+        result.unionInPlace(composition);
+    }
+  }
+  *this = result;
+}
+
+void PresburgerRelation::applyDomain(const PresburgerRelation &rel) {
+  assert(getSpace().getDomainSpace().isCompatible(
+             rel.getSpace().getDomainSpace()) &&
+         "Domain of `this` should be compatible with domain of `rel`");
+
+  inverse();
+  compose(rel);
+  inverse();
+}
+
+void PresburgerRelation::applyRange(const PresburgerRelation &rel) {
+  compose(rel);
+}
+
 /// Return the coefficients of the ineq in `rel` specified by  `idx`.
 /// `idx` can refer not only to an actual inequality of `rel`, but also
 /// to either of the inequalities that make up an equality in `rel`.

diff  --git a/mlir/unittests/Analysis/Presburger/CMakeLists.txt b/mlir/unittests/Analysis/Presburger/CMakeLists.txt
index 4bfaf95f1a89b4..7b0124ee24c352 100644
--- a/mlir/unittests/Analysis/Presburger/CMakeLists.txt
+++ b/mlir/unittests/Analysis/Presburger/CMakeLists.txt
@@ -7,6 +7,7 @@ add_mlir_unittest(MLIRPresburgerTests
   Parser.h
   ParserTest.cpp
   PresburgerSetTest.cpp
+  PresburgerRelationTest.cpp
   PresburgerSpaceTest.cpp
   PWMAFunctionTest.cpp
   SimplexTest.cpp

diff  --git a/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp
new file mode 100644
index 00000000000000..8b6e3c4fe3ed8e
--- /dev/null
+++ b/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp
@@ -0,0 +1,124 @@
+//===- PresburgerRelationTest.cpp - Tests for PresburgerRelation class ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Analysis/Presburger/PresburgerRelation.h"
+#include "Parser.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include <iostream>
+
+using namespace mlir;
+using namespace presburger;
+
+static PresburgerRelation
+parsePresburgerRelationFromPresburgerSet(ArrayRef<StringRef> strs,
+                                         unsigned numDomain) {
+  assert(!strs.empty() && "strs should not be empty");
+
+  IntegerRelation rel = parseIntegerPolyhedron(strs[0]);
+  rel.convertVarKind(VarKind::SetDim, 0, numDomain, VarKind::Domain);
+  PresburgerRelation result(rel);
+  for (unsigned i = 1, e = strs.size(); i < e; ++i) {
+    rel = parseIntegerPolyhedron(strs[i]);
+    rel.convertVarKind(VarKind::SetDim, 0, numDomain, VarKind::Domain);
+    result.unionInPlace(rel);
+  }
+  return result;
+}
+
+TEST(PresburgerRelationTest, applyDomainAndRange) {
+  {
+    PresburgerRelation map1 = parsePresburgerRelationFromPresburgerSet(
+        {// (x, y) -> (x + N, y - N)
+         "(x, y, a, b)[N] : (x - a + N == 0, y - b - N == 0)",
+         // (x, y) -> (y, x)
+         "(x, y, a, b)[N] : (a - y == 0, b - x == 0)",
+         // (x, y) -> (x + y, x - y)
+         "(x, y, a, b)[N] : (a - x - y == 0, b - x + y == 0)"},
+        2);
+    PresburgerRelation map2 = parsePresburgerRelationFromPresburgerSet(
+        {// (x, y) -> (x + y)
+         "(x, y, r)[N] : (r - x - y == 0)",
+         // (x, y) -> (N)
+         "(x, y, r)[N] : (r - N == 0)",
+         // (x, y) -> (y - x)
+         "(x, y, r)[N] : (r + x - y == 0)"},
+        2);
+
+    map1.applyRange(map2);
+
+    PresburgerRelation map3 = parsePresburgerRelationFromPresburgerSet(
+        {
+            // (x, y) -> (x + y)
+            "(x, y, r)[N] : (r - x - y == 0)",
+            // (x, y) -> (N)
+            "(x, y, r)[N] : (r - N == 0)",
+            // (x, y) -> (y - x - 2N)
+            "(x, y, r)[N] : (r - y + x + 2 * N == 0)",
+            // (x, y) -> (x - y)
+            "(x, y, r)[N] : (r - x + y == 0)",
+            // (x, y) -> (2x)
+            "(x, y, r)[N] : (r - 2 * x == 0)",
+            // (x, y) -> (-2y)
+            "(x, y, r)[N] : (r + 2 * y == 0)",
+        },
+        2);
+
+    EXPECT_TRUE(map1.isEqual(map3));
+  }
+
+  {
+    PresburgerRelation map1 = parsePresburgerRelationFromPresburgerSet(
+        {// (x, y) -> (y, x)
+         "(x, y, a, b)[N] : (y - a == 0, x - b == 0)",
+         // (x, y) -> (x + N, y - N)
+         "(x, y, a, b)[N] : (x - a + N == 0, y - b - N == 0)"},
+        2);
+    PresburgerRelation map2 = parsePresburgerRelationFromPresburgerSet(
+        {// (x, y) -> (x - y)
+         "(x, y, r)[N] : (x - y - r == 0)",
+         // (x, y) -> N
+         "(x, y, r)[N] : (N - r == 0)"},
+        2);
+
+    map1.applyDomain(map2);
+
+    PresburgerRelation map3 = parsePresburgerRelationFromPresburgerSet(
+        {// (y - x) -> (x, y)
+         "(r, x, y)[N] : (y - x - r == 0)",
+         // (x - y - 2N) -> (x, y)
+         "(r, x, y)[N] : (x - y - 2 * N - r == 0)",
+         // (x, y) -> N
+         "(r, x, y)[N] : (N - r == 0)"},
+        1);
+
+    EXPECT_TRUE(map1.isEqual(map3));
+  }
+}
+
+TEST(PresburgerRelationTest, inverse) {
+  {
+    PresburgerRelation rel = parsePresburgerRelationFromPresburgerSet(
+        {// (x, y) -> (-y, -x)
+         "(x, y, a, b)[N] : (y + a == 0, x + b == 0)",
+         // (x, y) -> (x + N, y - N)
+         "(x, y, a, b)[N] : (x - a + N == 0, y - b - N == 0)"},
+        2);
+
+    rel.inverse();
+
+    PresburgerRelation inverseRel = parsePresburgerRelationFromPresburgerSet(
+        {// (x, y) -> (-y, -x)
+         "(x, y, a, b)[N] : (y + a == 0, x + b == 0)",
+         // (x, y) -> (x - N, y + N)
+         "(x, y, a, b)[N] : (x - N - a == 0, y + N - b == 0)"},
+        2);
+
+    EXPECT_TRUE(rel.isEqual(inverseRel));
+  }
+}


        


More information about the Mlir-commits mailing list