[Mlir-commits] [mlir] 4807587 - [MLIR][Presburger] Factor out space information to PresburgerSpace

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Feb 10 04:59:28 PST 2022


Author: Groverkss
Date: 2022-02-10T18:24:40+05:30
New Revision: 4807587cf2fec24a1076264b8ada78a7ed9ce531

URL: https://github.com/llvm/llvm-project/commit/4807587cf2fec24a1076264b8ada78a7ed9ce531
DIFF: https://github.com/llvm/llvm-project/commit/4807587cf2fec24a1076264b8ada78a7ed9ce531.diff

LOG: [MLIR][Presburger] Factor out space information to PresburgerSpace

This patch factors out space information from IntegerPolyhedron, PresburgerSet
and PWMAFunction to PresburgerSpace and its extension with local variables,
PresburgerLocalSpace.

Generally any new data structure additions in Presburger library will require
space information. This patch removes the need to duplicate the space
information.

Reviewed By: arjunp

Differential Revision: https://reviews.llvm.org/D119280

Added: 
    mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h
    mlir/lib/Analysis/Presburger/PresburgerSpace.cpp

Modified: 
    mlir/include/mlir/Analysis/Presburger/IntegerPolyhedron.h
    mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
    mlir/include/mlir/Analysis/Presburger/PresburgerSet.h
    mlir/include/mlir/Dialect/Affine/Analysis/AffineStructures.h
    mlir/lib/Analysis/Presburger/CMakeLists.txt
    mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp
    mlir/lib/Analysis/Presburger/PresburgerSet.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/Presburger/IntegerPolyhedron.h b/mlir/include/mlir/Analysis/Presburger/IntegerPolyhedron.h
index a061ae8ab4cf6..5a1d6df84f736 100644
--- a/mlir/include/mlir/Analysis/Presburger/IntegerPolyhedron.h
+++ b/mlir/include/mlir/Analysis/Presburger/IntegerPolyhedron.h
@@ -15,6 +15,7 @@
 
 #include "mlir/Analysis/Presburger/Fraction.h"
 #include "mlir/Analysis/Presburger/Matrix.h"
+#include "mlir/Analysis/Presburger/PresburgerSpace.h"
 #include "mlir/Analysis/Presburger/Utils.h"
 #include "mlir/Support/LogicalResult.h"
 
