[Mlir-commits] [mlir] [MLIR][NVVM] Update support for conversions to f8x2 and f6x2 types (PR #137781)

Srinivasa Ravi llvmlistbot at llvm.org
Mon May 5 04:49:35 PDT 2025


https://github.com/Wolfram70 updated https://github.com/llvm/llvm-project/pull/137781

>From 5a694bff7a34b21d859a3918aad468c8153edc0a Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Tue, 15 Apr 2025 15:47:58 +0530
Subject: [PATCH] [MLIR][NVVM] Update support for conversions to f8x2 and f6x2
 types

This change:
- Adds the `cvt.f32x2.to.f6x2`, `cvt.f16x2.to.f8x2`, and `cvt.bf16x2.to.f8x2`
  Ops to the NVVM dialect for the conversions to `.e4m3x2`, `e5m2x2`,
  and `.ue8m0x2` types.
- Renames the recently added `cvt.to.f6x2` Op to `cvt.f32x2.to.f6x2`
  for consistency with the other conversion Ops.

For more information, see PTX ISA: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt
---
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 151 +++++++++++++++++++-
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp  | 129 ++++++++++++++++-
 mlir/test/Target/LLVMIR/nvvm/cvt_fp6x2.mlir |  17 ++-
 mlir/test/Target/LLVMIR/nvvm/cvt_fp8x2.mlir | 102 +++++++++++++
 mlir/test/Target/LLVMIR/nvvmir-invalid.mlir |  72 ++++++++++
 5 files changed, 455 insertions(+), 16 deletions(-)
 create mode 100644 mlir/test/Target/LLVMIR/nvvm/cvt_fp8x2.mlir

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 27d54e7abeda9..4b8485f3c3e7f 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1079,7 +1079,7 @@ def CVTFP6TypeAttr : EnumAttr<NVVM_Dialect, CVTFP6Type, "cvt_fp6_type"> {
   let assemblyFormat = "`<` $value `>`";
 }
 
-def NVVM_CvtToF6x2Op : NVVM_Op<"cvt.to.f6x2"> {
+def NVVM_CvtF32x2ToF6x2Op : NVVM_Op<"cvt.f32x2.to.f6x2"> {
   let summary = "Convert a pair of float inputs to f6x2";
   let description = [{
     This Op converts each of the given float inputs to the specified fp6 type.
@@ -1110,7 +1110,7 @@ def NVVM_CvtToF6x2Op : NVVM_Op<"cvt.to.f6x2"> {
   }];
 
   string llvmBuilder = [{
-    auto intId = NVVM::CvtToF6x2Op::getIntrinsicID($type, $relu);
+    auto intId = NVVM::CvtF32x2ToF6x2Op::getIntrinsicID($type, $relu);
     llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b});
     if(op.getDst().getType().isInteger(16))
       $dst = packedI16;
@@ -1120,6 +1120,153 @@ def NVVM_CvtToF6x2Op : NVVM_Op<"cvt.to.f6x2"> {
   }];
 }
 
+def CVTFP8E4M3 : I32EnumAttrCase<"E4M3", 0, "e4m3">;
+def CVTFP8E5M2 : I32EnumAttrCase<"E5M2", 1, "e5m2">;
+def CVTFP8UE8M0 : I32EnumAttrCase<"UE8M0", 2, "ue8m0">;
+
+def CVTFP8Type : I32EnumAttr<"CVTFP8Type", "NVVM CVTFP8Type kind",
+  [CVTFP8E4M3, CVTFP8E5M2, CVTFP8UE8M0]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::NVVM";
+}
+def CVTFP8TypeAttr : EnumAttr<NVVM_Dialect, CVTFP8Type, "cvt_fp8_type"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
+def NVVM_CvtF32x2ToF8x2Op : NVVM_Op<"cvt.f32x2.to.f8x2"> {
+  let summary = "Convert a pair of float inputs to f8x2";
+  let description = [{
+    This Op converts each of the given float inputs to the specified fp8 type.
+    The result `dst` is represented as an i16 type or as a vector
+    of two i8 types.
+    If `dst` is returned as an i16 type, the converted values are packed such 
+    that the value converted from `a` is stored in the upper 8 bits of `dst` 
+    and the value converted from `b` is stored in the lower 8 bits of `dst`.
+    If `dst` is returned as a vector type, each converted value is stored as an 
+    i8 element in the vector.
+    The `rnd` and `sat` attributes specify the rounding and saturation modes respectively.
+    The `relu` attribute, when set, lowers to the '.relu' variant of
+    the cvt instruction.
+
+    [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
+  }];
+  let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
+  let arguments = (ins
+    CVTFP8TypeAttr:$type,
+    F32:$a,
+    F32:$b,
+    DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd,
+    DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
+    DefaultValuedAttr<BoolAttr, "false">:$relu);
+  let assemblyFormat = "$type $a `,` $b attr-dict `:` type($dst)";
+
+  let extraClassDeclaration = [{
+    static llvm::Intrinsic::ID getIntrinsicID(NVVM::CVTFP8Type to,
+                                              NVVM::FPRoundingMode rnd,
+                                              NVVM::SaturationMode sat,
+                                              bool hasRelu);
+  }];
+  
+  string llvmBuilder = [{
+    auto intId = NVVM::CvtF32x2ToF8x2Op::getIntrinsicID($type, $rnd, $sat, $relu);
+    llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b});
+    if(op.getDst().getType().isInteger(16))
+      $dst = packedI16;
+    else
+      $dst = builder.CreateBitCast(packedI16,
+                      llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2));
+  }];
+  
+  let hasVerifier = 1;
+}
+
+def NVVM_CvtF16x2ToF8x2Op : NVVM_Op<"cvt.f16x2.to.f8x2"> {
+  let summary = "Convert an f16x2 input to f8x2";
+  let description = [{
+    This Op converts the given f16 inputs in an f16x2 vector to the specified 
+    f8 type.
+    The result `dst` is represented as an i16 type or as a vector
+    of two i8 types.
+    If `dst` is returned as an i16 type, the converted values from `a`
+    are packed such that the value converted from the first element of `a`
+    is stored in the upper 8 bits of `dst` and the value converted from the
+    second element of `a` is stored in the lower 8 bits of `dst`.
+    If `dst` is returned as a vector type, each converted value is stored as an 
+    i8 element in the vector.
+    The `relu` attribute, when set, lowers to the '.relu' variant of
+    the cvt instruction.
+
+    [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
+  }];
+  let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
+  let arguments = (ins
+    CVTFP8TypeAttr:$type,
+    VectorOfLengthAndType<[2], [F16]>:$a,
+    DefaultValuedAttr<BoolAttr, "false">:$relu);
+  let assemblyFormat = "$type $a attr-dict `:` type($a) `->` type($dst)";
+
+  let extraClassDeclaration = [{
+    static llvm::Intrinsic::ID getIntrinsicID(NVVM::CVTFP8Type to,
+                                              bool hasRelu);
+  }];
+
+  string llvmBuilder = [{
+    auto intId = NVVM::CvtF16x2ToF8x2Op::getIntrinsicID($type, $relu);
+    llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a});
+    if(op.getDst().getType().isInteger(16))
+      $dst = packedI16;
+    else
+      $dst = builder.CreateBitCast(packedI16,
+                      llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2));
+  }];
+
+  let hasVerifier = 1;
+}
+
+def NVVM_CvtBF16x2ToF8x2Op : NVVM_Op<"cvt.bf16x2.to.f8x2"> {
+  let summary = "Convert a pair of bf16 inputs to f8x2";
+  let description = [{
+    This Op converts the given bf16 inputs in a bf16x2 vector to the specified 
+    f8 type.
+    The result `dst` is represented as an i16 type or as a vector
+    of two i8 types.
+    If `dst` is returned as an i16 type, the converted values from `a`
+    are packed such that the value converted from the first element of `a`
+    is stored in the upper 8 bits of `dst` and the value converted from the
+    second element of `a` is stored in the lower 8 bits of `dst`.
+    If `dst` is returned as a vector type, each converted value is stored as an 
+    i8 element in the vector.
+    The `rnd` and `sat` attributes specify the rounding and saturation modes 
+    respectively.
+
+    [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
+  }];
+  let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
+  let arguments = (ins
+    CVTFP8TypeAttr:$type,
+    VectorOfLengthAndType<[2], [BF16]>:$a,
+    DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd,
+    DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat);
+  let assemblyFormat = "$type $a attr-dict `:` type($a) `->` type($dst)";
+  
+  let extraClassDeclaration = [{
+    static llvm::Intrinsic::ID getIntrinsicID(NVVM::FPRoundingMode rnd,
+                                              NVVM::SaturationMode sat);
+  }];
+
+  string llvmBuilder = [{
+    auto intId = NVVM::CvtBF16x2ToF8x2Op::getIntrinsicID($rnd, $sat);
+    llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a});
+    if(op.getDst().getType().isInteger(16))
+      $dst = packedI16;
+    else
+      $dst = builder.CreateBitCast(packedI16,
+                      llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2));
+  }];
+
+  let hasVerifier = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // NVVM MMA Ops
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 18453aa7f6ea9..803d6d0d792a7 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -133,6 +133,61 @@ LogicalResult CvtFloatToTF32Op::verify() {
   return success();
 }
 
