[llvm] [NVPTX] Add float to tf32 conversion intrinsic (PR #121507)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Jan 2 09:53:23 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-nvptx
Author: Durgadoss R (durga4github)
<details>
<summary>Changes</summary>
This patch adds an intrinsic to convert float to tf32.
* This intrinsic uses flags for rounding, saturation modes, and relu. The backend looks through these flags and lowers them to the appropriate instruction.
* Docs have been updated to describe the usage of flag arguments.
* Lit tests are added for all the combinations.
TODO: We already have an intrinsic 'llvm.nvvm.f2tf32.rna' which caters only to one variant of the PTX instruction. Once this change lands, I will submit a follow-up PR to auto-upgrade it to use the generic variant.
PTX Spec link:
https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt
---
Full diff: https://github.com/llvm/llvm-project/pull/121507.diff
12 Files Affected:
- (modified) llvm/docs/NVPTXUsage.rst (+60)
- (modified) llvm/include/llvm/IR/IntrinsicsNVVM.td (+9)
- (modified) llvm/include/llvm/IR/NVVMIntrinsicFlags.h (+16)
- (modified) llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp (+53)
- (modified) llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h (+6)
- (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp (+46)
- (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h (+1)
- (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+16)
- (modified) llvm/lib/Target/NVPTX/NVPTXSubtarget.h (+1)
- (modified) llvm/test/CodeGen/NVPTX/convert-sm80.ll (+17)
- (modified) llvm/test/CodeGen/NVPTX/convert-sm89.ll (+9)
- (added) llvm/test/CodeGen/NVPTX/convert-sm90.ll (+65)
``````````diff
diff --git a/llvm/docs/NVPTXUsage.rst b/llvm/docs/NVPTXUsage.rst
index 313e84f3722a95..f6d5d27b8850c9 100644
--- a/llvm/docs/NVPTXUsage.rst
+++ b/llvm/docs/NVPTXUsage.rst
@@ -462,6 +462,66 @@ to left-shift the found bit into the most-significant bit position, otherwise
the result is the shift amount needed to right-shift the found bit into the
least-significant bit position. 0xffffffff is returned if no 1 bit is found.
+Conversion Intrinsics (for cvt.* PTX instructions)
+--------------------------------------------------
+
+'``llvm.nvvm.cvt.float.to.tf32``'
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+
+.. code-block:: llvm
+
+ declare i32 @llvm.nvvm.cvt.float.to.tf32(float %f1, i8 %flag_fp_rnd_mode, i8 %flag_sat_mode, i1 %flag_relu)
+
+Overview:
+"""""""""
+
+The '``@llvm.nvvm.cvt.float.to.tf32``' intrinsic lowers to
+the ``cvt.*.tf32.f32`` set of PTX instructions.
+
+* The first argument is the input float to be converted to TF32.
+ This is followed by three flag arguments encoding the rounding mode,
+ saturation mode, and the relu modifier respectively.
+
+* The second argument (denoted by ``i8 %flag_fp_rnd_mode``) denotes
+ the floating-point rounding modes supported for this instruction.
+ This must be a compile-time constant and the encoding is as below:
+
+ ========== ==============
+ Enum Value Rounding Mode
+ ========== ==============
+ ``0`` NONE
+ ``1`` ROUND_RZ
+ ``2`` ROUND_RN
+ ``3`` ROUND_RP
+ ``4`` ROUND_RM
+ ``5`` ROUND_RNA
+ ========== ==============
+
+ The valid rounding modes are ``RNA, RN and RZ``.
+
+* The third argument (denoted by ``i8 %flag_sat_mode``) denotes the
+ saturation modifier for this intrinsic. As of now, it can either
+ be None or Satfinite, according to the enumeration below:
+
+ ========== ================
+ Enum Value Saturation Mode
+ ========== ================
+ ``0`` NONE
+ ``1`` SATFINITE
+ ========== ================
+
+* The last argument (denoted by ``i1 %flag_relu``) when set, generates
+ the ``.relu`` variant of the instruction.
+
+* Invalid values for the compile-time flag arguments may lead
+ to error(s) during Codegen.
+
+For more information, refer PTX ISA
+`<https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt>`_.
+
TMA family of Intrinsics
------------------------
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index fd07d131ce15b2..870378bda44b0a 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1466,6 +1466,15 @@ let TargetPrefix = "nvvm" in {
def int_nvvm_e5m2x2_to_f16x2_rn_relu : ClangBuiltin<"__nvvm_e5m2x2_to_f16x2_rn_relu">,
Intrinsic<[llvm_v2f16_ty], [llvm_i16_ty], [IntrNoMem, IntrNoCallback]>;
+// Convert Float to TF32
+def int_nvvm_cvt_float_to_tf32 : Intrinsic<[llvm_i32_ty],
+ [llvm_float_ty, // Input float
+ llvm_i8_ty, // Flag for Rounding Modes
+ llvm_i8_ty, // Flag for Saturation Modes
+ llvm_i1_ty], // Flag for relu
+ [IntrNoMem, IntrNoCallback,
+ ImmArg<ArgIndex<1>>, ImmArg<ArgIndex<2>>, ImmArg<ArgIndex<3>>]>;
+
// FNS
def int_nvvm_fns : ClangBuiltin<"__nvvm_fns">,
diff --git a/llvm/include/llvm/IR/NVVMIntrinsicFlags.h b/llvm/include/llvm/IR/NVVMIntrinsicFlags.h
index dfb6e857b3a6ad..3dfa58313e3b60 100644
--- a/llvm/include/llvm/IR/NVVMIntrinsicFlags.h
+++ b/llvm/include/llvm/IR/NVVMIntrinsicFlags.h
@@ -34,6 +34,22 @@ enum class TMAReductionOp : uint8_t {
XOR = 7,
};
+// Rounding Modes for floating point types
+enum class FPRoundingMode : uint8_t {
+ NONE = 0,
+ ROUND_RZ = 1, // roundTowardZero
+ ROUND_RN = 2, // roundToNearest-TiesToEven
+ ROUND_RP = 3, // roundTowardPositiveInf
+ ROUND_RM = 4, // roundTowardNegativeInf
+ ROUND_RNA = 5, // roundToNearest-TiesAwayFromZero
+};
+
+// Saturation Modes
+enum class SaturationMode : uint8_t {
+ NONE = 0,
+ SATFINITE = 1,
+};
+
} // namespace nvvm
} // namespace llvm
#endif // LLVM_IR_NVVMINTRINSICFLAGS_H
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
index 65e1893d3f3bdf..06dc60da9e6462 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
@@ -453,3 +453,56 @@ void NVPTXInstPrinter::printTmaReductionMode(const MCInst *MI, int OpNum,
llvm_unreachable(
"Invalid Reduction Op in printCpAsyncBulkTensorReductionMode");
}
+
+void NVPTXInstPrinter::printFPRoundingMode(const MCInst *MI, int OpNum,
+ raw_ostream &O,
+ const char *Modifier) {
+ const MCOperand &MO = MI->getOperand(OpNum);
+ using Mode = nvvm::FPRoundingMode;
+
+ switch (static_cast<Mode>(MO.getImm())) {
+ case Mode::NONE:
+ O << "";
+ return;
+ case Mode::ROUND_RN:
+ O << ".rn";
+ return;
+ case Mode::ROUND_RNA:
+ O << ".rna";
+ return;
+ case Mode::ROUND_RZ:
+ O << ".rz";
+ return;
+ case Mode::ROUND_RP:
+ O << ".rp";
+ return;
+ case Mode::ROUND_RM:
+ O << ".rm";
+ return;
+ }
+ llvm_unreachable("Invalid mode in printFPRoundingMode");
+}
+
+void NVPTXInstPrinter::printSaturationMode(const MCInst *MI, int OpNum,
+ raw_ostream &O,
+ const char *Modifier) {
+ const MCOperand &MO = MI->getOperand(OpNum);
+ using Mode = nvvm::SaturationMode;
+
+ switch (static_cast<Mode>(MO.getImm())) {
+ case Mode::NONE:
+ O << "";
+ return;
+ case Mode::SATFINITE:
+ O << ".satfinite";
+ return;
+ }
+ llvm_unreachable("Invalid mode in printSaturationMode");
+}
+
+void NVPTXInstPrinter::printReluModifier(const MCInst *MI, int OpNum,
+ raw_ostream &O, const char *Modifier) {
+ const MCOperand &MO = MI->getOperand(OpNum);
+ if (MO.getImm())
+ O << ".relu";
+}
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
index 2b19386ef17fe5..7c3be27751ca14 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
@@ -56,6 +56,12 @@ class NVPTXInstPrinter : public MCInstPrinter {
const char *Modifier = nullptr);
void printTmaReductionMode(const MCInst *MI, int OpNum, raw_ostream &O,
const char *Modifier = nullptr);
+ void printFPRoundingMode(const MCInst *MI, int OpNum, raw_ostream &O,
+ const char *Modifier = nullptr);
+ void printSaturationMode(const MCInst *MI, int OpNum, raw_ostream &O,
+ const char *Modifier = nullptr);
+ void printReluModifier(const MCInst *MI, int OpNum, raw_ostream &O,
+ const char *Modifier = nullptr);
};
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index c51729e224bf54..82ac658c4a4570 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -728,7 +728,53 @@ bool NVPTXDAGToDAGISel::tryIntrinsicNoChain(SDNode *N) {
case Intrinsic::nvvm_texsurf_handle_internal:
SelectTexSurfHandle(N);
return true;
+ case Intrinsic::nvvm_cvt_float_to_tf32:
+ SelectCvtFloatToTF32(N);
+ return true;
+ }
+}
+
+void NVPTXDAGToDAGISel::SelectCvtFloatToTF32(SDNode *N) {
+ // 0 - IID
+ // 1 - Input Float
+ // 2 - Rounding Mode
+ // 3 - Saturation Mode
+ // 4 - Relu Flag
+ uint64_t Rnd = N->getConstantOperandVal(2);
+ uint64_t Sat = N->getConstantOperandVal(3);
+ bool IsRelu = N->getConstantOperandVal(4) == 1;
+
+ if (!Subtarget->hasTF32Math())
+ report_fatal_error("TF32 destination format requires at least sm80");
+
+ using SatMode = nvvm::SaturationMode;
+ bool IsSatFinite = static_cast<SatMode>(Sat) == SatMode::SATFINITE;
+ if (IsSatFinite && Subtarget->getPTXVersion() < 81)
+ report_fatal_error("satfinite modifier requires PTX version 8.1 or higher");
+
+ using RndMode = nvvm::FPRoundingMode;
+ switch (static_cast<RndMode>(Rnd)) {
+ case RndMode::ROUND_RNA:
+ if (IsRelu)
+ report_fatal_error("relu not supported with rna rounding mode");
+ break;
+ case RndMode::ROUND_RN:
+ case RndMode::ROUND_RZ: {
+ if (Subtarget->getSmVersion() < 90)
+ report_fatal_error("rn/rz rounding modes require at least sm90");
+ if (IsSatFinite)
+ report_fatal_error("satfinite not supported with rn/rz rounding modes");
+ break;
+ }
+ default:
+ report_fatal_error("Invalid FP rounding mode in SelectCvtFloatToTF32");
}
+
+ SDLoc DL(N);
+ SDValue Ops[] = {N->getOperand(1), getI32Imm(Rnd, DL), getI32Imm(Sat, DL),
+ getI32Imm(IsRelu, DL)};
+ ReplaceNode(N, CurDAG->getMachineNode(NVPTX::cvt_float_to_tf32, DL,
+ N->getVTList(), Ops));
}
void NVPTXDAGToDAGISel::SelectTexSurfHandle(SDNode *N) {
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
index c307f28fcc6c0a..3e22ef5bab9931 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
@@ -73,6 +73,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
bool tryIntrinsicChain(SDNode *N);
bool tryIntrinsicVoid(SDNode *N);
void SelectTexSurfHandle(SDNode *N);
+ void SelectCvtFloatToTF32(SDNode *N);
bool tryLoad(SDNode *N);
bool tryLoadVector(SDNode *N);
bool tryLDGLDU(SDNode *N);
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 8ede1ec4f20dc9..3274c1ef4260db 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -1802,6 +1802,22 @@ def : Pat<(int_nvvm_e5m2x2_to_f16x2_rn Int16Regs:$a),
def : Pat<(int_nvvm_e5m2x2_to_f16x2_rn_relu Int16Regs:$a),
(CVT_f16x2_e5m2x2 $a, CvtRN_RELU)>;
+def FPRoundingMode : Operand<i32> {
+ let PrintMethod = "printFPRoundingMode";
+}
+
+def SatMode : Operand<i32> {
+ let PrintMethod = "printSaturationMode";
+}
+
+def ReluFlag : Operand<i32> {
+ let PrintMethod = "printReluModifier";
+}
+
+def cvt_float_to_tf32 : NVPTXInst<(outs Int32Regs:$dest),
+ (ins Float32Regs:$a, FPRoundingMode:$rnd, SatMode:$sat, ReluFlag:$relu),
+ "cvt${rnd:rnd}${sat:sat}${relu:relu}.tf32.f32 \t$dest, $a;", []>;
+
//
// FNS
//
diff --git a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
index 7555a2368ec963..9f0b437bd32dc5 100644
--- a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
+++ b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
@@ -83,6 +83,7 @@ class NVPTXSubtarget : public NVPTXGenSubtargetInfo {
bool hasFP16Math() const { return SmVersion >= 53; }
bool hasBF16Math() const { return SmVersion >= 80; }
bool allowFP16Math() const;
+ bool hasTF32Math() const { return SmVersion >= 80 && PTXVersion >= 70; }
bool hasMaskOperator() const { return PTXVersion >= 71; }
bool hasNoReturn() const { return SmVersion >= 30 && PTXVersion >= 64; }
// Does SM & PTX support memory orderings (weak and atomic: relaxed, acquire,
diff --git a/llvm/test/CodeGen/NVPTX/convert-sm80.ll b/llvm/test/CodeGen/NVPTX/convert-sm80.ll
index aebc28b1cfea3e..d54fa73d306ffa 100644
--- a/llvm/test/CodeGen/NVPTX/convert-sm80.ll
+++ b/llvm/test/CodeGen/NVPTX/convert-sm80.ll
@@ -261,3 +261,20 @@ define <2 x half> @fold_ff2f16x2(float %lo, float %hi) {
%v1 = insertelement <2 x half> %v0, half %hih, i64 1
ret <2 x half> %v1
}
+
+declare i32 @llvm.nvvm.cvt.float.to.tf32(float %f1, i8, i8, i1)
+
+define i32 @cvt_rna_tf32_f32_flags(float %f1) {
+; CHECK-LABEL: cvt_rna_tf32_f32_flags(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<2>;
+; CHECK-NEXT: .reg .f32 %f<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.f32 %f1, [cvt_rna_tf32_f32_flags_param_0];
+; CHECK-NEXT: cvt.rna.tf32.f32 %r1, %f1;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT: ret;
+ %val = call i32 @llvm.nvvm.cvt.float.to.tf32(float %f1, i8 5, i8 0, i1 0)
+ ret i32 %val
+}
diff --git a/llvm/test/CodeGen/NVPTX/convert-sm89.ll b/llvm/test/CodeGen/NVPTX/convert-sm89.ll
index 5d0576aebbe089..3ff58b95348095 100644
--- a/llvm/test/CodeGen/NVPTX/convert-sm89.ll
+++ b/llvm/test/CodeGen/NVPTX/convert-sm89.ll
@@ -84,3 +84,12 @@ define <2 x half> @cvt_rn_relu_f16x2_e5m2x2(i16 %in) {
%val = call <2 x half> @llvm.nvvm.e5m2x2.to.f16x2.rn.relu(i16 %in);
ret <2 x half> %val
}
+
+declare i32 @llvm.nvvm.cvt.float.to.tf32(float %f1, i8, i8, i1)
+
+; CHECK-LABEL: cvt_rna_satfinite_tf32_f32
+define i32 @cvt_rna_satfinite_tf32_f32(float %f1) {
+; CHECK: cvt.rna.satfinite.tf32.f32
+ %val = call i32 @llvm.nvvm.cvt.float.to.tf32(float %f1, i8 5, i8 1, i1 0)
+ ret i32 %val
+}
diff --git a/llvm/test/CodeGen/NVPTX/convert-sm90.ll b/llvm/test/CodeGen/NVPTX/convert-sm90.ll
new file mode 100644
index 00000000000000..8f932005830250
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/convert-sm90.ll
@@ -0,0 +1,65 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx78| FileCheck --check-prefixes=CHECK %s
+; RUN: %if ptxas-12.0 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx78| %ptxas-verify -arch=sm_90 %}
+
+declare i32 @llvm.nvvm.cvt.float.to.tf32(float %f1, i8, i8, i1)
+
+define i32 @cvt_rn_tf32_f32(float %f1) {
+; CHECK-LABEL: cvt_rn_tf32_f32(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<2>;
+; CHECK-NEXT: .reg .f32 %f<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.f32 %f1, [cvt_rn_tf32_f32_param_0];
+; CHECK-NEXT: cvt.rn.tf32.f32 %r1, %f1;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT: ret;
+ %val = call i32 @llvm.nvvm.cvt.float.to.tf32(float %f1, i8 2, i8 0, i1 0)
+ ret i32 %val
+}
+
+define i32 @cvt_rn_relu_tf32_f32(float %f1) {
+; CHECK-LABEL: cvt_rn_relu_tf32_f32(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<2>;
+; CHECK-NEXT: .reg .f32 %f<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.f32 %f1, [cvt_rn_relu_tf32_f32_param_0];
+; CHECK-NEXT: cvt.rn.relu.tf32.f32 %r1, %f1;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT: ret;
+ %val = call i32 @llvm.nvvm.cvt.float.to.tf32(float %f1, i8 2, i8 0, i1 1)
+ ret i32 %val
+}
+
+define i32 @cvt_rz_tf32_f32(float %f1) {
+; CHECK-LABEL: cvt_rz_tf32_f32(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<2>;
+; CHECK-NEXT: .reg .f32 %f<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.f32 %f1, [cvt_rz_tf32_f32_param_0];
+; CHECK-NEXT: cvt.rz.tf32.f32 %r1, %f1;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT: ret;
+ %val = call i32 @llvm.nvvm.cvt.float.to.tf32(float %f1, i8 1, i8 0, i1 0)
+ ret i32 %val
+}
+
+define i32 @cvt_rz_relu_tf32_f32(float %f1) {
+; CHECK-LABEL: cvt_rz_relu_tf32_f32(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<2>;
+; CHECK-NEXT: .reg .f32 %f<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.f32 %f1, [cvt_rz_relu_tf32_f32_param_0];
+; CHECK-NEXT: cvt.rz.relu.tf32.f32 %r1, %f1;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT: ret;
+ %val = call i32 @llvm.nvvm.cvt.float.to.tf32(float %f1, i8 1, i8 0, i1 1)
+ ret i32 %val
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/121507
More information about the llvm-commits
mailing list