[llvm] [NVVMReflect] Force dead branch elimination in NVVMReflect (PR #81189)

Joseph Huber via llvm-commits llvm-commits at lists.llvm.org
Thu Feb 8 14:26:11 PST 2024


https://github.com/jhuber6 updated https://github.com/llvm/llvm-project/pull/81189

>From 92d9b7838b07951ad6f6933871eaf47ed80a830a Mon Sep 17 00:00:00 2001
From: Joseph Huber <huberjn at outlook.com>
Date: Thu, 8 Feb 2024 14:59:17 -0600
Subject: [PATCH 1/3] [NNVMReflect] Force dead branch elimination in
 NNVMReflect

Summary:
The `__nvvm_reflect` function is used to guard invalid code that varies
between architectures. One problem with this feature is that if it is
used without optimizations, it will leave invalid code in the module
that will then make it to the backend. The `__nvvm_reflect` pass is
already mandatory, so it should do some trivial branch removal to ensure
that constants are handled correctly. This dead branch elimination only
works in the trivial case of a compare on a branch and does not touch
any conditionals that were not realted to the `__nvvm_reflect` call in
order to preserve `O0` semantics as much as possible. This should allow
the following to work on NVPTX targets

```c
int foo() {
  if (__nvvm_reflect__("__CUDA_ARCH") >= 700)
    asm("valid;\n");
}
```
---
 llvm/lib/Target/NVPTX/NVVMReflect.cpp         |  61 +++++++++++
 .../CodeGen/NVPTX/nvvm-reflect-arch-O0.ll     | 102 ++++++++++++++++++
 llvm/test/CodeGen/NVPTX/nvvm-reflect-arch.ll  |   1 -
 3 files changed, 163 insertions(+), 1 deletion(-)
 create mode 100644 llvm/test/CodeGen/NVPTX/nvvm-reflect-arch-O0.ll

diff --git a/llvm/lib/Target/NVPTX/NVVMReflect.cpp b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
index 7d2678ae592748..78b50c07e1c422 100644
--- a/llvm/lib/Target/NVPTX/NVVMReflect.cpp
+++ b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
@@ -20,6 +20,7 @@
 
 #include "NVPTX.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/Analysis/ConstantFolding.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/Function.h"
@@ -36,6 +37,8 @@
 #include "llvm/Support/raw_os_ostream.h"
 #include "llvm/Support/raw_ostream.h"
 #include "llvm/Transforms/Scalar.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+#include "llvm/Transforms/Utils/Local.h"
 #include <sstream>
 #include <string>
 #define NVVM_REFLECT_FUNCTION "__nvvm_reflect"
@@ -87,6 +90,7 @@ static bool runNVVMReflect(Function &F, unsigned SmVersion) {
   }
 
   SmallVector<Instruction *, 4> ToRemove;
+  SmallVector<ICmpInst *, 4> ToSimplify;
 
   // Go through the calls in this function.  Each call to __nvvm_reflect or
   // llvm.nvvm.reflect should be a CallInst with a ConstantArray argument.
@@ -171,6 +175,12 @@ static bool runNVVMReflect(Function &F, unsigned SmVersion) {
     } else if (ReflectArg == "__CUDA_ARCH") {
       ReflectVal = SmVersion * 10;
     }
+
+    // If the immediate user is a simple comparison we want to simplify it.
+    for (User *U : Call->users())
+      if (ICmpInst *I = dyn_cast<ICmpInst>(U))
+        ToSimplify.push_back(I);
+
     Call->replaceAllUsesWith(ConstantInt::get(Call->getType(), ReflectVal));
     ToRemove.push_back(Call);
   }