+LogicalResult CvtF32x2ToF8x2Op::verify() {
+  using RndMode = NVVM::FPRoundingMode;
+  using SatMode = NVVM::SaturationMode;
+
+  bool isRoundingModeRN = getRnd() == RndMode::RN;
+  bool isRoundingModeRZ = getRnd() == RndMode::RZ;
+  bool isRoundingModeRP = getRnd() == RndMode::RP;
+  bool isSatFinite = getSat() == SatMode::SATFINITE;
+
+  bool hasRelu = getRelu();
+
+  switch (getType()) {
+  case CVTFP8Type::E4M3:
+  case CVTFP8Type::E5M2:
+    if (!isRoundingModeRN)
+      return emitOpError("Only RN rounding mode is supported for conversions "
+                         "from f32x2 to .e4m3x2 or .e5m2x2 types");
+    if (!isSatFinite)
+      return emitOpError("Only SATFINITE saturation mode is supported for "
+                         "conversions from f32x2 to .e4m3x2 or .e5m2x2 types");
+    break;
+  case CVTFP8Type::UE8M0:
+    if (!(isRoundingModeRZ || isRoundingModeRP))
+      return emitOpError("Only RZ or RP rounding modes are supported for "
+                         "conversions from f32x2 to .ue8m0x2 type");
+    if (hasRelu)
+      return emitOpError("relu not supported for conversions to .ue8m0x2 type");
+    break;
+  }
+  return success();
+}
+
+LogicalResult CvtF16x2ToF8x2Op::verify() {
+  if(getType() == CVTFP8Type::UE8M0)
+    return emitOpError("Only .e4m3 or .e5m2 types are supported for "
+                       "conversions from f16x2 to f8x2.");
+
+  return success();
+}
+
+LogicalResult CvtBF16x2ToF8x2Op::verify() {
+  using RndMode = NVVM::FPRoundingMode;
+
+  if (getType() != CVTFP8Type::UE8M0)
+    return emitOpError(
+        "Only .ue8m0 type is supported for conversions from bf16x2 to f8x2.");
+  
+  auto rnd = getRnd();
+  if(!(rnd == RndMode::RZ || rnd == RndMode::RP))
+    return emitOpError("Only RZ and RP rounding modes are supported for "
+                       "conversions from bf16x2 to f8x2.");
+
+  return success();
+}
+
 LogicalResult BulkStoreOp::verify() {
   if (getInitVal() != 0)
     return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
@@ -1290,17 +1345,81 @@ llvm::Intrinsic::ID CvtFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
   }
 }
 
