[llvm-branch-commits] [mlir] ab31c28 - [MLIR][NVVM] Add support for narrow-fp to bf16x2 conversions (#200157)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Tue Jun 9 03:04:33 PDT 2026


Author: Srinivasa Ravi
Date: 2026-06-09T14:22:44+05:30
New Revision: ab31c28892a9ad5e016e94500861c93018736e7b

URL: https://github.com/llvm/llvm-project/commit/ab31c28892a9ad5e016e94500861c93018736e7b
DIFF: https://github.com/llvm/llvm-project/commit/ab31c28892a9ad5e016e94500861c93018736e7b.diff

LOG: [MLIR][NVVM] Add support for narrow-fp to bf16x2 conversions (#200157)

This change adds the following NVVM Ops to support narrow-fp to bf16x2
conversions:

- `nvvm.convert.f6x2.to.bf16x2`
- `nvvm.convert.f4x2.to.bf16x2`
- `nvvm.convert.f8x2.to.bf16x2` (updated to allow `E4M3FN` and `E5M2`
types)

Also removes unnecessary verifiers for narrow-fp to `f16x2` conversions
to instead use `TypeAttrOf` to validate the source type in the ODS
definition.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
    mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
    mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
    mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir
    mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir
    mlir/test/Target/LLVMIR/nvvmir-invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 01abc7e70f57c..9cbf76b9210be 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -2201,41 +2201,127 @@ def NVVM_ConvertBF16x2ToF8x2Op : NVVM_Op<"convert.bf16x2.to.f8x2"> {
   }];
 }
 
-class NVVM_ConvertToFP16x2Op_Base <string srcType, Type srcArgType, string dstType>
-: NVVM_SingleResultIntrinsicOp<"convert." # !tolower(srcType) # "x2.to." # !tolower(dstType) # "x2", [], "$dst"> {
-  let summary = "Convert a pair of " # !tolower(srcType) # " inputs to " # !tolower(dstType) # "x2";
+def SaturationModeSatfiniteOrNone :
+  ConfinedAttr<SaturationModeAttr, [EnumAttrIsOneOf<SaturationModeAttr,
+                [SaturationModeNone, SaturationModeFinite]>]>;
+
+class NVVM_ConvertToFP16x2Op_Base <string srcTypeStr, Type srcStorageType, string dstTypeStr, list<Type> supportedTypes, int needVerify = 0>
+: NVVM_SingleResultIntrinsicOp<"convert." # !tolower(srcTypeStr) # "x2.to." # !tolower(dstTypeStr) # "x2", [], "$dst"> {
+  let summary = "Convert a pair of " # !tolower(srcTypeStr) # " inputs to " # !tolower(dstTypeStr) # "x2";
   let description = [{
-    This Op converts the given }] # !tolower(srcType) # [{ inputs in a }] #
-    !if(!eq(srcType, "F4"), "packed i8", "i8x2 vector") # [{ to }] #
-    !tolower(dstType) # [{.
-
-    The result `dst` is represented as a vector of }] # !tolower(dstType) # [{ elements.
-    }] #
-    !if(!eq(dstType, "F16"),
-    [{The `relu` attribute, when set, lowers to the '.relu' variant of 
-    the cvt instruction."}], "") # [{
-    
+    This Op converts the given }] # !tolower(srcTypeStr) # [{ inputs in a }] #
+    !if(!eq(srcTypeStr, "F4"), "packed i8", "i8x2 vector") # [{ to }] #
+    !tolower(dstTypeStr) # [{.
+
+    The result `dst` is represented as a vector of }] # !tolower(dstTypeStr) # [{ elements.
+
+    The `relu` attribute, when set, lowers to the '.relu' variant of
+    the cvt instruction.}] #
+
+    !if(!eq(dstTypeStr, "BF16"),
+    [{
+
+    The `sat` attribute specifies the saturation mode.
+
+    The optional scaling-factors for each of the inputs are provided through
+    the operand `scaleFactor` as a packed i16 type. Only `ue8m0` is supported
+    as the type of the scale-factor currently.}], "") # [{
+
     [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
   }];
-  let results = (outs VectorOfLengthAndType<[2], [!cast<Type>(dstType)]>:$dst);
-  let arguments = !if(!eq(dstType, "F16"),
-    (ins srcArgType:$src,
-         DefaultValuedAttr<BoolAttr, "false">:$relu,
-         TypeAttr:$srcType),
-    (ins srcArgType:$src,
-         TypeAttr:$srcType));
-  let assemblyFormat = "$src attr-dict `:` type($src) `(` $srcType `)` `->` type($dst)";
-  let hasVerifier = 1;
+  let results = (outs VectorOfLengthAndType<[2], [!cast<Type>(dstTypeStr)]>:$dst);
+  let arguments = !if(!eq(dstTypeStr, "F16"),
+    (ins srcStorageType:$src,
+         TypeAttrOf<AnyTypeOf<supportedTypes>>:$srcType,
+         DefaultValuedAttr<BoolAttr, "false">:$relu),
+    (ins srcStorageType:$src,
+         Optional<I16>:$scaleFactor,
+         TypeAttrOf<AnyTypeOf<supportedTypes>>:$srcType,
+         DefaultValuedAttr<SaturationModeSatfiniteOrNone, "SaturationMode::NONE">:$sat,
+         DefaultValuedAttr<BoolAttr, "false">:$relu));
+  let assemblyFormat = 
+    !if(!eq(dstTypeStr, "F16"),
+      "$src attr-dict `:` type($src) `(` $srcType `)` `->` type($dst)",
+      "$src (`,` $scaleFactor^)? attr-dict `:` type($src) `(` $srcType `)` `->` type($dst)");
+  let hasVerifier = needVerify;
 }
 
 def NVVM_ConvertF8x2ToF16x2Op :
