[Mlir-commits] [mlir] a5a598b - [MLIR][Presburger] Use PresburgerSpace in constructors

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Apr 1 02:40:39 PDT 2022


Author: Groverkss
Date: 2022-04-01T15:07:26+05:30
New Revision: a5a598be44b12a27fa6043b20d36be6d099e385a

URL: https://github.com/llvm/llvm-project/commit/a5a598be44b12a27fa6043b20d36be6d099e385a
DIFF: https://github.com/llvm/llvm-project/commit/a5a598be44b12a27fa6043b20d36be6d099e385a.diff

LOG: [MLIR][Presburger] Use PresburgerSpace in constructors

This patch modifies IntegerPolyhedron, IntegerRelation, PresburgerRelation,
PresburgerSet, PWMAFunction, constructors to take PresburgerSpace instead of
dimensions. This allows information present in PresburgerSpace to be carried
better and allows for a general interface.

Reviewed By: arjunp

Differential Revision: https://reviews.llvm.org/D122842

Added: 
    

Modified: 
    mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
    mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
    mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h
    mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h
    mlir/include/mlir/Dialect/Affine/Analysis/AffineStructures.h
    mlir/lib/Analysis/Presburger/IntegerRelation.cpp
    mlir/lib/Analysis/Presburger/LinearTransform.cpp
    mlir/lib/Analysis/Presburger/PWMAFunction.cpp
    mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
    mlir/lib/Analysis/Presburger/PresburgerSpace.cpp
    mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp
    mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
    mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp
    mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp
    mlir/unittests/Analysis/Presburger/Utils.h

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
index cc2d4af6c3c99..4a32dc90c4731 100644
--- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
+++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
@@ -57,9 +57,8 @@ class IntegerRelation : public PresburgerSpace {
   /// of constraints and identifiers.
   IntegerRelation(unsigned numReservedInequalities,
                   unsigned numReservedEqualities, unsigned numReservedCols,
-                  unsigned numDomain, unsigned numRange, unsigned numSymbols,
-                  unsigned numLocals)
-      : PresburgerSpace(numDomain, numRange, numSymbols, numLocals),
+                  const PresburgerSpace &space)
+      : PresburgerSpace(space),
         equalities(0, getNumIds() + 1, numReservedEqualities, numReservedCols),
         inequalities(0, getNumIds() + 1, numReservedInequalities,
                      numReservedCols) {
@@ -67,20 +66,15 @@ class IntegerRelation : public PresburgerSpace {
   }
 
   /// Constructs a relation with the specified number of dimensions and symbols.
-  IntegerRelation(unsigned numDomain = 0, unsigned numRange = 0,
-                  unsigned numSymbols = 0, unsigned numLocals = 0)
+  IntegerRelation(const PresburgerSpace &space)
       : IntegerRelation(/*numReservedInequalities=*/0,
                         /*numReservedEqualities=*/0,
-                        /*numReservedCols=*/numDomain + numRange + numSymbols +
-                            numLocals + 1,
-                        numDomain, numRange, numSymbols, numLocals) {}
+                        /*numReservedCols=*/space.getNumIds() + 1, space) {}
 
   /// Return a system with no constraints, i.e., one which is satisfied by all
   /// points.
-  static IntegerRelation getUniverse(unsigned numDomain = 0,
-                                     unsigned numRange = 0,
-                                     unsigned numSymbols = 0) {
-    return IntegerRelation(numDomain, numRange, numSymbols);
+  static IntegerRelation getUniverse(const PresburgerSpace &space) {
+    return IntegerRelation(space);
   }
 
   /// Return the kind of this IntegerRelation.
@@ -562,25 +556,24 @@ class IntegerPolyhedron : public IntegerRelation {
   /// of constraints and identifiers.
   IntegerPolyhedron(unsigned numReservedInequalities,
                     unsigned numReservedEqualities, unsigned numReservedCols,
-                    unsigned numDims, unsigned numSymbols, unsigned numLocals)
+                    const PresburgerSpace &space)
       : IntegerRelation(numReservedInequalities, numReservedEqualities,
-                        numReservedCols, /*numDomain=*/0, /*numRange=*/numDims,
-                        numSymbols, numLocals) {}
+                        numReservedCols, space) {
+    assert(space.getNumDomainIds() == 0 &&
+           "Number of domain id's should be zero in Set kind space.");
+  }
 
-  /// Constructs a relation with the specified number of dimensions and symbols.
-  IntegerPolyhedron(unsigned numDims = 0, unsigned numSymbols = 0,
-                    unsigned numLocals = 0)
+  /// Constructs a relation with the specified number of dimensions and
+  /// symbols.
+  IntegerPolyhedron(const PresburgerSpace &space)
       : IntegerPolyhedron(/*numReservedInequalities=*/0,
                           /*numReservedEqualities=*/0,
-                          /*numReservedCols=*/numDims + numSymbols + numLocals +
-                              1,
-                          numDims, numSymbols, numLocals) {}
+                          /*numReservedCols=*/space.getNumIds() + 1, space) {}
 
   /// Return a system with no constraints, i.e., one which is satisfied by all
   /// points.