-#define CVT_TO_F6X2_ID_IMPL(type, has_relu)                                    \
+#define GET_FLOAT_TO_F6x2_ID(type, has_relu)                                   \
   has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite            \
            : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
 
-llvm::Intrinsic::ID CvtToF6x2Op::getIntrinsicID(NVVM::CVTFP6Type type,
-                                                bool hasRelu) {
+llvm::Intrinsic::ID CvtF32x2ToF6x2Op::getIntrinsicID(NVVM::CVTFP6Type type,
+                                                     bool hasRelu) {
   switch (type) {
   case NVVM::CVTFP6Type::E2M3:
-    return CVT_TO_F6X2_ID_IMPL(e2m3x2, hasRelu);
+    return GET_FLOAT_TO_F6x2_ID(e2m3x2, hasRelu);
   case NVVM::CVTFP6Type::E3M2:
-    return CVT_TO_F6X2_ID_IMPL(e3m2x2, hasRelu);
+    return GET_FLOAT_TO_F6x2_ID(e3m2x2, hasRelu);
+  }
+}
+
+#define GET_FLOAT_TO_F8X2_US_ID(rnd, has_satf)                                 \
+  has_satf ? llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd##_satfinite             \
+           : llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd
+
+#define GET_FLOAT_TO_F8X2_S_ID(type, has_relu)                                 \
+  has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu                      \
+           : llvm::Intrinsic::nvvm_ff_to_##type##_rn
+
+llvm::Intrinsic::ID CvtF32x2ToF8x2Op::getIntrinsicID(NVVM::CVTFP8Type type,
+                                                     NVVM::FPRoundingMode rnd,
+                                                     NVVM::SaturationMode sat,
+                                                     bool hasRelu) {
+  bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
+  bool hasRoundingModeRZ = (rnd == NVVM::FPRoundingMode::RZ);
+  bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
+
+  switch (type) {
+  case NVVM::CVTFP8Type::E4M3:
+    return GET_FLOAT_TO_F8X2_S_ID(e4m3x2, hasRelu);
+  case NVVM::CVTFP8Type::E5M2:
+    return GET_FLOAT_TO_F8X2_S_ID(e5m2x2, hasRelu);
+  case NVVM::CVTFP8Type::UE8M0:
+    if (hasRoundingModeRZ)
+      return GET_FLOAT_TO_F8X2_US_ID(rz, hasSatFinite);
+    else if (hasRoundingModeRP)
+      return GET_FLOAT_TO_F8X2_US_ID(rp, hasSatFinite);
+  }
+  llvm_unreachable("Invalid conversion in CvtFloatToF8x2Op");
+}
+
+#define GET_F16x2_TO_F8X2_ID(type, has_relu)                                   \
+  has_relu ? llvm::Intrinsic::nvvm_f16x2_to_##type##_rn_relu                   \
+           : llvm::Intrinsic::nvvm_f16x2_to_##type##_rn
+
+llvm::Intrinsic::ID CvtF16x2ToF8x2Op::getIntrinsicID(NVVM::CVTFP8Type type,
+                                                     bool hasRelu) {
+  switch (type) {
+  case NVVM::CVTFP8Type::E4M3:
+    return GET_F16x2_TO_F8X2_ID(e4m3x2, hasRelu);
+  case NVVM::CVTFP8Type::E5M2:
+    return GET_F16x2_TO_F8X2_ID(e5m2x2, hasRelu);
+  default:
+    llvm_unreachable("Invalid CVTFP8Type for CvtF16x2ToF8x2Op");
+  }
+}
+
+#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf)                                   \
+  has_satf ? llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd##_satfinite         \
+           : llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd
+
+llvm::Intrinsic::ID
+CvtBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
+                                  NVVM::SaturationMode sat) {
+  bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
+  switch (rnd) {
+  case NVVM::FPRoundingMode::RZ:
+    return GET_BF16X2_TO_F8X2_ID(rz, hasSatFinite);
+  case NVVM::FPRoundingMode::RP:
+    return GET_BF16X2_TO_F8X2_ID(rp, hasSatFinite);
+  default:
+    llvm_unreachable("Invalid rounding mode for CvtBF16x2ToF8x2Op");
   }
 }
 
