[Mlir-commits] [mlir] 15650b3 - [MLIR][Presburger] Remove inheritence in MultiAffineFunction

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Apr 18 12:58:08 PDT 2022


Author: Groverkss
Date: 2022-04-19T01:25:13+05:30
New Revision: 15650b320bf6a1ce5b7e00147d1cf2725946aab2

URL: https://github.com/llvm/llvm-project/commit/15650b320bf6a1ce5b7e00147d1cf2725946aab2
DIFF: https://github.com/llvm/llvm-project/commit/15650b320bf6a1ce5b7e00147d1cf2725946aab2.diff

LOG: [MLIR][Presburger] Remove inheritence in MultiAffineFunction

This patch removes inheritence of MultiAffineFunction from IntegerPolyhedron
and instead makes IntegerPolyhedron as a member.

This patch removes virtualization in MultiAffineFunction and also removes
unnecessary functions inherited from IntegerPolyhedron.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D123921

Added: 
    

Modified: 
    mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
    mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
    mlir/include/mlir/Analysis/Presburger/Utils.h
    mlir/lib/Analysis/Presburger/IntegerRelation.cpp
    mlir/lib/Analysis/Presburger/PWMAFunction.cpp
    mlir/lib/Analysis/Presburger/Utils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
index 5b4bdc9f27973..b6f127784b18b 100644
--- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
+++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h
@@ -49,7 +49,6 @@ class IntegerRelation {
   enum class Kind {
     FlatAffineConstraints,
     FlatAffineValueConstraints,
-    MultiAffineFunction,
     IntegerRelation,
     IntegerPolyhedron,
   };

diff  --git a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
index 3240dde15c569..ecb5d34cc3803 100644
--- a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
+++ b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
@@ -42,54 +42,35 @@ namespace presburger {
 ///
 /// Checking equality of two such functions is supported, as well as finding the
 /// value of the function at a specified point.
-class MultiAffineFunction : protected IntegerPolyhedron {
+class MultiAffineFunction {
 public:
-  /// We use protected inheritance to avoid inheriting the whole public
-  /// interface of IntegerPolyhedron. These using declarations explicitly make
-  /// only the relevant functions part of the public interface.
-  using IntegerPolyhedron::getNumDimAndSymbolIds;
-  using IntegerPolyhedron::getNumDimIds;
-  using IntegerPolyhedron::getNumIds;
-  using IntegerPolyhedron::getNumLocalIds;
-  using IntegerPolyhedron::getNumSymbolIds;
-  using IntegerPolyhedron::getSpace;
-
   MultiAffineFunction(const IntegerPolyhedron &domain, const Matrix &output)
-      : IntegerPolyhedron(domain), output(output) {}
+      : domainSet(domain), output(output) {}
   MultiAffineFunction(const Matrix &output, const PresburgerSpace &space)
-      : IntegerPolyhedron(space), output(output) {}
-
-  ~MultiAffineFunction() override = default;
-  Kind getKind() const override { return Kind::MultiAffineFunction; }
-  bool classof(const IntegerRelation *rel) const {
-    return rel->getKind() == Kind::MultiAffineFunction;
-  }
+      : domainSet(space), output(output) {}
 
-  unsigned getNumInputs() const { return getNumDimAndSymbolIds(); }
+  unsigned getNumInputs() const { return domainSet.getNumDimAndSymbolIds(); }
   unsigned getNumOutputs() const { return output.getNumRows(); }
   bool isConsistent() const {
-    return output.getNumColumns() == getNumIds() + 1;
+    return output.getNumColumns() == domainSet.getNumIds() + 1;
   }
-  const IntegerPolyhedron &getDomain() const { return *this; }
+  const IntegerPolyhedron &getDomain() const { return domainSet; }
+  const PresburgerSpace &getDomainSpace() const { return domainSet.getSpace(); }
 
   /// Insert `num` identifiers of the specified kind at position `pos`.
   /// Positions are relative to the kind of identifier. The coefficient columns
   /// corresponding to the added identifiers are initialized to zero. Return the
   /// absolute column position (i.e., not relative to the kind of identifier)
   /// of the first added identifier.
-  unsigned insertId(IdKind kind, unsigned pos, unsigned num = 1) override;
-
-  /// Swap the posA^th identifier with the posB^th identifier.
-  void swapId(unsigned posA, unsigned posB) override;
+  unsigned insertId(IdKind kind, unsigned pos, unsigned num = 1);
 
   /// Remove the specified range of ids.
-  void removeIdRange(IdKind kind, unsigned idStart, unsigned idLimit) override;
-  using IntegerRelation::removeIdRange;
+  void removeIdRange(IdKind kind, unsigned idStart, unsigned idLimit);
 
-  /// Eliminate the `posB^th` local identifier, replacing every instance of it
-  /// with the `posA^th` local identifier. This should be used when the two
-  /// local variables are known to always take the same values.
-  void eliminateRedundantLocalId(unsigned posA, unsigned posB) override;
+  /// Given a MAF `other`, merges local identifiers such that both funcitons
+  /// have union of local ids, without changing the set of points in domain or
+  /// the output.
+  void mergeLocalIds(MultiAffineFunction &other);
 
   /// Return whether the outputs of `this` and `other` agree wherever both
   /// functions are defined, i.e., the outputs should be equal for all points in
@@ -114,6 +95,10 @@ class MultiAffineFunction : protected IntegerPolyhedron {
   void dump() const;
 
 private:
+  /// The IntegerPolyhedron representing the domain over which the function is
+  /// defined.
+  IntegerPolyhedron domainSet;
+
   /// The function's output is a tuple of integers, with the ith element of the
   /// tuple defined by the affine expression given by the ith row of this output
   /// matrix.

diff  --git a/mlir/include/mlir/Analysis/Presburger/Utils.h b/mlir/include/mlir/Analysis/Presburger/Utils.h
index 732f093c79dc1..28ccb836dfc8e 100644
--- a/mlir/include/mlir/Analysis/Presburger/Utils.h
+++ b/mlir/include/mlir/Analysis/Presburger/Utils.h
@@ -130,6 +130,21 @@ void removeDuplicateDivs(
     SmallVectorImpl<unsigned> &denoms, unsigned localOffset,
     llvm::function_ref<bool(unsigned i, unsigned j)> merge);
 
+/// Given two relations, A and B, add additional local ids to the sets such
+/// that both have the union of the local ids in each set, without changing
+/// the set of points that lie in A and B.
+///
+/// While taking union, if a local id in any set has a division representation
+/// which is a duplicate of division representation, of another local id in any
+/// set, it is not added to the final union of local ids and is instead merged.
+///
+/// On every possible merge, `merge(i, j)` is called. `i`, `j` are position
+/// of local identifiers in both sets which are being merged. If `merge(i, j)`
+/// returns true, the divisions are merged, otherwise the divisions are not
+/// merged.
+void mergeLocalIds(IntegerRelation &relA, IntegerRelation &relB,
+                   llvm::function_ref<bool(unsigned i, unsigned j)> merge);
+
 /// Compute the gcd of the range.
 int64_t gcdRange(ArrayRef<int64_t> range);
 

diff  --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
index 619f9d54b81e2..14f05e0513f82 100644
--- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
@@ -1092,36 +1092,11 @@ void IntegerRelation::eliminateRedundantLocalId(unsigned posA, unsigned posB) {
 /// obtained, and thus these local ids are not considered for detecting
 /// duplicates.
 unsigned IntegerRelation::mergeLocalIds(IntegerRelation &other) {
-  assert(space.isCompatible(other.getSpace()) &&
-         "Spaces should be compatible.");
-
   IntegerRelation &relA = *this;
   IntegerRelation &relB = other;
 
   unsigned oldALocals = relA.getNumLocalIds();
 
-  // Merge local ids of relA and relB without using division information,
-  // i.e. append local ids of `relB` to `relA` and insert local ids of `relA`
-  // to `relB` at start of its local ids.
-  unsigned initLocals = relA.getNumLocalIds();
-  insertId(IdKind::Local, relA.getNumLocalIds(), relB.getNumLocalIds());
-  relB.insertId(IdKind::Local, 0, initLocals);
-
-  // Get division representations from each rel.
-  std::vector<SmallVector<int64_t, 8>> divsA, divsB;
-  SmallVector<unsigned, 4> denomsA, denomsB;
-  relA.getLocalReprs(divsA, denomsA);
-  relB.getLocalReprs(divsB, denomsB);
-
-  // Copy division information for relB into `divsA` and `denomsA`, so that
-  // these have the combined division information of both rels. Since newly
-  // added local variables in relA and relB have no constraints, they will not
-  // have any division representation.
-  std::copy(divsB.begin() + initLocals, divsB.end(),
-            divsA.begin() + initLocals);
-  std::copy(denomsB.begin() + initLocals, denomsB.end(),
-            denomsA.begin() + initLocals);
-
   // Merge function that merges the local variables in both sets by treating
   // them as the same identifier.
   auto merge = [&relA, &relB, oldALocals](unsigned i, unsigned j) -> bool {
@@ -1140,9 +1115,7 @@ unsigned IntegerRelation::mergeLocalIds(IntegerRelation &other) {
     return true;
   };
 
-  // Merge all divisions by removing duplicate divisions.
-  unsigned localOffset = getIdKindOffset(IdKind::Local);
-  presburger::removeDuplicateDivs(divsA, denomsA, localOffset, merge);
+  presburger::mergeLocalIds(*this, other, merge);
 
   // Since we do not remove duplicate divisions in relA, this is guranteed to be
   // non-negative.

diff  --git a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp
index 851099e303e2a..a9e0f8d2ba789 100644
--- a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp
+++ b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp
@@ -35,7 +35,7 @@ PresburgerSet PWMAFunction::getDomain() const {
 
 Optional<SmallVector<int64_t, 8>>
 MultiAffineFunction::valueAt(ArrayRef<int64_t> point) const {
-  assert(point.size() == getNumDimAndSymbolIds() &&
+  assert(point.size() == domainSet.getNumDimAndSymbolIds() &&
          "Point has incorrect dimensionality!");
 
   Optional<SmallVector<int64_t, 8>> maybeLocalValues =
@@ -74,7 +74,7 @@ PWMAFunction::valueAt(ArrayRef<int64_t> point) const {
 
 void MultiAffineFunction::print(raw_ostream &os) const {
   os << "Domain:";
-  IntegerPolyhedron::print(os);
+  domainSet.print(os);
   os << "Output:\n";
   output.print(os);
   os << "\n";
@@ -83,36 +83,24 @@ void MultiAffineFunction::print(raw_ostream &os) const {
 void MultiAffineFunction::dump() const { print(llvm::errs()); }
 
 bool MultiAffineFunction::isEqual(const MultiAffineFunction &other) const {
-  return space.isCompatible(other.getSpace()) &&
+  return getDomainSpace().isCompatible(other.getDomainSpace()) &&
          getDomain().isEqual(other.getDomain()) &&
          isEqualWhereDomainsOverlap(other);
 }
 
 unsigned MultiAffineFunction::insertId(IdKind kind, unsigned pos,
                                        unsigned num) {
-  assert((kind != IdKind::Domain || num == 0) &&
-         "Domain has to be zero in a set");
-  unsigned absolutePos = getIdKindOffset(kind) + pos;
+  assert(kind != IdKind::Domain && "Domain has to be zero in a set");
+  unsigned absolutePos = domainSet.getIdKindOffset(kind) + pos;
   output.insertColumns(absolutePos, num);
-  return IntegerPolyhedron::insertId(kind, pos, num);
-}
-
-void MultiAffineFunction::swapId(unsigned posA, unsigned posB) {
-  output.swapColumns(posA, posB);
-  IntegerPolyhedron::swapId(posA, posB);
+  return domainSet.insertId(kind, pos, num);
 }
 
 void MultiAffineFunction::removeIdRange(IdKind kind, unsigned idStart,
                                         unsigned idLimit) {
-  output.removeColumns(idStart + getIdKindOffset(kind), idLimit - idStart);
-  IntegerPolyhedron::removeIdRange(kind, idStart, idLimit);
-}
-
-void MultiAffineFunction::eliminateRedundantLocalId(unsigned posA,
-                                                    unsigned posB) {
-  unsigned localOffset = getIdKindOffset(IdKind::Local);
-  output.addToColumn(localOffset + posB, localOffset + posA, /*scale=*/1);
-  IntegerPolyhedron::eliminateRedundantLocalId(posA, posB);
+  output.removeColumns(idStart + domainSet.getIdKindOffset(kind),
+                       idLimit - idStart);
+  domainSet.removeIdRange(kind, idStart, idLimit);
 }
 
 void MultiAffineFunction::truncateOutput(unsigned count) {
@@ -127,9 +115,37 @@ void PWMAFunction::truncateOutput(unsigned count) {
   numOutputs = count;
 }
 
+void MultiAffineFunction::mergeLocalIds(MultiAffineFunction &other) {
+  // Merge output local ids of both functions without using division
+  // information i.e. append local ids of `other` to `this` and insert
+  // local ids of `this` to `other` at the start of it's local ids.
+  output.insertColumns(domainSet.getIdKindEnd(IdKind::Local),
+                       other.domainSet.getNumLocalIds());
+  other.output.insertColumns(other.domainSet.getIdKindOffset(IdKind::Local),
+                             domainSet.getNumLocalIds());
+
+  auto merge = [this, &other](unsigned i, unsigned j) -> bool {
+    // Merge local at position j into local at position i in function domain.
+    domainSet.eliminateRedundantLocalId(i, j);
+    other.domainSet.eliminateRedundantLocalId(i, j);
+
+    unsigned localOffset = domainSet.getIdKindOffset(IdKind::Local);
+
+    // Merge local at position j into local at position i in output domain.
+    output.addToColumn(localOffset + j, localOffset + i, 1);
+    output.removeColumn(localOffset + j);
+    other.output.addToColumn(localOffset + j, localOffset + i, 1);
+    other.output.removeColumn(localOffset + j);
+
+    return true;
+  };
+
+  presburger::mergeLocalIds(domainSet, other.domainSet, merge);
+}
+
 bool MultiAffineFunction::isEqualWhereDomainsOverlap(
     MultiAffineFunction other) const {
-  if (!space.isCompatible(other.getSpace()))
+  if (!getDomainSpace().isCompatible(other.getDomainSpace()))
     return false;
 
   // `commonFunc` has the same output as `this`.
@@ -139,7 +155,7 @@ bool MultiAffineFunction::isEqualWhereDomainsOverlap(
   commonFunc.mergeLocalIds(other);
   // After this, the domain of `commonFunc` will be the intersection of the
   // domains of `this` and `other`.
-  commonFunc.IntegerPolyhedron::append(other);
+  commonFunc.domainSet.append(other.domainSet);
 
   // `commonDomainMatching` contains the subset of the common domain
   // where the outputs of `this` and `other` match.
@@ -180,7 +196,7 @@ bool PWMAFunction::isEqual(const PWMAFunction &other) const {
 }
 
 void PWMAFunction::addPiece(const MultiAffineFunction &piece) {
-  assert(space.isCompatible(piece.getSpace()) &&
+  assert(space.isCompatible(piece.getDomainSpace()) &&
          "Piece to be added is not compatible with this PWMAFunction!");
   assert(piece.isConsistent() && "Piece is internally inconsistent!");
   assert(this->getDomain()

diff  --git a/mlir/lib/Analysis/Presburger/Utils.cpp b/mlir/lib/Analysis/Presburger/Utils.cpp
index d25d03447e9ab..df52bb9cb34a6 100644
--- a/mlir/lib/Analysis/Presburger/Utils.cpp
+++ b/mlir/lib/Analysis/Presburger/Utils.cpp
@@ -304,6 +304,39 @@ void presburger::removeDuplicateDivs(
   }
 }
 
+void presburger::mergeLocalIds(
+    IntegerRelation &relA, IntegerRelation &relB,
+    llvm::function_ref<bool(unsigned i, unsigned j)> merge) {
+  assert(relA.getSpace().isCompatible(relB.getSpace()) &&
+         "Spaces should be compatible.");
+
+  // Merge local ids of relA and relB without using division information,
+  // i.e. append local ids of `relB` to `relA` and insert local ids of `relA`
+  // to `relB` at start of its local ids.
+  unsigned initLocals = relA.getNumLocalIds();
+  relA.insertId(IdKind::Local, relA.getNumLocalIds(), relB.getNumLocalIds());
+  relB.insertId(IdKind::Local, 0, initLocals);
+
+  // Get division representations from each rel.
+  std::vector<SmallVector<int64_t, 8>> divsA, divsB;
+  SmallVector<unsigned, 4> denomsA, denomsB;
+  relA.getLocalReprs(divsA, denomsA);
+  relB.getLocalReprs(divsB, denomsB);
+
+  // Copy division information for relB into `divsA` and `denomsA`, so that
+  // these have the combined division information of both rels. Since newly
+  // added local variables in relA and relB have no constraints, they will not
+  // have any division representation.
+  std::copy(divsB.begin() + initLocals, divsB.end(),
+            divsA.begin() + initLocals);
+  std::copy(denomsB.begin() + initLocals, denomsB.end(),
+            denomsA.begin() + initLocals);
+
+  // Merge all divisions by removing duplicate divisions.
+  unsigned localOffset = relA.getIdKindOffset(IdKind::Local);
+  presburger::removeDuplicateDivs(divsA, denomsA, localOffset, merge);
+}
+
 int64_t presburger::gcdRange(ArrayRef<int64_t> range) {
   int64_t gcd = 0;
   for (int64_t elem : range) {


        


More information about the Mlir-commits mailing list