[llvm] [NVPTX] Add float to tf32 conversion intrinsic (PR #121507)

Durgadoss R via llvm-commits llvm-commits at lists.llvm.org
Fri Jan 3 11:14:25 PST 2025


https://github.com/durga4github updated https://github.com/llvm/llvm-project/pull/121507

>From c75c744e5464d829e1288d0c4a99b16e0ef4cd3a Mon Sep 17 00:00:00 2001
From: Durgadoss R <durgadossr at nvidia.com>
Date: Tue, 24 Dec 2024 19:34:31 +0530
Subject: [PATCH] [NVPTX] Add convert float to tf32 intrinsics

This patch adds an intrinsic to convert float to tf32.

* This intrinsic uses flags for rounding and saturation
  modes as well as relu. The backend looks through these
  flags and lowers to the appropriate instruction.
* Docs are updated to describe the usage of the flag arguments.
* Lit tests are added for all the combinations.

Note: We already have an intrinsic 'llvm.nvvm.f2tf32.rna'
which caters only to one variant of the PTX instruction. Once
this change lands, I will submit a follow-up PR to auto-upgrade
it to use the generic variant.

PTX Spec link:
https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt

Signed-off-by: Durgadoss R <durgadossr at nvidia.com>
---
 llvm/docs/NVPTXUsage.rst                      | 50 ++++++++++++++
 llvm/include/llvm/IR/IntrinsicsNVVM.td        |  8 +++
 llvm/include/llvm/IR/NVVMIntrinsicFlags.h     |  6 ++
 .../NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp   | 50 ++++++++++++++
 .../NVPTX/MCTargetDesc/NVPTXInstPrinter.h     |  6 ++
 llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp   | 48 +++++++++++++
 llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h     |  1 +
 llvm/lib/Target/NVPTX/NVPTXIntrinsics.td      | 16 +++++
 llvm/lib/Target/NVPTX/NVPTXSubtarget.h        |  1 +
 llvm/test/CodeGen/NVPTX/convert-sm80.ll       | 18 +++++
 llvm/test/CodeGen/NVPTX/convert-sm89.ll       | 10 +++
 llvm/test/CodeGen/NVPTX/convert-sm90.ll       | 67 +++++++++++++++++++
 12 files changed, 281 insertions(+)
 create mode 100644 llvm/test/CodeGen/NVPTX/convert-sm90.ll

diff --git a/llvm/docs/NVPTXUsage.rst b/llvm/docs/NVPTXUsage.rst
index 313e84f3722a95..2970687f8d377a 100644
--- a/llvm/docs/NVPTXUsage.rst
+++ b/llvm/docs/NVPTXUsage.rst
@@ -462,6 +462,56 @@ to left-shift the found bit into the most-significant bit position, otherwise
 the result is the shift amount needed to right-shift the found bit into the
 least-significant bit position. 0xffffffff is returned if no 1 bit is found.
 
+Conversion Intrinsics (for cvt.* PTX instructions)
+--------------------------------------------------
+
+'``llvm.nvvm.convert.to.tf32.float``'
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+
+.. code-block:: llvm
+
+  declare i32 @llvm.nvvm.convert.to.tf32.float(float %f1, metadata !round_mode, i8 %flag_sat_mode, i1 %flag_relu)
+
+Overview:
+"""""""""
+
+The '``@llvm.nvvm.convert.to.tf32.float``' intrinsic lowers to
+the ``cvt.*.tf32.f32`` set of PTX instructions.
+
+* The first argument is the input float to be converted to TF32.
+
+* The second argument (denoted by ``metadata !round_mode``) denotes
+  the floating-point rounding modes supported for this instruction.
+  The metadata strings are the same as the ones used for constrained-fp
+  intrinsics, documented here:
+  `<https://llvm.org/docs/LangRef.html#constrainedfp>`_.
+
+  Only ``round.tonearest, round.towardzero and round.tonearestaway``
+  rounding modes are valid for this intrinsic.
+
+* The third argument (denoted by ``i8 %flag_sat_mode``) denotes the
+  saturation modifier for this intrinsic. As of now, it can either
+  be None or Satfinite, according to the enumeration below:
+
+  ========== ================
+  Enum Value  Saturation Mode
+  ========== ================
+  ``0``       NONE
+  ``1``       SATFINITE
+  ========== ================
+
+* The last argument (denoted by ``i1 %flag_relu``) when set, generates
+  the ``.relu`` variant of the instruction.
+
+* Invalid values for the compile-time arguments may result in
+  error(s) during Codegen.
+
+For more information, refer PTX ISA
+`<https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt>`_.
+
 TMA family of Intrinsics
 ------------------------
 
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index fd07d131ce15b2..9acae49c24892f 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1466,6 +1466,14 @@ let TargetPrefix = "nvvm" in {
   def int_nvvm_e5m2x2_to_f16x2_rn_relu : ClangBuiltin<"__nvvm_e5m2x2_to_f16x2_rn_relu">,
       Intrinsic<[llvm_v2f16_ty], [llvm_i16_ty], [IntrNoMem, IntrNoCallback]>;
 
+// Convert float to TF32
+def int_nvvm_convert_to_tf32_float : DefaultAttrsIntrinsic<[llvm_i32_ty],
+    [llvm_float_ty,    // Input float
+     llvm_metadata_ty, // Metadata for Rounding modes
+     llvm_i8_ty,       // Flag for Saturation modes
+     llvm_i1_ty],      // Flag for relu
+    [IntrNoMem, ImmArg<ArgIndex<2>>, ImmArg<ArgIndex<3>>]>;
+
 // FNS
 
   def int_nvvm_fns : ClangBuiltin<"__nvvm_fns">,
diff --git a/llvm/include/llvm/IR/NVVMIntrinsicFlags.h b/llvm/include/llvm/IR/NVVMIntrinsicFlags.h
index dfb6e857b3a6ad..9c7a21d32d6721 100644
--- a/llvm/include/llvm/IR/NVVMIntrinsicFlags.h
+++ b/llvm/include/llvm/IR/NVVMIntrinsicFlags.h
@@ -34,6 +34,12 @@ enum class TMAReductionOp : uint8_t {
   XOR = 7,
 };
 
+// Saturation Modes
+enum class SaturationMode : uint8_t {
+  NONE = 0,
+  SATFINITE = 1,
+};
+
 } // namespace nvvm
 } // namespace llvm
 #endif // LLVM_IR_NVVMINTRINSICFLAGS_H
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
index 65e1893d3f3bdf..90ff445d889fa2 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
@@ -453,3 +453,53 @@ void NVPTXInstPrinter::printTmaReductionMode(const MCInst *MI, int OpNum,
   llvm_unreachable(
       "Invalid Reduction Op in printCpAsyncBulkTensorReductionMode");
 }
+
+void NVPTXInstPrinter::printFPRoundingMode(const MCInst *MI, int OpNum,
+                                           raw_ostream &O,
+                                           const char *Modifier) {
+  const MCOperand &MO = MI->getOperand(OpNum);
+  switch (static_cast<llvm::RoundingMode>(MO.getImm())) {
+  case llvm::RoundingMode::NearestTiesToEven:
+    O << ".rn";
+    return;
+  case llvm::RoundingMode::NearestTiesToAway:
+    O << ".rna";
+    return;
+  case llvm::RoundingMode::TowardZero:
+    O << ".rz";
+    return;
+  case llvm::RoundingMode::TowardPositive:
+    O << ".rp";
+    return;
+  case llvm::RoundingMode::TowardNegative:
+    O << ".rm";
+    return;
+  default:
+    O << "";
+    return;
+  }
+}
+
+void NVPTXInstPrinter::printSaturationMode(const MCInst *MI, int OpNum,
+                                           raw_ostream &O,
+                                           const char *Modifier) {
+  const MCOperand &MO = MI->getOperand(OpNum);
+  using Mode = nvvm::SaturationMode;
+
+  switch (static_cast<Mode>(MO.getImm())) {
+  case Mode::NONE:
+    O << "";
+    return;
+  case Mode::SATFINITE:
+    O << ".satfinite";
+    return;
+  }
+  llvm_unreachable("Invalid mode in printSaturationMode");
+}
+
+void NVPTXInstPrinter::printReluModifier(const MCInst *MI, int OpNum,
+                                         raw_ostream &O, const char *Modifier) {
+  const MCOperand &MO = MI->getOperand(OpNum);
+  if (MO.getImm())
+    O << ".relu";
+}
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
index 2b19386ef17fe5..7c3be27751ca14 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
@@ -56,6 +56,12 @@ class NVPTXInstPrinter : public MCInstPrinter {
                      const char *Modifier = nullptr);
   void printTmaReductionMode(const MCInst *MI, int OpNum, raw_ostream &O,
                              const char *Modifier = nullptr);
+  void printFPRoundingMode(const MCInst *MI, int OpNum, raw_ostream &O,
+                           const char *Modifier = nullptr);
+  void printSaturationMode(const MCInst *MI, int OpNum, raw_ostream &O,
+                           const char *Modifier = nullptr);
+  void printReluModifier(const MCInst *MI, int OpNum, raw_ostream &O,
+                         const char *Modifier = nullptr);
 };
 
 }
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index c51729e224bf54..7210a09d8caf7f 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -728,7 +728,55 @@ bool NVPTXDAGToDAGISel::tryIntrinsicNoChain(SDNode *N) {
   case Intrinsic::nvvm_texsurf_handle_internal:
     SelectTexSurfHandle(N);
     return true;
+  case Intrinsic::nvvm_convert_to_tf32_float:
+    SelectCvtFloatToTF32(N);
+    return true;
+  }
+}
+
+void NVPTXDAGToDAGISel::SelectCvtFloatToTF32(SDNode *N) {
+  // 0 - IID
+  // 1 - Input float
+  // 2 - Rounding mode as string metadata
+  // 3 - Saturation mode
+  // 4 - Relu flag
+  uint64_t Sat = N->getConstantOperandVal(3);
+  bool IsRelu = N->getConstantOperandVal(4) == 1;
+
+  if (!Subtarget->hasTF32Math())
+    report_fatal_error("TF32 destination format requires at least sm80");
+
+  using SatMode = nvvm::SaturationMode;
+  bool IsSatFinite = static_cast<SatMode>(Sat) == SatMode::SATFINITE;
+  if (IsSatFinite && Subtarget->getPTXVersion() < 81)
+    report_fatal_error("satfinite modifier requires PTX version 8.1 or higher");
+
+  const MDNode *MD = cast<MDNodeSDNode>(N->getOperand(2))->getMD();
+  auto RndString = cast<MDString>(MD->getOperand(0))->getString();
+  std::optional<RoundingMode> RndVal = convertStrToRoundingMode(RndString);
+  switch (*RndVal) {
+  case RoundingMode::NearestTiesToAway:
+    if (IsRelu)
+      report_fatal_error("relu not supported with rna rounding mode");
+    break;
+  case RoundingMode::NearestTiesToEven:
+  case RoundingMode::TowardZero: {
+    if (Subtarget->getSmVersion() < 90)
+      report_fatal_error("rn/rz rounding modes require at least sm90");
+    if (IsSatFinite)
+      report_fatal_error("satfinite not supported with rn/rz rounding modes");
+    break;
+  }
+  default:
+    report_fatal_error("Invalid FP rounding mode in SelectCvtFloatToTF32");
   }
+
+  SDLoc DL(N);
+  SDValue Ops[] = {N->getOperand(1),
+                   getI32Imm(static_cast<unsigned>(*RndVal), DL),
+                   getI32Imm(Sat, DL), getI32Imm(IsRelu, DL)};
+  ReplaceNode(N, CurDAG->getMachineNode(NVPTX::cvt_float_to_tf32, DL,
+                                        N->getVTList(), Ops));
 }
 
 void NVPTXDAGToDAGISel::SelectTexSurfHandle(SDNode *N) {
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
index c307f28fcc6c0a..3e22ef5bab9931 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
@@ -73,6 +73,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
   bool tryIntrinsicChain(SDNode *N);
   bool tryIntrinsicVoid(SDNode *N);
   void SelectTexSurfHandle(SDNode *N);
+  void SelectCvtFloatToTF32(SDNode *N);
   bool tryLoad(SDNode *N);
   bool tryLoadVector(SDNode *N);
   bool tryLDGLDU(SDNode *N);
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 8ede1ec4f20dc9..3274c1ef4260db 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -1802,6 +1802,22 @@ def : Pat<(int_nvvm_e5m2x2_to_f16x2_rn Int16Regs:$a),
 def : Pat<(int_nvvm_e5m2x2_to_f16x2_rn_relu Int16Regs:$a),
           (CVT_f16x2_e5m2x2 $a, CvtRN_RELU)>;
 
+def FPRoundingMode : Operand<i32> {
+  let PrintMethod = "printFPRoundingMode";
+}
+
+def SatMode : Operand<i32> {
+  let PrintMethod = "printSaturationMode";
+}
+
+def ReluFlag : Operand<i32> {
+  let PrintMethod = "printReluModifier";
+}
+
+def cvt_float_to_tf32 : NVPTXInst<(outs Int32Regs:$dest),
+    (ins Float32Regs:$a, FPRoundingMode:$rnd, SatMode:$sat, ReluFlag:$relu),
+    "cvt${rnd:rnd}${sat:sat}${relu:relu}.tf32.f32 \t$dest, $a;", []>;
+
 //
 // FNS
 //
diff --git a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
index 7555a2368ec963..9f0b437bd32dc5 100644
--- a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
+++ b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
@@ -83,6 +83,7 @@ class NVPTXSubtarget : public NVPTXGenSubtargetInfo {
   bool hasFP16Math() const { return SmVersion >= 53; }
   bool hasBF16Math() const { return SmVersion >= 80; }
   bool allowFP16Math() const;
+  bool hasTF32Math() const { return SmVersion >= 80 && PTXVersion >= 70; }
   bool hasMaskOperator() const { return PTXVersion >= 71; }
   bool hasNoReturn() const { return SmVersion >= 30 && PTXVersion >= 64; }
   // Does SM & PTX support memory orderings (weak and atomic: relaxed, acquire,
diff --git a/llvm/test/CodeGen/NVPTX/convert-sm80.ll b/llvm/test/CodeGen/NVPTX/convert-sm80.ll
index aebc28b1cfea3e..7172a980ec686f 100644
--- a/llvm/test/CodeGen/NVPTX/convert-sm80.ll
+++ b/llvm/test/CodeGen/NVPTX/convert-sm80.ll
@@ -261,3 +261,21 @@ define <2 x half> @fold_ff2f16x2(float %lo, float %hi) {
   %v1 = insertelement <2 x half> %v0, half %hih, i64 1
   ret <2 x half> %v1
 }
+
+declare i32 @llvm.nvvm.convert.to.tf32.float(float %f1, metadata, i8, i1)
+
+define i32 @cvt_rna_tf32_f32_flags(float %f1) {
+; CHECK-LABEL: cvt_rna_tf32_f32_flags(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<2>;
+; CHECK-NEXT:    .reg .f32 %f<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.f32 %f1, [cvt_rna_tf32_f32_flags_param_0];
+; CHECK-NEXT:    cvt.rna.tf32.f32 %r1, %f1;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT:    ret;
+  %val = call i32 @llvm.nvvm.convert.to.tf32.float(float %f1, metadata !0, i8 0, i1 0)
+  ret i32 %val
+}
+!0 = !{!"round.tonearestaway"}
diff --git a/llvm/test/CodeGen/NVPTX/convert-sm89.ll b/llvm/test/CodeGen/NVPTX/convert-sm89.ll
index 5d0576aebbe089..579d00efc1de24 100644
--- a/llvm/test/CodeGen/NVPTX/convert-sm89.ll
+++ b/llvm/test/CodeGen/NVPTX/convert-sm89.ll
@@ -84,3 +84,13 @@ define <2 x half> @cvt_rn_relu_f16x2_e5m2x2(i16 %in) {
   %val = call <2 x half> @llvm.nvvm.e5m2x2.to.f16x2.rn.relu(i16 %in);
   ret <2 x half> %val
 }
+
+declare i32 @llvm.nvvm.convert.to.tf32.float(float %f1, metadata, i8, i1)
+
+; CHECK-LABEL: cvt_rna_satfinite_tf32_f32
+define i32 @cvt_rna_satfinite_tf32_f32(float %f1) {
+; CHECK: cvt.rna.satfinite.tf32.f32
+  %val = call i32 @llvm.nvvm.convert.to.tf32.float(float %f1, metadata !0, i8 1, i1 0)
+  ret i32 %val
+}
+!0 = !{!"round.tonearestaway"}
diff --git a/llvm/test/CodeGen/NVPTX/convert-sm90.ll b/llvm/test/CodeGen/NVPTX/convert-sm90.ll
new file mode 100644
index 00000000000000..196ba084a90699
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/convert-sm90.ll
@@ -0,0 +1,67 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx78| FileCheck --check-prefixes=CHECK %s
+; RUN: %if ptxas-12.0 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx78| %ptxas-verify -arch=sm_90 %}
+
+declare i32 @llvm.nvvm.convert.to.tf32.float(float %f1, metadata, i8, i1)
+
+define i32 @cvt_rn_tf32_f32(float %f1) {
+; CHECK-LABEL: cvt_rn_tf32_f32(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<2>;
+; CHECK-NEXT:    .reg .f32 %f<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.f32 %f1, [cvt_rn_tf32_f32_param_0];
+; CHECK-NEXT:    cvt.rn.tf32.f32 %r1, %f1;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT:    ret;
+  %val = call i32 @llvm.nvvm.convert.to.tf32.float(float %f1, metadata !0, i8 0, i1 0)
+  ret i32 %val
+}
+
+define i32 @cvt_rn_relu_tf32_f32(float %f1) {
+; CHECK-LABEL: cvt_rn_relu_tf32_f32(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<2>;
+; CHECK-NEXT:    .reg .f32 %f<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.f32 %f1, [cvt_rn_relu_tf32_f32_param_0];
+; CHECK-NEXT:    cvt.rn.relu.tf32.f32 %r1, %f1;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT:    ret;
+  %val = call i32 @llvm.nvvm.convert.to.tf32.float(float %f1, metadata !0, i8 0, i1 1)
+  ret i32 %val
+}
+
+define i32 @cvt_rz_tf32_f32(float %f1) {
+; CHECK-LABEL: cvt_rz_tf32_f32(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<2>;
+; CHECK-NEXT:    .reg .f32 %f<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.f32 %f1, [cvt_rz_tf32_f32_param_0];
+; CHECK-NEXT:    cvt.rz.tf32.f32 %r1, %f1;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT:    ret;
+  %val = call i32 @llvm.nvvm.convert.to.tf32.float(float %f1, metadata !1, i8 0, i1 0)
+  ret i32 %val
+}
+
+define i32 @cvt_rz_relu_tf32_f32(float %f1) {
+; CHECK-LABEL: cvt_rz_relu_tf32_f32(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<2>;
+; CHECK-NEXT:    .reg .f32 %f<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.f32 %f1, [cvt_rz_relu_tf32_f32_param_0];
+; CHECK-NEXT:    cvt.rz.relu.tf32.f32 %r1, %f1;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT:    ret;
+  %val = call i32 @llvm.nvvm.convert.to.tf32.float(float %f1, metadata !1, i8 0, i1 1)
+  ret i32 %val
+}
+!0 = !{!"round.tonearest"}
+!1 = !{!"round.towardzero"}



More information about the llvm-commits mailing list