[Mlir-commits] [mlir] [MLIR][NVVM] Add support for f8x2 conversion (PR #137781)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Apr 29 03:28:35 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-llvm

Author: Srinivasa Ravi (Wolfram70)

<details>
<summary>Changes</summary>

This patch adds the `cvt.to.f8x2` NVVM dialect Op for conversion into f8x2 types.

For more information, see PTX ISA: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt

---
Full diff: https://github.com/llvm/llvm-project/pull/137781.diff


4 Files Affected:

- (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+104) 
- (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+97) 
- (added) mlir/test/Target/LLVMIR/nvvm/cvt_fp8x2.mlir (+71) 
- (modified) mlir/test/Target/LLVMIR/nvvmir-invalid.mlir (+88) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 27d54e7abeda9..f5eb91bc029f5 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1120,6 +1120,110 @@ def NVVM_CvtToF6x2Op : NVVM_Op<"cvt.to.f6x2"> {
   }];
 }
 
+def CVTFP8E4M3 : I32EnumAttrCase<"E4M3", 0, "e4m3">;
+def CVTFP8E5M2 : I32EnumAttrCase<"E5M2", 1, "e5m2">;
+def CVTFP8UE8M0 : I32EnumAttrCase<"UE8M0", 2, "ue8m0">;
+
+def CVTFP8Type : I32EnumAttr<"CVTFP8Type", "NVVM CVTFP8Type kind",
+  [CVTFP8E4M3, CVTFP8E5M2, CVTFP8UE8M0]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::NVVM";
+}
+def CVTFP8TypeAttr : EnumAttr<NVVM_Dialect, CVTFP8Type, "cvt_fp8_type"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
+def NVVM_CvtToF8x2Op : NVVM_Op<"cvt.to.f8x2"> {
+  let summary = "Convert a pair of f32 or fp16 inputs to f8x2";
+  let description = [{
+    This Op converts each of the given float input types to the specified f8 
+    type.
+    The result `dst` is either represented as an i16 type or a vector
+    of two f8 types.
+    The following table describes the supported conversions and their formats:
+    ```
+    |-----------|-----------|--------------------------------------------------|
+    | Src Type  | Dst Type  | Description                                      |
+    |-----------|-----------|--------------------------------------------------|
+    |   f16x2   |  e4m3x2   |  Only operand `a` must be provided and it must   |
+    |           |  e5m2x2   |  be a vector of two F16s.                        |
+    |           |           |  If `dst` is returned as an i16 type, the        |
+    |           |           |  converted values are packed such that the       |
+    |           |           |  value converted from the first element of `a`   |
+    |           |           |  is stored in the upper 8 bits of `dst` and the  |
+    |           |           |  value converted from the second element of `a`  |
+    |           |           |  is stored in the lower 8 bits of `dst`.         |
+    |           |           |  If `dst` is returned as a vector type, each     |
+    |           |           |  converted value from `a` is stored as an i8     |
+    |           |           |  element in the vector.                          |
+    |-----------|-----------|--------------------------------------------------|
+    |   bf16x2  |  ue8m0x2  |  Only operand `a` must be provided and it must   |
+    |           |           |  be a vector of two BF16s.                       |
+    |           |           |  If `dst` is returned as an i16 type, the        |
+    |           |           |  converted values are packed such that the       |
+    |           |           |  value converted from the first element of `a`   |
+    |           |           |  is stored in the upper 8 bits of `dst` and the  |
+    |           |           |  value converted from the second element of `a`  |
+    |           |           |  is stored in the lower 8 bits of `dst`.         |
+    |           |           |  If `dst` is returned as a vector type, each     |
+    |           |           |  converted value from `a` is stored as an i8     |
+    |           |           |  element in the vector.                          |
+    |-----------|-----------|--------------------------------------------------|
+    |  f32, f32 |  e4m3x2   |  Both operands `a` and `b` must be provided and  |
+    |           |  e5m2x2   |  they must be F32 values.                        |
+    |           |  ue8m0x2  |  If `dst` is returned as an i16 type, the        |
+    |           |           |  converted values are packed such that the       |
+    |           |           |  value converted from `a` is stored in the       |
+    |           |           |  upper 8 bits of `dst` and the value converted   |
+    |           |           |  from `b` is stored in the lower 8 bits of       |
+    |           |           |  `dst`. If `dst` is returned as a vector type,   |
+    |           |           |  each converted value is stored as an i8         |
+    |           |           |  element in the vector.                          |
+    |-----------|-----------|--------------------------------------------------|
+    ```
+    The `relu` attribute, when set, lowers to the '.relu' variant of
+    the cvt instruction for conversions to the signed f8 types (e4m3 and e5m2).
+    The `rnd` and `sat` attributes specify the rounding and saturation modes respectively.
+
+    [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
+  }];
+  let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
+  let arguments = (ins 
+    CVTFP8TypeAttr:$type,
+    AnyTypeOf<[F32, VectorOfLengthAndType<[2], [F16, BF16]>]>:$a,
+    Optional<F32>:$b,
+    DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd,
+    DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
+    DefaultValuedAttr<BoolAttr, "false">:$relu);
+  let assemblyFormat = "$type $a (`,` $b^)? attr-dict `:` type($a) (`,` type($b)^)? `->` type($dst)";
+  
+  let extraClassDeclaration = [{
+    bool isFromF32Type();
+    static llvm::Intrinsic::ID getIntrinsicID(NVVM::CVTFP8Type to,
+                                              bool isFromF32Type,
+                                              NVVM::FPRoundingMode rnd,
+                                              NVVM::SaturationMode sat,
+                                              bool hasRelu);
+  }];
+
+  string llvmBuilder = [{
+    auto intId = NVVM::CvtToF8x2Op::getIntrinsicID($type, op.isFromF32Type(), $rnd, $sat, $relu);
+    llvm::Value *packedI16;
+    if(op.isFromF32Type())
+      packedI16 = createIntrinsicCall(builder, intId, {$a, $b});
+    else
+      packedI16 = createIntrinsicCall(builder, intId, {$a});
+    if(op.getDst().getType().isInteger(16))
+      $dst = packedI16;
+    else
+      $dst = builder.CreateBitCast(packedI16,
+                      llvm::FixedVectorType::get(
+                        llvm::Type::getInt8Ty(builder.getContext()), 2));
+  }];
+  
+  let hasVerifier = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // NVVM MMA Ops
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 18453aa7f6ea9..c30c45abbdd02 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -32,6 +32,7 @@
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/AsmParser/Parser.h"
 #include "llvm/IR/Attributes.h"
+#include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/IntrinsicsNVPTX.h"
 #include "llvm/IR/Type.h"
@@ -133,6 +134,68 @@ LogicalResult CvtFloatToTF32Op::verify() {
   return success();
 }
 
+bool CvtToF8x2Op::isFromF32Type() { return getA().getType().isF32(); }
+
+LogicalResult CvtToF8x2Op::verify() {
+  bool isFromF32 = false;
+  bool isFromF16x2 = false;
+  bool isFromBF16x2 = false;
+
+  bool isRoundingModeRN = getRnd() == NVVM::FPRoundingMode::RN;
+  bool isRoundingModeRZ = getRnd() == NVVM::FPRoundingMode::RZ;
+  bool isRoundingModeRP = getRnd() == NVVM::FPRoundingMode::RP;
+
+  bool isSatFinite = getSat() == NVVM::SaturationMode::SATFINITE;
+
+  bool hasRelu = getRelu();
+
+  if (auto vecType = dyn_cast<VectorType>(getA().getType())) {
+    isFromF16x2 = vecType.getElementType().isF16();
+    isFromBF16x2 = vecType.getElementType().isBF16();
+  } else {
+    isFromF32 = true;
+  }
+
+  if (isFromF32) {
+    if (!(getODSOperands(1).size() > 0))
+      return emitOpError("expected two f32 inputs for converting from f32");
+  } else {
+    if (getODSOperands(1).size() > 0)
+      return emitOpError(
+          "expected only a single f32, vector<2xf16> or vector<2xbf16> input "
+          "for converting from f16x2 or bf16x2, got two inputs instead.");
+  }
+
+  switch (getType()) {
+  case NVVM::CVTFP8Type::E4M3:
+  case NVVM::CVTFP8Type::E5M2:
+    if (!(isFromF32 || isFromF16x2))
+      return emitOpError("expected f32 or f16x2 input for conversions to "
+                         ".e4m3x2 or .e5m2x2 types");
+    if (!isRoundingModeRN)
+      return emitOpError("RN rounding mode required for conversions to .e4m3x2 "
+                         "or .e5m2x2 types");
+    if (!isSatFinite)
+      return emitOpError("SATFINITE saturation mode required for conversions "
+                         "to .e4m3x2 or .e5m2x2 types");
+    break;
+  case NVVM::CVTFP8Type::UE8M0:
+    if (!(isFromF32 || isFromBF16x2))
+      return emitOpError(
+          "expected f32 or bf16x2 input for conversions to .ue8m0x2 type");
+    if (!(isRoundingModeRP || isRoundingModeRZ))
+      return emitOpError(
+          "RP or RZ rounding mode required for conversions to .ue8m0x2 type");
+    if (hasRelu)
+      return emitOpError("relu not supported for conversions to .ue8m0x2 type");
+    break;
+  default:
+    return emitOpError("unsupported FP8 type");
+  }
+
+  return success();
+}
+
 LogicalResult BulkStoreOp::verify() {
   if (getInitVal() != 0)
     return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
@@ -1304,6 +1367,40 @@ llvm::Intrinsic::ID CvtToF6x2Op::getIntrinsicID(NVVM::CVTFP6Type type,
   }
 }
 
+#define CVT_TO_UE8M0X2_IMPL(fromtype, rndm, has_sat)                           \
+  has_sat ? llvm::Intrinsic::nvvm_##fromtype##_to_ue8m0x2##rndm##_satfinite    \
+          : llvm::Intrinsic::nvvm_##fromtype##_to_ue8m0x2##rndm
+
+#define GET_CVT_TO_UE8M0X2_ID(fromtype, rnd, has_sat)                          \
+  (rnd == NVVM::FPRoundingMode::RZ)                                            \
+      ? CVT_TO_UE8M0X2_IMPL(fromtype, _rz, has_sat)                            \
+      : CVT_TO_UE8M0X2_IMPL(fromtype, _rp, has_sat)
+
+#define GET_CVT_TO_F8X2_ID(fromtype, totype, has_relu)                         \
+  has_relu ? llvm::Intrinsic::nvvm_##fromtype##_to_##totype##_rn_relu          \
+           : llvm::Intrinsic::nvvm_##fromtype##_to_##totype##_rn
+
+llvm::Intrinsic::ID CvtToF8x2Op::getIntrinsicID(NVVM::CVTFP8Type to,
+                                                bool isFromF32Type,
+                                                NVVM::FPRoundingMode rnd,
+                                                NVVM::SaturationMode sat,
+                                                bool hasRelu) {
+  bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
+
+  switch (to) {
+  case NVVM::CVTFP8Type::E4M3:
+    return isFromF32Type ? GET_CVT_TO_F8X2_ID(ff, e4m3x2, hasRelu)
+                         : GET_CVT_TO_F8X2_ID(f16x2, e4m3x2, hasRelu);
+  case NVVM::CVTFP8Type::E5M2:
+    return isFromF32Type ? GET_CVT_TO_F8X2_ID(ff, e5m2x2, hasRelu)
+                         : GET_CVT_TO_F8X2_ID(f16x2, e5m2x2, hasRelu);
+  case NVVM::CVTFP8Type::UE8M0:
+    return isFromF32Type ? GET_CVT_TO_UE8M0X2_ID(ff, rnd, hasSatFinite)
+                         : GET_CVT_TO_UE8M0X2_ID(bf16x2, rnd, hasSatFinite);
+  }
+  llvm_unreachable("Invalid CVTFP8Type for CvtToF8x2Op");
+}
+
 llvm::Intrinsic::ID
 Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op,
                                       LLVM::ModuleTranslation &mt,
diff --git a/mlir/test/Target/LLVMIR/nvvm/cvt_fp8x2.mlir b/mlir/test/Target/LLVMIR/nvvm/cvt_fp8x2.mlir
new file mode 100644
index 0000000000000..06573ce53676f
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/cvt_fp8x2.mlir
@@ -0,0 +1,71 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: @convert_float_to_fp8x2_packed
+llvm.func @convert_float_to_fp8x2_packed(%srcA : f32, %srcB : f32) -> !llvm.void {
+  // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e4m3x2.rn(float %{{.*}}, float %{{.*}})
+  %res1 = nvvm.cvt.to.f8x2 <e4m3> %srcA, %srcB {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>} : f32, f32 -> i16
+  // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e5m2x2.rn(float %{{.*}}, float %{{.*}})
+  %res2 = nvvm.cvt.to.f8x2 <e5m2> %srcA, %srcB {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>} : f32, f32 -> i16
+  // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.ue8m0x2.rz(float %{{.*}}, float %{{.*}})
+  %res3 = nvvm.cvt.to.f8x2 <ue8m0> %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>} : f32, f32 -> i16
+  // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.ue8m0x2.rp(float %{{.*}}, float %{{.*}})
+  %res4 = nvvm.cvt.to.f8x2 <ue8m0> %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rp>} : f32, f32 -> i16
+  llvm.return
+}
+
+// CHECK-LABEL: @convert_float_to_fp8x2_vector
+llvm.func @convert_float_to_fp8x2_vector(%srcA : f32, %srcB : f32) -> !llvm.void {
+  // CHECK: %[[res0:.*]] = call i16 @llvm.nvvm.ff.to.e4m3x2.rn(float %{{.*}}, float %{{.*}})
+  // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res0]] to <2 x i8>
+  %res1 = nvvm.cvt.to.f8x2 <e4m3> %srcA, %srcB {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>} : f32, f32 -> vector<2xi8>
+  // CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.ff.to.e5m2x2.rn(float %{{.*}}, float %{{.*}})
+  // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
+  %res2 = nvvm.cvt.to.f8x2 <e5m2> %srcA, %srcB {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>} : f32, f32 -> vector<2xi8>
+  // CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.ff.to.ue8m0x2.rz(float %{{.*}}, float %{{.*}})
+  // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res2]] to <2 x i8>
+  %res3 = nvvm.cvt.to.f8x2 <ue8m0> %srcA, %srcB {rnd = #nvvm.fp_rnd_mode<rz>} : f32, f32 -> vector<2xi8>
+  llvm.return
+}
+
+// CHECK-LABEL: @convert_float_to_fp8x2_with_relu
+llvm.func @convert_float_to_fp8x2_with_relu(%srcA : f32, %srcB : f32) -> !llvm.void {
+  // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e4m3x2.rn.relu(float %{{.*}}, float %{{.*}})
+  %res1 = nvvm.cvt.to.f8x2 <e4m3> %srcA, %srcB {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>, relu = true} : f32, f32 -> i16
+  // CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e5m2x2.rn.relu(float %{{.*}}, float %{{.*}})
+  %res2 = nvvm.cvt.to.f8x2 <e5m2> %srcA, %srcB {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>, relu = true} : f32, f32 -> i16
+  llvm.return
+}
+
+// CHECK-LABEL: @convert_f16x2_to_fp8x2
+llvm.func @convert_f16x2_to_fp8x2(%src : vector<2xf16>) -> !llvm.void {
+  // CHECK: %[[res0:.*]] = call i16 @llvm.nvvm.f16x2.to.e4m3x2.rn(<2 x half> %{{.*}})
+  // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res0]] to <2 x i8>
+  %res1 = nvvm.cvt.to.f8x2 <e4m3> %src {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xf16> -> vector<2xi8>
+  // CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.f16x2.to.e5m2x2.rn(<2 x half> %{{.*}})
+  // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
+  %res2 = nvvm.cvt.to.f8x2 <e5m2> %src {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xf16> -> vector<2xi8>
+  llvm.return
+}
+
+
+// CHECK-LABEL: @convert_bf16x2_to_fp8x2
+llvm.func @convert_bf16x2_to_fp8x2(%src : vector<2xbf16>) -> !llvm.void {
+  // CHECK: %[[res0:.*]] = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rz(<2 x bfloat> %{{.*}})
+  // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res0]] to <2 x i8>
+  %res1 = nvvm.cvt.to.f8x2 <ue8m0> %src {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xbf16> -> vector<2xi8>
+  // CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rp(<2 x bfloat> %{{.*}})
+  // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
+  %res2 = nvvm.cvt.to.f8x2 <ue8m0> %src {rnd = #nvvm.fp_rnd_mode<rp>} : vector<2xbf16> -> vector<2xi8>
+  llvm.return
+}
+
+// CHECK-LABEL: @convert_f16x2_to_fp8x2_with_relu
+llvm.func @convert_f16x2_to_fp8x2_with_relu(%src : vector<2xf16>) -> !llvm.void {
+  // CHECK: %[[res0:.*]] = call i16 @llvm.nvvm.f16x2.to.e4m3x2.rn.relu(<2 x half> %{{.*}})
+  // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res0]] to <2 x i8>
+  %res1 = nvvm.cvt.to.f8x2 <e4m3> %src {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>, relu = true} : vector<2xf16> -> vector<2xi8>
+  // CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.f16x2.to.e5m2x2.rn.relu(<2 x half> %{{.*}})
+  // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
+  %res2 = nvvm.cvt.to.f8x2 <e5m2> %src {sat = #nvvm.sat_mode<satfinite>, rnd = #nvvm.fp_rnd_mode<rn>, relu = true} : vector<2xf16> -> vector<2xi8>
+  llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index f87f11daeef54..fc00ea6ee7003 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -176,3 +176,91 @@ llvm.func @nvvm_match_sync_any(%val32: i32, %thread_mask: i32) {
   %0 = nvvm.match.sync any %thread_mask, %val32 : i32 -> !llvm.struct<(i32, i1)>
   llvm.return
 }
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_wrong_rounding_e4m3(%a : f32, %b : f32) {
+  // expected-error @below {{RN rounding mode required for conversions to .e4m3x2 or .e5m2x2 types}}
+  %res = nvvm.cvt.to.f8x2 <e4m3> %a, %b {rnd = #nvvm.fp_rnd_mode<rz>} : f32, f32 -> i16
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_wrong_rounding_e5m2(%a : f32, %b : f32) {
+  // expected-error @below {{RN rounding mode required for conversions to .e4m3x2 or .e5m2x2 types}}
+  %res = nvvm.cvt.to.f8x2 <e5m2> %a, %b {rnd = #nvvm.fp_rnd_mode<rp>} : f32, f32 -> i16
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_wrong_rounding_ue8m0(%a : f32, %b : f32) {
+  // expected-error @below {{RP or RZ rounding mode required for conversions to .ue8m0x2 type}}
+  %res = nvvm.cvt.to.f8x2 <ue8m0> %a, %b {rnd = #nvvm.fp_rnd_mode<rn>} : f32, f32 -> i16
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_wrong_saturation_e4m3(%a : f32, %b : f32) {
+  // expected-error @below {{SATFINITE saturation mode required for conversions to .e4m3x2 or .e5m2x2 types}}
+  %res = nvvm.cvt.to.f8x2 <e4m3> %a, %b {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<none>} : f32, f32 -> i16
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_wrong_saturation_e5m2(%a : f32, %b : f32) {
+  // expected-error @below {{SATFINITE saturation mode required for conversions to .e4m3x2 or .e5m2x2 types}}
+  %res = nvvm.cvt.to.f8x2 <e5m2> %a, %b {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<none>} : f32, f32 -> i16
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_relu_not_supported_ue8m0(%a : f32, %b : f32) {
+  // expected-error @below {{relu not supported for conversions to .ue8m0x2 type}}
+  %res = nvvm.cvt.to.f8x2 <ue8m0> %a, %b {rnd = #nvvm.fp_rnd_mode<rz>, relu = true} : f32, f32 -> i16
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_wrong_from_type_e4m3(%src : vector<2xbf16>) {
+  // expected-error @below {{expected f32 or f16x2 input for conversions to .e4m3x2 or .e5m2x2 types}}
+  %res = nvvm.cvt.to.f8x2 <e4m3> %src : vector<2xbf16> -> i16
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_wrong_from_type_e5m2(%src : vector<2xbf16>) {
+  // expected-error @below {{expected f32 or f16x2 input for conversions to .e4m3x2 or .e5m2x2 types}}
+  %res = nvvm.cvt.to.f8x2 <e5m2> %src : vector<2xbf16> -> i16
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_wrong_from_type_ue8m0(%src : vector<2xf16>) {
+  // expected-error @below {{expected f32 or bf16x2 input for conversions to .ue8m0x2 type}}
+  %res = nvvm.cvt.to.f8x2 <ue8m0> %src {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xf16> -> i16
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_two_inputs_with_fromfp16x2(%src : vector<2xf16>, %b : f32) {
+  // expected-error @below {{expected only a single f32, vector<2xf16> or vector<2xbf16> input for converting from f16x2 or bf16x2, got two inputs instead.}}
+  %res = nvvm.cvt.to.f8x2 <e4m3> %src, %b : vector<2xf16>, f32 -> i16
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_to_f8x2_missing_second_input(%a : f32) {
+  // expected-error @below {{expected two f32 inputs for converting from f32}}
+  %res = nvvm.cvt.to.f8x2 <e4m3> %a {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : f32 -> i16
+  llvm.return
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/137781


More information about the Mlir-commits mailing list