[llvm] [NVVMReflect] Improve folding inside of the NVVMReflect pass (PR #81253)

Joseph Huber via llvm-commits llvm-commits at lists.llvm.org
Fri Feb 9 06:39:42 PST 2024


https://github.com/jhuber6 created https://github.com/llvm/llvm-project/pull/81253

Summary:
The previous patch did very simple folding that only worked for driectly
used branches. This patch improves this by traversing the use-def chain
to sipmlify every constant subexpression until it reaches a terminator
we can delete. The support should work for all expected cases now.


>From b00f01a46f25c38b45e1979b045d0e4ecb30555b Mon Sep 17 00:00:00 2001
From: Joseph Huber <huberjn at outlook.com>
Date: Fri, 9 Feb 2024 08:36:26 -0600
Subject: [PATCH] [NVVMReflect] Improve folding inside of the NVVMReflect pass

Summary:
The previous patch did very simple folding that only worked for driectly
used branches. This patch improves this by traversing the use-def chain
to sipmlify every constant subexpression until it reaches a terminator
we can delete. The support should work for all expected cases now.
---
 llvm/docs/NVPTXUsage.rst                      |  3 +-
 llvm/lib/Target/NVPTX/NVVMReflect.cpp         | 70 ++++-------------
 .../CodeGen/NVPTX/nvvm-reflect-arch-O0.ll     | 78 +++++++++++++++----
 3 files changed, 82 insertions(+), 69 deletions(-)

diff --git a/llvm/docs/NVPTXUsage.rst b/llvm/docs/NVPTXUsage.rst
index b5e3918e56e940..6a55b1205a7618 100644
--- a/llvm/docs/NVPTXUsage.rst
+++ b/llvm/docs/NVPTXUsage.rst
@@ -298,8 +298,7 @@ input IR module ``module.bc``, the following compilation flow is recommended:
 
 The ``NVVMReflect`` pass will attempt to remove dead code even without
 optimizations. This allows potentially incompatible instructions to be avoided
-at all optimizations levels. This currently only works for simple conditionals
-like the above example.
+at all optimizations levels by using the ``__CUDA_ARCH`` argument.
 
 1. Save list of external functions in ``module.bc``
 2. Link ``module.bc`` with ``libdevice.compute_XX.YY.bc``
diff --git a/llvm/lib/Target/NVPTX/NVVMReflect.cpp b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
index 3794ad9b126fa8..64fedf32e9a269 100644
--- a/llvm/lib/Target/NVPTX/NVVMReflect.cpp
+++ b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
@@ -90,7 +90,7 @@ static bool runNVVMReflect(Function &F, unsigned SmVersion) {
   }
 
   SmallVector<Instruction *, 4> ToRemove;
-  SmallVector<ICmpInst *, 4> ToSimplify;
+  SmallVector<Instruction *, 4> ToSimplify;
 
   // Go through the calls in this function.  Each call to __nvvm_reflect or
   // llvm.nvvm.reflect should be a CallInst with a ConstantArray argument.
@@ -177,9 +177,8 @@ static bool runNVVMReflect(Function &F, unsigned SmVersion) {
     }
 
     // If the immediate user is a simple comparison we want to simplify it.
-    // TODO: This currently does not handle switch instructions.
     for (User *U : Call->users())
-      if (ICmpInst *I = dyn_cast<ICmpInst>(U))
+      if (Instruction *I = dyn_cast<Instruction>(U))
         ToSimplify.push_back(I);
 
     Call->replaceAllUsesWith(ConstantInt::get(Call->getType(), ReflectVal));
@@ -190,56 +189,21 @@ static bool runNVVMReflect(Function &F, unsigned SmVersion) {
     I->eraseFromParent();
 
   // The code guarded by __nvvm_reflect may be invalid for the target machine.
-  // We need to do some basic dead code elimination to trim invalid code before
-  // it reaches the backend at all optimization levels.
-  SmallVector<BranchInst *> Simplified;
-  for (ICmpInst *Cmp : ToSimplify) {
-    Constant *LHS = dyn_cast<Constant>(Cmp->getOperand(0));
-    Constant *RHS = dyn_cast<Constant>(Cmp->getOperand(1));
-
-    if (!LHS || !RHS)
-      continue;
-
-    // If the comparison is a compile time constant we simply propagate it.
-    Constant *C = ConstantFoldCompareInstOperands(
-        Cmp->getPredicate(), LHS, RHS, Cmp->getModule()->getDataLayout());
-
-    if (!C)
-      continue;
-
-    for (User *U : Cmp->users())
-      if (BranchInst *I = dyn_cast<BranchInst>(U))
-        Simplified.push_back(I);
-
-    Cmp->replaceAllUsesWith(C);
-    Cmp->eraseFromParent();
-  }
-
-  // Each instruction here is a conditional branch off of a constant true or
-  // false value. Simply replace it with an unconditional branch to the
-  // appropriate basic block and delete the rest if it is trivially dead.
-  DenseSet<Instruction *> Removed;
-  for (BranchInst *Branch : Simplified) {
-    if (Removed.contains(Branch))
-      continue;
-
-    ConstantInt *C = dyn_cast<ConstantInt>(Branch->getCondition());
-    if (!C || (!C->isOne() && !C->isZero()))
-      continue;
-
-    BasicBlock *TrueBB =
-        C->isOne() ? Branch->getSuccessor(0) : Branch->getSuccessor(1);
-    BasicBlock *FalseBB =
-        C->isOne() ? Branch->getSuccessor(1) : Branch->getSuccessor(0);
-
-    // This transformation is only correct on simple edges.
-    if (!FalseBB->hasNPredecessors(1))
-      continue;
-
-    ReplaceInstWithInst(Branch, BranchInst::Create(TrueBB));
-    if (FalseBB->use_empty() && !FalseBB->getFirstNonPHIOrDbg()) {
-      Removed.insert(FalseBB->getFirstNonPHIOrDbg());
-      changeToUnreachable(FalseBB->getFirstNonPHIOrDbg());
+  // Traverse the use-def chain, continually simplifying constant expressions
+  // until we find a terminator that we can then remove.
+  while (!ToSimplify.empty()) {
+    Instruction *I = ToSimplify.pop_back_val();
+    if (Constant *C =
+            ConstantFoldInstruction(I, F.getParent()->getDataLayout())) {
+      for (User *U : I->users())
+        if (Instruction *I = dyn_cast<Instruction>(U))
+          ToSimplify.push_back(I);
+
+      I->replaceAllUsesWith(C);
+      if (isInstructionTriviallyDead(I))
+        I->eraseFromParent();
+    } else if (I->isTerminator()) {
+      ConstantFoldTerminator(I->getParent());
     }
   }
 
diff --git a/llvm/test/CodeGen/NVPTX/nvvm-reflect-arch-O0.ll b/llvm/test/CodeGen/NVPTX/nvvm-reflect-arch-O0.ll
index 9dcdf5b6825767..0088d6c64205d2 100644
--- a/llvm/test/CodeGen/NVPTX/nvvm-reflect-arch-O0.ll
+++ b/llvm/test/CodeGen/NVPTX/nvvm-reflect-arch-O0.ll
@@ -102,23 +102,24 @@ if.end:
   ret void
 }
 
-; SM_52: .visible .func  (.param .b32 func_retval0) qux()
-; SM_52: mov.u32         %[[REG1:.+]], %[[REG2:.+]];
-; SM_52: st.param.b32    [func_retval0+0], %[[REG1:.+]];
-; SM_52: ret;
-; SM_70: .visible .func  (.param .b32 func_retval0) qux()
-; SM_70: mov.u32         %[[REG1:.+]], %[[REG2:.+]];
-; SM_70: st.param.b32    [func_retval0+0], %[[REG1:.+]];
-; SM_70: ret;
-; SM_90: .visible .func  (.param .b32 func_retval0) qux()
-; SM_90: st.param.b32    [func_retval0+0], %[[REG1:.+]];
-; SM_90: ret;
+;      SM_52: .visible .func  (.param .b32 func_retval0) qux()
+;      SM_52: mov.b32         %[[REG:.+]], 3;
+; SM_52-NEXT: st.param.b32    [func_retval0+0], %[[REG:.+]];
+; SM_52-NEXT: ret;
+;
+;      SM_70: .visible .func  (.param .b32 func_retval0) qux()
+;      SM_70: mov.b32         %[[REG:.+]], 2;
+; SM_70-NEXT: st.param.b32    [func_retval0+0], %[[REG:.+]];
+; SM_70-NEXT: ret;
+;
+;      SM_90: .visible .func  (.param .b32 func_retval0) qux()
+;      SM_90: mov.b32         %[[REG:.+]], 1;
+; SM_90-NEXT: st.param.b32    [func_retval0+0], %[[REG:.+]];
+; SM_90-NEXT: ret;
 define i32 @qux() {
 entry:
   %call = call i32 @__nvvm_reflect(ptr noundef @.str)
-  %cmp = icmp uge i32 %call, 700
-  %conv = zext i1 %cmp to i32
-  switch i32 %conv, label %sw.default [
+  switch i32 %call, label %sw.default [
     i32 900, label %sw.bb
     i32 700, label %sw.bb1
     i32 520, label %sw.bb2
@@ -173,3 +174,52 @@ if.exit:
 exit:
   ret float 0.000000e+00
 }
+
+;      SM_52: .visible .func  (.param .b32 func_retval0) prop()
+;      SM_52: mov.b32         %[[REG:.+]], 3;
+; SM_52-NEXT: st.param.b32    [func_retval0+0], %[[REG:.+]];
+; SM_52-NEXT: ret;
+;
+;      SM_70: .visible .func  (.param .b32 func_retval0) prop()
+;      SM_70: mov.b32         %[[REG:.+]], 2;
+; SM_70-NEXT: st.param.b32    [func_retval0+0], %[[REG:.+]];
+; SM_70-NEXT: ret;
+;
+;      SM_90: .visible .func  (.param .b32 func_retval0) prop()
+;      SM_90: mov.b32         %[[REG:.+]], 1;
+; SM_90-NEXT: st.param.b32    [func_retval0+0], %[[REG:.+]];
+; SM_90-NEXT: ret;
+define i32 @prop() {
+entry:
+  %call = call i32 @__nvvm_reflect(ptr @.str)
+  %conv = zext i32 %call to i64
+  %div = udiv i64 %conv, 100
+  %cmp = icmp eq i64 %div, 9
+  br i1 %cmp, label %if.then, label %if.else
+
+if.then:
+  br label %return
+
+if.else:
+  %div2 = udiv i64 %conv, 100
+  %cmp3 = icmp eq i64 %div2, 7
+  br i1 %cmp3, label %if.then5, label %if.else6
+
+if.then5:
+  br label %return
+
+if.else6:
+  %div7 = udiv i64 %conv, 100
+  %cmp8 = icmp eq i64 %div7, 5
+  br i1 %cmp8, label %if.then10, label %if.else11
+
+if.then10:
+  br label %return
+
+if.else11:
+  br label %return
+
+return:
+  %retval = phi i32 [ 1, %if.then ], [ 2, %if.then5 ], [ 3, %if.then10 ], [ 4, %if.else11 ]
+  ret i32 %retval
+}



More information about the llvm-commits mailing list