[Mlir-commits] [mlir] 6dcb2a0 - [MLIR][NVVM] Add Float to TF32 conversion Op (#123199)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 17 05:32:30 PST 2025


Author: Durgadoss R
Date: 2025-01-17T19:02:25+05:30
New Revision: 6dcb2a09028b25f8a8cfbda486d9b87a42fd3b30

URL: https://github.com/llvm/llvm-project/commit/6dcb2a09028b25f8a8cfbda486d9b87a42fd3b30
DIFF: https://github.com/llvm/llvm-project/commit/6dcb2a09028b25f8a8cfbda486d9b87a42fd3b30.diff

LOG: [MLIR][NVVM] Add Float to TF32 conversion Op (#123199)

PR #121507 added 'cvt' intrinsics to convert
float to tf32, with the valid set of rounding and
saturation modes. This PR adds an NVVM Dialect Op
for the same.
* lit tests are added to verify the lowering to intrinsics.
* Negative tests are also added to check the error-handling of invalid
combinations.

Signed-off-by: Durgadoss R <durgadossr at nvidia.com>

Added: 
    mlir/test/Target/LLVMIR/nvvm/cvt_tf32.mlir

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
    mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
    mlir/test/Target/LLVMIR/nvvmir-invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 04042903e343ed..bf3131932a56bc 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -970,6 +970,77 @@ def NVVM_CpAsyncMBarrierArriveSharedOp : NVVM_Op<"cp.async.mbarrier.arrive.share
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// NVVM Conversion Ops (for "cvt.*" family of PTX instructions)
+//===----------------------------------------------------------------------===//
+
+// Attributes for the floating point rounding modes supported by PTX
+def FPRoundingModeNone : I32EnumAttrCase<"NONE", 0, "none">;
+def FPRoundingModeRN   : I32EnumAttrCase<"RN",   1, "rn">;
+def FPRoundingModeRM   : I32EnumAttrCase<"RM",   2, "rm">;
+def FPRoundingModeRP   : I32EnumAttrCase<"RP",   3, "rp">;
+def FPRoundingModeRZ   : I32EnumAttrCase<"RZ",   4, "rz">;
+def FPRoundingModeRNA  : I32EnumAttrCase<"RNA",  5, "rna">;
+
+def FPRoundingMode : I32EnumAttr<"FPRoundingMode", "NVVM FPRoundingMode kind",
+  [FPRoundingModeNone, FPRoundingModeRN, FPRoundingModeRM,
+    FPRoundingModeRP, FPRoundingModeRZ, FPRoundingModeRNA]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::NVVM";
+}
+def FPRoundingModeAttr : EnumAttr<NVVM_Dialect, FPRoundingMode, "fp_rnd_mode"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
+def SaturationModeNone   : I32EnumAttrCase<"NONE", 0, "none">;
+def SaturationModeFinite : I32EnumAttrCase<"SATFINITE", 1, "satfinite">;
+
+def SaturationMode : I32EnumAttr<"SaturationMode", "NVVM SaturationMode kind",
+  [SaturationModeNone, SaturationModeFinite]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::NVVM";
+}
+def SaturationModeAttr : EnumAttr<NVVM_Dialect, SaturationMode, "sat_mode"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
+def NVVM_CvtFloatToTF32Op : NVVM_Op<"cvt.float.to.tf32"> {
+  let summary = "Convert the given float input to TF32";
+  let description = [{
+    This Op converts the given f32 input to tf32.
+    The result `res` is represented as an i32 type.
+    The `relu` attribute, when set, lowers to the '.relu' variant of
+    the cvt instruction. The `rnd` and `sat` attributes specify the
+    the rounding and saturation modes respectively.
+    [For more information, see PTX ISA]
+    (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
+  }];
+
+  let hasVerifier = 1;
+  let results = (outs I32:$res);
+  let arguments = (ins
+    F32:$src,
+    DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd,
+    DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
+    DefaultValuedAttr<BoolAttr, "false">:$relu);
+
+  let assemblyFormat = "$src attr-dict";
+
+  let extraClassDeclaration = [{
+    static llvm::Intrinsic::ID getIntrinsicID(NVVM::FPRoundingMode,
+                                              NVVM::SaturationMode,
+                                              bool hasRelu);
+  }];
+
+  string llvmBuilder = [{
+    auto intId = NVVM::CvtFloatToTF32Op::getIntrinsicID($rnd, $sat, $relu);
+    $res = createIntrinsicCall(builder, intId, {$src});
+  }];
+}
+
+//===----------------------------------------------------------------------===//
+// NVVM MMA Ops
+//===----------------------------------------------------------------------===//
 /// Helpers to instantiate 
diff erent version of wmma intrinsics.
 /// This matches the hierarchy used in IntrinsicsNVVM.td to define all the
 /// combinations of the intrinsics.

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index d8fde3e765ac49..ccb5ad05f0bf72 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -138,6 +138,26 @@ LogicalResult CpAsyncBulkTensorReduceOp::verify() {
                                          getLoc());
 }
 
+LogicalResult CvtFloatToTF32Op::verify() {
+  using RndMode = NVVM::FPRoundingMode;
+  switch (getRnd()) {
+  case RndMode::RNA:
+    if (getRelu())
+      return emitError("Relu not supported with rna rounding mode.");
+    break;
+  case RndMode::RN:
+  case RndMode::RZ:
+    if (getSat() != NVVM::SaturationMode::NONE)
+      return emitError(
+          "Saturation mode not supported with rn/rz rounding modes.");
+    break;
+  default:
+    return emitError(
+        "Only {rn,rz,rna} rounding modes supported for CvtFloatToTF32Op.");
+  }
+  return success();
+}
+
 // Given the element type of an operand and whether or not it is an accumulator,
 // this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the
 // operand's element type.
