[Mlir-commits] [mlir] [MLIR][Presburger] Fix bug in PresburgerSpace::convertVarKind (PR #67267)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Sep 24 07:54:48 PDT 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

<details>
<summary>Changes</summary>

This patch fixes a bug in PresburgerSpace::convertVarKind where the identifiers were not moved properly due to offset being invalidated.

---
Full diff: https://github.com/llvm/llvm-project/pull/67267.diff


2 Files Affected:

- (modified) mlir/lib/Analysis/Presburger/PresburgerSpace.cpp (+20-16) 
- (modified) mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp (+21) 


``````````diff
diff --git a/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp b/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp
index c4d01c551b43795..e62ba9c7b90806b 100644
--- a/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp
+++ b/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp
@@ -161,6 +161,26 @@ void PresburgerSpace::convertVarKind(VarKind srcKind, unsigned srcPos,
   assert(dstPos <= getNumVarKind(dstKind) &&
          "invalid position for destination variables");
 
+  // Move identifiers if `usingIds` and variables moved are not locals.
+  unsigned srcOffset = getVarKindOffset(srcKind) + srcPos;
+  unsigned dstOffset = getVarKindOffset(dstKind) + dstPos;
+  if (isUsingIds() && srcKind != VarKind::Local && dstKind != VarKind::Local) {
+    identifiers.insert(identifiers.begin() + dstOffset, num, Identifier());
+    // srcOffset is invalid after the insert.
+    if (dstOffset < srcOffset)
+      srcOffset += num;
+    std::move(identifiers.begin() + srcOffset,
+              identifiers.begin() + srcOffset + num,
+              identifiers.begin() + dstOffset);
+    identifiers.erase(identifiers.begin() + srcOffset,
+                      identifiers.begin() + srcOffset + num);
+  } else if (isUsingIds() && srcKind != VarKind::Local) {
+    identifiers.erase(identifiers.begin() + srcOffset,
+                      identifiers.begin() + srcOffset + num);
+  } else if (isUsingIds() && dstKind != VarKind::Local) {
+    identifiers.insert(identifiers.begin() + dstOffset, num, Identifier());
+  }
+
   auto addVars = [&](VarKind kind, int num) {
     switch (kind) {
     case VarKind::Domain:
@@ -180,22 +200,6 @@ void PresburgerSpace::convertVarKind(VarKind srcKind, unsigned srcPos,
 
   addVars(srcKind, -(signed)num);
   addVars(dstKind, num);
-
-  // Move identifiers if `usingIds` and variables moved are not locals.
-  unsigned srcOffset = getVarKindOffset(srcKind) + srcPos;
-  unsigned dstOffset = getVarKindOffset(dstKind) + dstPos;
-  if (isUsingIds() && srcKind != VarKind::Local && dstKind != VarKind::Local) {
-    identifiers.insert(identifiers.begin() + dstOffset, num, Identifier());
-    for (unsigned i = 0; i < num; ++i)
-      identifiers[dstOffset + i] = identifiers[srcOffset + i];
-    identifiers.erase(identifiers.begin() + srcOffset,
-                      identifiers.begin() + srcOffset + num);
-  } else if (isUsingIds() && srcKind != VarKind::Local) {
-    identifiers.erase(identifiers.begin() + srcOffset,
-                      identifiers.begin() + srcOffset + num);
-  } else if (isUsingIds() && dstKind != VarKind::Local) {
-    identifiers.insert(identifiers.begin() + dstOffset, num, Identifier());
-  }
 }
 
 void PresburgerSpace::swapVar(VarKind kindA, VarKind kindB, unsigned posA,
diff --git a/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp b/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp
index cb23174b939c38b..dd06d462f54bee7 100644
--- a/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/PresburgerSpaceTest.cpp
@@ -158,3 +158,24 @@ TEST(PresburgerSpaceTest, convertVarKindLocals) {
   EXPECT_FALSE(space.getId(VarKind::Range, 0).hasValue());
   EXPECT_FALSE(space.getId(VarKind::Range, 1).hasValue());
 }
+
+TEST(PresburgerSpaceTest, convertVarKind2) {
+  PresburgerSpace space = PresburgerSpace::getRelationSpace(0, 2, 2, 0);
+  space.resetIds();
+
+  // Attach identifiers.
+  int identifiers[4] = {0, 1, 2, 3};
+  space.getId(VarKind::Range, 0) = Identifier(&identifiers[0]);
+  space.getId(VarKind::Range, 1) = Identifier(&identifiers[1]);
+  space.getId(VarKind::Symbol, 0) = Identifier(&identifiers[2]);
+  space.getId(VarKind::Symbol, 1) = Identifier(&identifiers[3]);
+
+  // Convert Range variables to symbols.
+  space.convertVarKind(VarKind::Range, 0, 2, VarKind::Symbol, 1);
+
+  // Check if the identifiers are moved to symbols.
+  EXPECT_EQ(space.getId(VarKind::Symbol, 0), Identifier(&identifiers[2]));
+  EXPECT_EQ(space.getId(VarKind::Symbol, 1), Identifier(&identifiers[0]));
+  EXPECT_EQ(space.getId(VarKind::Symbol, 2), Identifier(&identifiers[1]));
+  EXPECT_EQ(space.getId(VarKind::Symbol, 3), Identifier(&identifiers[3]));
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/67267


More information about the Mlir-commits mailing list