[Mlir-commits] [mlir] [MLIR][NVVM] Add support for f6x2 conversion (PR #136537)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Apr 20 22:38:50 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Srinivasa Ravi (Wolfram70)

<details>
<summary>Changes</summary>

This patch adds the `cvt.to.fp6x2` NVVM dialect Op for conversions into the f6x2 types, `e2m3x2` and `e3m2x2`.

For more information, see PTX ISA: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-shfl-sync

---
Full diff: https://github.com/llvm/llvm-project/pull/136537.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+49) 
- (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+45) 
- (added) mlir/test/Target/LLVMIR/nvvm/cvt_fp6x2.mlir (+22) 
- (modified) mlir/test/Target/LLVMIR/nvvmir-invalid.mlir (+16) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 8a54804b220a1..b8581a7504c67 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1032,6 +1032,55 @@ def NVVM_CvtFloatToTF32Op : NVVM_Op<"cvt.float.to.tf32"> {
   }];
 }
 
+def FP6E2M3 : I32EnumAttrCase<"E2M3", 0, "e2m3">;
+def FP6E3M2 : I32EnumAttrCase<"E3M2", 1, "e3m2">;
+
+def FP6Type : I32EnumAttr<"FP6Type", "NVVM FP6Type kind",
+  [FP6E2M3, FP6E3M2]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::NVVM";
+}
+def FP6TypeAttr : EnumAttr<NVVM_Dialect, FP6Type, "fp6_type"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
+def NVVM_CvtToF6x2Op : NVVM_Op<"cvt.to.f6x2"> {
+  let summary = "Convert the given float input to f6x2";
+  let description = [{
+    This Op converts the given float input to f6x2.
+    The result `res` is represented as an i16 type.
+    The `relu` attribute, when set, lowers to the '.relu' variant of
+    the cvt instruction. The `rnd` and `sat` attributes specify the
+    the rounding and saturation modes respectively.
+
+    [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
+  }];
+  let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
+  let arguments = (ins 
+    FP6TypeAttr:$type,
+    F32:$a,
+    F32:$b,
+    DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::RN">:$rnd,
+    DefaultValuedAttr<SaturationModeAttr, "SaturationMode::SATFINITE">:$sat,
+    DefaultValuedAttr<BoolAttr, "false">:$relu);
+  let assemblyFormat = "$type $a `,` $b attr-dict `:` type($dst)";
+  
+  let extraClassDeclaration = [{
+    static llvm::Intrinsic::ID getIntrinsicID(NVVM::FP6Type,
+                                              bool hasRelu);
+    bool isPacked();                                          
+    llvm::Value* getCastedResult(llvm::Value* packedI16, llvm::IRBuilderBase &builder);
+  }];
+
+  string llvmBuilder = [{
+    auto intId = NVVM::CvtToF6x2Op::getIntrinsicID($type, $relu);
+    llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b});
+    $dst = op.getCastedResult(packedI16, builder);
+  }];
+  
+  let hasVerifier = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // NVVM MMA Ops
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 556114f4370b3..8540a88653973 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -32,6 +32,7 @@
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/AsmParser/Parser.h"
 #include "llvm/IR/Attributes.h"
+#include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/Type.h"
 #include "llvm/Support/Casting.h"
@@ -160,6 +161,33 @@ LogicalResult CvtFloatToTF32Op::verify() {
   return success();
 }
 
+bool CvtToF6x2Op::isPacked() {
+  if (getDst().getType().isInteger(16)) {
+    return true;
+  }
+  return false;
+}
+
+llvm::Value *CvtToF6x2Op::getCastedResult(llvm::Value *packedI16,
+                                          llvm::IRBuilderBase &builder) {
+  if (isPacked()) {
+    return packedI16;
+  }
+  return builder.CreateBitCast(
+      packedI16, llvm::FixedVectorType::get(
+                     llvm::Type::getInt8Ty(builder.getContext()), 2));
+}
+
+LogicalResult CvtToF6x2Op::verify() {
+  if (getRnd() != NVVM::FPRoundingMode::RN) {
+    return emitOpError("RN rounding mode required for CvtToF6x2Op.");
+  }
+  if (getSat() != NVVM::SaturationMode::SATFINITE) {
+    return emitOpError("SATFINITE saturation mode required for CvtToF6x2Op.");
+  }
+  return success();
+}
+
 LogicalResult BulkStoreOp::verify() {
   if (getInitVal() != 0)
     return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
@@ -1300,6 +1328,23 @@ llvm::Intrinsic::ID CvtFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
   }
 }
 