@@ -1163,6 +1183,26 @@ llvm::Intrinsic::ID CpAsyncBulkTensorReduceOp::getIntrinsicID(
   llvm_unreachable("Invalid Reduction Op for CpAsyncBulkTensorReduceOp");
 }
 
+llvm::Intrinsic::ID CvtFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
+                                                     NVVM::SaturationMode sat,
+                                                     bool hasRelu) {
+  using RndMode = NVVM::FPRoundingMode;
+  switch (rnd) {
+  case RndMode::RN:
+    return hasRelu ? llvm::Intrinsic::nvvm_f2tf32_rn_relu
+                   : llvm::Intrinsic::nvvm_f2tf32_rn;
+  case RndMode::RZ:
+    return hasRelu ? llvm::Intrinsic::nvvm_f2tf32_rz_relu
+                   : llvm::Intrinsic::nvvm_f2tf32_rz;
+  case RndMode::RNA:
+    return (sat == NVVM::SaturationMode::SATFINITE)
+               ? llvm::Intrinsic::nvvm_f2tf32_rna_satfinite
+               : llvm::Intrinsic::nvvm_f2tf32_rna;
+  default:
+    llvm_unreachable("Invalid RoundingMode for CvtFloatToTF32Op");
+  }
+}
+
 /// Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might
 /// have ConstantRangeAttr.
 static void nvvmInferResultRanges(Operation *op, Value result,

diff  --git a/mlir/test/Target/LLVMIR/nvvm/cvt_tf32.mlir b/mlir/test/Target/LLVMIR/nvvm/cvt_tf32.mlir
new file mode 100644
index 00000000000000..90a232e4baac6f
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/cvt_tf32.mlir
@@ -0,0 +1,43 @@
+// RUN: mlir-translate -mlir-to-llvmir %s  -split-input-file --verify-diagnostics | FileCheck %s
+
+// CHECK-LABEL: @convert_float_to_tf32_rna
+llvm.func @convert_float_to_tf32_rna(%src : f32) -> i32 {
+  // CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rna(float %{{.*}})
+  %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rna>}
+  llvm.return %res : i32
+}
+
+// CHECK-LABEL: @convert_float_to_tf32_rna_sf
+llvm.func @convert_float_to_tf32_rna_sf(%src : f32) -> i32 {
+  // CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rna.satfinite(float %{{.*}})
+  %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rna>, sat = #nvvm.sat_mode<satfinite>}
+  llvm.return %res : i32
+}
+
+// CHECK-LABEL: @convert_float_to_tf32_rn
+llvm.func @convert_float_to_tf32_rn(%src : f32) -> i32 {
+  // CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rn(float %{{.*}})
+  %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rn>}
+  llvm.return %res : i32
+}
+
+// CHECK-LABEL: @convert_float_to_tf32_rn_relu
+llvm.func @convert_float_to_tf32_rn_relu(%src : f32) -> i32 {
+  // CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rn.relu(float %{{.*}})
+  %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rn>, relu=true}
+  llvm.return %res : i32
+}
+
+// CHECK-LABEL: @convert_float_to_tf32_rz
+llvm.func @convert_float_to_tf32_rz(%src : f32) -> i32 {
+  // CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rz(float %{{.*}})
+  %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rz>}
+  llvm.return %res : i32
+}
+
+// CHECK-LABEL: @convert_float_to_tf32_rz_relu
+llvm.func @convert_float_to_tf32_rz_relu(%src : f32) -> i32 {
+  // CHECK: %{{.*}} = call i32 @llvm.nvvm.f2tf32.rz.relu(float %{{.*}})
+  %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rz>, relu=true}
+  llvm.return %res : i32
+}

diff  --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index 44c7126255dc4f..cb08064590bc30 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -106,3 +106,35 @@ llvm.func @tma_reduce_2d_im2col(%src : !llvm.ptr<3>, %tma_desc : !llvm.ptr, %d0
   nvvm.cp.async.bulk.tensor.reduce %tma_desc, %src, box[%d0, %d1] {redKind = #nvvm.tma_redux_kind<and>, mode = #nvvm.tma_store_mode<im2col>}: !llvm.ptr, !llvm.ptr<3>
   llvm.return
 }
+
+// -----
+
+llvm.func @convert_float_to_tf32_rna_relu(%src : f32) -> i32 {
+  // expected-error @below {{Relu not supported with rna rounding mode.}}
+  %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rna>, relu=true}
+  llvm.return %res : i32
+}
+
+// -----
+
+llvm.func @convert_float_to_tf32_rn_sf(%src : f32) -> i32 {
+  // expected-error @below {{Saturation mode not supported with rn/rz rounding modes.}}
+  %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>}
+  llvm.return %res : i32
+}
+
+// -----
+
+llvm.func @convert_float_to_tf32_rz_sf(%src : f32) -> i32 {
+  // expected-error @below {{Saturation mode not supported with rn/rz rounding modes.}}
+  %res = nvvm.cvt.float.to.tf32 %src {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>}
+  llvm.return %res : i32
+}
+
+// -----
+
+llvm.func @convert_float_to_tf32_no_rnd_mode(%src : f32) -> i32 {
+  // expected-error @below {{Only {rn,rz,rna} rounding modes supported for CvtFloatToTF32Op.}}
+  %res = nvvm.cvt.float.to.tf32 %src
+  llvm.return %res : i32
+}


        


More information about the Mlir-commits mailing list