[llvm] [NVPTX] Fix code generation for `trap-unreachable`. (PR #67478)

Christian Sigg via llvm-commits llvm-commits at lists.llvm.org
Thu Sep 28 09:11:51 PDT 2023


https://github.com/chsigg updated https://github.com/llvm/llvm-project/pull/67478

>From 76f722a18debdb2362eb7cc38943bd4ec4cadc62 Mon Sep 17 00:00:00 2001
From: Christian Sigg <csigg at google.com>
Date: Wed, 27 Sep 2023 09:32:26 +0200
Subject: [PATCH 1/5] [NVPTX] Fix code generation for `trap-unreachable`.

https://reviews.llvm.org/D152789 added an `exit` op before each `unreachable`. This means we never get to the `trap` instruction.

- When `trap-unreachable` is enabled and `no-trap-after-noreturn` is not, don't insert `exit` before each `unreachable`.
- Lower ISD::TRAP to both `trap` and `exit` instead of just the former.

The fix doesn't work with `no-trap-after-noreturn`, because the `unreachable`s not following a `noreturn` are lowered to `exit; trap; exit;`.

An alternative approach would be to insert `trap`s in `NVPTXLowerUnreachablePass`, depending on the `trap-unreachable` and `no-trap-after-noreturn` settings. I think we would then want skip lowering ISD::TRAP, so that we don't end up with `trap; exit; trap;` sequences.
---
 llvm/lib/Target/NVPTX/NVPTXInstrInfo.td      |  4 +++-
 llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp | 10 ++--------
 llvm/test/CodeGen/NVPTX/unreachable.ll       |  6 +++++-
 3 files changed, 10 insertions(+), 10 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index ad10d7938ef12e4..fb5ae339cb11a74 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -3545,7 +3545,9 @@ def Callseq_End :
             [(callseq_end timm:$amt1, timm:$amt2)]>;
 
 // trap instruction
-def trapinst : NVPTXInst<(outs), (ins), "trap;", [(trap)]>;
+// Emit an `exit` as well to convey to ptxas that `trap` exits the CFG.
+// This won't be necessary in a future version of ptxas.
+def trapinst : NVPTXInst<(outs), (ins), "trap; exit;", [(trap)]>;
 
 // Call prototype wrapper
 def SDTCallPrototype : SDTypeProfile<0, 1, [SDTCisInt<0>]>;
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
index cad97b1f14eb2b9..0aaeff9fa76b15c 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
@@ -63,13 +63,6 @@ static cl::opt<bool> UseShortPointersOpt(
         "Use 32-bit pointers for accessing const/local/shared address spaces."),
     cl::init(false), cl::Hidden);
 
-// FIXME: intended as a temporary debugging aid. Should be removed before it
-// makes it into the LLVM-17 release.
-static cl::opt<bool>
-    ExitOnUnreachable("nvptx-exit-on-unreachable",
-                      cl::desc("Lower 'unreachable' as 'exit' instruction."),
-                      cl::init(true), cl::Hidden);
-
 namespace llvm {
 
 void initializeGenericToNVVMLegacyPassPass(PassRegistry &);
@@ -410,7 +403,8 @@ void NVPTXPassConfig::addIRPasses() {
     addPass(createSROAPass());
   }
 
-  if (ExitOnUnreachable)
+  const auto &options = getNVPTXTargetMachine().Options;
+  if (!options.TrapUnreachable || options.NoTrapAfterNoreturn)
     addPass(createNVPTXLowerUnreachablePass());
 }
 
diff --git a/llvm/test/CodeGen/NVPTX/unreachable.ll b/llvm/test/CodeGen/NVPTX/unreachable.ll
index 742089df1bd4533..05d5a750c444c93 100644
--- a/llvm/test/CodeGen/NVPTX/unreachable.ll
+++ b/llvm/test/CodeGen/NVPTX/unreachable.ll
@@ -1,5 +1,7 @@
 ; RUN: llc < %s -march=nvptx -mcpu=sm_20 -verify-machineinstrs | FileCheck %s
 ; RUN: llc < %s -march=nvptx64 -mcpu=sm_20 -verify-machineinstrs | FileCheck %s
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_20 -verify-machineinstrs -trap-unreachable \
+; RUN:     | FileCheck %s --check-prefix=CHECK --check-prefix=CHECK-TRAP
 ; RUN: %if ptxas && !ptxas-12.0 %{ llc < %s -march=nvptx -mcpu=sm_20 -verify-machineinstrs | %ptxas-verify %}
 ; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_20 -verify-machineinstrs | %ptxas-verify %}
 
