[llvm] [NVVMReflect] Improve folding inside of the NVVMReflect pass (PR #81253)
Joseph Huber via llvm-commits
llvm-commits at lists.llvm.org
Fri Feb 9 06:39:42 PST 2024
https://github.com/jhuber6 created https://github.com/llvm/llvm-project/pull/81253
Summary:
The previous patch did very simple folding that only worked for driectly
used branches. This patch improves this by traversing the use-def chain
to sipmlify every constant subexpression until it reaches a terminator
we can delete. The support should work for all expected cases now.
>From b00f01a46f25c38b45e1979b045d0e4ecb30555b Mon Sep 17 00:00:00 2001
From: Joseph Huber <huberjn at outlook.com>
Date: Fri, 9 Feb 2024 08:36:26 -0600
Subject: [PATCH] [NVVMReflect] Improve folding inside of the NVVMReflect pass
Summary:
The previous patch did very simple folding that only worked for driectly
used branches. This patch improves this by traversing the use-def chain
to sipmlify every constant subexpression until it reaches a terminator
we can delete. The support should work for all expected cases now.
---
llvm/docs/NVPTXUsage.rst | 3 +-
llvm/lib/Target/NVPTX/NVVMReflect.cpp | 70 ++++-------------
.../CodeGen/NVPTX/nvvm-reflect-arch-O0.ll | 78 +++++++++++++++----
3 files changed, 82 insertions(+), 69 deletions(-)
diff --git a/llvm/docs/NVPTXUsage.rst b/llvm/docs/NVPTXUsage.rst
index b5e3918e56e940..6a55b1205a7618 100644
--- a/llvm/docs/NVPTXUsage.rst
+++ b/llvm/docs/NVPTXUsage.rst
@@ -298,8 +298,7 @@ input IR module ``module.bc``, the following compilation flow is recommended:
The ``NVVMReflect`` pass will attempt to remove dead code even without
optimizations. This allows potentially incompatible instructions to be avoided
-at all optimizations levels. This currently only works for simple conditionals
-like the above example.
+at all optimizations levels by using the ``__CUDA_ARCH`` argument.
1. Save list of external functions in ``module.bc``
2. Link ``module.bc`` with ``libdevice.compute_XX.YY.bc``
diff --git a/llvm/lib/Target/NVPTX/NVVMReflect.cpp b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
index 3794ad9b126fa8..64fedf32e9a269 100644
--- a/llvm/lib/Target/NVPTX/NVVMReflect.cpp
+++ b/llvm/lib/Target/NVPTX/NVVMReflect.cpp
@@ -90,7 +90,7 @@ static bool runNVVMReflect(Function &F, unsigned SmVersion) {
}
SmallVector<Instruction *, 4> ToRemove;
- SmallVector<ICmpInst *, 4> ToSimplify;
+ SmallVector<Instruction *, 4> ToSimplify;
// Go through the calls in this function. Each call to __nvvm_reflect or
// llvm.nvvm.reflect should be a CallInst with a ConstantArray argument.
@@ -177,9 +177,8 @@ static bool runNVVMReflect(Function &F, unsigned SmVersion) {
}
// If the immediate user is a simple comparison we want to simplify it.
- // TODO: This currently does not handle switch instructions.
for (User *U : Call->users())
- if (ICmpInst *I = dyn_cast<ICmpInst>(U))
+ if (Instruction *I = dyn_cast<Instruction>(U))
ToSimplify.push_back(I);
Call->replaceAllUsesWith(ConstantInt::get(Call->getType(), ReflectVal));
@@ -190,56 +189,21 @@ static bool runNVVMReflect(Function &F, unsigned SmVersion) {
I->eraseFromParent();
// The code guarded by __nvvm_reflect may be invalid for the target machine.
- // We need to do some basic dead code elimination to trim invalid code before
- // it reaches the backend at all optimization levels.
- SmallVector<BranchInst *> Simplified;
- for (ICmpInst *Cmp : ToSimplify) {
- Constant *LHS = dyn_cast<Constant>(Cmp->getOperand(0));
- Constant *RHS = dyn_cast<Constant>(Cmp->getOperand(1));
-
- if (!LHS || !RHS)
- continue;
-
- // If the comparison is a compile time constant we simply propagate it.
- Constant *C = ConstantFoldCompareInstOperands(
- Cmp->getPredicate(), LHS, RHS, Cmp->getModule()->getDataLayout());
-
- if (!C)
- continue;
-
- for (User *U : Cmp->users())
- if (BranchInst *I = dyn_cast<BranchInst>(U))
- Simplified.push_back(I);
-
- Cmp->replaceAllUsesWith(C);
- Cmp->eraseFromParent();
- }
-
- // Each instruction here is a conditional branch off of a constant true or
- // false value. Simply replace it with an unconditional branch to the
- // appropriate basic block and delete the rest if it is trivially dead.
- DenseSet<Instruction *> Removed;
- for (BranchInst *Branch : Simplified) {
- if (Removed.contains(Branch))
- continue;
-
- ConstantInt *C = dyn_cast<ConstantInt>(Branch->getCondition());
- if (!C || (!C->isOne() && !C->isZero()))
- continue;
-
- BasicBlock *TrueBB =
- C->isOne() ? Branch->getSuccessor(0) : Branch->getSuccessor(1);
- BasicBlock *FalseBB =
- C->isOne() ? Branch->getSuccessor(1) : Branch->getSuccessor(0);
-
- // This transformation is only correct on simple edges.
- if (!FalseBB->hasNPredecessors(1))
- continue;
-
- ReplaceInstWithInst(Branch, BranchInst::Create(TrueBB));
- if (FalseBB->use_empty() && !FalseBB->getFirstNonPHIOrDbg()) {
- Removed.insert(FalseBB->getFirstNonPHIOrDbg());
- changeToUnreachable(FalseBB->getFirstNonPHIOrDbg());
+ // Traverse the use-def chain, continually simplifying constant expressions
+ // until we find a terminator that we can then remove.
+ while (!ToSimplify.empty()) {
+ Instruction *I = ToSimplify.pop_back_val();
+ if (Constant *C =
+ ConstantFoldInstruction(I, F.getParent()->getDataLayout())) {
+ for (User *U : I->users())
+ if (Instruction *I = dyn_cast<Instruction>(U))
+ ToSimplify.push_back(I);
+
+ I->replaceAllUsesWith(C);
+ if (isInstructionTriviallyDead(I))
+ I->eraseFromParent();
+ } else if (I->isTerminator()) {
+ ConstantFoldTerminator(I->getParent());
}
}
diff --git a/llvm/test/CodeGen/NVPTX/nvvm-reflect-arch-O0.ll b/llvm/test/CodeGen/NVPTX/nvvm-reflect-arch-O0.ll
index 9dcdf5b6825767..0088d6c64205d2 100644
--- a/llvm/test/CodeGen/NVPTX/nvvm-reflect-arch-O0.ll
+++ b/llvm/test/CodeGen/NVPTX/nvvm-reflect-arch-O0.ll
@@ -102,23 +102,24 @@ if.end:
ret void
}
-; SM_52: .visible .func (.param .b32 func_retval0) qux()
-; SM_52: mov.u32 %[[REG1:.+]], %[[REG2:.+]];
-; SM_52: st.param.b32 [func_retval0+0], %[[REG1:.+]];
-; SM_52: ret;
-; SM_70: .visible .func (.param .b32 func_retval0) qux()
-; SM_70: mov.u32 %[[REG1:.+]], %[[REG2:.+]];
-; SM_70: st.param.b32 [func_retval0+0], %[[REG1:.+]];
-; SM_70: ret;
-; SM_90: .visible .func (.param .b32 func_retval0) qux()
-; SM_90: st.param.b32 [func_retval0+0], %[[REG1:.+]];
-; SM_90: ret;
+; SM_52: .visible .func (.param .b32 func_retval0) qux()
+; SM_52: mov.b32 %[[REG:.+]], 3;
+; SM_52-NEXT: st.param.b32 [func_retval0+0], %[[REG:.+]];
+; SM_52-NEXT: ret;
+;
+; SM_70: .visible .func (.param .b32 func_retval0) qux()
+; SM_70: mov.b32 %[[REG:.+]], 2;
+; SM_70-NEXT: st.param.b32 [func_retval0+0], %[[REG:.+]];
+; SM_70-NEXT: ret;
+;
+; SM_90: .visible .func (.param .b32 func_retval0) qux()
+; SM_90: mov.b32 %[[REG:.+]], 1;
+; SM_90-NEXT: st.param.b32 [func_retval0+0], %[[REG:.+]];
+; SM_90-NEXT: ret;
define i32 @qux() {
entry:
%call = call i32 @__nvvm_reflect(ptr noundef @.str)
- %cmp = icmp uge i32 %call, 700
- %conv = zext i1 %cmp to i32
- switch i32 %conv, label %sw.default [
+ switch i32 %call, label %sw.default [
i32 900, label %sw.bb
i32 700, label %sw.bb1
i32 520, label %sw.bb2
@@ -173,3 +174,52 @@ if.exit:
exit:
ret float 0.000000e+00
}
+
+; SM_52: .visible .func (.param .b32 func_retval0) prop()
+; SM_52: mov.b32 %[[REG:.+]], 3;
+; SM_52-NEXT: st.param.b32 [func_retval0+0], %[[REG:.+]];
+; SM_52-NEXT: ret;
+;
+; SM_70: .visible .func (.param .b32 func_retval0) prop()
+; SM_70: mov.b32 %[[REG:.+]], 2;
+; SM_70-NEXT: st.param.b32 [func_retval0+0], %[[REG:.+]];
+; SM_70-NEXT: ret;
+;
+; SM_90: .visible .func (.param .b32 func_retval0) prop()
+; SM_90: mov.b32 %[[REG:.+]], 1;
+; SM_90-NEXT: st.param.b32 [func_retval0+0], %[[REG:.+]];
+; SM_90-NEXT: ret;
+define i32 @prop() {
+entry:
+ %call = call i32 @__nvvm_reflect(ptr @.str)
+ %conv = zext i32 %call to i64
+ %div = udiv i64 %conv, 100
+ %cmp = icmp eq i64 %div, 9
+ br i1 %cmp, label %if.then, label %if.else
+
+if.then:
+ br label %return
+
+if.else:
+ %div2 = udiv i64 %conv, 100
+ %cmp3 = icmp eq i64 %div2, 7
+ br i1 %cmp3, label %if.then5, label %if.else6
+
+if.then5:
+ br label %return
+
+if.else6:
+ %div7 = udiv i64 %conv, 100
+ %cmp8 = icmp eq i64 %div7, 5
+ br i1 %cmp8, label %if.then10, label %if.else11
+
+if.then10:
+ br label %return
+
+if.else11:
+ br label %return
+
+return:
+ %retval = phi i32 [ 1, %if.then ], [ 2, %if.then5 ], [ 3, %if.then10 ], [ 4, %if.else11 ]
+ ret i32 %retval
+}
More information about the llvm-commits
mailing list