[Mlir-commits] [mlir] [MLIR][Presburger] Make IntegerRelation::convertVarKind consistent wi… (PR #67323)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 25 05:41:50 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

<details>
<summary>Changes</summary>

…th PresburgerSpace::convertVarKind

---
Full diff: https://github.com/llvm/llvm-project/pull/67323.diff


8 Files Affected:

- (modified) mlir/include/mlir/Analysis/Presburger/IntegerRelation.h (+11-12) 
- (modified) mlir/lib/Analysis/Presburger/IntegerRelation.cpp (+26-14) 
- (modified) mlir/lib/Analysis/Presburger/PresburgerRelation.cpp (+2-2) 
- (modified) mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp (+3-7) 
- (modified) mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp (+7-7) 
- (modified) mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp (+47) 
- (modified) mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp (+2-2) 
- (modified) mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp (+3-3) 


``````````diff
diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
index 56484622ec980cd..75289079dab4080 100644
--- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
+++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
@@ -545,20 +545,19 @@ class IntegerRelation {
 
   void removeDuplicateDivs();
 
-  /// Converts variables of kind srcKind in the range [varStart, varLimit) to
-  /// variables of kind dstKind. If `pos` is given, the variables are placed at
-  /// position `pos` of dstKind, otherwise they are placed after all the other
-  /// variables of kind dstKind. The internal ordering among the moved variables
-  /// is preserved.
-  void convertVarKind(VarKind srcKind, unsigned varStart, unsigned varLimit,
-                      VarKind dstKind, unsigned pos);
-  void convertVarKind(VarKind srcKind, unsigned varStart, unsigned varLimit,
+  /// Converts variables of the specified kind in the column range [srcPos,
+  /// srcPos + num) to variables of the specified kind at position dstPos. The
+  /// ranges are relative to the kind of variable.
+  ///
+  /// srcKind and dstKind must be different.
+  void convertVarKind(VarKind srcKind, unsigned srcPos, unsigned num,
+                      VarKind dstKind, unsigned dstPos);
+  void convertVarKind(VarKind srcKind, unsigned srcPos, unsigned num,
                       VarKind dstKind) {
-    convertVarKind(srcKind, varStart, varLimit, dstKind,
-                   getNumVarKind(dstKind));
+    convertVarKind(srcKind, srcPos, num, dstKind, getNumVarKind(dstKind));
   }
-  void convertToLocal(VarKind kind, unsigned varStart, unsigned varLimit) {
-    convertVarKind(kind, varStart, varLimit, VarKind::Local);
+  void convertToLocal(VarKind kind, unsigned varStart, unsigned num) {
+    convertVarKind(kind, varStart, num, VarKind::Local);
   }
 
   /// Adds additional local vars to the sets such that they both have the union
diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
index be764bd7c9176b9..6b9e165f0fd2f08 100644
--- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Analysis/Presburger/LinearTransform.h"
 #include "mlir/Analysis/Presburger/PWMAFunction.h"
 #include "mlir/Analysis/Presburger/PresburgerRelation.h"
+#include "mlir/Analysis/Presburger/PresburgerSpace.h"
 #include "mlir/Analysis/Presburger/Simplex.h"
 #include "mlir/Analysis/Presburger/Utils.h"
 #include "llvm/ADT/DenseMap.h"
@@ -1317,31 +1318,43 @@ void IntegerRelation::removeRedundantLocalVars() {
   }
 }
 
-void IntegerRelation::convertVarKind(VarKind srcKind, unsigned varStart,
-                                     unsigned varLimit, VarKind dstKind,
-                                     unsigned pos) {
-  assert(varLimit <= getNumVarKind(srcKind) && "Invalid id range");
+void IntegerRelation::convertVarKind(VarKind srcKind, unsigned srcPos,
+                                     unsigned num, VarKind dstKind,
+                                     unsigned dstPos) {
+  unsigned varLimit = srcPos + num;
+  assert(srcKind != dstKind && "cannot convert variables to the same kind");
+  assert(varLimit <= getNumVarKind(srcKind) &&
+         "invalid range for source variables");
+  assert(dstPos <= getNumVarKind(dstKind) &&
+         "invalid position for destination variables");
 
-  if (varStart >= varLimit)
+  if (srcPos >= varLimit)
     return;
 
+  // Save the space as the insert/delete vars operations affect the identifier
+  // information in the space.
+  PresburgerSpace oldSpace = space;
+
   // Append new local variables corresponding to the dimensions to be converted.
-  unsigned convertCount = varLimit - varStart;
-  unsigned newVarsBegin = insertVar(dstKind, pos, convertCount);
+  unsigned newVarsBegin = insertVar(dstKind, dstPos, num);
 
   // Swap the new local variables with dimensions.
   //
-  // Essentially, this moves the information corresponding to the specified ids
+  // Essentially, this moves the constraints corresponding to the specified ids
   // of kind `srcKind` to the `convertCount` newly created ids of kind
   // `dstKind`. In particular, this moves the columns in the constraint
   // matrices, and zeros out the initially occupied columns (because the newly
   // created ids we're swapping with were zero-initialized).
   unsigned offset = getVarKindOffset(srcKind);
-  for (unsigned i = 0; i < convertCount; ++i)
-    swapVar(offset + varStart + i, newVarsBegin + i);
+  for (unsigned i = 0; i < num; ++i)
+    swapVar(offset + srcPos + i, newVarsBegin + i);
+
+  // Delete the initially occupied columns.
+  removeVarRange(srcKind, srcPos, varLimit);
 
-  // Complete the move by deleting the initially occupied columns.
-  removeVarRange(srcKind, varStart, varLimit);
+  // Complete the move by updating the space.
+  space = oldSpace;
+  space.convertVarKind(srcKind, srcPos, num, dstKind, dstPos);
 }
 
 void IntegerRelation::addBound(BoundType type, unsigned pos,
@@ -2260,8 +2273,7 @@ void IntegerRelation::intersectRange(const IntegerPolyhedron &poly) {
 
 void IntegerRelation::inverse() {
   unsigned numRangeVars = getNumVarKind(VarKind::Range);
-  convertVarKind(VarKind::Domain, 0, getVarKindEnd(VarKind::Domain),
-                 VarKind::Range);
+  convertVarKind(VarKind::Domain, 0, getNumDomainVars(), VarKind::Range);
   convertVarKind(VarKind::Range, 0, numRangeVars, VarKind::Domain);
 }
 
diff --git a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
index 0b3f6a39128858e..8074e80c34908b7 100644
--- a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
@@ -43,7 +43,7 @@ void PresburgerRelation::convertVarKind(VarKind srcKind, unsigned srcPos,
                                         unsigned num, VarKind dstKind,
                                         unsigned dstPos) {
   assert(srcKind != VarKind::Local && dstKind != VarKind::Local &&
-      "srcKind/dstKind cannot be local");
+         "srcKind/dstKind cannot be local");
   assert(srcKind != dstKind && "cannot convert variables to the same kind");
   assert(srcPos + num <= space.getNumVarKind(srcKind) &&
          "invalid range for source variables");
@@ -53,7 +53,7 @@ void PresburgerRelation::convertVarKind(VarKind srcKind, unsigned srcPos,
   space.convertVarKind(srcKind, srcPos, num, dstKind, dstPos);
 
   for (IntegerRelation &disjunct : disjuncts)
-    disjunct.convertVarKind(srcKind, srcPos, srcPos + num, dstKind, dstPos);
+    disjunct.convertVarKind(srcKind, srcPos, num, dstKind, dstPos);
 }
 
 unsigned PresburgerRelation::getNumDisjuncts() const {
diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp
index 5f32505690263fc..8fa89da0e14d588 100644
--- a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp
@@ -34,7 +34,6 @@ using namespace mlir;
 using namespace affine;
 using namespace presburger;
 
-
 void FlatAffineValueConstraints::addInductionVarOrTerminalSymbol(Value val) {
   if (containsVar(val))
     return;
@@ -357,8 +356,7 @@ void FlatAffineValueConstraints::getIneqAsAffineValueMap(
 FlatAffineValueConstraints FlatAffineRelation::getDomainSet() const {
   FlatAffineValueConstraints domain = *this;
   // Convert all range variables to local variables.
-  domain.convertToLocal(VarKind::SetDim, getNumDomainDims(),
-                        getNumDomainDims() + getNumRangeDims());
+  domain.convertToLocal(VarKind::SetDim, getNumDomainDims(), getNumRangeDims());
   return domain;
 }
 
@@ -397,13 +395,11 @@ void FlatAffineRelation::compose(const FlatAffineRelation &other) {
   // Convert `rel` from [otherDomain] -> [otherRange thisRange] to
   // [otherDomain] -> [thisRange] by converting first otherRange range vars
   // to local vars.
-  rel.convertToLocal(VarKind::SetDim, rel.getNumDomainDims(),
-                     rel.getNumDomainDims() + removeDims);
+  rel.convertToLocal(VarKind::SetDim, rel.getNumDomainDims(), removeDims);
   // Convert `this` from [otherDomain thisDomain] -> [thisRange] to
   // [otherDomain] -> [thisRange] by converting last thisDomain domain vars
   // to local vars.
-  convertToLocal(VarKind::SetDim, getNumDomainDims() - removeDims,
-                 getNumDomainDims());
+  convertToLocal(VarKind::SetDim, getNumDomainDims() - removeDims, removeDims);
 
   auto thisMaybeValues = getMaybeValues(VarKind::SetDim);
   auto relMaybeValues = rel.getMaybeValues(VarKind::SetDim);
diff --git a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
index ba035e84ff1fd70..331a41d208c7523 100644
--- a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
@@ -728,7 +728,7 @@ TEST(IntegerPolyhedronTest, computeLocalReprTightUpperBound) {
     IntegerPolyhedron poly = parseIntegerPolyhedron(
         "(i, j, q) : (4*q - i - j + 2 >= 0, -4*q + i + j >= 0)");
     // Convert `q` to a local variable.
-    poly.convertToLocal(VarKind::SetDim, 2, 3);
+    poly.convertToLocal(VarKind::SetDim, 2, 1);
 
     std::vector<SmallVector<int64_t, 8>> divisions = {{1, 1, 0, 1}};
     SmallVector<int64_t, 8> denoms = {4};
@@ -743,7 +743,7 @@ TEST(IntegerPolyhedronTest, computeLocalReprFromEquality) {
     IntegerPolyhedron poly =
         parseIntegerPolyhedron("(i, j, q) : (-4*q + i + j == 0)");
     // Convert `q` to a local variable.
-    poly.convertToLocal(VarKind::SetDim, 2, 3);
+    poly.convertToLocal(VarKind::SetDim, 2, 1);
 
     std::vector<SmallVector<int64_t, 8>> divisions = {{1, 1, 0, 0}};
     SmallVector<int64_t, 8> denoms = {4};
@@ -754,7 +754,7 @@ TEST(IntegerPolyhedronTest, computeLocalReprFromEquality) {
     IntegerPolyhedron poly =
         parseIntegerPolyhedron("(i, j, q) : (4*q - i - j == 0)");
     // Convert `q` to a local variable.
-    poly.convertToLocal(VarKind::SetDim, 2, 3);
+    poly.convertToLocal(VarKind::SetDim, 2, 1);
 
     std::vector<SmallVector<int64_t, 8>> divisions = {{1, 1, 0, 0}};
     SmallVector<int64_t, 8> denoms = {4};
@@ -765,7 +765,7 @@ TEST(IntegerPolyhedronTest, computeLocalReprFromEquality) {
     IntegerPolyhedron poly =
         parseIntegerPolyhedron("(i, j, q) : (3*q + i + j - 2 == 0)");
     // Convert `q` to a local variable.
-    poly.convertToLocal(VarKind::SetDim, 2, 3);
+    poly.convertToLocal(VarKind::SetDim, 2, 1);
 
     std::vector<SmallVector<int64_t, 8>> divisions = {{-1, -1, 0, 2}};
     SmallVector<int64_t, 8> denoms = {3};
@@ -780,7 +780,7 @@ TEST(IntegerPolyhedronTest, computeLocalReprFromEqualityAndInequality) {
         parseIntegerPolyhedron("(i, j, q, k) : (-3*k + i + j == 0, 4*q - "
                                "i - j + 2 >= 0, -4*q + i + j >= 0)");
     // Convert `q` and `k` to local variables.
-    poly.convertToLocal(VarKind::SetDim, 2, 4);
+    poly.convertToLocal(VarKind::SetDim, 2, 2);
 
     std::vector<SmallVector<int64_t, 8>> divisions = {{1, 1, 0, 0, 1},
                                                       {1, 1, 0, 0, 0}};
@@ -794,7 +794,7 @@ TEST(IntegerPolyhedronTest, computeLocalReprNoRepr) {
   IntegerPolyhedron poly =
       parseIntegerPolyhedron("(x, q) : (x - 3 * q >= 0, -x + 3 * q + 3 >= 0)");
   // Convert q to a local variable.
-  poly.convertToLocal(VarKind::SetDim, 1, 2);
+  poly.convertToLocal(VarKind::SetDim, 1, 1);
 
   std::vector<SmallVector<int64_t, 8>> divisions = {{0, 0, 0}};
   SmallVector<int64_t, 8> denoms = {0};
@@ -807,7 +807,7 @@ TEST(IntegerPolyhedronTest, computeLocalReprNegConstNormalize) {
   IntegerPolyhedron poly = parseIntegerPolyhedron(
       "(x, q) : (-1 - 3*x - 6 * q >= 0, 6 + 3*x + 6*q >= 0)");
   // Convert q to a local variable.
-  poly.convertToLocal(VarKind::SetDim, 1, 2);
+  poly.convertToLocal(VarKind::SetDim, 1, 1);
 
   // q = floor((-1/3 - x)/2)
   //   = floor((1/3) + (-1 - x)/2)
diff --git a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
index 287f7c7c56549ff..067f508da16a586 100644
--- a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/Analysis/Presburger/IntegerRelation.h"
 #include "Parser.h"
+#include "mlir/Analysis/Presburger/PresburgerSpace.h"
 #include "mlir/Analysis/Presburger/Simplex.h"
 
 #include <gmock/gmock.h>
@@ -167,3 +168,49 @@ TEST(IntegerRelationTest, symbolicLexmax) {
   EXPECT_TRUE(lexmax3.unboundedDomain.isIntegerEmpty());
   EXPECT_TRUE(lexmax3.lexopt.isEqual(expectedLexmax3));
 }
+
+TEST(IntegerRelationTest, convertVarKind) {
+  PresburgerSpace space = PresburgerSpace::getSetSpace(3, 3, 0);
+  space.resetIds();
+
+  // Attach identifiers.
+  int identifiers[6] = {0, 1, 2, 3, 4, 5};
+  space.getId(VarKind::SetDim, 0) = Identifier(&identifiers[0]);
+  space.getId(VarKind::SetDim, 1) = Identifier(&identifiers[1]);
+  space.getId(VarKind::SetDim, 2) = Identifier(&identifiers[2]);
+  space.getId(VarKind::Symbol, 0) = Identifier(&identifiers[3]);
+  space.getId(VarKind::Symbol, 1) = Identifier(&identifiers[4]);
+  space.getId(VarKind::Symbol, 2) = Identifier(&identifiers[5]);
+
+  // Cannot call parseIntegerRelation to test convertVarKind as
+  // parseIntegerRelation uses convertVarKind.
+  IntegerRelation rel = parseIntegerPolyhedron(
+      // 0  1  2  3  4  5
+      "(x, y, a)[U, V, W] : (x - U == 0, y + a - W == 0, U - V >= 0, y - a >= "
+      "0)");
+  rel.setSpace(space);
+
+  // Make a few kind conversions.
+  rel.convertVarKind(VarKind::Symbol, 1, 1, VarKind::Domain, 0);
+  rel.convertVarKind(VarKind::Range, 2, 1, VarKind::Domain, 0);
+  rel.convertVarKind(VarKind::Range, 0, 2, VarKind::Symbol, 1);
+  rel.convertVarKind(VarKind::Domain, 1, 1, VarKind::Range, 0);
+  rel.convertVarKind(VarKind::Domain, 0, 1, VarKind::Range, 1);
+
+  space = rel.getSpace();
+
+  // Expected rel.
+  IntegerRelation expectedRel = parseIntegerPolyhedron(
+      "(V, a)[U, x, y, W] : (x - U == 0, y + a - W == 0, U - V >= 0,"
+      "y - a >= 0)");
+  expectedRel.setSpace(space);
+
+  EXPECT_TRUE(rel.isEqual(expectedRel));
+
+  EXPECT_EQ(space.getId(VarKind::SetDim, 0), Identifier(&identifiers[4]));
+  EXPECT_EQ(space.getId(VarKind::SetDim, 1), Identifier(&identifiers[2]));
+  EXPECT_EQ(space.getId(VarKind::Symbol, 0), Identifier(&identifiers[3]));
+  EXPECT_EQ(space.getId(VarKind::Symbol, 1), Identifier(&identifiers[0]));
+  EXPECT_EQ(space.getId(VarKind::Symbol, 2), Identifier(&identifiers[1]));
+  EXPECT_EQ(space.getId(VarKind::Symbol, 3), Identifier(&identifiers[5]));
+}
diff --git a/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp
index ad71bb32a06880f..a7c1e7501850add 100644
--- a/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp
@@ -299,10 +299,10 @@ TEST(PresburgerRelationTest, convertVarKind) {
 
   // Expected rel.
   disj1.convertVarKind(VarKind::Domain, 0, 1, VarKind::Range, 0);
-  disj1.convertVarKind(VarKind::Symbol, 1, 3, VarKind::Domain, 1);
+  disj1.convertVarKind(VarKind::Symbol, 1, 2, VarKind::Domain, 1);
   disj1.convertVarKind(VarKind::Symbol, 0, 1, VarKind::Range, 1);
   disj2.convertVarKind(VarKind::Domain, 0, 1, VarKind::Range, 0);
-  disj2.convertVarKind(VarKind::Symbol, 1, 3, VarKind::Domain, 1);
+  disj2.convertVarKind(VarKind::Symbol, 1, 2, VarKind::Domain, 1);
   disj2.convertVarKind(VarKind::Symbol, 0, 1, VarKind::Range, 1);
 
   PresburgerRelation expectedRel(disj1);
diff --git a/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp b/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp
index 8e31a8bb2030b6c..51042c6e5a85304 100644
--- a/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp
@@ -455,7 +455,7 @@ TEST(SetTest, divisions) {
 
 void convertSuffixDimsToLocals(IntegerPolyhedron &poly, unsigned numLocals) {
   poly.convertVarKind(VarKind::SetDim, poly.getNumDimVars() - numLocals,
-                      poly.getNumDimVars(), VarKind::Local);
+                      numLocals, VarKind::Local);
 }
 
 inline IntegerPolyhedron
@@ -815,7 +815,7 @@ void testComputeReprAtPoints(IntegerPolyhedron poly,
                              ArrayRef<SmallVector<int64_t, 4>> points,
                              unsigned numToProject) {
   poly.convertVarKind(VarKind::SetDim, poly.getNumDimVars() - numToProject,
-                      poly.getNumDimVars(), VarKind::Local);
+                      numToProject, VarKind::Local);
   PresburgerRelation repr = poly.computeReprWithOnlyDivLocals();
   EXPECT_TRUE(repr.hasOnlyDivLocals());
   EXPECT_TRUE(repr.getSpace().isCompatible(poly.getSpace()));
@@ -828,7 +828,7 @@ void testComputeReprAtPoints(IntegerPolyhedron poly,
 void testComputeRepr(IntegerPolyhedron poly, const PresburgerSet &expected,
                      unsigned numToProject) {
   poly.convertVarKind(VarKind::SetDim, poly.getNumDimVars() - numToProject,
-                      poly.getNumDimVars(), VarKind::Local);
+                      numToProject, VarKind::Local);
   PresburgerRelation repr = poly.computeReprWithOnlyDivLocals();
   EXPECT_TRUE(repr.hasOnlyDivLocals());
   EXPECT_TRUE(repr.getSpace().isCompatible(poly.getSpace()));

``````````

</details>


https://github.com/llvm/llvm-project/pull/67323


More information about the Mlir-commits mailing list