[Mlir-commits] [mlir] 7817163 - [mlir] [presburger] Add IntegerRelation::rangeProduct (#148092)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jul 17 08:02:02 PDT 2025


Author: Jeremy Kun
Date: 2025-07-17T08:01:58-07:00
New Revision: 7817163663b3bb662a46a73cf1903ec900ba6146

URL: https://github.com/llvm/llvm-project/commit/7817163663b3bb662a46a73cf1903ec900ba6146
DIFF: https://github.com/llvm/llvm-project/commit/7817163663b3bb662a46a73cf1903ec900ba6146.diff

LOG: [mlir] [presburger] Add IntegerRelation::rangeProduct (#148092)

This is intended to match `isl::map`'s `flat_range_product`.

---------

Co-authored-by: Jeremy Kun <j2kun at users.noreply.github.com>

Added: 
    

Modified: 
    mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
    mlir/lib/Analysis/Presburger/IntegerRelation.cpp
    mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
index b68262f09f485..ee401cca8f552 100644
--- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
+++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
@@ -707,6 +707,19 @@ class IntegerRelation {
   /// this for uniformity with `applyDomain`.
   void applyRange(const IntegerRelation &rel);
 
+  /// Let the relation `this` be R1, and the relation `rel` be R2. Requires
+  /// R1 and R2 to have the same domain.
+  ///
+  /// Let R3 be the rangeProduct of R1 and R2. Then x R3 (y, z) iff
+  /// (x R1 y and x R2 z).
+  ///
+  /// Example:
+  ///
+  /// R1: (i, j) -> k : f(i, j, k) = 0
+  /// R2: (i, j) -> l : g(i, j, l) = 0
+  /// R1.rangeProduct(R2): (i, j) -> (k, l) : f(i, j, k) = 0 and g(i, j, l) = 0
+  IntegerRelation rangeProduct(const IntegerRelation &rel);
+
   /// Given a relation `other: (A -> B)`, this operation merges the symbol and
   /// local variables and then takes the composition of `other` on `this: (B ->
   /// C)`. The resulting relation represents tuples of the form: `A -> C`.

diff  --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
index 17e48e0d069b7..5c4d4d13580a0 100644
--- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
@@ -2481,6 +2481,44 @@ void IntegerRelation::applyDomain(const IntegerRelation &rel) {
 
 void IntegerRelation::applyRange(const IntegerRelation &rel) { compose(rel); }
 
+IntegerRelation IntegerRelation::rangeProduct(const IntegerRelation &rel) {
+  /// R1: (i, j) -> k : f(i, j, k) = 0
+  /// R2: (i, j) -> l : g(i, j, l) = 0
+  /// R1.rangeProduct(R2): (i, j) -> (k, l) : f(i, j, k) = 0 and g(i, j, l) = 0
+  assert(getNumDomainVars() == rel.getNumDomainVars() &&
+         "Range product is only defined for relations with equal domains");
+
+  // explicit copy of `this`
+  IntegerRelation result = *this;
+  unsigned relRangeVarStart = rel.getVarKindOffset(VarKind::Range);
+  unsigned numThisRangeVars = getNumRangeVars();
+  unsigned numNewSymbolVars = result.getNumSymbolVars() - getNumSymbolVars();
+
+  result.appendVar(VarKind::Range, rel.getNumRangeVars());
+
+  // Copy each equality from `rel` and update the copy to account for range
+  // variables from `this`. The `rel` equality is a list of coefficients of the
+  // variables from `rel`, and so the range variables need to be shifted right
+  // by the number of `this` range variables and symbols.
+  for (unsigned i = 0; i < rel.getNumEqualities(); ++i) {
+    SmallVector<DynamicAPInt> copy =
+        SmallVector<DynamicAPInt>(rel.getEquality(i));
+    copy.insert(copy.begin() + relRangeVarStart,
+                numThisRangeVars + numNewSymbolVars, DynamicAPInt(0));
+    result.addEquality(copy);
+  }
+
+  for (unsigned i = 0; i < rel.getNumInequalities(); ++i) {
+    SmallVector<DynamicAPInt> copy =
+        SmallVector<DynamicAPInt>(rel.getInequality(i));
+    copy.insert(copy.begin() + relRangeVarStart,
+                numThisRangeVars + numNewSymbolVars, DynamicAPInt(0));
+    result.addInequality(copy);
+  }
+
+  return result;
+}
+
 void IntegerRelation::printSpace(raw_ostream &os) const {
   space.print(os);
   os << getNumConstraints() << " constraints\n";

diff  --git a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
index 7df500bc9568a..dd0b09f7f05d2 100644
--- a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
@@ -608,3 +608,97 @@ TEST(IntegerRelationTest, convertVarKindToLocal) {
   EXPECT_EQ(space.getId(VarKind::Symbol, 0), Identifier(&identifiers[3]));
   EXPECT_EQ(space.getId(VarKind::Symbol, 1), Identifier(&identifiers[4]));
 }
+
+TEST(IntegerRelationTest, rangeProduct) {
+  IntegerRelation r1 = parseRelationFromSet(
+      "(i, j, k) : (2*i + 3*k == 0, i >= 0, j >= 0, k >= 0)", 2);
+  IntegerRelation r2 = parseRelationFromSet(
+      "(i, j, l) : (4*i + 6*j + 9*l == 0, i >= 0, j >= 0, l >= 0)", 2);
+
+  IntegerRelation rangeProd = r1.rangeProduct(r2);
+  IntegerRelation expected =
+      parseRelationFromSet("(i, j, k, l) : (2*i + 3*k == 0, 4*i + 6*j + 9*l == "
+                           "0, i >= 0, j >= 0, k >= 0, l >= 0)",
+                           2);
+
+  EXPECT_TRUE(expected.isEqual(rangeProd));
+}
+
+TEST(IntegerRelationTest, rangeProductMultdimRange) {
+  IntegerRelation r1 =
+      parseRelationFromSet("(i, k) : (2*i + 3*k == 0, i >= 0, k >= 0)", 1);
+  IntegerRelation r2 = parseRelationFromSet(
+      "(i, l, m) : (4*i + 6*m + 9*l == 0, i >= 0, l >= 0, m >= 0)", 1);
+
+  IntegerRelation rangeProd = r1.rangeProduct(r2);
+  IntegerRelation expected =
+      parseRelationFromSet("(i, k, l, m) : (2*i + 3*k == 0, 4*i + 6*m + 9*l == "
+                           "0, i >= 0, k >= 0, l >= 0, m >= 0)",
+                           1);
+
+  EXPECT_TRUE(expected.isEqual(rangeProd));
+}
+
+TEST(IntegerRelationTest, rangeProductMultdimRangeSwapped) {
+  IntegerRelation r1 = parseRelationFromSet(
+      "(i, l, m) : (4*i + 6*m + 9*l == 0, i >= 0, l >= 0, m >= 0)", 1);
+  IntegerRelation r2 =
+      parseRelationFromSet("(i, k) : (2*i + 3*k == 0, i >= 0, k >= 0)", 1);
+
+  IntegerRelation rangeProd = r1.rangeProduct(r2);
+  IntegerRelation expected =
+      parseRelationFromSet("(i, l, m, k) : (2*i + 3*k == 0, 4*i + 6*m + 9*l == "
+                           "0, i >= 0, k >= 0, l >= 0, m >= 0)",
+                           1);
+
+  EXPECT_TRUE(expected.isEqual(rangeProd));
+}
+
+TEST(IntegerRelationTest, rangeProductEmptyDomain) {
+  IntegerRelation r1 =
+      parseRelationFromSet("(i, j) : (4*i + 9*j == 0, i >= 0, j >= 0)", 0);
+  IntegerRelation r2 =
+      parseRelationFromSet("(k, l) : (2*k + 3*l == 0, k >= 0, l >= 0)", 0);
+  IntegerRelation rangeProd = r1.rangeProduct(r2);
+  IntegerRelation expected =
+      parseRelationFromSet("(i, j, k, l) : (2*k + 3*l == 0, 4*i + 9*j == "
+                           "0, i >= 0, j >= 0, k >= 0, l >= 0)",
+                           0);
+  EXPECT_TRUE(expected.isEqual(rangeProd));
+}
+
+TEST(IntegerRelationTest, rangeProductEmptyRange) {
+  IntegerRelation r1 =
+      parseRelationFromSet("(i, j) : (4*i + 9*j == 0, i >= 0, j >= 0)", 2);
+  IntegerRelation r2 =
+      parseRelationFromSet("(i, j) : (2*i + 3*j == 0, i >= 0, j >= 0)", 2);
+  IntegerRelation rangeProd = r1.rangeProduct(r2);
+  IntegerRelation expected =
+      parseRelationFromSet("(i, j) : (2*i + 3*j == 0, 4*i + 9*j == "
+                           "0, i >= 0, j >= 0)",
+                           2);
+  EXPECT_TRUE(expected.isEqual(rangeProd));
+}
+
+TEST(IntegerRelationTest, rangeProductEmptyDomainAndRange) {
+  IntegerRelation r1 = parseRelationFromSet("() : ()", 0);
+  IntegerRelation r2 = parseRelationFromSet("() : ()", 0);
+  IntegerRelation rangeProd = r1.rangeProduct(r2);
+  IntegerRelation expected = parseRelationFromSet("() : ()", 0);
+  EXPECT_TRUE(expected.isEqual(rangeProd));
+}
+
+TEST(IntegerRelationTest, rangeProductSymbols) {
+  IntegerRelation r1 = parseRelationFromSet(
+      "(i, j)[s] : (2*i + 3*j + s == 0, i >= 0, j >= 0)", 1);
+  IntegerRelation r2 = parseRelationFromSet(
+      "(i, l)[s] : (3*i + 4*l + s == 0, i >= 0, l >= 0)", 1);
+
+  IntegerRelation rangeProd = r1.rangeProduct(r2);
+  IntegerRelation expected = parseRelationFromSet(
+      "(i, j, l)[s] : (2*i + 3*j + s == 0, 3*i + 4*l + s == "
+      "0, i >= 0, j >= 0, l >= 0)",
+      1);
+
+  EXPECT_TRUE(expected.isEqual(rangeProd));
+}


        


More information about the Mlir-commits mailing list