[Mlir-commits] [mlir] 6607bd9 - [MLIR] AffineStructures::removeIdRange: support specifying a range within an IdKind

Arjun P llvmlistbot at llvm.org
Fri Sep 17 03:55:30 PDT 2021


Author: Arjun P
Date: 2021-09-17T16:25:26+05:30
New Revision: 6607bd9fd819de1a5872dce47ce1a67bbb9a12e8

URL: https://github.com/llvm/llvm-project/commit/6607bd9fd819de1a5872dce47ce1a67bbb9a12e8
DIFF: https://github.com/llvm/llvm-project/commit/6607bd9fd819de1a5872dce47ce1a67bbb9a12e8.diff

LOG: [MLIR] AffineStructures::removeIdRange: support specifying a range within an IdKind

Reviewed By: Groverkss, grosser

Differential Revision: https://reviews.llvm.org/D109896

Added: 
    

Modified: 
    mlir/include/mlir/Analysis/AffineStructures.h
    mlir/lib/Analysis/AffineStructures.cpp
    mlir/unittests/Analysis/AffineStructuresTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h
index c2dbc5c89da73..b5676186d87d7 100644
--- a/mlir/include/mlir/Analysis/AffineStructures.h
+++ b/mlir/include/mlir/Analysis/AffineStructures.h
@@ -282,6 +282,12 @@ class FlatAffineConstraints {
   void projectOut(unsigned pos, unsigned num);
   inline void projectOut(unsigned pos) { return projectOut(pos, 1); }
 
+  /// Removes identifiers of the specified kind with the specified pos (or
+  /// within the specified range) from the system. The specified location is
+  /// relative to the first identifier of the specified kind.
+  void removeId(IdKind kind, unsigned pos);
+  void removeIdRange(IdKind kind, unsigned idStart, unsigned idLimit);
+
   /// Removes the specified identifier from the system.
   void removeId(unsigned pos);
 
@@ -423,6 +429,12 @@ class FlatAffineConstraints {
   void dump() const;
 
 protected:
+  /// Return the index at which the specified kind of id starts.
+  unsigned getIdKindOffset(IdKind kind) const;
+
+  /// Assert that `value` is at most the number of ids of the specified kind.
+  void assertAtMostNumIdKind(unsigned value, IdKind kind) const;
+
   /// Returns false if the fields corresponding to various identifier counts, or
   /// equality/inequality buffer sizes aren't consistent; true otherwise. This
   /// is meant to be used within an assert internally.

diff  --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp
index 43de94407db2b..4453e85318846 100644
--- a/mlir/lib/Analysis/AffineStructures.cpp
+++ b/mlir/lib/Analysis/AffineStructures.cpp
@@ -311,23 +311,13 @@ unsigned FlatAffineConstraints::insertLocalId(unsigned pos, unsigned num) {
 
 unsigned FlatAffineConstraints::insertId(IdKind kind, unsigned pos,
                                          unsigned num) {
-  if (kind == IdKind::Dimension)
-    assert(pos <= getNumDimIds());
-  else if (kind == IdKind::Symbol)
-    assert(pos <= getNumSymbolIds());
-  else
-    assert(pos <= getNumLocalIds());
+  assertAtMostNumIdKind(pos, kind);
 
-  unsigned absolutePos;
-  if (kind == IdKind::Dimension) {
-    absolutePos = pos;
+  unsigned absolutePos = getIdKindOffset(kind) + pos;
+  if (kind == IdKind::Dimension)
     numDims += num;
-  } else if (kind == IdKind::Symbol) {
-    absolutePos = pos + getNumDimIds();
+  else if (kind == IdKind::Symbol)
     numSymbols += num;
-  } else {
-    absolutePos = pos + getNumDimIds() + getNumSymbolIds();
-  }
   numIds += num;
 
   inequalities.insertColumns(absolutePos, num);
@@ -336,6 +326,28 @@ unsigned FlatAffineConstraints::insertId(IdKind kind, unsigned pos,
   return absolutePos;
 }
 
+void FlatAffineConstraints::assertAtMostNumIdKind(unsigned val,
+                                                  IdKind kind) const {
+  if (kind == IdKind::Dimension)
+    assert(val <= getNumDimIds());
+  else if (kind == IdKind::Symbol)
+    assert(val <= getNumSymbolIds());
+  else if (kind == IdKind::Local)
+    assert(val <= getNumLocalIds());
+  else
+    llvm_unreachable("IdKind expected to be Dimension, Symbol or Local!");
+}
+
+unsigned FlatAffineConstraints::getIdKindOffset(IdKind kind) const {
+  if (kind == IdKind::Dimension)
+    return 0;
+  if (kind == IdKind::Symbol)
+    return getNumDimIds();
+  if (kind == IdKind::Local)
+    return getNumDimAndSymbolIds();
+  llvm_unreachable("IdKind expected to be Dimension, Symbol or Local!");
+}
+
 unsigned FlatAffineValueConstraints::insertId(IdKind kind, unsigned pos,
                                               unsigned num) {
   unsigned absolutePos = FlatAffineConstraints::insertId(kind, pos, num);
@@ -365,6 +377,17 @@ bool FlatAffineValueConstraints::hasValues() const {
          }) != values.end();
 }
 
+void FlatAffineConstraints::removeId(IdKind kind, unsigned pos) {
+  removeIdRange(kind, pos, pos + 1);
+}
+
+void FlatAffineConstraints::removeIdRange(IdKind kind, unsigned idStart,
+                                          unsigned idLimit) {
+  assertAtMostNumIdKind(idLimit, kind);
+  removeIdRange(getIdKindOffset(kind) + idStart,
+                getIdKindOffset(kind) + idLimit);
+}
+
 /// Checks if two constraint systems are in the same space, i.e., if they are
 /// associated with the same set of identifiers, appearing in the same order.
 static bool areIdsAligned(const FlatAffineValueConstraints &a,

diff  --git a/mlir/unittests/Analysis/AffineStructuresTest.cpp b/mlir/unittests/Analysis/AffineStructuresTest.cpp
index 56f0dbdc950c4..d5a88b684b9e5 100644
--- a/mlir/unittests/Analysis/AffineStructuresTest.cpp
+++ b/mlir/unittests/Analysis/AffineStructuresTest.cpp
@@ -711,4 +711,22 @@ TEST(FlatAffineConstraintsTest, computeLocalReprRecursive) {
   checkDivisionRepresentation(fac, divisions, denoms);
 }
 
+TEST(FlatAffineConstraintsTest, removeIdRange) {
+  FlatAffineConstraints fac(3, 2, 1);
+
+  fac.addInequality({10, 11, 12, 20, 21, 30, 40});
+  fac.removeId(FlatAffineConstraints::IdKind::Symbol, 1);
+  EXPECT_THAT(fac.getInequality(0),
+              testing::ElementsAre(10, 11, 12, 20, 30, 40));
+
+  fac.removeIdRange(FlatAffineConstraints::IdKind::Dimension, 0, 2);
+  EXPECT_THAT(fac.getInequality(0), testing::ElementsAre(12, 20, 30, 40));
+
+  fac.removeIdRange(FlatAffineConstraints::IdKind::Local, 1, 1);
+  EXPECT_THAT(fac.getInequality(0), testing::ElementsAre(12, 20, 30, 40));
+
+  fac.removeIdRange(FlatAffineConstraints::IdKind::Local, 0, 1);
+  EXPECT_THAT(fac.getInequality(0), testing::ElementsAre(12, 20, 40));
+}
+
 } // namespace mlir


        


More information about the Mlir-commits mailing list