[Mlir-commits] [mlir] d222b69 - [MLIR][Presburger] Add Identifier class to store identifiers and their type

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Sep 5 00:16:46 PDT 2023


Author: Groverkss
Date: 2023-09-05T12:46:00+05:30
New Revision: d222b69093d69e034b7f01d9ceedce28d48fbb0c

URL: https://github.com/llvm/llvm-project/commit/d222b69093d69e034b7f01d9ceedce28d48fbb0c
DIFF: https://github.com/llvm/llvm-project/commit/d222b69093d69e034b7f01d9ceedce28d48fbb0c.diff

LOG: [MLIR][Presburger] Add Identifier class to store identifiers and their type

This patch adds a new class Identifier to store identifiers in PresburgerSpace
and their types.

Identifiers were added earlier and were stored as a void pointer, and their type
in the form of mlir::TypeId in PresburgerSpace. To get an identifier, a user of
PresburgerSpace needed to know the type of identifiers. This was a problem for
users of PresburgerSpace like IntegerRelation, which want to work on
identifiers without knowing their type.

The Identifier class allows users like IntegerRelation to work on identifiers
without knowing their type, and also exposes an easier way to work with
Identifiers.

Reviewed By: arjunp

Differential Revision: https://reviews.llvm.org/D146909

Added: 
    

Modified: 
    mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h
    mlir/lib/Analysis/Presburger/PresburgerSpace.cpp
    mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h
