[llvm] [NVPTX] Teach NVPTX about predicates (PR #67468)

via llvm-commits llvm-commits at lists.llvm.org
Thu Sep 28 06:20:10 PDT 2023


https://github.com/ldrumm updated https://github.com/llvm/llvm-project/pull/67468

>From 9849e8004efa79b062b682a08b5f2ed83bd772e3 Mon Sep 17 00:00:00 2001
From: Luke Drummond <luke.drummond at codeplay.com>
Date: Tue, 5 Sep 2023 14:12:14 +0100
Subject: [PATCH] [NVPTX] Teach NVPTX about predicates

PTX is fully predicated[1], and Maxwell through Ampere take predicate
registers at the ISA level[2]. However, we've not been utilizing this
in LLVM, but only manually specifying branch predicates for the
`CBranch` instruction. As mentioned in [1], all PTX instructions can be
predicated, and there are several forms of a predicated instruction in
use:

  @<predicate_reg>
  @!<predicate_reg>

The first form enables the instruction if <predicate_reg> is nonzero,
the second if it is zero.

In this part-the-first, we add such two-part predicates to the NVPTX
backend: a predicate register operand which defaults to `$nopred` i.e.
always-true, and a predicate inversion "switch" for inverting the
condition, which defaults to zero

e.g.

     ADDi64ri %0, 1, $noreg, 0

is unpredicated, but

    %2:int1regs = IMPLICIT_DEF
     ADDi64ri %0, 1, %2, 0

is predicated on `%2`, e.g.

```asm
 @%p1 add.s64 %rd3, %rd1, 1;
```
Finally:

    %2:int1regs = IMPLICIT_DEF
    StoreRetvalI64 %4, 0, %2, 1

is the "inverted version" e.g.

```asm
 @!%p1 add.s64   %rd3, %rd1, 1;
```
where the last two MOs are a default "no predicate", and "uninverted
predicate" register and switch.

The changes here are logically fairly minimal, not really affecting
the generated code that much, but add the machinery for better
optimization opportunities, such as if-conversion which I'm working on.
Also missing here are some useful target hooks which I'll add in due
course:

  - getPredicationCost
  - optimizeCondBranch
  - reverseBranchCondition
  - isPredicated
  et al.

I also intend to add more `AllowModify` cases to `analyzeBranch` to
enable better machine block placement and other generic machine
optimizations.

Since the branching logic is significantly affected here I've renamed
the branch instructions to make it clear their implementation has
changed.

[1]: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#ptx-instructions
[2]: https://docs.nvidia.com/cuda/pdf/CUDA_Binary_Utilities.pdf Chapter 6
---
 .../NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp   |  13 +
 .../NVPTX/MCTargetDesc/NVPTXInstPrinter.h     |   2 +
 llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp     |   5 +
 llvm/lib/Target/NVPTX/NVPTXFrameLowering.cpp  |   8 +-
 llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp   | 196 ++++++++----
 llvm/lib/Target/NVPTX/NVPTXInstrFormats.td    |  42 ++-
 llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp      |  83 +++--
 llvm/lib/Target/NVPTX/NVPTXInstrInfo.td       |  51 ++--
 llvm/lib/Target/NVPTX/NVPTXPeephole.cpp       |   4 +-
 .../CodeGen/NVPTX/analyze_branch_crash.ll     |  10 +
 llvm/test/CodeGen/NVPTX/branch-fold.mir       |  16 +-
 llvm/test/CodeGen/NVPTX/branches.ll           |  23 ++
 llvm/test/CodeGen/NVPTX/branches.mir          | 285 ++++++++++++++++++
 llvm/test/CodeGen/NVPTX/nopred.mir            |  33 ++
 14 files changed, 644 insertions(+), 127 deletions(-)
 create mode 100644 llvm/test/CodeGen/NVPTX/analyze_branch_crash.ll
 create mode 100644 llvm/test/CodeGen/NVPTX/branches.ll
 create mode 100644 llvm/test/CodeGen/NVPTX/branches.mir
 create mode 100644 llvm/test/CodeGen/NVPTX/nopred.mir

diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
index 5d27accdc198c1e..0f866fe1dfc80c7 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
@@ -31,6 +31,19 @@ NVPTXInstPrinter::NVPTXInstPrinter(const MCAsmInfo &MAI, const MCInstrInfo &MII,
                                    const MCRegisterInfo &MRI)
     : MCInstPrinter(MAI, MII, MRI) {}
 
+void NVPTXInstPrinter::printPredicateOperand(const MCInst *MI, int OpNum,
+                                             raw_ostream &OS,
+                                             const char * /*Modifier*/) {
+  assert(MI->getNumOperands() == OpNum + 2 &&
+         "predicate and switch must be last");
+  unsigned Reg = MI->getOperand(OpNum).getReg();
+  unsigned Sw = MI->getOperand(OpNum + 1).getImm();
+  if (!Reg)
+    return;
+  OS << (Sw ? "@!" : "@");
+  printRegName(OS, Reg);
+}
+
 void NVPTXInstPrinter::printRegName(raw_ostream &OS, MCRegister Reg) const {
   // Decode the virtual register
   // Must be kept in sync with NVPTXAsmPrinter::encodeVirtualRegister
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
index 49ad3f269229d5f..e69f2d8ba6ee9ce 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
@@ -47,6 +47,8 @@ class NVPTXInstPrinter : public MCInstPrinter {
                        raw_ostream &O, const char *Modifier = nullptr);
   void printProtoIdent(const MCInst *MI, int OpNum,
                        raw_ostream &O, const char *Modifier = nullptr);
+  void printPredicateOperand(const MCInst *MI, int OpNum, raw_ostream &O,
+                             const char *Modifier = nullptr);
 };
 
 }
diff --git a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index 5d6127419d6318e..da42e6699a4abc9 100644
--- a/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -222,6 +222,11 @@ void NVPTXAsmPrinter::lowerToMCInst(const MachineInstr *MI, MCInst &OutMI) {
     const MachineOperand &MO = MI->getOperand(0);
     OutMI.addOperand(GetSymbolRef(
       OutContext.getOrCreateSymbol(Twine(MO.getSymbolName()))));
+    MCOperand Predicate, Sw;
+    if (lowerOperand(MI->getOperand(1), Predicate))
+      OutMI.addOperand(Predicate);
+    if (lowerOperand(MI->getOperand(2), Sw))
+      OutMI.addOperand(Sw);
     return;
   }
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXFrameLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXFrameLowering.cpp
index 86fb367780dc1a0..efb86a6202e49fa 100644
--- a/llvm/lib/Target/NVPTX/NVPTXFrameLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXFrameLowering.cpp
@@ -58,12 +58,16 @@ void NVPTXFrameLowering::emitPrologue(MachineFunction &MF,
       MBBI = BuildMI(MBB, MBBI, dl,
                      MF.getSubtarget().getInstrInfo()->get(CvtaLocalOpcode),
                      NRI->getFrameRegister(MF))
-                 .addReg(NRI->getFrameLocalRegister(MF));
+                 .addReg(NRI->getFrameLocalRegister(MF))
+                 .addReg(0)
+                 .addImm(0);
     }
     BuildMI(MBB, MBBI, dl,
             MF.getSubtarget().getInstrInfo()->get(MovDepotOpcode),
             NRI->getFrameLocalRegister(MF))
-        .addImm(MF.getFunctionNumber());
+        .addImm(MF.getFunctionNumber())
+        .addReg(0)
+        .addImm(0);
   }
 }
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 0aef2591c6e2394..c8b037ec89bb706 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -29,6 +29,9 @@ using namespace llvm;
 #define DEBUG_TYPE "nvptx-isel"
 #define PASS_NAME "NVPTX DAG->DAG Pattern Instruction Selection"
 
+#define PRED(reg) (CurDAG->getRegister((reg), MVT::i32))
+#define SW(val) (CurDAG->getTargetConstant((val), SDLoc(N), MVT::i1))
+
 /// createNVPTXISelDag - This pass converts a legalized DAG into a
 /// NVPTX-specific DAG, ready for instruction scheduling.
 FunctionPass *llvm::createNVPTXISelDag(NVPTXTargetMachine &TM,
@@ -534,7 +537,7 @@ bool NVPTXDAGToDAGISel::tryConstantFP(SDNode *N) {
   SDNode *LoadConstF16 = CurDAG->getMachineNode(
       (N->getValueType(0) == MVT::f16 ? NVPTX::LOAD_CONST_F16
                                       : NVPTX::LOAD_CONST_BF16),
-      SDLoc(N), N->getValueType(0), Val);
+      SDLoc(N), N->getValueType(0), Val, PRED(0), SW(0));
   ReplaceNode(N, LoadConstF16);
   return true;
 }
