[Mlir-commits] [mlir] [MLIR][NVVM] Update Float to TF32 conversion Op (PR #125048)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jan 30 02:33:57 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-llvm

Author: Srinivasa Ravi (Wolfram70)

<details>
<summary>Changes</summary>

This change updates the Float to TF32 conversion MLIR Op to include lowering to the new intrinsics introduced in sm_100 through ptx8.6:

- `nvvm_f2tf32_rn_satfinite`
- `nvvm_f2tf32_rn_relu_satfinite`
- `nvvm_f2tf32_rz_satfinite`
- `nvvm_f2tf32_rz_relu_satfinite`

PTX Spec Reference:
https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt

---
Full diff: https://github.com/llvm/llvm-project/pull/125048.diff


3 Files Affected:

- (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+19-10) 
- (modified) mlir/test/Target/LLVMIR/nvvm/cvt_tf32.mlir (+28) 
- (modified) mlir/test/Target/LLVMIR/nvvmir-invalid.mlir (-16) 


``````````diff
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 3e0a6987bd85b0..1ad20bb35273ea 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -147,9 +147,6 @@ LogicalResult CvtFloatToTF32Op::verify() {
     break;
   case RndMode::RN:
   case RndMode::RZ:
-    if (getSat() != NVVM::SaturationMode::NONE)
-      return emitError(
-          "Saturation mode not supported with rn/rz rounding modes.");
     break;
   default:
     return emitError(
@@ -1225,17 +1222,29 @@ llvm::Intrinsic::ID CvtFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
                                                      NVVM::SaturationMode sat,
                                                      bool hasRelu) {
   using RndMode = NVVM::FPRoundingMode;
+  bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
+  bool hasReluAndSatFinite = hasRelu && hasSatFinite;
   switch (rnd) {
   case RndMode::RN:
-    return hasRelu ? llvm::Intrinsic::nvvm_f2tf32_rn_relu
-                   : llvm::Intrinsic::nvvm_f2tf32_rn;
+    if(hasReluAndSatFinite)
+      return llvm::Intrinsic::nvvm_f2tf32_rn_relu_satfinite;
+    if(hasRelu)
+      return llvm::Intrinsic::nvvm_f2tf32_rn_relu;
+    if(hasSatFinite)
+      return llvm::Intrinsic::nvvm_f2tf32_rn_satfinite;
+    return llvm::Intrinsic::nvvm_f2tf32_rn;
   case RndMode::RZ:
-    return hasRelu ? llvm::Intrinsic::nvvm_f2tf32_rz_relu
-                   : llvm::Intrinsic::nvvm_f2tf32_rz;
+    if(hasReluAndSatFinite)
+      return llvm::Intrinsic::nvvm_f2tf32_rz_relu_satfinite;
+    if(hasRelu)
+      return llvm::Intrinsic::nvvm_f2tf32_rz_relu;
+    if(hasSatFinite)
+      return llvm::Intrinsic::nvvm_f2tf32_rz_satfinite;
+    return llvm::Intrinsic::nvvm_f2tf32_rz;
   case RndMode::RNA:
-    return (sat == NVVM::SaturationMode::SATFINITE)
-               ? llvm::Intrinsic::nvvm_f2tf32_rna_satfinite
-               : llvm::Intrinsic::nvvm_f2tf32_rna;
+    return hasSatFinite
+            ? llvm::Intrinsic::nvvm_f2tf32_rna_satfinite
+            : llvm::Intrinsic::nvvm_f2tf32_rna;
   default:
     llvm_unreachable("Invalid RoundingMode for CvtFloatToTF32Op");
   }
diff --git a/mlir/test/Target/LLVMIR/nvvm/cvt_tf32.mlir b/mlir/test/Target/LLVMIR/nvvm/cvt_tf32.mlir
index 90a232e4baac6f..ff7bad0149d4cf 100644
--- a/mlir/test/Target/LLVMIR/nvvm/cvt_tf32.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/cvt_tf32.mlir
@@ -28,6 +28,20 @@ llvm.func @convert_float_to_tf32_rn_relu(%src : f32) -> i32 {
   llvm.return %res : i32
 }
 
+// CHECK-LABEL: @convert_float_to_tf32_rn_sf
+llvm.func @convert_float_to_tf32_rn_sf(%src : f32) -> i32 {
+  // CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rn.satfinite(float %{{.*}})
+  %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>}
+  llvm.return %res : i32
+}
+
+// CHECK-LABEL: @convert_float_to_tf32_rn_relu_sf
+llvm.func @convert_float_to_tf32_rn_relu_sf(%src : f32) -> i32 {
+  // CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rn.relu.satfinite(float %{{.*}})
+  %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rn>, relu=true, sat = #nvvm.sat_mode<satfinite>}
+  llvm.return %res : i32
+}
+
 // CHECK-LABEL: @convert_float_to_tf32_rz
 llvm.func @convert_float_to_tf32_rz(%src : f32) -> i32 {
   // CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rz(float %{{.*}})
@@ -41,3 +55,17 @@ llvm.func @convert_float_to_tf32_rz_relu(%src : f32) -> i32 {
   %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rz>, relu=true}
   llvm.return %res : i32
 }
+
+// CHECK-LABEL: @convert_float_to_tf32_rz_sf
+llvm.func @convert_float_to_tf32_rz_sf(%src : f32) -> i32 {
+  // CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rz.satfinite(float %{{.*}})
+  %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>}
+  llvm.return %res : i32
+}
+
+// CHECK-LABEL: @convert_float_to_tf32_rz_relu_sf
+llvm.func @convert_float_to_tf32_rz_relu_sf(%src : f32) -> i32 {
+  // CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rz.relu.satfinite(float %{{.*}})
+  %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rz>, relu=true, sat = #nvvm.sat_mode<satfinite>}
+  llvm.return %res : i32
+}
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index cb08064590bc30..8957377607dad6 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -117,22 +117,6 @@ llvm.func @convert_float_to_tf32_rna_relu(%src : f32) -> i32 {
 
 // -----
 
-llvm.func @convert_float_to_tf32_rn_sf(%src : f32) -> i32 {
-  // expected-error @below {{Saturation mode not supported with rn/rz rounding modes.}}
-  %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>}
-  llvm.return %res : i32
-}
-
-// -----
-
-llvm.func @convert_float_to_tf32_rz_sf(%src : f32) -> i32 {
-  // expected-error @below {{Saturation mode not supported with rn/rz rounding modes.}}
-  %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>}
-  llvm.return %res : i32
-}
-
-// -----
-
 llvm.func @convert_float_to_tf32_no_rnd_mode(%src : f32) -> i32 {
   // expected-error @below {{Only {rn,rz,rna} rounding modes supported for CvtFloatToTF32Op.}}
   %res = nvvm.cvt.float.to.tf32 %src

``````````

</details>


https://github.com/llvm/llvm-project/pull/125048


More information about the Mlir-commits mailing list