+#define CVT_TO_F6X2_ID_IMPL(type, relu)                                        \
+  hasRelu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn##relu##_satfinite          \
+          : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
+
+llvm::Intrinsic::ID CvtToF6x2Op::getIntrinsicID(NVVM::FP6Type type,
+                                                bool hasRelu) {
+  switch (type) {
+  case NVVM::FP6Type::E2M3:
+    return CVT_TO_F6X2_ID_IMPL(e2m3x2, _relu);
+  case NVVM::FP6Type::E3M2:
+    return CVT_TO_F6X2_ID_IMPL(e3m2x2, _relu);
+  default:
+    break;
+  }
+  llvm_unreachable("Invalid FP6Type for CvtToF6x2Op");
+}
+
 llvm::Intrinsic::ID
 Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op,
                                       LLVM::ModuleTranslation &mt,
diff --git a/mlir/test/Target/LLVMIR/nvvm/cvt_fp6x2.mlir b/mlir/test/Target/LLVMIR/nvvm/cvt_fp6x2.mlir
new file mode 100644
index 0000000000000..2237e6faad52d
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/cvt_fp6x2.mlir
@@ -0,0 +1,22 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: @convert_float_to_fp6x2_packed
+llvm.func @convert_float_to_fp6x2_packed(%srcA : f32, %srcB : f32) {
+  //CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.satfinite(float %{{.*}}, float %{{.*}})
+  %res1 = nvvm.cvt.to.f6x2 <e2m3> %srcA, %srcB : i16
+  //CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.satfinite(float %{{.*}}, float %{{.*}})
+  %res2 = nvvm.cvt.to.f6x2 <e3m2> %srcA, %srcB : i16
+  llvm.return
+}
+
+// CHECK-LABEL: @convert_float_to_fp6x2_vector
+llvm.func @convert_float_to_fp6x2_vector(%srcA : f32, %srcB : f32) {
+  //CHECK: %[[res0:.*]] = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.satfinite(float %{{.*}}, float %{{.*}})
+  //CHECK-NEXT: %{{.*}} = bitcast i16 %[[res0]] to <2 x i8>
+  %res1 = nvvm.cvt.to.f6x2 <e2m3> %srcA, %srcB : vector<2xi8>
+  //CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.satfinite(float %{{.*}}, float %{{.*}})
+  //CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
+  %res2 = nvvm.cvt.to.f6x2 <e3m2> %srcA, %srcB : vector<2xi8>
+  llvm.return
+}
+
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index f87f11daeef54..5fcef1aa67139 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -176,3 +176,19 @@ llvm.func @nvvm_match_sync_any(%val32: i32, %thread_mask: i32) {
   %0 = nvvm.match.sync any %thread_mask, %val32 : i32 -> !llvm.struct<(i32, i1)>
   llvm.return
 }
+
+// -----
+
+llvm.func @nvvm_cvt_to_f6x2(%a : f32, %b : f32) {
+  // expected-error @below {{RN rounding mode required for CvtToF6x2Op.}}
+  %res = nvvm.cvt.to.f6x2 <e2m3> %a, %b {rnd = #nvvm.fp_rnd_mode<rna>} : i16
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f6x2_packed(%a : f32, %b : f32) {
+  // expected-error @below {{SATFINITE saturation mode required for CvtToF6x2Op.}}
+  %res = nvvm.cvt.to.f6x2 <e3m2> %a, %b {sat = #nvvm.sat_mode<none>} : i16
+  llvm.return
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/136537


More information about the Mlir-commits mailing list