-  NVVM_ConvertToFP16x2Op_Base<"F8", VectorOfLengthAndType<[2], [I8]>, "F16">;
+  NVVM_ConvertToFP16x2Op_Base<"F8", VectorOfLengthAndType<[2], [I8]>, "F16", 
+    [F8E4M3FN, F8E5M2]>;
 def NVVM_ConvertF8x2ToBF16x2Op :
-  NVVM_ConvertToFP16x2Op_Base<"F8", VectorOfLengthAndType<[2], [I8]>, "BF16">;
+  NVVM_ConvertToFP16x2Op_Base<"F8", VectorOfLengthAndType<[2], [I8]>, "BF16",
+    [F8E8M0FNU, F8E4M3FN, F8E5M2], 1> {
+  let append description = [{
+
+    Example:
+
+    ```mlir
+    // Basic conversion from f8E4M3FN.
+    %res1 = nvvm.convert.f8x2.to.bf16x2 %src
+        : vector<2xi8> (f8E4M3FN) -> vector<2xbf16>
+
+    // Conversion from f8E5M2 with relu and saturation.
+    %res2 = nvvm.convert.f8x2.to.bf16x2 %src
+        {relu = true, sat = #nvvm.sat_mode<satfinite>}
+        : vector<2xi8> (f8E5M2) -> vector<2xbf16>
+
+    // Conversion with a packed ue8m0 scale-factor.
+    %res3 = nvvm.convert.f8x2.to.bf16x2 %src, %scaleFactor
+        : vector<2xi8> (f8E4M3FN) -> vector<2xbf16>
+    ```
+  }];
+}
 def NVVM_ConvertF6x2ToF16x2Op :
-  NVVM_ConvertToFP16x2Op_Base<"F6", VectorOfLengthAndType<[2], [I8]>, "F16">;
+  NVVM_ConvertToFP16x2Op_Base<"F6", VectorOfLengthAndType<[2], [I8]>, "F16", 
+    [F6E2M3FN, F6E3M2FN]>;
+def NVVM_ConvertF6x2ToBF16x2Op :
+  NVVM_ConvertToFP16x2Op_Base<"F6", VectorOfLengthAndType<[2], [I8]>, "BF16",
+    [F6E2M3FN, F6E3M2FN]> {
+  let append description = [{
+
+    Example:
+
+    ```mlir
+    // Basic conversion from f6E2M3FN.
+    %res1 = nvvm.convert.f6x2.to.bf16x2 %src
+        : vector<2xi8> (f6E2M3FN) -> vector<2xbf16>
+
+    // Conversion from f6E3M2FN with relu and saturation.
+    %res2 = nvvm.convert.f6x2.to.bf16x2 %src
+        {relu = true, sat = #nvvm.sat_mode<satfinite>}
+        : vector<2xi8> (f6E3M2FN) -> vector<2xbf16>
+
+    // Conversion with a packed ue8m0 scale-factor.
+    %res3 = nvvm.convert.f6x2.to.bf16x2 %src, %scaleFactor
+        : vector<2xi8> (f6E2M3FN) -> vector<2xbf16>
+    ```
+  }];
+}
 def NVVM_ConvertF4x2ToF16x2Op :
