[llvm] [NVPTX] Rework and cleanup FTZ ISel (PR #146410)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Jun 30 12:27:07 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-nvptx
Author: Alex MacLean (AlexMaclean)
<details>
<summary>Changes</summary>
This change cleans up DAG-to-DAG instruction selection around FTZ and SETP comparison mode. Largely this is non-functional though support for `{sin.cos}.approx.ftz.f32` is added.
---
Patch is 346.39 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/146410.diff
35 Files Affected:
- (modified) llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp (+74-33)
- (modified) llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h (+1)
- (modified) llvm/lib/Target/NVPTX/NVPTX.h (+15-22)
- (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp (+18-28)
- (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h (+1-2)
- (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+325-609)
- (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+69-84)
- (modified) llvm/test/CodeGen/MIR/NVPTX/floating-point-immediate-operands.mir (+4-4)
- (modified) llvm/test/CodeGen/NVPTX/atomics-sm70.ll (+4-4)
- (modified) llvm/test/CodeGen/NVPTX/atomics-sm90.ll (+4-4)
- (modified) llvm/test/CodeGen/NVPTX/atomics.ll (+1-1)
- (modified) llvm/test/CodeGen/NVPTX/bf16-instructions.ll (+6-6)
- (modified) llvm/test/CodeGen/NVPTX/branch-fold.mir (+2-2)
- (modified) llvm/test/CodeGen/NVPTX/cmpxchg-sm60.ll (+180-180)
- (modified) llvm/test/CodeGen/NVPTX/cmpxchg-sm70.ll (+180-180)
- (modified) llvm/test/CodeGen/NVPTX/cmpxchg-sm90.ll (+180-180)
- (modified) llvm/test/CodeGen/NVPTX/cmpxchg.ll (+40-40)
- (modified) llvm/test/CodeGen/NVPTX/compare-int.ll (+8-8)
- (modified) llvm/test/CodeGen/NVPTX/distributed-shared-cluster.ll (+10-10)
- (modified) llvm/test/CodeGen/NVPTX/extractelement.ll (+1-1)
- (modified) llvm/test/CodeGen/NVPTX/f16-instructions.ll (+6-4)
- (modified) llvm/test/CodeGen/NVPTX/f16x2-instructions.ll (+1-1)
- (modified) llvm/test/CodeGen/NVPTX/fast-math.ll (+28)
- (modified) llvm/test/CodeGen/NVPTX/i1-select.ll (+7-7)
- (modified) llvm/test/CodeGen/NVPTX/i128.ll (+36-36)
- (modified) llvm/test/CodeGen/NVPTX/i16x2-instructions.ll (+7-7)
- (modified) llvm/test/CodeGen/NVPTX/i8x4-instructions.ll (+32-32)
- (modified) llvm/test/CodeGen/NVPTX/inline-asm-b128-test1.ll (+1-1)
- (modified) llvm/test/CodeGen/NVPTX/inline-asm-b128-test3.ll (+2-2)
- (modified) llvm/test/CodeGen/NVPTX/jump-table.ll (+1-1)
- (modified) llvm/test/CodeGen/NVPTX/load-with-non-coherent-cache.ll (+2-2)
- (modified) llvm/test/CodeGen/NVPTX/lower-aggr-copies.ll (+1-1)
- (modified) llvm/test/CodeGen/NVPTX/math-intrins.ll (+42-42)
- (modified) llvm/test/CodeGen/NVPTX/sext-setcc.ll (+6-6)
- (modified) llvm/test/CodeGen/NVPTX/tid-range.ll (+1-1)
``````````diff
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
index 28f6968ee6caf..4ba3d0f1eaccd 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
@@ -154,73 +154,114 @@ void NVPTXInstPrinter::printCvtMode(const MCInst *MI, int OpNum, raw_ostream &O,
llvm_unreachable("Invalid conversion modifier");
}
+void NVPTXInstPrinter::printFTZFlag(const MCInst *MI, int OpNum, raw_ostream &O) {
+ const MCOperand &MO = MI->getOperand(OpNum);
+ const int Imm = MO.getImm();
+ if (Imm)
+ O << ".ftz";
+}
+
void NVPTXInstPrinter::printCmpMode(const MCInst *MI, int OpNum, raw_ostream &O,
StringRef Modifier) {
const MCOperand &MO = MI->getOperand(OpNum);
int64_t Imm = MO.getImm();
- if (Modifier == "ftz") {
- // FTZ flag
- if (Imm & NVPTX::PTXCmpMode::FTZ_FLAG)
- O << ".ftz";
- return;
- } else if (Modifier == "base") {
- switch (Imm & NVPTX::PTXCmpMode::BASE_MASK) {
+ if (Modifier == "FCmp") {
+ switch (Imm) {
default:
return;
case NVPTX::PTXCmpMode::EQ:
- O << ".eq";
+ O << "eq";
return;
case NVPTX::PTXCmpMode::NE:
- O << ".ne";
+ O << "ne";
return;
case NVPTX::PTXCmpMode::LT:
- O << ".lt";
+ O << "lt";
return;
case NVPTX::PTXCmpMode::LE:
- O << ".le";
+ O << "le";
return;
case NVPTX::PTXCmpMode::GT:
- O << ".gt";
+ O << "gt";
return;
case NVPTX::PTXCmpMode::GE:
- O << ".ge";
- return;
- case NVPTX::PTXCmpMode::LO:
- O << ".lo";
- return;
- case NVPTX::PTXCmpMode::LS:
- O << ".ls";
- return;
- case NVPTX::PTXCmpMode::HI:
- O << ".hi";
- return;
- case NVPTX::PTXCmpMode::HS:
- O << ".hs";
+ O << "ge";
return;
case NVPTX::PTXCmpMode::EQU:
- O << ".equ";
+ O << "equ";
return;
case NVPTX::PTXCmpMode::NEU:
- O << ".neu";
+ O << "neu";
return;
case NVPTX::PTXCmpMode::LTU:
- O << ".ltu";
+ O << "ltu";
return;
case NVPTX::PTXCmpMode::LEU:
- O << ".leu";
+ O << "leu";
return;
case NVPTX::PTXCmpMode::GTU:
- O << ".gtu";
+ O << "gtu";
return;
case NVPTX::PTXCmpMode::GEU:
- O << ".geu";
+ O << "geu";
return;
case NVPTX::PTXCmpMode::NUM:
- O << ".num";
+ O << "num";
return;
case NVPTX::PTXCmpMode::NotANumber:
- O << ".nan";
+ O << "nan";
+ return;
+ }
+ }
+ if (Modifier == "ICmp") {
+ switch (Imm) {
+ default:
+ llvm_unreachable("Invalid ICmp mode");
+ case NVPTX::PTXCmpMode::EQ:
+ O << "eq";
+ return;
+ case NVPTX::PTXCmpMode::NE:
+ O << "ne";
+ return;
+ case NVPTX::PTXCmpMode::LT:
+ case NVPTX::PTXCmpMode::LTU:
+ O << "lt";
+ return;
+ case NVPTX::PTXCmpMode::LE:
+ case NVPTX::PTXCmpMode::LEU:
+ O << "le";
+ return;
+ case NVPTX::PTXCmpMode::GT:
+ case NVPTX::PTXCmpMode::GTU:
+ O << "gt";
+ return;
+ case NVPTX::PTXCmpMode::GE:
+ case NVPTX::PTXCmpMode::GEU:
+ O << "ge";
+ return;
+
+ }
+ }
+ if (Modifier == "IType") {
+ switch (Imm) {
+ default:
+ llvm_unreachable("Invalid ICmp mode");
+ case NVPTX::PTXCmpMode::EQ:
+ case NVPTX::PTXCmpMode::NE:
+ O << "b";
+ return;
+ case NVPTX::PTXCmpMode::LT:
+ case NVPTX::PTXCmpMode::LE:
+ case NVPTX::PTXCmpMode::GT:
+ case NVPTX::PTXCmpMode::GE:
+ O << "s";
+ return;
+ case NVPTX::PTXCmpMode::LTU:
+ case NVPTX::PTXCmpMode::LEU:
+ case NVPTX::PTXCmpMode::GTU:
+ case NVPTX::PTXCmpMode::GEU:
+ O << "u";
return;
}
}
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
index 6189284e8a58c..193c436939f66 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
@@ -54,6 +54,7 @@ class NVPTXInstPrinter : public MCInstPrinter {
void printCTAGroup(const MCInst *MI, int OpNum, raw_ostream &O);
void printCallOperand(const MCInst *MI, int OpNum, raw_ostream &O,
StringRef Modifier = {});
+ void printFTZFlag(const MCInst *MI, int OpNum, raw_ostream &O);
};
}
diff --git a/llvm/lib/Target/NVPTX/NVPTX.h b/llvm/lib/Target/NVPTX/NVPTX.h
index b7fd7090299a9..cfe2c8e7d2ed1 100644
--- a/llvm/lib/Target/NVPTX/NVPTX.h
+++ b/llvm/lib/Target/NVPTX/NVPTX.h
@@ -19,7 +19,7 @@
#include "llvm/Support/AtomicOrdering.h"
#include "llvm/Support/CodeGen.h"
#include "llvm/Target/TargetMachine.h"
-
+#include "llvm/CodeGen/ISDOpcodes.h"
namespace llvm {
class FunctionPass;
class MachineFunctionPass;
@@ -218,28 +218,21 @@ enum CvtMode {
/// PTXCmpMode - Comparison mode enumeration
namespace PTXCmpMode {
enum CmpMode {
- EQ = 0,
- NE,
- LT,
- LE,
- GT,
- GE,
- LO,
- LS,
- HI,
- HS,
- EQU,
- NEU,
- LTU,
- LEU,
- GTU,
- GEU,
- NUM,
+ EQ = ISD::SETEQ,
+ NE = ISD::SETNE,
+ LT = ISD::SETLT,
+ LE = ISD::SETLE,
+ GT = ISD::SETGT,
+ GE = ISD::SETGE,
+ EQU = ISD::SETUEQ,
+ NEU = ISD::SETUNE,
+ LTU = ISD::SETULT,
+ LEU = ISD::SETULE,
+ GTU = ISD::SETUGT,
+ GEU = ISD::SETUGE,
+ NUM = ISD::SETO,
// NAN is a MACRO
- NotANumber,
-
- BASE_MASK = 0xFF,
- FTZ_FLAG = 0x100
+ NotANumber = ISD::SETUO,
};
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 5631342ecc13e..75461a4b6213a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -363,23 +363,29 @@ bool NVPTXDAGToDAGISel::tryIntrinsicChain(SDNode *N) {
// Map ISD:CONDCODE value to appropriate CmpMode expected by
// NVPTXInstPrinter::printCmpMode()
-static unsigned getPTXCmpMode(const CondCodeSDNode &CondCode, bool FTZ) {
+SDValue NVPTXDAGToDAGISel::getPTXCmpMode(const CondCodeSDNode &CondCode) {
using NVPTX::PTXCmpMode::CmpMode;
- unsigned PTXCmpMode = [](ISD::CondCode CC) {
+ const unsigned PTXCmpMode = [](ISD::CondCode CC) {
switch (CC) {
default:
llvm_unreachable("Unexpected condition code.");
case ISD::SETOEQ:
+ case ISD::SETEQ:
return CmpMode::EQ;
case ISD::SETOGT:
+ case ISD::SETGT:
return CmpMode::GT;
case ISD::SETOGE:
+ case ISD::SETGE:
return CmpMode::GE;
case ISD::SETOLT:
+ case ISD::SETLT:
return CmpMode::LT;
case ISD::SETOLE:
+ case ISD::SETLE:
return CmpMode::LE;
case ISD::SETONE:
+ case ISD::SETNE:
return CmpMode::NE;
case ISD::SETO:
return CmpMode::NUM;
@@ -397,45 +403,29 @@ static unsigned getPTXCmpMode(const CondCodeSDNode &CondCode, bool FTZ) {
return CmpMode::LEU;
case ISD::SETUNE:
return CmpMode::NEU;
- case ISD::SETEQ:
- return CmpMode::EQ;
- case ISD::SETGT:
- return CmpMode::GT;
- case ISD::SETGE:
- return CmpMode::GE;
- case ISD::SETLT:
- return CmpMode::LT;
- case ISD::SETLE:
- return CmpMode::LE;
- case ISD::SETNE:
- return CmpMode::NE;
}
}(CondCode.get());
-
- if (FTZ)
- PTXCmpMode |= NVPTX::PTXCmpMode::FTZ_FLAG;
-
- return PTXCmpMode;
+ return CurDAG->getTargetConstant(PTXCmpMode, SDLoc(), MVT::i32);
}
bool NVPTXDAGToDAGISel::SelectSETP_F16X2(SDNode *N) {
- unsigned PTXCmpMode =
- getPTXCmpMode(*cast<CondCodeSDNode>(N->getOperand(2)), useF32FTZ());
+ SDValue PTXCmpMode =
+ getPTXCmpMode(*cast<CondCodeSDNode>(N->getOperand(2)));
SDLoc DL(N);
SDNode *SetP = CurDAG->getMachineNode(
- NVPTX::SETP_f16x2rr, DL, MVT::i1, MVT::i1, N->getOperand(0),
- N->getOperand(1), CurDAG->getTargetConstant(PTXCmpMode, DL, MVT::i32));
+ NVPTX::SETP_f16x2rr, DL, MVT::i1, MVT::i1, {N->getOperand(0),
+ N->getOperand(1), PTXCmpMode, CurDAG->getTargetConstant(useF32FTZ() ? 1 : 0, DL, MVT::i1)});
ReplaceNode(N, SetP);
return true;
}
bool NVPTXDAGToDAGISel::SelectSETP_BF16X2(SDNode *N) {
- unsigned PTXCmpMode =
- getPTXCmpMode(*cast<CondCodeSDNode>(N->getOperand(2)), useF32FTZ());
+ SDValue PTXCmpMode =
+ getPTXCmpMode(*cast<CondCodeSDNode>(N->getOperand(2)));
SDLoc DL(N);
SDNode *SetP = CurDAG->getMachineNode(
- NVPTX::SETP_bf16x2rr, DL, MVT::i1, MVT::i1, N->getOperand(0),
- N->getOperand(1), CurDAG->getTargetConstant(PTXCmpMode, DL, MVT::i32));
+ NVPTX::SETP_bf16x2rr, DL, MVT::i1, MVT::i1, {N->getOperand(0),
+ N->getOperand(1), PTXCmpMode, CurDAG->getTargetConstant(useF32FTZ() ? 1 : 0, DL, MVT::i1)});
ReplaceNode(N, SetP);
return true;
}
@@ -1953,7 +1943,7 @@ bool NVPTXDAGToDAGISel::tryBF16ArithToFMA(SDNode *N) {
llvm_unreachable("Unexpected opcode");
};
- int Opcode = IsVec ? NVPTX::BFMA16x2rrr : NVPTX::BFMA16rrr;
+ int Opcode = IsVec ? NVPTX::FMA_BF16x2rrr : NVPTX::FMA_BF16rrr;
MachineSDNode *FMA = CurDAG->getMachineNode(Opcode, DL, VT, Operands);
ReplaceNode(N, FMA);
return true;
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
index 0e4dec1adca67..b314c4ccefe8b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
@@ -104,12 +104,11 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
}
bool SelectADDR(SDValue Addr, SDValue &Base, SDValue &Offset);
+ SDValue getPTXCmpMode(const CondCodeSDNode &CondCode);
SDValue selectPossiblyImm(SDValue V);
bool ChkMemSDNodeAddressSpace(SDNode *N, unsigned int spN) const;
- static unsigned GetConvertOpcode(MVT DestTy, MVT SrcTy, LoadSDNode *N);
-
// Returns the Memory Order and Scope that the PTX memory instruction should
// use, and inserts appropriate fence instruction before the memory
// instruction, if needed to implement the instructions memory order. Required
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 1a2515b7f66f3..9ef9ce3b7bb8d 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -68,48 +68,28 @@ def CvtMode : Operand<i32> {
let PrintMethod = "printCvtMode";
}
+// FTZ flag
+
+def FTZ : PatLeaf<(i1 1)>;
+def NoFTZ : PatLeaf<(i1 0)>;
+
+def getFTZFlag : SDNodeXForm<imm, [{
+ (void)N;
+ return CurDAG->getTargetConstant(useF32FTZ() ? 1 : 0, SDLoc(), MVT::i1);
+}]>;
+
+def FTZFlag : OperandWithDefaultOps<i1, (ops (getFTZFlag (i1 0)))> {
+ let PrintMethod = "printFTZFlag";
+}
+
// Compare modes
// These must match the enum in NVPTX.h
-def CmpEQ : PatLeaf<(i32 0)>;
-def CmpNE : PatLeaf<(i32 1)>;
-def CmpLT : PatLeaf<(i32 2)>;
-def CmpLE : PatLeaf<(i32 3)>;
-def CmpGT : PatLeaf<(i32 4)>;
-def CmpGE : PatLeaf<(i32 5)>;
-def CmpLO : PatLeaf<(i32 6)>;
-def CmpLS : PatLeaf<(i32 7)>;
-def CmpHI : PatLeaf<(i32 8)>;
-def CmpHS : PatLeaf<(i32 9)>;
-def CmpEQU : PatLeaf<(i32 10)>;
-def CmpNEU : PatLeaf<(i32 11)>;
-def CmpLTU : PatLeaf<(i32 12)>;
-def CmpLEU : PatLeaf<(i32 13)>;
-def CmpGTU : PatLeaf<(i32 14)>;
-def CmpGEU : PatLeaf<(i32 15)>;
-def CmpNUM : PatLeaf<(i32 16)>;
-def CmpNAN : PatLeaf<(i32 17)>;
-
-def CmpEQ_FTZ : PatLeaf<(i32 0x100)>;
-def CmpNE_FTZ : PatLeaf<(i32 0x101)>;
-def CmpLT_FTZ : PatLeaf<(i32 0x102)>;
-def CmpLE_FTZ : PatLeaf<(i32 0x103)>;
-def CmpGT_FTZ : PatLeaf<(i32 0x104)>;
-def CmpGE_FTZ : PatLeaf<(i32 0x105)>;
-def CmpEQU_FTZ : PatLeaf<(i32 0x10A)>;
-def CmpNEU_FTZ : PatLeaf<(i32 0x10B)>;
-def CmpLTU_FTZ : PatLeaf<(i32 0x10C)>;
-def CmpLEU_FTZ : PatLeaf<(i32 0x10D)>;
-def CmpGTU_FTZ : PatLeaf<(i32 0x10E)>;
-def CmpGEU_FTZ : PatLeaf<(i32 0x10F)>;
-def CmpNUM_FTZ : PatLeaf<(i32 0x110)>;
-def CmpNAN_FTZ : PatLeaf<(i32 0x111)>;
+def CmpEQ : PatLeaf<(i32 17)>;
+def CmpNE : PatLeaf<(i32 22)>;
def CmpMode : Operand<i32> {
let PrintMethod = "printCmpMode";
}
-def VecElement : Operand<i32> {
- let PrintMethod = "printVecElement";
-}
// PRMT modes
// These must match the enum in NVPTX.h
@@ -152,8 +132,6 @@ def noHWROT32 : Predicate<"!Subtarget->hasHWROT32()">;
def hasDotInstructions : Predicate<"Subtarget->hasDotInstructions()">;
def hasTcgen05Instructions : Predicate<"Subtarget->hasTcgen05Instructions()">;
-def True : Predicate<"true">;
-
class hasPTX<int version>: Predicate<"Subtarget->getPTXVersion() >= " # version>;
class hasSM<int version>: Predicate<"Subtarget->getSmVersion() >= " # version>;
@@ -198,7 +176,7 @@ def RI64 : Operand<Any>;
// Utility class to wrap up information about a register and DAG type for more
// convenient iteration and parameterization
-class RegTyInfo<ValueType ty, NVPTXRegClass rc, Operand imm, SDNode imm_node,
+class RegTyInfo<ValueType ty, NVPTXRegClass rc, string str, Operand imm, SDNode imm_node,
bit supports_imm = 1> {
ValueType Ty = ty;
NVPTXRegClass RC = rc;
@@ -206,20 +184,21 @@ class RegTyInfo<ValueType ty, NVPTXRegClass rc, Operand imm, SDNode imm_node,
SDNode ImmNode = imm_node;
bit SupportsImm = supports_imm;
int Size = ty.Size;
+ string Str = str;
}
-def I1RT : RegTyInfo<i1, B1, i1imm, imm>;
-def I16RT : RegTyInfo<i16, B16, i16imm, imm>;
-def I32RT : RegTyInfo<i32, B32, i32imm, imm>;
-def I64RT : RegTyInfo<i64, B64, i64imm, imm>;
+def I1RT : RegTyInfo<i1, B1, "pred", i1imm, imm>;
+def I16RT : RegTyInfo<i16, B16, "b16", i16imm, imm>;
+def I32RT : RegTyInfo<i32, B32, "b32", i32imm, imm>;
+def I64RT : RegTyInfo<i64, B64, "b64", i64imm, imm>;
-def F32RT : RegTyInfo<f32, B32, f32imm, fpimm>;
-def F64RT : RegTyInfo<f64, B64, f64imm, fpimm>;
-def F16RT : RegTyInfo<f16, B16, f16imm, fpimm, supports_imm = 0>;
-def BF16RT : RegTyInfo<bf16, B16, bf16imm, fpimm, supports_imm = 0>;
+def F32RT : RegTyInfo<f32, B32, "f32", f32imm, fpimm>;
+def F64RT : RegTyInfo<f64, B64, "f64", f64imm, fpimm>;
+def F16RT : RegTyInfo<f16, B16, "f16", f16imm, fpimm, supports_imm = 0>;
+def BF16RT : RegTyInfo<bf16, B16, "bf16", bf16imm, fpimm, supports_imm = 0>;
-def F16X2RT : RegTyInfo<v2f16, B32, ?, ?, supports_imm = 0>;
-def BF16X2RT : RegTyInfo<v2bf16, B32, ?, ?, supports_imm = 0>;
+def F16X2RT : RegTyInfo<v2f16, B32, "f16x2", ?, ?, supports_imm = 0>;
+def BF16X2RT : RegTyInfo<v2bf16, B32, "bf16x2", ?, ?, supports_imm = 0>;
// This class provides a basic wrapper around an NVPTXInst that abstracts the
@@ -321,76 +300,57 @@ multiclass ADD_SUB_INT_CARRY<string op_str, SDNode op_node, bit commutative> {
// Also defines ftz (flush subnormal inputs and results to sign-preserving
// zero) variants for fp32 functions.
multiclass FMINIMUMMAXIMUM<string OpcStr, bit NaN, SDNode OpNode> {
+ defvar nan_str = !if(NaN, ".NaN", "");
if !not(NaN) then {
- def f64rr :
+ def _f64_rr :
BasicNVPTXInst<(outs B64:$dst),
(ins B64:$a, B64:$b),
OpcStr # ".f64",
[(set f64:$dst, (OpNode f64:$a, f64:$b))]>;
- def f64ri :
+ def _f64_ri :
BasicNVPTXInst<(outs B64:$dst),
(ins B64:$a, f64imm:$b),
OpcStr # ".f64",
[(set f64:$dst, (OpNode f64:$a, fpimm:$b))]>;
}
- def f32rr_ftz :
- BasicNVPTXInst<(outs B32:$dst),
+ def _f32_rr :
+ BasicFlagsNVPTXInst<(outs B32:$dst),
(ins B32:$a, B32:$b),
- OpcStr # ".ftz.f32",
- [(set f32:$dst, (OpNode f32:$a, f32:$b))]>,
- Requires<[doF32FTZ]>;
- def f32ri_ftz :
- BasicNVPTXInst<(outs B32:$dst),
- (ins B32:$a, f32imm:$b),
- OpcStr # ".ftz.f32",
- [(set f32:$dst, (OpNode f32:$a, fpimm:$b))]>,
- Requires<[doF32FTZ]>;
- def f32rr :
- BasicNVPTXInst<(outs B32:$dst),
- (ins B32:$a, B32:$b),
- OpcStr # ".f32",
+ (ins FTZFlag:$ftz),
+ OpcStr # "$ftz" # nan_str # ".f32",
[(set f32:$dst, (OpNode f32:$a, f32:$b))]>;
- def f32ri :
- BasicNVPTXInst<(outs B32:$dst),
+ def _f32_ri :
+ BasicFlagsNVPTXInst<(outs B32:$dst),
(ins B32:$a, f32imm:$b),
- OpcStr # ".f32",
+ (ins FTZFlag:$ftz),
+ OpcStr # "$ftz" # nan_str # ".f32",
[(set f32:$dst, (OpNode f32:$a, fpimm:$b))]>;
- def f16rr_ftz :
- BasicNVPTXInst<(outs B16:$dst),
- (ins B16:$a, B16:$b),
- OpcStr # ".ftz.f16",
- [(set f16:$dst, (OpNode f16:$a, f16:$b))]>,
- Requires<[useFP16Math, doF32FTZ]>;
- def f16rr :
- BasicNVPTXInst<(outs B16:$dst),
+ def _f16_rr :
+ BasicFlagsNVPTXInst<(outs B16:$dst),
(ins B16:$a, B16:$b),
- OpcStr # ".f16",
+ (ins FTZFlag:$ftz),
+ OpcStr # "$ftz" # nan_str # ".f16",
[(set f16:$dst, (OpNode f16:$a, f16:$b))]>,
- Requires<[useFP16Math, hasSM<80>, hasPTX<70>]>;
+ Requires<[useFP16Math]>;
- def f16x2rr_ftz :
- BasicNVPTXInst<(outs B32:$dst),
- (ins B32:$a, B32:$b),
- OpcStr # ".ftz.f16x2",
- [(set v2f16:$dst, (OpNode v2f16:$a, v2f16:$b))]>,
- Requires<[useFP16Math, hasSM<80>, hasPTX<70>, doF32FTZ]>;
- def f16x2rr :
- BasicNVPTXInst<(outs B32:$dst),
+ def _f16x2_rr :
+ BasicFlagsNVPTXInst<(outs B32:$dst),
(ins B32:$a, B32:$b),
- OpcStr # ".f16x2",
+ (ins FTZFlag:$ftz),
+ OpcStr # "$ftz" # nan_str # ".f16x2",
[(set v2f16:$dst, (OpNode v2f16:$a, v2f16:$b))]>,
Requires<[useFP16Math, hasSM<80>, hasPTX<70>]>;
- def bf16rr :
+ def _bf16_rr :
BasicNVPTXInst<(outs B16:$dst),
(ins B16:$a, B16:$b),
- OpcStr # ".bf16",
+ OpcStr # nan_str # ".bf16",
[(set bf16:$dst, (OpNode bf16:$a, bf16:$b))]>,
Requires<[hasBF16Math, hasSM<80>, hasPTX<70>]>;
- def bf16x2rr :
+ def _bf16x2_rr :
BasicNVPTXInst<(outs B32:$dst),
(ins B32:$a, B32:$b),
- OpcStr # ".bf16x2",
+ OpcStr # nan_str # ".bf16x2",
[(set v2bf16:$dst, (OpNode v2bf16:$a, v2bf16:$b))]>,
Requires<[hasBF16Math, hasSM<80>, hasPTX<70>]>;
}
@@ -415,52 +375,31 @@ multiclass F3<string op_str, SDPatternOperator op_pat> {
(ins B64:$a, f64imm:$b),
op_str # ".f64",
[(set f64:$dst, (op_pat f64:$a, fpimm:$b))]>;
- def f32rr_ftz :
- BasicNVPTXInst<(outs B32:$dst),
- (ins B32:$a, B32:$b),
- op_str # ".ftz.f32",
- [(set f32:$dst, (op_pat f32:$a, f32:$b))]>,
- Requires<[doF32FTZ]>;
- def f32ri_ftz :
- BasicNVPTXInst<(outs B32:$dst),
- (ins B32:$a, f32imm:$b),
- op_str # ".ftz.f32",
- [(set f32:$dst, (op_pat f32:$a, fpimm:$b))]>,
- Requires<[doF32FTZ]>;
def f32rr :
- BasicNVPTXInst<(outs B32:$dst),
+ BasicFlagsNVPTXInst<(outs B32:$dst),
(ins B32:$a, B32:$b),
- op_str # ".f32",
+ (ins FTZFlag:$ftz),
+ op_str # "$ftz.f32",
[(set f32:$dst, (op_pat f32:$a, f32:$b))]>;
def f32ri :
- BasicNVPTXInst<(outs B32:$dst),
+ BasicFlagsNVPTXInst<(outs B32:$dst),
(ins B32:$a, f32imm:$b),
- op_str # ".f32",
+ (ins FTZFlag:$ftz),
+ op_str # "$ftz.f32",
[(set f32:$dst, (op_pat f32:$a, fpimm:$b))]>;
- def f16rr_ftz :
- BasicNVPTXInst<(outs B16:$dst),
- (ins B16:$a, B16:$b),
- op_str # ".ftz.f16",
- [(set f16:$dst, (op_pat f16:$a, f16:$b))]>,
- Requires<[useFP16Math, doF32FTZ]>;
def f16rr :
- BasicNVPTXInst<(outs B16:$dst),
+ BasicFlagsNVPTXInst<(outs B16:$dst),
(ins B16:$a, B16:$b),
- op_str # ".f16",
+ (ins FTZFlag:$ftz),
+ op_str # "$ftz.f16",
[(set f16:$dst, (op_pat f16:$a, f16:$b))]>,
Requires<[useFP16Math]>;
-
- def f16x2rr_ftz :
- BasicNVPTXInst<(outs B32:$dst),
- (...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/146410
More information about the llvm-commits
mailing list