@@ -178,6 +188,57 @@ static bool runNVVMReflect(Function &F, unsigned SmVersion) {
   for (Instruction *I : ToRemove)
     I->eraseFromParent();
 
+  // The code guarded by __nvvm_reflect may be invalid for the target machine.
+  // We need to do some basic dead code elimination to trim invalid code before
+  // it reaches the backend at all optimization levels.
+  SmallVector<BranchInst *> Simplified;
+  for (ICmpInst *Cmp : ToSimplify) {
+    Constant *LHS = dyn_cast<Constant>(Cmp->getOperand(0));
+    Constant *RHS = dyn_cast<Constant>(Cmp->getOperand(1));
+
+    if (!LHS || !RHS)
+      continue;
+
+    // If the comparison is a compile time constat we sipmly propagate it.
+    Constant *C = ConstantFoldCompareInstOperands(
+        Cmp->getPredicate(), LHS, RHS, Cmp->getModule()->getDataLayout());
+
+    if (!C)
+      continue;
+
+    for (User *U : Cmp->users())
+      if (BranchInst *I = dyn_cast<BranchInst>(U))
+        Simplified.push_back(I);
+
+    Cmp->replaceAllUsesWith(C);
+    Cmp->eraseFromParent();
+  }
+
+  // Each instruction here is a conditional branch off of a constant true or
+  // false value. Simply replace it with an unconditional branch to the
+  // appropriate basic block and delete the rest if it is trivally dead.
+  DenseSet<Instruction *> Removed;
+  for (BranchInst *Branch : Simplified) {
+    if (Removed.contains(Branch))
+      continue;
+
+    ConstantInt *C = dyn_cast<ConstantInt>(Branch->getCondition());
+    if (!C || (!C->isOne() && !C->isZero()))
+      continue;
+
+    BasicBlock *TrueBB =
+        C->isOne() ? Branch->getSuccessor(0) : Branch->getSuccessor(1);
+    BasicBlock *FalseBB =
+        C->isOne() ? Branch->getSuccessor(1) : Branch->getSuccessor(0);
+
+    ReplaceInstWithInst(Branch, BranchInst::Create(TrueBB));
+    if (FalseBB->use_empty() && FalseBB->hasNPredecessors(0) &&
+        FalseBB->getFirstNonPHIOrDbg()) {
+      Removed.insert(FalseBB->getFirstNonPHIOrDbg());
+      changeToUnreachable(FalseBB->getFirstNonPHIOrDbg());
+    }
+  }
+
   return ToRemove.size() > 0;
 }
 