-  static IntegerPolyhedron getUniverse(unsigned numDims = 0,
-                                       unsigned numSymbols = 0) {
-    return IntegerPolyhedron(numDims, numSymbols);
+  static IntegerPolyhedron getUniverse(const PresburgerSpace &space) {
+    return IntegerPolyhedron(space);
   }
 
   /// Return the kind of this IntegerRelation.

diff  --git a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
index 8d1199f1b247c..ce0d77da9bc2c 100644
--- a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
+++ b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
@@ -57,9 +57,8 @@ class MultiAffineFunction : protected IntegerPolyhedron {
 
   MultiAffineFunction(const IntegerPolyhedron &domain, const Matrix &output)
       : IntegerPolyhedron(domain), output(output) {}
-  MultiAffineFunction(const Matrix &output, unsigned numDims,
-                      unsigned numSymbols = 0, unsigned numLocals = 0)
-      : IntegerPolyhedron(numDims, numSymbols, numLocals), output(output) {}
+  MultiAffineFunction(const Matrix &output, const PresburgerSpace &space)
+      : IntegerPolyhedron(space), output(output) {}
 
   ~MultiAffineFunction() override = default;
   Kind getKind() const override { return Kind::MultiAffineFunction; }
@@ -137,10 +136,10 @@ class MultiAffineFunction : protected IntegerPolyhedron {
 /// finding the value of the function at a point.
 class PWMAFunction : public PresburgerSpace {
 public:
-  PWMAFunction(unsigned numDims, unsigned numSymbols, unsigned numOutputs)
-      : PresburgerSpace(/*numDomain=*/0, /*numRange=*/numDims, numSymbols,
-                        /*numLocals=*/0),
-        numOutputs(numOutputs) {
+  PWMAFunction(const PresburgerSpace &space, unsigned numOutputs)
+      : PresburgerSpace(space), numOutputs(numOutputs) {
+    assert(getNumDomainIds() == 0 && "Set type space should zero domain ids.");
+    assert(getNumLocalIds() == 0 && "PWMAFunction cannot have local ids.");
     assert(numOutputs >= 1 && "The function must output something!");
   }
 

diff  --git a/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h b/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h
index 69730b5dcf13b..8093a710c5512 100644
--- a/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h
+++ b/mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h
@@ -37,13 +37,10 @@ class SetCoalescer;
 class PresburgerRelation : public PresburgerSpace {
 public:
   /// Return a universe set of the specified type that contains all points.
-  static PresburgerRelation getUniverse(unsigned numDomain, unsigned numRange,
-                                        unsigned numSymbols);
+  static PresburgerRelation getUniverse(const PresburgerSpace &space);
 
   /// Return an empty set of the specified type that contains no points.
-  static PresburgerRelation getEmpty(unsigned numDomain = 0,
-                                     unsigned numRange = 0,
-                                     unsigned numSymbols = 0);
+  static PresburgerRelation getEmpty(const PresburgerSpace &space);
 
   explicit PresburgerRelation(const IntegerRelation &disjunct);
 
@@ -119,9 +116,10 @@ class PresburgerRelation : public PresburgerSpace {
 protected:
   /// Construct an empty PresburgerRelation with the specified number of
   /// dimension and symbols.
-  PresburgerRelation(unsigned numDomain = 0, unsigned numRange = 0,
-                     unsigned numSymbols = 0)
-      : PresburgerSpace(numDomain, numRange, numSymbols, /*numLocals=*/0) {}
+  PresburgerRelation(const PresburgerSpace &space) : PresburgerSpace(space) {
+    assert(space.getNumLocalIds() == 0 &&
+           "PresburgerRelation cannot have local ids.");
+  }
 
   /// The list of disjuncts that this set is the union of.
   SmallVector<IntegerRelation, 2> integerRelations;
@@ -132,11 +130,10 @@ class PresburgerRelation : public PresburgerSpace {
 class PresburgerSet : public PresburgerRelation {
 public:
   /// Return a universe set of the specified type that contains all points.
-  static PresburgerSet getUniverse(unsigned numDims = 0,
-                                   unsigned numSymbols = 0);
+  static PresburgerSet getUniverse(const PresburgerSpace &space);
 
   /// Return an empty set of the specified type that contains no points.
-  static PresburgerSet getEmpty(unsigned numDims = 0, unsigned numSymbols = 0);
+  static PresburgerSet getEmpty(const PresburgerSpace &space);
 
   /// Create a set from a relation.
   explicit PresburgerSet(const IntegerPolyhedron &disjunct);
@@ -154,8 +151,11 @@ class PresburgerSet : public PresburgerRelation {
 protected:
   /// Construct an empty PresburgerRelation with the specified number of
   /// dimension and symbols.
-  PresburgerSet(unsigned numDims = 0, unsigned numSymbols = 0)
-      : PresburgerRelation(/*numDomain=*/0, numDims, numSymbols) {}
+  PresburgerSet(const PresburgerSpace &space) : PresburgerRelation(space) {
+    assert(space.getNumDomainIds() == 0 && "Set type cannot have domain ids.");
+    assert(space.getNumLocalIds() == 0 &&
+           "PresburgerRelation cannot have local ids.");
+  }
 };
 
 } // namespace presburger

diff  --git a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h
index 4a587a21f95b3..936151908419e 100644
--- a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h
+++ b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h
@@ -64,10 +64,24 @@ enum class IdKind { Symbol, Local, Domain, Range, SetDim = Range };
 /// identifiers of each kind are equal.
 class PresburgerSpace {
 public:
-  PresburgerSpace(unsigned numDomain = 0, unsigned numRange = 0,
-                  unsigned numSymbols = 0, unsigned numLocals = 0)
-      : numDomain(numDomain), numRange(numRange), numSymbols(numSymbols),
-        numLocals(numLocals) {}
+  static PresburgerSpace getRelationSpace(unsigned numDomain = 0,
+                                          unsigned numRange = 0,
+                                          unsigned numSymbols = 0,
+                                          unsigned numLocals = 0) {
+    return PresburgerSpace(numDomain, numRange, numSymbols, numLocals);
+  }
+
+  static PresburgerSpace getSetSpace(unsigned numDims = 0,
+                                     unsigned numSymbols = 0,
+                                     unsigned numLocals = 0) {
+    return PresburgerSpace(/*numDomain=*/0, /*numRange=*/numDims, numSymbols,
+                           numLocals);
+  }
+
+  PresburgerSpace getSpace() const { return *this; }
+  PresburgerSpace getCompatibleSpace() const {
+    return PresburgerSpace(numDomain, numRange, numSymbols);
+  }
 
   virtual ~PresburgerSpace() = default;
 
@@ -99,6 +113,9 @@ class PresburgerSpace {
   unsigned getIdKindOverlap(IdKind kind, unsigned idStart,
                             unsigned idLimit) const;
 
+  /// Return the IdKind of the id at the specified position.
+  IdKind getIdKindAt(unsigned pos) const;
+
   /// Insert `num` identifiers of the specified kind at position `pos`.
   /// Positions are relative to the kind of identifier. Return the absolute
   /// column position (i.e., not relative to the kind of identifier) of the
@@ -131,6 +148,12 @@ class PresburgerSpace {
   void print(llvm::raw_ostream &os) const;
   void dump() const;
 
+protected:
+  PresburgerSpace(unsigned numDomain = 0, unsigned numRange = 0,
+                  unsigned numSymbols = 0, unsigned numLocals = 0)
+      : numDomain(numDomain), numRange(numRange), numSymbols(numSymbols),
+        numLocals(numLocals) {}
+
 private:
   // Number of identifiers corresponding to domain identifiers.
   unsigned numDomain;

diff  --git a/mlir/include/mlir/Dialect/Affine/Analysis/AffineStructures.h b/mlir/include/mlir/Dialect/Affine/Analysis/AffineStructures.h
index 8c256181ecfd3..fc0364747cafb 100644
--- a/mlir/include/mlir/Dialect/Affine/Analysis/AffineStructures.h
+++ b/mlir/include/mlir/Dialect/Affine/Analysis/AffineStructures.h
@@ -65,18 +65,19 @@ class FlatAffineConstraints : public presburger::IntegerPolyhedron {
                         unsigned numReservedEqualities,
                         unsigned numReservedCols, unsigned numDims,
                         unsigned numSymbols, unsigned numLocals)
-      : IntegerPolyhedron(numReservedInequalities, numReservedEqualities,
-                          numReservedCols, numDims, numSymbols, numLocals) {}
+      : IntegerPolyhedron(
+            numReservedInequalities, numReservedEqualities, numReservedCols,
+            PresburgerSpace::getSetSpace(numDims, numSymbols, numLocals)) {}
 
   /// Constructs a constraint system with the specified number of
   /// dimensions and symbols.
   FlatAffineConstraints(unsigned numDims = 0, unsigned numSymbols = 0,
                         unsigned numLocals = 0)
-      : IntegerPolyhedron(/*numReservedInequalities=*/0,
-                          /*numReservedEqualities=*/0,
-                          /*numReservedCols=*/numDims + numSymbols + numLocals +
-                              1,
-                          numDims, numSymbols, numLocals) {}
+      : FlatAffineConstraints(/*numReservedInequalities=*/0,
+                              /*numReservedEqualities=*/0,
+                              /*numReservedCols=*/numDims + numSymbols +
+                                  numLocals + 1,
+                              numDims, numSymbols, numLocals) {}
 
   explicit FlatAffineConstraints(const IntegerPolyhedron &poly)
       : IntegerPolyhedron(poly) {}

diff  --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
index de95511dec6d2..2ae384d2bbcf6 100644
--- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
@@ -1702,20 +1702,14 @@ void IntegerRelation::fourierMotzkinEliminate(unsigned pos, bool darkShadow,
     }
   }
 
-  // Set the number of dimensions, symbols, locals in the resulting system.
-  unsigned newNumDomain =
-      getNumDomainIds() - getIdKindOverlap(IdKind::Domain, pos, pos + 1);
-  unsigned newNumRange =
-      getNumRangeIds() - getIdKindOverlap(IdKind::Range, pos, pos + 1);
-  unsigned newNumSymbols =
-      getNumSymbolIds() - getIdKindOverlap(IdKind::Symbol, pos, pos + 1);
-  unsigned newNumLocals =
-      getNumLocalIds() - getIdKindOverlap(IdKind::Local, pos, pos + 1);
+  PresburgerSpace newSpace = getSpace();
+  IdKind idKindRemove = newSpace.getIdKindAt(pos);
+  unsigned relativePos = pos - newSpace.getIdKindOffset(idKindRemove);
+  newSpace.removeIdRange(idKindRemove, relativePos, relativePos + 1);
 
   /// Create the new system which has one identifier less.
   IntegerRelation newRel(lbIndices.size() * ubIndices.size() + nbIndices.size(),
-                         getNumEqualities(), getNumCols() - 1, newNumDomain,
-                         newNumRange, newNumSymbols, newNumLocals);
+                         getNumEqualities(), getNumCols() - 1, newSpace);
 
   // This will be used to check if the elimination was integer exact.
   unsigned lcmProducts = 1;
@@ -1866,8 +1860,7 @@ static BoundCmpResult compareBounds(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
 // Returns constraints that are common to both A & B.
 static void getCommonConstraints(const IntegerRelation &a,
                                  const IntegerRelation &b, IntegerRelation &c) {
-  c = IntegerRelation(a.getNumDomainIds(), a.getNumRangeIds(),
-                      a.getNumSymbolIds(), a.getNumLocalIds());
+  c = IntegerRelation(a.getSpace());
   // a naive O(n^2) check should be enough here given the input sizes.
   for (unsigned r = 0, e = a.getNumInequalities(); r < e; ++r) {
     for (unsigned s = 0, f = b.getNumInequalities(); s < f; ++s) {
@@ -1896,7 +1889,7 @@ IntegerRelation::unionBoundingBox(const IntegerRelation &otherCst) {
 
   // Get the constraints common to both systems; these will be added as is to
   // the union.
-  IntegerRelation commonCst;
+  IntegerRelation commonCst(PresburgerSpace::getRelationSpace());
   getCommonConstraints(*this, otherCst, commonCst);
 
   std::vector<SmallVector<int64_t, 8>> boundingLbs;

diff  --git a/mlir/lib/Analysis/Presburger/LinearTransform.cpp b/mlir/lib/Analysis/Presburger/LinearTransform.cpp
index 6c399fe4f4923..ba9d34b998ac0 100644
--- a/mlir/lib/Analysis/Presburger/LinearTransform.cpp
+++ b/mlir/lib/Analysis/Presburger/LinearTransform.cpp
@@ -113,8 +113,7 @@ LinearTransform::makeTransformToColumnEchelon(Matrix m) {
 }
 
 IntegerRelation LinearTransform::applyTo(const IntegerRelation &rel) const {
-  IntegerRelation result(rel.getNumDomainIds(), rel.getNumRangeIds(),
-                         rel.getNumSymbolIds(), rel.getNumLocalIds());
+  IntegerRelation result(rel.getSpace());
 
   for (unsigned i = 0, e = rel.getNumEqualities(); i < e; ++i) {
     ArrayRef<int64_t> eq = rel.getEquality(i);

diff  --git a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp
index d00bc2d7f580b..b995bc00a19c8 100644
--- a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp
+++ b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp
@@ -27,8 +27,7 @@ static SmallVector<int64_t, 8> subtract(ArrayRef<int64_t> vecA,
 }
 
 PresburgerSet PWMAFunction::getDomain() const {
-  PresburgerSet domain =
-      PresburgerSet::getEmpty(getNumDimIds(), getNumSymbolIds());
+  PresburgerSet domain = PresburgerSet::getEmpty(getSpace());
   for (const MultiAffineFunction &piece : pieces)
     domain.unionInPlace(piece.getDomain());
   return domain;

diff  --git a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
index f5f02455bd4b4..c99fc81bd3c8c 100644
--- a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
@@ -17,7 +17,7 @@ using namespace mlir;
 using namespace presburger;
 
 PresburgerRelation::PresburgerRelation(const IntegerRelation &disjunct)
-    : PresburgerSpace(disjunct) {
+    : PresburgerSpace(disjunct.getCompatibleSpace()) {
   unionInPlace(disjunct);
 }
 
@@ -67,19 +67,15 @@ bool PresburgerRelation::containsPoint(ArrayRef<int64_t> point) const {
   });
 }
 
-PresburgerRelation PresburgerRelation::getUniverse(unsigned numDomain,
-                                                   unsigned numRange,
-                                                   unsigned numSymbols) {
-  PresburgerRelation result(numDomain, numRange, numSymbols);
-  result.unionInPlace(
-      IntegerRelation::getUniverse(numDomain, numRange, numSymbols));
+PresburgerRelation
+PresburgerRelation::getUniverse(const PresburgerSpace &space) {
+  PresburgerRelation result(space);
+  result.unionInPlace(IntegerRelation::getUniverse(space));
   return result;
 }
 
-PresburgerRelation PresburgerRelation::getEmpty(unsigned numDomain,
-                                                unsigned numRange,
-                                                unsigned numSymbols) {
-  return PresburgerRelation(numDomain, numRange, numSymbols);
+PresburgerRelation PresburgerRelation::getEmpty(const PresburgerSpace &space) {
+  return PresburgerRelation(space);
 }
 
 // Return the intersection of this set with the given set.
@@ -93,8 +89,7 @@ PresburgerRelation
 PresburgerRelation::intersect(const PresburgerRelation &set) const {
   assert(isSpaceCompatible(set) && "Spaces should match");
 
-  PresburgerRelation result(getNumDomainIds(), getNumRangeIds(),
-                            getNumSymbolIds());
+  PresburgerRelation result(getSpace());
   for (const IntegerRelation &csA : integerRelations) {
     for (const IntegerRelation &csB : set.integerRelations) {
       IntegerRelation intersection = csA.intersect(csB);
@@ -283,13 +278,10 @@ static PresburgerRelation getSetDifference(IntegerRelation disjunct,
                                            const PresburgerRelation &set) {
   assert(disjunct.isSpaceCompatible(set) && "Spaces should match");
   if (disjunct.isEmptyByGCDTest())
-    return PresburgerRelation::getEmpty(disjunct.getNumDomainIds(),
-                                        disjunct.getNumRangeIds(),
-                                        disjunct.getNumSymbolIds());
+    return PresburgerRelation::getEmpty(disjunct.getCompatibleSpace());
 
-  PresburgerRelation result = PresburgerRelation::getEmpty(
-      disjunct.getNumDomainIds(), disjunct.getNumRangeIds(),
-      disjunct.getNumSymbolIds());
+  PresburgerRelation result =
+      PresburgerRelation::getEmpty(disjunct.getCompatibleSpace());
   Simplex simplex(disjunct);
   subtractRecursively(disjunct, simplex, set, 0, result);
   return result;
@@ -297,10 +289,7 @@ static PresburgerRelation getSetDifference(IntegerRelation disjunct,
 
 /// Return the complement of this set.
 PresburgerRelation PresburgerRelation::complement() const {
-  return getSetDifference(IntegerRelation::getUniverse(getNumDomainIds(),
-                                                       getNumRangeIds(),
-                                                       getNumSymbolIds()),
-                          *this);
+  return getSetDifference(IntegerRelation::getUniverse(getSpace()), *this);
 }
 
 /// Return the result of subtract the given set from this set, i.e.,
@@ -308,8 +297,7 @@ PresburgerRelation PresburgerRelation::complement() const {
 PresburgerRelation
 PresburgerRelation::subtract(const PresburgerRelation &set) const {
   assert(isSpaceCompatible(set) && "Spaces should match");
-  PresburgerRelation result(getNumDomainIds(), getNumRangeIds(),
-                            getNumSymbolIds());
+  PresburgerRelation result(getSpace());
   // We compute (U_i t_i) \ (U_i set_i) as U_i (t_i \ V_i set_i).
   for (const IntegerRelation &disjunct : integerRelations)
     result.unionInPlace(getSetDifference(disjunct, set));
@@ -505,7 +493,8 @@ PresburgerRelation SetCoalescer::coalesce() {
   }
 
   PresburgerRelation newSet =
-      PresburgerRelation::getEmpty(numDomainIds, numRangeIds, numSymbolIds);
+      PresburgerRelation::getEmpty(PresburgerSpace::getRelationSpace(
+          numDomainIds, numRangeIds, numSymbolIds));
   for (unsigned i = 0, e = disjuncts.size(); i < e; ++i)
     newSet.unionInPlace(disjuncts[i]);
 
@@ -584,8 +573,7 @@ LogicalResult SetCoalescer::coalescePairCutCase(unsigned i, unsigned j) {
         return !isFacetContained(curr, simp);
       }))
     return failure();
-  IntegerRelation newSet(disjunct.getNumDomainIds(), disjunct.getNumRangeIds(),
-                         disjunct.getNumSymbolIds(), disjunct.getNumLocalIds());
+  IntegerRelation newSet(disjunct.getSpace());
 
   for (ArrayRef<int64_t> curr : redundantIneqsA)
     newSet.addInequality(curr);
@@ -707,15 +695,14 @@ void PresburgerRelation::print(raw_ostream &os) const {
 
 void PresburgerRelation::dump() const { print(llvm::errs()); }
 
-PresburgerSet PresburgerSet::getUniverse(unsigned numDims,
-                                         unsigned numSymbols) {
-  PresburgerSet result(numDims, numSymbols);
-  result.unionInPlace(IntegerPolyhedron::getUniverse(numDims, numSymbols));
+PresburgerSet PresburgerSet::getUniverse(const PresburgerSpace &space) {
+  PresburgerSet result(space);
+  result.unionInPlace(IntegerPolyhedron::getUniverse(space));
   return result;
 }
 
-PresburgerSet PresburgerSet::getEmpty(unsigned numDims, unsigned numSymbols) {
-  return PresburgerSet(numDims, numSymbols);
+PresburgerSet PresburgerSet::getEmpty(const PresburgerSpace &space) {
+  return PresburgerSet(space);
 }
 
 PresburgerSet::PresburgerSet(const IntegerPolyhedron &disjunct)

diff  --git a/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp b/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp
index 255e6c17d7c6b..4b1a22039aee5 100644
--- a/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp
+++ b/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp
@@ -56,6 +56,19 @@ unsigned PresburgerSpace::getIdKindOverlap(IdKind kind, unsigned idStart,
   return overlapEnd - overlapStart;
 }
 
+IdKind PresburgerSpace::getIdKindAt(unsigned pos) const {
+  assert(pos < getNumIds() && "`pos` should represent a valid id position");
+  if (pos < getIdKindEnd(IdKind::Domain))
+    return IdKind::Domain;
+  if (pos < getIdKindEnd(IdKind::Range))
+    return IdKind::Range;
+  if (pos < getIdKindEnd(IdKind::Symbol))
+    return IdKind::Symbol;
+  if (pos < getIdKindEnd(IdKind::Local))
+    return IdKind::Local;
+  llvm_unreachable("`pos` should represent a valid id position");
+}
+
 unsigned PresburgerSpace::insertId(IdKind kind, unsigned pos, unsigned num) {
   assert(pos <= getNumIdKind(kind));
 

diff  --git a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp
index a54820dcbb7aa..2530e239acc86 100644
--- a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp
@@ -158,8 +158,9 @@ FlatAffineValueConstraints::clone() const {
 FlatAffineConstraints::FlatAffineConstraints(IntegerSet set)
     : IntegerPolyhedron(set.getNumInequalities(), set.getNumEqualities(),
                         set.getNumDims() + set.getNumSymbols() + 1,
-                        set.getNumDims(), set.getNumSymbols(),
-                        /*numLocals=*/0) {
+                        PresburgerSpace::getSetSpace(set.getNumDims(),
+                                                     set.getNumSymbols(),
+                                                     /*numLocals=*/0)) {
 
   // Flatten expressions and add them to the constraint system.
   std::vector<SmallVector<int64_t, 8>> flatExprs;

diff  --git a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
index 4d55036e5fbf7..4149d85d8759f 100644
--- a/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/IntegerPolyhedronTest.cpp
@@ -28,8 +28,9 @@ static IntegerPolyhedron
 makeSetFromConstraints(unsigned ids, ArrayRef<SmallVector<int64_t, 4>> ineqs,
                        ArrayRef<SmallVector<int64_t, 4>> eqs,
                        unsigned syms = 0) {
-  IntegerPolyhedron set(ineqs.size(), eqs.size(), ids + 1, ids - syms, syms,
-                        /*numLocals=*/0);
+  IntegerPolyhedron set(
+      ineqs.size(), eqs.size(), ids + 1,
+      PresburgerSpace::getSetSpace(ids - syms, syms, /*numLocals=*/0));
   for (const auto &eq : eqs)
     set.addEquality(eq);
   for (const auto &ineq : ineqs)
@@ -178,7 +179,7 @@ TEST(IntegerPolyhedronTest, clearConstraints) {
 }
 
 TEST(IntegerPolyhedronTest, removeIdRange) {
-  IntegerPolyhedron set(3, 2, 1);
+  IntegerPolyhedron set(PresburgerSpace::getSetSpace(3, 2, 1));
 
   set.addInequality({10, 11, 12, 20, 21, 30, 40});
   set.removeId(IdKind::Symbol, 1);
@@ -572,7 +573,7 @@ TEST(IntegerPolyhedronTest, removeRedundantConstraintsTest) {
 }
 
 TEST(IntegerPolyhedronTest, addConstantUpperBound) {
-  IntegerPolyhedron poly(2);
+  IntegerPolyhedron poly(PresburgerSpace::getSetSpace(2));
   poly.addBound(IntegerPolyhedron::UB, 0, 1);
   EXPECT_EQ(poly.atIneq(0, 0), -1);
   EXPECT_EQ(poly.atIneq(0, 1), 0);
@@ -585,7 +586,7 @@ TEST(IntegerPolyhedronTest, addConstantUpperBound) {
 }
 
 TEST(IntegerPolyhedronTest, addConstantLowerBound) {
-  IntegerPolyhedron poly(2);
+  IntegerPolyhedron poly(PresburgerSpace::getSetSpace(2));
   poly.addBound(IntegerPolyhedron::LB, 0, 1);
   EXPECT_EQ(poly.atIneq(0, 0), 1);
   EXPECT_EQ(poly.atIneq(0, 1), 0);
@@ -626,7 +627,7 @@ static void checkDivisionRepresentation(
 }
 
 TEST(IntegerPolyhedronTest, computeLocalReprSimple) {
-  IntegerPolyhedron poly(1);
+  IntegerPolyhedron poly(PresburgerSpace::getSetSpace(1));
 
   poly.addLocalFloorDiv({1, 4}, 10);
   poly.addLocalFloorDiv({1, 0, 100}, 10);
@@ -641,7 +642,7 @@ TEST(IntegerPolyhedronTest, computeLocalReprSimple) {
 }
 
 TEST(IntegerPolyhedronTest, computeLocalReprConstantFloorDiv) {
-  IntegerPolyhedron poly(4);
+  IntegerPolyhedron poly(PresburgerSpace::getSetSpace(4));
 
   poly.addInequality({1, 0, 3, 1, 2});
   poly.addInequality({1, 2, -8, 1, 10});
@@ -659,7 +660,7 @@ TEST(IntegerPolyhedronTest, computeLocalReprConstantFloorDiv) {
 }
 
 TEST(IntegerPolyhedronTest, computeLocalReprRecursive) {
-  IntegerPolyhedron poly(4);
+  IntegerPolyhedron poly(PresburgerSpace::getSetSpace(4));
   poly.addInequality({1, 0, 3, 1, 2});
   poly.addInequality({1, 2, -8, 1, 10});
   poly.addEquality({1, 2, -4, 1, 10});
@@ -795,14 +796,14 @@ TEST(IntegerPolyhedronTest, computeLocalReprNegConstNormalize) {
 
 TEST(IntegerPolyhedronTest, simplifyLocalsTest) {
   // (x) : (exists y: 2x + y = 1 and y = 2).
-  IntegerPolyhedron poly(1, 0, 1);
+  IntegerPolyhedron poly(PresburgerSpace::getSetSpace(1, 0, 1));
   poly.addEquality({2, 1, -1});
   poly.addEquality({0, 1, -2});
 
   EXPECT_TRUE(poly.isEmpty());
 
   // (x) : (exists y, z, w: 3x + y = 1 and 2y = z and 3y = w and z = w).
-  IntegerPolyhedron poly2(1, 0, 3);
+  IntegerPolyhedron poly2(PresburgerSpace::getSetSpace(1, 0, 3));
   poly2.addEquality({3, 1, 0, 0, -1});
   poly2.addEquality({0, 2, -1, 0, 0});
   poly2.addEquality({0, 3, 0, -1, 0});
@@ -811,7 +812,7 @@ TEST(IntegerPolyhedronTest, simplifyLocalsTest) {
   EXPECT_TRUE(poly2.isEmpty());
 
   // (x) : (exists y: x >= y + 1 and 2x + y = 0 and y >= -1).
-  IntegerPolyhedron poly3(1, 0, 1);
+  IntegerPolyhedron poly3(PresburgerSpace::getSetSpace(1, 0, 1));
   poly3.addInequality({1, -1, -1});
   poly3.addInequality({0, 1, 1});
   poly3.addEquality({2, 1, 0});
@@ -822,13 +823,13 @@ TEST(IntegerPolyhedronTest, simplifyLocalsTest) {
 TEST(IntegerPolyhedronTest, mergeDivisionsSimple) {
   {
     // (x) : (exists z, y  = [x / 2] : x = 3y and x + z + 1 >= 0).
-    IntegerPolyhedron poly1(1, 0, 1);
+    IntegerPolyhedron poly1(PresburgerSpace::getSetSpace(1, 0, 1));
     poly1.addLocalFloorDiv({1, 0, 0}, 2); // y = [x / 2].
     poly1.addEquality({1, 0, -3, 0});     // x = 3y.
     poly1.addInequality({1, 1, 0, 1});    // x + z + 1 >= 0.
 
     // (x) : (exists y = [x / 2], z : x = 5y).
-    IntegerPolyhedron poly2(1);
+    IntegerPolyhedron poly2(PresburgerSpace::getSetSpace(1));
     poly2.addLocalFloorDiv({1, 0}, 2); // y = [x / 2].
     poly2.addEquality({1, -5, 0});     // x = 5y.
     poly2.appendId(IdKind::Local);     // Add local id z.
@@ -845,13 +846,13 @@ TEST(IntegerPolyhedronTest, mergeDivisionsSimple) {
 
   {
     // (x) : (exists z = [x / 5], y = [x / 2] : x = 3y).
-    IntegerPolyhedron poly1(1);
+    IntegerPolyhedron poly1(PresburgerSpace::getSetSpace(1));
     poly1.addLocalFloorDiv({1, 0}, 5);    // z = [x / 5].
     poly1.addLocalFloorDiv({1, 0, 0}, 2); // y = [x / 2].
     poly1.addEquality({1, 0, -3, 0});     // x = 3y.
 
     // (x) : (exists y = [x / 2], z = [x / 5]: x = 5z).
-    IntegerPolyhedron poly2(1);
+    IntegerPolyhedron poly2(PresburgerSpace::getSetSpace(1));
     poly2.addLocalFloorDiv({1, 0}, 2);    // y = [x / 2].
     poly2.addLocalFloorDiv({1, 0, 0}, 5); // z = [x / 5].
     poly2.addEquality({1, 0, -5, 0});     // x = 5z.
@@ -869,14 +870,14 @@ TEST(IntegerPolyhedronTest, mergeDivisionsSimple) {
   {
     // Division Normalization test.
     // (x) : (exists z, y  = [x / 2] : x = 3y and x + z + 1 >= 0).
-    IntegerPolyhedron poly1(1, 0, 1);
+    IntegerPolyhedron poly1(PresburgerSpace::getSetSpace(1, 0, 1));
     // This division would be normalized.
     poly1.addLocalFloorDiv({3, 0, 0}, 6); // y = [3x / 6] -> [x/2].
     poly1.addEquality({1, 0, -3, 0});     // x = 3z.
     poly1.addInequality({1, 1, 0, 1});    // x + y + 1 >= 0.
 
     // (x) : (exists y = [x / 2], z : x = 5y).
-    IntegerPolyhedron poly2(1);
+    IntegerPolyhedron poly2(PresburgerSpace::getSetSpace(1));
     poly2.addLocalFloorDiv({1, 0}, 2); // y = [x / 2].
     poly2.addEquality({1, -5, 0});     // x = 5y.
     poly2.appendId(IdKind::Local);     // Add local id z.
@@ -895,13 +896,13 @@ TEST(IntegerPolyhedronTest, mergeDivisionsSimple) {
 TEST(IntegerPolyhedronTest, mergeDivisionsNestedDivsions) {
   {
     // (x) : (exists y = [x / 2], z = [x + y / 3]: y + z >= x).
-    IntegerPolyhedron poly1(1);
+    IntegerPolyhedron poly1(PresburgerSpace::getSetSpace(1));
     poly1.addLocalFloorDiv({1, 0}, 2);    // y = [x / 2].
     poly1.addLocalFloorDiv({1, 1, 0}, 3); // z = [x + y / 3].
     poly1.addInequality({-1, 1, 1, 0});   // y + z >= x.
 
     // (x) : (exists y = [x / 2], z = [x + y / 3]: y + z <= x).
-    IntegerPolyhedron poly2(1);
+    IntegerPolyhedron poly2(PresburgerSpace::getSetSpace(1));
     poly2.addLocalFloorDiv({1, 0}, 2);    // y = [x / 2].
     poly2.addLocalFloorDiv({1, 1, 0}, 3); // z = [x + y / 3].
     poly2.addInequality({1, -1, -1, 0});  // y + z <= x.
@@ -918,14 +919,14 @@ TEST(IntegerPolyhedronTest, mergeDivisionsNestedDivsions) {
 
   {
     // (x) : (exists y = [x / 2], z = [x + y / 3], w = [z + 1 / 5]: y + z >= x).
-    IntegerPolyhedron poly1(1);
+    IntegerPolyhedron poly1(PresburgerSpace::getSetSpace(1));
     poly1.addLocalFloorDiv({1, 0}, 2);       // y = [x / 2].
     poly1.addLocalFloorDiv({1, 1, 0}, 3);    // z = [x + y / 3].
     poly1.addLocalFloorDiv({0, 0, 1, 1}, 5); // w = [z + 1 / 5].
     poly1.addInequality({-1, 1, 1, 0, 0});   // y + z >= x.
 
     // (x) : (exists y = [x / 2], z = [x + y / 3], w = [z + 1 / 5]: y + z <= x).
-    IntegerPolyhedron poly2(1);
+    IntegerPolyhedron poly2(PresburgerSpace::getSetSpace(1));
     poly2.addLocalFloorDiv({1, 0}, 2);       // y = [x / 2].
     poly2.addLocalFloorDiv({1, 1, 0}, 3);    // z = [x + y / 3].
     poly2.addLocalFloorDiv({0, 0, 1, 1}, 5); // w = [z + 1 / 5].
@@ -942,13 +943,13 @@ TEST(IntegerPolyhedronTest, mergeDivisionsNestedDivsions) {
   }
   {
     // (x) : (exists y = [x / 2], z = [x + y / 3]: y + z >= x).
-    IntegerPolyhedron poly1(1);
+    IntegerPolyhedron poly1(PresburgerSpace::getSetSpace(1));
     poly1.addLocalFloorDiv({2, 0}, 4);    // y = [2x / 4] -> [x / 2].
     poly1.addLocalFloorDiv({1, 1, 0}, 3); // z = [x + y / 3].
     poly1.addInequality({-1, 1, 1, 0});   // y + z >= x.
 
     // (x) : (exists y = [x / 2], z = [x + y / 3]: y + z <= x).
-    IntegerPolyhedron poly2(1);
+    IntegerPolyhedron poly2(PresburgerSpace::getSetSpace(1));
     poly2.addLocalFloorDiv({1, 0}, 2); // y = [x / 2].
     // This division would be normalized.
     poly2.addLocalFloorDiv({3, 3, 0}, 9); // z = [3x + 3y / 9] -> [x + y / 3].
@@ -968,13 +969,13 @@ TEST(IntegerPolyhedronTest, mergeDivisionsNestedDivsions) {
 TEST(IntegerPolyhedronTest, mergeDivisionsConstants) {
   {
     // (x) : (exists y = [x + 1 / 3], z = [x + 2 / 3]: y + z >= x).
-    IntegerPolyhedron poly1(1);
+    IntegerPolyhedron poly1(PresburgerSpace::getSetSpace(1));
     poly1.addLocalFloorDiv({1, 1}, 2);    // y = [x + 1 / 2].
     poly1.addLocalFloorDiv({1, 0, 2}, 3); // z = [x + 2 / 3].
     poly1.addInequality({-1, 1, 1, 0});   // y + z >= x.
 
     // (x) : (exists y = [x + 1 / 3], z = [x + 2 / 3]: y + z <= x).
-    IntegerPolyhedron poly2(1);
+    IntegerPolyhedron poly2(PresburgerSpace::getSetSpace(1));
     poly2.addLocalFloorDiv({1, 1}, 2);    // y = [x + 1 / 2].
     poly2.addLocalFloorDiv({1, 0, 2}, 3); // z = [x + 2 / 3].
     poly2.addInequality({1, -1, -1, 0});  // y + z <= x.
@@ -990,14 +991,14 @@ TEST(IntegerPolyhedronTest, mergeDivisionsConstants) {
   }
   {
     // (x) : (exists y = [x + 1 / 3], z = [x + 2 / 3]: y + z >= x).
-    IntegerPolyhedron poly1(1);
+    IntegerPolyhedron poly1(PresburgerSpace::getSetSpace(1));
     poly1.addLocalFloorDiv({1, 1}, 2); // y = [x + 1 / 2].
     // Normalization test.
     poly1.addLocalFloorDiv({3, 0, 6}, 9); // z = [3x + 6 / 9] -> [x + 2 / 3].
     poly1.addInequality({-1, 1, 1, 0});   // y + z >= x.
 
     // (x) : (exists y = [x + 1 / 3], z = [x + 2 / 3]: y + z <= x).
-    IntegerPolyhedron poly2(1);
+    IntegerPolyhedron poly2(PresburgerSpace::getSetSpace(1));
     // Normalization test.
     poly2.addLocalFloorDiv({2, 2}, 4);    // y = [2x + 2 / 4] -> [x + 1 / 2].
     poly2.addLocalFloorDiv({1, 0, 2}, 3); // z = [x + 2 / 3].
@@ -1016,14 +1017,14 @@ TEST(IntegerPolyhedronTest, mergeDivisionsConstants) {
 
 TEST(IntegerPolyhedronTest, negativeDividends) {
   // (x) : (exists y = [-x + 1 / 2], z = [-x - 2 / 3]: y + z >= x).
-  IntegerPolyhedron poly1(1);
+  IntegerPolyhedron poly1(PresburgerSpace::getSetSpace(1));
   poly1.addLocalFloorDiv({-1, 1}, 2); // y = [x + 1 / 2].
   // Normalization test with negative dividends
   poly1.addLocalFloorDiv({-3, 0, -6}, 9); // z = [3x + 6 / 9] -> [x + 2 / 3].
   poly1.addInequality({-1, 1, 1, 0});     // y + z >= x.
 
   // (x) : (exists y = [x + 1 / 3], z = [x + 2 / 3]: y + z <= x).
-  IntegerPolyhedron poly2(1);
+  IntegerPolyhedron poly2(PresburgerSpace::getSetSpace(1));
   // Normalization test.
   poly2.addLocalFloorDiv({-2, 2}, 4);     // y = [-2x + 2 / 4] -> [-x + 1 / 2].
   poly2.addLocalFloorDiv({-1, 0, -2}, 3); // z = [-x - 2 / 3].
@@ -1206,7 +1207,7 @@ TEST(IntegerPolyhedronTest, containsPointNoLocal) {
 TEST(IntegerPolyhedronTest, truncateEqualityRegressionTest) {
   // IntegerRelation::truncate was truncating inequalities to the number of
   // equalities.
-  IntegerRelation set(1);
+  IntegerRelation set(PresburgerSpace::getSetSpace(1));
   IntegerRelation::CountsSnapshot snapshot = set.getCounts();
   set.addEquality({1, 0});
   set.truncate(snapshot);

diff  --git a/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp b/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp
index c2077ee26639f..548ba133c2854 100644
--- a/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/PresburgerSetTest.cpp
@@ -89,7 +89,8 @@ static void testComplementAtPoints(const PresburgerSet &s,
 /// local ids.
 static PresburgerSet makeSetFromPoly(unsigned numDims,
                                      ArrayRef<IntegerPolyhedron> polys) {
-  PresburgerSet set = PresburgerSet::getEmpty(numDims);
+  PresburgerSet set =
+      PresburgerSet::getEmpty(PresburgerSpace::getSetSpace(numDims));
   for (const IntegerPolyhedron &poly : polys)
     set.unionInPlace(poly);
   return set;
@@ -131,23 +132,26 @@ TEST(SetTest, Union) {
       {"(x) : (x - 2 >= 0, -x + 8 >= 0)", "(x) : (x - 10 >= 0, -x + 20 >= 0)"});
 
   // Universe union set.
-  testUnionAtPoints(PresburgerSet::getUniverse(1), set,
-                    {{1}, {2}, {8}, {9}, {10}, {20}, {21}});
+  testUnionAtPoints(PresburgerSet::getUniverse(PresburgerSpace::getSetSpace(1)),
+                    set, {{1}, {2}, {8}, {9}, {10}, {20}, {21}});
 
   // empty set union set.
-  testUnionAtPoints(PresburgerSet::getEmpty(1), set,
-                    {{1}, {2}, {8}, {9}, {10}, {20}, {21}});
+  testUnionAtPoints(PresburgerSet::getEmpty(PresburgerSpace::getSetSpace(1)),
+                    set, {{1}, {2}, {8}, {9}, {10}, {20}, {21}});
 
   // empty set union Universe.
-  testUnionAtPoints(PresburgerSet::getEmpty(1), PresburgerSet::getUniverse(1),
+  testUnionAtPoints(PresburgerSet::getEmpty(PresburgerSpace::getSetSpace(1)),
+                    PresburgerSet::getUniverse(PresburgerSpace::getSetSpace(1)),
                     {{1}, {2}, {0}, {-1}});
 
   // Universe union empty set.
-  testUnionAtPoints(PresburgerSet::getUniverse(1), PresburgerSet::getEmpty(1),
+  testUnionAtPoints(PresburgerSet::getUniverse(PresburgerSpace::getSetSpace(1)),
+                    PresburgerSet::getEmpty(PresburgerSpace::getSetSpace(1)),
                     {{1}, {2}, {0}, {-1}});
 
   // empty set union empty set.
-  testUnionAtPoints(PresburgerSet::getEmpty(1), PresburgerSet::getEmpty(1),
+  testUnionAtPoints(PresburgerSet::getEmpty(PresburgerSpace::getSetSpace((1))),
+                    PresburgerSet::getEmpty(PresburgerSpace::getSetSpace((1))),
                     {{1}, {2}, {0}, {-1}});
 }
 
@@ -157,24 +161,32 @@ TEST(SetTest, Intersect) {
       {"(x) : (x - 2 >= 0, -x + 8 >= 0)", "(x) : (x - 10 >= 0, -x + 20 >= 0)"});
 
   // Universe intersection set.
-  testIntersectAtPoints(PresburgerSet::getUniverse(1), set,
-                        {{1}, {2}, {8}, {9}, {10}, {20}, {21}});
+  testIntersectAtPoints(
+      PresburgerSet::getUniverse(PresburgerSpace::getSetSpace((1))), set,
+      {{1}, {2}, {8}, {9}, {10}, {20}, {21}});
 
   // empty set intersection set.
-  testIntersectAtPoints(PresburgerSet::getEmpty(1), set,
-                        {{1}, {2}, {8}, {9}, {10}, {20}, {21}});
+  testIntersectAtPoints(
+      PresburgerSet::getEmpty(PresburgerSpace::getSetSpace((1))), set,
+      {{1}, {2}, {8}, {9}, {10}, {20}, {21}});
 
   // empty set intersection Universe.
-  testIntersectAtPoints(PresburgerSet::getEmpty(1),
-                        PresburgerSet::getUniverse(1), {{1}, {2}, {0}, {-1}});
+  testIntersectAtPoints(
+      PresburgerSet::getEmpty(PresburgerSpace::getSetSpace((1))),
+      PresburgerSet::getUniverse(PresburgerSpace::getSetSpace((1))),
+      {{1}, {2}, {0}, {-1}});
 
   // Universe intersection empty set.
-  testIntersectAtPoints(PresburgerSet::getUniverse(1),
-                        PresburgerSet::getEmpty(1), {{1}, {2}, {0}, {-1}});
+  testIntersectAtPoints(
+      PresburgerSet::getUniverse(PresburgerSpace::getSetSpace((1))),
+      PresburgerSet::getEmpty(PresburgerSpace::getSetSpace((1))),
+      {{1}, {2}, {0}, {-1}});
 
   // Universe intersection Universe.
-  testIntersectAtPoints(PresburgerSet::getUniverse(1),
-                        PresburgerSet::getUniverse(1), {{1}, {2}, {0}, {-1}});
+  testIntersectAtPoints(
+      PresburgerSet::getUniverse(PresburgerSpace::getSetSpace((1))),
+      PresburgerSet::getUniverse(PresburgerSpace::getSetSpace((1))),
+      {{1}, {2}, {0}, {-1}});
 }
 
 TEST(SetTest, Subtract) {
@@ -329,12 +341,12 @@ TEST(SetTest, Subtract) {
 TEST(SetTest, Complement) {
   // Complement of universe.
   testComplementAtPoints(
-      PresburgerSet::getUniverse(1),
+      PresburgerSet::getUniverse(PresburgerSpace::getSetSpace((1))),
       {{-1}, {-2}, {-8}, {1}, {2}, {8}, {9}, {10}, {20}, {21}});
 
   // Complement of empty set.
   testComplementAtPoints(
-      PresburgerSet::getEmpty(1),
+      PresburgerSet::getEmpty(PresburgerSpace::getSetSpace((1))),
       {{-1}, {-2}, {-8}, {1}, {2}, {8}, {9}, {10}, {20}, {21}});
 
   testComplementAtPoints(
@@ -356,8 +368,10 @@ TEST(SetTest, Complement) {
 
 TEST(SetTest, isEqual) {
   // set = [2, 8] U [10, 20].
-  PresburgerSet universe = PresburgerSet::getUniverse(1);
-  PresburgerSet emptySet = PresburgerSet::getEmpty(1);
+  PresburgerSet universe =
+      PresburgerSet::getUniverse(PresburgerSpace::getSetSpace((1)));
+  PresburgerSet emptySet =
+      PresburgerSet::getEmpty(PresburgerSpace::getSetSpace((1)));
   PresburgerSet set = parsePresburgerSetFromPolyStrings(
       1,
       {"(x) : (x - 2 >= 0, -x + 8 >= 0)", "(x) : (x - 10 >= 0, -x + 20 >= 0)"});
@@ -431,7 +445,8 @@ TEST(SetTest, divisions) {
   // evens /\ odds = empty.
   expectEmpty(PresburgerSet(evens).intersect(PresburgerSet(odds)));
   // evens U odds = universe.
-  expectEqual(evens.unionSet(odds), PresburgerSet::getUniverse(1));
+  expectEqual(evens.unionSet(odds),
+              PresburgerSet::getUniverse(PresburgerSpace::getSetSpace((1))));
   expectEqual(evens.complement(), odds);
   expectEqual(odds.complement(), evens);
   // even multiples of 3 = multiples of 6.

diff  --git a/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp b/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp
index b88f83c5ef4e8..8f793747f2a5d 100644
--- a/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp
@@ -14,7 +14,7 @@ using namespace mlir;
 using namespace presburger;
 
 TEST(PresburgerSpaceTest, insertId) {
-  PresburgerSpace space(2, 2, 1);
+  PresburgerSpace space = PresburgerSpace::getRelationSpace(2, 2, 1);
 
   // Try inserting 2 domain ids.
   space.insertId(IdKind::Domain, 0, 2);
@@ -26,7 +26,7 @@ TEST(PresburgerSpaceTest, insertId) {
 }
 
 TEST(PresburgerSpaceTest, insertIdSet) {
-  PresburgerSpace space(0, 2, 1);
+  PresburgerSpace space = PresburgerSpace::getSetSpace(2, 1);
 
   // Try inserting 2 dimension ids. The space should have 4 range ids since
   // spaces which do not distinguish between domain, range are implemented like
@@ -36,7 +36,7 @@ TEST(PresburgerSpaceTest, insertIdSet) {
 }
 
 TEST(PresburgerSpaceTest, removeIdRange) {
-  PresburgerSpace space(2, 1, 3);
+  PresburgerSpace space = PresburgerSpace::getRelationSpace(2, 1, 3);
 
   // Remove 1 domain identifier.
   space.removeIdRange(IdKind::Domain, 0, 1);

diff  --git a/mlir/unittests/Analysis/Presburger/Utils.h b/mlir/unittests/Analysis/Presburger/Utils.h
index 4ed03c5b97e75..6b24ca0d576db 100644
--- a/mlir/unittests/Analysis/Presburger/Utils.h
+++ b/mlir/unittests/Analysis/Presburger/Utils.h
@@ -41,7 +41,8 @@ inline IntegerPolyhedron parsePoly(StringRef str) {
 /// number of dimensions as is specified by the numDims argument.
 inline PresburgerSet
 parsePresburgerSetFromPolyStrings(unsigned numDims, ArrayRef<StringRef> strs) {
-  PresburgerSet set = PresburgerSet::getEmpty(numDims);
+  PresburgerSet set =
+      PresburgerSet::getEmpty(PresburgerSpace::getSetSpace(numDims));
   for (StringRef str : strs)
     set.unionInPlace(parsePoly(str));
   return set;
@@ -70,7 +71,9 @@ inline PWMAFunction parsePWMAF(
     unsigned numSymbols = 0) {
   static MLIRContext context;
 
-  PWMAFunction result(numInputs - numSymbols, numSymbols, numOutputs);
+  PWMAFunction result(
+      PresburgerSpace::getSetSpace(numInputs - numSymbols, numSymbols),
+      numOutputs);
   for (const auto &pair : data) {
     IntegerPolyhedron domain = parsePoly(pair.first);
 


        


More information about the Mlir-commits mailing list