@@ -50,7 +51,7 @@ namespace mlir {
 /// example, `q` is existentially quantified. This can be thought of as the
 /// result of projecting out `q` from the previous example, i.e. we obtained {2,
 /// 4, 6} by projecting out the second dimension from {(2, 1), (4, 2), (6, 2)}.
-class IntegerPolyhedron {
+class IntegerPolyhedron : public PresburgerLocalSpace {
 public:
   /// All derived classes of IntegerPolyhedron.
   enum class Kind {
@@ -60,19 +61,16 @@ class IntegerPolyhedron {
     IntegerPolyhedron
   };
 
-  /// Kind of identifier (column).
-  enum IdKind { Dimension, Symbol, Local };
-
   /// Constructs a constraint system reserving memory for the specified number
   /// of constraints and identifiers.
   IntegerPolyhedron(unsigned numReservedInequalities,
                     unsigned numReservedEqualities, unsigned numReservedCols,
                     unsigned numDims, unsigned numSymbols, unsigned numLocals)
-      : numIds(numDims + numSymbols + numLocals), numDims(numDims),
-        numSymbols(numSymbols),
-        equalities(0, numIds + 1, numReservedEqualities, numReservedCols),
-        inequalities(0, numIds + 1, numReservedInequalities, numReservedCols) {
-    assert(numReservedCols >= numIds + 1);
+      : PresburgerLocalSpace(numDims, numSymbols, numLocals),
+        equalities(0, getNumIds() + 1, numReservedEqualities, numReservedCols),
+        inequalities(0, getNumIds() + 1, numReservedInequalities,
+                     numReservedCols) {
+    assert(numReservedCols >= getNumIds() + 1);
   }
 
   /// Constructs a constraint system with the specified number of
@@ -92,8 +90,6 @@ class IntegerPolyhedron {
     return IntegerPolyhedron(numDims, numSymbols);
   }
 
-  virtual ~IntegerPolyhedron() = default;
-
   /// Return the kind of this IntegerPolyhedron.
   virtual Kind getKind() const { return Kind::IntegerPolyhedron; }
 
@@ -139,16 +135,9 @@ class IntegerPolyhedron {
   unsigned getNumConstraints() const {
     return getNumInequalities() + getNumEqualities();
   }
-  inline unsigned getNumIds() const { return numIds; }
-  inline unsigned getNumDimIds() const { return numDims; }
-  inline unsigned getNumSymbolIds() const { return numSymbols; }
-  inline unsigned getNumDimAndSymbolIds() const { return numDims + numSymbols; }
-  inline unsigned getNumLocalIds() const {
-    return numIds - numDims - numSymbols;
-  }
 
   /// Returns the number of columns in the constraint system.
-  inline unsigned getNumCols() const { return numIds + 1; }
+  inline unsigned getNumCols() const { return getNumIds() + 1; }
 
   inline unsigned getNumEqualities() const { return equalities.getNumRows(); }
 
@@ -180,7 +169,7 @@ class IntegerPolyhedron {
   unsigned insertDimId(unsigned pos, unsigned num = 1);
   unsigned insertSymbolId(unsigned pos, unsigned num = 1);
   unsigned insertLocalId(unsigned pos, unsigned num = 1);
-  virtual unsigned insertId(IdKind kind, unsigned pos, unsigned num = 1);
+  unsigned insertId(IdKind kind, unsigned pos, unsigned num = 1) override;
 
   /// Append `num` identifiers of the specified kind after the last identifier.
   /// of that kind. Return the position of the first appended column. The
@@ -330,11 +319,6 @@ class IntegerPolyhedron {
   void projectOut(unsigned pos, unsigned num);
   inline void projectOut(unsigned pos) { return projectOut(pos, 1); }
 
-  /// Changes the partition between dimensions and symbols. Depending on the new
-  /// symbol count, either a chunk of trailing dimensional identifiers becomes
-  /// symbols, or some of the leading symbols become dimensions.
-  void setDimSymbolSeparation(unsigned newSymbolCount);
-
   /// Tries to fold the specified identifier to a constant using a trivial
   /// equality detection; if successful, the constant is substituted for the
   /// identifier everywhere in the constraint system and then removed from the
@@ -508,24 +492,10 @@ class IntegerPolyhedron {
   /// IntegerPolyhedron.
   virtual void printSpace(raw_ostream &os) const;
 
-  /// Return the index at which the specified kind of id starts.
-  unsigned getIdKindOffset(IdKind kind) const;
-
-  /// Return the index at which the specified kind of id ends.
-  unsigned getIdKindEnd(IdKind kind) const;
-
-  /// Get the number of ids of the specified kind.
-  unsigned getNumIdKind(IdKind kind) const;
-
-  /// Get the number of elements of the specified kind in the range
-  /// [idStart, idLimit).
-  unsigned getIdKindOverlap(IdKind kind, unsigned idStart,
-                            unsigned idLimit) const;
-
   /// Removes identifiers in the column range [idStart, idLimit), and copies any
   /// remaining valid data into place, updates member variables, and resizes
   /// arrays as needed.
-  virtual void removeIdRange(unsigned idStart, unsigned idLimit);
+  void removeIdRange(unsigned idStart, unsigned idLimit) override;
 
   /// A parameter that controls detection of an unrealistic number of
   /// constraints. If the number of constraints is this many times the number of
@@ -539,16 +509,6 @@ class IntegerPolyhedron {
   // constraints. This is conservatively set low and can be raised if needed.
   constexpr static unsigned kExplosionFactor = 32;
 
-  /// Total number of identifiers.
-  unsigned numIds;
-
-  /// Number of identifiers corresponding to real dimensions.
-  unsigned numDims;
-
-  /// Number of identifiers corresponding to symbols (unknown but constant for
-  /// analysis).
-  unsigned numSymbols;
-
   /// Coefficients of affine equalities (in == 0 form).
   Matrix equalities;
 

diff  --git a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
index a01c3ef7a3e61..26958e4308f13 100644
--- a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
+++ b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
@@ -67,7 +67,9 @@ class MultiAffineFunction : protected IntegerPolyhedron {
 
   unsigned getNumInputs() const { return getNumDimAndSymbolIds(); }
   unsigned getNumOutputs() const { return output.getNumRows(); }
-  bool isConsistent() const { return output.getNumColumns() == numIds + 1; }
+  bool isConsistent() const {
+    return output.getNumColumns() == getNumIds() + 1;
+  }
   const IntegerPolyhedron &getDomain() const { return *this; }
 
   bool hasCompatibleDimensions(const MultiAffineFunction &f) const;
@@ -136,10 +138,10 @@ class MultiAffineFunction : protected IntegerPolyhedron {
 /// Support is provided to compare equality of two such functions as well as
 /// finding the value of the function at a point. Note that local ids in the
 /// piece are not supported for the latter.
-class PWMAFunction {
+class PWMAFunction : PresburgerSpace {
 public:
   PWMAFunction(unsigned numDims, unsigned numSymbols, unsigned numOutputs)
-      : numDims(numDims), numSymbols(numSymbols), numOutputs(numOutputs) {
+      : PresburgerSpace(numDims, numSymbols), numOutputs(numOutputs) {
     assert(numOutputs >= 1 && "The function must output something!");
   }
 
@@ -149,9 +151,7 @@ class PWMAFunction {
   const MultiAffineFunction &getPiece(unsigned i) const { return pieces[i]; }
   unsigned getNumPieces() const { return pieces.size(); }
   unsigned getNumOutputs() const { return numOutputs; }
-  unsigned getNumInputs() const { return numDims + numSymbols; }
-  unsigned getNumDimIds() const { return numDims; }
-  unsigned getNumSymbolIds() const { return numSymbols; }
+  unsigned getNumInputs() const { return getNumIds(); }
   MultiAffineFunction &getPiece(unsigned i) { return pieces[i]; }
 
   /// Return the domain of this piece-wise MultiAffineFunction. This is the
@@ -182,10 +182,6 @@ class PWMAFunction {
   /// The list of pieces in this piece-wise MultiAffineFunction.
   SmallVector<MultiAffineFunction, 4> pieces;
 
-  /// The number of dimensions ids in the domains.
-  unsigned numDims;
-  /// The number of symbol ids in the domains.
-  unsigned numSymbols;
   /// The number of output ids.
   unsigned numOutputs;
 };

diff  --git a/mlir/include/mlir/Analysis/Presburger/PresburgerSet.h b/mlir/include/mlir/Analysis/Presburger/PresburgerSet.h
index af3e0178d5c25..0bf7c8d7daad4 100644
--- a/mlir/include/mlir/Analysis/Presburger/PresburgerSet.h
+++ b/mlir/include/mlir/Analysis/Presburger/PresburgerSet.h
@@ -28,19 +28,13 @@ namespace mlir {
 /// Note that there are no invariants guaranteed on the list of Poly other than
 /// that they are all in the same space, i.e., they all have the same number of
 /// dimensions and symbols. For example, the Polys may overlap each other.
-class PresburgerSet {
+class PresburgerSet : public PresburgerSpace {
 public:
   explicit PresburgerSet(const IntegerPolyhedron &poly);
 
   /// Return the number of Polys in the union.
   unsigned getNumPolys() const;
 
-  /// Return the number of real dimensions.
-  unsigned getNumDimIds() const;
-
-  /// Return the number of symbolic dimensions.
-  unsigned getNumSymbolIds() const;
-
   /// Return a reference to the list of IntegerPolyhedrons.
   ArrayRef<IntegerPolyhedron> getAllIntegerPolyhedron() const;
 
@@ -117,19 +111,12 @@ class PresburgerSet {
 private:
   /// Construct an empty PresburgerSet.
   PresburgerSet(unsigned numDims = 0, unsigned numSymbols = 0)
-      : numDims(numDims), numSymbols(numSymbols) {}
+      : PresburgerSpace(numDims, numSymbols) {}
 
   /// Return the set 
diff erence poly \ set.
   static PresburgerSet getSetDifference(IntegerPolyhedron poly,
                                         const PresburgerSet &set);
 
-  /// Number of identifiers corresponding to real dimensions.
-  unsigned numDims;
-
-  /// Number of symbolic dimensions, unknown but constant for analysis, as in
-  /// IntegerPolyhedron.
-  unsigned numSymbols;
-
   /// The list of integerPolyhedrons that this set is the union of.
   SmallVector<IntegerPolyhedron, 2> integerPolyhedrons;
 };

diff  --git a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h
new file mode 100644
index 0000000000000..b97c8f67b28af
--- /dev/null
+++ b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h
@@ -0,0 +1,117 @@
+//===- PresburgerSpace.h - MLIR PresburgerSpace Class -----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Classes representing space information like number of identifiers and kind of
+// identifiers.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ANALYSIS_PRESBURGER_PRESBURGERSPACE_H
+#define MLIR_ANALYSIS_PRESBURGER_PRESBURGERSPACE_H
+
+#include "llvm/Support/ErrorHandling.h"
+
+namespace mlir {
+
+class PresburgerLocalSpace;
+
+/// PresburgerSpace is a tuple of identifiers with information about what kind
+/// they correspond to. The identifiers can be split into three types:
+///
+/// Dimension: Ordinary variables over which the space is represented.
+///
+/// Symbol: Symbol identifiers correspond to fixed but unknown values.
+/// Mathematically, a space with symbolic identifiers is like a
+/// family of spaces indexed by the symbolic identifiers.
+///
+/// Local: Local identifiers correspond to existentially quantified variables.
+///
+/// PresburgerSpace only supports identifiers of kind Dimension and Symbol.
+class PresburgerSpace {
+  friend PresburgerLocalSpace;
+
+public:
+  /// Kind of identifier (column).
+  enum IdKind { Dimension, Symbol, Local };
+
+  PresburgerSpace(unsigned numDims, unsigned numSymbols)
+      : numDims(numDims), numSymbols(numSymbols), numLocals(0) {}
+
+  virtual ~PresburgerSpace() = default;
+
+  unsigned getNumIds() const { return numDims + numSymbols + numLocals; }
+  unsigned getNumDimIds() const { return numDims; }
+  unsigned getNumSymbolIds() const { return numSymbols; }
+  unsigned getNumDimAndSymbolIds() const { return numDims + numSymbols; }
+
+  /// Get the number of ids of the specified kind.
+  unsigned getNumIdKind(IdKind kind) const;
+
+  /// Return the index at which the specified kind of id starts.
+  unsigned getIdKindOffset(IdKind kind) const;
+
+  /// Return the index at Which the specified kind of id ends.
+  unsigned getIdKindEnd(IdKind kind) const;
+
+  /// Get the number of elements of the specified kind in the range
+  /// [idStart, idLimit).
+  unsigned getIdKindOverlap(IdKind kind, unsigned idStart,
+                            unsigned idLimit) const;
+
+  /// Insert `num` identifiers of the specified kind at position `pos`.
+  /// Positions are relative to the kind of identifier. Return the absolute
+  /// column position (i.e., not relative to the kind of identifier) of the
+  /// first added identifier.
+  virtual unsigned insertId(IdKind kind, unsigned pos, unsigned num = 1);
+
+  /// Removes identifiers in the column range [idStart, idLimit).
+  virtual void removeIdRange(unsigned idStart, unsigned idLimit);
+
+  /// Changes the partition between dimensions and symbols. Depending on the new
+  /// symbol count, either a chunk of dimensional identifiers immediately before
+  /// the split become symbols, or some of the symbols immediately after the
+  /// split become dimensions.
+  void setDimSymbolSeparation(unsigned newSymbolCount);
+
+private:
+  PresburgerSpace(unsigned numDims, unsigned numSymbols, unsigned numLocals)
+      : numDims(numDims), numSymbols(numSymbols), numLocals(numLocals) {}
+
+  /// Number of identifiers corresponding to real dimensions.
+  unsigned numDims;
+
+  /// Number of identifiers corresponding to symbols (unknown but constant for
+  /// analysis).
+  unsigned numSymbols;
+
+  /// Total number of identifiers.
+  unsigned numLocals;
+};
+
+/// Extension of PresburgerSpace supporting Local identifiers.
+class PresburgerLocalSpace : public PresburgerSpace {
+public:
+  PresburgerLocalSpace(unsigned numDims, unsigned numSymbols,
+                       unsigned numLocals)
+      : PresburgerSpace(numDims, numSymbols, numLocals) {}
+
+  unsigned getNumLocalIds() const { return numLocals; }
+
+  /// Insert `num` identifiers of the specified kind at position `pos`.
+  /// Positions are relative to the kind of identifier. Return the absolute
+  /// column position (i.e., not relative to the kind of identifier) of the
+  /// first added identifier.
+  unsigned insertId(IdKind kind, unsigned pos, unsigned num = 1) override;
+
+  /// Removes identifiers in the column range [idStart, idLimit).
+  void removeIdRange(unsigned idStart, unsigned idLimit) override;
+};
+
+} // namespace mlir
+
+#endif // MLIR_ANALYSIS_PRESBURGER_PRESBURGERSPACE_H

diff  --git a/mlir/include/mlir/Dialect/Affine/Analysis/AffineStructures.h b/mlir/include/mlir/Dialect/Affine/Analysis/AffineStructures.h
index e53dc74e73eb3..3f3f854e434d5 100644
--- a/mlir/include/mlir/Dialect/Affine/Analysis/AffineStructures.h
+++ b/mlir/include/mlir/Dialect/Affine/Analysis/AffineStructures.h
@@ -188,11 +188,11 @@ class FlatAffineValueConstraints : public FlatAffineConstraints {
                              ArrayRef<Optional<Value>> valArgs = {})
       : FlatAffineConstraints(numReservedInequalities, numReservedEqualities,
                               numReservedCols, numDims, numSymbols, numLocals) {
-    assert(numReservedCols >= numIds + 1);
-    assert(valArgs.empty() || valArgs.size() == numIds);
+    assert(numReservedCols >= getNumIds() + 1);
+    assert(valArgs.empty() || valArgs.size() == getNumIds());
     values.reserve(numReservedCols);
     if (valArgs.empty())
-      values.resize(numIds, None);
+      values.resize(getNumIds(), None);
     else
       values.append(valArgs.begin(), valArgs.end());
   }
@@ -211,9 +211,9 @@ class FlatAffineValueConstraints : public FlatAffineConstraints {
   FlatAffineValueConstraints(const FlatAffineConstraints &fac,
                              ArrayRef<Optional<Value>> valArgs = {})
       : FlatAffineConstraints(fac) {
-    assert(valArgs.empty() || valArgs.size() == numIds);
+    assert(valArgs.empty() || valArgs.size() == getNumIds());
     if (valArgs.empty())
-      values.resize(numIds, None);
+      values.resize(getNumIds(), None);
     else
       values.append(valArgs.begin(), valArgs.end());
   }
@@ -463,15 +463,15 @@ class FlatAffineValueConstraints : public FlatAffineConstraints {
   /// Asserts if no Value was associated with one of these identifiers.
   inline void getValues(unsigned start, unsigned end,
                         SmallVectorImpl<Value> *values) const {
-    assert((start < numIds || start == end) && "invalid start position");
-    assert(end <= numIds && "invalid end position");
+    assert((start < getNumIds() || start == end) && "invalid start position");
+    assert(end <= getNumIds() && "invalid end position");
     values->clear();
     values->reserve(end - start);
     for (unsigned i = start; i < end; i++)
       values->push_back(getValue(i));
   }
   inline void getAllValues(SmallVectorImpl<Value> *values) const {
-    getValues(0, numIds, values);
+    getValues(0, getNumIds(), values);
   }
 
   inline ArrayRef<Optional<Value>> getMaybeValues() const {
@@ -492,14 +492,14 @@ class FlatAffineValueConstraints : public FlatAffineConstraints {
 
   /// Sets the Value associated with the pos^th identifier.
   inline void setValue(unsigned pos, Value val) {
-    assert(pos < numIds && "invalid id position");
+    assert(pos < getNumIds() && "invalid id position");
     values[pos] = val;
   }
 
   /// Sets the Values associated with the identifiers in the range [start, end).
   void setValues(unsigned start, unsigned end, ArrayRef<Value> values) {
-    assert((start < numIds || end == start) && "invalid start position");
-    assert(end <= numIds && "invalid end position");
+    assert((start < getNumIds() || end == start) && "invalid start position");
+    assert(end <= getNumIds() && "invalid end position");
     assert(values.size() == end - start);
     for (unsigned i = start; i < end; ++i)
       setValue(i, values[i - start]);

diff  --git a/mlir/lib/Analysis/Presburger/CMakeLists.txt b/mlir/lib/Analysis/Presburger/CMakeLists.txt
index 313742f7e3d8b..042c089553ea3 100644
--- a/mlir/lib/Analysis/Presburger/CMakeLists.txt
+++ b/mlir/lib/Analysis/Presburger/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_library(MLIRPresburger
   LinearTransform.cpp
   Matrix.cpp
   PresburgerSet.cpp
+  PresburgerSpace.cpp
   PWMAFunction.cpp
   Simplex.cpp
   Utils.cpp

diff  --git a/mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp b/mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp
index f68bd1c61012c..60a1361ec81aa 100644
--- a/mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp
+++ b/mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp
@@ -108,16 +108,10 @@ unsigned IntegerPolyhedron::insertId(IdKind kind, unsigned pos, unsigned num) {
   assert(pos <= getNumIdKind(kind));
 
   unsigned absolutePos = getIdKindOffset(kind) + pos;
-  if (kind == IdKind::Dimension)
-    numDims += num;
-  else if (kind == IdKind::Symbol)
-    numSymbols += num;
-  numIds += num;
-
   inequalities.insertColumns(absolutePos, num);
   equalities.insertColumns(absolutePos, num);
 
-  return absolutePos;
+  return PresburgerLocalSpace::insertId(kind, pos, num);
 }
 
 unsigned IntegerPolyhedron::appendDimId(unsigned num) {
@@ -166,27 +160,12 @@ void IntegerPolyhedron::removeIdRange(IdKind kind, unsigned idStart,
 }
 
 void IntegerPolyhedron::removeIdRange(unsigned idStart, unsigned idLimit) {
-  assert(idLimit < getNumCols() && "invalid id limit");
-
-  if (idStart >= idLimit)
-    return;
-
-  // We are going to be removing one or more identifiers from the range.
-  assert(idStart < getNumIds() && "invalid idStart position");
+  // Update space paramaters.
+  PresburgerLocalSpace::removeIdRange(idStart, idLimit);
 
   // Remove eliminated identifiers from the constraints..
   equalities.removeColumns(idStart, idLimit - idStart);
   inequalities.removeColumns(idStart, idLimit - idStart);
-
-  // Update members numDims, numSymbols and numIds.
-  unsigned numDimsEliminated =
-      getIdKindOverlap(IdKind::Dimension, idStart, idLimit);
-  unsigned numSymbolsEliminated =
-      getIdKindOverlap(IdKind::Symbol, idStart, idLimit);
-
-  numDims -= numDimsEliminated;
-  numSymbols -= numSymbolsEliminated;
-  numIds -= (idLimit - idStart);
 }
 
 void IntegerPolyhedron::removeEquality(unsigned pos) {
@@ -222,45 +201,6 @@ void IntegerPolyhedron::swapId(unsigned posA, unsigned posB) {
     std::swap(atEq(r, posA), atEq(r, posB));
 }
 
-unsigned IntegerPolyhedron::getIdKindOffset(IdKind kind) const {
-  if (kind == IdKind::Dimension)
-    return 0;
-  if (kind == IdKind::Symbol)
-    return getNumDimIds();
-  if (kind == IdKind::Local)
-    return getNumDimAndSymbolIds();
-  llvm_unreachable("IdKind expected to be Dimension, Symbol or Local!");
-}
-
-unsigned IntegerPolyhedron::getIdKindEnd(IdKind kind) const {
-  return getIdKindOffset(kind) + getNumIdKind(kind);
-}
-
-unsigned IntegerPolyhedron::getNumIdKind(IdKind kind) const {
-  if (kind == IdKind::Dimension)
-    return getNumDimIds();
-  if (kind == IdKind::Symbol)
-    return getNumSymbolIds();
-  if (kind == IdKind::Local)
-    return getNumLocalIds();
-  llvm_unreachable("IdKind expected to be Dimension, Symbol or Local!");
-}
-
-unsigned IntegerPolyhedron::getIdKindOverlap(IdKind kind, unsigned idStart,
-                                             unsigned idLimit) const {
-  unsigned idRangeStart = getIdKindOffset(kind);
-  unsigned idRangeEnd = getIdKindEnd(kind);
-
-  // Compute number of elements in intersection of the ranges [idStart, idLimit)
-  // and [idRangeStart, idRangeEnd).
-  unsigned overlapStart = std::max(idStart, idRangeStart);
-  unsigned overlapEnd = std::min(idLimit, idRangeEnd);
-
-  if (overlapStart > overlapEnd)
-    return 0;
-  return overlapEnd - overlapStart;
-}
-
 void IntegerPolyhedron::clearConstraints() {
   equalities.resizeVertically(0);
   inequalities.resizeVertically(0);
@@ -1308,13 +1248,6 @@ void IntegerPolyhedron::addLocalFloorDiv(ArrayRef<int64_t> dividend,
   addInequality(bound);
 }
 
-void IntegerPolyhedron::setDimSymbolSeparation(unsigned newSymbolCount) {
-  assert(newSymbolCount <= getNumDimAndSymbolIds() &&
-         "invalid separation position");
-  numDims = numDims + numSymbols - newSymbolCount;
-  numSymbols = newSymbolCount;
-}
-
 /// Finds an equality that equates the specified identifier to a constant.
 /// Returns the position of the equality row. If 'symbolic' is set to true,
 /// symbols are also treated like a constant, i.e., an affine function of the

diff  --git a/mlir/lib/Analysis/Presburger/PresburgerSet.cpp b/mlir/lib/Analysis/Presburger/PresburgerSet.cpp
index 198eb8d62a593..b75cd1d51fddd 100644
--- a/mlir/lib/Analysis/Presburger/PresburgerSet.cpp
+++ b/mlir/lib/Analysis/Presburger/PresburgerSet.cpp
@@ -16,7 +16,7 @@ using namespace mlir;
 using namespace presburger_utils;
 
 PresburgerSet::PresburgerSet(const IntegerPolyhedron &poly)
-    : numDims(poly.getNumDimIds()), numSymbols(poly.getNumSymbolIds()) {
+    : PresburgerSpace(poly) {
   unionPolyInPlace(poly);
 }
 
@@ -24,10 +24,6 @@ unsigned PresburgerSet::getNumPolys() const {
   return integerPolyhedrons.size();
 }
 
-unsigned PresburgerSet::getNumDimIds() const { return numDims; }
-
-unsigned PresburgerSet::getNumSymbolIds() const { return numSymbols; }
-
 ArrayRef<IntegerPolyhedron> PresburgerSet::getAllIntegerPolyhedron() const {
   return integerPolyhedrons;
 }

diff  --git a/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp b/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp
new file mode 100644
index 0000000000000..8e63eb24717ae
--- /dev/null
+++ b/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp
@@ -0,0 +1,123 @@
+//===- PresburgerSpace.cpp - MLIR PresburgerSpace Class -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/Presburger/PresburgerSpace.h"
+#include <algorithm>
+#include <cassert>
+
+using namespace mlir;
+
+unsigned PresburgerSpace::getNumIdKind(IdKind kind) const {
+  if (kind == IdKind::Dimension)
+    return getNumDimIds();
+  if (kind == IdKind::Symbol)
+    return getNumSymbolIds();
+  if (kind == IdKind::Local)
+    return numLocals;
+  llvm_unreachable("IdKind does not exit!");
+}
+
+unsigned PresburgerSpace::getIdKindOffset(IdKind kind) const {
+  if (kind == IdKind::Dimension)
+    return 0;
+  if (kind == IdKind::Symbol)
+    return getNumDimIds();
+  if (kind == IdKind::Local)
+    return getNumDimAndSymbolIds();
+  llvm_unreachable("IdKind does not exit!");
+}
+
+unsigned PresburgerSpace::getIdKindEnd(IdKind kind) const {
+  return getIdKindOffset(kind) + getNumIdKind(kind);
+}
+
+unsigned PresburgerSpace::getIdKindOverlap(IdKind kind, unsigned idStart,
+                                           unsigned idLimit) const {
+  unsigned idRangeStart = getIdKindOffset(kind);
+  unsigned idRangeEnd = getIdKindEnd(kind);
+
+  // Compute number of elements in intersection of the ranges [idStart, idLimit)
+  // and [idRangeStart, idRangeEnd).
+  unsigned overlapStart = std::max(idStart, idRangeStart);
+  unsigned overlapEnd = std::min(idLimit, idRangeEnd);
+
+  if (overlapStart > overlapEnd)
+    return 0;
+  return overlapEnd - overlapStart;
+}
+
+unsigned PresburgerSpace::insertId(IdKind kind, unsigned pos, unsigned num) {
+  assert(pos <= getNumIdKind(kind));
+
+  unsigned absolutePos = getIdKindOffset(kind) + pos;
+
+  if (kind == IdKind::Dimension)
+    numDims += num;
+  else if (kind == IdKind::Symbol)
+    numSymbols += num;
+  else
+    llvm_unreachable(
+        "PresburgerSpace only supports Dimensions and Symbol identifiers!");
+
+  return absolutePos;
+}
+
+void PresburgerSpace::removeIdRange(unsigned idStart, unsigned idLimit) {
+  assert(idLimit <= getNumIds() && "invalid id limit");
+
+  if (idStart >= idLimit)
+    return;
+
+  // We are going to be removing one or more identifiers from the range.
+  assert(idStart < getNumIds() && "invalid idStart position");
+
+  // Update members numDims, numSymbols and numIds.
+  unsigned numDimsEliminated =
+      getIdKindOverlap(IdKind::Dimension, idStart, idLimit);
+  unsigned numSymbolsEliminated =
+      getIdKindOverlap(IdKind::Symbol, idStart, idLimit);
+
+  numDims -= numDimsEliminated;
+  numSymbols -= numSymbolsEliminated;
+}
+
+unsigned PresburgerLocalSpace::insertId(IdKind kind, unsigned pos,
+                                        unsigned num) {
+  if (kind == IdKind::Local) {
+    numLocals += num;
+    return getIdKindOffset(IdKind::Local) + pos;
+  }
+  return PresburgerSpace::insertId(kind, pos, num);
+}
+
+void PresburgerLocalSpace::removeIdRange(unsigned idStart, unsigned idLimit) {
+  assert(idLimit <= getNumIds() && "invalid id limit");
+
+  if (idStart >= idLimit)
+    return;
+
+  // We are going to be removing one or more identifiers from the range.
+  assert(idStart < getNumIds() && "invalid idStart position");
+
+  unsigned numLocalsEliminated =
+      getIdKindOverlap(IdKind::Local, idStart, idLimit);
+
+  // Update space parameters.
+  PresburgerSpace::removeIdRange(
+      idStart, std::min(idLimit, PresburgerSpace::getNumIds()));
+
+  // Update local ids.
+  numLocals -= numLocalsEliminated;
+}
+
+void PresburgerSpace::setDimSymbolSeparation(unsigned newSymbolCount) {
+  assert(newSymbolCount <= getNumDimAndSymbolIds() &&
+         "invalid separation position");
+  numDims = numDims + numSymbols - newSymbolCount;
+  numSymbols = newSymbolCount;
+}


        


More information about the Mlir-commits mailing list