diff --git a/llvm/test/CodeGen/NVPTX/nvvm-reflect-arch-O0.ll b/llvm/test/CodeGen/NVPTX/nvvm-reflect-arch-O0.ll
new file mode 100644
index 00000000000000..3e480aaf89e50f
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/nvvm-reflect-arch-O0.ll
@@ -0,0 +1,102 @@
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_52 -mattr=+ptx64 -O0 | FileCheck %s --check-prefix=SM_52
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_70 -mattr=+ptx64 -O0 | FileCheck %s --check-prefix=SM_70
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_90 -mattr=+ptx72 -O0 | FileCheck %s --check-prefix=SM_90
+
+ at .str = private unnamed_addr constant [12 x i8] c"__CUDA_ARCH\00"
+
+declare i32 @__nvvm_reflect(ptr)
+
+;      SM_52: .visible .func  (.param .b32 func_retval0) foo()
+;      SM_52: mov.b32         %[[REG:.+]], 3;
+; SM_52-NEXT: st.param.b32    [func_retval0+0], %[[REG:.+]];
+; SM_52-NEXT: ret;
+;
+;      SM_70: .visible .func  (.param .b32 func_retval0) foo()
+;      SM_70: mov.b32         %[[REG:.+]], 2;
+; SM_70-NEXT: st.param.b32    [func_retval0+0], %[[REG:.+]];
+; SM_70-NEXT: ret;
+;
+;      SM_90: .visible .func  (.param .b32 func_retval0) foo()
+;      SM_90: mov.b32         %[[REG:.+]], 1;
+; SM_90-NEXT: st.param.b32    [func_retval0+0], %[[REG:.+]];
+; SM_90-NEXT: ret;
+define i32 @foo() {
+entry:
+  %call = call i32 @__nvvm_reflect(ptr @.str)
+  %cmp = icmp uge i32 %call, 900
+  br i1 %cmp, label %if.then, label %if.else
+
+if.then:
+  br label %return
+
+if.else:
+  %call1 = call i32 @__nvvm_reflect(ptr @.str)
+  %cmp2 = icmp uge i32 %call1, 700
+  br i1 %cmp2, label %if.then3, label %if.else4
+
+if.then3:
+  br label %return
+
+if.else4:
+  %call5 = call i32 @__nvvm_reflect(ptr @.str)
+  %cmp6 = icmp uge i32 %call5, 520
+  br i1 %cmp6, label %if.then7, label %if.else8
+
+if.then7:
+  br label %return
+
+if.else8:
+  br label %return
+
+return:
+  %retval.0 = phi i32 [ 1, %if.then ], [ 2, %if.then3 ], [ 3, %if.then7 ], [ 4, %if.else8 ]
+  ret i32 %retval.0
+}
+
+;      SM_52: .visible .func  (.param .b32 func_retval0) bar()
+;      SM_52: mov.b32         %[[REG:.+]], 2;
+; SM_52-NEXT: st.param.b32    [func_retval0+0], %[[REG:.+]];
+; SM_52-NEXT: ret;
+;
+;      SM_70: .visible .func  (.param .b32 func_retval0) bar()
+;      SM_70: mov.b32         %[[REG:.+]], 1;
+; SM_70-NEXT: st.param.b32    [func_retval0+0], %[[REG:.+]];
+; SM_70-NEXT: ret;
+;
+;      SM_90: .visible .func  (.param .b32 func_retval0) bar()
+;      SM_90: mov.b32         %[[REG:.+]], 1;
+; SM_90-NEXT: st.param.b32    [func_retval0+0], %[[REG:.+]];
+; SM_90-NEXT: ret;
+define i32 @bar() {
+entry:
+  %call = call i32 @__nvvm_reflect(ptr @.str)
+  %cmp = icmp uge i32 %call, 700
+  br i1 %cmp, label %if.then, label %if.else
+
+if.then:
+  br label %if.end
+
+if.else:
+  br label %if.end
+
+if.end:
+  %x = phi i32 [ 1, %if.then ], [ 2, %if.else ]
+  ret i32 %x
+}
+
+; SM_52-NOT: valid;
+; SM_70: valid;
+; SM_90: valid;
+define void @baz() {
+entry:
+  %call = call i32 @__nvvm_reflect(ptr @.str)
+  %cmp = icmp uge i32 %call, 700
+  br i1 %cmp, label %if.then, label %if.end
+
+if.then:
+  call void asm sideeffect "valid;\0A", ""()
+  br label %if.end
+
+if.end:
+  ret void
+}
diff --git a/llvm/test/CodeGen/NVPTX/nvvm-reflect-arch.ll b/llvm/test/CodeGen/NVPTX/nvvm-reflect-arch.ll
index e8c554c9ed5289..ac5875c6ab1043 100644
--- a/llvm/test/CodeGen/NVPTX/nvvm-reflect-arch.ll
+++ b/llvm/test/CodeGen/NVPTX/nvvm-reflect-arch.ll
@@ -18,4 +18,3 @@ define i32 @foo(float %a, float %b) {
 ; SM35: ret i32 350  
   ret i32 %reflect
 }
-

>From 37513713a1594505f4e9e72af521c2a9e54f2700 Mon Sep 17 00:00:00 2001
From: Joseph Huber <huberjn at outlook.com>
Date: Thu, 8 Feb 2024 15:28:49 -0600
Subject: [PATCH 2/3] Grammar

---
 llvm/lib/Target/NVPTX/NVVMReflect.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVVMReflect.cpp b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
