[Mlir-commits] [mlir] 83cad68 - [MLIR][NVVM] Update Float to TF32 conversion Op (#125048)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jan 31 22:02:48 PST 2025
Author: Srinivasa Ravi
Date: 2025-02-01T11:32:44+05:30
New Revision: 83cad6805d144d941bdda99d71a6df2cf113a76d
URL: https://github.com/llvm/llvm-project/commit/83cad6805d144d941bdda99d71a6df2cf113a76d
DIFF: https://github.com/llvm/llvm-project/commit/83cad6805d144d941bdda99d71a6df2cf113a76d.diff
LOG: [MLIR][NVVM] Update Float to TF32 conversion Op (#125048)
This change updates the Float to TF32 conversion MLIR Op to include
lowering to the new intrinsics introduced in sm_100 through ptx8.6:
- `nvvm_f2tf32_rn_satfinite`
- `nvvm_f2tf32_rn_relu_satfinite`
- `nvvm_f2tf32_rz_satfinite`
- `nvvm_f2tf32_rz_relu_satfinite`
PTX Spec Reference:
https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt
Added:
Modified:
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
mlir/test/Target/LLVMIR/nvvm/cvt_tf32.mlir
mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 3e0a6987bd85b0a..a5d09eaa34eb548 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -147,9 +147,6 @@ LogicalResult CvtFloatToTF32Op::verify() {
break;
case RndMode::RN:
case RndMode::RZ:
- if (getSat() != NVVM::SaturationMode::NONE)
- return emitError(
- "Saturation mode not supported with rn/rz rounding modes.");
break;
default:
return emitError(
@@ -1221,21 +1218,26 @@ llvm::Intrinsic::ID CpAsyncBulkTensorReduceOp::getIntrinsicID(
llvm_unreachable("Invalid Reduction Op for CpAsyncBulkTensorReduceOp");
}
+#define CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
+ hasRelu ? llvm::Intrinsic::nvvm_f2tf32_##rnd##relu##sf \
+ : llvm::Intrinsic::nvvm_f2tf32_##rnd##sf
+
+#define GET_CVT_F2TF32_ID(rnd, relu, sf) \
+ hasSatFinite ? CVT_F2TF32_ID_IMPL(rnd, relu, sf) \
+ : CVT_F2TF32_ID_IMPL(rnd, relu, )
+
llvm::Intrinsic::ID CvtFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
NVVM::SaturationMode sat,
bool hasRelu) {
using RndMode = NVVM::FPRoundingMode;
+ bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
switch (rnd) {
case RndMode::RN:
- return hasRelu ? llvm::Intrinsic::nvvm_f2tf32_rn_relu
- : llvm::Intrinsic::nvvm_f2tf32_rn;
+ return GET_CVT_F2TF32_ID(rn, _relu, _satfinite);
case RndMode::RZ:
- return hasRelu ? llvm::Intrinsic::nvvm_f2tf32_rz_relu
- : llvm::Intrinsic::nvvm_f2tf32_rz;
+ return GET_CVT_F2TF32_ID(rz, _relu, _satfinite);
case RndMode::RNA:
- return (sat == NVVM::SaturationMode::SATFINITE)
- ? llvm::Intrinsic::nvvm_f2tf32_rna_satfinite
- : llvm::Intrinsic::nvvm_f2tf32_rna;
+ return GET_CVT_F2TF32_ID(rna, , _satfinite);
default:
llvm_unreachable("Invalid RoundingMode for CvtFloatToTF32Op");
}
diff --git a/mlir/test/Target/LLVMIR/nvvm/cvt_tf32.mlir b/mlir/test/Target/LLVMIR/nvvm/cvt_tf32.mlir
index 90a232e4baac6f0..ff7bad0149d4cf8 100644
--- a/mlir/test/Target/LLVMIR/nvvm/cvt_tf32.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/cvt_tf32.mlir
@@ -28,6 +28,20 @@ llvm.func @convert_float_to_tf32_rn_relu(%src : f32) -> i32 {
llvm.return %res : i32
}
+// CHECK-LABEL: @convert_float_to_tf32_rn_sf
+llvm.func @convert_float_to_tf32_rn_sf(%src : f32) -> i32 {
+ // CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rn.satfinite(float %{{.*}})
+ %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>}
+ llvm.return %res : i32
+}
+
+// CHECK-LABEL: @convert_float_to_tf32_rn_relu_sf
+llvm.func @convert_float_to_tf32_rn_relu_sf(%src : f32) -> i32 {
+ // CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rn.relu.satfinite(float %{{.*}})
+ %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rn>, relu=true, sat = #nvvm.sat_mode<satfinite>}
+ llvm.return %res : i32
+}
+
// CHECK-LABEL: @convert_float_to_tf32_rz
llvm.func @convert_float_to_tf32_rz(%src : f32) -> i32 {
// CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rz(float %{{.*}})
@@ -41,3 +55,17 @@ llvm.func @convert_float_to_tf32_rz_relu(%src : f32) -> i32 {
%res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rz>, relu=true}
llvm.return %res : i32
}
+
+// CHECK-LABEL: @convert_float_to_tf32_rz_sf
+llvm.func @convert_float_to_tf32_rz_sf(%src : f32) -> i32 {
+ // CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rz.satfinite(float %{{.*}})
+ %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>}
+ llvm.return %res : i32
+}
+
+// CHECK-LABEL: @convert_float_to_tf32_rz_relu_sf
+llvm.func @convert_float_to_tf32_rz_relu_sf(%src : f32) -> i32 {
+ // CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rz.relu.satfinite(float %{{.*}})
+ %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rz>, relu=true, sat = #nvvm.sat_mode<satfinite>}
+ llvm.return %res : i32
+}
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index cb08064590bc301..8957377607dad6c 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -117,22 +117,6 @@ llvm.func @convert_float_to_tf32_rna_relu(%src : f32) -> i32 {
// -----
-llvm.func @convert_float_to_tf32_rn_sf(%src : f32) -> i32 {
- // expected-error @below {{Saturation mode not supported with rn/rz rounding modes.}}
- %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>}
- llvm.return %res : i32
-}
-
-// -----
-
-llvm.func @convert_float_to_tf32_rz_sf(%src : f32) -> i32 {
- // expected-error @below {{Saturation mode not supported with rn/rz rounding modes.}}
- %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>}
- llvm.return %res : i32
-}
-
-// -----
-
llvm.func @convert_float_to_tf32_no_rnd_mode(%src : f32) -> i32 {
// expected-error @below {{Only {rn,rz,rna} rounding modes supported for CvtFloatToTF32Op.}}
%res = nvvm.cvt.float.to.tf32 %src
More information about the Mlir-commits
mailing list