[Mlir-commits] [mlir] [MLIR][NVVM] Add Float to TF32 conversion Op (PR #123199)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Jan 16 05:19:47 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-llvm
Author: Durgadoss R (durga4github)
<details>
<summary>Changes</summary>
PR #<!-- -->121507 added 'cvt' intrinsics to convert
float to tf32, with the valid set of rounding and
saturation modes. This PR adds an NVVM Dialect Op
for the same.
* lit tests are added to verify the lowering to intrinsics.
* Negative tests are also added to check the error-handling of invalid combinations.
---
Full diff: https://github.com/llvm/llvm-project/pull/123199.diff
4 Files Affected:
- (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+71)
- (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+40)
- (added) mlir/test/Target/LLVMIR/nvvm/cvt_tf32.mlir (+43)
- (modified) mlir/test/Target/LLVMIR/nvvmir-invalid.mlir (+32)
``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 04042903e343ed..bf3131932a56bc 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -970,6 +970,77 @@ def NVVM_CpAsyncMBarrierArriveSharedOp : NVVM_Op<"cp.async.mbarrier.arrive.share
}];
}
+//===----------------------------------------------------------------------===//
+// NVVM Conversion Ops (for "cvt.*" family of PTX instructions)
+//===----------------------------------------------------------------------===//
+
+// Attributes for the floating point rounding modes supported by PTX
+def FPRoundingModeNone : I32EnumAttrCase<"NONE", 0, "none">;
+def FPRoundingModeRN : I32EnumAttrCase<"RN", 1, "rn">;
+def FPRoundingModeRM : I32EnumAttrCase<"RM", 2, "rm">;
+def FPRoundingModeRP : I32EnumAttrCase<"RP", 3, "rp">;
+def FPRoundingModeRZ : I32EnumAttrCase<"RZ", 4, "rz">;
+def FPRoundingModeRNA : I32EnumAttrCase<"RNA", 5, "rna">;
+
+def FPRoundingMode : I32EnumAttr<"FPRoundingMode", "NVVM FPRoundingMode kind",
+ [FPRoundingModeNone, FPRoundingModeRN, FPRoundingModeRM,
+ FPRoundingModeRP, FPRoundingModeRZ, FPRoundingModeRNA]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::NVVM";
+}
+def FPRoundingModeAttr : EnumAttr<NVVM_Dialect, FPRoundingMode, "fp_rnd_mode"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
+def SaturationModeNone : I32EnumAttrCase<"NONE", 0, "none">;
+def SaturationModeFinite : I32EnumAttrCase<"SATFINITE", 1, "satfinite">;
+
+def SaturationMode : I32EnumAttr<"SaturationMode", "NVVM SaturationMode kind",
+ [SaturationModeNone, SaturationModeFinite]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::NVVM";
+}
+def SaturationModeAttr : EnumAttr<NVVM_Dialect, SaturationMode, "sat_mode"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
+def NVVM_CvtFloatToTF32Op : NVVM_Op<"cvt.float.to.tf32"> {
+ let summary = "Convert the given float input to TF32";
+ let description = [{
+ This Op converts the given f32 input to tf32.
+ The result `res` is represented as an i32 type.
+ The `relu` attribute, when set, lowers to the '.relu' variant of
+ the cvt instruction. The `rnd` and `sat` attributes specify the
+ the rounding and saturation modes respectively.
+ [For more information, see PTX ISA]
+ (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
+ }];
+
+ let hasVerifier = 1;
+ let results = (outs I32:$res);
+ let arguments = (ins
+ F32:$src,
+ DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd,
+ DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
+ DefaultValuedAttr<BoolAttr, "false">:$relu);
+
+ let assemblyFormat = "$src attr-dict";
+
+ let extraClassDeclaration = [{
+ static llvm::Intrinsic::ID getIntrinsicID(NVVM::FPRoundingMode,
+ NVVM::SaturationMode,
+ bool hasRelu);
+ }];
+
+ string llvmBuilder = [{
+ auto intId = NVVM::CvtFloatToTF32Op::getIntrinsicID($rnd, $sat, $relu);
+ $res = createIntrinsicCall(builder, intId, {$src});
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// NVVM MMA Ops
+//===----------------------------------------------------------------------===//
/// Helpers to instantiate different version of wmma intrinsics.
/// This matches the hierarchy used in IntrinsicsNVVM.td to define all the
/// combinations of the intrinsics.
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index d8fde3e765ac49..ccb5ad05f0bf72 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -138,6 +138,26 @@ LogicalResult CpAsyncBulkTensorReduceOp::verify() {
getLoc());
}
+LogicalResult CvtFloatToTF32Op::verify() {
+ using RndMode = NVVM::FPRoundingMode;
+ switch (getRnd()) {
+ case RndMode::RNA:
+ if (getRelu())
+ return emitError("Relu not supported with rna rounding mode.");
+ break;
+ case RndMode::RN:
+ case RndMode::RZ:
+ if (getSat() != NVVM::SaturationMode::NONE)
+ return emitError(
+ "Saturation mode not supported with rn/rz rounding modes.");
+ break;
+ default:
+ return emitError(
+ "Only {rn,rz,rna} rounding modes supported for CvtFloatToTF32Op.");
+ }
+ return success();
+}
+
// Given the element type of an operand and whether or not it is an accumulator,
// this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the
// operand's element type.
@@ -1163,6 +1183,26 @@ llvm::Intrinsic::ID CpAsyncBulkTensorReduceOp::getIntrinsicID(
llvm_unreachable("Invalid Reduction Op for CpAsyncBulkTensorReduceOp");
}
+llvm::Intrinsic::ID CvtFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
+ NVVM::SaturationMode sat,
+ bool hasRelu) {
+ using RndMode = NVVM::FPRoundingMode;
+ switch (rnd) {
+ case RndMode::RN:
+ return hasRelu ? llvm::Intrinsic::nvvm_f2tf32_rn_relu
+ : llvm::Intrinsic::nvvm_f2tf32_rn;
+ case RndMode::RZ:
+ return hasRelu ? llvm::Intrinsic::nvvm_f2tf32_rz_relu
+ : llvm::Intrinsic::nvvm_f2tf32_rz;
+ case RndMode::RNA:
+ return (sat == NVVM::SaturationMode::SATFINITE)
+ ? llvm::Intrinsic::nvvm_f2tf32_rna_satfinite
+ : llvm::Intrinsic::nvvm_f2tf32_rna;
+ default:
+ llvm_unreachable("Invalid RoundingMode for CvtFloatToTF32Op");
+ }
+}
+
/// Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might
/// have ConstantRangeAttr.
static void nvvmInferResultRanges(Operation *op, Value result,
diff --git a/mlir/test/Target/LLVMIR/nvvm/cvt_tf32.mlir b/mlir/test/Target/LLVMIR/nvvm/cvt_tf32.mlir
new file mode 100644
index 00000000000000..90a232e4baac6f
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/cvt_tf32.mlir
@@ -0,0 +1,43 @@
+// RUN: mlir-translate -mlir-to-llvmir %s -split-input-file --verify-diagnostics | FileCheck %s
+
+// CHECK-LABEL: @convert_float_to_tf32_rna
+llvm.func @convert_float_to_tf32_rna(%src : f32) -> i32 {
+ // CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rna(float %{{.*}})
+ %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rna>}
+ llvm.return %res : i32
+}
+
+// CHECK-LABEL: @convert_float_to_tf32_rna_sf
+llvm.func @convert_float_to_tf32_rna_sf(%src : f32) -> i32 {
+ // CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rna.satfinite(float %{{.*}})
+ %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rna>, sat = #nvvm.sat_mode<satfinite>}
+ llvm.return %res : i32
+}
+
+// CHECK-LABEL: @convert_float_to_tf32_rn
+llvm.func @convert_float_to_tf32_rn(%src : f32) -> i32 {
+ // CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rn(float %{{.*}})
+ %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rn>}
+ llvm.return %res : i32
+}
+
+// CHECK-LABEL: @convert_float_to_tf32_rn_relu
+llvm.func @convert_float_to_tf32_rn_relu(%src : f32) -> i32 {
+ // CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rn.relu(float %{{.*}})
+ %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rn>, relu=true}
+ llvm.return %res : i32
+}
+
+// CHECK-LABEL: @convert_float_to_tf32_rz
+llvm.func @convert_float_to_tf32_rz(%src : f32) -> i32 {
+ // CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rz(float %{{.*}})
+ %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rz>}
+ llvm.return %res : i32
+}
+
+// CHECK-LABEL: @convert_float_to_tf32_rz_relu
+llvm.func @convert_float_to_tf32_rz_relu(%src : f32) -> i32 {
+ // CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rz.relu(float %{{.*}})
+ %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rz>, relu=true}
+ llvm.return %res : i32
+}
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index 44c7126255dc4f..cb08064590bc30 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -106,3 +106,35 @@ llvm.func @tma_reduce_2d_im2col(%src : !llvm.ptr<3>, %tma_desc : !llvm.ptr, %d0
nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0, %d1] {redKind = #nvvm.tma_redux_kind<and>, mode = #nvvm.tma_store_mode<im2col>}: !llvm.ptr, !llvm.ptr<3>
llvm.return
}
+
+// -----
+
+llvm.func @convert_float_to_tf32_rna_relu(%src : f32) -> i32 {
+ // expected-error @below {{Relu not supported with rna rounding mode.}}
+ %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rna>, relu=true}
+ llvm.return %res : i32
+}
+
+// -----
+
+llvm.func @convert_float_to_tf32_rn_sf(%src : f32) -> i32 {
+ // expected-error @below {{Saturation mode not supported with rn/rz rounding modes.}}
+ %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>}
+ llvm.return %res : i32
+}
+
+// -----
+
+llvm.func @convert_float_to_tf32_rz_sf(%src : f32) -> i32 {
+ // expected-error @below {{Saturation mode not supported with rn/rz rounding modes.}}
+ %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>}
+ llvm.return %res : i32
+}
+
+// -----
+
+llvm.func @convert_float_to_tf32_no_rnd_mode(%src : f32) -> i32 {
+ // expected-error @below {{Only {rn,rz,rna} rounding modes supported for CvtFloatToTF32Op.}}
+ %res = nvvm.cvt.float.to.tf32 %src
+ llvm.return %res : i32
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/123199
More information about the Mlir-commits
mailing list