diff --git a/mlir/test/Target/LLVMIR/nvvm/cvt_fp6x2.mlir b/mlir/test/Target/LLVMIR/nvvm/cvt_fp6x2.mlir
index 2237e6faad52d..8ccc656e57e1c 100644
--- a/mlir/test/Target/LLVMIR/nvvm/cvt_fp6x2.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/cvt_fp6x2.mlir
@@ -1,22 +1,21 @@
 // RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
 
-// CHECK-LABEL: @convert_float_to_fp6x2_packed
-llvm.func @convert_float_to_fp6x2_packed(%srcA : f32, %srcB : f32) {
+// CHECK-LABEL: @convert_f32x2_to_fp6x2_packed
+llvm.func @convert_f32x2_to_fp6x2_packed(%srcA : f32, %srcB : f32) {
   //CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.satfinite(float %{{.*}}, float %{{.*}})
-  %res1 = nvvm.cvt.to.f6x2 <e2m3> %srcA, %srcB : i16
+  %res1 = nvvm.cvt.f32x2.to.f6x2 <e2m3> %srcA, %srcB : i16
   //CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.satfinite(float %{{.*}}, float %{{.*}})
-  %res2 = nvvm.cvt.to.f6x2 <e3m2> %srcA, %srcB : i16
+  %res2 = nvvm.cvt.f32x2.to.f6x2 <e3m2> %srcA, %srcB : i16
   llvm.return
 }
 
-// CHECK-LABEL: @convert_float_to_fp6x2_vector
-llvm.func @convert_float_to_fp6x2_vector(%srcA : f32, %srcB : f32) {
+// CHECK-LABEL: @convert_f32x2_to_fp6x2_vector
+llvm.func @convert_f32x2_to_fp6x2_vector(%srcA : f32, %srcB : f32) {
   //CHECK: %[[res0:.*]] = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.satfinite(float %{{.*}}, float %{{.*}})
   //CHECK-NEXT: %{{.*}} = bitcast i16 %[[res0]] to <2 x i8>
-  %res1 = nvvm.cvt.to.f6x2 <e2m3> %srcA, %srcB : vector<2xi8>
+  %res1 = nvvm.cvt.f32x2.to.f6x2 <e2m3> %srcA, %srcB : vector<2xi8>
   //CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.satfinite(float %{{.*}}, float %{{.*}})
   //CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
-  %res2 = nvvm.cvt.to.f6x2 <e3m2> %srcA, %srcB : vector<2xi8>
+  %res2 = nvvm.cvt.f32x2.to.f6x2 <e3m2> %srcA, %srcB : vector<2xi8>
   llvm.return
 }
-
diff --git a/mlir/test/Target/LLVMIR/nvvm/cvt_fp8x2.mlir b/mlir/test/Target/LLVMIR/nvvm/cvt_fp8x2.mlir
new file mode 100644
index 0000000000000..8ea0bbabe4d0a
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/cvt_fp8x2.mlir
@@ -0,0 +1,102 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// -----
+
+// CHECK-LABEL: @convert_f32x2_to_f8x2_e4m3
+llvm.func @convert_f32x2_to_f8x2_e4m3(%srcA : f32, %srcB : f32) {
+  // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e4m3x2.rn(float %{{.*}}, float %{{.*}})
+  %res1 = nvvm.cvt.f32x2.to.f8x2 <e4m3> %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : i16
+  // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e4m3x2.rn.relu(float %{{.*}}, float %{{.*}})
+  %res2 = nvvm.cvt.f32x2.to.f8x2 <e4m3> %srcA, %srcB {relu = true, rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : i16
+  llvm.return
+}
+
+// CHECK-LABEL: @convert_f32x2_to_f8x2_e5m2
+llvm.func @convert_f32x2_to_f8x2_e5m2(%srcA : f32, %srcB : f32) {
+  // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e5m2x2.rn(float %{{.*}}, float %{{.*}})
+  %res1 = nvvm.cvt.f32x2.to.f8x2 <e5m2> %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : i16
+  // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e5m2x2.rn.relu(float %{{.*}}, float %{{.*}})
+  %res2 = nvvm.cvt.f32x2.to.f8x2 <e5m2> %srcA, %srcB {relu = true, rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : i16
+  llvm.return
+}
+
+// CHECK-LABEL: @convert_f32x2_to_f8x2_ue8m0
+llvm.func @convert_f32x2_to_f8x2_ue8m0(%srcA : f32, %srcB : f32) {
+  // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.ue8m0x2.rz(float %{{.*}}, float %{{.*}})
+  %res1 = nvvm.cvt.f32x2.to.f8x2 <ue8m0> %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>} : i16
+  // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.ue8m0x2.rp(float %{{.*}}, float %{{.*}})
+  %res2 = nvvm.cvt.f32x2.to.f8x2 <ue8m0> %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rp>} : i16
+  // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.ue8m0x2.rz.satfinite(float %{{.*}}, float %{{.*}})
+  %res3 = nvvm.cvt.f32x2.to.f8x2 <ue8m0> %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>} : i16
+  // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.ue8m0x2.rp.satfinite(float %{{.*}}, float %{{.*}})
+  %res4 = nvvm.cvt.f32x2.to.f8x2 <ue8m0> %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rp>, sat = #nvvm.sat_mode<satfinite>} : i16
+  llvm.return
+}
+
+// CHECK-LABEL: @convert_f32x2_to_f8x2_vector_return
+llvm.func @convert_f32x2_to_f8x2_vector_return(%srcA : f32, %srcB : f32) {
+  // CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.ff.to.e4m3x2.rn(float %{{.*}}, float %{{.*}})
+  // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
+  %res1 = nvvm.cvt.f32x2.to.f8x2 <e4m3> %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xi8>
+  // CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.ff.to.e4m3x2.rn.relu(float %{{.*}}, float %{{.*}})
+  // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res2]] to <2 x i8>
+  %res2 = nvvm.cvt.f32x2.to.f8x2 <e4m3> %srcA, %srcB {relu = true, rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xi8>
+  llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @convert_f16x2_to_f8x2_e4m3
+llvm.func @convert_f16x2_to_f8x2_e4m3(%src : vector<2xf16>) {
+  // CHECK: %{{.*}} = call i16 @llvm.nvvm.f16x2.to.e4m3x2.rn(<2 x half> %{{.*}})
+  %res1 = nvvm.cvt.f16x2.to.f8x2 <e4m3> %src : vector<2xf16> -> i16
+  // CHECK: %{{.*}} = call i16 @llvm.nvvm.f16x2.to.e4m3x2.rn.relu(<2 x half> %{{.*}})
+  %res2 = nvvm.cvt.f16x2.to.f8x2 <e4m3> %src {relu = true} : vector<2xf16> -> i16
+  llvm.return
+}
+
+// CHECK-LABEL: @convert_f16x2_to_f8x2_e5m2
+llvm.func @convert_f16x2_to_f8x2_e5m2(%src : vector<2xf16>) {
+  // CHECK: %{{.*}} = call i16 @llvm.nvvm.f16x2.to.e5m2x2.rn(<2 x half> %{{.*}})
+  %res1 = nvvm.cvt.f16x2.to.f8x2 <e5m2> %src : vector<2xf16> -> i16
+  // CHECK: %{{.*}} = call i16 @llvm.nvvm.f16x2.to.e5m2x2.rn.relu(<2 x half> %{{.*}})
+  %res2 = nvvm.cvt.f16x2.to.f8x2 <e5m2> %src {relu = true} : vector<2xf16> -> i16
+  llvm.return
+}
+
+// CHECK-LABEL: @convert_f16x2_to_f8x2_vector_return
+llvm.func @convert_f16x2_to_f8x2_vector_return(%src : vector<2xf16>) {
+  // CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.f16x2.to.e4m3x2.rn(<2 x half> %{{.*}})
+  // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
+  %res1 = nvvm.cvt.f16x2.to.f8x2 <e4m3> %src : vector<2xf16> -> vector<2xi8>
+  // CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.f16x2.to.e5m2x2.rn(<2 x half> %{{.*}})
+  // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res2]] to <2 x i8>
+  %res2 = nvvm.cvt.f16x2.to.f8x2 <e5m2> %src : vector<2xf16> -> vector<2xi8>
+  llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @convert_bf16x2_to_f8x2_ue8m0
+llvm.func @convert_bf16x2_to_f8x2_ue8m0(%src : vector<2xbf16>) {
+  // CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rz(<2 x bfloat> %{{.*}})
+  %res1 = nvvm.cvt.bf16x2.to.f8x2 <ue8m0> %src {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xbf16> -> i16
+  // CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rp(<2 x bfloat> %{{.*}})
+  %res2 = nvvm.cvt.bf16x2.to.f8x2 <ue8m0> %src {rnd = #nvvm.fp_rnd_mode<rp>} : vector<2xbf16> -> i16
+  // CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rz.satfinite(<2 x bfloat> %{{.*}})
+  %res3 = nvvm.cvt.bf16x2.to.f8x2 <ue8m0> %src {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> -> i16
+  // CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rp.satfinite(<2 x bfloat> %{{.*}})
+  %res4 = nvvm.cvt.bf16x2.to.f8x2 <ue8m0> %src {rnd = #nvvm.fp_rnd_mode<rp>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> -> i16
+  llvm.return
+}
+
+// CHECK-LABEL: @convert_bf16x2_to_f8x2_vector_return
+llvm.func @convert_bf16x2_to_f8x2_vector_return(%src : vector<2xbf16>) {
+  // CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rz(<2 x bfloat> %{{.*}})
+  // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
+  %res1 = nvvm.cvt.bf16x2.to.f8x2 <ue8m0> %src {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xbf16> -> vector<2xi8>
+  // CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rp.satfinite(<2 x bfloat> %{{.*}})
+  // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res2]] to <2 x i8>
+  %res2 = nvvm.cvt.bf16x2.to.f8x2 <ue8m0> %src {rnd = #nvvm.fp_rnd_mode<rp>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> -> vector<2xi8>
+  llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index f87f11daeef54..accec9c7af4f2 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -176,3 +176,75 @@ llvm.func @nvvm_match_sync_any(%val32: i32, %thread_mask: i32) {
   %0 = nvvm.match.sync any %thread_mask, %val32 : i32 -> !llvm.struct<(i32, i1)>
   llvm.return
 }
+
+// -----
+
+llvm.func @nvvm_cvt_float_to_f8x2_invalid_rounding_e4m3(%a : f32, %b : f32) {
+  // expected-error @below {{Only RN rounding mode is supported for conversions from f32x2 to .e4m3x2 or .e5m2x2 types}}
+  %res = nvvm.cvt.f32x2.to.f8x2 <e4m3> %a, %b {rnd = #nvvm.fp_rnd_mode<rz>, sat = #nvvm.sat_mode<satfinite>} : i16
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_float_to_f8x2_invalid_rounding_e5m2(%a : f32, %b : f32) {
+  // expected-error @below {{Only RN rounding mode is supported for conversions from f32x2 to .e4m3x2 or .e5m2x2 types}}
+  %res = nvvm.cvt.f32x2.to.f8x2 <e5m2> %a, %b {rnd = #nvvm.fp_rnd_mode<rp>, sat = #nvvm.sat_mode<satfinite>} : i16
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_float_to_f8x2_invalid_rounding_ue8m0(%a : f32, %b : f32) {
+  // expected-error @below {{Only RZ or RP rounding modes are supported for conversions from f32x2 to .ue8m0x2 type}}
+  %res = nvvm.cvt.f32x2.to.f8x2 <ue8m0> %a, %b {rnd = #nvvm.fp_rnd_mode<rn>} : i16
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_float_to_f8x2_invalid_saturation_e4m3(%a : f32, %b : f32) {
+  // expected-error @below {{Only SATFINITE saturation mode is supported for conversions from f32x2 to .e4m3x2 or .e5m2x2 types}}
+  %res = nvvm.cvt.f32x2.to.f8x2 <e4m3> %a, %b {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<none>} : i16
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_float_to_f8x2_invalid_saturation_e5m2(%a : f32, %b : f32) {
+  // expected-error @below {{Only SATFINITE saturation mode is supported for conversions from f32x2 to .e4m3x2 or .e5m2x2 types}}
+  %res = nvvm.cvt.f32x2.to.f8x2 <e5m2> %a, %b {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<none>} : i16
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_float_to_f8x2_relu_not_supported_ue8m0(%a : f32, %b : f32) {
+  // expected-error @below {{relu not supported for conversions to .ue8m0x2 type}}
+  %res = nvvm.cvt.f32x2.to.f8x2 <ue8m0> %a, %b {rnd = #nvvm.fp_rnd_mode<rp>, relu = true} : i16
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_f16x2_to_f8x2_invalid_type(%src : vector<2xf16>) {
+  // expected-error @below {{Only .e4m3 or .e5m2 types are supported for conversions from f16x2 to f8x2.}}
+  %res = nvvm.cvt.f16x2.to.f8x2 <ue8m0> %src : vector<2xf16> -> i16
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_bf16x2_to_f8x2_invalid_type(%src : vector<2xbf16>) {
+  // expected-error @below {{Only .ue8m0 type is supported for conversions from bf16x2 to f8x2.}}
+  %res = nvvm.cvt.bf16x2.to.f8x2 <e4m3> %src {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xbf16> -> i16
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_bf16x2_to_f8x2_invalid_rounding(%src : vector<2xbf16>) {
+  // expected-error @below {{Only RZ and RP rounding modes are supported for conversions from bf16x2 to f8x2.}}
+  %res = nvvm.cvt.bf16x2.to.f8x2 <ue8m0> %src {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xbf16> -> i16
+  llvm.return
+}



More information about the Mlir-commits mailing list