index 998a70c677bf5e8..9fe2abafd36badb 100644
--- a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h
+++ b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h
@@ -28,6 +28,90 @@ namespace presburger {
 /// as relations with zero domain vars.
 enum class VarKind { Symbol, Local, Domain, Range, SetDim = Range };
 
+/// An Identifier stores a pointer to an object, such as a Value or an
+/// Operation. Identifiers are intended to be attached to a variable in a
+/// PresburgerSpace and can be used to check if two variables correspond to the
+/// same object.
+///
+/// Take for example the following code:
+///
+/// for i = 0 to 100
+///   for j = 0 to 100
+///     S0: A[j] = 0
+///   for k = 0 to 100
+///     S1: A[k] = 1
+///
+/// If we represent the space of iteration variables surrounding S0, S1 we have:
+/// space(S0): {d0, d1}
+/// space(S1): {d0, d1}
+///
+/// Since the variables are in 
diff erent spaces, without an identifier, there
+/// is no way to distinguish if the variables in the two spaces correspond to
+/// 
diff erent SSA values in the program. So, we attach an Identifier
+/// corresponding to the loop iteration variable to them. Now,
+///
+/// space(S0) = {d0(id = i), d1(id = j)}
+/// space(S1) = {d0(id = i), d1(id = k)}.
+///
+/// Using the identifier, we can check that the first iteration variable in
+/// both the spaces correspond to the same variable in the program, while they
+/// are 
diff erent for second iteration variable.
+///
+/// The equality of Identifiers is checked by comparing the stored pointers.
+/// Checking equality asserts that the type of the equal identifiers is same.
+/// Identifiers storing null pointers are treated as having no attachment and
+/// are considered unequal to any other identifier, including other identifiers
+/// with no attachments.
+///
+/// The type of the pointer stored must have an `llvm::PointerLikeTypeTraits`
+/// specialization.
+class Identifier {
+public:
+  Identifier() = default;
+
+  // Create an identifier from a pointer.
+  template <typename T>
+  explicit Identifier(T value)
+      : value(llvm::PointerLikeTypeTraits<T>::getAsVoidPointer(value)) {
+#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
+    idType = TypeID::get<T>();
+#endif
+  }
+
+  /// Get the value of the identifier casted to type `T`. `T` here should match
+  /// the type of the identifier used to create it.
+  template <typename T>
+  T getValue() const {
+#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
+    assert(TypeID::get<T>() == idType &&
+           "Identifier was initialized with a 
diff erent type than the one used "
+           "to retrieve it.");
+#endif
+    return llvm::PointerLikeTypeTraits<T>::getFromVoidPointer(value);
+  }
+
+  bool hasValue() const { return value != nullptr; }
+
+  /// Check if the two identifiers are equal. Null identifiers are considered
+  /// not equal. Asserts if two identifiers are equal but their types are not.
+  bool isEqual(const Identifier &other) const;
+
+  bool operator==(const Identifier &other) const { return isEqual(other); }
+  bool operator!=(const Identifier &other) const { return !isEqual(other); }
+
+  void print(llvm::raw_ostream &os) const;
+  void dump() const;
+
+private:
+  /// The value of the identifier.
+  void *value = nullptr;
+
+#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
+  /// TypeID of the identifiers in space. This should be used in asserts only.
+  TypeID idType = TypeID::get<void>();
+#endif
+};
+
 /// PresburgerSpace is the space of all possible values of a tuple of integer
 /// valued variables/variables. Each variable has one of the three types:
 ///
@@ -66,14 +150,12 @@ enum class VarKind { Symbol, Local, Domain, Range, SetDim = Range };
 /// other than Locals are equal. Equality of two spaces implies that number of
 /// variables of each kind are equal.
 ///
-/// PresburgerSpace optionally also supports attaching some information to each
-/// variable in space, called "identifier" of that variable. `resetIds<IdType>`
-/// is used to enable/reset these identifiers. All identifiers must be of the
-/// same type, `IdType`. `IdType` must have a `llvm::PointerLikeTypeTraits`
-/// specialization available and should be supported via `mlir::TypeID`.
-///
-/// These identifiers can be used to check if two variables in two 
diff erent
-/// spaces are actually same variable.
+/// PresburgerSpace optionally also supports attaching an Identifier with each
+/// non-local variable in the space. This is disabled by default. `resetIds` is
+/// used to enable/reset these identifiers. The user can identify each variable
+/// in the space as corresponding to some Identifier. Some example use cases
+/// are described in the `Identifier` documentation above. The type attached to
+/// the Identifier can be 
diff erent for 
diff erent variables in the space.
 class PresburgerSpace {
 public:
   static PresburgerSpace getRelationSpace(unsigned numDomain = 0,
@@ -142,6 +224,20 @@ class PresburgerSpace {
   /// varLimit). The range is relative to the kind of variable.
   void removeVarRange(VarKind kind, unsigned varStart, unsigned varLimit);
 
+  /// Converts variables of the specified kind in the column range [srcPos,
+  /// srcPos + num) to variables of the specified kind at position dstPos. The
+  /// ranges are relative to the kind of variable.
+  ///
+  /// srcKind and dstKind must be 
diff erent.
+  void convertVarKind(VarKind srcKind, unsigned srcPos, unsigned num,
+                      VarKind dstKind, unsigned dstPos);
+
+  /// Changes the partition between dimensions and symbols. Depending on the new
+  /// symbol count, either a chunk of dimensional variables immediately before
+  /// the split become symbols, or some of the symbols immediately after the
+  /// split become dimensions.
+  void setVarSymbolSeperation(unsigned newSymbolCount);
+
   /// Swaps the posA^th variable of kindA and posB^th variable of kindB.
   void swapVar(VarKind kindA, VarKind kindB, unsigned posA, unsigned posB);
 
@@ -154,77 +250,29 @@ class PresburgerSpace {
   /// locals).
   bool isEqual(const PresburgerSpace &other) const;
 
-  /// Changes the partition between dimensions and symbols. Depending on the new
-  /// symbol count, either a chunk of dimensional variables immediately before
-  /// the split become symbols, or some of the symbols immediately after the
-  /// split become dimensions.
-  void setVarSymbolSeperation(unsigned newSymbolCount);
-
-  void print(llvm::raw_ostream &os) const;
-  void dump() const;
-
-  //===--------------------------------------------------------------------===//
-  //     Identifier Interactions
-  //===--------------------------------------------------------------------===//
-
-  /// Set the identifier for `i^th` variable to `id`. `T` here should match the
-  /// type used to enable identifiers.
-  template <typename T>
-  void setId(VarKind kind, unsigned i, T id) {
-#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
-    assert(TypeID::get<T>() == idType && "Type mismatch");
-#endif
-    atId(kind, i) = llvm::PointerLikeTypeTraits<T>::getAsVoidPointer(id);
-  }
-
-  /// Get the identifier for `i^th` variable casted to type `T`. `T` here
-  /// should match the type used to enable identifiers.
-  template <typename T>
-  T getId(VarKind kind, unsigned i) const {
-#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
-    assert(TypeID::get<T>() == idType && "Type mismatch");
-#endif
-    return llvm::PointerLikeTypeTraits<T>::getFromVoidPointer(atId(kind, i));
+  /// Get the identifier of the specified variable.
+  Identifier &getId(VarKind kind, unsigned pos) {
+    assert(kind != VarKind::Local && "Local variables have no identifiers");
+    return identifiers[getVarKindOffset(kind) + pos];
   }
-
-  /// Check if the i^th variable of the specified kind has a non-null
-  /// identifier.
-  bool hasId(VarKind kind, unsigned i) const {
-    return atId(kind, i) != nullptr;
+  Identifier getId(VarKind kind, unsigned pos) const {
+    assert(kind != VarKind::Local && "Local variables have no identifiers");
+    return identifiers[getVarKindOffset(kind) + pos];
   }
 
-  /// Check if the spaces are compatible, as well as have the same identifiers
-  /// for each variable.
-  bool isAligned(const PresburgerSpace &other) const;
-  /// Check if the number of variables of the specified kind match, and have
-  /// same identifiers with the other space.
-  bool isAligned(const PresburgerSpace &other, VarKind kind) const;
-
-  /// Find the variable of the specified kind with identifier `id`.
-  /// Returns PresburgerSpace::kIdNotFound if identifier is not found.
-  template <typename T>
-  unsigned findId(VarKind kind, T id) const {
-    unsigned i = 0;
-    for (unsigned e = getNumVarKind(kind); i < e; ++i)
-      if (hasId(kind, i) && getId<T>(kind, i) == id)
-        return i;
-    return kIdNotFound;
+  ArrayRef<Identifier> getIds(VarKind kind) const {
+    assert(kind != VarKind::Local && "Local variables have no identifiers");
+    return {identifiers.data() + getVarKindOffset(kind), getNumVarKind(kind)};
   }
-  static const unsigned kIdNotFound = UINT_MAX;
 
   /// Returns if identifiers are being used.
   bool isUsingIds() const { return usingIds; }
 
   /// Reset the stored identifiers in the space. Enables `usingIds` if it was
   /// `false` before.
-  template <typename T>
   void resetIds() {
     identifiers.clear();
     identifiers.resize(getNumDimAndSymbolVars());
-#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
-    idType = TypeID::get<T>();
-#endif
-
     usingIds = true;
   }
 
@@ -234,26 +282,23 @@ class PresburgerSpace {
     usingIds = false;
   }
 
+  /// Check if the spaces are compatible, and the non-local variables having
+  /// same identifiers are in the same positions. If the space is not using
+  /// Identifiers, this check is same as isCompatible.
+  bool isAligned(const PresburgerSpace &other) const;
+  /// Same as above but only check the specified VarKind. Useful to check if
+  /// the symbols in two spaces are aligned.
+  bool isAligned(const PresburgerSpace &other, VarKind kind) const;
+
+  void print(llvm::raw_ostream &os) const;
+  void dump() const;
+
 protected:
-  PresburgerSpace(unsigned numDomain = 0, unsigned numRange = 0,
-                  unsigned numSymbols = 0, unsigned numLocals = 0)
+  PresburgerSpace(unsigned numDomain, unsigned numRange, unsigned numSymbols,
+                  unsigned numLocals)
       : numDomain(numDomain), numRange(numRange), numSymbols(numSymbols),
         numLocals(numLocals) {}
 
-  void *&atId(VarKind kind, unsigned i) {
-    assert(usingIds && "Cannot access identifiers when `usingIds` is false.");
-    assert(kind != VarKind::Local &&
-           "Local variables cannot have identifiers.");
-    return identifiers[getVarKindOffset(kind) + i];
-  }
-
-  void *atId(VarKind kind, unsigned i) const {
-    assert(usingIds && "Cannot access identifiers when `usingIds` is false.");
-    assert(kind != VarKind::Local &&
-           "Local variables cannot have identifiers.");
-    return identifiers[getVarKindOffset(kind) + i];
-  }
-
 private:
   // Number of variables corresponding to domain variables.
   unsigned numDomain;
@@ -272,13 +317,8 @@ class PresburgerSpace {
   /// Stores whether or not identifiers are being used in this space.
   bool usingIds = false;
 
-#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
-  /// TypeID of the identifiers in space. This should be used in asserts only.
-  TypeID idType;
-#endif
-
   /// Stores an identifier for each non-local variable as a `void` pointer.
-  SmallVector<void *, 0> identifiers;
+  SmallVector<Identifier, 0> identifiers;
 };
 
 } // namespace presburger

diff  --git a/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp b/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp
index e15db1edf8cb489..c4d01c551b43795 100644
--- a/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp
+++ b/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp
@@ -13,18 +13,40 @@
 using namespace mlir;
 using namespace presburger;
 
+bool Identifier::isEqual(const Identifier &other) const {
+  if (value == nullptr || other.value == nullptr)
+    return false;
+  assert(value == other.value && idType == other.idType &&
+         "Values of Identifiers are equal but their types do not match.");
+  return value == other.value;
+}
+
+void Identifier::print(llvm::raw_ostream &os) const {
+  os << "Id<" << value << ">";
+}
+
+void Identifier::dump() const {
+  print(llvm::errs());
+  llvm::errs() << "\n";
+}
+
 PresburgerSpace PresburgerSpace::getDomainSpace() const {
-  // TODO: Preserve identifiers here.
-  return PresburgerSpace::getSetSpace(numDomain, numSymbols, numLocals);
+  PresburgerSpace newSpace = *this;
+  newSpace.removeVarRange(VarKind::Range, 0, getNumRangeVars());
+  newSpace.convertVarKind(VarKind::Domain, 0, getNumDomainVars(),
+                          VarKind::SetDim, 0);
+  return newSpace;
 }
 
 PresburgerSpace PresburgerSpace::getRangeSpace() const {
-  return PresburgerSpace::getSetSpace(numRange, numSymbols, numLocals);
+  PresburgerSpace newSpace = *this;
+  newSpace.removeVarRange(VarKind::Domain, 0, getNumDomainVars());
+  return newSpace;
 }
 
 PresburgerSpace PresburgerSpace::getSpaceWithoutLocals() const {
   PresburgerSpace space = *this;
-  space.removeVarRange(VarKind::Local, 0, numLocals);
+  space.removeVarRange(VarKind::Local, 0, getNumLocalVars());
   return space;
 }
 
@@ -36,7 +58,7 @@ unsigned PresburgerSpace::getNumVarKind(VarKind kind) const {
   if (kind == VarKind::Symbol)
     return getNumSymbolVars();
   if (kind == VarKind::Local)
-    return numLocals;
+    return getNumLocalVars();
   llvm_unreachable("VarKind does not exist!");
 }
 
@@ -101,7 +123,7 @@ unsigned PresburgerSpace::insertVar(VarKind kind, unsigned pos, unsigned num) {
   // Insert NULL identifiers if `usingIds` and variables inserted are
   // not locals.
   if (usingIds && kind != VarKind::Local)
-    identifiers.insert(identifiers.begin() + absolutePos, num, nullptr);
+    identifiers.insert(identifiers.begin() + absolutePos, num, Identifier());
 
   return absolutePos;
 }
@@ -130,26 +152,71 @@ void PresburgerSpace::removeVarRange(VarKind kind, unsigned varStart,
                       identifiers.begin() + getVarKindOffset(kind) + varLimit);
 }
 
+void PresburgerSpace::convertVarKind(VarKind srcKind, unsigned srcPos,
+                                     unsigned num, VarKind dstKind,
+                                     unsigned dstPos) {
+  assert(srcKind != dstKind && "cannot convert variables to the same kind");
+  assert(srcPos + num <= getNumVarKind(srcKind) &&
+         "invalid range for source variables");
+  assert(dstPos <= getNumVarKind(dstKind) &&
+         "invalid position for destination variables");
+
+  auto addVars = [&](VarKind kind, int num) {
+    switch (kind) {
+    case VarKind::Domain:
+      numDomain += num;
+      break;
+    case VarKind::Range:
+      numRange += num;
+      break;
+    case VarKind::Symbol:
+      numSymbols += num;
+      break;
+    case VarKind::Local:
+      numLocals += num;
+      break;
+    }
+  };
+
+  addVars(srcKind, -(signed)num);
+  addVars(dstKind, num);
+
+  // Move identifiers if `usingIds` and variables moved are not locals.
+  unsigned srcOffset = getVarKindOffset(srcKind) + srcPos;
+  unsigned dstOffset = getVarKindOffset(dstKind) + dstPos;
+  if (isUsingIds() && srcKind != VarKind::Local && dstKind != VarKind::Local) {
+    identifiers.insert(identifiers.begin() + dstOffset, num, Identifier());
+    for (unsigned i = 0; i < num; ++i)
+      identifiers[dstOffset + i] = identifiers[srcOffset + i];
+    identifiers.erase(identifiers.begin() + srcOffset,
+                      identifiers.begin() + srcOffset + num);
+  } else if (isUsingIds() && srcKind != VarKind::Local) {
+    identifiers.erase(identifiers.begin() + srcOffset,
+                      identifiers.begin() + srcOffset + num);
+  } else if (isUsingIds() && dstKind != VarKind::Local) {
+    identifiers.insert(identifiers.begin() + dstOffset, num, Identifier());
+  }
+}
+
 void PresburgerSpace::swapVar(VarKind kindA, VarKind kindB, unsigned posA,
                               unsigned posB) {
-
-  if (!usingIds)
+  if (!isUsingIds())
     return;
 
   if (kindA == VarKind::Local && kindB == VarKind::Local)
     return;
 
   if (kindA == VarKind::Local) {
-    atId(kindB, posB) = nullptr;
+    getId(kindB, posB) = Identifier();
     return;
   }
 
   if (kindB == VarKind::Local) {
-    atId(kindA, posA) = nullptr;
+    getId(kindA, posA) = Identifier();
     return;
   }
 
-  std::swap(atId(kindA, posA), atId(kindB, posB));
+  std::swap(getId(kindA, posA), getId(kindB, posB));
 }
 
 bool PresburgerSpace::isCompatible(const PresburgerSpace &other) const {
@@ -162,25 +229,53 @@ bool PresburgerSpace::isEqual(const PresburgerSpace &other) const {
   return isCompatible(other) && getNumLocalVars() == other.getNumLocalVars();
 }
 
+/// Checks if the number of ids of the given kind in the two spaces are
+/// equal and if the ids are equal. Assumes that both spaces are using
+/// ids.
+static bool areIdsEqual(const PresburgerSpace &spaceA,
+                        const PresburgerSpace &spaceB, VarKind kind) {
+  assert(spaceA.isUsingIds() && spaceB.isUsingIds() &&
+         "Both spaces should be using ids");
+  if (spaceA.getNumVarKind(kind) != spaceB.getNumVarKind(kind))
+    return false;
+  if (kind == VarKind::Local)
+    return true; // No ids.
+  return spaceA.getIds(kind) == spaceB.getIds(kind);
+}
+
 bool PresburgerSpace::isAligned(const PresburgerSpace &other) const {
-  assert(isUsingIds() && other.isUsingIds() &&
-         "Both spaces should be using identifiers to check for "
-         "alignment.");
-  return isCompatible(other) && identifiers == other.identifiers;
+  // If only one of the spaces is using identifiers, then they are
+  // not aligned.
+  if (isUsingIds() != other.isUsingIds())
+    return false;
+  // If both spaces are using identifiers, then they are aligned if
+  // their identifiers are equal. Identifiers being equal implies
+  // that the number of variables of each kind is same, which implies
+  // compatiblity, so we do not check for that.
+  if (isUsingIds())
+    return areIdsEqual(*this, other, VarKind::Domain) &&
+           areIdsEqual(*this, other, VarKind::Range) &&
+           areIdsEqual(*this, other, VarKind::Symbol);
+  // If neither space is using identifiers, then they are aligned if
+  // they are compatible.
+  return isCompatible(other);
 }
 
 bool PresburgerSpace::isAligned(const PresburgerSpace &other,
                                 VarKind kind) const {
-  assert(isUsingIds() && other.isUsingIds() &&
-         "Both spaces should be using identifiers to check for "
-         "alignment.");
-
-  ArrayRef<void *> kindAttachments =
-      ArrayRef(identifiers).slice(getVarKindOffset(kind), getNumVarKind(kind));
-  ArrayRef<void *> otherKindAttachments =
-      ArrayRef(other.identifiers)
-          .slice(other.getVarKindOffset(kind), other.getNumVarKind(kind));
-  return kindAttachments == otherKindAttachments;
+  // If only one of the spaces is using identifiers, then they are
+  // not aligned.
+  if (isUsingIds() != other.isUsingIds())
+    return false;
+  // If both spaces are using identifiers, then they are aligned if
+  // their identifiers are equal. Identifiers being equal implies
+  // that the number of variables of each kind is same, which implies
+  // compatiblity, so we do not check for that
+  if (isUsingIds())
+    return areIdsEqual(*this, other, kind);
+  // If neither space is using identifiers, then they are aligned if
+  // the number of variable kind is equal.
+  return getNumVarKind(kind) == other.getNumVarKind(kind);
 }
 
 void PresburgerSpace::setVarSymbolSeperation(unsigned newSymbolCount) {
@@ -198,16 +293,29 @@ void PresburgerSpace::print(llvm::raw_ostream &os) const {
      << "Symbols: " << getNumSymbolVars() << ", "
      << "Locals: " << getNumLocalVars() << "\n";
 
-  if (usingIds) {
-#ifdef LLVM_ENABLE_ABI_BREAKING_CHECKS
-    os << "TypeID of identifiers: " << idType.getAsOpaquePointer() << "\n";
-#endif
+  if (isUsingIds()) {
+    auto printIds = [&](VarKind kind) {
+      os << " ";
+      for (Identifier id : getIds(kind)) {
+        if (id.hasValue())
+          id.print(os);
+        else
+          os << "None";
+        os << " ";
+      }
+    };
 
     os << "(";
-    for (void *identifier : identifiers)
-      os << identifier << " ";
-    os << ")\n";
+    printIds(VarKind::Domain);
+    os << ") -> (";
+    printIds(VarKind::Range);
+    os << ") : [";
+    printIds(VarKind::Symbol);
+    os << "]";
   }
 }
 
-void PresburgerSpace::dump() const { print(llvm::errs()); }
+void PresburgerSpace::dump() const {
+  print(llvm::errs());
+  llvm::errs() << "\n";
+}

diff  --git a/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp b/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp
index 9966954ed69bf69..cb23174b939c38b 100644
--- a/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp
@@ -52,12 +52,13 @@ TEST(PresburgerSpaceTest, removeIdRange) {
 
 TEST(PresburgerSpaceTest, insertVarIdentifier) {
   PresburgerSpace space = PresburgerSpace::getRelationSpace(2, 2, 1, 0);
-  space.resetIds<int *>();
+  space.resetIds();
 
-  // Attach identifiers to domain ids.
   int identifiers[2] = {0, 1};
-  space.setId<int *>(VarKind::Domain, 0, &identifiers[0]);
-  space.setId<int *>(VarKind::Domain, 1, &identifiers[1]);
+
+  // Attach identifiers to domain ids.
+  space.getId(VarKind::Domain, 0) = Identifier(&identifiers[0]);
+  space.getId(VarKind::Domain, 1) = Identifier(&identifiers[1]);
 
   // Try inserting 2 domain ids.
   space.insertVar(VarKind::Domain, 0, 2);
@@ -68,27 +69,27 @@ TEST(PresburgerSpaceTest, insertVarIdentifier) {
   EXPECT_EQ(space.getNumRangeVars(), 3u);
 
   // Check if the identifiers for the old ids are still attached properly.
-  EXPECT_EQ(*space.getId<int *>(VarKind::Domain, 2), identifiers[0]);
-  EXPECT_EQ(*space.getId<int *>(VarKind::Domain, 3), identifiers[1]);
+  EXPECT_EQ(space.getId(VarKind::Domain, 2), Identifier(&identifiers[0]));
+  EXPECT_EQ(space.getId(VarKind::Domain, 3), Identifier(&identifiers[1]));
 }
 
 TEST(PresburgerSpaceTest, removeVarRangeIdentifier) {
   PresburgerSpace space = PresburgerSpace::getRelationSpace(2, 1, 3, 0);
-  space.resetIds<int *>();
+  space.resetIds();
 
   int identifiers[6] = {0, 1, 2, 3, 4, 5};
 
   // Attach identifiers to domain identifiers.
-  space.setId<int *>(VarKind::Domain, 0, &identifiers[0]);
-  space.setId<int *>(VarKind::Domain, 1, &identifiers[1]);
+  space.getId(VarKind::Domain, 0) = Identifier(&identifiers[0]);
+  space.getId(VarKind::Domain, 1) = Identifier(&identifiers[1]);
 
   // Attach identifiers to range identifiers.
-  space.setId<int *>(VarKind::Range, 0, &identifiers[2]);
+  space.getId(VarKind::Range, 0) = Identifier(&identifiers[2]);
 
   // Attach identifiers to symbol identifiers.
-  space.setId<int *>(VarKind::Symbol, 0, &identifiers[3]);
-  space.setId<int *>(VarKind::Symbol, 1, &identifiers[4]);
-  space.setId<int *>(VarKind::Symbol, 2, &identifiers[5]);
+  space.getId(VarKind::Symbol, 0) = Identifier(&identifiers[3]);
+  space.getId(VarKind::Symbol, 1) = Identifier(&identifiers[4]);
+  space.getId(VarKind::Symbol, 2) = Identifier(&identifiers[5]);
 
   // Remove 1 domain identifier.
   space.removeVarRange(VarKind::Domain, 0, 1);
@@ -102,9 +103,58 @@ TEST(PresburgerSpaceTest, removeVarRangeIdentifier) {
   EXPECT_EQ(space.getNumSymbolVars(), 2u);
 
   // Check if domain identifiers are attached properly.
-  EXPECT_EQ(*space.getId<int *>(VarKind::Domain, 0), identifiers[1]);
+  EXPECT_EQ(space.getId(VarKind::Domain, 0), Identifier(&identifiers[1]));
 
   // Check if symbol identifiers are attached properly.
-  EXPECT_EQ(*space.getId<int *>(VarKind::Range, 0), identifiers[4]);
-  EXPECT_EQ(*space.getId<int *>(VarKind::Range, 1), identifiers[5]);
+  EXPECT_EQ(space.getId(VarKind::Range, 0), Identifier(&identifiers[4]));
+  EXPECT_EQ(space.getId(VarKind::Range, 1), Identifier(&identifiers[5]));
+}
+
+TEST(PresburgerSpaceTest, convertVarKind) {
+  PresburgerSpace space = PresburgerSpace::getRelationSpace(2, 2, 0, 0);
+  space.resetIds();
+
+  // Attach identifiers.
+  int identifiers[4] = {0, 1, 2, 3};
+  space.getId(VarKind::Domain, 0) = Identifier(&identifiers[0]);
+  space.getId(VarKind::Domain, 1) = Identifier(&identifiers[1]);
+  space.getId(VarKind::Range, 0) = Identifier(&identifiers[2]);
+  space.getId(VarKind::Range, 1) = Identifier(&identifiers[3]);
+
+  // Convert Range variables to symbols.
+  space.convertVarKind(VarKind::Range, 0, 2, VarKind::Symbol, 0);
+
+  // Check if the identifiers are moved to symbols.
+  EXPECT_EQ(space.getId(VarKind::Symbol, 0), Identifier(&identifiers[2]));
+  EXPECT_EQ(space.getId(VarKind::Symbol, 1), Identifier(&identifiers[3]));
+
+  // Convert 1 symbol to range identifier.
+  space.convertVarKind(VarKind::Symbol, 1, 1, VarKind::Range, 0);
+
+  // Check if the identifier is moved to range.
+  EXPECT_EQ(space.getId(VarKind::Range, 0), Identifier(&identifiers[3]));
+}
+
+TEST(PresburgerSpaceTest, convertVarKindLocals) {
+  PresburgerSpace space = PresburgerSpace::getRelationSpace(2, 2, 0, 0);
+  space.resetIds();
+
+  // Attach identifiers to range variables.
+  int identifiers[4] = {0, 1};
+  space.getId(VarKind::Range, 0) = Identifier(&identifiers[0]);
+  space.getId(VarKind::Range, 1) = Identifier(&identifiers[1]);
+
+  // Convert Range variables to locals i.e. project them out.
+  space.convertVarKind(VarKind::Range, 0, 2, VarKind::Local, 0);
+
+  // Check if the variables were moved.
+  EXPECT_EQ(space.getNumVarKind(VarKind::Range), 0u);
+  EXPECT_EQ(space.getNumVarKind(VarKind::Local), 2u);
+
+  // Convert the Local variables back to Range variables.
+  space.convertVarKind(VarKind::Local, 0, 2, VarKind::Range, 0);
+
+  // The identifier information should be lost.
+  EXPECT_FALSE(space.getId(VarKind::Range, 0).hasValue());
+  EXPECT_FALSE(space.getId(VarKind::Range, 1).hasValue());
 }


        


More information about the Mlir-commits mailing list