[llvm] [NVPTX] Rework and cleanup FTZ ISel (PR #146410)

via llvm-commits llvm-commits at lists.llvm.org
Mon Jun 30 12:27:07 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-nvptx

Author: Alex MacLean (AlexMaclean)

<details>
<summary>Changes</summary>

This change cleans up DAG-to-DAG instruction selection around FTZ and SETP comparison mode. Largely this is non-functional though support for `{sin.cos}.approx.ftz.f32` is added.

---

Patch is 346.39 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/146410.diff


35 Files Affected:

- (modified) llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp (+74-33) 
- (modified) llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h (+1) 
- (modified) llvm/lib/Target/NVPTX/NVPTX.h (+15-22) 
- (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp (+18-28) 
- (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h (+1-2) 
- (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+325-609) 
- (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+69-84) 
- (modified) llvm/test/CodeGen/MIR/NVPTX/floating-point-immediate-operands.mir (+4-4) 
- (modified) llvm/test/CodeGen/NVPTX/atomics-sm70.ll (+4-4) 
- (modified) llvm/test/CodeGen/NVPTX/atomics-sm90.ll (+4-4) 
- (modified) llvm/test/CodeGen/NVPTX/atomics.ll (+1-1) 
- (modified) llvm/test/CodeGen/NVPTX/bf16-instructions.ll (+6-6) 
- (modified) llvm/test/CodeGen/NVPTX/branch-fold.mir (+2-2) 
- (modified) llvm/test/CodeGen/NVPTX/cmpxchg-sm60.ll (+180-180) 
- (modified) llvm/test/CodeGen/NVPTX/cmpxchg-sm70.ll (+180-180) 
- (modified) llvm/test/CodeGen/NVPTX/cmpxchg-sm90.ll (+180-180) 
- (modified) llvm/test/CodeGen/NVPTX/cmpxchg.ll (+40-40) 
- (modified) llvm/test/CodeGen/NVPTX/compare-int.ll (+8-8) 
- (modified) llvm/test/CodeGen/NVPTX/distributed-shared-cluster.ll (+10-10) 
- (modified) llvm/test/CodeGen/NVPTX/extractelement.ll (+1-1) 
- (modified) llvm/test/CodeGen/NVPTX/f16-instructions.ll (+6-4) 
- (modified) llvm/test/CodeGen/NVPTX/f16x2-instructions.ll (+1-1) 
- (modified) llvm/test/CodeGen/NVPTX/fast-math.ll (+28) 
- (modified) llvm/test/CodeGen/NVPTX/i1-select.ll (+7-7) 
- (modified) llvm/test/CodeGen/NVPTX/i128.ll (+36-36) 
- (modified) llvm/test/CodeGen/NVPTX/i16x2-instructions.ll (+7-7) 
- (modified) llvm/test/CodeGen/NVPTX/i8x4-instructions.ll (+32-32) 
- (modified) llvm/test/CodeGen/NVPTX/inline-asm-b128-test1.ll (+1-1) 
- (modified) llvm/test/CodeGen/NVPTX/inline-asm-b128-test3.ll (+2-2) 
- (modified) llvm/test/CodeGen/NVPTX/jump-table.ll (+1-1) 
- (modified) llvm/test/CodeGen/NVPTX/load-with-non-coherent-cache.ll (+2-2) 
- (modified) llvm/test/CodeGen/NVPTX/lower-aggr-copies.ll (+1-1) 
- (modified) llvm/test/CodeGen/NVPTX/math-intrins.ll (+42-42) 
- (modified) llvm/test/CodeGen/NVPTX/sext-setcc.ll (+6-6) 
- (modified) llvm/test/CodeGen/NVPTX/tid-range.ll (+1-1) 


``````````diff
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
index 28f6968ee6caf..4ba3d0f1eaccd 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
@@ -154,73 +154,114 @@ void NVPTXInstPrinter::printCvtMode(const MCInst *MI, int OpNum, raw_ostream &O,
   llvm_unreachable("Invalid conversion modifier");
 }
 
+void NVPTXInstPrinter::printFTZFlag(const MCInst *MI, int OpNum, raw_ostream &O) {
+  const MCOperand &MO = MI->getOperand(OpNum);
+  const int Imm = MO.getImm();
+  if (Imm)
+    O << ".ftz";
+}
+
 void NVPTXInstPrinter::printCmpMode(const MCInst *MI, int OpNum, raw_ostream &O,
                                     StringRef Modifier) {
   const MCOperand &MO = MI->getOperand(OpNum);
   int64_t Imm = MO.getImm();
 
-  if (Modifier == "ftz") {
-    // FTZ flag
-    if (Imm & NVPTX::PTXCmpMode::FTZ_FLAG)
-      O << ".ftz";
-    return;
-  } else if (Modifier == "base") {
-    switch (Imm & NVPTX::PTXCmpMode::BASE_MASK) {
+  if (Modifier == "FCmp") {
+    switch (Imm) {
     default:
       return;
     case NVPTX::PTXCmpMode::EQ:
-      O << ".eq";
+      O << "eq";
       return;
     case NVPTX::PTXCmpMode::NE:
-      O << ".ne";
+      O << "ne";
       return;
     case NVPTX::PTXCmpMode::LT:
-      O << ".lt";
+      O << "lt";
       return;
     case NVPTX::PTXCmpMode::LE:
-      O << ".le";
+      O << "le";
       return;
     case NVPTX::PTXCmpMode::GT:
-      O << ".gt";
+      O << "gt";
       return;
     case NVPTX::PTXCmpMode::GE:
-      O << ".ge";
-      return;
-    case NVPTX::PTXCmpMode::LO:
-      O << ".lo";
-      return;
-    case NVPTX::PTXCmpMode::LS:
-      O << ".ls";
-      return;
-    case NVPTX::PTXCmpMode::HI:
-      O << ".hi";
-      return;
-    case NVPTX::PTXCmpMode::HS:
-      O << ".hs";
+      O << "ge";
       return;
     case NVPTX::PTXCmpMode::EQU:
-      O << ".equ";
+      O << "equ";
       return;
     case NVPTX::PTXCmpMode::NEU:
-      O << ".neu";
+      O << "neu";
       return;
     case NVPTX::PTXCmpMode::LTU:
-      O << ".ltu";
+      O << "ltu";
       return;
     case NVPTX::PTXCmpMode::LEU:
-      O << ".leu";
+      O << "leu";
       return;
     case NVPTX::PTXCmpMode::GTU:
-      O << ".gtu";
+      O << "gtu";
       return;
     case NVPTX::PTXCmpMode::GEU:
-      O << ".geu";
+      O << "geu";
       return;
     case NVPTX::PTXCmpMode::NUM:
-      O << ".num";
+      O << "num";
       return;
     case NVPTX::PTXCmpMode::NotANumber:
-      O << ".nan";
+      O << "nan";
+      return;
+    }
+  }
+  if (Modifier == "ICmp") {
+    switch (Imm) {
+    default:
+      llvm_unreachable("Invalid ICmp mode");
+    case NVPTX::PTXCmpMode::EQ:
+      O << "eq";
+      return;
+    case NVPTX::PTXCmpMode::NE:
+      O << "ne";
+      return;
+    case NVPTX::PTXCmpMode::LT:
+    case NVPTX::PTXCmpMode::LTU:
+      O << "lt";
+      return;
+    case NVPTX::PTXCmpMode::LE:
+    case NVPTX::PTXCmpMode::LEU:
+      O << "le";
+      return;
+    case NVPTX::PTXCmpMode::GT:
+    case NVPTX::PTXCmpMode::GTU:
+      O << "gt";
+      return;
+    case NVPTX::PTXCmpMode::GE:
+    case NVPTX::PTXCmpMode::GEU:
+      O << "ge";
+      return;
+      
+    }
+  }
+  if (Modifier == "IType") {
+    switch (Imm) {
+    default:
+      llvm_unreachable("Invalid ICmp mode");
+    case NVPTX::PTXCmpMode::EQ:
+    case NVPTX::PTXCmpMode::NE:
+      O << "b";
+      return;
+    case NVPTX::PTXCmpMode::LT:
+    case NVPTX::PTXCmpMode::LE:
+    case NVPTX::PTXCmpMode::GT:
+    case NVPTX::PTXCmpMode::GE:
+      O << "s";
+      return;
+    case NVPTX::PTXCmpMode::LTU:
+    case NVPTX::PTXCmpMode::LEU:
+    case NVPTX::PTXCmpMode::GTU:
+    case NVPTX::PTXCmpMode::GEU:
+      O << "u";
       return;
     }
   }
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
index 6189284e8a58c..193c436939f66 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
@@ -54,6 +54,7 @@ class NVPTXInstPrinter : public MCInstPrinter {
   void printCTAGroup(const MCInst *MI, int OpNum, raw_ostream &O);
   void printCallOperand(const MCInst *MI, int OpNum, raw_ostream &O,
                         StringRef Modifier = {});
+  void printFTZFlag(const MCInst *MI, int OpNum, raw_ostream &O);
 };
 
 }
diff --git a/llvm/lib/Target/NVPTX/NVPTX.h b/llvm/lib/Target/NVPTX/NVPTX.h
index b7fd7090299a9..cfe2c8e7d2ed1 100644
--- a/llvm/lib/Target/NVPTX/NVPTX.h
+++ b/llvm/lib/Target/NVPTX/NVPTX.h
@@ -19,7 +19,7 @@
 #include "llvm/Support/AtomicOrdering.h"
 #include "llvm/Support/CodeGen.h"
 #include "llvm/Target/TargetMachine.h"
-
+#include "llvm/CodeGen/ISDOpcodes.h"
 namespace llvm {
 class FunctionPass;
 class MachineFunctionPass;
@@ -218,28 +218,21 @@ enum CvtMode {
 /// PTXCmpMode - Comparison mode enumeration
 namespace PTXCmpMode {
 enum CmpMode {
-  EQ = 0,
-  NE,
-  LT,
-  LE,
-  GT,
-  GE,
-  LO,
-  LS,
-  HI,
-  HS,
-  EQU,
-  NEU,
-  LTU,
-  LEU,
-  GTU,
-  GEU,
-  NUM,
+  EQ = ISD::SETEQ,
+  NE = ISD::SETNE,
+  LT = ISD::SETLT,
+  LE = ISD::SETLE,
+  GT = ISD::SETGT,
+  GE = ISD::SETGE,
+  EQU = ISD::SETUEQ,
+  NEU = ISD::SETUNE,
+  LTU = ISD::SETULT,
+  LEU = ISD::SETULE,
+  GTU = ISD::SETUGT,
+  GEU = ISD::SETUGE,
+  NUM = ISD::SETO,
   // NAN is a MACRO
-  NotANumber,
-
-  BASE_MASK = 0xFF,
-  FTZ_FLAG = 0x100
+  NotANumber = ISD::SETUO,
 };
 }
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 5631342ecc13e..75461a4b6213a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -363,23 +363,29 @@ bool NVPTXDAGToDAGISel::tryIntrinsicChain(SDNode *N) {
 
 // Map ISD:CONDCODE value to appropriate CmpMode expected by
 // NVPTXInstPrinter::printCmpMode()
-static unsigned getPTXCmpMode(const CondCodeSDNode &CondCode, bool FTZ) {
+SDValue NVPTXDAGToDAGISel::getPTXCmpMode(const CondCodeSDNode &CondCode) {
   using NVPTX::PTXCmpMode::CmpMode;
-  unsigned PTXCmpMode = [](ISD::CondCode CC) {
+  const unsigned PTXCmpMode = [](ISD::CondCode CC) {
     switch (CC) {
     default:
       llvm_unreachable("Unexpected condition code.");
     case ISD::SETOEQ:
+    case ISD::SETEQ:
       return CmpMode::EQ;
     case ISD::SETOGT:
+    case ISD::SETGT:
       return CmpMode::GT;
     case ISD::SETOGE:
+    case ISD::SETGE:
       return CmpMode::GE;
     case ISD::SETOLT:
+    case ISD::SETLT:
       return CmpMode::LT;
     case ISD::SETOLE:
+    case ISD::SETLE:
       return CmpMode::LE;
     case ISD::SETONE:
+    case ISD::SETNE:
       return CmpMode::NE;
     case ISD::SETO:
       return CmpMode::NUM;
@@ -397,45 +403,29 @@ static unsigned getPTXCmpMode(const CondCodeSDNode &CondCode, bool FTZ) {
       return CmpMode::LEU;
     case ISD::SETUNE:
       return CmpMode::NEU;
-    case ISD::SETEQ:
-      return CmpMode::EQ;
-    case ISD::SETGT:
-      return CmpMode::GT;
-    case ISD::SETGE:
-      return CmpMode::GE;
-    case ISD::SETLT:
-      return CmpMode::LT;
-    case ISD::SETLE:
-      return CmpMode::LE;
-    case ISD::SETNE:
-      return CmpMode::NE;
     }
   }(CondCode.get());
-
-  if (FTZ)
-    PTXCmpMode |= NVPTX::PTXCmpMode::FTZ_FLAG;
-
-  return PTXCmpMode;
+  return CurDAG->getTargetConstant(PTXCmpMode, SDLoc(), MVT::i32);
 }
 
 bool NVPTXDAGToDAGISel::SelectSETP_F16X2(SDNode *N) {
-  unsigned PTXCmpMode =
-      getPTXCmpMode(*cast<CondCodeSDNode>(N->getOperand(2)), useF32FTZ());
+  SDValue PTXCmpMode =
+      getPTXCmpMode(*cast<CondCodeSDNode>(N->getOperand(2)));
   SDLoc DL(N);
   SDNode *SetP = CurDAG->getMachineNode(
-      NVPTX::SETP_f16x2rr, DL, MVT::i1, MVT::i1, N->getOperand(0),
-      N->getOperand(1), CurDAG->getTargetConstant(PTXCmpMode, DL, MVT::i32));
+      NVPTX::SETP_f16x2rr, DL, MVT::i1, MVT::i1, {N->getOperand(0),
+      N->getOperand(1), PTXCmpMode, CurDAG->getTargetConstant(useF32FTZ() ? 1 : 0, DL, MVT::i1)});
   ReplaceNode(N, SetP);
   return true;
 }
 
 bool NVPTXDAGToDAGISel::SelectSETP_BF16X2(SDNode *N) {
-  unsigned PTXCmpMode =
-      getPTXCmpMode(*cast<CondCodeSDNode>(N->getOperand(2)), useF32FTZ());
+  SDValue PTXCmpMode =
+      getPTXCmpMode(*cast<CondCodeSDNode>(N->getOperand(2)));
   SDLoc DL(N);
   SDNode *SetP = CurDAG->getMachineNode(
-      NVPTX::SETP_bf16x2rr, DL, MVT::i1, MVT::i1, N->getOperand(0),
-      N->getOperand(1), CurDAG->getTargetConstant(PTXCmpMode, DL, MVT::i32));
+      NVPTX::SETP_bf16x2rr, DL, MVT::i1, MVT::i1, {N->getOperand(0),
+      N->getOperand(1), PTXCmpMode, CurDAG->getTargetConstant(useF32FTZ() ? 1 : 0, DL, MVT::i1)});
   ReplaceNode(N, SetP);
   return true;
 }
@@ -1953,7 +1943,7 @@ bool NVPTXDAGToDAGISel::tryBF16ArithToFMA(SDNode *N) {
     llvm_unreachable("Unexpected opcode");
   };
 
-  int Opcode = IsVec ? NVPTX::BFMA16x2rrr : NVPTX::BFMA16rrr;
+  int Opcode = IsVec ? NVPTX::FMA_BF16x2rrr : NVPTX::FMA_BF16rrr;
   MachineSDNode *FMA = CurDAG->getMachineNode(Opcode, DL, VT, Operands);
   ReplaceNode(N, FMA);
   return true;
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
index 0e4dec1adca67..b314c4ccefe8b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
@@ -104,12 +104,11 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
   }
 
   bool SelectADDR(SDValue Addr, SDValue &Base, SDValue &Offset);
+  SDValue getPTXCmpMode(const CondCodeSDNode &CondCode);
   SDValue selectPossiblyImm(SDValue V);
 
   bool ChkMemSDNodeAddressSpace(SDNode *N, unsigned int spN) const;
 
-  static unsigned GetConvertOpcode(MVT DestTy, MVT SrcTy, LoadSDNode *N);
-
   // Returns the Memory Order and Scope that the PTX memory instruction should
   // use, and inserts appropriate fence instruction before the memory
   // instruction, if needed to implement the instructions memory order. Required
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 1a2515b7f66f3..9ef9ce3b7bb8d 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -68,48 +68,28 @@ def CvtMode : Operand<i32> {
   let PrintMethod = "printCvtMode";
 }
 
+// FTZ flag
+
+def FTZ : PatLeaf<(i1 1)>;
+def NoFTZ : PatLeaf<(i1 0)>;
+
+def getFTZFlag : SDNodeXForm<imm, [{
+  (void)N;
+  return CurDAG->getTargetConstant(useF32FTZ() ? 1 : 0, SDLoc(), MVT::i1);
+}]>;
+
+def FTZFlag : OperandWithDefaultOps<i1, (ops (getFTZFlag (i1 0)))> {
+  let PrintMethod = "printFTZFlag";
+}
+
 // Compare modes
 // These must match the enum in NVPTX.h
-def CmpEQ   : PatLeaf<(i32 0)>;
-def CmpNE   : PatLeaf<(i32 1)>;
-def CmpLT   : PatLeaf<(i32 2)>;
-def CmpLE   : PatLeaf<(i32 3)>;
-def CmpGT   : PatLeaf<(i32 4)>;
-def CmpGE   : PatLeaf<(i32 5)>;
-def CmpLO   : PatLeaf<(i32 6)>;
-def CmpLS   : PatLeaf<(i32 7)>;
-def CmpHI   : PatLeaf<(i32 8)>;
-def CmpHS   : PatLeaf<(i32 9)>;
-def CmpEQU  : PatLeaf<(i32 10)>;
-def CmpNEU  : PatLeaf<(i32 11)>;
-def CmpLTU  : PatLeaf<(i32 12)>;
-def CmpLEU  : PatLeaf<(i32 13)>;
-def CmpGTU  : PatLeaf<(i32 14)>;
-def CmpGEU  : PatLeaf<(i32 15)>;
-def CmpNUM  : PatLeaf<(i32 16)>;
-def CmpNAN  : PatLeaf<(i32 17)>;
-
-def CmpEQ_FTZ   : PatLeaf<(i32 0x100)>;
-def CmpNE_FTZ   : PatLeaf<(i32 0x101)>;
-def CmpLT_FTZ   : PatLeaf<(i32 0x102)>;
-def CmpLE_FTZ   : PatLeaf<(i32 0x103)>;
-def CmpGT_FTZ   : PatLeaf<(i32 0x104)>;
-def CmpGE_FTZ   : PatLeaf<(i32 0x105)>;
-def CmpEQU_FTZ  : PatLeaf<(i32 0x10A)>;
-def CmpNEU_FTZ  : PatLeaf<(i32 0x10B)>;
-def CmpLTU_FTZ  : PatLeaf<(i32 0x10C)>;
-def CmpLEU_FTZ  : PatLeaf<(i32 0x10D)>;
-def CmpGTU_FTZ  : PatLeaf<(i32 0x10E)>;
-def CmpGEU_FTZ  : PatLeaf<(i32 0x10F)>;
-def CmpNUM_FTZ  : PatLeaf<(i32 0x110)>;
-def CmpNAN_FTZ  : PatLeaf<(i32 0x111)>;
+def CmpEQ   : PatLeaf<(i32 17)>;
+def CmpNE   : PatLeaf<(i32 22)>;
 
 def CmpMode : Operand<i32> {
   let PrintMethod = "printCmpMode";
 }
-def VecElement : Operand<i32> {
-  let PrintMethod = "printVecElement";
-}
 
 // PRMT modes
 // These must match the enum in NVPTX.h
@@ -152,8 +132,6 @@ def noHWROT32 : Predicate<"!Subtarget->hasHWROT32()">;
 def hasDotInstructions : Predicate<"Subtarget->hasDotInstructions()">;
 def hasTcgen05Instructions : Predicate<"Subtarget->hasTcgen05Instructions()">;
 
-def True : Predicate<"true">;
-
 class hasPTX<int version>: Predicate<"Subtarget->getPTXVersion() >= " # version>;
 class hasSM<int version>: Predicate<"Subtarget->getSmVersion() >= " # version>;
 
@@ -198,7 +176,7 @@ def RI64 : Operand<Any>;
 
 // Utility class to wrap up information about a register and DAG type for more
 // convenient iteration and parameterization
-class RegTyInfo<ValueType ty, NVPTXRegClass rc, Operand imm, SDNode imm_node,
+class RegTyInfo<ValueType ty, NVPTXRegClass rc, string str, Operand imm, SDNode imm_node,
                 bit supports_imm = 1> {
   ValueType Ty = ty;
   NVPTXRegClass RC = rc;
@@ -206,20 +184,21 @@ class RegTyInfo<ValueType ty, NVPTXRegClass rc, Operand imm, SDNode imm_node,
   SDNode ImmNode = imm_node;
   bit SupportsImm = supports_imm;
   int Size = ty.Size;
+  string Str = str;
 }
 
-def I1RT     : RegTyInfo<i1,  B1,  i1imm,  imm>;
-def I16RT    : RegTyInfo<i16, B16, i16imm, imm>;
-def I32RT    : RegTyInfo<i32, B32, i32imm, imm>;
-def I64RT    : RegTyInfo<i64, B64, i64imm, imm>;
+def I1RT     : RegTyInfo<i1,  B1,  "pred", i1imm,  imm>;
+def I16RT    : RegTyInfo<i16, B16, "b16",  i16imm, imm>;
+def I32RT    : RegTyInfo<i32, B32, "b32",  i32imm, imm>;
+def I64RT    : RegTyInfo<i64, B64, "b64",  i64imm, imm>;
 
-def F32RT    : RegTyInfo<f32, B32, f32imm, fpimm>;
-def F64RT    : RegTyInfo<f64, B64, f64imm, fpimm>;
-def F16RT    : RegTyInfo<f16, B16, f16imm, fpimm, supports_imm = 0>;
-def BF16RT   : RegTyInfo<bf16, B16, bf16imm, fpimm, supports_imm = 0>;
+def F32RT    : RegTyInfo<f32,  B32, "f32",  f32imm,  fpimm>;
+def F64RT    : RegTyInfo<f64,  B64, "f64",  f64imm,  fpimm>;
+def F16RT    : RegTyInfo<f16,  B16, "f16",  f16imm,  fpimm, supports_imm = 0>;
+def BF16RT   : RegTyInfo<bf16, B16, "bf16", bf16imm, fpimm, supports_imm = 0>;
 
-def F16X2RT  : RegTyInfo<v2f16, B32, ?, ?, supports_imm = 0>;
-def BF16X2RT : RegTyInfo<v2bf16, B32, ?, ?, supports_imm = 0>;
+def F16X2RT  : RegTyInfo<v2f16, B32, "f16x2", ?, ?, supports_imm = 0>;
+def BF16X2RT : RegTyInfo<v2bf16, B32, "bf16x2", ?, ?, supports_imm = 0>;
 
 
 // This class provides a basic wrapper around an NVPTXInst that abstracts the
@@ -321,76 +300,57 @@ multiclass ADD_SUB_INT_CARRY<string op_str, SDNode op_node, bit commutative> {
 // Also defines ftz (flush subnormal inputs and results to sign-preserving
 // zero) variants for fp32 functions.
 multiclass FMINIMUMMAXIMUM<string OpcStr, bit NaN, SDNode OpNode> {
+  defvar nan_str = !if(NaN, ".NaN", "");
   if !not(NaN) then {
-   def f64rr :
+   def _f64_rr :
      BasicNVPTXInst<(outs B64:$dst),
                (ins B64:$a, B64:$b),
                OpcStr # ".f64",
                [(set f64:$dst, (OpNode f64:$a, f64:$b))]>;
-   def f64ri :
+   def _f64_ri :
      BasicNVPTXInst<(outs B64:$dst),
                (ins B64:$a, f64imm:$b),
                OpcStr # ".f64",
                [(set f64:$dst, (OpNode f64:$a, fpimm:$b))]>;
   }
-   def f32rr_ftz :
-     BasicNVPTXInst<(outs B32:$dst),
+   def _f32_rr :
+     BasicFlagsNVPTXInst<(outs B32:$dst),
                (ins B32:$a, B32:$b),
-               OpcStr # ".ftz.f32",
-               [(set f32:$dst, (OpNode f32:$a, f32:$b))]>,
-               Requires<[doF32FTZ]>;
-   def f32ri_ftz :
-     BasicNVPTXInst<(outs B32:$dst),
-               (ins B32:$a, f32imm:$b),
-               OpcStr # ".ftz.f32",
-               [(set f32:$dst, (OpNode f32:$a, fpimm:$b))]>,
-               Requires<[doF32FTZ]>;
-   def f32rr :
-     BasicNVPTXInst<(outs B32:$dst),
-               (ins B32:$a, B32:$b),
-               OpcStr # ".f32",
+               (ins FTZFlag:$ftz),
+               OpcStr # "$ftz" # nan_str # ".f32",
                [(set f32:$dst, (OpNode f32:$a, f32:$b))]>;
-   def f32ri :
-     BasicNVPTXInst<(outs B32:$dst),
+   def _f32_ri :
+     BasicFlagsNVPTXInst<(outs B32:$dst),
                (ins B32:$a, f32imm:$b),
-               OpcStr # ".f32",
+               (ins FTZFlag:$ftz),
+               OpcStr # "$ftz" # nan_str # ".f32",
                [(set f32:$dst, (OpNode f32:$a, fpimm:$b))]>;
 
-   def f16rr_ftz :
-     BasicNVPTXInst<(outs B16:$dst),
-               (ins B16:$a, B16:$b),
-               OpcStr # ".ftz.f16",
-               [(set f16:$dst, (OpNode f16:$a, f16:$b))]>,
-               Requires<[useFP16Math, doF32FTZ]>;
-   def f16rr :
-     BasicNVPTXInst<(outs B16:$dst),
+   def _f16_rr :
+     BasicFlagsNVPTXInst<(outs B16:$dst),
                (ins B16:$a, B16:$b),
-               OpcStr # ".f16",
+               (ins FTZFlag:$ftz),
+               OpcStr # "$ftz" # nan_str # ".f16",
                [(set f16:$dst, (OpNode f16:$a, f16:$b))]>,
-               Requires<[useFP16Math, hasSM<80>, hasPTX<70>]>;
+               Requires<[useFP16Math]>;
 
-   def f16x2rr_ftz :
-     BasicNVPTXInst<(outs B32:$dst),
-               (ins B32:$a, B32:$b),
-               OpcStr # ".ftz.f16x2",
-               [(set v2f16:$dst, (OpNode v2f16:$a, v2f16:$b))]>,
-               Requires<[useFP16Math, hasSM<80>, hasPTX<70>, doF32FTZ]>;
-   def f16x2rr :
-     BasicNVPTXInst<(outs B32:$dst),
+   def _f16x2_rr :
+     BasicFlagsNVPTXInst<(outs B32:$dst),
                (ins B32:$a, B32:$b),
-               OpcStr # ".f16x2",
+               (ins FTZFlag:$ftz),
+               OpcStr # "$ftz" # nan_str # ".f16x2",
                [(set v2f16:$dst, (OpNode v2f16:$a, v2f16:$b))]>,
                Requires<[useFP16Math, hasSM<80>, hasPTX<70>]>;
-   def bf16rr :
+   def _bf16_rr :
      BasicNVPTXInst<(outs B16:$dst),
                (ins B16:$a, B16:$b),
-               OpcStr # ".bf16",
+               OpcStr # nan_str # ".bf16",
                [(set bf16:$dst, (OpNode bf16:$a, bf16:$b))]>,
                Requires<[hasBF16Math, hasSM<80>, hasPTX<70>]>;
-   def bf16x2rr :
+   def _bf16x2_rr :
      BasicNVPTXInst<(outs B32:$dst),
                (ins B32:$a, B32:$b),
-               OpcStr # ".bf16x2",
+               OpcStr # nan_str # ".bf16x2",
                [(set v2bf16:$dst, (OpNode v2bf16:$a, v2bf16:$b))]>,
                Requires<[hasBF16Math, hasSM<80>, hasPTX<70>]>;
 }
@@ -415,52 +375,31 @@ multiclass F3<string op_str, SDPatternOperator op_pat> {
               (ins B64:$a, f64imm:$b),
               op_str # ".f64",
               [(set f64:$dst, (op_pat f64:$a, fpimm:$b))]>;
-  def f32rr_ftz :
-    BasicNVPTXInst<(outs B32:$dst),
-              (ins B32:$a, B32:$b),
-              op_str # ".ftz.f32",
-              [(set f32:$dst, (op_pat f32:$a, f32:$b))]>,
-              Requires<[doF32FTZ]>;
-  def f32ri_ftz :
-    BasicNVPTXInst<(outs B32:$dst),
-              (ins B32:$a, f32imm:$b),
-              op_str # ".ftz.f32",
-              [(set f32:$dst, (op_pat f32:$a, fpimm:$b))]>,
-              Requires<[doF32FTZ]>;
   def f32rr :
-    BasicNVPTXInst<(outs B32:$dst),
+    BasicFlagsNVPTXInst<(outs B32:$dst),
               (ins B32:$a, B32:$b),
-              op_str # ".f32",
+              (ins FTZFlag:$ftz),
+              op_str # "$ftz.f32",
               [(set f32:$dst, (op_pat f32:$a, f32:$b))]>;
   def f32ri :
-    BasicNVPTXInst<(outs B32:$dst),
+    BasicFlagsNVPTXInst<(outs B32:$dst),
               (ins B32:$a, f32imm:$b),
-              op_str # ".f32",
+              (ins FTZFlag:$ftz),
+              op_str # "$ftz.f32",
               [(set f32:$dst, (op_pat f32:$a, fpimm:$b))]>;
 
-  def f16rr_ftz :
-    BasicNVPTXInst<(outs B16:$dst),
-              (ins B16:$a, B16:$b),
-              op_str # ".ftz.f16",
-              [(set f16:$dst, (op_pat f16:$a, f16:$b))]>,
-              Requires<[useFP16Math, doF32FTZ]>;
   def f16rr :
-    BasicNVPTXInst<(outs B16:$dst),
+    BasicFlagsNVPTXInst<(outs B16:$dst),
               (ins B16:$a, B16:$b),
-              op_str # ".f16",
+              (ins FTZFlag:$ftz),
+              op_str # "$ftz.f16",
               [(set f16:$dst, (op_pat f16:$a, f16:$b))]>,
               Requires<[useFP16Math]>;
-
-  def f16x2rr_ftz :
-    BasicNVPTXInst<(outs B32:$dst),
-              (...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/146410


More information about the llvm-commits mailing list