-  NVVM_ConvertToFP16x2Op_Base<"F4", I8, "F16">;
+  NVVM_ConvertToFP16x2Op_Base<"F4", I8, "F16", [F4E2M1FN]>;
+def NVVM_ConvertF4x2ToBF16x2Op :
+  NVVM_ConvertToFP16x2Op_Base<"F4", I8, "BF16", [F4E2M1FN]> {
+  let append description = [{
+
+    Example:
+
+    ```mlir
+    // Basic conversion; the f4x2 source is packed in a single i8.
+    %res1 = nvvm.convert.f4x2.to.bf16x2 %src
+        : i8 (f4E2M1FN) -> vector<2xbf16>
+
+    // Conversion with relu and saturation.
+    %res2 = nvvm.convert.f4x2.to.bf16x2 %src
+        {relu = true, sat = #nvvm.sat_mode<satfinite>}
+        : i8 (f4E2M1FN) -> vector<2xbf16>
+
+    // Conversion with a packed ue8m0 scale-factor.
+    %res3 = nvvm.convert.f4x2.to.bf16x2 %src, %scaleFactor
+        : i8 (f4E2M1FN) -> vector<2xbf16>
+    ```
+  }];
+}
 
 def NVVM_ConvertF32x2ToS2F6x2Op : NVVM_Op<"convert.f32x2.to.s2f6x2"> {
   let summary = "Convert a pair of f32 inputs to S2F6x2";

diff  --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 00c997ec7a031..2d929f740f137 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -542,47 +542,20 @@ LogicalResult ConvertF32x2ToF4x2Op::verify() {
   return success();
 }
 
-LogicalResult ConvertF8x2ToF16x2Op::verify() {
-  mlir::MLIRContext *ctx = getContext();
-
-  if (!llvm::isa<Float8E4M3FNType, Float8E5M2Type>(getSrcType()))
-    return emitOpError("Only ")
-           << mlir::Float8E4M3FNType::get(ctx) << " and "
-           << mlir::Float8E5M2Type::get(ctx)
-           << " types are supported for conversions from f8x2 to f16x2.";
-
-  return success();
-}
-
 LogicalResult ConvertF8x2ToBF16x2Op::verify() {
   mlir::MLIRContext *ctx = getContext();
-  if (!llvm::isa<Float8E8M0FNUType>(getSrcType()))
-    return emitOpError("Only ")
-           << mlir::Float8E8M0FNUType::get(ctx)
-           << " type is supported for conversions from f8x2 to bf16x2.";
-
-  return success();
-}
-
-LogicalResult ConvertF6x2ToF16x2Op::verify() {
-  mlir::MLIRContext *ctx = getContext();
-
-  if (!llvm::isa<Float6E2M3FNType, Float6E3M2FNType>(getSrcType()))
-    return emitOpError("Only ")
-           << mlir::Float6E2M3FNType::get(ctx) << " and "
-           << mlir::Float6E3M2FNType::get(ctx)
-           << " types are supported for conversions from f6x2 to f16x2.";
-
-  return success();
-}
-
-LogicalResult ConvertF4x2ToF16x2Op::verify() {
-  mlir::MLIRContext *ctx = getContext();
-
-  if (!llvm::isa<Float4E2M1FNType>(getSrcType()))
-    return emitOpError("Only ")
-           << mlir::Float4E2M1FNType::get(ctx)
-           << " type is supported for conversions from f4x2 to f16x2.";
+  if (llvm::isa<Float8E8M0FNUType>(getSrcType())) {
+    if (getSat() != SaturationMode::NONE)
+      return emitOpError(
+                 "Only NONE saturation mode is supported for conversions from ")
+             << Float8E8M0FNUType::get(ctx) << " type";
+    if (getScaleFactor())
+      return emitOpError("scaleFactor not supported for conversions from ")
+             << Float8E8M0FNUType::get(ctx) << " type";
+    if (getRelu())
+      return emitOpError("relu not supported for conversions from ")
+             << Float8E8M0FNUType::get(ctx) << " type";
+  }
 
   return success();
 }
@@ -4798,13 +4771,52 @@ NVVM::IDArgPair ConvertF8x2ToF16x2Op::getIntrinsicIDAndArgs(
 NVVM::IDArgPair ConvertF8x2ToBF16x2Op::getIntrinsicIDAndArgs(
     Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
   auto curOp = cast<NVVM::ConvertF8x2ToBF16x2Op>(op);
+  bool hasScale = static_cast<bool>(curOp.getScaleFactor());
+  bool hasSatfinite = curOp.getSat() == NVVM::SaturationMode::SATFINITE;
+  bool hasRelu = curOp.getRelu();
+
+  static constexpr llvm::Intrinsic::ID E4M3Ids[] = {
+      llvm::Intrinsic::nvvm_e4m3x2_to_bf16x2_rn_scale_n2_ue8m0,
+      llvm::Intrinsic::nvvm_e4m3x2_to_bf16x2_rn_relu_scale_n2_ue8m0,
+      llvm::Intrinsic::nvvm_e4m3x2_to_bf16x2_rn_satfinite_scale_n2_ue8m0,
+      llvm::Intrinsic::nvvm_e4m3x2_to_bf16x2_rn_relu_satfinite_scale_n2_ue8m0,
+  };
 
-  llvm::Intrinsic::ID intId = llvm::Intrinsic::nvvm_ue8m0x2_to_bf16x2;
+  static constexpr llvm::Intrinsic::ID E5M2Ids[] = {
+      llvm::Intrinsic::nvvm_e5m2x2_to_bf16x2_rn_scale_n2_ue8m0,
+      llvm::Intrinsic::nvvm_e5m2x2_to_bf16x2_rn_relu_scale_n2_ue8m0,
+      llvm::Intrinsic::nvvm_e5m2x2_to_bf16x2_rn_satfinite_scale_n2_ue8m0,
+      llvm::Intrinsic::nvvm_e5m2x2_to_bf16x2_rn_relu_satfinite_scale_n2_ue8m0,
+  };
+
+  llvm::Intrinsic::ID intId =
+      llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(curOp.getSrcType())
+          .Case([&](Float8E8M0FNUType type) {
+            return llvm::Intrinsic::nvvm_ue8m0x2_to_bf16x2;
+          })
+          .Case([&](Float8E4M3FNType type) {
+            return E4M3Ids[hasSatfinite << 1 | hasRelu];
+          })
+          .Case([&](Float8E5M2Type type) {
+            return E5M2Ids[hasSatfinite << 1 | hasRelu];
+          })
+          .Default([](mlir::Type type) {
+            llvm_unreachable("Invalid type for ConvertF8x2ToBF16x2Op");
+            return llvm::Intrinsic::not_intrinsic;
+          });
   llvm::Value *packedI16 =
       builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
                             llvm::Type::getInt16Ty(builder.getContext()));
 
-  return {intId, {packedI16}};
+  llvm::SmallVector<llvm::Value *> args;
+  args.push_back(packedI16);
+  if (!isa<Float8E8M0FNUType>(curOp.getSrcType()))
+    args.push_back(
+        hasScale ? mt.lookupValue(curOp.getScaleFactor())
+                 : builder.getInt16(0x7f7f)); // default scale factor (value of
+                                              // 1 for both elements)
+
+  return {intId, std::move(args)};
 }
 
 NVVM::IDArgPair ConvertF6x2ToF16x2Op::getIntrinsicIDAndArgs(
@@ -4835,6 +4847,52 @@ NVVM::IDArgPair ConvertF6x2ToF16x2Op::getIntrinsicIDAndArgs(
   return {intId, {packedI16}};
 }
 
+NVVM::IDArgPair ConvertF6x2ToBF16x2Op::getIntrinsicIDAndArgs(
+    Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+  auto curOp = cast<NVVM::ConvertF6x2ToBF16x2Op>(op);
+  bool hasScale = static_cast<bool>(curOp.getScaleFactor());
+  bool hasSatfinite = curOp.getSat() == NVVM::SaturationMode::SATFINITE;
+  bool hasRelu = curOp.getRelu();
+
+  static constexpr llvm::Intrinsic::ID E2M3Ids[] = {
+      llvm::Intrinsic::nvvm_e2m3x2_to_bf16x2_rn_scale_n2_ue8m0,
+      llvm::Intrinsic::nvvm_e2m3x2_to_bf16x2_rn_relu_scale_n2_ue8m0,
+      llvm::Intrinsic::nvvm_e2m3x2_to_bf16x2_rn_satfinite_scale_n2_ue8m0,
+      llvm::Intrinsic::nvvm_e2m3x2_to_bf16x2_rn_relu_satfinite_scale_n2_ue8m0,
+  };
+
+  static constexpr llvm::Intrinsic::ID E3M2Ids[] = {
+      llvm::Intrinsic::nvvm_e3m2x2_to_bf16x2_rn_scale_n2_ue8m0,
+      llvm::Intrinsic::nvvm_e3m2x2_to_bf16x2_rn_relu_scale_n2_ue8m0,
+      llvm::Intrinsic::nvvm_e3m2x2_to_bf16x2_rn_satfinite_scale_n2_ue8m0,
+      llvm::Intrinsic::nvvm_e3m2x2_to_bf16x2_rn_relu_satfinite_scale_n2_ue8m0,
+  };
+
+  unsigned idx = (hasSatfinite << 1) | hasRelu;
+  llvm::Intrinsic::ID intId =
+      llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(curOp.getSrcType())
+          .Case([&](Float6E2M3FNType type) { return E2M3Ids[idx]; })
+          .Case([&](Float6E3M2FNType type) { return E3M2Ids[idx]; })
+          .Default([](mlir::Type type) {
+            llvm_unreachable("Invalid type for ConvertF6x2ToBF16x2Op");
+            return llvm::Intrinsic::not_intrinsic;
+          });
+
+  llvm::Value *packedI16 =
+      builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
+                            llvm::Type::getInt16Ty(builder.getContext()));
+
+  llvm::SmallVector<llvm::Value *> args;
+  args.push_back(packedI16);
+  args.push_back(
+      hasScale
+          ? mt.lookupValue(curOp.getScaleFactor())
+          : builder.getInt16(
+                0x7f7f)); // default scale factor (value of 1 for both elements)
+
+  return {intId, std::move(args)};
+}
+
 NVVM::IDArgPair ConvertF4x2ToF16x2Op::getIntrinsicIDAndArgs(
     Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
   auto curOp = cast<NVVM::ConvertF4x2ToF16x2Op>(op);
@@ -4859,6 +4917,44 @@ NVVM::IDArgPair ConvertF4x2ToF16x2Op::getIntrinsicIDAndArgs(
   return {intId, {extendedI16}};
 }
 
+NVVM::IDArgPair ConvertF4x2ToBF16x2Op::getIntrinsicIDAndArgs(
+    Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+  auto curOp = cast<NVVM::ConvertF4x2ToBF16x2Op>(op);
+  bool hasScale = static_cast<bool>(curOp.getScaleFactor());
+  bool hasSatfinite = curOp.getSat() == NVVM::SaturationMode::SATFINITE;
+  bool hasRelu = curOp.getRelu();
+
+  static constexpr llvm::Intrinsic::ID E2M1Ids[] = {
+      llvm::Intrinsic::nvvm_e2m1x2_to_bf16x2_rn_scale_n2_ue8m0,
+      llvm::Intrinsic::nvvm_e2m1x2_to_bf16x2_rn_relu_scale_n2_ue8m0,
+      llvm::Intrinsic::nvvm_e2m1x2_to_bf16x2_rn_satfinite_scale_n2_ue8m0,
+      llvm::Intrinsic::nvvm_e2m1x2_to_bf16x2_rn_relu_satfinite_scale_n2_ue8m0,
+  };
+
+  unsigned idx = (hasSatfinite << 1) | hasRelu;
+  llvm::Intrinsic::ID intId =
+      llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(curOp.getSrcType())
+          .Case([&](Float4E2M1FNType type) { return E2M1Ids[idx]; })
+          .Default([](mlir::Type type) {
+            llvm_unreachable("Invalid type for ConvertF4x2ToBF16x2Op");
+            return llvm::Intrinsic::not_intrinsic;
+          });
+
+  llvm::Value *extendedI16 =
+      builder.CreateZExt(mt.lookupValue(curOp.getSrc()),
+                         llvm::Type::getInt16Ty(builder.getContext()));
+
+  llvm::SmallVector<llvm::Value *> args;
+  args.push_back(extendedI16);
+  args.push_back(
+      hasScale
+          ? mt.lookupValue(curOp.getScaleFactor())
+          : builder.getInt16(
+                0x7f7f)); // default scale factor (value of 1 for both elements)
+
+  return {intId, std::move(args)};
+}
+
 NVVM::IDArgPair ConvertF32x2ToS2F6x2Op::getIntrinsicIDAndArgs(
     Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
   auto thisOp = cast<NVVM::ConvertF32x2ToS2F6x2Op>(op);

diff  --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
index 3d3bd714fa8fa..fc25a194ef236 100644
--- a/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
@@ -49,3 +49,25 @@ llvm.func @convert_f4x2_to_f16x2(%src : i8) {
   %res2 = nvvm.convert.f4x2.to.f16x2 %src {relu = true} : i8 (f4E2M1FN)-> vector<2xf16>
   llvm.return
 }
+
+// -----
+
+// CHECK-LABEL: @convert_f4x2_to_bf16x2
+llvm.func @convert_f4x2_to_bf16x2(%src : i8, %scale_factor : i16) {
+  // CHECK: %[[res1:.*]] = zext i8 %{{.*}} to i16
+  // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.e2m1x2.to.bf16x2.rn.scale.n2.ue8m0(i16 %[[res1]], i16 32639)
+  %res1 = nvvm.convert.f4x2.to.bf16x2 %src : i8 (f4E2M1FN) -> vector<2xbf16>
+  // CHECK: %[[res2:.*]] = zext i8 %{{.*}} to i16
+  // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.e2m1x2.to.bf16x2.rn.relu.scale.n2.ue8m0(i16 %[[res2]], i16 32639)
+  %res2 = nvvm.convert.f4x2.to.bf16x2 %src {relu = true} : i8 (f4E2M1FN) -> vector<2xbf16>
+  // CHECK: %[[res3:.*]] = zext i8 %{{.*}} to i16
+  // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.e2m1x2.to.bf16x2.rn.satfinite.scale.n2.ue8m0(i16 %[[res3]], i16 32639)
+  %res3 = nvvm.convert.f4x2.to.bf16x2 %src {sat = #nvvm.sat_mode<satfinite>} : i8 (f4E2M1FN) -> vector<2xbf16>
+  // CHECK: %[[res4:.*]] = zext i8 %{{.*}} to i16
+  // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.e2m1x2.to.bf16x2.rn.relu.satfinite.scale.n2.ue8m0(i16 %[[res4]], i16 32639)
+  %res4 = nvvm.convert.f4x2.to.bf16x2 %src {relu = true, sat = #nvvm.sat_mode<satfinite>} : i8 (f4E2M1FN) -> vector<2xbf16>
+  // CHECK: %[[res5:.*]] = zext i8 %{{.*}} to i16
+  // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.e2m1x2.to.bf16x2.rn.scale.n2.ue8m0(i16 %[[res5]], i16 %{{.*}})
+  %res5 = nvvm.convert.f4x2.to.bf16x2 %src, %scale_factor : i8 (f4E2M1FN) -> vector<2xbf16>
+  llvm.return
+}

diff  --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir
index 8d9e5ff2a6a82..e83f4fe6449db 100644
--- a/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir
@@ -114,3 +114,45 @@ llvm.func @convert_f6x2_to_f16x2_e3m2(%src : vector<2xi8>) {
   %res2 = nvvm.convert.f6x2.to.f16x2 %src {relu = true} : vector<2xi8> (f6E3M2FN)-> vector<2xf16>
   llvm.return
 }
+
+// -----
+
+// CHECK-LABEL: @convert_f6x2_to_bf16x2_e2m3
+llvm.func @convert_f6x2_to_bf16x2_e2m3(%src : vector<2xi8>, %scale_factor : i16) {
+  // CHECK: %[[res1:.*]] = bitcast <2 x i8> %{{.*}} to i16
+  // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.e2m3x2.to.bf16x2.rn.scale.n2.ue8m0(i16 %[[res1]], i16 32639)
+  %res1 = nvvm.convert.f6x2.to.bf16x2 %src : vector<2xi8> (f6E2M3FN) -> vector<2xbf16>
+  // CHECK: %[[res2:.*]] = bitcast <2 x i8> %{{.*}} to i16
+  // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.e2m3x2.to.bf16x2.rn.relu.scale.n2.ue8m0(i16 %[[res2]], i16 32639)
+  %res2 = nvvm.convert.f6x2.to.bf16x2 %src {relu = true} : vector<2xi8> (f6E2M3FN) -> vector<2xbf16>
+  // CHECK: %[[res3:.*]] = bitcast <2 x i8> %{{.*}} to i16
+  // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.e2m3x2.to.bf16x2.rn.satfinite.scale.n2.ue8m0(i16 %[[res3]], i16 32639)
+  %res3 = nvvm.convert.f6x2.to.bf16x2 %src {sat = #nvvm.sat_mode<satfinite>} : vector<2xi8> (f6E2M3FN) -> vector<2xbf16>
+  // CHECK: %[[res4:.*]] = bitcast <2 x i8> %{{.*}} to i16
+  // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.e2m3x2.to.bf16x2.rn.relu.satfinite.scale.n2.ue8m0(i16 %[[res4]], i16 32639)
+  %res4 = nvvm.convert.f6x2.to.bf16x2 %src {relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xi8> (f6E2M3FN) -> vector<2xbf16>
+  // CHECK: %[[res5:.*]] = bitcast <2 x i8> %{{.*}} to i16
+  // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.e2m3x2.to.bf16x2.rn.scale.n2.ue8m0(i16 %[[res5]], i16 %{{.*}})
+  %res5 = nvvm.convert.f6x2.to.bf16x2 %src, %scale_factor : vector<2xi8> (f6E2M3FN) -> vector<2xbf16>
+  llvm.return
+}
+
+// CHECK-LABEL: @convert_f6x2_to_bf16x2_e3m2
+llvm.func @convert_f6x2_to_bf16x2_e3m2(%src : vector<2xi8>, %scale_factor : i16) {
+  // CHECK: %[[res1:.*]] = bitcast <2 x i8> %{{.*}} to i16
+  // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.e3m2x2.to.bf16x2.rn.scale.n2.ue8m0(i16 %[[res1]], i16 32639)
+  %res1 = nvvm.convert.f6x2.to.bf16x2 %src : vector<2xi8> (f6E3M2FN) -> vector<2xbf16>
+  // CHECK: %[[res2:.*]] = bitcast <2 x i8> %{{.*}} to i16
+  // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.e3m2x2.to.bf16x2.rn.relu.scale.n2.ue8m0(i16 %[[res2]], i16 32639)
+  %res2 = nvvm.convert.f6x2.to.bf16x2 %src {relu = true} : vector<2xi8> (f6E3M2FN) -> vector<2xbf16>
+  // CHECK: %[[res3:.*]] = bitcast <2 x i8> %{{.*}} to i16
+  // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.e3m2x2.to.bf16x2.rn.satfinite.scale.n2.ue8m0(i16 %[[res3]], i16 32639)
+  %res3 = nvvm.convert.f6x2.to.bf16x2 %src {sat = #nvvm.sat_mode<satfinite>} : vector<2xi8> (f6E3M2FN) -> vector<2xbf16>
+  // CHECK: %[[res4:.*]] = bitcast <2 x i8> %{{.*}} to i16
+  // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.e3m2x2.to.bf16x2.rn.relu.satfinite.scale.n2.ue8m0(i16 %[[res4]], i16 32639)
+  %res4 = nvvm.convert.f6x2.to.bf16x2 %src {relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xi8> (f6E3M2FN) -> vector<2xbf16>
+  // CHECK: %[[res5:.*]] = bitcast <2 x i8> %{{.*}} to i16
+  // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.e3m2x2.to.bf16x2.rn.scale.n2.ue8m0(i16 %[[res5]], i16 %{{.*}})
+  %res5 = nvvm.convert.f6x2.to.bf16x2 %src, %scale_factor : vector<2xi8> (f6E3M2FN) -> vector<2xbf16>
+  llvm.return
+}

diff  --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir
index d8002d790b6a2..317e95dc3a75b 100644
--- a/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir
@@ -159,3 +159,43 @@ llvm.func @convert_f8x2_to_bf16x2_ue8m0(%src : vector<2xi8>) {
   %res1 = nvvm.convert.f8x2.to.bf16x2 %src : vector<2xi8> (f8E8M0FNU)-> vector<2xbf16>
   llvm.return
 }
+
+// CHECK-LABEL: @convert_f8x2_to_bf16x2_e4m3
+llvm.func @convert_f8x2_to_bf16x2_e4m3(%src : vector<2xi8>, %scale_factor : i16) {
+  // CHECK: %[[res1:.*]] = bitcast <2 x i8> %{{.*}} to i16
+  // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.e4m3x2.to.bf16x2.rn.scale.n2.ue8m0(i16 %[[res1]], i16 32639)
+  %res1 = nvvm.convert.f8x2.to.bf16x2 %src : vector<2xi8> (f8E4M3FN) -> vector<2xbf16>
+  // CHECK: %[[res2:.*]] = bitcast <2 x i8> %{{.*}} to i16
+  // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.e4m3x2.to.bf16x2.rn.relu.scale.n2.ue8m0(i16 %[[res2]], i16 32639)
+  %res2 = nvvm.convert.f8x2.to.bf16x2 %src {relu = true} : vector<2xi8> (f8E4M3FN) -> vector<2xbf16>
+  // CHECK: %[[res3:.*]] = bitcast <2 x i8> %{{.*}} to i16
+  // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.e4m3x2.to.bf16x2.rn.satfinite.scale.n2.ue8m0(i16 %[[res3]], i16 32639)
+  %res3 = nvvm.convert.f8x2.to.bf16x2 %src {sat = #nvvm.sat_mode<satfinite>} : vector<2xi8> (f8E4M3FN) -> vector<2xbf16>
+  // CHECK: %[[res4:.*]] = bitcast <2 x i8> %{{.*}} to i16
+  // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.e4m3x2.to.bf16x2.rn.relu.satfinite.scale.n2.ue8m0(i16 %[[res4]], i16 32639)
+  %res4 = nvvm.convert.f8x2.to.bf16x2 %src {relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xi8> (f8E4M3FN) -> vector<2xbf16>
+  // CHECK: %[[res5:.*]] = bitcast <2 x i8> %{{.*}} to i16
+  // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.e4m3x2.to.bf16x2.rn.scale.n2.ue8m0(i16 %[[res5]], i16 %{{.*}})
+  %res5 = nvvm.convert.f8x2.to.bf16x2 %src, %scale_factor : vector<2xi8> (f8E4M3FN) -> vector<2xbf16>
+  llvm.return
+}
+
+// CHECK-LABEL: @convert_f8x2_to_bf16x2_e5m2
+llvm.func @convert_f8x2_to_bf16x2_e5m2(%src : vector<2xi8>, %scale_factor : i16) {
+  // CHECK: %[[res1:.*]] = bitcast <2 x i8> %{{.*}} to i16
+  // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.e5m2x2.to.bf16x2.rn.scale.n2.ue8m0(i16 %[[res1]], i16 32639)
+  %res1 = nvvm.convert.f8x2.to.bf16x2 %src : vector<2xi8> (f8E5M2) -> vector<2xbf16>
+  // CHECK: %[[res2:.*]] = bitcast <2 x i8> %{{.*}} to i16
+  // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.e5m2x2.to.bf16x2.rn.relu.scale.n2.ue8m0(i16 %[[res2]], i16 32639)
+  %res2 = nvvm.convert.f8x2.to.bf16x2 %src {relu = true} : vector<2xi8> (f8E5M2) -> vector<2xbf16>
+  // CHECK: %[[res3:.*]] = bitcast <2 x i8> %{{.*}} to i16
+  // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.e5m2x2.to.bf16x2.rn.satfinite.scale.n2.ue8m0(i16 %[[res3]], i16 32639)
+  %res3 = nvvm.convert.f8x2.to.bf16x2 %src {sat = #nvvm.sat_mode<satfinite>} : vector<2xi8> (f8E5M2) -> vector<2xbf16>
+  // CHECK: %[[res4:.*]] = bitcast <2 x i8> %{{.*}} to i16
+  // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.e5m2x2.to.bf16x2.rn.relu.satfinite.scale.n2.ue8m0(i16 %[[res4]], i16 32639)
+  %res4 = nvvm.convert.f8x2.to.bf16x2 %src {relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xi8> (f8E5M2) -> vector<2xbf16>
+  // CHECK: %[[res5:.*]] = bitcast <2 x i8> %{{.*}} to i16
+  // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.e5m2x2.to.bf16x2.rn.scale.n2.ue8m0(i16 %[[res5]], i16 %{{.*}})
+  %res5 = nvvm.convert.f8x2.to.bf16x2 %src, %scale_factor : vector<2xi8> (f8E5M2) -> vector<2xbf16>
+  llvm.return
+}

diff  --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index 82e7373a40baa..07c34f10d0e3d 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -219,7 +219,7 @@ llvm.func @nvvm_cvt_f32x2_to_f4x2_invalid_type(%a : f32, %b : f32) {
 // -----
 
 llvm.func @nvvm_cvt_f8x2_to_f16x2_invalid_type(%src : vector<2xi8>) {
-  // expected-error @below {{Only 'f8E4M3FN' and 'f8E5M2' types are supported for conversions from f8x2 to f16x2.}}
+  // expected-error @below {{op attribute 'srcType' failed to satisfy constraint: type attribute of f8E4M3FN type or f8E5M2 type}}
   %res = nvvm.convert.f8x2.to.f16x2 %src : vector<2xi8> (f8E4M3) -> vector<2xf16>
   llvm.return
 }
@@ -227,29 +227,93 @@ llvm.func @nvvm_cvt_f8x2_to_f16x2_invalid_type(%src : vector<2xi8>) {
 // -----
 
 llvm.func @nvvm_cvt_f8x2_to_bf16x2_invalid_type(%src : vector<2xi8>) {
-  // expected-error @below {{Only 'f8E8M0FNU' type is supported for conversions from f8x2 to bf16x2.}}
-  %res = nvvm.convert.f8x2.to.bf16x2 %src : vector<2xi8> (f8E4M3FN) -> vector<2xbf16>
+  // expected-error @below {{op attribute 'srcType' failed to satisfy constraint: type attribute of f8E8M0FNU type or f8E4M3FN type or f8E5M2 type}}
+  %res = nvvm.convert.f8x2.to.bf16x2 %src : vector<2xi8> (f8E4M3) -> vector<2xbf16>
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_f8x2_to_bf16x2_ue8m0_invalid_sat(%src : vector<2xi8>) {
+  // expected-error @below {{Only NONE saturation mode is supported for conversions from 'f8E8M0FNU' type}}
+  %res = nvvm.convert.f8x2.to.bf16x2 %src {sat = #nvvm.sat_mode<satfinite>} : vector<2xi8> (f8E8M0FNU) -> vector<2xbf16>
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_f8x2_to_bf16x2_ue8m0_invalid_scale(%src : vector<2xi8>, %sf : i16) {
+  // expected-error @below {{scaleFactor not supported for conversions from 'f8E8M0FNU' type}}
+  %res = nvvm.convert.f8x2.to.bf16x2 %src, %sf : vector<2xi8> (f8E8M0FNU) -> vector<2xbf16>
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_f8x2_to_bf16x2_ue8m0_invalid_relu(%src : vector<2xi8>) {
+  // expected-error @below {{relu not supported for conversions from 'f8E8M0FNU' type}}
+  %res = nvvm.convert.f8x2.to.bf16x2 %src {relu = true} : vector<2xi8> (f8E8M0FNU) -> vector<2xbf16>
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_f8x2_to_bf16x2_invalid_sat(%src : vector<2xi8>) {
+  // expected-error @below {{op attribute 'sat' failed to satisfy constraint: Describes the saturation mode whose value is one of {none, satfinite}}}
+  %res = nvvm.convert.f8x2.to.bf16x2 %src {sat = #nvvm.sat_mode<sat>} : vector<2xi8> (f8E4M3FN) -> vector<2xbf16>
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_f6x2_to_bf16x2_invalid_sat(%src : vector<2xi8>) {
+  // expected-error @below {{op attribute 'sat' failed to satisfy constraint: Describes the saturation mode whose value is one of {none, satfinite}}}
+  %res = nvvm.convert.f6x2.to.bf16x2 %src {sat = #nvvm.sat_mode<sat>} : vector<2xi8> (f6E2M3FN) -> vector<2xbf16>
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_f4x2_to_bf16x2_invalid_sat(%src : i8) {
+  // expected-error @below {{op attribute 'sat' failed to satisfy constraint: Describes the saturation mode whose value is one of {none, satfinite}}}
+  %res = nvvm.convert.f4x2.to.bf16x2 %src {sat = #nvvm.sat_mode<sat>} : i8 (f4E2M1FN) -> vector<2xbf16>
   llvm.return
 }
 
 // -----
 
 llvm.func @nvvm_cvt_f6x2_to_f16x2_invalid_type(%src : vector<2xi8>) {
-  // expected-error @below {{Only 'f6E2M3FN' and 'f6E3M2FN' types are supported for conversions from f6x2 to f16x2.}}
+  // expected-error @below {{op attribute 'srcType' failed to satisfy constraint: type attribute of f6E2M3FN type or f6E3M2FN type}}
   %res = nvvm.convert.f6x2.to.f16x2 %src : vector<2xi8> (f8E4M3FN) -> vector<2xf16>
   llvm.return
 }
 
 // -----
 
+llvm.func @nvvm_cvt_f6x2_to_bf16x2_invalid_type(%src : vector<2xi8>) {
+  // expected-error @below {{op attribute 'srcType' failed to satisfy constraint: type attribute of f6E2M3FN type or f6E3M2FN type}}
+  %res = nvvm.convert.f6x2.to.bf16x2 %src : vector<2xi8> (f8E4M3FN) -> vector<2xbf16>
+  llvm.return
+}
+
+// -----
+
 llvm.func @nvvm_cvt_f4x2_to_f16x2_invalid_type(%src : i8) {
-  // expected-error @below {{Only 'f4E2M1FN' type is supported for conversions from f4x2 to f16x2.}}
+  // expected-error @below {{op attribute 'srcType' failed to satisfy constraint: type attribute of f4E2M1FN type}}
   %res = nvvm.convert.f4x2.to.f16x2 %src : i8 (f6E2M3FN) -> vector<2xf16>
   llvm.return
 }
 
 // -----
 
+llvm.func @nvvm_cvt_f4x2_to_bf16x2_invalid_type(%src : i8) {
+  // expected-error @below {{op attribute 'srcType' failed to satisfy constraint: type attribute of f4E2M1FN type}}
+  %res = nvvm.convert.f4x2.to.bf16x2 %src : i8 (f6E2M3FN) -> vector<2xbf16>
+  llvm.return
+}
+
+// -----
+
 llvm.func @nvvm_prefetch_L1_with_evict_priority(%global_ptr: !llvm.ptr<1>) {
   // expected-error @below {{cache eviction priority supported only for cache level L2}}
   nvvm.prefetch level = L1, evict_priority = evict_last, %global_ptr : !llvm.ptr<1>


        


More information about the llvm-branch-commits mailing list