@@ -11,7 +13,9 @@ define void @kernel_func() {
 ; CHECK: call.uni
 ; CHECK: throw,
   call void @throw()
-; CHECK: exit
+; CHECK-TRAP-NOT: exit;
+; CHECK-TRAP: trap;
+; CHECK: exit;
   unreachable
 }
 

>From 1ef73791d52a76a0a4c754c8b541ce914f9d87fa Mon Sep 17 00:00:00 2001
From: Christian Sigg <csigg at google.com>
Date: Wed, 27 Sep 2023 09:32:46 +0200
Subject: [PATCH 2/5] Implement the logic in `NVPTXLowerUnreachablePass`
 instead.

---
 .../SelectionDAG/SelectionDAGBuilder.cpp      |  7 +--
 llvm/lib/Target/NVPTX/NVPTX.h                 |  3 +-
 llvm/lib/Target/NVPTX/NVPTXInstrInfo.td       |  4 +-
 .../Target/NVPTX/NVPTXLowerUnreachable.cpp    | 45 ++++++++++++++++---
 llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp  |  6 +--
 llvm/test/CodeGen/NVPTX/unreachable.ll        |  9 +++-
 6 files changed, 52 insertions(+), 22 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index f39b62abdd87790..ba0ab3586f75825 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -3226,14 +3226,9 @@ void SelectionDAGBuilder::visitUnreachable(const UnreachableInst &I) {
 
   // We may be able to ignore unreachable behind a noreturn call.
   if (DAG.getTarget().Options.NoTrapAfterNoreturn) {
-    const BasicBlock &BB = *I.getParent();
-    if (&I != &BB.front()) {
-      BasicBlock::const_iterator PredI =
-        std::prev(BasicBlock::const_iterator(&I));
-      if (const CallInst *Call = dyn_cast<CallInst>(&*PredI)) {
+    if (const CallInst *Call = dyn_cast_or_null<CallInst>(I.getPrevNode())) {
         if (Call->doesNotReturn())
           return;
-      }
     }
   }
 
diff --git a/llvm/lib/Target/NVPTX/NVPTX.h b/llvm/lib/Target/NVPTX/NVPTX.h
index c5816b9266dfd9e..8dc68911fff0c05 100644
--- a/llvm/lib/Target/NVPTX/NVPTX.h
+++ b/llvm/lib/Target/NVPTX/NVPTX.h
@@ -47,7 +47,8 @@ MachineFunctionPass *createNVPTXReplaceImageHandlesPass();
 FunctionPass *createNVPTXImageOptimizerPass();
 FunctionPass *createNVPTXLowerArgsPass();
 FunctionPass *createNVPTXLowerAllocaPass();
-FunctionPass *createNVPTXLowerUnreachablePass();
+FunctionPass *createNVPTXLowerUnreachablePass(bool TrapUnreachable,
+                                              bool NoTrapAfterNoreturn);
 MachineFunctionPass *createNVPTXPeephole();
 MachineFunctionPass *createNVPTXProxyRegErasurePass();
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index fb5ae339cb11a74..ad10d7938ef12e4 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -3545,9 +3545,7 @@ def Callseq_End :
             [(callseq_end timm:$amt1, timm:$amt2)]>;
 
 // trap instruction
-// Emit an `exit` as well to convey to ptxas that `trap` exits the CFG.
-// This won't be necessary in a future version of ptxas.
-def trapinst : NVPTXInst<(outs), (ins), "trap; exit;", [(trap)]>;
+def trapinst : NVPTXInst<(outs), (ins), "trap;", [(trap)]>;
 
 // Call prototype wrapper
 def SDTCallPrototype : SDTypeProfile<0, 1, [SDTCisInt<0>]>;
diff --git a/llvm/lib/Target/NVPTX/NVPTXLowerUnreachable.cpp b/llvm/lib/Target/NVPTX/NVPTXLowerUnreachable.cpp
index 1d312f82e6c061c..efafd909b93be37 100644
--- a/llvm/lib/Target/NVPTX/NVPTXLowerUnreachable.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXLowerUnreachable.cpp
@@ -72,6 +72,7 @@
 #include "llvm/IR/Function.h"
 #include "llvm/IR/InlineAsm.h"
 #include "llvm/IR/Instructions.h"
+#include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/Type.h"
 #include "llvm/Pass.h"
 
@@ -83,14 +84,19 @@ void initializeNVPTXLowerUnreachablePass(PassRegistry &);
 
 namespace {
 class NVPTXLowerUnreachable : public FunctionPass {
+  StringRef getPassName() const override;
   bool runOnFunction(Function &F) override;
+  bool shouldEmitTrap(const UnreachableInst &I) const;
 
 public:
   static char ID; // Pass identification, replacement for typeid
-  NVPTXLowerUnreachable() : FunctionPass(ID) {}
-  StringRef getPassName() const override {
-    return "add an exit instruction before every unreachable";
-  }
+  NVPTXLowerUnreachable(bool TrapUnreachable, bool NoTrapAfterNoreturn)
+      : FunctionPass(ID), TrapUnreachable(TrapUnreachable),
+        NoTrapAfterNoreturn(NoTrapAfterNoreturn) {}
+
+private:
+  bool TrapUnreachable;
+  bool NoTrapAfterNoreturn;
 };
 } // namespace
 
@@ -99,6 +105,24 @@ char NVPTXLowerUnreachable::ID = 1;
 INITIALIZE_PASS(NVPTXLowerUnreachable, "nvptx-lower-unreachable",
                 "Lower Unreachable", false, false)
 
+StringRef NVPTXLowerUnreachable::getPassName() const {
+  return "add an exit instruction before every unreachable";
+}
+
+// =============================================================================
+// Returns whether a `trap` intrinsic should be emitted before I.
+//
+// This is a copy of the logic in SelectionDAGBuilder::visitUnreachable().
+// =============================================================================
+bool NVPTXLowerUnreachable::shouldEmitTrap(const UnreachableInst &I) const {
+  if (!TrapUnreachable)
+    return false;
+  if (!NoTrapAfterNoreturn)
+    return true;
+  const CallInst *Call = dyn_cast_or_null<CallInst>(I.getPrevNode());
+  return Call && Call->doesNotReturn();
+}
+
 // =============================================================================
 // Main function for this pass.
 // =============================================================================
@@ -109,18 +133,25 @@ bool NVPTXLowerUnreachable::runOnFunction(Function &F) {
   LLVMContext &C = F.getContext();
   FunctionType *ExitFTy = FunctionType::get(Type::getVoidTy(C), false);
   InlineAsm *Exit = InlineAsm::get(ExitFTy, "exit;", "", true);
+  Function *Trap = nullptr;
 
   bool Changed = false;
   for (auto &BB : F)
     for (auto &I : BB) {
       if (auto unreachableInst = dyn_cast<UnreachableInst>(&I)) {
-        Changed = true;
+        if (shouldEmitTrap(*unreachableInst)) {
+          if (!Trap)
+            Trap = Intrinsic::getDeclaration(F.getParent(), Intrinsic::trap);
+          CallInst::Create(Trap, "", unreachableInst);
+        }
         CallInst::Create(ExitFTy, Exit, "", unreachableInst);
+        Changed = true;
       }
     }
   return Changed;
 }
 
-FunctionPass *llvm::createNVPTXLowerUnreachablePass() {
-  return new NVPTXLowerUnreachable();
+FunctionPass *llvm::createNVPTXLowerUnreachablePass(bool TrapUnreachable,
+                                                    bool NoTrapAfterNoreturn) {
+  return new NVPTXLowerUnreachable(TrapUnreachable, NoTrapAfterNoreturn);
 }
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
index 0aaeff9fa76b15c..8d895762fbe1d9d 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
@@ -403,9 +403,9 @@ void NVPTXPassConfig::addIRPasses() {
     addPass(createSROAPass());
   }
 
-  const auto &options = getNVPTXTargetMachine().Options;
-  if (!options.TrapUnreachable || options.NoTrapAfterNoreturn)
-    addPass(createNVPTXLowerUnreachablePass());
+  const auto &Options = getNVPTXTargetMachine().Options;
+  addPass(createNVPTXLowerUnreachablePass(Options.TrapUnreachable,
+                                          Options.NoTrapAfterNoreturn));
 }
 
 bool NVPTXPassConfig::addInstSelector() {
diff --git a/llvm/test/CodeGen/NVPTX/unreachable.ll b/llvm/test/CodeGen/NVPTX/unreachable.ll
index 05d5a750c444c93..011497c4e23401a 100644
--- a/llvm/test/CodeGen/NVPTX/unreachable.ll
+++ b/llvm/test/CodeGen/NVPTX/unreachable.ll
@@ -1,5 +1,9 @@
-; RUN: llc < %s -march=nvptx -mcpu=sm_20 -verify-machineinstrs | FileCheck %s
-; RUN: llc < %s -march=nvptx64 -mcpu=sm_20 -verify-machineinstrs | FileCheck %s
+; RUN: llc < %s -march=nvptx -mcpu=sm_20 -verify-machineinstrs \
+; RUN:     | FileCheck %s  --check-prefix=CHECK --check-prefix=CHECK-NOTRAP
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_20 -verify-machineinstrs \
+; RUN:     | FileCheck %s  --check-prefix=CHECK --check-prefix=CHECK-NOTRAP
+; RUN: llc < %s -march=nvptx -mcpu=sm_20 -verify-machineinstrs -trap-unreachable \
+; RUN:     | FileCheck %s --check-prefix=CHECK --check-prefix=CHECK-TRAP
 ; RUN: llc < %s -march=nvptx64 -mcpu=sm_20 -verify-machineinstrs -trap-unreachable \
 ; RUN:     | FileCheck %s --check-prefix=CHECK --check-prefix=CHECK-TRAP
 ; RUN: %if ptxas && !ptxas-12.0 %{ llc < %s -march=nvptx -mcpu=sm_20 -verify-machineinstrs | %ptxas-verify %}
@@ -15,6 +19,7 @@ define void @kernel_func() {
   call void @throw()
 ; CHECK-TRAP-NOT: exit;
 ; CHECK-TRAP: trap;
+; CHECK-NOTRAP-NOT: trap;
 ; CHECK: exit;
   unreachable
 }

>From f0e772fe2cb2ff2c587803d5d091c7db6bdf8af5 Mon Sep 17 00:00:00 2001
From: Christian Sigg <chsigg at users.noreply.github.com>
Date: Thu, 28 Sep 2023 17:49:36 +0200
Subject: [PATCH 3/5] Skip adding `exit` before `unreachable` if it's lowered
 to `trap`

---
 .../lib/Target/NVPTX/NVPTXLowerUnreachable.cpp | 18 +++++++-----------
 1 file changed, 7 insertions(+), 11 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXLowerUnreachable.cpp b/llvm/lib/Target/NVPTX/NVPTXLowerUnreachable.cpp
index efafd909b93be37..fd9922774fe8e65 100644
--- a/llvm/lib/Target/NVPTX/NVPTXLowerUnreachable.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXLowerUnreachable.cpp
@@ -63,8 +63,9 @@
 // `bar.sync` instruction happen divergently.
 //
 // To work around this, we add an `exit` instruction before every `unreachable`,
-// as `ptxas` understands that exit terminates the CFG. Note that `trap` is not
-// equivalent, and only future versions of `ptxas` will model it like `exit`.
+// as `ptxas` understands that exit terminates the CFG. We do only do this if
+// `unreachable` is not lowered to `trap`, which has the same effect (although
+// with current versions of `ptxas` only because it is emited as `trap; exit;`).
 //
 //===----------------------------------------------------------------------===//
 
@@ -72,7 +73,6 @@
 #include "llvm/IR/Function.h"
 #include "llvm/IR/InlineAsm.h"
 #include "llvm/IR/Instructions.h"
-#include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/Type.h"
 #include "llvm/Pass.h"
 
@@ -86,7 +86,7 @@ namespace {
 class NVPTXLowerUnreachable : public FunctionPass {
   StringRef getPassName() const override;
   bool runOnFunction(Function &F) override;
-  bool shouldEmitTrap(const UnreachableInst &I) const;
+  bool isLoweredToTrap(const UnreachableInst &I) const;
 
 public:
   static char ID; // Pass identification, replacement for typeid
@@ -114,7 +114,7 @@ StringRef NVPTXLowerUnreachable::getPassName() const {
 //
 // This is a copy of the logic in SelectionDAGBuilder::visitUnreachable().
 // =============================================================================
-bool NVPTXLowerUnreachable::shouldEmitTrap(const UnreachableInst &I) const {
+bool NVPTXLowerUnreachable::isLoweredToTrap(const UnreachableInst &I) const {
   if (!TrapUnreachable)
     return false;
   if (!NoTrapAfterNoreturn)
@@ -133,17 +133,13 @@ bool NVPTXLowerUnreachable::runOnFunction(Function &F) {
   LLVMContext &C = F.getContext();
   FunctionType *ExitFTy = FunctionType::get(Type::getVoidTy(C), false);
   InlineAsm *Exit = InlineAsm::get(ExitFTy, "exit;", "", true);
-  Function *Trap = nullptr;
 
   bool Changed = false;
   for (auto &BB : F)
     for (auto &I : BB) {
       if (auto unreachableInst = dyn_cast<UnreachableInst>(&I)) {
-        if (shouldEmitTrap(*unreachableInst)) {
-          if (!Trap)
-            Trap = Intrinsic::getDeclaration(F.getParent(), Intrinsic::trap);
-          CallInst::Create(Trap, "", unreachableInst);
-        }
+        if (isLoweredToTrap(*unreachableInst))
+          continue; // trap is emitted as `trap; exit;`.
         CallInst::Create(ExitFTy, Exit, "", unreachableInst);
         Changed = true;
       }

>From 5c165f71afb92e9ba6535fed1c4803908821261a Mon Sep 17 00:00:00 2001
From: Christian Sigg <chsigg at users.noreply.github.com>
Date: Thu, 28 Sep 2023 17:51:51 +0200
Subject: [PATCH 4/5] Emit trap as `trap; exit;`.

---
 llvm/lib/Target/NVPTX/NVPTXInstrInfo.td | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index ad10d7938ef12e4..fb5ae339cb11a74 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -3545,7 +3545,9 @@ def Callseq_End :
             [(callseq_end timm:$amt1, timm:$amt2)]>;
 
 // trap instruction
-def trapinst : NVPTXInst<(outs), (ins), "trap;", [(trap)]>;
+// Emit an `exit` as well to convey to ptxas that `trap` exits the CFG.
+// This won't be necessary in a future version of ptxas.
+def trapinst : NVPTXInst<(outs), (ins), "trap; exit;", [(trap)]>;
 
 // Call prototype wrapper
 def SDTCallPrototype : SDTypeProfile<0, 1, [SDTCisInt<0>]>;

>From e4d6a915d654e929dbb8ae6077f081f53bbfd42f Mon Sep 17 00:00:00 2001
From: Christian Sigg <chsigg at users.noreply.github.com>
Date: Thu, 28 Sep 2023 18:11:42 +0200
Subject: [PATCH 5/5] Fix formatting.

---
 llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index ba0ab3586f75825..c5fd56795a5201a 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -3227,8 +3227,8 @@ void SelectionDAGBuilder::visitUnreachable(const UnreachableInst &I) {
   // We may be able to ignore unreachable behind a noreturn call.
   if (DAG.getTarget().Options.NoTrapAfterNoreturn) {
     if (const CallInst *Call = dyn_cast_or_null<CallInst>(I.getPrevNode())) {
-        if (Call->doesNotReturn())
-          return;
+      if (Call->doesNotReturn())
+        return;
     }
   }
 



More information about the llvm-commits mailing list