@@ -600,9 +603,15 @@ bool NVPTXDAGToDAGISel::SelectSETP_F16X2(SDNode *N) {
   unsigned PTXCmpMode =
       getPTXCmpMode(*cast<CondCodeSDNode>(N->getOperand(2)), useF32FTZ());
   SDLoc DL(N);
-  SDNode *SetP = CurDAG->getMachineNode(
-      NVPTX::SETP_f16x2rr, DL, MVT::i1, MVT::i1, N->getOperand(0),
-      N->getOperand(1), CurDAG->getTargetConstant(PTXCmpMode, DL, MVT::i32));
+  SmallVector<SDValue> Ops = {
+      N->getOperand(0),
+      N->getOperand(1),
+      CurDAG->getTargetConstant(PTXCmpMode, DL, MVT::i32),
+      PRED(0),
+      SW(0),
+  };
+  SDNode *SetP =
+      CurDAG->getMachineNode(NVPTX::SETP_f16x2rr, DL, MVT::i1, MVT::i1, Ops);
   ReplaceNode(N, SetP);
   return true;
 }
@@ -643,8 +652,8 @@ bool NVPTXDAGToDAGISel::tryEXTRACT_VECTOR_ELEMENT(SDNode *N) {
   // Merge (f16 extractelt(V, 0), f16 extractelt(V,1))
   // into f16,f16 SplitF16x2(V)
   MVT EltVT = VT.getVectorElementType();
-  SDNode *ScatterOp =
-      CurDAG->getMachineNode(NVPTX::I32toV2I16, SDLoc(N), EltVT, EltVT, Vector);
+  SDNode *ScatterOp = CurDAG->getMachineNode(NVPTX::I32toV2I16, SDLoc(N), EltVT,
+                                             EltVT, Vector, PRED(0), SW(0));
   for (auto *Node : E0)
     ReplaceUses(SDValue(Node, 0), SDValue(ScatterOp, 0));
   for (auto *Node : E1)
@@ -731,7 +740,7 @@ void NVPTXDAGToDAGISel::SelectTexSurfHandle(SDNode *N) {
   SDValue Wrapper = N->getOperand(1);
   SDValue GlobalVal = Wrapper.getOperand(0);
   ReplaceNode(N, CurDAG->getMachineNode(NVPTX::texsurf_handles, SDLoc(N),
-                                        MVT::i64, GlobalVal));
+                                        MVT::i64, GlobalVal, PRED(0), SW(0)));
 }
 
 void NVPTXDAGToDAGISel::SelectAddrSpaceCast(SDNode *N) {
@@ -767,7 +776,7 @@ void NVPTXDAGToDAGISel::SelectAddrSpaceCast(SDNode *N) {
       break;
     }
     ReplaceNode(N, CurDAG->getMachineNode(Opc, SDLoc(N), N->getValueType(0),
-                                          Src));
+                                          Src, PRED(0), SW(0)));
     return;
   } else {
     // Generic to specific
@@ -801,7 +810,7 @@ void NVPTXDAGToDAGISel::SelectAddrSpaceCast(SDNode *N) {
       break;
     }
     ReplaceNode(N, CurDAG->getMachineNode(Opc, SDLoc(N), N->getValueType(0),
-                                          Src));
+                                          Src, PRED(0), SW(0)));
     return;
   }
 }
@@ -934,9 +943,15 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
                              NVPTX::LD_f32_avar, NVPTX::LD_f64_avar);
     if (!Opcode)
       return false;
-    SDValue Ops[] = { getI32Imm(isVolatile, dl), getI32Imm(CodeAddrSpace, dl),
-                      getI32Imm(vecType, dl), getI32Imm(fromType, dl),
-                      getI32Imm(fromTypeWidth, dl), Addr, Chain };
+    SDValue Ops[] = {getI32Imm(isVolatile, dl),
+                     getI32Imm(CodeAddrSpace, dl),
+                     getI32Imm(vecType, dl),
+                     getI32Imm(fromType, dl),
+                     getI32Imm(fromTypeWidth, dl),
+                     Addr,
+                     PRED(0),
+                     SW(0),
+                     Chain};
     NVPTXLD = CurDAG->getMachineNode(*Opcode, dl, TargetVT, MVT::Other, Ops);
   } else if (PointerSize == 64 ? SelectADDRsi64(N1.getNode(), N1, Base, Offset)
                                : SelectADDRsi(N1.getNode(), N1, Base, Offset)) {
@@ -945,9 +960,16 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
                              NVPTX::LD_f32_asi, NVPTX::LD_f64_asi);
     if (!Opcode)
       return false;
-    SDValue Ops[] = { getI32Imm(isVolatile, dl), getI32Imm(CodeAddrSpace, dl),
-                      getI32Imm(vecType, dl), getI32Imm(fromType, dl),
-                      getI32Imm(fromTypeWidth, dl), Base, Offset, Chain };
+    SDValue Ops[] = {getI32Imm(isVolatile, dl),
+                     getI32Imm(CodeAddrSpace, dl),
+                     getI32Imm(vecType, dl),
+                     getI32Imm(fromType, dl),
+                     getI32Imm(fromTypeWidth, dl),
+                     Base,
+                     Offset,
+                     PRED(0),
+                     SW(0),
+                     Chain};
     NVPTXLD = CurDAG->getMachineNode(*Opcode, dl, TargetVT, MVT::Other, Ops);
   } else if (PointerSize == 64 ? SelectADDRri64(N1.getNode(), N1, Base, Offset)
                                : SelectADDRri(N1.getNode(), N1, Base, Offset)) {
@@ -962,9 +984,16 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
                                NVPTX::LD_f32_ari, NVPTX::LD_f64_ari);
     if (!Opcode)
       return false;
-    SDValue Ops[] = { getI32Imm(isVolatile, dl), getI32Imm(CodeAddrSpace, dl),
-                      getI32Imm(vecType, dl), getI32Imm(fromType, dl),
-                      getI32Imm(fromTypeWidth, dl), Base, Offset, Chain };
+    SDValue Ops[] = {getI32Imm(isVolatile, dl),
+                     getI32Imm(CodeAddrSpace, dl),
+                     getI32Imm(vecType, dl),
+                     getI32Imm(fromType, dl),
+                     getI32Imm(fromTypeWidth, dl),
+                     Base,
+                     Offset,
+                     PRED(0),
+                     SW(0),
+                     Chain};
     NVPTXLD = CurDAG->getMachineNode(*Opcode, dl, TargetVT, MVT::Other, Ops);
   } else {
     if (PointerSize == 64)
@@ -978,9 +1007,15 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
                                NVPTX::LD_f32_areg, NVPTX::LD_f64_areg);
     if (!Opcode)
       return false;
-    SDValue Ops[] = { getI32Imm(isVolatile, dl), getI32Imm(CodeAddrSpace, dl),
-                      getI32Imm(vecType, dl), getI32Imm(fromType, dl),
-                      getI32Imm(fromTypeWidth, dl), N1, Chain };
+    SDValue Ops[] = {getI32Imm(isVolatile, dl),
+                     getI32Imm(CodeAddrSpace, dl),
+                     getI32Imm(vecType, dl),
+                     getI32Imm(fromType, dl),
+                     getI32Imm(fromTypeWidth, dl),
+                     N1,
+                     PRED(0),
+                     SW(0),
+                     Chain};
     NVPTXLD = CurDAG->getMachineNode(*Opcode, dl, TargetVT, MVT::Other, Ops);
   }
 
@@ -1090,9 +1125,15 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
     }
     if (!Opcode)
       return false;
