[llvm] [NVPTX] Add float to tf32 conversion intrinsic (PR #121507)

Durgadoss R via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 9 06:09:42 PST 2025


https://github.com/durga4github updated https://github.com/llvm/llvm-project/pull/121507

>From bf34c2e595f9503306fa9ac59d60ba95a8762daa Mon Sep 17 00:00:00 2001
From: Durgadoss R <durgadossr at nvidia.com>
Date: Tue, 7 Jan 2025 20:16:17 +0530
Subject: [PATCH] [NVPTX] Add convert float to tf32 intrinsics

This patch adds the missing variants of float to tf32
conversion intrinsics. Lit tests are added for all the
intrinsics.

PTX Spec link:
https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt

Signed-off-by: Durgadoss R <durgadossr at nvidia.com>
---
 llvm/include/llvm/IR/IntrinsicsNVVM.td        | 10 +++
 .../NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp   |  5 ++
 llvm/lib/Target/NVPTX/NVPTX.h                 |  3 +-
 llvm/lib/Target/NVPTX/NVPTXInstrInfo.td       |  9 +++
 llvm/lib/Target/NVPTX/NVPTXIntrinsics.td      | 22 ++++--
 llvm/test/CodeGen/NVPTX/convert-sm89.ll       |  7 ++
 llvm/test/CodeGen/NVPTX/convert-sm90.ll       | 68 +++++++++++++++++++
 7 files changed, 119 insertions(+), 5 deletions(-)
 create mode 100644 llvm/test/CodeGen/NVPTX/convert-sm90.ll

diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index fd07d131ce15b2..8c171aa73e05d9 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1438,6 +1438,16 @@ let TargetPrefix = "nvvm" in {
 
   def int_nvvm_f2tf32_rna : ClangBuiltin<"__nvvm_f2tf32_rna">,
       Intrinsic<[llvm_i32_ty], [llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
+  def int_nvvm_f2tf32_rna_satfinite : ClangBuiltin<"__nvvm_f2tf32_rna_satfinite">,
+      Intrinsic<[llvm_i32_ty], [llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
+  def int_nvvm_f2tf32_rn : ClangBuiltin<"__nvvm_f2tf32_rn">,
+      Intrinsic<[llvm_i32_ty], [llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
+  def int_nvvm_f2tf32_rn_relu : ClangBuiltin<"__nvvm_f2tf32_rn_relu">,
+      Intrinsic<[llvm_i32_ty], [llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
+  def int_nvvm_f2tf32_rz : ClangBuiltin<"__nvvm_f2tf32_rz">,
+      Intrinsic<[llvm_i32_ty], [llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
+  def int_nvvm_f2tf32_rz_relu : ClangBuiltin<"__nvvm_f2tf32_rz_relu">,
+      Intrinsic<[llvm_i32_ty], [llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
 
   def int_nvvm_ff_to_e4m3x2_rn : ClangBuiltin<"__nvvm_ff_to_e4m3x2_rn">,
       Intrinsic<[llvm_i16_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
index d34f45fcac0087..c11262149c813e 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
@@ -110,6 +110,11 @@ void NVPTXInstPrinter::printCvtMode(const MCInst *MI, int OpNum, raw_ostream &O,
     if (Imm & NVPTX::PTXCvtMode::SAT_FLAG)
       O << ".sat";
     return;
+  } else if (Modifier == "satfinite") {
+    // SATFINITE flag
+    if (Imm & NVPTX::PTXCvtMode::SATFINITE_FLAG)
+      O << ".satfinite";
+    return;
   } else if (Modifier == "relu") {
     // RELU flag
     if (Imm & NVPTX::PTXCvtMode::RELU_FLAG)
diff --git a/llvm/lib/Target/NVPTX/NVPTX.h b/llvm/lib/Target/NVPTX/NVPTX.h
index ca915cd3f3732f..f45126c408b0dc 100644
--- a/llvm/lib/Target/NVPTX/NVPTX.h
+++ b/llvm/lib/Target/NVPTX/NVPTX.h
@@ -182,7 +182,8 @@ enum CvtMode {
   BASE_MASK = 0x0F,
   FTZ_FLAG = 0x10,
   SAT_FLAG = 0x20,
-  RELU_FLAG = 0x40
+  RELU_FLAG = 0x40,
+  SATFINITE_FLAG = 0x80
 };
 }
 
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index c3e72d6ce3a3f8..2fe01be0c3010a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -64,6 +64,9 @@ def CvtNONE_RELU   : PatLeaf<(i32 0x40)>;
 def CvtRN_RELU   : PatLeaf<(i32 0x45)>;
 def CvtRZ_RELU   : PatLeaf<(i32 0x46)>;
 
+def CvtNONE_SATFINITE : PatLeaf<(i32 0x80)>;
+def CvtRNA_SATFINITE  : PatLeaf<(i32 0x89)>;
+
 def CvtMode : Operand<i32> {
   let PrintMethod = "printCvtMode";
 }
@@ -725,6 +728,12 @@ let hasSideEffects = false in {
 
   def CVT_f16x2_e4m3x2 : CVT_f16x2_fp8<"e4m3">;
   def CVT_f16x2_e5m2x2 : CVT_f16x2_fp8<"e5m2">;
+
+  // Float to TF32 conversions.
+  def CVT_tf32_f32 : NVPTXInst<(outs Int32Regs:$dst),
+                     (ins Float32Regs:$src, CvtMode:$mode),
+                     !strconcat("cvt${mode:base}${mode:relu}${mode:satfinite}.",
+                     "tf32.f32 \t$dst, $src;"), []>;
 }
 
 def fpround_oneuse : PatFrag<(ops node:$a), (fpround node:$a), [{
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 8ede1ec4f20dc9..a4c1f6e0730f9a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -1660,10 +1660,24 @@ def : Pat<(int_nvvm_f2bf16_rz f32:$a),
 def : Pat<(int_nvvm_f2bf16_rz_relu f32:$a),
           (CVT_bf16_f32 $a, CvtRZ_RELU)>;
 
-def CVT_tf32_f32 :
-   NVPTXInst<(outs Int32Regs:$dest), (ins Float32Regs:$a),
-                   "cvt.rna.tf32.f32 \t$dest, $a;",
-       [(set i32:$dest, (int_nvvm_f2tf32_rna f32:$a))]>;
+def : Pat<(int_nvvm_f2tf32_rna f32:$a),
+          (CVT_tf32_f32 $a, CvtRNA)>,
+          Requires<[hasPTX<70>, hasSM<80>]>;
+def : Pat<(int_nvvm_f2tf32_rna_satfinite f32:$a),
+          (CVT_tf32_f32 $a, CvtRNA_SATFINITE)>,
+          Requires<[hasPTX<81>, hasSM<89>]>;
+def : Pat<(int_nvvm_f2tf32_rn f32:$a),
+          (CVT_tf32_f32 $a, CvtRN)>,
+          Requires<[hasPTX<78>, hasSM<90>]>;
+def : Pat<(int_nvvm_f2tf32_rn_relu f32:$a),
+          (CVT_tf32_f32 $a, CvtRN_RELU)>,
+          Requires<[hasPTX<78>, hasSM<90>]>;
+def : Pat<(int_nvvm_f2tf32_rz f32:$a),
+          (CVT_tf32_f32 $a, CvtRZ)>,
+          Requires<[hasPTX<78>, hasSM<90>]>;
+def : Pat<(int_nvvm_f2tf32_rz_relu f32:$a),
+          (CVT_tf32_f32 $a, CvtRZ_RELU)>,
+          Requires<[hasPTX<78>, hasSM<90>]>;
 
 def INT_NVVM_LOHI_I2D : F_MATH_2<"mov.b64 \t$dst, {{$src0, $src1}};",
   Float64Regs, Int32Regs, Int32Regs, int_nvvm_lohi_i2d>;
diff --git a/llvm/test/CodeGen/NVPTX/convert-sm89.ll b/llvm/test/CodeGen/NVPTX/convert-sm89.ll
index 5d0576aebbe089..30fd76f5a31c23 100644
--- a/llvm/test/CodeGen/NVPTX/convert-sm89.ll
+++ b/llvm/test/CodeGen/NVPTX/convert-sm89.ll
@@ -84,3 +84,10 @@ define <2 x half> @cvt_rn_relu_f16x2_e5m2x2(i16 %in) {
   %val = call <2 x half> @llvm.nvvm.e5m2x2.to.f16x2.rn.relu(i16 %in);
   ret <2 x half> %val
 }
+
+; CHECK-LABEL: cvt_rna_satfinite_tf32_f32
+define i32 @cvt_rna_satfinite_tf32_f32(float %f1) {
+; CHECK: cvt.rna.satfinite.tf32.f32
+  %val = call i32 @llvm.nvvm.f2tf32.rna.satfinite(float %f1)
+  ret i32 %val
+}
diff --git a/llvm/test/CodeGen/NVPTX/convert-sm90.ll b/llvm/test/CodeGen/NVPTX/convert-sm90.ll
new file mode 100644
index 00000000000000..5f610e0e91f888
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/convert-sm90.ll
@@ -0,0 +1,68 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx78| FileCheck --check-prefixes=CHECK %s
+; RUN: %if ptxas-12.0 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_90 -mattr=+ptx78| %ptxas-verify -arch=sm_90 %}
+
+declare i32 @llvm.nvvm.f2tf32.rn(float %f1)
+declare i32 @llvm.nvvm.f2tf32.rn.relu(float %f1)
+declare i32 @llvm.nvvm.f2tf32.rz(float %f1)
+declare i32 @llvm.nvvm.f2tf32.rz.relu(float %f1)
+
+define i32 @cvt_rn_tf32_f32(float %f1) {
+; CHECK-LABEL: cvt_rn_tf32_f32(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<2>;
+; CHECK-NEXT:    .reg .f32 %f<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.f32 %f1, [cvt_rn_tf32_f32_param_0];
+; CHECK-NEXT:    cvt.rn.tf32.f32 %r1, %f1;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT:    ret;
+  %val = call i32 @llvm.nvvm.f2tf32.rn(float %f1)
+  ret i32 %val
+}
+
+define i32 @cvt_rn_relu_tf32_f32(float %f1) {
+; CHECK-LABEL: cvt_rn_relu_tf32_f32(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<2>;
+; CHECK-NEXT:    .reg .f32 %f<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.f32 %f1, [cvt_rn_relu_tf32_f32_param_0];
+; CHECK-NEXT:    cvt.rn.relu.tf32.f32 %r1, %f1;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT:    ret;
+  %val = call i32 @llvm.nvvm.f2tf32.rn.relu(float %f1)
+  ret i32 %val
+}
+
+define i32 @cvt_rz_tf32_f32(float %f1) {
+; CHECK-LABEL: cvt_rz_tf32_f32(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<2>;
+; CHECK-NEXT:    .reg .f32 %f<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.f32 %f1, [cvt_rz_tf32_f32_param_0];
+; CHECK-NEXT:    cvt.rz.tf32.f32 %r1, %f1;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT:    ret;
+  %val = call i32 @llvm.nvvm.f2tf32.rz(float %f1)
+  ret i32 %val
+}
+
+define i32 @cvt_rz_relu_tf32_f32(float %f1) {
+; CHECK-LABEL: cvt_rz_relu_tf32_f32(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<2>;
+; CHECK-NEXT:    .reg .f32 %f<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.f32 %f1, [cvt_rz_relu_tf32_f32_param_0];
+; CHECK-NEXT:    cvt.rz.relu.tf32.f32 %r1, %f1;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT:    ret;
+  %val = call i32 @llvm.nvvm.f2tf32.rz.relu(float %f1)
+  ret i32 %val
+}



More information about the llvm-commits mailing list