[llvm] [NVPTX] support switch statement with brx.idx (reland) (PR #102550)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Aug 8 16:20:58 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-selectiondag
Author: Alex MacLean (AlexMaclean)
<details>
<summary>Changes</summary>
Add custom lowering for `BR_JT` DAG nodes to the `brx.idx` PTX instruction ([PTX ISA 9.7.13.4. Control Flow Instructions: brx.idx](https://docs.nvidia.com/cuda/parallel-thread-execution/#control-flow-instructions-brx-idx)).
Depending on the heuristics in DAG selection, `switch` statements may now be lowered using `brx.idx`.
Note: this fixes the previous issue in #<!-- -->102400 by adding the isBarrier attribute to BRX_END
---
Full diff: https://github.com/llvm/llvm-project/pull/102550.diff
6 Files Affected:
- (modified) llvm/include/llvm/CodeGen/TargetLowering.h (+4)
- (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp (+6-5)
- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+42-3)
- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.h (+10)
- (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+40)
- (added) llvm/test/CodeGen/NVPTX/jump-table.ll (+69)
``````````diff
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 9ccdbab008aec8..5b2214fa66c40b 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -3843,6 +3843,10 @@ class TargetLowering : public TargetLoweringBase {
/// returned value is a member of the MachineJumpTableInfo::JTEntryKind enum.
virtual unsigned getJumpTableEncoding() const;
+ virtual MVT getJumpTableRegTy(const DataLayout &DL) const {
+ return getPointerTy(DL);
+ }
+
virtual const MCExpr *
LowerCustomJumpTableEntry(const MachineJumpTableInfo * /*MJTI*/,
const MachineBasicBlock * /*MBB*/, unsigned /*uid*/,
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 1f4436fb3a4966..37ba62911ec70b 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -2977,7 +2977,7 @@ void SelectionDAGBuilder::visitJumpTable(SwitchCG::JumpTable &JT) {
// Emit the code for the jump table
assert(JT.SL && "Should set SDLoc for SelectionDAG!");
assert(JT.Reg != -1U && "Should lower JT Header first!");
- EVT PTy = DAG.getTargetLoweringInfo().getPointerTy(DAG.getDataLayout());
+ EVT PTy = DAG.getTargetLoweringInfo().getJumpTableRegTy(DAG.getDataLayout());
SDValue Index = DAG.getCopyFromReg(getControlRoot(), *JT.SL, JT.Reg, PTy);
SDValue Table = DAG.getJumpTable(JT.JTI, PTy);
SDValue BrJumpTable = DAG.getNode(ISD::BR_JT, *JT.SL, MVT::Other,
@@ -3005,12 +3005,13 @@ void SelectionDAGBuilder::visitJumpTableHeader(SwitchCG::JumpTable &JT,
// This value may be smaller or larger than the target's pointer type, and
// therefore require extension or truncating.
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
- SwitchOp = DAG.getZExtOrTrunc(Sub, dl, TLI.getPointerTy(DAG.getDataLayout()));
+ SwitchOp =
+ DAG.getZExtOrTrunc(Sub, dl, TLI.getJumpTableRegTy(DAG.getDataLayout()));
unsigned JumpTableReg =
- FuncInfo.CreateReg(TLI.getPointerTy(DAG.getDataLayout()));
- SDValue CopyTo = DAG.getCopyToReg(getControlRoot(), dl,
- JumpTableReg, SwitchOp);
+ FuncInfo.CreateReg(TLI.getJumpTableRegTy(DAG.getDataLayout()));
+ SDValue CopyTo =
+ DAG.getCopyToReg(getControlRoot(), dl, JumpTableReg, SwitchOp);
JT.Reg = JumpTableReg;
if (!JTH.FallthroughUnreachable) {
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 516fc7339a4bf3..bf647c88f00e28 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -25,6 +25,7 @@
#include "llvm/CodeGen/Analysis.h"
#include "llvm/CodeGen/ISDOpcodes.h"
#include "llvm/CodeGen/MachineFunction.h"
+#include "llvm/CodeGen/MachineJumpTableInfo.h"
#include "llvm/CodeGen/MachineMemOperand.h"
#include "llvm/CodeGen/SelectionDAG.h"
#include "llvm/CodeGen/SelectionDAGNodes.h"
@@ -582,9 +583,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setOperationAction(ISD::ROTR, MVT::i8, Expand);
setOperationAction(ISD::BSWAP, MVT::i16, Expand);
- // Indirect branch is not supported.
- // This also disables Jump Table creation.
- setOperationAction(ISD::BR_JT, MVT::Other, Expand);
+ setOperationAction(ISD::BR_JT, MVT::Other, Custom);
setOperationAction(ISD::BRIND, MVT::Other, Expand);
setOperationAction(ISD::GlobalAddress, MVT::i32, Custom);
@@ -945,6 +944,9 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(NVPTXISD::Dummy)
MAKE_CASE(NVPTXISD::MUL_WIDE_SIGNED)
MAKE_CASE(NVPTXISD::MUL_WIDE_UNSIGNED)
+ MAKE_CASE(NVPTXISD::BrxEnd)
+ MAKE_CASE(NVPTXISD::BrxItem)
+ MAKE_CASE(NVPTXISD::BrxStart)
MAKE_CASE(NVPTXISD::Tex1DFloatS32)
MAKE_CASE(NVPTXISD::Tex1DFloatFloat)
MAKE_CASE(NVPTXISD::Tex1DFloatFloatLevel)
@@ -2785,6 +2787,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
return LowerFP_ROUND(Op, DAG);
case ISD::FP_EXTEND:
return LowerFP_EXTEND(Op, DAG);
+ case ISD::BR_JT:
+ return LowerBR_JT(Op, DAG);
case ISD::VAARG:
return LowerVAARG(Op, DAG);
case ISD::VASTART:
@@ -2810,6 +2814,41 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
}
}
+SDValue NVPTXTargetLowering::LowerBR_JT(SDValue Op, SelectionDAG &DAG) const {
+ SDLoc DL(Op);
+ SDValue Chain = Op.getOperand(0);
+ const auto *JT = cast<JumpTableSDNode>(Op.getOperand(1));
+ SDValue Index = Op.getOperand(2);
+
+ unsigned JId = JT->getIndex();
+ MachineJumpTableInfo *MJTI = DAG.getMachineFunction().getJumpTableInfo();
+ ArrayRef<MachineBasicBlock *> MBBs = MJTI->getJumpTables()[JId].MBBs;
+
+ SDValue IdV = DAG.getConstant(JId, DL, MVT::i32);
+
+ // Generate BrxStart node
+ SDVTList VTs = DAG.getVTList(MVT::Other, MVT::Glue);
+ Chain = DAG.getNode(NVPTXISD::BrxStart, DL, VTs, Chain, IdV);
+
+ // Generate BrxItem nodes
+ assert(!MBBs.empty());
+ for (MachineBasicBlock *MBB : MBBs.drop_back())
+ Chain = DAG.getNode(NVPTXISD::BrxItem, DL, VTs, Chain.getValue(0),
+ DAG.getBasicBlock(MBB), Chain.getValue(1));
+
+ // Generate BrxEnd nodes
+ SDValue EndOps[] = {Chain.getValue(0), DAG.getBasicBlock(MBBs.back()), Index,
+ IdV, Chain.getValue(1)};
+ SDValue BrxEnd = DAG.getNode(NVPTXISD::BrxEnd, DL, VTs, EndOps);
+
+ return BrxEnd;
+}
+
+// This will prevent AsmPrinter from trying to print the jump tables itself.
+unsigned NVPTXTargetLowering::getJumpTableEncoding() const {
+ return MachineJumpTableInfo::EK_Inline;
+}
+
// This function is almost a copy of SelectionDAG::expandVAArg().
// The only diff is that this one produces loads from local address space.
SDValue NVPTXTargetLowering::LowerVAARG(SDValue Op, SelectionDAG &DAG) const {
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 63262961b363ed..32e6b044b0de1f 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -62,6 +62,9 @@ enum NodeType : unsigned {
BFI,
PRMT,
DYNAMIC_STACKALLOC,
+ BrxStart,
+ BrxItem,
+ BrxEnd,
Dummy,
LoadV2 = ISD::FIRST_TARGET_MEMORY_OPCODE,
@@ -580,6 +583,11 @@ class NVPTXTargetLowering : public TargetLowering {
return true;
}
+ // The default is the same as pointer type, but brx.idx only accepts i32
+ MVT getJumpTableRegTy(const DataLayout &) const override { return MVT::i32; }
+
+ unsigned getJumpTableEncoding() const override;
+
bool enableAggressiveFMAFusion(EVT VT) const override { return true; }
// The default is to transform llvm.ctlz(x, false) (where false indicates that
@@ -637,6 +645,8 @@ class NVPTXTargetLowering : public TargetLowering {
SDValue LowerSelect(SDValue Op, SelectionDAG &DAG) const;
+ SDValue LowerBR_JT(SDValue Op, SelectionDAG &DAG) const;
+
SDValue LowerVAARG(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerVASTART(SDValue Op, SelectionDAG &DAG) const;
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 6a096fa5acea7c..d75dc8781f7802 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -3880,6 +3880,46 @@ def DYNAMIC_STACKALLOC64 :
[(set Int64Regs:$ptr, (dyn_alloca Int64Regs:$size, (i32 timm:$align)))]>,
Requires<[hasPTX<73>, hasSM<52>]>;
+
+//
+// BRX
+//
+
+def SDTBrxStartProfile : SDTypeProfile<0, 1, [SDTCisInt<0>]>;
+def SDTBrxItemProfile : SDTypeProfile<0, 1, [SDTCisVT<0, OtherVT>]>;
+def SDTBrxEndProfile : SDTypeProfile<0, 3, [SDTCisVT<0, OtherVT>, SDTCisInt<1>, SDTCisInt<2>]>;
+
+def brx_start :
+ SDNode<"NVPTXISD::BrxStart", SDTBrxStartProfile,
+ [SDNPHasChain, SDNPOutGlue, SDNPSideEffect]>;
+def brx_item :
+ SDNode<"NVPTXISD::BrxItem", SDTBrxItemProfile,
+ [SDNPHasChain, SDNPOutGlue, SDNPInGlue, SDNPSideEffect]>;
+def brx_end :
+ SDNode<"NVPTXISD::BrxEnd", SDTBrxEndProfile,
+ [SDNPHasChain, SDNPInGlue, SDNPSideEffect]>;
+
+let isTerminator = 1, isBranch = 1, isIndirectBranch = 1, isNotDuplicable = 1 in {
+
+ def BRX_START :
+ NVPTXInst<(outs), (ins i32imm:$id),
+ "$$L_brx_$id: .branchtargets",
+ [(brx_start (i32 imm:$id))]>;
+
+ def BRX_ITEM :
+ NVPTXInst<(outs), (ins brtarget:$target),
+ "\t$target,",
+ [(brx_item bb:$target)]>;
+
+ def BRX_END :
+ NVPTXInst<(outs), (ins brtarget:$target, Int32Regs:$val, i32imm:$id),
+ "\t$target;\n\tbrx.idx \t$val, $$L_brx_$id;",
+ [(brx_end bb:$target, (i32 Int32Regs:$val), (i32 imm:$id))]> {
+ let isBarrier = 1;
+ }
+}
+
+
include "NVPTXIntrinsics.td"
//-----------------------------------
diff --git a/llvm/test/CodeGen/NVPTX/jump-table.ll b/llvm/test/CodeGen/NVPTX/jump-table.ll
new file mode 100644
index 00000000000000..867e171a5840ae
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/jump-table.ll
@@ -0,0 +1,69 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s | FileCheck %s
+; RUN: %if ptxas %{ llc < %s | %ptxas-verify %}
+
+target triple = "nvptx64-nvidia-cuda"
+
+ at out = addrspace(1) global i32 0, align 4
+
+define void @foo(i32 %i) {
+; CHECK-LABEL: foo(
+; CHECK: {
+; CHECK-NEXT: .reg .pred %p<2>;
+; CHECK-NEXT: .reg .b32 %r<7>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0: // %entry
+; CHECK-NEXT: ld.param.u32 %r2, [foo_param_0];
+; CHECK-NEXT: setp.gt.u32 %p1, %r2, 3;
+; CHECK-NEXT: @%p1 bra $L__BB0_6;
+; CHECK-NEXT: // %bb.1: // %entry
+; CHECK-NEXT: $L_brx_0: .branchtargets
+; CHECK-NEXT: $L__BB0_2,
+; CHECK-NEXT: $L__BB0_3,
+; CHECK-NEXT: $L__BB0_4,
+; CHECK-NEXT: $L__BB0_5;
+; CHECK-NEXT: brx.idx %r2, $L_brx_0;
+; CHECK-NEXT: $L__BB0_2: // %case0
+; CHECK-NEXT: mov.b32 %r6, 0;
+; CHECK-NEXT: st.global.u32 [out], %r6;
+; CHECK-NEXT: bra.uni $L__BB0_6;
+; CHECK-NEXT: $L__BB0_4: // %case2
+; CHECK-NEXT: mov.b32 %r4, 2;
+; CHECK-NEXT: st.global.u32 [out], %r4;
+; CHECK-NEXT: bra.uni $L__BB0_6;
+; CHECK-NEXT: $L__BB0_5: // %case3
+; CHECK-NEXT: mov.b32 %r3, 3;
+; CHECK-NEXT: st.global.u32 [out], %r3;
+; CHECK-NEXT: bra.uni $L__BB0_6;
+; CHECK-NEXT: $L__BB0_3: // %case1
+; CHECK-NEXT: mov.b32 %r5, 1;
+; CHECK-NEXT: st.global.u32 [out], %r5;
+; CHECK-NEXT: $L__BB0_6: // %end
+; CHECK-NEXT: ret;
+entry:
+ switch i32 %i, label %end [
+ i32 0, label %case0
+ i32 1, label %case1
+ i32 2, label %case2
+ i32 3, label %case3
+ ]
+
+case0:
+ store i32 0, ptr addrspace(1) @out, align 4
+ br label %end
+
+case1:
+ store i32 1, ptr addrspace(1) @out, align 4
+ br label %end
+
+case2:
+ store i32 2, ptr addrspace(1) @out, align 4
+ br label %end
+
+case3:
+ store i32 3, ptr addrspace(1) @out, align 4
+ br label %end
+
+end:
+ ret void
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/102550
More information about the llvm-commits
mailing list