[llvm] [NVPTX] Add float to tf32 conversion intrinsic (PR #121507)
Durgadoss R via llvm-commits
llvm-commits at lists.llvm.org
Fri Jan 3 11:14:25 PST 2025
https://github.com/durga4github updated https://github.com/llvm/llvm-project/pull/121507
>From c75c744e5464d829e1288d0c4a99b16e0ef4cd3a Mon Sep 17 00:00:00 2001
From: Durgadoss R <durgadossr at nvidia.com>
Date: Tue, 24 Dec 2024 19:34:31 +0530
Subject: [PATCH] [NVPTX] Add convert float to tf32 intrinsics
This patch adds an intrinsic to convert float to tf32.
* This intrinsic uses flags for rounding and saturation
modes as well as relu. The backend looks through these
flags and lowers to the appropriate instruction.
* Docs are updated to describe the usage of the flag arguments.
* Lit tests are added for all the combinations.
Note: We already have an intrinsic 'llvm.nvvm.f2tf32.rna'
which caters only to one variant of the PTX instruction. Once
this change lands, I will submit a follow-up PR to auto-upgrade
it to use the generic variant.
PTX Spec link:
https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt
Signed-off-by: Durgadoss R <durgadossr at nvidia.com>
---
llvm/docs/NVPTXUsage.rst | 50 ++++++++++++++
llvm/include/llvm/IR/IntrinsicsNVVM.td | 8 +++
llvm/include/llvm/IR/NVVMIntrinsicFlags.h | 6 ++
.../NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp | 50 ++++++++++++++
.../NVPTX/MCTargetDesc/NVPTXInstPrinter.h | 6 ++
llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp | 48 +++++++++++++
llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h | 1 +
llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 16 +++++
llvm/lib/Target/NVPTX/NVPTXSubtarget.h | 1 +
llvm/test/CodeGen/NVPTX/convert-sm80.ll | 18 +++++
llvm/test/CodeGen/NVPTX/convert-sm89.ll | 10 +++
llvm/test/CodeGen/NVPTX/convert-sm90.ll | 67 +++++++++++++++++++
12 files changed, 281 insertions(+)
create mode 100644 llvm/test/CodeGen/NVPTX/convert-sm90.ll
diff --git a/llvm/docs/NVPTXUsage.rst b/llvm/docs/NVPTXUsage.rst
index 313e84f3722a95..2970687f8d377a 100644
--- a/llvm/docs/NVPTXUsage.rst
+++ b/llvm/docs/NVPTXUsage.rst
@@ -462,6 +462,56 @@ to left-shift the found bit into the most-significant bit position, otherwise
the result is the shift amount needed to right-shift the found bit into the
least-significant bit position. 0xffffffff is returned if no 1 bit is found.
+Conversion Intrinsics (for cvt.* PTX instructions)
+--------------------------------------------------
+
+'``llvm.nvvm.convert.to.tf32.float``'
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Syntax:
+"""""""
+
+.. code-block:: llvm
+
+ declare i32 @llvm.nvvm.convert.to.tf32.float(float %f1, metadata !round_mode, i8 %flag_sat_mode, i1 %flag_relu)
+
+Overview:
+"""""""""
+
+The '``@llvm.nvvm.convert.to.tf32.float``' intrinsic lowers to
+the ``cvt.*.tf32.f32`` set of PTX instructions.
+
+* The first argument is the input float to be converted to TF32.
+
+* The second argument (denoted by ``metadata !round_mode``) denotes
+ the floating-point rounding modes supported for this instruction.
+ The metadata strings are the same as the ones used for constrained-fp
+ intrinsics, documented here:
+ `<https://llvm.org/docs/LangRef.html#constrainedfp>`_.
+
+ Only ``round.tonearest, round.towardzero and round.tonearestaway``
+ rounding modes are valid for this intrinsic.
+
+* The third argument (denoted by ``i8 %flag_sat_mode``) denotes the
+ saturation modifier for this intrinsic. As of now, it can either
+ be None or Satfinite, according to the enumeration below:
+
+ ========== ================
+ Enum Value Saturation Mode
+ ========== ================
+ ``0`` NONE
+ ``1`` SATFINITE
+ ========== ================
+
+* The last argument (denoted by ``i1 %flag_relu``) when set, generates
+ the ``.relu`` variant of the instruction.
+
+* Invalid values for the compile-time arguments may result in
+ error(s) during Codegen.
+
+For more information, refer PTX ISA
+`<https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt>`_.
+
TMA family of Intrinsics
------------------------
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index fd07d131ce15b2..9acae49c24892f 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1466,6 +1466,14 @@ let TargetPrefix = "nvvm" in {
def int_nvvm_e5m2x2_to_f16x2_rn_relu : ClangBuiltin<"__nvvm_e5m2x2_to_f16x2_rn_relu">,
Intrinsic<[llvm_v2f16_ty], [llvm_i16_ty], [IntrNoMem, IntrNoCallback]>;
+// Convert float to TF32
+def int_nvvm_convert_to_tf32_float : DefaultAttrsIntrinsic<[llvm_i32_ty],
+ [llvm_float_ty, // Input float
+ llvm_metadata_ty, // Metadata for Rounding modes
+ llvm_i8_ty, // Flag for Saturation modes
+ llvm_i1_ty], // Flag for relu
+ [IntrNoMem, ImmArg<ArgIndex<2>>, ImmArg<ArgIndex<3>>]>;
+
// FNS
def int_nvvm_fns : ClangBuiltin<"__nvvm_fns">,
diff --git a/llvm/include/llvm/IR/NVVMIntrinsicFlags.h b/llvm/include/llvm/IR/NVVMIntrinsicFlags.h
index dfb6e857b3a6ad..9c7a21d32d6721 100644
--- a/llvm/include/llvm/IR/NVVMIntrinsicFlags.h
+++ b/llvm/include/llvm/IR/NVVMIntrinsicFlags.h
@@ -34,6 +34,12 @@ enum class TMAReductionOp : uint8_t {
XOR = 7,
};
+// Saturation Modes
+enum class SaturationMode : uint8_t {
+ NONE = 0,
+ SATFINITE = 1,
+};
+
} // namespace nvvm
} // namespace llvm
#endif // LLVM_IR_NVVMINTRINSICFLAGS_H
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
index 65e1893d3f3bdf..90ff445d889fa2 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
@@ -453,3 +453,53 @@ void NVPTXInstPrinter::printTmaReductionMode(const MCInst *MI, int OpNum,
llvm_unreachable(
"Invalid Reduction Op in printCpAsyncBulkTensorReductionMode");
}
+
+void NVPTXInstPrinter::printFPRoundingMode(const MCInst *MI, int OpNum,
+ raw_ostream &O,
+ const char *Modifier) {
+ const MCOperand &MO = MI->getOperand(OpNum);
+ switch (static_cast<llvm::RoundingMode>(MO.getImm())) {
+ case llvm::RoundingMode::NearestTiesToEven:
+ O << ".rn";
+ return;
+ case llvm::RoundingMode::NearestTiesToAway:
+ O << ".rna";
+ return;
+ case llvm::RoundingMode::TowardZero:
+ O << ".rz";
+ return;
+ case llvm::RoundingMode::TowardPositive:
+ O << ".rp";
+ return;
+ case llvm::RoundingMode::TowardNegative:
+ O << ".rm";
+ return;
+ default:
+ O << "";
+ return;
+ }
+}
+
+void NVPTXInstPrinter::printSaturationMode(const MCInst *MI, int OpNum,
+ raw_ostream &O,
+ const char *Modifier) {
+ const MCOperand &MO = MI->getOperand(OpNum);
+ using Mode = nvvm::SaturationMode;
+
+ switch (static_cast<Mode>(MO.getImm())) {
+ case Mode::NONE:
+ O << "";
+ return;
+ case Mode::SATFINITE:
+ O << ".satfinite";
+ return;
+ }
+ llvm_unreachable("Invalid mode in printSaturationMode");
+}
+
+void NVPTXInstPrinter::printReluModifier(const MCInst *MI, int OpNum,
+ raw_ostream &O, const char *Modifier) {
+ const MCOperand &MO = MI->getOperand(OpNum);
+ if (MO.getImm())
+ O << ".relu";
+}
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
index 2b19386ef17fe5..7c3be27751ca14 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
@@ -56,6 +56,12 @@ class NVPTXInstPrinter : public MCInstPrinter {
const char *Modifier = nullptr);
void printTmaReductionMode(const MCInst *MI, int OpNum, raw_ostream &O,
const char *Modifier = nullptr);
+ void printFPRoundingMode(const MCInst *MI, int OpNum, raw_ostream &O,
+ const char *Modifier = nullptr);
+ void printSaturationMode(const MCInst *MI, int OpNum, raw_ostream &O,
+ const char *Modifier = nullptr);
+ void printReluModifier(const MCInst *MI, int OpNum, raw_ostream &O,
+ const char *Modifier = nullptr);
};
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index c51729e224bf54..7210a09d8caf7f 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -728,7 +728,55 @@ bool NVPTXDAGToDAGISel::tryIntrinsicNoChain(SDNode *N) {
case Intrinsic::nvvm_texsurf_handle_internal:
SelectTexSurfHandle(N);
return true;
+ case Intrinsic::nvvm_convert_to_tf32_float:
+ SelectCvtFloatToTF32(N);
+ return true;
+ }
+}
+
+void NVPTXDAGToDAGISel::SelectCvtFloatToTF32(SDNode *N) {
+ // 0 - IID
+ // 1 - Input float
+ // 2 - Rounding mode as string metadata
+ // 3 - Saturation mode
+ // 4 - Relu flag
+ uint64_t Sat = N->getConstantOperandVal(3);
+ bool IsRelu = N->getConstantOperandVal(4) == 1;
+
+ if (!Subtarget->hasTF32Math())
+ report_fatal_error("TF32 destination format requires at least sm80");
+
+ using SatMode = nvvm::SaturationMode;
+ bool IsSatFinite = static_cast<SatMode>(Sat) == SatMode::SATFINITE;
+ if (IsSatFinite && Subtarget->getPTXVersion() < 81)
+ report_fatal_error("satfinite modifier requires PTX version 8.1 or higher");
+
+ const MDNode *MD = cast<MDNodeSDNode>(N->getOperand(2))->getMD();
+ auto RndString = cast<MDString>(MD->getOperand(0))->getString();
+ std::optional<RoundingMode> RndVal = convertStrToRoundingMode(RndString);
+ switch (*RndVal) {
+ case RoundingMode::NearestTiesToAway:
+ if (IsRelu)
+ report_fatal_error("relu not supported with rna rounding mode");
+ break;
+ case RoundingMode::NearestTiesToEven:
+ case RoundingMode::TowardZero: {
+ if (Subtarget->getSmVersion() < 90)
+ report_fatal_error("rn/rz rounding modes require at least sm90");
+ if (IsSatFinite)
+ report_fatal_error("satfinite not supported with rn/rz rounding modes");
+ break;
+ }
+ default:
+ report_fatal_error("Invalid FP rounding mode in SelectCvtFloatToTF32");
}
+
+ SDLoc DL(N);
+ SDValue Ops[] = {N->getOperand(1),
+ getI32Imm(static_cast<unsigned>(*RndVal), DL),
+ getI32Imm(Sat, DL), getI32Imm(IsRelu, DL)};
+ ReplaceNode(N, CurDAG->getMachineNode(NVPTX::cvt_float_to_tf32, DL,
+ N->getVTList(), Ops));
}
void NVPTXDAGToDAGISel::SelectTexSurfHandle(SDNode *N) {
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
index c307f28fcc6c0a..3e22ef5bab9931 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
@@ -73,6 +73,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
bool tryIntrinsicChain(SDNode *N);
bool tryIntrinsicVoid(SDNode *N);
void SelectTexSurfHandle(SDNode *N);
+ void SelectCvtFloatToTF32(SDNode *N);
bool tryLoad(SDNode *N);
bool tryLoadVector(SDNode *N);
bool tryLDGLDU(SDNode *N);
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 8ede1ec4f20dc9..3274c1ef4260db 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -1802,6 +1802,22 @@ def : Pat<(int_nvvm_e5m2x2_to_f16x2_rn Int16Regs:$a),
def : Pat<(int_nvvm_e5m2x2_to_f16x2_rn_relu Int16Regs:$a),
(CVT_f16x2_e5m2x2 $a, CvtRN_RELU)>;
+def FPRoundingMode : Operand<i32> {
+ let PrintMethod = "printFPRoundingMode";
+}
+
+def SatMode : Operand<i32> {
+ let PrintMethod = "printSaturationMode";
+}
+
+def ReluFlag : Operand<i32> {
+ let PrintMethod = "printReluModifier";
+}
+
+def cvt_float_to_tf32 : NVPTXInst<(outs Int32Regs:$dest),
+ (ins Float32Regs:$a, FPRoundingMode:$rnd, SatMode:$sat, ReluFlag:$relu),
+ "cvt${rnd:rnd}${sat:sat}${relu:relu}.tf32.f32 \t$dest, $a;", []>;
+
//
// FNS
//
diff --git a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
index 7555a2368ec963..9f0b437bd32dc5 100644
--- a/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
+++ b/llvm/lib/Target/NVPTX/NVPTXSubtarget.h
@@ -83,6 +83,7 @@ class NVPTXSubtarget : public NVPTXGenSubtargetInfo {
bool hasFP16Math() const { return SmVersion >= 53; }
bool hasBF16Math() const { return SmVersion >= 80; }
bool allowFP16Math() const;
+ bool hasTF32Math() const { return SmVersion >= 80 && PTXVersion >= 70; }
bool hasMaskOperator() const { return PTXVersion >= 71; }
bool hasNoReturn() const { return SmVersion >= 30 && PTXVersion >= 64; }
// Does SM & PTX support memory orderings (weak and atomic: relaxed, acquire,
diff --git a/llvm/test/CodeGen/NVPTX/convert-sm80.ll b/llvm/test/CodeGen/NVPTX/convert-sm80.ll
index aebc28b1cfea3e..7172a980ec686f 100644
--- a/llvm/test/CodeGen/NVPTX/convert-sm80.ll
+++ b/llvm/test/CodeGen/NVPTX/convert-sm80.ll
@@ -261,3 +261,21 @@ define <2 x half> @fold_ff2f16x2(float %lo, float %hi) {
%v1 = insertelement <2 x half> %v0, half %hih, i64 1
ret <2 x half> %v1
}
+
+declare i32 @llvm.nvvm.convert.to.tf32.float(float %f1, metadata, i8, i1)
+
+define i32 @cvt_rna_tf32_f32_flags(float %f1) {
+; CHECK-LABEL: cvt_rna_tf32_f32_flags(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<2>;
+; CHECK-NEXT: .reg .f32 %f<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.f32 %f1, [cvt_rna_tf32_f32_flags_param_0];
+; CHECK-NEXT: cvt.rna.tf32.f32 %r1, %f1;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT: ret;
+ %val = call i32 @llvm.nvvm.convert.to.tf32.float(float %f1, metadata !0, i8 0, i1 0)
+ ret i32 %val
+}
+!0 = !{!"round.tonearestaway"}
diff --git a/llvm/test/CodeGen/NVPTX/convert-sm89.ll b/llvm/test/CodeGen/NVPTX/convert-sm89.ll
index 5d0576aebbe089..579d00efc1de24 100644
--- a/llvm/test/CodeGen/NVPTX/convert-sm89.ll
+++ b/llvm/test/CodeGen/NVPTX/convert-sm89.ll
@@ -84,3 +84,13 @@ define <2 x half> @cvt_rn_relu_f16x2_e5m2x2(i16 %in) {
%val = call <2 x half> @llvm.nvvm.e5m2x2.to.f16x2.rn.relu(i16 %in);
ret <2 x half> %val
}
+
+declare i32 @llvm.nvvm.convert.to.tf32.float(float %f1, metadata, i8, i1)
+
+; CHECK-LABEL: cvt_rna_satfinite_tf32_f32
+define i32 @cvt_rna_satfinite_tf32_f32(float %f1) {
+; CHECK: cvt.rna.satfinite.tf32.f32
+ %val = call i32 @llvm.nvvm.convert.to.tf32.float(float %f1, metadata !0, i8 1, i1 0)
+ ret i32 %val
+}
+!0 = !{!"round.tonearestaway"}
diff --git a/llvm/test/CodeGen/NVPTX/convert-sm90.ll b/llvm/test/CodeGen/NVPTX/convert-sm90.ll
new file mode 100644
index 00000000000000..196ba084a90699
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/convert-sm90.ll
@@ -0,0 +1,67 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx78| FileCheck --check-prefixes=CHECK %s
+; RUN: %if ptxas-12.0 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx78| %ptxas-verify -arch=sm_90 %}
+
+declare i32 @llvm.nvvm.convert.to.tf32.float(float %f1, metadata, i8, i1)
+
+define i32 @cvt_rn_tf32_f32(float %f1) {
+; CHECK-LABEL: cvt_rn_tf32_f32(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<2>;
+; CHECK-NEXT: .reg .f32 %f<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.f32 %f1, [cvt_rn_tf32_f32_param_0];
+; CHECK-NEXT: cvt.rn.tf32.f32 %r1, %f1;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT: ret;
+ %val = call i32 @llvm.nvvm.convert.to.tf32.float(float %f1, metadata !0, i8 0, i1 0)
+ ret i32 %val
+}
+
+define i32 @cvt_rn_relu_tf32_f32(float %f1) {
+; CHECK-LABEL: cvt_rn_relu_tf32_f32(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<2>;
+; CHECK-NEXT: .reg .f32 %f<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.f32 %f1, [cvt_rn_relu_tf32_f32_param_0];
+; CHECK-NEXT: cvt.rn.relu.tf32.f32 %r1, %f1;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT: ret;
+ %val = call i32 @llvm.nvvm.convert.to.tf32.float(float %f1, metadata !0, i8 0, i1 1)
+ ret i32 %val
+}
+
+define i32 @cvt_rz_tf32_f32(float %f1) {
+; CHECK-LABEL: cvt_rz_tf32_f32(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<2>;
+; CHECK-NEXT: .reg .f32 %f<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.f32 %f1, [cvt_rz_tf32_f32_param_0];
+; CHECK-NEXT: cvt.rz.tf32.f32 %r1, %f1;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT: ret;
+ %val = call i32 @llvm.nvvm.convert.to.tf32.float(float %f1, metadata !1, i8 0, i1 0)
+ ret i32 %val
+}
+
+define i32 @cvt_rz_relu_tf32_f32(float %f1) {
+; CHECK-LABEL: cvt_rz_relu_tf32_f32(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<2>;
+; CHECK-NEXT: .reg .f32 %f<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: ld.param.f32 %f1, [cvt_rz_relu_tf32_f32_param_0];
+; CHECK-NEXT: cvt.rz.relu.tf32.f32 %r1, %f1;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT: ret;
+ %val = call i32 @llvm.nvvm.convert.to.tf32.float(float %f1, metadata !1, i8 0, i1 1)
+ ret i32 %val
+}
+!0 = !{!"round.tonearest"}
+!1 = !{!"round.towardzero"}
More information about the llvm-commits
mailing list