[llvm] [NVPTX] Add float to tf32 conversion intrinsic (PR #121507)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 2 09:53:23 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-nvptx

Author: Durgadoss R (durga4github)

<details>
<summary>Changes</summary>

This patch adds an intrinsic to convert float to tf32.

* This intrinsic uses flags for rounding, saturation modes, and relu. The backend looks through these flags and lowers them to the appropriate instruction.
* Docs have been updated to describe the usage of flag arguments.
* Lit tests are added for all the combinations.

TODO: We already have an intrinsic 'llvm.nvvm.f2tf32.rna' which caters only to one variant of the PTX instruction. Once this change lands, I will submit a follow-up PR to auto-upgrade it to use the generic variant.

PTX Spec link:
https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt

---
Full diff: https://github.com/llvm/llvm-project/pull/121507.diff


12 Files Affected:

- (modified) llvm/docs/NVPTXUsage.rst (+60) 
- (modified) llvm/include/llvm/IR/IntrinsicsNVVM.td (+9) 
- (modified) llvm/include/llvm/IR/NVVMIntrinsicFlags.h (+16) 
- (modified) llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp (+53) 
- (modified) llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h (+6) 
- (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp (+46) 
- (modified) llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h (+1) 
- (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+16) 
- (modified) llvm/lib/Target/NVPTX/NVPTXSubtarget.h (+1) 
- (modified) llvm/test/CodeGen/NVPTX/convert-sm80.ll (+17) 
- (modified) llvm/test/CodeGen/NVPTX/convert-sm89.ll (+9) 
- (added) llvm/test/CodeGen/NVPTX/convert-sm90.ll (+65) 


``````````diff
diff --git a/llvm/docs/NVPTXUsage.rst b/llvm/docs/NVPTXUsage.rst
index 313e84f3722a95..f6d5d27b8850c9 100644
--- a/llvm/docs/NVPTXUsage.rst
+++ b/llvm/docs/NVPTXUsage.rst
@@ -462,6 +462,66 @@ to left-shift the found bit into the most-significant bit position, otherwise
 the result is the shift amount needed to right-shift the found bit into the
 least-significant bit position. 0xffffffff is returned if no 1 bit is found.
 
+Conversion Intrinsics (for cvt.* PTX instructions)
+--------------------------------------------------
+
+'``llvm.nvvm.cvt.float.to.tf32``'
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+
+.. code-block:: llvm
+
+  declare i32 @llvm.nvvm.cvt.float.to.tf32(float %f1, i8 %flag_fp_rnd_mode, i8 %flag_sat_mode, i1 %flag_relu)
+
+Overview:
+"""""""""
+
+The '``@llvm.nvvm.cvt.float.to.tf32``' intrinsic lowers to
+the ``cvt.*.tf32.f32`` set of PTX instructions.
+
+* The first argument is the input float to be converted to TF32.
+  This is followed by three flag arguments encoding the rounding mode,
+  saturation mode, and the relu modifier respectively.
+
+* The second argument (denoted by ``i8 %flag_fp_rnd_mode``) denotes
+  the floating-point rounding modes supported for this instruction.
+  This must be a compile-time constant and the encoding is as below:
+
+  ========== ==============
+  Enum Value  Rounding Mode
+  ========== ==============
+  ``0``       NONE
+  ``1``       ROUND_RZ
+  ``2``       ROUND_RN
+  ``3``       ROUND_RP
+  ``4``       ROUND_RM
+  ``5``       ROUND_RNA
+  ========== ==============
+
+  The valid rounding modes are ``RNA, RN and RZ``.
+
+* The third argument (denoted by ``i8 %flag_sat_mode``) denotes the
+  saturation modifier for this intrinsic. As of now, it can either
+  be None or Satfinite, according to the enumeration below:
+
+  ========== ================
+  Enum Value  Saturation Mode
+  ========== ================
+  ``0``       NONE
+  ``1``       SATFINITE
+  ========== ================
+
+* The last argument (denoted by ``i1 %flag_relu``) when set, generates
+  the ``.relu`` variant of the instruction.
+
+* Invalid values for the compile-time flag arguments may lead
+  to error(s) during Codegen.
+
+For more information, refer PTX ISA
+`<https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt>`_.
+
 TMA family of Intrinsics
 ------------------------
 
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index fd07d131ce15b2..870378bda44b0a 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1466,6 +1466,15 @@ let TargetPrefix = "nvvm" in {
   def int_nvvm_e5m2x2_to_f16x2_rn_relu : ClangBuiltin<"__nvvm_e5m2x2_to_f16x2_rn_relu">,
       Intrinsic<[llvm_v2f16_ty], [llvm_i16_ty], [IntrNoMem, IntrNoCallback]>;
 
+// Convert Float to TF32
+def int_nvvm_cvt_float_to_tf32 : Intrinsic<[llvm_i32_ty],
+    [llvm_float_ty, // Input float
+     llvm_i8_ty,    // Flag for Rounding Modes
+     llvm_i8_ty,    // Flag for Saturation Modes
+     llvm_i1_ty],   // Flag for relu
+    [IntrNoMem, IntrNoCallback,
+     ImmArg<ArgIndex<1>>, ImmArg<ArgIndex<2>>, ImmArg<ArgIndex<3>>]>;
+
 // FNS
 
   def int_nvvm_fns : ClangBuiltin<"__nvvm_fns">,
diff --git a/llvm/include/llvm/IR/NVVMIntrinsicFlags.h b/llvm/include/llvm/IR/NVVMIntrinsicFlags.h
index dfb6e857b3a6ad..3dfa58313e3b60 100644
--- a/llvm/include/llvm/IR/NVVMIntrinsicFlags.h
+++ b/llvm/include/llvm/IR/NVVMIntrinsicFlags.h
@@ -34,6 +34,22 @@ enum class TMAReductionOp : uint8_t {
   XOR = 7,
 };
 
+// Rounding Modes for floating point types
+enum class FPRoundingMode : uint8_t {
+  NONE = 0,
+  ROUND_RZ = 1,  // roundTowardZero
+  ROUND_RN = 2,  // roundToNearest-TiesToEven
+  ROUND_RP = 3,  // roundTowardPositiveInf
+  ROUND_RM = 4,  // roundTowardNegativeInf
+  ROUND_RNA = 5, // roundToNearest-TiesAwayFromZero
+};
+
+// Saturation Modes
+enum class SaturationMode : uint8_t {
+  NONE = 0,
+  SATFINITE = 1,
+};
+
 } // namespace nvvm
 } // namespace llvm
 #endif // LLVM_IR_NVVMINTRINSICFLAGS_H
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
index 65e1893d3f3bdf..06dc60da9e6462 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
@@ -453,3 +453,56 @@ void NVPTXInstPrinter::printTmaReductionMode(const MCInst *MI, int OpNum,
   llvm_unreachable(
       "Invalid Reduction Op in printCpAsyncBulkTensorReductionMode");
 }
+
+void NVPTXInstPrinter::printFPRoundingMode(const MCInst *MI, int OpNum,
+                                           raw_ostream &O,
+                                           const char *Modifier) {
+  const MCOperand &MO = MI->getOperand(OpNum);
+  using Mode = nvvm::FPRoundingMode;
+
+  switch (static_cast<Mode>(MO.getImm())) {
+  case Mode::NONE:
+    O << "";
+    return;
+  case Mode::ROUND_RN:
+    O << ".rn";
+    return;
+  case Mode::ROUND_RNA:
+    O << ".rna";
+    return;
+  case Mode::ROUND_RZ:
+    O << ".rz";
+    return;
+  case Mode::ROUND_RP:
+    O << ".rp";
+    return;
+  case Mode::ROUND_RM:
+    O << ".rm";
+    return;
+  }
+  llvm_unreachable("Invalid mode in printFPRoundingMode");
+}
+
+void NVPTXInstPrinter::printSaturationMode(const MCInst *MI, int OpNum,
+                                           raw_ostream &O,
+                                           const char *Modifier) {
+  const MCOperand &MO = MI->getOperand(OpNum);
+  using Mode = nvvm::SaturationMode;
+
+  switch (static_cast<Mode>(MO.getImm())) {
+  case Mode::NONE:
+    O << "";
+    return;
+  case Mode::SATFINITE:
+    O << ".satfinite";
+    return;
+  }
+  llvm_unreachable("Invalid mode in printSaturationMode");
+}
+
+void NVPTXInstPrinter::printReluModifier(const MCInst *MI, int OpNum,
+                                         raw_ostream &O, const char *Modifier) {
+  const MCOperand &MO = MI->getOperand(OpNum);
+  if (MO.getImm())
+    O << ".relu";
+}
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
index 2b19386ef17fe5..7c3be27751ca14 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
@@ -56,6 +56,12 @@ class NVPTXInstPrinter : public MCInstPrinter {
                      const char *Modifier = nullptr);
   void printTmaReductionMode(const MCInst *MI, int OpNum, raw_ostream &O,
                              const char *Modifier = nullptr);
+  void printFPRoundingMode(const MCInst *MI, int OpNum, raw_ostream &O,
+                           const char *Modifier = nullptr);
+  void printSaturationMode(const MCInst *MI, int OpNum, raw_ostream &O,
+                           const char *Modifier = nullptr);
+  void printReluModifier(const MCInst *MI, int OpNum, raw_ostream &O,
+                         const char *Modifier = nullptr);
 };
 
 }
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index c51729e224bf54..82ac658c4a4570 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -728,7 +728,53 @@ bool NVPTXDAGToDAGISel::tryIntrinsicNoChain(SDNode *N) {
   case Intrinsic::nvvm_texsurf_handle_internal:
     SelectTexSurfHandle(N);
     return true;
+  case Intrinsic::nvvm_cvt_float_to_tf32:
+    SelectCvtFloatToTF32(N);
+    return true;
+  }
+}
+
+void NVPTXDAGToDAGISel::SelectCvtFloatToTF32(SDNode *N) {
+  // 0 - IID
+  // 1 - Input Float
+  // 2 - Rounding Mode
+  // 3 - Saturation Mode
+  // 4 - Relu Flag
+  uint64_t Rnd = N->getConstantOperandVal(2);
+  uint64_t Sat = N->getConstantOperandVal(3);
+  bool IsRelu = N->getConstantOperandVal(4) == 1;
+
+  if (!Subtarget->hasTF32Math())
+    report_fatal_error("TF32 destination format requires at least sm80");
+
+  using SatMode = nvvm::SaturationMode;
+  bool IsSatFinite = static_cast<SatMode>(Sat) == SatMode::SATFINITE;
+  if (IsSatFinite && Subtarget->getPTXVersion() < 81)
+    report_fatal_error("satfinite modifier requires PTX version 8.1 or higher");
+
+  using RndMode = nvvm::FPRoundingMode;
+  switch (static_cast<RndMode>(Rnd)) {
+  case RndMode::ROUND_RNA:
+    if (IsRelu)
+      report_fatal_error("relu not supported with rna rounding mode");
+    break;
+  case RndMode::ROUND_RN:
+  case RndMode::ROUND_RZ: {
+    if (Subtarget->getSmVersion() < 90)
+      report_fatal_error("rn/rz rounding modes require at least sm90");
+    if (IsSatFinite)
+      report_fatal_error("satfinite not supported with rn/rz rounding modes");
+    break;
+  }
+  default:
+    report_fatal_error("Invalid FP rounding mode in SelectCvtFloatToTF32");
   }
+
+  SDLoc DL(N);
+  SDValue Ops[] = {N->getOperand(1), getI32Imm(Rnd, DL), getI32Imm(Sat, DL),
+                   getI32Imm(IsRelu, DL)};
+  ReplaceNode(N, CurDAG->getMachineNode(NVPTX::cvt_float_to_tf32, DL,
+                                        N->getVTList(), Ops));
 }
 
 void NVPTXDAGToDAGISel::SelectTexSurfHandle(SDNode *N) {
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
index c307f28fcc6c0a..3e22ef5bab9931 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
@@ -73,6 +73,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
   bool tryIntrinsicChain(SDNode *N);
   bool tryIntrinsicVoid(SDNode *N);
   void SelectTexSurfHandle(SDNode *N);
+  void SelectCvtFloatToTF32(SDNode *N);
   bool tryLoad(SDNode *N);
   bool tryLoadVector(SDNode *N);
   bool tryLDGLDU(SDNode *N);
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 8ede1ec4f20dc9..3274c1ef4260db 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -1802,6 +1802,22 @@ def : Pat<(int_nvvm_e5m2x2_to_f16x2_rn Int16Regs:$a),
 def : Pat<(int_nvvm_e5m2x2_to_f16x2_rn_relu Int16Regs:$a),
           (CVT_f16x2_e5m2x2 $a, CvtRN_RELU)>;
 
+def FPRoundingMode : Operand<i32> {
+  let PrintMethod = "printFPRoundingMode";
+}
+
+def SatMode : Operand<i32> {
+  let PrintMethod = "printSaturationMode";
+}
+
+def ReluFlag : Operand<i32> {
+  let PrintMethod = "printReluModifier";
+}
+
+def cvt_float_to_tf32 : NVPTXInst<(outs Int32Regs:$dest),
+    (ins Float32Regs:$a, FPRoundingMode:$rnd, SatMode:$sat, ReluFlag:$relu),
+    "cvt${rnd:rnd}${sat:sat}${relu:relu}.tf32.f32 \t$dest, $a;", []>;
+
 //
 // FNS
 //
diff --git a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
index 7555a2368ec963..9f0b437bd32dc5 100644
--- a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
+++ b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
@@ -83,6 +83,7 @@ class NVPTXSubtarget : public NVPTXGenSubtargetInfo {
   bool hasFP16Math() const { return SmVersion >= 53; }
   bool hasBF16Math() const { return SmVersion >= 80; }
   bool allowFP16Math() const;
+  bool hasTF32Math() const { return SmVersion >= 80 && PTXVersion >= 70; }
   bool hasMaskOperator() const { return PTXVersion >= 71; }
   bool hasNoReturn() const { return SmVersion >= 30 && PTXVersion >= 64; }
   // Does SM & PTX support memory orderings (weak and atomic: relaxed, acquire,
diff --git a/llvm/test/CodeGen/NVPTX/convert-sm80.ll b/llvm/test/CodeGen/NVPTX/convert-sm80.ll
index aebc28b1cfea3e..d54fa73d306ffa 100644
--- a/llvm/test/CodeGen/NVPTX/convert-sm80.ll
+++ b/llvm/test/CodeGen/NVPTX/convert-sm80.ll
@@ -261,3 +261,20 @@ define <2 x half> @fold_ff2f16x2(float %lo, float %hi) {
   %v1 = insertelement <2 x half> %v0, half %hih, i64 1
   ret <2 x half> %v1
 }
+
+declare i32 @llvm.nvvm.cvt.float.to.tf32(float %f1, i8, i8, i1)
+
+define i32 @cvt_rna_tf32_f32_flags(float %f1) {
+; CHECK-LABEL: cvt_rna_tf32_f32_flags(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<2>;
+; CHECK-NEXT:    .reg .f32 %f<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.f32 %f1, [cvt_rna_tf32_f32_flags_param_0];
+; CHECK-NEXT:    cvt.rna.tf32.f32 %r1, %f1;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT:    ret;
+  %val = call i32 @llvm.nvvm.cvt.float.to.tf32(float %f1, i8 5, i8 0, i1 0)
+  ret i32 %val
+}
diff --git a/llvm/test/CodeGen/NVPTX/convert-sm89.ll b/llvm/test/CodeGen/NVPTX/convert-sm89.ll
index 5d0576aebbe089..3ff58b95348095 100644
--- a/llvm/test/CodeGen/NVPTX/convert-sm89.ll
+++ b/llvm/test/CodeGen/NVPTX/convert-sm89.ll
@@ -84,3 +84,12 @@ define <2 x half> @cvt_rn_relu_f16x2_e5m2x2(i16 %in) {
   %val = call <2 x half> @llvm.nvvm.e5m2x2.to.f16x2.rn.relu(i16 %in);
   ret <2 x half> %val
 }
+
+declare i32 @llvm.nvvm.cvt.float.to.tf32(float %f1, i8, i8, i1)
+
+; CHECK-LABEL: cvt_rna_satfinite_tf32_f32
+define i32 @cvt_rna_satfinite_tf32_f32(float %f1) {
+; CHECK: cvt.rna.satfinite.tf32.f32
+  %val = call i32 @llvm.nvvm.cvt.float.to.tf32(float %f1, i8 5, i8 1, i1 0)
+  ret i32 %val
+}
diff --git a/llvm/test/CodeGen/NVPTX/convert-sm90.ll b/llvm/test/CodeGen/NVPTX/convert-sm90.ll
new file mode 100644
index 00000000000000..8f932005830250
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/convert-sm90.ll
@@ -0,0 +1,65 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx78| FileCheck --check-prefixes=CHECK %s
+; RUN: %if ptxas-12.0 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx78| %ptxas-verify -arch=sm_90 %}
+
+declare i32 @llvm.nvvm.cvt.float.to.tf32(float %f1, i8, i8, i1)
+
+define i32 @cvt_rn_tf32_f32(float %f1) {
+; CHECK-LABEL: cvt_rn_tf32_f32(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<2>;
+; CHECK-NEXT:    .reg .f32 %f<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.f32 %f1, [cvt_rn_tf32_f32_param_0];
+; CHECK-NEXT:    cvt.rn.tf32.f32 %r1, %f1;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT:    ret;
+  %val = call i32 @llvm.nvvm.cvt.float.to.tf32(float %f1, i8 2, i8 0, i1 0)
+  ret i32 %val
+}
+
+define i32 @cvt_rn_relu_tf32_f32(float %f1) {
+; CHECK-LABEL: cvt_rn_relu_tf32_f32(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<2>;
+; CHECK-NEXT:    .reg .f32 %f<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.f32 %f1, [cvt_rn_relu_tf32_f32_param_0];
+; CHECK-NEXT:    cvt.rn.relu.tf32.f32 %r1, %f1;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT:    ret;
+  %val = call i32 @llvm.nvvm.cvt.float.to.tf32(float %f1, i8 2, i8 0, i1 1)
+  ret i32 %val
+}
+
+define i32 @cvt_rz_tf32_f32(float %f1) {
+; CHECK-LABEL: cvt_rz_tf32_f32(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<2>;
+; CHECK-NEXT:    .reg .f32 %f<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.f32 %f1, [cvt_rz_tf32_f32_param_0];
+; CHECK-NEXT:    cvt.rz.tf32.f32 %r1, %f1;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT:    ret;
+  %val = call i32 @llvm.nvvm.cvt.float.to.tf32(float %f1, i8 1, i8 0, i1 0)
+  ret i32 %val
+}
+
+define i32 @cvt_rz_relu_tf32_f32(float %f1) {
+; CHECK-LABEL: cvt_rz_relu_tf32_f32(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<2>;
+; CHECK-NEXT:    .reg .f32 %f<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.f32 %f1, [cvt_rz_relu_tf32_f32_param_0];
+; CHECK-NEXT:    cvt.rz.relu.tf32.f32 %r1, %f1;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT:    ret;
+  %val = call i32 @llvm.nvvm.cvt.float.to.tf32(float %f1, i8 1, i8 0, i1 1)
+  ret i32 %val
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/121507


More information about the llvm-commits mailing list