-    SDValue Ops[] = { getI32Imm(IsVolatile, DL), getI32Imm(CodeAddrSpace, DL),
-                      getI32Imm(VecType, DL), getI32Imm(FromType, DL),
-                      getI32Imm(FromTypeWidth, DL), Addr, Chain };
+    SDValue Ops[] = {getI32Imm(IsVolatile, DL),
+                     getI32Imm(CodeAddrSpace, DL),
+                     getI32Imm(VecType, DL),
+                     getI32Imm(FromType, DL),
+                     getI32Imm(FromTypeWidth, DL),
+                     Addr,
+                     PRED(0),
+                     SW(0),
+                     Chain};
     LD = CurDAG->getMachineNode(*Opcode, DL, N->getVTList(), Ops);
   } else if (PointerSize == 64
                  ? SelectADDRsi64(Op1.getNode(), Op1, Base, Offset)
@@ -1115,9 +1156,16 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
     }
     if (!Opcode)
       return false;
-    SDValue Ops[] = { getI32Imm(IsVolatile, DL), getI32Imm(CodeAddrSpace, DL),
-                      getI32Imm(VecType, DL), getI32Imm(FromType, DL),
-                      getI32Imm(FromTypeWidth, DL), Base, Offset, Chain };
+    SDValue Ops[] = {getI32Imm(IsVolatile, DL),
+                     getI32Imm(CodeAddrSpace, DL),
+                     getI32Imm(VecType, DL),
+                     getI32Imm(FromType, DL),
+                     getI32Imm(FromTypeWidth, DL),
+                     Base,
+                     Offset,
+                     PRED(0),
+                     SW(0),
+                     Chain};
     LD = CurDAG->getMachineNode(*Opcode, DL, N->getVTList(), Ops);
   } else if (PointerSize == 64
                  ? SelectADDRri64(Op1.getNode(), Op1, Base, Offset)
@@ -1160,9 +1208,16 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
     }
     if (!Opcode)
       return false;
-    SDValue Ops[] = { getI32Imm(IsVolatile, DL), getI32Imm(CodeAddrSpace, DL),
-                      getI32Imm(VecType, DL), getI32Imm(FromType, DL),
-                      getI32Imm(FromTypeWidth, DL), Base, Offset, Chain };
+    SDValue Ops[] = {getI32Imm(IsVolatile, DL),
+                     getI32Imm(CodeAddrSpace, DL),
+                     getI32Imm(VecType, DL),
+                     getI32Imm(FromType, DL),
+                     getI32Imm(FromTypeWidth, DL),
+                     Base,
+                     Offset,
+                     PRED(0),
+                     SW(0),
+                     Chain};
 
     LD = CurDAG->getMachineNode(*Opcode, DL, N->getVTList(), Ops);
   } else {
@@ -1205,9 +1260,15 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
     }
     if (!Opcode)
       return false;
-    SDValue Ops[] = { getI32Imm(IsVolatile, DL), getI32Imm(CodeAddrSpace, DL),
-                      getI32Imm(VecType, DL), getI32Imm(FromType, DL),
-                      getI32Imm(FromTypeWidth, DL), Op1, Chain };
+    SDValue Ops[] = {getI32Imm(IsVolatile, DL),
+                     getI32Imm(CodeAddrSpace, DL),
+                     getI32Imm(VecType, DL),
+                     getI32Imm(FromType, DL),
+                     getI32Imm(FromTypeWidth, DL),
+                     Op1,
+                     PRED(0),
+                     SW(0),
+                     Chain};
     LD = CurDAG->getMachineNode(*Opcode, DL, N->getVTList(), Ops);
   }
 
@@ -1341,7 +1402,7 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
     }
     if (!Opcode)
       return false;
-    SDValue Ops[] = { Addr, Chain };
+    SDValue Ops[] = {Addr, PRED(0), SW(0), Chain};
     LD = CurDAG->getMachineNode(*Opcode, DL, InstVTList, Ops);
   } else if (TM.is64Bit() ? SelectADDRri64(Op1.getNode(), Op1, Base, Offset)
                           : SelectADDRri(Op1.getNode(), Op1, Base, Offset)) {
@@ -1464,7 +1525,7 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
     }
     if (!Opcode)
       return false;
-    SDValue Ops[] = {Base, Offset, Chain};
+    SDValue Ops[] = {Base, Offset, PRED(0), SW(0), Chain};
     LD = CurDAG->getMachineNode(*Opcode, DL, InstVTList, Ops);
   } else {
     if (TM.is64Bit()) {
@@ -1586,7 +1647,7 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
     }
     if (!Opcode)
       return false;
-    SDValue Ops[] = { Op1, Chain };
+    SDValue Ops[] = {Op1, PRED(0), SW(0), Chain};
     LD = CurDAG->getMachineNode(*Opcode, DL, InstVTList, Ops);
   }
 
@@ -1618,10 +1679,14 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
       SDValue Res(LD, i);
       SDValue OrigVal(N, i);
 
-      SDNode *CvtNode =
-        CurDAG->getMachineNode(CvtOpc, DL, OrigType, Res,
-                               CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE,
-                                                         DL, MVT::i32));
+      SDValue Ops[] = {
+          Res,
+          CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE, DL, MVT::i32),
+          PRED(0),
+          SW(0),
+
+      };
+      SDNode *CvtNode = CurDAG->getMachineNode(CvtOpc, DL, OrigType, Ops);
       ReplaceUses(OrigVal, SDValue(CvtNode, 0));
     }
   }
@@ -1709,6 +1774,8 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
                      getI32Imm(toType, dl),
                      getI32Imm(toTypeWidth, dl),
                      Addr,
+                     PRED(0),
+                     SW(0),
                      Chain};
     NVPTXST = CurDAG->getMachineNode(*Opcode, dl, MVT::Other, Ops);
   } else if (PointerSize == 64
@@ -1727,6 +1794,8 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
                      getI32Imm(toTypeWidth, dl),
                      Base,
                      Offset,
+                     PRED(0),
+                     SW(0),
                      Chain};
     NVPTXST = CurDAG->getMachineNode(*Opcode, dl, MVT::Other, Ops);
   } else if (PointerSize == 64
@@ -1752,6 +1821,8 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
                      getI32Imm(toTypeWidth, dl),
                      Base,
                      Offset,
+                     PRED(0),
+                     SW(0),
                      Chain};
     NVPTXST = CurDAG->getMachineNode(*Opcode, dl, MVT::Other, Ops);
   } else {
@@ -1773,6 +1844,8 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
                      getI32Imm(toType, dl),
                      getI32Imm(toTypeWidth, dl),
                      BasePtr,
+                     PRED(0),
+                     SW(0),
                      Chain};
     NVPTXST = CurDAG->getMachineNode(*Opcode, dl, MVT::Other, Ops);
   }
@@ -1982,6 +2055,8 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
   if (!Opcode)
     return false;
 
+  StOps.push_back(PRED(0));
+  StOps.push_back(SW(0));
   StOps.push_back(Chain);
 
   ST = CurDAG->getMachineNode(*Opcode, DL, MVT::Other, StOps);
@@ -1993,15 +2068,15 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
   return true;
 }
 
