[llvm] Add Dead Block Elimination to NVVMReflect (PR #144171)

Yonah Goldberg via llvm-commits llvm-commits at lists.llvm.org
Fri Jun 13 15:56:09 PDT 2025


https://github.com/YonahGoldberg created https://github.com/llvm/llvm-project/pull/144171

Currently, NVVMReflect replaces calls to __nvvm_reflect with a constant, and then constant propagates/folds the result, but doesn't handle dead block elimination.

The most common use case of reflect calls is to query the arch number and select valid code depending on the arch. Therefore, the blocks that become dead after reflect replacement need to be deleted as a matter of correctness.

The way this gets cleaned up now in llc is with UnreachableBlockElim followed by CodegenPrepare, which I've observed work together to delete the dead blocks. It's better to just have this pass handle deleting the dead blocks right away.

This PR introduces some additional code to handle the dead block deletion. I think what I've written is actually pretty general, it's kind've like a lightweight version of SCCP. I wonder if I missed somewhere where this is already implemented in LLVM so I don't duplicate code. If I didn't, would it ever be useful to put this somewhere more general where others can use it instead of in NVVMReflect?

Note that I also removed running simplifycfg in two test cases, which shows that this pass is now able to handle the dead block elimination without simplifycfg.

>From e9c4b6b2a9f4b1a6627181554c9d2ce5578764a7 Mon Sep 17 00:00:00 2001
From: Yonah Goldberg <ygoldberg at nvidia.com>
Date: Fri, 13 Jun 2025 22:20:23 +0000
Subject: [PATCH 1/6] updates

---
 llvm/lib/Target/NVPTX/NVVMReflect.cpp         | 113 +++++++++++++++---
 .../test/CodeGen/NVPTX/nvvm-reflect-opaque.ll |   4 +-
 llvm/test/CodeGen/NVPTX/nvvm-reflect.ll       |   4 +-
 3 files changed, 100 insertions(+), 21 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVVMReflect.cpp b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
index 208bab52284a3..1c17852503660 100644
--- a/llvm/lib/Target/NVPTX/NVVMReflect.cpp
+++ b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
@@ -39,6 +39,8 @@
 #include "llvm/Transforms/Scalar.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
 #include "llvm/Transforms/Utils/Local.h"
+#include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/SetVector.h"
 #define NVVM_REFLECT_FUNCTION "__nvvm_reflect"
 #define NVVM_REFLECT_OCL_FUNCTION "__nvvm_reflect_ocl"
 // Argument of reflect call to retrive arch number
@@ -59,7 +61,10 @@ class NVVMReflect {
   StringMap<unsigned> ReflectMap;
   bool handleReflectFunction(Module &M, StringRef ReflectName);
   void populateReflectMap(Module &M);
-  void foldReflectCall(CallInst *Call, Constant *NewValue);
+  void replaceReflectCalls(
+      SmallVector<std::pair<CallInst *, Constant *>, 8> &ReflectReplacements,
+      const DataLayout &DL);
+  SetVector<BasicBlock *> findTransitivelyDeadBlocks(BasicBlock *DeadBB);
 
 public:
   // __CUDA_FTZ is assigned in `runOnModule` by checking nvvm-reflect-ftz module
@@ -138,6 +143,8 @@ bool NVVMReflect::handleReflectFunction(Module &M, StringRef ReflectName) {
   assert(F->getReturnType()->isIntegerTy() &&
          "_reflect's return type should be integer");
 
+  SmallVector<std::pair<CallInst *, Constant *>, 8> ReflectReplacements;
+
   const bool Changed = !F->use_empty();
   for (User *U : make_early_inc_range(F->users())) {
     // Reflect function calls look like:
@@ -178,38 +185,110 @@ bool NVVMReflect::handleReflectFunction(Module &M, StringRef ReflectName) {
                       << "(" << ReflectArg << ") with value " << ReflectVal
                       << "\n");
     auto *NewValue = ConstantInt::get(Call->getType(), ReflectVal);
-    foldReflectCall(Call, NewValue);
-    Call->eraseFromParent();
+    ReflectReplacements.push_back({Call, NewValue});
   }
 
-  // Remove the __nvvm_reflect function from the module
+  replaceReflectCalls(ReflectReplacements, M.getDataLayout());
   F->eraseFromParent();
   return Changed;
 }
 
-void NVVMReflect::foldReflectCall(CallInst *Call, Constant *NewValue) {
+/// Find all blocks that become dead transitively from an initial dead block.
+/// Returns the complete set including the original dead block and any blocks
+/// that lose all their predecessors due to the deletion cascade.
+SetVector<BasicBlock *>
+NVVMReflect::findTransitivelyDeadBlocks(BasicBlock *DeadBB) {
+  SmallVector<BasicBlock *, 8> Worklist({DeadBB});
+  SetVector<BasicBlock *> DeadBlocks;
+  while (!Worklist.empty()) {
+    auto *BB = Worklist.pop_back_val();
+    DeadBlocks.insert(BB);
+
+    for (BasicBlock *Succ : successors(BB))
+      if (pred_size(Succ) == 1 && DeadBlocks.insert(Succ))
+        Worklist.push_back(Succ);
+  }
+  return DeadBlocks;
+}
+
+/// Replace calls to __nvvm_reflect with corresponding constant values. Then
+/// clean up through constant folding and propagation and dead block
+/// elimination.
+///
+/// The purpose of this cleanup is not optimization because that could be
+/// handled by later passes
+/// (i.e. SCCP, SimplifyCFG, etc.), but for correctness. Reflect calls are most
+/// commonly used to query the arch number and select a valid instruction for
+/// the arch. Therefore, you need to eliminate blocks that become dead because
+/// they may contain invalid instructions for the arch. The purpose of the
+/// cleanup is to do the minimal amount of work to leave the code in a valid
+/// state.
+void NVVMReflect::replaceReflectCalls(
+    SmallVector<std::pair<CallInst *, Constant *>, 8> &ReflectReplacements,
+    const DataLayout &DL) {
   SmallVector<Instruction *, 8> Worklist;
-  // Replace an instruction with a constant and add all users of the instruction
-  // to the worklist
+  SetVector<BasicBlock *> DeadBlocks;
+
+  // Replace an instruction with a constant and add all users to the worklist,
+  // then delete the instruction
   auto ReplaceInstructionWithConst = [&](Instruction *I, Constant *C) {
     for (auto *U : I->users())
       if (auto *UI = dyn_cast<Instruction>(U))
         Worklist.push_back(UI);
     I->replaceAllUsesWith(C);
+    I->eraseFromParent();
   };
 
-  ReplaceInstructionWithConst(Call, NewValue);
+  for (auto &[Call, NewValue] : ReflectReplacements)
+    ReplaceInstructionWithConst(Call, NewValue);
 
-  auto &DL = Call->getModule()->getDataLayout();
-  while (!Worklist.empty()) {
-    auto *I = Worklist.pop_back_val();
-    if (auto *C = ConstantFoldInstruction(I, DL)) {
-      ReplaceInstructionWithConst(I, C);
-      if (isInstructionTriviallyDead(I))
-        I->eraseFromParent();
-    } else if (I->isTerminator()) {
-      ConstantFoldTerminator(I->getParent());
+  // Alternate between constant folding/propagation and dead block elimination.
+  // Terminator folding may create new dead blocks. When those dead blocks are
+  // deleted, their live successors may have PHIs that can be simplified, which
+  // may yield more work for folding/propagation.
+  while (true) {
+    // Iterate folding and propagating constants until the worklist is empty.
+    while (!Worklist.empty()) {
+      auto *I = Worklist.pop_back_val();
+      if (auto *C = ConstantFoldInstruction(I, DL)) {
+        ReplaceInstructionWithConst(I, C);
+      } else if (I->isTerminator()) {
+        BasicBlock *BB = I->getParent();
+        SmallVector<BasicBlock *, 8> Succs(successors(BB));
+        // Some blocks may become dead if the terminator is folded because
+        // a conditional branch is turned into a direct branch.
+        if (ConstantFoldTerminator(BB)) {
+          for (BasicBlock *Succ : Succs) {
+            if (pred_empty(Succ) &&
+                Succ != &Succ->getParent()->getEntryBlock()) {
+              SetVector<BasicBlock *> TransitivelyDead =
+                  findTransitivelyDeadBlocks(Succ);
+              DeadBlocks.insert(TransitivelyDead.begin(),
+                                TransitivelyDead.end());
+            }
+          }
+        }
+      }
     }
+    // No more constants to fold and no more dead blocks
+    // to create more work. We're done.
+    if (DeadBlocks.empty())
+      break;
+    // PHI nodes of live successors of dead blocks get eliminated when the dead
+    // blocks are eliminated. Their users can now be simplified further, so add
+    // them to the worklist.
+    for (BasicBlock *DeadBB : DeadBlocks)
+      for (BasicBlock *Succ : successors(DeadBB))
+        if (!DeadBlocks.contains(Succ))
+          for (PHINode &PHI : Succ->phis())
+            for (auto *U : PHI.users())
+              if (auto *UI = dyn_cast<Instruction>(U))
+                Worklist.push_back(UI);
+    // Delete all dead blocks in order
+    for (BasicBlock *DeadBB : DeadBlocks)
+      DeleteDeadBlock(DeadBB);
+
+    DeadBlocks.clear();
   }
 }
 
diff --git a/llvm/test/CodeGen/NVPTX/nvvm-reflect-opaque.ll b/llvm/test/CodeGen/NVPTX/nvvm-reflect-opaque.ll
index 19c74df303702..7bb1af707001a 100644
--- a/llvm/test/CodeGen/NVPTX/nvvm-reflect-opaque.ll
+++ b/llvm/test/CodeGen/NVPTX/nvvm-reflect-opaque.ll
@@ -3,12 +3,12 @@
 
 ; RUN: cat %s > %t.noftz
 ; RUN: echo '!0 = !{i32 4, !"nvvm-reflect-ftz", i32 0}' >> %t.noftz
-; RUN: opt %t.noftz -S -mtriple=nvptx-nvidia-cuda -passes='nvvm-reflect,simplifycfg' \
+; RUN: opt %t.noftz -S -mtriple=nvptx-nvidia-cuda -passes='nvvm-reflect' \
 ; RUN:   | FileCheck %s --check-prefix=USE_FTZ_0 --check-prefix=CHECK
 
 ; RUN: cat %s > %t.ftz
 ; RUN: echo '!0 = !{i32 4, !"nvvm-reflect-ftz", i32 1}' >> %t.ftz
-; RUN: opt %t.ftz -S -mtriple=nvptx-nvidia-cuda -passes='nvvm-reflect,simplifycfg' \
+; RUN: opt %t.ftz -S -mtriple=nvptx-nvidia-cuda -passes='nvvm-reflect' \
 ; RUN:   | FileCheck %s --check-prefix=USE_FTZ_1 --check-prefix=CHECK
 
 @str = private unnamed_addr addrspace(4) constant [11 x i8] c"__CUDA_FTZ\00"
diff --git a/llvm/test/CodeGen/NVPTX/nvvm-reflect.ll b/llvm/test/CodeGen/NVPTX/nvvm-reflect.ll
index 244b44fea9b83..581dbf353c1ff 100644
--- a/llvm/test/CodeGen/NVPTX/nvvm-reflect.ll
+++ b/llvm/test/CodeGen/NVPTX/nvvm-reflect.ll
@@ -3,12 +3,12 @@
 
 ; RUN: cat %s > %t.noftz
 ; RUN: echo '!0 = !{i32 4, !"nvvm-reflect-ftz", i32 0}' >> %t.noftz
-; RUN: opt %t.noftz -S -mtriple=nvptx-nvidia-cuda -passes='nvvm-reflect,simplifycfg' \
+; RUN: opt %t.noftz -S -mtriple=nvptx-nvidia-cuda -passes='nvvm-reflect' \
 ; RUN:   | FileCheck %s --check-prefix=USE_FTZ_0 --check-prefix=CHECK
 
 ; RUN: cat %s > %t.ftz
 ; RUN: echo '!0 = !{i32 4, !"nvvm-reflect-ftz", i32 1}' >> %t.ftz
-; RUN: opt %t.ftz -S -mtriple=nvptx-nvidia-cuda -passes='nvvm-reflect,simplifycfg' \
+; RUN: opt %t.ftz -S -mtriple=nvptx-nvidia-cuda -passes='nvvm-reflect' \
 ; RUN:   | FileCheck %s --check-prefix=USE_FTZ_1 --check-prefix=CHECK
 
 @str = private unnamed_addr addrspace(4) constant [11 x i8] c"__CUDA_FTZ\00"

>From d5e1cfc8bfaaf736493743d17af464a743c34cb2 Mon Sep 17 00:00:00 2001
From: Yonah Goldberg <ygoldberg at nvidia.com>
Date: Fri, 13 Jun 2025 22:20:31 +0000
Subject: [PATCH 2/6] format

---
 llvm/lib/Target/NVPTX/NVVMReflect.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVVMReflect.cpp b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
index 1c17852503660..5b24864ab586f 100644
--- a/llvm/lib/Target/NVPTX/NVVMReflect.cpp
+++ b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
@@ -19,6 +19,8 @@
 //===----------------------------------------------------------------------===//
 
 #include "NVPTX.h"
+#include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Analysis/ConstantFolding.h"
@@ -39,8 +41,6 @@
 #include "llvm/Transforms/Scalar.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
 #include "llvm/Transforms/Utils/Local.h"
-#include "llvm/ADT/MapVector.h"
-#include "llvm/ADT/SetVector.h"
 #define NVVM_REFLECT_FUNCTION "__nvvm_reflect"
 #define NVVM_REFLECT_OCL_FUNCTION "__nvvm_reflect_ocl"
 // Argument of reflect call to retrive arch number

>From 2c951c83fd46a0d046dd72fa397de678d0c834bd Mon Sep 17 00:00:00 2001
From: Yonah Goldberg <ygoldberg at nvidia.com>
Date: Fri, 13 Jun 2025 22:23:10 +0000
Subject: [PATCH 3/6] format

---
 llvm/lib/Target/NVPTX/NVVMReflect.cpp | 11 ++++-------
 1 file changed, 4 insertions(+), 7 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVVMReflect.cpp b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
index 5b24864ab586f..b0f69598972ce 100644
--- a/llvm/lib/Target/NVPTX/NVVMReflect.cpp
+++ b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
@@ -259,12 +259,9 @@ void NVVMReflect::replaceReflectCalls(
         // a conditional branch is turned into a direct branch.
         if (ConstantFoldTerminator(BB)) {
           for (BasicBlock *Succ : Succs) {
-            if (pred_empty(Succ) &&
-                Succ != &Succ->getParent()->getEntryBlock()) {
-              SetVector<BasicBlock *> TransitivelyDead =
-                  findTransitivelyDeadBlocks(Succ);
-              DeadBlocks.insert(TransitivelyDead.begin(),
-                                TransitivelyDead.end());
+            if (pred_empty(Succ) && Succ != &Succ->getParent()->getEntryBlock()) {
+              SetVector<BasicBlock *> TransitivelyDead = findTransitivelyDeadBlocks(Succ);
+              DeadBlocks.insert(TransitivelyDead.begin(), TransitivelyDead.end());
             }
           }
         }
@@ -284,7 +281,7 @@ void NVVMReflect::replaceReflectCalls(
             for (auto *U : PHI.users())
               if (auto *UI = dyn_cast<Instruction>(U))
                 Worklist.push_back(UI);
-    // Delete all dead blocks in order
+    // Delete all dead blocks
     for (BasicBlock *DeadBB : DeadBlocks)
       DeleteDeadBlock(DeadBB);
 

>From 8fbef2a34339f66e7aaeab113372edde945c1fb1 Mon Sep 17 00:00:00 2001
From: Yonah Goldberg <ygoldberg at nvidia.com>
Date: Fri, 13 Jun 2025 22:23:15 +0000
Subject: [PATCH 4/6] format

---
 llvm/lib/Target/NVPTX/NVVMReflect.cpp | 9 ++++++---
 1 file changed, 6 insertions(+), 3 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVVMReflect.cpp b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
index b0f69598972ce..74c3efd18ad89 100644
--- a/llvm/lib/Target/NVPTX/NVVMReflect.cpp
+++ b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
@@ -259,9 +259,12 @@ void NVVMReflect::replaceReflectCalls(
         // a conditional branch is turned into a direct branch.
         if (ConstantFoldTerminator(BB)) {
           for (BasicBlock *Succ : Succs) {
-            if (pred_empty(Succ) && Succ != &Succ->getParent()->getEntryBlock()) {
-              SetVector<BasicBlock *> TransitivelyDead = findTransitivelyDeadBlocks(Succ);
-              DeadBlocks.insert(TransitivelyDead.begin(), TransitivelyDead.end());
+            if (pred_empty(Succ) &&
+                Succ != &Succ->getParent()->getEntryBlock()) {
+              SetVector<BasicBlock *> TransitivelyDead =
+                  findTransitivelyDeadBlocks(Succ);
+              DeadBlocks.insert(TransitivelyDead.begin(),
+                                TransitivelyDead.end());
             }
           }
         }

>From f9eedaadebddf6cc60cf3d6a5424e30bfdda7fb5 Mon Sep 17 00:00:00 2001
From: Yonah Goldberg <ygoldberg at nvidia.com>
Date: Fri, 13 Jun 2025 22:30:54 +0000
Subject: [PATCH 5/6] cleanup

---
 llvm/lib/Target/NVPTX/NVVMReflect.cpp | 1 -
 1 file changed, 1 deletion(-)

diff --git a/llvm/lib/Target/NVPTX/NVVMReflect.cpp b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
index 74c3efd18ad89..fd9225838b243 100644
--- a/llvm/lib/Target/NVPTX/NVVMReflect.cpp
+++ b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
@@ -19,7 +19,6 @@
 //===----------------------------------------------------------------------===//
 
 #include "NVPTX.h"
-#include "llvm/ADT/MapVector.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringExtras.h"

>From 4ba9826586f37cc6f3e0d5eb3261b5132e69c003 Mon Sep 17 00:00:00 2001
From: Yonah Goldberg <ygoldberg at nvidia.com>
Date: Fri, 13 Jun 2025 22:33:19 +0000
Subject: [PATCH 6/6] added back isInstructionTriviallyDead check

---
 llvm/lib/Target/NVPTX/NVVMReflect.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Target/NVPTX/NVVMReflect.cpp b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
index fd9225838b243..2585ff45bde4c 100644
--- a/llvm/lib/Target/NVPTX/NVVMReflect.cpp
+++ b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
@@ -235,7 +235,8 @@ void NVVMReflect::replaceReflectCalls(
       if (auto *UI = dyn_cast<Instruction>(U))
         Worklist.push_back(UI);
     I->replaceAllUsesWith(C);
-    I->eraseFromParent();
+    if (isInstructionTriviallyDead(I))
+      I->eraseFromParent();
   };
 
   for (auto &[Call, NewValue] : ReflectReplacements)



More information about the llvm-commits mailing list