[Mlir-commits] [mlir] [MLIR][NVVM] Add support for f6x2 conversion (PR #136537)

Srinivasa Ravi llvmlistbot at llvm.org
Thu Apr 24 23:54:39 PDT 2025


https://github.com/Wolfram70 updated https://github.com/llvm/llvm-project/pull/136537

>From 9815a64c2b02ce339b60c728782488aa41cad10a Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Tue, 15 Apr 2025 15:47:58 +0530
Subject: [PATCH] [MLIR][NVVM] Add support for f6x2 conversion

This patch adds the `cvt.to.fp6x2` NVVM dialect Op for conversion into
f6x2 types.

For more information, see PTX ISA: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-shfl-sync
---
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 53 +++++++++++++++++++++
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp  | 16 +++++++
 mlir/test/Target/LLVMIR/nvvm/cvt_fp6x2.mlir | 22 +++++++++
 3 files changed, 91 insertions(+)
 create mode 100644 mlir/test/Target/LLVMIR/nvvm/cvt_fp6x2.mlir

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 0a6e66919f021..0e1481ea374db 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1066,6 +1066,59 @@ def NVVM_CvtFloatToTF32Op : NVVM_Op<"cvt.float.to.tf32"> {
   }];
 }
 
+def CVTFP6E2M3 : I32EnumAttrCase<"E2M3", 0, "e2m3">;
+def CVTFP6E3M2 : I32EnumAttrCase<"E3M2", 1, "e3m2">;
+
+def CVTFP6Type : I32EnumAttr<"CVTFP6Type", "NVVM CVTFP6Type kind",
+  [CVTFP6E2M3, CVTFP6E3M2]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::NVVM";
+}
+def CVTFP6TypeAttr : EnumAttr<NVVM_Dialect, CVTFP6Type, "cvt_fp6_type"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
+def NVVM_CvtToF6x2Op : NVVM_Op<"cvt.to.f6x2"> {
+  let summary = "Convert a pair of float inputs to f6x2";
+  let description = [{
+    This Op converts each of the given float inputs to the specified fp6 type.
+    The result `dst` is represented either as an i16 type or as a vector
+    of two i8 types.
+    If `dst` is returned as an i16 type, the converted values are packed such 
+    that the value converted from `a` is stored in the upper 8 bits of `dst` 
+    with 2 MSB bits padded with zeros and the value converted from `b` is 
+    stored in the lower 8 bits of `dst` with 2 MSB bits padded with zeros.
+    If `dst` is returned as a vector type, each converted value is stored as an 
+    i8 element in the vector.
+    The `relu` attribute, when set, lowers to the '.relu' variant of
+    the cvt instruction.
+
+    [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
+  }];
+  let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
+  let arguments = (ins 
+    CVTFP6TypeAttr:$type,
+    F32:$a,
+    F32:$b,
+    DefaultValuedAttr<BoolAttr, "false">:$relu);
+  let assemblyFormat = "$type $a `,` $b attr-dict `:` type($dst)";
+  
+  let extraClassDeclaration = [{
+    static llvm::Intrinsic::ID getIntrinsicID(NVVM::CVTFP6Type,
+                                              bool hasRelu);
+  }];
+
+  string llvmBuilder = [{
+    auto intId = NVVM::CvtToF6x2Op::getIntrinsicID($type, $relu);
+    llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b});
+    if(op.getDst().getType().isInteger(16))
+      $dst = packedI16;
+    else
+      $dst = builder.CreateBitCast(packedI16,
+                      llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2));
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // NVVM MMA Ops
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index e3d496c983e59..44040401b0406 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1290,6 +1290,22 @@ llvm::Intrinsic::ID CvtFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
   }
 }
 
+#define CVT_TO_F6X2_ID_IMPL(type, has_relu)                                    \
+  has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite            \
+           : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
+
+llvm::Intrinsic::ID CvtToF6x2Op::getIntrinsicID(NVVM::CVTFP6Type type,
+                                                bool hasRelu) {
+  switch (type) {
+  case NVVM::CVTFP6Type::E2M3:
+    return CVT_TO_F6X2_ID_IMPL(e2m3x2, hasRelu);
+  case NVVM::CVTFP6Type::E3M2:
+    return CVT_TO_F6X2_ID_IMPL(e3m2x2, hasRelu);
+  default:
+    llvm_unreachable("Invalid CVTFP6Type for CvtToF6x2Op");
+  }
+}
+
 llvm::Intrinsic::ID
 Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op,
                                       LLVM::ModuleTranslation &mt,
diff --git a/mlir/test/Target/LLVMIR/nvvm/cvt_fp6x2.mlir b/mlir/test/Target/LLVMIR/nvvm/cvt_fp6x2.mlir
new file mode 100644
index 0000000000000..2237e6faad52d
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/cvt_fp6x2.mlir
@@ -0,0 +1,22 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: @convert_float_to_fp6x2_packed
+llvm.func @convert_float_to_fp6x2_packed(%srcA : f32, %srcB : f32) {
+  //CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.satfinite(float %{{.*}}, float %{{.*}})
+  %res1 = nvvm.cvt.to.f6x2 <e2m3> %srcA, %srcB : i16
+  //CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.satfinite(float %{{.*}}, float %{{.*}})
+  %res2 = nvvm.cvt.to.f6x2 <e3m2> %srcA, %srcB : i16
+  llvm.return
+}
+
+// CHECK-LABEL: @convert_float_to_fp6x2_vector
+llvm.func @convert_float_to_fp6x2_vector(%srcA : f32, %srcB : f32) {
+  //CHECK: %[[res0:.*]] = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.satfinite(float %{{.*}}, float %{{.*}})
+  //CHECK-NEXT: %{{.*}} = bitcast i16 %[[res0]] to <2 x i8>
+  %res1 = nvvm.cvt.to.f6x2 <e2m3> %srcA, %srcB : vector<2xi8>
+  //CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.satfinite(float %{{.*}}, float %{{.*}})
+  //CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
+  %res2 = nvvm.cvt.to.f6x2 <e3m2> %srcA, %srcB : vector<2xi8>
+  llvm.return
+}
+



More information about the Mlir-commits mailing list