-bool NVPTXDAGToDAGISel::tryLoadParam(SDNode *Node) {
-  SDValue Chain = Node->getOperand(0);
-  SDValue Offset = Node->getOperand(2);
-  SDValue Glue = Node->getOperand(3);
-  SDLoc DL(Node);
-  MemSDNode *Mem = cast<MemSDNode>(Node);
+bool NVPTXDAGToDAGISel::tryLoadParam(SDNode *N) {
+  SDValue Chain = N->getOperand(0);
+  SDValue Offset = N->getOperand(2);
+  SDValue Glue = N->getOperand(3);
+  SDLoc DL(N);
+  MemSDNode *Mem = cast<MemSDNode>(N);
 
   unsigned VecSize;
-  switch (Node->getOpcode()) {
+  switch (N->getOpcode()) {
   default:
     return false;
   case NVPTXISD::LoadParam:
@@ -2015,7 +2090,7 @@ bool NVPTXDAGToDAGISel::tryLoadParam(SDNode *Node) {
     break;
   }
 
-  EVT EltVT = Node->getValueType(0);
+  EVT EltVT = N->getValueType(0);
   EVT MemVT = Mem->getMemoryVT();
 
   std::optional<unsigned> Opcode;
@@ -2058,12 +2133,10 @@ bool NVPTXDAGToDAGISel::tryLoadParam(SDNode *Node) {
 
   unsigned OffsetVal = cast<ConstantSDNode>(Offset)->getZExtValue();
 
-  SmallVector<SDValue, 2> Ops;
-  Ops.push_back(CurDAG->getTargetConstant(OffsetVal, DL, MVT::i32));
-  Ops.push_back(Chain);
-  Ops.push_back(Glue);
+  SDValue Ops[] = {CurDAG->getTargetConstant(OffsetVal, DL, MVT::i32), PRED(0),
+                   SW(0), Chain, Glue};
 
-  ReplaceNode(Node, CurDAG->getMachineNode(*Opcode, DL, VTs, Ops));
+  ReplaceNode(N, CurDAG->getMachineNode(*Opcode, DL, VTs, Ops));
   return true;
 }
 
@@ -2094,8 +2167,8 @@ bool NVPTXDAGToDAGISel::tryStoreRetval(SDNode *N) {
   SmallVector<SDValue, 6> Ops;
   for (unsigned i = 0; i < NumElts; ++i)
     Ops.push_back(N->getOperand(i + 2));
-  Ops.push_back(CurDAG->getTargetConstant(OffsetVal, DL, MVT::i32));
-  Ops.push_back(Chain);
+  Ops.append({CurDAG->getTargetConstant(OffsetVal, DL, MVT::i32), PRED(0),
+              SW(0), Chain});
 
   // Determine target opcode
   // If we have an i1, use an 8-bit store. The lowering code in
@@ -2166,10 +2239,9 @@ bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
   SmallVector<SDValue, 8> Ops;
   for (unsigned i = 0; i < NumElts; ++i)
     Ops.push_back(N->getOperand(i + 3));
-  Ops.push_back(CurDAG->getTargetConstant(ParamVal, DL, MVT::i32));
-  Ops.push_back(CurDAG->getTargetConstant(OffsetVal, DL, MVT::i32));
-  Ops.push_back(Chain);
-  Ops.push_back(Glue);
+  Ops.append({CurDAG->getTargetConstant(ParamVal, DL, MVT::i32),
+              CurDAG->getTargetConstant(OffsetVal, DL, MVT::i32), PRED(0),
+              SW(0), Chain, Glue});
 
   // Determine target opcode
   // If we have an i1, use an 8-bit store. The lowering code in
@@ -2747,6 +2819,8 @@ bool NVPTXDAGToDAGISel::tryTextureIntrinsic(SDNode *N) {
 
   // Copy over operands
   SmallVector<SDValue, 8> Ops(drop_begin(N->ops()));
+  Ops.push_back(PRED(0));
+  Ops.push_back(SW(0));
   Ops.push_back(N->getOperand(0)); // Move chain to the back.
 
   ReplaceNode(N, CurDAG->getMachineNode(Opc, SDLoc(N), N->getVTList(), Ops));
@@ -3256,6 +3330,8 @@ bool NVPTXDAGToDAGISel::trySurfaceIntrinsic(SDNode *N) {
 
   // Copy over operands
   SmallVector<SDValue, 8> Ops(drop_begin(N->ops()));
+  Ops.push_back(PRED(0));
+  Ops.push_back(SW(0));
   Ops.push_back(N->getOperand(0)); // Move chain to the back.
 
   ReplaceNode(N, CurDAG->getMachineNode(Opc, SDLoc(N), N->getVTList(), Ops));
@@ -3462,7 +3538,7 @@ bool NVPTXDAGToDAGISel::tryBFE(SDNode *N) {
   }
 
   SDValue Ops[] = {
-    Val, Start, Len
+      Val, Start, Len, PRED(0), SW(0),
   };
 
   ReplaceNode(N, CurDAG->getMachineNode(Opc, DL, N->getVTList(), Ops));
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrFormats.td b/llvm/lib/Target/NVPTX/NVPTXInstrFormats.td
index 9220f4766d92c60..e0660c50ff5c5e0 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrFormats.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrFormats.td
@@ -10,24 +10,42 @@
 //  Describe NVPTX instructions format
 //
 //===----------------------------------------------------------------------===//
-
 // Vector instruction type enum
 class VecInstTypeEnum<bits<4> val> {
   bits<4> Value=val;
 }
+
 def VecNOP : VecInstTypeEnum<0>;
 
-// Generic NVPTX Format
+def pred : PredicateOperand<
+    /*ValueType*/i1,
+    /*OpTypes*/(ops Int1Regs),
+    (ops (i32 /*AlwaysVal*/zero_reg))> {
+  let PrintMethod = "printPredicateOperand";
+  let ParserMatchClass = ?;
+  let DecoderMethod = ?;
+}
+
+def switch : OperandWithDefaultOps</*ValueType*/i1, /*defaultops*/(ops (i1 0))>;
 
-class NVPTXInst<dag outs, dag ins, string asmstr, list<dag> pattern>
-  : Instruction {
+// Generic NVPTX Format
+class NVPTXInstBase <dag outs, dag Ins, string asmstr, list<dag>
+pattern, int defaultPreds = 1> : Instruction {
   field bits<14> Inst;
 
   let Namespace = "NVPTX";
   dag OutOperandList = outs;
-  dag InOperandList = ins;
-  let AsmString = asmstr;
+  dag InOperandList = !if(!and(defaultPreds, isPredicable),
+                          !con(Ins, (ins pred:$predicate, switch:$invert_pred)),
+                          Ins);
+
+  let AsmString = !if(!and(defaultPreds, isPredicable),
+                      !strconcat("$predicate ", asmstr),
+                      asmstr);
   let Pattern = pattern;
+  // All PTX instructions are predicable
+  // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#predicated-execution
+  let isPredicable = true;
 
   // TSFlagFields
   bits<4> VecInstType = VecNOP.Value;
@@ -56,3 +74,15 @@ class NVPTXInst<dag outs, dag ins, string asmstr, list<dag> pattern>
   let TSFlags{11}      = IsSurfTexQuery;
   let TSFlags{12}      = IsTexModeUnified;
 }
+
+class NVPTXInst<dag outs, dag ins, string asmstr, list<dag> pattern> :
+NVPTXInstBase<outs, ins, asmstr, pattern, 1>;
+
+class NVPTXInstWithExplicitPreds<dag outs, dag ins, string asmstr, list<dag>
+pattern> : NVPTXInstBase<outs, ins, asmstr, pattern, 0> ;
+
+class NVPTXPseudo<dag outs, dag ins, string asmstr, list<dag> pattern, int
+defaultPreds> : NVPTXInstBase<outs, ins, asmstr, pattern, defaultPreds> {
+  let isPseudo = 1;
+  let isCodeGenOnly = 1;
+}
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
index b0d792b5ee3fe69..2d397da0012c554 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
@@ -29,6 +29,15 @@ void NVPTXInstrInfo::anchor() {}
 
 NVPTXInstrInfo::NVPTXInstrInfo() : RegInfo() {}
 
+static bool isKnownBranch(MachineInstr &MI) {
+  switch (MI.getOpcode()) {
+  case NVPTX::Bra:
+  case NVPTX::BraUni:
+  case NVPTX::Jump:
+    return true;
+  }
+  return false;
+};
 void NVPTXInstrInfo::copyPhysReg(MachineBasicBlock &MBB,
                                  MachineBasicBlock::iterator I,
                                  const DebugLoc &DL, MCRegister DestReg,
@@ -61,7 +70,9 @@ void NVPTXInstrInfo::copyPhysReg(MachineBasicBlock &MBB,
     llvm_unreachable("Bad register copy");
   }
   BuildMI(MBB, I, DL, get(Op), DestReg)
-      .addReg(SrcReg, getKillRegState(KillSrc));
+      .addReg(SrcReg, getKillRegState(KillSrc))
+      .addReg(0)
+      .addImm(0);
 }
 
 /// analyzeBranch - Analyze the branching code at the end of MBB, returning
@@ -92,50 +103,62 @@ bool NVPTXInstrInfo::analyzeBranch(MachineBasicBlock &MBB,
                                    MachineBasicBlock *&FBB,
                                    SmallVectorImpl<MachineOperand> &Cond,
                                    bool AllowModify) const {
+  auto getBranchTarget = [](MachineInstr &MI) {
+    assert(MI.isBranch());
+    return MI.getOperand(0).getMBB();
+  };
+  auto getPred = [](MachineInstr &MI) {
+    assert(MI.isPredicable());
+    // The predicate and predicate switch are always the last two operands
+    return MI.getOperand(MI.getNumOperands() - 2);
+  };
+  auto isConditional = [&](MachineInstr &MI) { return !!getPred(MI).getReg(); };
   // If the block has no terminators, it just falls into the block after it.
   MachineBasicBlock::iterator I = MBB.end();
-  if (I == MBB.begin() || !isUnpredicatedTerminator(*--I))
+  if (I == MBB.begin() || !(--I)->isTerminator())
     return false;
+  if (I->isReturn())
+    return true;
 
   // Get the last instruction in the block.
   MachineInstr &LastInst = *I;
+  assert(LastInst.isTerminator());
 
   // If there is only one terminator instruction, process it.
-  if (I == MBB.begin() || !isUnpredicatedTerminator(*--I)) {
-    if (LastInst.getOpcode() == NVPTX::GOTO) {
-      TBB = LastInst.getOperand(0).getMBB();
-      return false;
-    } else if (LastInst.getOpcode() == NVPTX::CBranch) {
-      // Block ends with fall-through condbranch.
-      TBB = LastInst.getOperand(1).getMBB();
-      Cond.push_back(LastInst.getOperand(0));
+  if (I == MBB.begin() || !(--I)->isTerminator()) {
+    if (!isKnownBranch(LastInst)) {
+      // We don't know what this is
+      return true;
+    }
+    if (!isConditional(LastInst)) {
+      TBB = getBranchTarget(LastInst);
       return false;
     }
-    // Otherwise, don't know what this is.
-    return true;
+    // Block ends with fall-through condbranch.
+    TBB = getBranchTarget(LastInst);
+    Cond.push_back(getPred(LastInst));
+    return false;
   }
 
   // Get the instruction before it if it's a terminator.
   MachineInstr &SecondLastInst = *I;
 
   // If there are three terminators, we don't know what sort of block this is.
-  if (I != MBB.begin() && isUnpredicatedTerminator(*--I))
+  if (I != MBB.begin() && (--I)->isTerminator())
     return true;
 
-  // If the block ends with NVPTX::GOTO and NVPTX:CBranch, handle it.
-  if (SecondLastInst.getOpcode() == NVPTX::CBranch &&
-      LastInst.getOpcode() == NVPTX::GOTO) {
-    TBB = SecondLastInst.getOperand(1).getMBB();
-    Cond.push_back(SecondLastInst.getOperand(0));
-    FBB = LastInst.getOperand(0).getMBB();
+  // If the block ends with unconditional preceded by a conditional, handle it.
+  if (isConditional(SecondLastInst) && !isConditional(LastInst)) {
+    TBB = getBranchTarget(SecondLastInst);
+    Cond.push_back(getPred(SecondLastInst));
+    FBB = getBranchTarget(LastInst);
     return false;
   }
 
-  // If the block ends with two NVPTX:GOTOs, handle it.  The second one is not
-  // executed, so remove it.
-  if (SecondLastInst.getOpcode() == NVPTX::GOTO &&
-      LastInst.getOpcode() == NVPTX::GOTO) {
-    TBB = SecondLastInst.getOperand(0).getMBB();
+  // If the block ends with two unconditional jumps, handle it. The second one
+  // is not executed, so remove it.
+  if (!isConditional(SecondLastInst) && !isConditional(LastInst)) {
+    TBB = getBranchTarget(SecondLastInst);
     I = LastInst;
     if (AllowModify)
       I->eraseFromParent();
@@ -153,7 +176,7 @@ unsigned NVPTXInstrInfo::removeBranch(MachineBasicBlock &MBB,
   if (I == MBB.begin())
     return 0;
   --I;
-  if (I->getOpcode() != NVPTX::GOTO && I->getOpcode() != NVPTX::CBranch)
+  if (!isKnownBranch(*I))
     return 0;
 
   // Remove the branch.
@@ -164,7 +187,7 @@ unsigned NVPTXInstrInfo::removeBranch(MachineBasicBlock &MBB,
   if (I == MBB.begin())
     return 1;
   --I;
-  if (I->getOpcode() != NVPTX::CBranch)
+  if (!isKnownBranch(*I))
     return 1;
 
   // Remove the branch.
@@ -188,14 +211,14 @@ unsigned NVPTXInstrInfo::insertBranch(MachineBasicBlock &MBB,
   // One-way branch.
   if (!FBB) {
     if (Cond.empty()) // Unconditional branch
-      BuildMI(&MBB, DL, get(NVPTX::GOTO)).addMBB(TBB);
+      BuildMI(&MBB, DL, get(NVPTX::Jump)).addMBB(TBB).addReg(0).addImm(0);
     else // Conditional branch
-      BuildMI(&MBB, DL, get(NVPTX::CBranch)).add(Cond[0]).addMBB(TBB);
+      BuildMI(&MBB, DL, get(NVPTX::Bra)).addMBB(TBB).add(Cond[0]).addImm(0);
     return 1;
   }
 
   // Two-way Conditional Branch.
-  BuildMI(&MBB, DL, get(NVPTX::CBranch)).add(Cond[0]).addMBB(TBB);
-  BuildMI(&MBB, DL, get(NVPTX::GOTO)).addMBB(FBB);
+  BuildMI(&MBB, DL, get(NVPTX::Bra)).addMBB(TBB).add(Cond[0]).addImm(0);
+  BuildMI(&MBB, DL, get(NVPTX::Jump)).addMBB(FBB).addReg(0).addImm(0);
   return 2;
 }
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 28c4cadb303ad4f..a2c0fd966e4fa65 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -3498,33 +3498,44 @@ defm : CVT_ROUND<frint, CvtRNI, CvtRNI_FTZ>;
 // Control-flow
 //-----------------------------------
 
-let isTerminator=1 in {
-   let isReturn=1, isBarrier=1 in
-      def Return : NVPTXInst<(outs), (ins), "ret;", [(retglue)]>;
-
-   let isBranch=1 in
-      def CBranch : NVPTXInst<(outs), (ins Int1Regs:$a, brtarget:$target),
-                              "@$a bra \t$target;",
-                              [(brcond Int1Regs:$a, bb:$target)]>;
-   let isBranch=1 in
-      def CBranchOther : NVPTXInst<(outs), (ins Int1Regs:$a, brtarget:$target),
-                                   "@!$a bra \t$target;", []>;
-
-   let isBranch=1, isBarrier=1 in
-      def GOTO : NVPTXInst<(outs), (ins brtarget:$target),
-                           "bra.uni \t$target;", [(br bb:$target)]>;
+// As with the uniform call instructions, uniform branches and returns are
+// modelled as distinct instructions
+let isTerminator = 1 in {
+  let isBranch = 1 in {
+    def Bra : NVPTXInst<(outs), (ins brtarget:$target), "bra \t$target;", []>;
+    def BraUni : NVPTXInst<(outs), (ins brtarget:$target), "bra.uni \t$target;", []>;
+    // LLVM won't let us model conditional and unconditional branches with a
+    // single instruction, but PTX has only one "true" branch instruction.
+    // For codegen purposes, and to placate the verifier, we model a separate
+    // "unconditional" branch with a pseudo.
+    let isBarrier = 1 in
+      def Jump : NVPTXPseudo<(outs), (ins brtarget:$target), "bra.uni \t$target",
+      [], /* defaultPreds */ 1>;
+  }
+  let isReturn = 1, isBarrier = 1 in {
+    def Ret : NVPTXInst<(outs), (ins), "ret;", [(retglue)]>;
+    def RetUni : NVPTXInst<(outs), (ins), "ret.uni;", [(retglue)]>;
+  }
 }
 
-def : Pat<(brcond (i32 Int32Regs:$a), bb:$target),
-          (CBranch (SETP_u32ri Int32Regs:$a, 0, CmpNE), bb:$target)>;
-
+// Conditional Branch
+def : Pat<(brcond Int1Regs:$p, bb:$target),
+          (Bra bb:$target, Int1Regs:$p, /* invert */ 0)>;
+// GOTO / jmp is by definition non-divergent, so we can use the uniform version
+// of the branch with a default-true predicate
+def : Pat<(br bb:$target), (Jump bb:$target, zero_reg, /* invert */ 0)>;
+// conditional branch on i32 condition
+def : Pat<(brcond (i32 Int32Regs:$p), bb:$target),
+          (Bra bb:$target, (SETP_u32ri Int32Regs:$p, 0, CmpNE), /* invert */ 0)>;
+
+// XXX
 // SelectionDAGBuilder::visitSWitchCase() will invert the condition of a
 // conditional branch if the target block is the next block so that the code
 // can fall through to the target block.  The invertion is done by 'xor
 // condition, 1', which will be translated to (setne condition, -1).  Since ptx
 // supports '@!pred bra target', we should use it.
-def : Pat<(brcond (i1 (setne Int1Regs:$a, -1)), bb:$target),
-          (CBranchOther Int1Regs:$a, bb:$target)>;
+def : Pat<(brcond (i1 (setne Int1Regs:$p, -1)), bb:$target),
+            (Bra bb:$target, Int1Regs:$p, /* invert */1)>;
 
 // Call
 def SDT_NVPTXCallSeqStart : SDCallSeqStart<[SDTCisVT<0, i32>,
diff --git a/llvm/lib/Target/NVPTX/NVPTXPeephole.cpp b/llvm/lib/Target/NVPTX/NVPTXPeephole.cpp
index 0968701737e88d6..9d4806245ce6125 100644
--- a/llvm/lib/Target/NVPTX/NVPTXPeephole.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXPeephole.cpp
@@ -120,7 +120,9 @@ static void CombineCVTAToLocal(MachineInstr &Root) {
       BuildMI(MF, Root.getDebugLoc(), TII->get(Prev.getOpcode()),
               Root.getOperand(0).getReg())
           .addReg(NRI->getFrameLocalRegister(MF))
-          .add(Prev.getOperand(2));
+          .add(Prev.getOperand(2))
+          .addReg(0)
+          .addImm(0);
 
   MBB.insert((MachineBasicBlock::iterator)&Root, MIB);
 
diff --git a/llvm/test/CodeGen/NVPTX/analyze_branch_crash.ll b/llvm/test/CodeGen/NVPTX/analyze_branch_crash.ll
new file mode 100644
index 000000000000000..f56d081cb24eac4
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/analyze_branch_crash.ll
@@ -0,0 +1,10 @@
+; RUN: llc < %s -march=nvptx64 -verify-machineinstrs -o /dev/null
+; Regression test: don't crash when analyzing branches like that one time
+
+define void @crash() {
+entry:
+  br label %loop
+
+loop:
+  br label %loop
+}
diff --git a/llvm/test/CodeGen/NVPTX/branch-fold.mir b/llvm/test/CodeGen/NVPTX/branch-fold.mir
index 8bdac44c4f2350e..32eab35acb2c8c8 100644
--- a/llvm/test/CodeGen/NVPTX/branch-fold.mir
+++ b/llvm/test/CodeGen/NVPTX/branch-fold.mir
@@ -47,7 +47,7 @@ body:             |
   ; CHECK: bb.0.bb:
   ; CHECK-NEXT:   successors: %bb.1(0x40000000), %bb.3(0x40000000)
   ; CHECK-NEXT: {{  $}}
-  ; CHECK-NEXT:   CBranch undef %2:int1regs, %bb.3
+  ; CHECK-NEXT:   Bra %bb.3, undef %2:int1regs, 0
   ; CHECK-NEXT: {{  $}}
   ; CHECK-NEXT: bb.1.bb1.preheader:
   ; CHECK-NEXT:   successors: %bb.2(0x80000000)
@@ -59,16 +59,16 @@ body:             |
   ; CHECK-NEXT: {{  $}}
   ; CHECK-NEXT:   [[ADDi64ri:%[0-9]+]]:int64regs = ADDi64ri [[ADDi64ri]], 1
   ; CHECK-NEXT:   [[SETP_s64ri:%[0-9]+]]:int1regs = SETP_s64ri [[ADDi64ri]], 1, 2
-  ; CHECK-NEXT:   CBranch [[SETP_s64ri]], %bb.2
+  ; CHECK-NEXT:   Bra %bb.2, [[SETP_s64ri]]
   ; CHECK-NEXT: {{  $}}
   ; CHECK-NEXT: bb.3.bb4:
   ; CHECK-NEXT:   successors: %bb.3(0x80000000)
   ; CHECK-NEXT: {{  $}}
-  ; CHECK-NEXT:   GOTO %bb.3
+  ; CHECK-NEXT:   Jump %bb.3
   bb.0.bb:
     successors: %bb.1, %bb.3
 
-    CBranch undef %2:int1regs, %bb.3
+    Bra %bb.3, undef %2:int1regs, 0
 
   bb.1.bb1.preheader:
     %5:int64regs = IMPLICIT_DEF
@@ -76,11 +76,11 @@ body:             |
   bb.2.bb1:
     successors: %bb.2(0x7c000000), %bb.3(0x04000000)
 
-    %5:int64regs = ADDi64ri %5, 1
-    %4:int1regs = SETP_s64ri %5, 1, 2
-    CBranch %4, %bb.2
+    %5:int64regs = ADDi64ri %5, 1, $noreg, 0
+    %4:int1regs = SETP_s64ri %5, 1, 2, $noreg, 0
+    Bra %bb.2, %4, 0
 
   bb.3.bb4:
-    GOTO %bb.3
+    Jump %bb.3, $noreg, 0
 
 ...
diff --git a/llvm/test/CodeGen/NVPTX/branches.ll b/llvm/test/CodeGen/NVPTX/branches.ll
new file mode 100644
index 000000000000000..0963268fbf60da4
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/branches.ll
@@ -0,0 +1,23 @@
+; RUN: llc < %s -stop-after=finalize-isel -O0 -march=nvptx64| FileCheck %s
+
+define dso_local i32 @foo(i32 noundef %i, i32 noundef %b) #0 {
+; CHECK-LABEL: body:
+entry:
+  %cmp = icmp eq i32 %i, 4
+  br i1 %cmp, label %bb.1, label %bb.2
+; CHECK: [[COND:%[0-9]+]]:int1regs = SETP_s32ri
+; CHECK: Bra  %bb.2, killed [[COND]], 0
+; CHECK: Jump %bb.1, $noreg, 0
+
+bb.1:
+; CHECK-LABEL: bb.1:
+  %add = add nsw i32 %i, %b
+  br label %bb.2
+; CHECK: Jump %bb.2, $noreg, 0
+
+bb.2:
+; CHECK-LABEL: bb.2:
+  %ret = phi i32 [%add, %bb.1], [%i, %entry]
+  ret i32 %ret
+; CHECK: Ret $noreg, 0
+}
diff --git a/llvm/test/CodeGen/NVPTX/branches.mir b/llvm/test/CodeGen/NVPTX/branches.mir
new file mode 100644
index 000000000000000..7063c7076f06cf6
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/branches.mir
@@ -0,0 +1,285 @@
+# NOTE: Assertions have been autogenerated by utils/update_mir_test_checks.py UTC_ARGS: --version 3
+# RUN: llc -o - %s -march=nvptx64 -run-pass=branch-folder | FileCheck %s
+---
+name:            two_way_conditional
+tracksRegLiveness: true
+registers:
+  - { id: 0, class: int32regs, preferred-register: '' }
+  - { id: 1, class: int32regs, preferred-register: '' }
+  - { id: 2, class: int1regs, preferred-register: '' }
+  - { id: 3, class: int1regs, preferred-register: '' }
+body:             |
+  ; CHECK-LABEL: name: two_way_conditional
+  ; CHECK: bb.0:
+  ; CHECK-NEXT:   successors: %bb.1(0x40000000), %bb.2(0x40000000)
+  ; CHECK-NEXT: {{  $}}
+  ; CHECK-NEXT:   [[DEF:%[0-9]+]]:int32regs = IMPLICIT_DEF
+  ; CHECK-NEXT:   [[DEF1:%[0-9]+]]:int32regs = IMPLICIT_DEF
+  ; CHECK-NEXT:   [[DEF2:%[0-9]+]]:int1regs = IMPLICIT_DEF
+  ; CHECK-NEXT:   [[DEF3:%[0-9]+]]:int1regs = IMPLICIT_DEF
+  ; CHECK-NEXT:   Bra %bb.1, [[DEF3]], 0
+  ; CHECK-NEXT:   Jump %bb.2, [[DEF2]], 0
+  ; CHECK-NEXT: {{  $}}
+  ; CHECK-NEXT: bb.1:
+  ; CHECK-NEXT:   StoreRetvalI32 [[DEF]], 0, $noreg, 0 :: (store (s32), align 1)
+  ; CHECK-NEXT:   Ret $noreg, 0
+  ; CHECK-NEXT: {{  $}}
+  ; CHECK-NEXT: bb.2:
+  ; CHECK-NEXT:   StoreRetvalI32 [[DEF1]], 0, $noreg, 0 :: (store (s32), align 1)
+  ; CHECK-NEXT:   Ret $noreg, 0
+  bb.0:
+    successors: %bb.1(0x40000000), %bb.2(0x40000000)
+
+    %0:int32regs = IMPLICIT_DEF
+    %1:int32regs = IMPLICIT_DEF
+    %2:int1regs = IMPLICIT_DEF
+    %3:int1regs = IMPLICIT_DEF
+    Bra %bb.1, %3, 0
+    Jump %bb.2, %2, 0
+
+  bb.1:
+    StoreRetvalI32 %0, 0, $noreg, 0 :: (store (s32), align 1)
+    Ret $noreg, 0
+
+  bb.2:
+    StoreRetvalI32 %1, 0, $noreg, 0 :: (store (s32), align 1)
+    Ret $noreg, 0
+
+...
+---
+name:            conditional_with_fallthrough
+tracksRegLiveness: true
+registers:
+  - { id: 0, class: int32regs, preferred-register: '' }
+  - { id: 1, class: int32regs, preferred-register: '' }
+  - { id: 2, class: int32regs, preferred-register: '' }
+  - { id: 3, class: int32regs, preferred-register: '' }
+  - { id: 4, class: int1regs, preferred-register: '' }
+body:             |
+  ; CHECK-LABEL: name: conditional_with_fallthrough
+  ; CHECK: bb.0:
+  ; CHECK-NEXT:   successors: %bb.1(0x40000000), %bb.2(0x40000000)
+  ; CHECK-NEXT: {{  $}}
+  ; CHECK-NEXT:   [[LD_i32_avar:%[0-9]+]]:int32regs = LD_i32_avar 0, 4, 1, 0, 32, &foo_param_1, $noreg, 0
+  ; CHECK-NEXT:   [[LD_i32_avar1:%[0-9]+]]:int32regs = LD_i32_avar 0, 4, 1, 0, 32, &foo_param_0, $noreg, 0
+  ; CHECK-NEXT:   [[SETP_s32ri:%[0-9]+]]:int1regs = SETP_s32ri [[LD_i32_avar1]], 4, 1, $noreg, 0
+  ; CHECK-NEXT:   Bra %bb.2, killed [[SETP_s32ri]], 0
+  ; CHECK-NEXT: {{  $}}
+  ; CHECK-NEXT: bb.1:
+  ; CHECK-NEXT:   StoreRetvalI32 [[LD_i32_avar]], 0, $noreg, 0 :: (store (s32), align 1)
+  ; CHECK-NEXT:   Ret $noreg, 0
+  ; CHECK-NEXT: {{  $}}
+  ; CHECK-NEXT: bb.2:
+  ; CHECK-NEXT:   StoreRetvalI32 [[LD_i32_avar1]], 0, $noreg, 0 :: (store (s32), align 1)
+  ; CHECK-NEXT:   Ret $noreg, 0
+  bb.0:
+    successors: %bb.1(0x40000000), %bb.2(0x40000000)
+
+    %3:int32regs = LD_i32_avar 0, 4, 1, 0, 32, &foo_param_1, $noreg, 0
+    %2:int32regs = LD_i32_avar 0, 4, 1, 0, 32, &foo_param_0, $noreg, 0
+    %4:int1regs = SETP_s32ri %2, 4, 1, $noreg, 0
+    Bra %bb.2, killed %4, 0
+
+  bb.1:
+    StoreRetvalI32 %3, 0, $noreg, 0 :: (store (s32), align 1)
+    Ret $noreg, 0
+
+  bb.2:
+    StoreRetvalI32 %2, 0, $noreg, 0 :: (store (s32), align 1)
+    Ret $noreg, 0
+
+...
+---
+name:            two_way_inverted_conditional
+tracksRegLiveness: true
+registers:
+  - { id: 0, class: int32regs, preferred-register: '' }
+  - { id: 1, class: int32regs, preferred-register: '' }
+  - { id: 2, class: int32regs, preferred-register: '' }
+  - { id: 3, class: int32regs, preferred-register: '' }
+  - { id: 4, class: int1regs, preferred-register: '' }
+body:             |
+  ; CHECK-LABEL: name: two_way_inverted_conditional
+  ; CHECK: bb.0:
+  ; CHECK-NEXT:   successors: %bb.1(0x40000000), %bb.2(0x40000000)
+  ; CHECK-NEXT: {{  $}}
+  ; CHECK-NEXT:   [[LD_i32_avar:%[0-9]+]]:int32regs = LD_i32_avar 0, 4, 1, 0, 32, &foo_param_1, $noreg, 0
+  ; CHECK-NEXT:   [[LD_i32_avar1:%[0-9]+]]:int32regs = LD_i32_avar 0, 4, 1, 0, 32, &foo_param_0, $noreg, 0
+  ; CHECK-NEXT:   [[SETP_s32ri:%[0-9]+]]:int1regs = SETP_s32ri [[LD_i32_avar1]], 4, 1, $noreg, 0
+  ; CHECK-NEXT:   Bra %bb.2, [[SETP_s32ri]], 0
+  ; CHECK-NEXT:   Bra %bb.1, killed [[SETP_s32ri]], -1
+  ; CHECK-NEXT: {{  $}}
+  ; CHECK-NEXT: bb.1:
+  ; CHECK-NEXT:   StoreRetvalI32 [[LD_i32_avar]], 0, $noreg, 0 :: (store (s32), align 1)
+  ; CHECK-NEXT:   Ret $noreg, 0
+  ; CHECK-NEXT: {{  $}}
+  ; CHECK-NEXT: bb.2:
+  ; CHECK-NEXT:   StoreRetvalI32 [[LD_i32_avar1]], 0, $noreg, 0 :: (store (s32), align 1)
+  ; CHECK-NEXT:   Ret $noreg, 0
+  bb.0:
+    successors: %bb.1(0x40000000), %bb.2(0x40000000)
+
+    %3:int32regs = LD_i32_avar 0, 4, 1, 0, 32, &foo_param_1, $noreg, 0
+    %2:int32regs = LD_i32_avar 0, 4, 1, 0, 32, &foo_param_0, $noreg, 0
+    %4:int1regs = SETP_s32ri %2, 4, 1, $noreg, 0
+    Bra %bb.2, %4, 0
+    Bra %bb.1, killed %4, -1
+
+  bb.1:
+    StoreRetvalI32 %3, 0, $noreg, 0 :: (store (s32), align 1)
+    Ret $noreg, 0
+
+  bb.2:
+    StoreRetvalI32 %2, 0, $noreg, 0 :: (store (s32), align 1)
+    Ret $noreg, 0
+
+...
+---
+name:            empty_block_with_fallthrough
+tracksRegLiveness: true
+registers:
+  - { id: 1, class: int32regs, preferred-register: '' }
+body:             |
+  ; CHECK-LABEL: name: empty_block_with_fallthrough
+  ; CHECK: bb.0:
+  ; CHECK-NEXT:   [[LD_i32_avar:%[0-9]+]]:int32regs = LD_i32_avar 0, 4, 1, 0, 32, &foo_param_1, $noreg, 0
+  ; CHECK-NEXT:   StoreRetvalI32 [[LD_i32_avar]], 0, $noreg, 0 :: (store (s32), align 1)
+  ; CHECK-NEXT:   Ret $noreg, 0
+  bb.0:
+    successors: %bb.1(0x80000000)
+
+  bb.1:
+    %1:int32regs = LD_i32_avar 0, 4, 1, 0, 32, &foo_param_1, $noreg, 0
+    StoreRetvalI32 %1, 0, $noreg, 0 :: (store (s32), align 1)
+    Ret $noreg, 0
+
+...
+---
+name:            empty_return_block
+tracksRegLiveness: true
+registers:
+body:             |
+  ; CHECK-LABEL: name: empty_return_block
+  ; CHECK: bb.0:
+  ; CHECK-NEXT:   Ret $noreg, 0
+  bb.0:
+    successors: %bb.1(0x80000000)
+
+  bb.1:
+    Ret $noreg, 0
+
+...
+---
+name:            three_way_branch
+tracksRegLiveness: true
+registers:
+  - { id: 0, class: int32regs, preferred-register: '' }
+  - { id: 1, class: int32regs, preferred-register: '' }
+  - { id: 2, class: int32regs, preferred-register: '' }
+  - { id: 3, class: int1regs, preferred-register: '' }
+  - { id: 4, class: int1regs, preferred-register: '' }
+body:             |
+  ; CHECK-LABEL: name: three_way_branch
+  ; CHECK: bb.0:
+  ; CHECK-NEXT:   successors: %bb.1(0x40000000), %bb.2(0x40000000)
+  ; CHECK-NEXT: {{  $}}
+  ; CHECK-NEXT:   [[DEF:%[0-9]+]]:int32regs = IMPLICIT_DEF
+  ; CHECK-NEXT:   [[DEF1:%[0-9]+]]:int32regs = IMPLICIT_DEF
+  ; CHECK-NEXT:   [[DEF2:%[0-9]+]]:int32regs = IMPLICIT_DEF
+  ; CHECK-NEXT:   [[DEF3:%[0-9]+]]:int1regs = IMPLICIT_DEF
+  ; CHECK-NEXT:   [[DEF4:%[0-9]+]]:int1regs = IMPLICIT_DEF
+  ; CHECK-NEXT:   Bra %bb.2, killed [[DEF4]], 0
+  ; CHECK-NEXT:   Bra %bb.1, killed [[DEF3]], -1
+  ; CHECK-NEXT:   Jump %bb.-1, $noreg, 0
+  ; CHECK-NEXT: {{  $}}
+  ; CHECK-NEXT: bb.1:
+  ; CHECK-NEXT:   StoreRetvalI32 [[DEF1]], 0, $noreg, 0 :: (store (s32), align 1)
+  ; CHECK-NEXT:   Ret $noreg, 0
+  ; CHECK-NEXT: {{  $}}
+  ; CHECK-NEXT: bb.2:
+  ; CHECK-NEXT:   StoreRetvalI32 [[DEF2]], 0, $noreg, 0 :: (store (s32), align 1)
+  ; CHECK-NEXT:   Ret $noreg, 0
+  bb.0:
+    successors: %bb.1(0x40000000), %bb.2(0x40000000)
+
+    %0:int32regs = IMPLICIT_DEF
+    %1:int32regs = IMPLICIT_DEF
+    %2:int32regs = IMPLICIT_DEF
+    %3:int1regs = IMPLICIT_DEF
+    %4:int1regs = IMPLICIT_DEF
+    Bra %bb.2, killed %4, 0
+    Bra %bb.1, killed %3, -1
+    Jump %bb.3, $noreg, 0
+
+  bb.1:
+    StoreRetvalI32 %1, 0, $noreg, 0 :: (store (s32), align 1)
+    Ret $noreg, 0
+
+  bb.2:
+    StoreRetvalI32 %2, 0, $noreg, 0 :: (store (s32), align 1)
+    Ret $noreg, 0
+
+  bb.3:
+    StoreRetvalI32 %0, 0, $noreg, 0 :: (store (s32), align 1)
+    Ret $noreg, 0
+
+...
+
+
+---
+name:            back_to_back_jumps
+tracksRegLiveness: true
+registers:
+  - { id: 0, class: int32regs, preferred-register: '' }
+  - { id: 1, class: int32regs, preferred-register: '' }
+  - { id: 2, class: int32regs, preferred-register: '' }
+  - { id: 3, class: int1regs, preferred-register: '' }
+  - { id: 4, class: int1regs, preferred-register: '' }
+body:             |
+  ; CHECK-LABEL: name: back_to_back_jumps
+  ; CHECK: bb.0:
+  ; CHECK-NEXT:   successors: %bb.1(0x2aaaaaab), %bb.2(0x2aaaaaab), %bb.3(0x2aaaaaab)
+  ; CHECK-NEXT: {{  $}}
+  ; CHECK-NEXT:   [[DEF:%[0-9]+]]:int32regs = IMPLICIT_DEF
+  ; CHECK-NEXT:   [[DEF1:%[0-9]+]]:int32regs = IMPLICIT_DEF
+  ; CHECK-NEXT:   [[DEF2:%[0-9]+]]:int32regs = IMPLICIT_DEF
+  ; CHECK-NEXT:   [[DEF3:%[0-9]+]]:int1regs = IMPLICIT_DEF
+  ; CHECK-NEXT:   [[DEF4:%[0-9]+]]:int1regs = IMPLICIT_DEF
+  ; CHECK-NEXT:   Jump %bb.3, [[DEF3]], 0
+  ; CHECK-NEXT:   Jump %bb.2, [[DEF4]], 0
+  ; CHECK-NEXT: {{  $}}
+  ; CHECK-NEXT: bb.1:
+  ; CHECK-NEXT:   StoreRetvalI32 [[DEF1]], 0, $noreg, 0 :: (store (s32), align 1)
+  ; CHECK-NEXT:   Ret $noreg, 0
+  ; CHECK-NEXT: {{  $}}
+  ; CHECK-NEXT: bb.2:
+  ; CHECK-NEXT:   StoreRetvalI32 [[DEF2]], 0, $noreg, 0 :: (store (s32), align 1)
+  ; CHECK-NEXT:   Ret $noreg, 0
+  ; CHECK-NEXT: {{  $}}
+  ; CHECK-NEXT: bb.3:
+  ; CHECK-NEXT:   StoreRetvalI32 [[DEF]], 0, $noreg, 0 :: (store (s32), align 1)
+  ; CHECK-NEXT:   Ret $noreg, 0
+  bb.0:
+    successors: %bb.1(0x40000000), %bb.2(0x40000000), %bb.3(0x40000000)
+
+    %0:int32regs = IMPLICIT_DEF
+    %1:int32regs = IMPLICIT_DEF
+    %2:int32regs = IMPLICIT_DEF
+    %3:int1regs = IMPLICIT_DEF
+    %4:int1regs = IMPLICIT_DEF
+    Jump %bb.3, %3, 0
+    Jump %bb.2, %4, 0
+
+  bb.1:
+    StoreRetvalI32 %1, 0, $noreg, 0 :: (store (s32), align 1)
+    Ret $noreg, 0
+
+  bb.2:
+    StoreRetvalI32 %2, 0, $noreg, 0 :: (store (s32), align 1)
+    Ret $noreg, 0
+
+  bb.3:
+    StoreRetvalI32 %0, 0, $noreg, 0 :: (store (s32), align 1)
+    Ret $noreg, 0
+
+...
diff --git a/llvm/test/CodeGen/NVPTX/nopred.mir b/llvm/test/CodeGen/NVPTX/nopred.mir
new file mode 100644
index 000000000000000..2d6baa31e75ca50
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/nopred.mir
@@ -0,0 +1,33 @@
+# NOTE: Assertions have been autogenerated by utils/update_mir_test_checks.py
+# RUN: llc -o - %s -march=nvptx64 -mcpu=sm_35 | FileCheck %s
+
+---
+name:            foo
+alignment:       1
+tracksRegLiveness: true
+registers:
+  - { id: 0, class: int64regs }
+  - { id: 1, class: int64regs }
+  - { id: 2, class: int1regs }
+  - { id: 3, class: int64regs }
+  - { id: 4, class: int64regs }
+frameInfo:
+  maxAlignment:    1
+machineFunctionInfo: {}
+body:             |
+  bb.0:
+    successors:
+    %0:int64regs = IMPLICIT_DEF
+    %1:int64regs = IMPLICIT_DEF
+    %2:int1regs = IMPLICIT_DEF; SETP_s64ri %5, 1, 2, $noreg, 0
+    %3:int64regs = ADDi64ri %0, 1, %2, 1
+    %4:int64regs = SELP_u64rr %0, %3, %2, $noreg, 0
+    StoreRetvalI64 %4, 0, $noreg, 0
+    ; CHECK: @!%p1 add.s64   %rd3, %rd1, 1;
+    ; CHECK: selp.u64 %rd4, %rd1, %rd3, %p1;
+    ; st.param.b64   [func_retval0+0], %rd4;
+
+
+
+
+...



More information about the llvm-commits mailing list