index 78b50c07e1c422..c966731ad9dd96 100644
--- a/llvm/lib/Target/NVPTX/NVVMReflect.cpp
+++ b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
@@ -199,7 +199,7 @@ static bool runNVVMReflect(Function &F, unsigned SmVersion) {
     if (!LHS || !RHS)
       continue;
 
-    // If the comparison is a compile time constat we sipmly propagate it.
+    // If the comparison is a compile time constant we simply propagate it.
     Constant *C = ConstantFoldCompareInstOperands(
         Cmp->getPredicate(), LHS, RHS, Cmp->getModule()->getDataLayout());
 
@@ -216,7 +216,7 @@ static bool runNVVMReflect(Function &F, unsigned SmVersion) {
 
   // Each instruction here is a conditional branch off of a constant true or
   // false value. Simply replace it with an unconditional branch to the
-  // appropriate basic block and delete the rest if it is trivally dead.
+  // appropriate basic block and delete the rest if it is trivially dead.
   DenseSet<Instruction *> Removed;
   for (BranchInst *Branch : Simplified) {
     if (Removed.contains(Branch))

>From 6ea67cb4dba9dcbf7238750299e92adc95730db0 Mon Sep 17 00:00:00 2001
From: Joseph Huber <huberjn at outlook.com>
Date: Thu, 8 Feb 2024 16:26:01 -0600
Subject: [PATCH 3/3] Document lack of switch support

---
 llvm/docs/NVPTXUsage.rst                      |  5 +++
 llvm/lib/Target/NVPTX/NVVMReflect.cpp         |  1 +
 .../CodeGen/NVPTX/nvvm-reflect-arch-O0.ll     | 39 +++++++++++++++++++
 3 files changed, 45 insertions(+)

diff --git a/llvm/docs/NVPTXUsage.rst b/llvm/docs/NVPTXUsage.rst
index 22acc6c9cb37f5..b5e3918e56e940 100644
--- a/llvm/docs/NVPTXUsage.rst
+++ b/llvm/docs/NVPTXUsage.rst
@@ -296,6 +296,11 @@ pipeline, immediately after the link stage. The ``internalize`` pass is also
 recommended to remove unused math functions from the resulting PTX. For an
 input IR module ``module.bc``, the following compilation flow is recommended:
 
+The ``NVVMReflect`` pass will attempt to remove dead code even without
+optimizations. This allows potentially incompatible instructions to be avoided
+at all optimizations levels. This currently only works for simple conditionals
+like the above example.
+
 1. Save list of external functions in ``module.bc``
 2. Link ``module.bc`` with ``libdevice.compute_XX.YY.bc``
 3. Internalize all functions not in list from (1)
diff --git a/llvm/lib/Target/NVPTX/NVVMReflect.cpp b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
index c966731ad9dd96..5283c2fff2c6c2 100644
--- a/llvm/lib/Target/NVPTX/NVVMReflect.cpp
+++ b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
@@ -177,6 +177,7 @@ static bool runNVVMReflect(Function &F, unsigned SmVersion) {
     }
 
     // If the immediate user is a simple comparison we want to simplify it.
+    // TODO: This currently does not handle switch instructions.
     for (User *U : Call->users())
       if (ICmpInst *I = dyn_cast<ICmpInst>(U))
         ToSimplify.push_back(I);
diff --git a/llvm/test/CodeGen/NVPTX/nvvm-reflect-arch-O0.ll b/llvm/test/CodeGen/NVPTX/nvvm-reflect-arch-O0.ll
index 3e480aaf89e50f..c9586d5688f809 100644
--- a/llvm/test/CodeGen/NVPTX/nvvm-reflect-arch-O0.ll
+++ b/llvm/test/CodeGen/NVPTX/nvvm-reflect-arch-O0.ll
@@ -100,3 +100,42 @@ if.then:
 if.end:
   ret void
 }
+
+; SM_52: .visible .func  (.param .b32 func_retval0) qux()
+; SM_52: mov.u32         %[[REG1:.+]], %[[REG2:.+]];
+; SM_52: st.param.b32    [func_retval0+0], %[[REG1:.+]];
+; SM_52: ret;
+; SM_70: .visible .func  (.param .b32 func_retval0) qux()
+; SM_70: mov.u32         %[[REG1:.+]], %[[REG2:.+]];
+; SM_70: st.param.b32    [func_retval0+0], %[[REG1:.+]];
+; SM_70: ret;
+; SM_90: .visible .func  (.param .b32 func_retval0) qux()
+; SM_90: st.param.b32    [func_retval0+0], %[[REG1:.+]];
+; SM_90: ret;
+define i32 @qux() {
+entry:
+  %call = call i32 @__nvvm_reflect(ptr noundef @.str)
+  %cmp = icmp uge i32 %call, 700
+  %conv = zext i1 %cmp to i32
+  switch i32 %conv, label %sw.default [
+    i32 900, label %sw.bb
+    i32 700, label %sw.bb1
+    i32 520, label %sw.bb2
+  ]
+
+sw.bb:
+  br label %return
+
+sw.bb1:
+  br label %return
+
+sw.bb2:
+  br label %return
+
+sw.default:
+  br label %return
+
+return:
+  %retval = phi i32 [ 4, %sw.default ], [ 3, %sw.bb2 ], [ 2, %sw.bb1 ], [ 1, %sw.bb ]
+  ret i32 %retval
+}



More information about the llvm-commits mailing list