[llvm] [SandboxVec][DAG] Fix DAG when old interval is mem free (PR #126983)

via llvm-commits llvm-commits at lists.llvm.org
Wed Feb 12 14:36:45 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-vectorizers

Author: vporpo (vporpo)

<details>
<summary>Changes</summary>

This patch fixes a bug in `DependencyGraph::extend()` when the old interval contains no memory instructions. When this is the case we should do a full dependency scan of the new interval.

---
Full diff: https://github.com/llvm/llvm-project/pull/126983.diff


2 Files Affected:

- (modified) llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp (+6-5) 
- (modified) llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp (+39) 


``````````diff
diff --git a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
index 06a5e3bed7f03..098b296c30ab8 100644
--- a/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
+++ b/llvm/lib/Transforms/Vectorize/SandboxVectorizer/DependencyGraph.cpp
@@ -122,6 +122,8 @@ MemDGNodeIntervalBuilder::getBotMemDGNode(const Interval<Instruction> &Intvl,
 Interval<MemDGNode>
 MemDGNodeIntervalBuilder::make(const Interval<Instruction> &Instrs,
                                DependencyGraph &DAG) {
+  if (Instrs.empty())
+    return {};
   auto *TopMemN = getTopMemDGNode(Instrs, DAG);
   // If we couldn't find a mem node in range TopN - BotN then it's empty.
   if (TopMemN == nullptr)
@@ -529,8 +531,8 @@ Interval<Instruction> DependencyGraph::extend(ArrayRef<Instruction *> Instrs) {
       }
     }
   };
-  if (DAGInterval.empty()) {
-    assert(NewInterval == InstrsInterval && "Expected empty DAGInterval!");
+  auto MemDAGInterval = MemDGNodeIntervalBuilder::make(DAGInterval, *this);
+  if (MemDAGInterval.empty()) {
     FullScan(NewInterval);
   }
   // 2. The new section is below the old section.
@@ -550,8 +552,7 @@ Interval<Instruction> DependencyGraph::extend(ArrayRef<Instruction *> Instrs) {
   // range including both NewInterval and DAGInterval until DstN, for each DstN.
   else if (DAGInterval.bottom()->comesBefore(NewInterval.top())) {
     auto DstRange = MemDGNodeIntervalBuilder::make(NewInterval, *this);
-    auto SrcRangeFull = MemDGNodeIntervalBuilder::make(
-        DAGInterval.getUnionInterval(NewInterval), *this);
+    auto SrcRangeFull = MemDAGInterval.getUnionInterval(DstRange);
     for (MemDGNode &DstN : DstRange) {
       auto SrcRange =
           Interval<MemDGNode>(SrcRangeFull.top(), DstN.getPrevNode());
@@ -589,7 +590,7 @@ Interval<Instruction> DependencyGraph::extend(ArrayRef<Instruction *> Instrs) {
     // When scanning for deps with destination in DAGInterval we need to
     // consider sources from the NewInterval only, because all intra-DAGInterval
     // dependencies have already been created.
-    auto DstRangeOld = MemDGNodeIntervalBuilder::make(DAGInterval, *this);
+    auto DstRangeOld = MemDAGInterval;
     auto SrcRange = MemDGNodeIntervalBuilder::make(NewInterval, *this);
     for (MemDGNode &DstN : DstRangeOld)
       scanAndAddDeps(DstN, SrcRange);
diff --git a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
index f1e9afefb4531..37f29428e900a 100644
--- a/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
+++ b/llvm/unittests/Transforms/Vectorize/SandboxVectorizer/DependencyGraphTest.cpp
@@ -1013,3 +1013,42 @@ define void @foo(ptr %ptr, i8 %v1, i8 %v2, i8 %arg) {
   EXPECT_EQ(S2N->getNextNode(), S1N);
   EXPECT_EQ(S1N->getNextNode(), nullptr);
 }
+
+// Extending an "Old" interval with no mem instructions.
+TEST_F(DependencyGraphTest, ExtendDAGWithNoMem) {
+  parseIR(C, R"IR(
+define void @foo(ptr %ptr, i8 %v, i8 %v0, i8 %v1, i8 %v2, i8 %v3) {
+  store i8 %v0, ptr %ptr
+  store i8 %v1, ptr %ptr
+  %zext1 = zext i8 %v to i32
+  %zext2 = zext i8 %v to i32
+  store i8 %v2, ptr %ptr
+  store i8 %v3, ptr %ptr
+  ret void
+}
+)IR");
+  llvm::Function *LLVMF = &*M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  auto *F = Ctx.createFunction(LLVMF);
+  auto *BB = &*F->begin();
+  auto It = BB->begin();
+  auto *S0 = cast<sandboxir::StoreInst>(&*It++);
+  auto *S1 = cast<sandboxir::StoreInst>(&*It++);
+  auto *Z1 = cast<sandboxir::CastInst>(&*It++);
+  auto *Z2 = cast<sandboxir::CastInst>(&*It++);
+  auto *S2 = cast<sandboxir::StoreInst>(&*It++);
+  auto *S3 = cast<sandboxir::StoreInst>(&*It++);
+
+  sandboxir::DependencyGraph DAG(getAA(*LLVMF), Ctx);
+  // Create a non-empty DAG that contains no memory instructions.
+  DAG.extend({Z1, Z2});
+  // Now extend it downwards.
+  DAG.extend({S2, S3});
+  EXPECT_TRUE(memDependency(DAG.getNode(S2), DAG.getNode(S3)));
+
+  // Same but upwards.
+  DAG.clear();
+  DAG.extend({Z1, Z2});
+  DAG.extend({S0, S1});
+  EXPECT_TRUE(memDependency(DAG.getNode(S0), DAG.getNode(S1)));
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/126983


More information about the llvm-commits mailing list