[Mlir-commits] [mlir] [MLIR][NVVM] Add new narrow FP convert Ops (PR #184291)
Srinivasa Ravi
llvmlistbot at llvm.org
Thu Apr 2 02:45:57 PDT 2026
https://github.com/Wolfram70 updated https://github.com/llvm/llvm-project/pull/184291
>From 7153047317d84d74edc91a9563dbb0e889422340 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Mon, 2 Mar 2026 11:13:22 +0000
Subject: [PATCH 1/3] [MLIR][NVVM] Add new narrow FP convert Ops
This change adds the following NVVM Ops for new narrow FP conversions
introduced in PTX 9.1:
- `convert.{f32x2/bf16x2}.to.s2f6x2`
- `convert.s2f6x2.to.bf16x2`
- `convert.bf16x2.to.f8x2` (extended for `f8E4M3FN` and `f8E5M2` types)
- `convert.{f16x2/bf16x2}.to.f6x2`
- `convert.{f16x2/bf16x2}.to.f4x2`
PTX ISA Reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt
---
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 177 ++++++++++-
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 277 ++++++++++++++++--
.../invalid-convert-stochastic-rounding.mlir | 2 +-
.../Target/LLVMIR/nvvm/convert_fp4x2.mlir | 28 ++
.../LLVMIR/nvvm/convert_fp4x2_invalid.mlir | 17 ++
.../Target/LLVMIR/nvvm/convert_fp6x2.mlir | 77 ++++-
.../LLVMIR/nvvm/convert_fp6x2_invalid.mlir | 17 ++
.../Target/LLVMIR/nvvm/convert_fp8x2.mlir | 25 ++
.../LLVMIR/nvvm/convert_fp8x2_invalid.mlir | 41 +++
.../Target/LLVMIR/nvvm/convert_s2f6x2.mlir | 181 ++++++++++++
mlir/test/Target/LLVMIR/nvvmir-invalid.mlir | 16 -
11 files changed, 812 insertions(+), 46 deletions(-)
create mode 100644 mlir/test/Target/LLVMIR/nvvm/convert_fp4x2_invalid.mlir
create mode 100644 mlir/test/Target/LLVMIR/nvvm/convert_fp6x2_invalid.mlir
create mode 100644 mlir/test/Target/LLVMIR/nvvm/convert_fp8x2_invalid.mlir
create mode 100644 mlir/test/Target/LLVMIR/nvvm/convert_s2f6x2.mlir
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index f8e1ab38e80d4..ef69def599ccc 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1972,6 +1972,45 @@ def NVVM_ConvertF32x2ToF4x2Op : NVVM_Op<"convert.f32x2.to.f4x2"> {
}];
}
+class NVVM_ConvertFPx2ToF4x2Op<string srcType>
+ : NVVM_Op<"convert."#!tolower(srcType)#"x2.to.f4x2"> {
+ let summary = "Convert a pair of "#srcType#" inputs to F4x2";
+ let description = [{
+ This Op converts each of the given }]#srcType#[{ inputs to the specified
+ F4x2 type.
+ The result `dst` is returned as an i8 type where the converted values are
+ packed such that the value converted from `a` is stored in the upper 4 bits
+ of `dst` and the value converted from `b` is stored in the lower 4 bits of
+ `dst`.
+ The `relu` attribute, when set, lowers to the '.relu' variant of
+ the cvt instruction.
+ }];
+ let hasVerifier = 1;
+ let results = (outs I8:$dst);
+ let arguments = (ins
+ VectorOfLengthAndType<[2], [!cast<Type>(srcType)]>:$a,
+ DefaultValuedAttr<BoolAttr, "false">:$relu,
+ TypeAttr:$dstTy);
+
+ let assemblyFormat =
+ "$a attr-dict `:` type($a) `->` type($dst) `(` $dstTy `)`";
+
+ let extraClassDeclaration = [{
+ static NVVM::IDArgPair
+ getIntrinsicIDAndArgs(NVVM::Convert}]#srcType#[{x2ToF4x2Op &op,
+ LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder);
+ }];
+
+ string llvmBuilder = [{
+ auto [intId, args] = NVVM::Convert}]#srcType#[{x2ToF4x2Op::getIntrinsicIDAndArgs(op, moduleTranslation, builder);
+ llvm::Value *packedI16 = createIntrinsicCall(builder, intId, args);
+ $dst = builder.CreateTruncOrBitCast(packedI16, llvm::Type::getInt8Ty(builder.getContext()));
+ }];
+}
+
+def NVVM_ConvertF16x2ToF4x2Op : NVVM_ConvertFPx2ToF4x2Op<"F16">;
+def NVVM_ConvertBF16x2ToF4x2Op : NVVM_ConvertFPx2ToF4x2Op<"BF16">;
+
def NVVM_ConvertF32x2ToF6x2Op : NVVM_Op<"convert.f32x2.to.f6x2"> {
let summary = "Convert a pair of float inputs to f6x2";
let description = [{
@@ -2015,6 +2054,43 @@ def NVVM_ConvertF32x2ToF6x2Op : NVVM_Op<"convert.f32x2.to.f6x2"> {
}];
}
+class NVVM_ConvertFPx2ToF6x2Op<string srcType>
+ : NVVM_Op<"convert."#!tolower(srcType)#"x2.to.f6x2"> {
+ let summary = "Convert a pair of "#srcType#" inputs to F6x2";
+ let description = [{
+ This Op converts each of the given }]#srcType#[{ inputs to the specified
+ F6x2 type.
+ The result `dst` is represented as a vector type(vector<2xi8>). Each
+ converted value is stored as an i8 element in the vector.
+ The `relu` attribute, when set, lowers to the '.relu' variant of
+ the cvt instruction.
+ }];
+
+ let hasVerifier = 1;
+ let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
+ let arguments = (ins
+ VectorOfLengthAndType<[2], [!cast<Type>(srcType)]>:$a,
+ DefaultValuedAttr<BoolAttr, "false">:$relu,
+ TypeAttr:$dstTy);
+ let assemblyFormat = "$a attr-dict `:` type($a) `->` type($dst) `(` $dstTy `)`";
+ let extraClassDeclaration = [{
+ static llvm::Intrinsic::ID getIntrinsicID(mlir::Type dstTy, bool hasRelu);
+ }];
+
+ string llvmBuilder = [{
+ auto intId = NVVM::Convert}]#srcType#[{x2ToF6x2Op::getIntrinsicID($dstTy, $relu);
+ llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a});
+ if(op.getDst().getType().isInteger(16))
+ $dst = packedI16;
+ else
+ $dst = builder.CreateBitCast(packedI16,
+ llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2));
+ }];
+}
+
+def NVVM_ConvertF16x2ToF6x2Op : NVVM_ConvertFPx2ToF6x2Op<"F16">;
+def NVVM_ConvertBF16x2ToF6x2Op : NVVM_ConvertFPx2ToF6x2Op<"BF16">;
+
def NVVM_ConvertF32x2ToF8x2Op : NVVM_Op<"convert.f32x2.to.f8x2"> {
let summary = "Convert a pair of float inputs to f8x2";
let description = [{
@@ -2130,16 +2206,18 @@ def NVVM_ConvertBF16x2ToF8x2Op : NVVM_Op<"convert.bf16x2.to.f8x2"> {
VectorOfLengthAndType<[2], [BF16]>:$a,
DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd,
DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
+ DefaultValuedAttr<BoolAttr, "false">:$relu,
TypeAttr:$dstTy);
let assemblyFormat = "$a attr-dict `:` type($a) `->` type($dst) `(` $dstTy `)`";
let extraClassDeclaration = [{
- static llvm::Intrinsic::ID getIntrinsicID(NVVM::FPRoundingMode rnd,
- NVVM::SaturationMode sat);
+ static llvm::Intrinsic::ID getIntrinsicID(mlir::Type dstTy,
+ NVVM::FPRoundingMode rnd,
+ NVVM::SaturationMode sat, bool hasRelu);
}];
string llvmBuilder = [{
- auto intId = NVVM::ConvertBF16x2ToF8x2Op::getIntrinsicID($rnd, $sat);
+ auto intId = NVVM::ConvertBF16x2ToF8x2Op::getIntrinsicID($dstTy, $rnd, $sat, $relu);
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a});
if(op.getDst().getType().isInteger(16))
$dst = packedI16;
@@ -2185,6 +2263,99 @@ def NVVM_ConvertF6x2ToF16x2Op :
def NVVM_ConvertF4x2ToF16x2Op :
NVVM_ConvertToFP16x2Op_Base<"F4", I8, "F16">;
+def NVVM_ConvertF32x2ToS2F6x2Op : NVVM_Op<"convert.f32x2.to.s2f6x2"> {
+ let summary = "Convert a pair of f32 inputs to S2F6x2";
+ let description = [{
+ This Op converts each of the given f32 inputs to the
+ S2F6x2 type. The result `dst` can be an I16 or vector<2xi8>.
+ The `relu` attribute, when set, lowers to the '.relu' variant
+ of the cvt instruction. The optional scaling-factor for the
+ conversion is provided through the operand `scaleFactor`.
+ Only `ue8m0` is supported as the type of the scale-factor currently.
+ }];
+
+ let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
+ let arguments = (ins F32:$a, F32:$b,
+ Optional<I16>:$scaleFactor,
+ DefaultValuedAttr<BoolAttr, "false">:$relu);
+ let assemblyFormat =
+ "$a `,` $b (`,` $scaleFactor^)? attr-dict `:` type($dst)";
+ let extraClassDeclaration = [{
+ static IDArgPair getIntrinsicIDAndArgs(Operation &op,
+ LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder);
+ }];
+
+ string llvmBuilder = [{
+ auto [id, args] = NVVM::ConvertF32x2ToS2F6x2Op::getIntrinsicIDAndArgs(*op, moduleTranslation, builder);
+ llvm::Value *packedI16 = createIntrinsicCall(builder, id, args);
+ if(op.getDst().getType().isInteger(16))
+ $dst = packedI16;
+ else
+ $dst = builder.CreateBitCast(packedI16,
+ llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2));
+ }];
+}
+
+def NVVM_ConvertBF16x2ToS2F6x2Op : NVVM_Op<"convert.bf16x2.to.s2f6x2"> {
+ let summary = "Convert a pair of BF16 inputs to S2F6x2";
+ let description = [{
+ This Op converts each of the given BF16 inputs to the
+ S2F6x2 type. The result `dst` can be an I16 or vector<2xi8>.
+ The `relu` attribute, when set, lowers to the '.relu' variant
+ of the cvt instruction.
+ }];
+
+ let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
+ let arguments = (ins
+ VectorOfLengthAndType<[2], [BF16]>:$a,
+ Optional<I16>:$scaleFactor,
+ DefaultValuedAttr<BoolAttr, "false">:$relu);
+ let assemblyFormat =
+ "$a (`,` $scaleFactor^)? attr-dict `:` type($a) `->` type($dst)";
+ let extraClassDeclaration = [{
+ static IDArgPair getIntrinsicIDAndArgs(Operation &op,
+ LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder);
+ }];
+
+ string llvmBuilder = [{
+ auto [id, args] = NVVM::ConvertBF16x2ToS2F6x2Op::getIntrinsicIDAndArgs(*op, moduleTranslation, builder);
+ llvm::Value *packedI16 = createIntrinsicCall(builder, id, args);
+ if(op.getDst().getType().isInteger(16))
+ $dst = packedI16;
+ else
+ $dst = builder.CreateBitCast(packedI16,
+ llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2));
+ }];
+}
+
+def NVVM_ConvertS2F6x2ToBF16x2Op : NVVM_Op<"convert.s2f6x2.to.bf16x2"> {
+ let summary = "Convert s2f6x2 to bf16x2";
+ let description = [{
+ This Op converts the given s2f6x2 to bf16x2 values.
+ The `relu` attribute, when set, lowers to the '.relu' variant
+ of the cvt instruction. The optional scaling-factor for the
+ conversion is provided through the operand `scaleFactor`.
+ Only `ue8m0` is supported as the type of the scale-factor currently.
+ }];
+
+ let results = (outs VectorOfLengthAndType<[2], [BF16]>:$dst);
+ let arguments = (ins I16:$a,
+ Optional<I16>:$scaleFactor,
+ DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
+ DefaultValuedAttr<BoolAttr, "false">:$relu);
+ let assemblyFormat =
+ "$a (`,` $scaleFactor^)? attr-dict `:` type($a) `->` type($dst)";
+ let extraClassDeclaration = [{
+ static IDArgPair getIntrinsicIDAndArgs(Operation &op,
+ LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder);
+ }];
+
+ string llvmBuilder = [{
+ auto [id, args] = NVVM::ConvertS2F6x2ToBF16x2Op::getIntrinsicIDAndArgs(*op, moduleTranslation, builder);
+ $dst = createIntrinsicCall(builder, id, args);
+ }];
+}
+
//===----------------------------------------------------------------------===//
// NVVM Stochastic Rounding Conversion Ops
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 6ccd59cec65bc..a27ca2ebdfe78 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -347,6 +347,32 @@ LogicalResult ConvertF32x2ToF6x2Op::verify() {
return success();
}
+LogicalResult ConvertF16x2ToF6x2Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy())) {
+ return emitOpError("Only ")
+ << mlir::Float6E2M3FNType::get(ctx) << " and "
+ << mlir::Float6E3M2FNType::get(ctx)
+ << " types are supported for conversions from f16x2 to f6x2.";
+ }
+
+ return success();
+}
+
+LogicalResult ConvertBF16x2ToF6x2Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy())) {
+ return emitOpError("Only ")
+ << mlir::Float6E2M3FNType::get(ctx) << " and "
+ << mlir::Float6E3M2FNType::get(ctx)
+ << " types are supported for conversions from bf16x2 to f6x2.";
+ }
+
+ return success();
+}
+
LogicalResult ConvertF32x2ToF8x2Op::verify() {
using RndMode = NVVM::FPRoundingMode;
using SatMode = NVVM::SaturationMode;
@@ -414,18 +440,48 @@ LogicalResult ConvertF16x2ToF8x2Op::verify() {
LogicalResult ConvertBF16x2ToF8x2Op::verify() {
using RndMode = NVVM::FPRoundingMode;
+ using SatMode = NVVM::SaturationMode;
- if (!llvm::isa<mlir::Float8E8M0FNUType>(getDstTy()))
- return emitOpError("Only ") << mlir::Float8E8M0FNUType::get(getContext())
- << " type is supported for conversions from "
- "bf16x2 to f8x2.";
+ bool isRoundingModeRN = getRnd() == RndMode::RN;
+ bool isRoundingModeRZ = getRnd() == RndMode::RZ;
+ bool isRoundingModeRP = getRnd() == RndMode::RP;
+ bool isSatFinite = getSat() == SatMode::SATFINITE;
+ bool hasRelu = getRelu();
- auto rnd = getRnd();
- if (rnd != RndMode::RZ && rnd != RndMode::RP)
- return emitOpError("Only RZ and RP rounding modes are supported for "
- "conversions from bf16x2 to f8x2.");
+ mlir::MLIRContext *ctx = getContext();
- return success();
+ return llvm::TypeSwitch<mlir::Type, LogicalResult>(getDstTy())
+ .Case<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(
+ [&](mlir::Type) -> LogicalResult {
+ if (!isRoundingModeRN)
+ return emitOpError("Only RN rounding mode is supported for "
+ "conversions from bf16x2 to ")
+ << mlir::Float8E4M3FNType::get(ctx) << " and "
+ << mlir::Float8E5M2Type::get(ctx) << " types";
+ if (!isSatFinite)
+ return emitOpError("Only SATFINITE saturation mode is supported "
+ "for conversions from bf16x2 to ")
+ << mlir::Float8E4M3FNType::get(ctx) << " and "
+ << mlir::Float8E5M2Type::get(ctx) << " types";
+ return success();
+ })
+ .Case<mlir::Float8E8M0FNUType>([&](mlir::Type) -> LogicalResult {
+ if (!(isRoundingModeRZ || isRoundingModeRP))
+ return emitOpError("Only RZ and RP rounding modes are supported for "
+ "conversions from bf16x2 to ")
+ << mlir::Float8E8M0FNUType::get(ctx) << " type";
+ if (hasRelu)
+ return emitOpError("relu not supported for conversions to ")
+ << mlir::Float8E8M0FNUType::get(ctx) << " type";
+ return success();
+ })
+ .Default([&](mlir::Type) -> LogicalResult {
+ return emitOpError("Only ")
+ << mlir::Float8E4M3FNType::get(ctx) << ", "
+ << mlir::Float8E5M2Type::get(ctx) << ", and "
+ << mlir::Float8E8M0FNUType::get(ctx)
+ << " types are supported for conversions from bf16x2 to f8x2.";
+ });
}
LogicalResult ConvertF32x2ToF4x2Op::verify() {
@@ -439,6 +495,26 @@ LogicalResult ConvertF32x2ToF4x2Op::verify() {
return success();
}
+LogicalResult ConvertF16x2ToF4x2Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
+ return emitOpError("Only ")
+ << mlir::Float4E2M1FNType::get(ctx)
+ << " type is supported for conversions from f16x2 to f4x2.";
+ return success();
+}
+
+LogicalResult ConvertBF16x2ToF4x2Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
+ return emitOpError("Only ")
+ << mlir::Float4E2M1FNType::get(ctx)
+ << " type is supported for conversions from bf16x2 to f4x2.";
+ return success();
+}
+
LogicalResult ConvertF8x2ToF16x2Op::verify() {
mlir::MLIRContext *ctx = getContext();
@@ -4185,6 +4261,80 @@ llvm::Intrinsic::ID ConvertF32x2ToF6x2Op::getIntrinsicID(mlir::Type dstTy,
});
}
+NVVM::IDArgPair
+ConvertF16x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF16x2ToF4x2Op &op,
+ LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
+ mlir::Type dstTy = op.getDstTy();
+ bool hasRelu = op.getRelu();
+
+ llvm::Intrinsic::ID intId = llvm::Intrinsic::not_intrinsic;
+
+ if (llvm::isa<mlir::Float4E2M1FNType>(dstTy))
+ intId = hasRelu ? llvm::Intrinsic::nvvm_f16x2_to_e2m1x2_rn_relu_satfinite
+ : llvm::Intrinsic::nvvm_f16x2_to_e2m1x2_rn_satfinite;
+
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(op.getA()));
+
+ return {intId, std::move(args)};
+}
+
+NVVM::IDArgPair
+ConvertBF16x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertBF16x2ToF4x2Op &op,
+ LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
+ mlir::Type dstTy = op.getDstTy();
+ bool hasRelu = op.getRelu();
+
+ llvm::Intrinsic::ID intId = llvm::Intrinsic::not_intrinsic;
+
+ if (llvm::isa<mlir::Float4E2M1FNType>(dstTy))
+ intId = hasRelu ? llvm::Intrinsic::nvvm_bf16x2_to_e2m1x2_rn_relu_satfinite
+ : llvm::Intrinsic::nvvm_bf16x2_to_e2m1x2_rn_satfinite;
+
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(op.getA()));
+
+ return {intId, std::move(args)};
+}
+
+llvm::Intrinsic::ID ConvertF16x2ToF6x2Op::getIntrinsicID(mlir::Type dstTy,
+ bool hasRelu) {
+ return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
+ .Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
+ return hasRelu ? llvm::Intrinsic::nvvm_f16x2_to_e2m3x2_rn_relu_satfinite
+ : llvm::Intrinsic::nvvm_f16x2_to_e2m3x2_rn_satfinite;
+ })
+ .Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
+ return hasRelu ? llvm::Intrinsic::nvvm_f16x2_to_e3m2x2_rn_relu_satfinite
+ : llvm::Intrinsic::nvvm_f16x2_to_e3m2x2_rn_satfinite;
+ })
+ .Default([](mlir::Type) {
+ llvm_unreachable("Invalid conversion in ConvertF16x2ToF6x2Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
+}
+
+llvm::Intrinsic::ID ConvertBF16x2ToF6x2Op::getIntrinsicID(mlir::Type dstTy,
+ bool hasRelu) {
+ return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
+ .Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
+ return hasRelu
+ ? llvm::Intrinsic::nvvm_bf16x2_to_e2m3x2_rn_relu_satfinite
+ : llvm::Intrinsic::nvvm_bf16x2_to_e2m3x2_rn_satfinite;
+ })
+ .Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
+ return hasRelu
+ ? llvm::Intrinsic::nvvm_bf16x2_to_e3m2x2_rn_relu_satfinite
+ : llvm::Intrinsic::nvvm_bf16x2_to_e3m2x2_rn_satfinite;
+ })
+ .Default([](mlir::Type) {
+ llvm_unreachable("Invalid conversion in ConvertBF16x2ToF6x2Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
+}
+
#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf) \
has_satf ? llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd##_satfinite \
: llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd
@@ -4240,22 +4390,39 @@ llvm::Intrinsic::ID ConvertF16x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy,
});
}
-#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf) \
- has_satf ? llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd##_satfinite \
- : llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd
-
llvm::Intrinsic::ID
-ConvertBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
- NVVM::SaturationMode sat) {
+ConvertBF16x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy,
+ NVVM::FPRoundingMode rnd,
+ NVVM::SaturationMode sat, bool hasRelu) {
bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
- switch (rnd) {
- case NVVM::FPRoundingMode::RZ:
- return GET_BF16X2_TO_F8X2_ID(rz, hasSatFinite);
- case NVVM::FPRoundingMode::RP:
- return GET_BF16X2_TO_F8X2_ID(rp, hasSatFinite);
- default:
- llvm_unreachable("Invalid rounding mode for CvtBF16x2ToF8x2Op");
- }
+
+ static constexpr llvm::Intrinsic::ID ue8m0x2IDs[] = {
+ llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_rz,
+ llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_rp,
+ llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_rz_satfinite,
+ llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_rp_satfinite,
+ };
+
+ return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
+ .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
+ return hasRelu
+ ? llvm::Intrinsic::nvvm_bf16x2_to_e4m3x2_rn_relu_satfinite
+ : llvm::Intrinsic::nvvm_bf16x2_to_e4m3x2_rn_satfinite;
+ })
+ .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
+ return hasRelu
+ ? llvm::Intrinsic::nvvm_bf16x2_to_e5m2x2_rn_relu_satfinite
+ : llvm::Intrinsic::nvvm_bf16x2_to_e5m2x2_rn_satfinite;
+ })
+ .Case<mlir::Float8E8M0FNUType>([&](mlir::Float8E8M0FNUType) {
+ bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
+ unsigned index = (hasSatFinite << 1) | hasRoundingModeRP;
+ return ue8m0x2IDs[index];
+ })
+ .Default([](mlir::Type) {
+ llvm_unreachable("Invalid conversion in ConvertBF16x2ToF8x2Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
}
NVVM::IDArgPair ConvertF8x2ToF16x2Op::getIntrinsicIDAndArgs(
@@ -4350,6 +4517,70 @@ NVVM::IDArgPair ConvertF4x2ToF16x2Op::getIntrinsicIDAndArgs(
return {intId, {extendedI16}};
}
+NVVM::IDArgPair ConvertF32x2ToS2F6x2Op::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::ConvertF32x2ToS2F6x2Op>(op);
+ bool hasRelu = thisOp.getRelu();
+ bool hasScale = static_cast<bool>(thisOp.getScaleFactor());
+
+ llvm::Intrinsic::ID id =
+ hasRelu
+ ? llvm::Intrinsic::nvvm_ff_to_s2f6x2_rn_relu_satfinite_scale_n2_ue8m0
+ : llvm::Intrinsic::nvvm_ff_to_s2f6x2_rn_satfinite_scale_n2_ue8m0;
+
+ // Fill the Intrinsic Args
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(thisOp.getA()));
+ args.push_back(mt.lookupValue(thisOp.getB()));
+ args.push_back(hasScale ? mt.lookupValue(thisOp.getScaleFactor())
+ : builder.getInt16(0x7f7f));
+ return {id, std::move(args)};
+}
+
+NVVM::IDArgPair ConvertBF16x2ToS2F6x2Op::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::ConvertBF16x2ToS2F6x2Op>(op);
+ bool hasRelu = thisOp.getRelu();
+ bool hasScale = static_cast<bool>(thisOp.getScaleFactor());
+
+ llvm::Intrinsic::ID id =
+ hasRelu
+ ? llvm::Intrinsic::
+ nvvm_bf16x2_to_s2f6x2_rn_relu_satfinite_scale_n2_ue8m0
+ : llvm::Intrinsic::nvvm_bf16x2_to_s2f6x2_rn_satfinite_scale_n2_ue8m0;
+
+ // Fill the Intrinsic Args
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(thisOp.getA()));
+ args.push_back(hasScale ? mt.lookupValue(thisOp.getScaleFactor())
+ : builder.getInt16(0x7f7f));
+ return {id, std::move(args)};
+}
+
+NVVM::IDArgPair ConvertS2F6x2ToBF16x2Op::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::ConvertS2F6x2ToBF16x2Op>(op);
+ bool hasRelu = thisOp.getRelu();
+ bool hasScale = static_cast<bool>(thisOp.getScaleFactor());
+ bool hasSat = thisOp.getSat() == NVVM::SaturationMode::SATFINITE;
+
+ static constexpr llvm::Intrinsic::ID ids[] = {
+ llvm::Intrinsic::nvvm_s2f6x2_to_bf16x2_rn_scale_n2_ue8m0,
+ llvm::Intrinsic::nvvm_s2f6x2_to_bf16x2_rn_relu_scale_n2_ue8m0,
+ llvm::Intrinsic::nvvm_s2f6x2_to_bf16x2_rn_satfinite_scale_n2_ue8m0,
+ llvm::Intrinsic::nvvm_s2f6x2_to_bf16x2_rn_relu_satfinite_scale_n2_ue8m0,
+ };
+
+ unsigned idx = (hasSat << 1) | hasRelu;
+
+ // Fill the Intrinsic Args
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(thisOp.getA()));
+ args.push_back(hasScale ? mt.lookupValue(thisOp.getScaleFactor())
+ : builder.getInt16(0x7f7f));
+ return {ids[idx], std::move(args)};
+}
+
llvm::Intrinsic::ID
Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op,
LLVM::ModuleTranslation &mt,
diff --git a/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir b/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir
index 506b81e1e7048..3e763a036f2d3 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir
@@ -64,7 +64,7 @@ llvm.func @convert_f32x2_to_f8x2_rs_not_supported(%a : f32, %b : f32) {
// -----
llvm.func @convert_bf16x2_to_f8x2_rs_not_supported(%src : vector<2xbf16>) {
- // expected-error @below {{Only RZ and RP rounding modes are supported for conversions from bf16x2 to f8x2.}}
+ // expected-error @below {{Only RZ and RP rounding modes are supported for conversions from bf16x2 to 'f8E8M0FNU' type}}
%res = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xbf16> -> i16 (f8E8M0FNU)
llvm.return
}
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
index 451475ca76027..3d3bd714fa8fa 100644
--- a/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
@@ -11,6 +11,34 @@ llvm.func @convert_f32x2_to_f4x2_e2m1(%srcA : f32, %srcB : f32) {
llvm.return
}
+// -----
+
+// CHECK-LABEL: @convert_f16x2_to_f4x2
+llvm.func @convert_f16x2_to_f4x2(%srcA : vector<2xf16>) {
+ // CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.f16x2.to.e2m1x2.rn.satfinite(<2 x half> %{{.*}})
+ // CHECK-NEXT: %{{.*}} = trunc i16 %[[res1]] to i8
+ %res1 = nvvm.convert.f16x2.to.f4x2 %srcA : vector<2xf16> -> i8 (f4E2M1FN)
+ // CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.f16x2.to.e2m1x2.rn.relu.satfinite(<2 x half> %{{.*}})
+ // CHECK-NEXT: %{{.*}} = trunc i16 %[[res2]] to i8
+ %res2 = nvvm.convert.f16x2.to.f4x2 %srcA {relu = true} : vector<2xf16> -> i8 (f4E2M1FN)
+ llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @convert_bf16x2_to_f4x2
+llvm.func @convert_bf16x2_to_f4x2(%srcA : vector<2xbf16>) {
+ // CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.bf16x2.to.e2m1x2.rn.satfinite(<2 x bfloat> %{{.*}})
+ // CHECK-NEXT: %{{.*}} = trunc i16 %[[res1]] to i8
+ %res1 = nvvm.convert.bf16x2.to.f4x2 %srcA : vector<2xbf16> -> i8 (f4E2M1FN)
+ // CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.bf16x2.to.e2m1x2.rn.relu.satfinite(<2 x bfloat> %{{.*}})
+ // CHECK-NEXT: %{{.*}} = trunc i16 %[[res2]] to i8
+ %res2 = nvvm.convert.bf16x2.to.f4x2 %srcA {relu = true} : vector<2xbf16> -> i8 (f4E2M1FN)
+ llvm.return
+}
+
+// -----
+
// CHECK-LABEL: @convert_f4x2_to_f16x2
llvm.func @convert_f4x2_to_f16x2(%src : i8) {
// CHECK: %[[res1:.*]] = zext i8 %{{.*}} to i16
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2_invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2_invalid.mlir
new file mode 100644
index 0000000000000..637db9e8ca07a
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2_invalid.mlir
@@ -0,0 +1,17 @@
+// RUN: mlir-translate -mlir-to-llvmir -verify-diagnostics %s
+
+// -----
+
+llvm.func @convert_f16x2_to_f4x2_invalid_type(%src : vector<2xf16>) {
+ // expected-error @below {{Only 'f4E2M1FN' type is supported for conversions from f16x2 to f4x2.}}
+ %res = nvvm.convert.f16x2.to.f4x2 %src : vector<2xf16> -> i8 (f8E4M3FN)
+ llvm.return
+}
+
+// -----
+
+llvm.func @convert_bf16x2_to_f4x2_invalid_type(%src : vector<2xbf16>) {
+ // expected-error @below {{Only 'f4E2M1FN' type is supported for conversions from bf16x2 to f4x2.}}
+ %res = nvvm.convert.bf16x2.to.f4x2 %src : vector<2xbf16> -> i8 (f8E4M3FN)
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir
index 61a7a48f40d54..8d9e5ff2a6a82 100644
--- a/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir
@@ -1,11 +1,20 @@
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
-// CHECK-LABEL: @convert_f32x2_to_fp6x2_packed
-llvm.func @convert_f32x2_to_fp6x2_packed(%srcA : f32, %srcB : f32) {
+// CHECK-LABEL: @convert_f32x2_to_fp6x2_e2m3
+llvm.func @convert_f32x2_to_fp6x2_e2m3(%srcA : f32, %srcB : f32) {
//CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.satfinite(float %{{.*}}, float %{{.*}})
%res1 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : i16 (f6E2M3FN)
+ //CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.relu.satfinite(float %{{.*}}, float %{{.*}})
+ %res2 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB {relu = true} : i16 (f6E2M3FN)
+ llvm.return
+}
+
+// CHECK-LABEL: @convert_f32x2_to_fp6x2_e3m2
+llvm.func @convert_f32x2_to_fp6x2_e3m2(%srcA : f32, %srcB : f32) {
//CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.satfinite(float %{{.*}}, float %{{.*}})
- %res2 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : i16 (f6E3M2FN)
+ %res1 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : i16 (f6E3M2FN)
+ //CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.relu.satfinite(float %{{.*}}, float %{{.*}})
+ %res2 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB {relu = true} : i16 (f6E3M2FN)
llvm.return
}
@@ -22,6 +31,68 @@ llvm.func @convert_f32x2_to_fp6x2_vector(%srcA : f32, %srcB : f32) {
// -----
+// CHECK-LABEL: @convert_f16x2_to_fp6x2_e2m3
+llvm.func @convert_f16x2_to_fp6x2_e2m3(%srcA : vector<2xf16>) {
+ // CHECK: %{{.*}} = call i16 @llvm.nvvm.f16x2.to.e2m3x2.rn.satfinite(<2 x half> %{{.*}})
+ %res1 = nvvm.convert.f16x2.to.f6x2 %srcA : vector<2xf16> -> i16 (f6E2M3FN)
+ // CHECK: %{{.*}} = call i16 @llvm.nvvm.f16x2.to.e2m3x2.rn.relu.satfinite(<2 x half> %{{.*}})
+ %res2 = nvvm.convert.f16x2.to.f6x2 %srcA {relu = true} : vector<2xf16> -> i16 (f6E2M3FN)
+ llvm.return
+}
+
+// CHECK-LABEL: @convert_f16x2_to_fp6x2_e3m2
+llvm.func @convert_f16x2_to_fp6x2_e3m2(%srcA : vector<2xf16>) {
+ // CHECK: %{{.*}} = call i16 @llvm.nvvm.f16x2.to.e3m2x2.rn.satfinite(<2 x half> %{{.*}})
+ %res1 = nvvm.convert.f16x2.to.f6x2 %srcA : vector<2xf16> -> i16 (f6E3M2FN)
+ // CHECK: %{{.*}} = call i16 @llvm.nvvm.f16x2.to.e3m2x2.rn.relu.satfinite(<2 x half> %{{.*}})
+ %res2 = nvvm.convert.f16x2.to.f6x2 %srcA {relu = true} : vector<2xf16> -> i16 (f6E3M2FN)
+ llvm.return
+}
+
+// CHECK-LABEL: @convert_f16x2_to_fp6x2_vector
+llvm.func @convert_f16x2_to_fp6x2_vector(%srcA : vector<2xf16>) {
+ // CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.f16x2.to.e2m3x2.rn.satfinite(<2 x half> %{{.*}})
+ // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
+ %res1 = nvvm.convert.f16x2.to.f6x2 %srcA : vector<2xf16> -> vector<2xi8> (f6E2M3FN)
+ // CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.f16x2.to.e3m2x2.rn.satfinite(<2 x half> %{{.*}})
+ // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res2]] to <2 x i8>
+ %res2 = nvvm.convert.f16x2.to.f6x2 %srcA : vector<2xf16> -> vector<2xi8> (f6E3M2FN)
+ llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @convert_bf16x2_to_fp6x2_e2m3
+llvm.func @convert_bf16x2_to_fp6x2_e2m3(%srcA : vector<2xbf16>, %scale_factor : i16) {
+ // CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.e2m3x2.rn.satfinite(<2 x bfloat> %{{.*}})
+ %res1 = nvvm.convert.bf16x2.to.f6x2 %srcA : vector<2xbf16> -> i16 (f6E2M3FN)
+ // CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.e2m3x2.rn.relu.satfinite(<2 x bfloat> %{{.*}})
+ %res2 = nvvm.convert.bf16x2.to.f6x2 %srcA {relu = true} : vector<2xbf16> -> i16 (f6E2M3FN)
+ llvm.return
+}
+
+// CHECK-LABEL: @convert_bf16x2_to_fp6x2_e3m2
+llvm.func @convert_bf16x2_to_fp6x2_e3m2(%srcA : vector<2xbf16>, %scale_factor : i16) {
+ // CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.e3m2x2.rn.satfinite(<2 x bfloat> %{{.*}})
+ %res1 = nvvm.convert.bf16x2.to.f6x2 %srcA : vector<2xbf16> -> i16 (f6E3M2FN)
+ // CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.e3m2x2.rn.relu.satfinite(<2 x bfloat> %{{.*}})
+ %res2 = nvvm.convert.bf16x2.to.f6x2 %srcA {relu = true} : vector<2xbf16> -> i16 (f6E3M2FN)
+ llvm.return
+}
+
+// CHECK-LABEL: @convert_bf16x2_to_fp6x2_vector
+llvm.func @convert_bf16x2_to_fp6x2_vector(%srcA : vector<2xbf16>, %scale_factor : i16) {
+ // CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.e2m3x2.rn.satfinite(<2 x bfloat> %{{.*}})
+ // CHECK-NEXT: %{{.*}} = bitcast i16 %{{.*}} to <2 x i8>
+ %res1 = nvvm.convert.bf16x2.to.f6x2 %srcA : vector<2xbf16> -> vector<2xi8> (f6E2M3FN)
+ // CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.e3m2x2.rn.satfinite(<2 x bfloat> %{{.*}})
+ // CHECK-NEXT: %{{.*}} = bitcast i16 %{{.*}} to <2 x i8>
+ %res2 = nvvm.convert.bf16x2.to.f6x2 %srcA : vector<2xbf16> -> vector<2xi8> (f6E3M2FN)
+ llvm.return
+}
+
+// -----
+
// CHECK-LABEL: @convert_f6x2_to_f16x2_e2m3
llvm.func @convert_f6x2_to_f16x2_e2m3(%src : vector<2xi8>) {
// CHECK: %[[res1:.*]] = bitcast <2 x i8> %{{.*}} to i16
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2_invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2_invalid.mlir
new file mode 100644
index 0000000000000..8d0df5863a931
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2_invalid.mlir
@@ -0,0 +1,17 @@
+// RUN: mlir-translate -mlir-to-llvmir -verify-diagnostics %s
+
+// -----
+
+llvm.func @convert_f16x2_to_f6x2_invalid_type(%src : vector<2xf16>) {
+ // expected-error @below {{Only 'f6E2M3FN' and 'f6E3M2FN' types are supported for conversions from f16x2 to f6x2.}}
+ %res = nvvm.convert.f16x2.to.f6x2 %src : vector<2xf16> -> vector<2xi8> (f8E4M3FN)
+ llvm.return
+}
+
+// -----
+
+llvm.func @convert_bf16x2_to_f6x2_invalid_type(%src : vector<2xbf16>) {
+ // expected-error @below {{Only 'f6E2M3FN' and 'f6E3M2FN' types are supported for conversions from bf16x2 to f6x2.}}
+ %res = nvvm.convert.bf16x2.to.f6x2 %src : vector<2xbf16> -> vector<2xi8> (f8E4M3FN)
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir
index 4afe901bc08e9..d8002d790b6a2 100644
--- a/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir
@@ -90,6 +90,25 @@ llvm.func @convert_bf16x2_to_f8x2_ue8m0(%src : vector<2xbf16>) {
llvm.return
}
+
+// CHECK-LABEL: @convert_bf16x2_to_f8x2_e4m3
+llvm.func @convert_bf16x2_to_f8x2_e4m3(%srcA : vector<2xbf16>) {
+ // CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.e4m3x2.rn.satfinite(<2 x bfloat> %{{.*}})
+ %res1 = nvvm.convert.bf16x2.to.f8x2 %srcA {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> -> i16 (f8E4M3FN)
+ // CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.e4m3x2.rn.relu.satfinite(<2 x bfloat> %{{.*}})
+ %res2 = nvvm.convert.bf16x2.to.f8x2 %srcA {relu = true, rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> -> i16 (f8E4M3FN)
+ llvm.return
+}
+
+// CHECK-LABEL: @convert_bf16x2_to_f8x2_e5m2
+llvm.func @convert_bf16x2_to_f8x2_e5m2(%srcA : vector<2xbf16>) {
+ // CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.e5m2x2.rn.satfinite(<2 x bfloat> %{{.*}})
+ %res1 = nvvm.convert.bf16x2.to.f8x2 %srcA {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> -> vector<2xi8> (f8E5M2)
+ // CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.e5m2x2.rn.relu.satfinite(<2 x bfloat> %{{.*}})
+ %res2 = nvvm.convert.bf16x2.to.f8x2 %srcA {relu = true, rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> -> vector<2xi8> (f8E5M2)
+ llvm.return
+}
+
// CHECK-LABEL: @convert_bf16x2_to_f8x2_vector_return
llvm.func @convert_bf16x2_to_f8x2_vector_return(%src : vector<2xbf16>) {
// CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rz(<2 x bfloat> %{{.*}})
@@ -98,6 +117,12 @@ llvm.func @convert_bf16x2_to_f8x2_vector_return(%src : vector<2xbf16>) {
// CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rp.satfinite(<2 x bfloat> %{{.*}})
// CHECK-NEXT: %{{.*}} = bitcast i16 %[[res2]] to <2 x i8>
%res2 = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rp>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> -> vector<2xi8> (f8E8M0FNU)
+ // CHECK: %[[res3:.*]] = call i16 @llvm.nvvm.bf16x2.to.e4m3x2.rn.satfinite(<2 x bfloat> %{{.*}})
+ // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res3]] to <2 x i8>
+ %res3 = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> -> vector<2xi8> (f8E4M3FN)
+ // CHECK: %[[res4:.*]] = call i16 @llvm.nvvm.bf16x2.to.e5m2x2.rn.relu.satfinite(<2 x bfloat> %{{.*}})
+ // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res4]] to <2 x i8>
+ %res4 = nvvm.convert.bf16x2.to.f8x2 %src {relu = true, rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> -> vector<2xi8> (f8E5M2)
llvm.return
}
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2_invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2_invalid.mlir
new file mode 100644
index 0000000000000..bef5637839205
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2_invalid.mlir
@@ -0,0 +1,41 @@
+// RUN: mlir-translate -mlir-to-llvmir -verify-diagnostics %s
+
+// -----
+
+llvm.func @convert_bf16x2_to_f8x2_invalid_type(%src : vector<2xbf16>) {
+ // expected-error @below {{Only 'f8E4M3FN', 'f8E5M2', and 'f8E8M0FNU' types are supported for conversions from bf16x2 to f8x2.}}
+ %res = nvvm.convert.bf16x2.to.f8x2 %src : vector<2xbf16> -> vector<2xi8> (f6E2M3FN)
+ llvm.return
+}
+
+// -----
+
+llvm.func @convert_f16x2_to_f8x2_invalid_rounding_1(%src : vector<2xbf16>) {
+ // expected-error @below {{Only RN rounding mode is supported for conversions from bf16x2 to 'f8E4M3FN' and 'f8E5M2' types}}
+ %res = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rm>} : vector<2xbf16> -> vector<2xi8> (f8E4M3FN)
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_bf16x2_to_f8x2_invalid_rounding_2(%src : vector<2xbf16>) {
+ // expected-error @below {{Only RZ and RP rounding modes are supported for conversions from bf16x2 to 'f8E8M0FNU' type}}
+ %res = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xbf16> -> vector<2xi8> (f8E8M0FNU)
+ llvm.return
+}
+
+// -----
+
+llvm.func @convert_bf16x2_to_f8x2_invalid_sat_mode(%src : vector<2xbf16>) {
+ // expected-error @below {{Only SATFINITE saturation mode is supported for conversions from bf16x2 to 'f8E4M3FN' and 'f8E5M2' types}}
+ %res = nvvm.convert.bf16x2.to.f8x2 %src {sat = #nvvm.sat_mode<none>, rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xbf16> -> vector<2xi8> (f8E4M3FN)
+ llvm.return
+}
+
+// -----
+
+llvm.func @convert_bf16x2_to_f8x2_invalid_relu(%src : vector<2xbf16>) {
+ // expected-error @below {{relu not supported for conversions to 'f8E8M0FNU' type}}
+ %res = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rp>, relu = true} : vector<2xbf16> -> vector<2xi8> (f8E8M0FNU)
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_s2f6x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_s2f6x2.mlir
new file mode 100644
index 0000000000000..f54b391af96e2
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_s2f6x2.mlir
@@ -0,0 +1,181 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+llvm.func @convert_f32x2_to_s2f6x2(%srcA : f32, %srcB : f32) -> i16 {
+ // CHECK-LABEL: define i16 @convert_f32x2_to_s2f6x2(float %0, float %1) {
+ // CHECK-NEXT: %3 = call i16 @llvm.nvvm.ff.to.s2f6x2.rn.satfinite.scale.n2.ue8m0(float %0, float %1, i16 32639)
+ // CHECK-NEXT: %4 = call i16 @llvm.nvvm.ff.to.s2f6x2.rn.relu.satfinite.scale.n2.ue8m0(float %0, float %1, i16 32639)
+ // CHECK-NEXT: %5 = or i16 %3, %4
+ // CHECK-NEXT: ret i16 %5
+ // CHECK-NEXT: }
+ %res1 = nvvm.convert.f32x2.to.s2f6x2 %srcA, %srcB : i16
+ %res2 = nvvm.convert.f32x2.to.s2f6x2 %srcA, %srcB {relu = true} : i16
+
+ // Combine results to avoid dead code elimination
+ %final_result = llvm.or %res1, %res2 : i16
+ llvm.return %final_result : i16
+}
+
+llvm.func @convert_f32x2_to_s2f6x2_scale(%srcA : f32, %srcB : f32, %scale : i16) -> i16 {
+ // CHECK-LABEL: define i16 @convert_f32x2_to_s2f6x2_scale(float %0, float %1, i16 %2) {
+ // CHECK-NEXT: %4 = call i16 @llvm.nvvm.ff.to.s2f6x2.rn.satfinite.scale.n2.ue8m0(float %0, float %1, i16 %2)
+ // CHECK-NEXT: %5 = call i16 @llvm.nvvm.ff.to.s2f6x2.rn.relu.satfinite.scale.n2.ue8m0(float %0, float %1, i16 %2)
+ // CHECK-NEXT: %6 = or i16 %4, %5
+ // CHECK-NEXT: ret i16 %6
+ // CHECK-NEXT: }
+ %res1 = nvvm.convert.f32x2.to.s2f6x2 %srcA, %srcB, %scale : i16
+ %res2 = nvvm.convert.f32x2.to.s2f6x2 %srcA, %srcB, %scale {relu = true} : i16
+
+ // Combine results to avoid dead code elimination
+ %final_result = llvm.or %res1, %res2 : i16
+ llvm.return %final_result : i16
+}
+
+llvm.func @convert_f32x2_to_s2f6x2_vector(%srcA : f32, %srcB : f32) -> vector<2xi8> {
+ // CHECK-LABEL: define <2 x i8> @convert_f32x2_to_s2f6x2_vector(float %0, float %1) {
+ // CHECK-NEXT: %3 = call i16 @llvm.nvvm.ff.to.s2f6x2.rn.satfinite.scale.n2.ue8m0(float %0, float %1, i16 32639)
+ // CHECK-NEXT: %4 = bitcast i16 %3 to <2 x i8>
+ // CHECK-NEXT: ret <2 x i8> %4
+ // CHECK-NEXT: }
+ %res1 = nvvm.convert.f32x2.to.s2f6x2 %srcA, %srcB : vector<2xi8>
+ llvm.return %res1 : vector<2xi8>
+}
+
+llvm.func @convert_f32x2_to_s2f6x2_vector_scale(%srcA : f32, %srcB : f32, %scale : i16) -> vector<2xi8> {
+ // CHECK-LABEL: define <2 x i8> @convert_f32x2_to_s2f6x2_vector_scale(float %0, float %1, i16 %2) {
+ // CHECK-NEXT: %4 = call i16 @llvm.nvvm.ff.to.s2f6x2.rn.satfinite.scale.n2.ue8m0(float %0, float %1, i16 %2)
+ // CHECK-NEXT: %5 = bitcast i16 %4 to <2 x i8>
+ // CHECK-NEXT: ret <2 x i8> %5
+ // CHECK-NEXT: }
+ %res1 = nvvm.convert.f32x2.to.s2f6x2 %srcA, %srcB, %scale : vector<2xi8>
+ llvm.return %res1 : vector<2xi8>
+}
+
+llvm.func @convert_bf16x2_to_s2f6x2(%srcA : vector<2xbf16>) -> i16 {
+ // CHECK-LABEL: define i16 @convert_bf16x2_to_s2f6x2(<2 x bfloat> %0) {
+ // CHECK-NEXT: %2 = call i16 @llvm.nvvm.bf16x2.to.s2f6x2.rn.satfinite.scale.n2.ue8m0(<2 x bfloat> %0, i16 32639)
+ // CHECK-NEXT: %3 = call i16 @llvm.nvvm.bf16x2.to.s2f6x2.rn.relu.satfinite.scale.n2.ue8m0(<2 x bfloat> %0, i16 32639)
+ // CHECK-NEXT: %4 = or i16 %2, %3
+ // CHECK-NEXT: ret i16 %4
+ // CHECK-NEXT: }
+ %res1 = nvvm.convert.bf16x2.to.s2f6x2 %srcA : vector<2xbf16> -> i16
+ %res2 = nvvm.convert.bf16x2.to.s2f6x2 %srcA {relu = true} : vector<2xbf16> -> i16
+
+ // Combine results to avoid dead code elimination
+ %final_result = llvm.or %res1, %res2 : i16
+ llvm.return %final_result : i16
+}
+
+llvm.func @convert_bf16x2_to_s2f6x2_scale(%srcA : vector<2xbf16>, %scale : i16) -> i16 {
+ // CHECK-LABEL: define i16 @convert_bf16x2_to_s2f6x2_scale(<2 x bfloat> %0, i16 %1) {
+ // CHECK-NEXT: %3 = call i16 @llvm.nvvm.bf16x2.to.s2f6x2.rn.satfinite.scale.n2.ue8m0(<2 x bfloat> %0, i16 %1)
+ // CHECK-NEXT: %4 = call i16 @llvm.nvvm.bf16x2.to.s2f6x2.rn.relu.satfinite.scale.n2.ue8m0(<2 x bfloat> %0, i16 %1)
+ // CHECK-NEXT: %5 = or i16 %3, %4
+ // CHECK-NEXT: ret i16 %5
+ // CHECK-NEXT: }
+ %res1 = nvvm.convert.bf16x2.to.s2f6x2 %srcA, %scale : vector<2xbf16> -> i16
+ %res2 = nvvm.convert.bf16x2.to.s2f6x2 %srcA, %scale {relu = true} : vector<2xbf16> -> i16
+
+ // Combine results to avoid dead code elimination
+ %final_result = llvm.or %res1, %res2 : i16
+ llvm.return %final_result : i16
+}
+
+llvm.func @convert_bf16x2_to_s2f6x2_vector(%srcA : vector<2xbf16>) -> vector<2xi8> {
+ // CHECK-LABEL: define <2 x i8> @convert_bf16x2_to_s2f6x2_vector(<2 x bfloat> %0) {
+ // CHECK-NEXT: %2 = call i16 @llvm.nvvm.bf16x2.to.s2f6x2.rn.satfinite.scale.n2.ue8m0(<2 x bfloat> %0, i16 32639)
+ // CHECK-NEXT: %3 = bitcast i16 %2 to <2 x i8>
+ // CHECK-NEXT: ret <2 x i8> %3
+ // CHECK-NEXT: }
+ %res1 = nvvm.convert.bf16x2.to.s2f6x2 %srcA : vector<2xbf16> -> vector<2xi8>
+ llvm.return %res1 : vector<2xi8>
+}
+
+llvm.func @convert_bf16x2_to_s2f6x2_vector_scale(%srcA : vector<2xbf16>, %scale : i16) -> vector<2xi8> {
+ // CHECK-LABEL: define <2 x i8> @convert_bf16x2_to_s2f6x2_vector_scale(<2 x bfloat> %0, i16 %1) {
+ // CHECK-NEXT: %3 = call i16 @llvm.nvvm.bf16x2.to.s2f6x2.rn.satfinite.scale.n2.ue8m0(<2 x bfloat> %0, i16 %1)
+ // CHECK-NEXT: %4 = bitcast i16 %3 to <2 x i8>
+ // CHECK-NEXT: ret <2 x i8> %4
+ // CHECK-NEXT: }
+ %res1 = nvvm.convert.bf16x2.to.s2f6x2 %srcA, %scale : vector<2xbf16> -> vector<2xi8>
+ llvm.return %res1 : vector<2xi8>
+}
+
+// 1. no relu, no scale, no satfinite
+llvm.func @convert_s2f6x2_to_bf16x2(%src : i16) -> vector<2xbf16> {
+ // CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2(i16 %0) {
+ // CHECK-NEXT: %2 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.scale.n2.ue8m0(i16 %0, i16 32639)
+ // CHECK-NEXT: ret <2 x bfloat> %2
+ // CHECK-NEXT: }
+ %res = nvvm.convert.s2f6x2.to.bf16x2 %src : i16 -> vector<2xbf16>
+ llvm.return %res : vector<2xbf16>
+}
+
+// 2. relu, no scale, no satfinite
+llvm.func @convert_s2f6x2_to_bf16x2_relu(%src : i16) -> vector<2xbf16> {
+ // CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2_relu(i16 %0) {
+ // CHECK-NEXT: %2 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.relu.scale.n2.ue8m0(i16 %0, i16 32639)
+ // CHECK-NEXT: ret <2 x bfloat> %2
+ // CHECK-NEXT: }
+ %res = nvvm.convert.s2f6x2.to.bf16x2 %src {relu = true} : i16 -> vector<2xbf16>
+ llvm.return %res : vector<2xbf16>
+}
+
+// 3. no relu, with scale, no satfinite
+llvm.func @convert_s2f6x2_to_bf16x2_scale(%src : i16, %scale : i16) -> vector<2xbf16> {
+ // CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2_scale(i16 %0, i16 %1) {
+ // CHECK-NEXT: %3 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.scale.n2.ue8m0(i16 %0, i16 %1)
+ // CHECK-NEXT: ret <2 x bfloat> %3
+ // CHECK-NEXT: }
+ %res = nvvm.convert.s2f6x2.to.bf16x2 %src, %scale : i16 -> vector<2xbf16>
+ llvm.return %res : vector<2xbf16>
+}
+
+// 4. relu, with scale, no satfinite
+llvm.func @convert_s2f6x2_to_bf16x2_scale_relu(%src : i16, %scale : i16) -> vector<2xbf16> {
+ // CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2_scale_relu(i16 %0, i16 %1) {
+ // CHECK-NEXT: %3 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.relu.scale.n2.ue8m0(i16 %0, i16 %1)
+ // CHECK-NEXT: ret <2 x bfloat> %3
+ // CHECK-NEXT: }
+ %res = nvvm.convert.s2f6x2.to.bf16x2 %src, %scale {relu = true} : i16 -> vector<2xbf16>
+ llvm.return %res : vector<2xbf16>
+}
+
+// 5. no relu, no scale, satfinite
+llvm.func @convert_s2f6x2_to_bf16x2_satfinite(%src : i16) -> vector<2xbf16> {
+ // CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2_satfinite(i16 %0) {
+ // CHECK-NEXT: %2 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.satfinite.scale.n2.ue8m0(i16 %0, i16 32639)
+ // CHECK-NEXT: ret <2 x bfloat> %2
+ // CHECK-NEXT: }
+ %res = nvvm.convert.s2f6x2.to.bf16x2 %src {sat = #nvvm.sat_mode<satfinite>} : i16 -> vector<2xbf16>
+ llvm.return %res : vector<2xbf16>
+}
+
+// 6. relu, no scale, satfinite
+llvm.func @convert_s2f6x2_to_bf16x2_relu_satfinite(%src : i16) -> vector<2xbf16> {
+ // CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2_relu_satfinite(i16 %0) {
+ // CHECK-NEXT: %2 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.relu.satfinite.scale.n2.ue8m0(i16 %0, i16 32639)
+ // CHECK-NEXT: ret <2 x bfloat> %2
+ // CHECK-NEXT: }
+ %res = nvvm.convert.s2f6x2.to.bf16x2 %src {relu = true, sat = #nvvm.sat_mode<satfinite>} : i16 -> vector<2xbf16>
+ llvm.return %res : vector<2xbf16>
+}
+
+// 7. no relu, with scale, satfinite
+llvm.func @convert_s2f6x2_to_bf16x2_scale_satfinite(%src : i16, %scale : i16) -> vector<2xbf16> {
+ // CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2_scale_satfinite(i16 %0, i16 %1) {
+ // CHECK-NEXT: %3 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.satfinite.scale.n2.ue8m0(i16 %0, i16 %1)
+ // CHECK-NEXT: ret <2 x bfloat> %3
+ // CHECK-NEXT: }
+ %res = nvvm.convert.s2f6x2.to.bf16x2 %src, %scale {sat = #nvvm.sat_mode<satfinite>} : i16 -> vector<2xbf16>
+ llvm.return %res : vector<2xbf16>
+}
+
+// 8. relu, with scale, satfinite
+llvm.func @convert_s2f6x2_to_bf16x2_scale_relu_satfinite(%src : i16, %scale : i16) -> vector<2xbf16> {
+ // CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2_scale_relu_satfinite(i16 %0, i16 %1) {
+ // CHECK-NEXT: %3 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.relu.satfinite.scale.n2.ue8m0(i16 %0, i16 %1)
+ // CHECK-NEXT: ret <2 x bfloat> %3
+ // CHECK-NEXT: }
+ %res = nvvm.convert.s2f6x2.to.bf16x2 %src, %scale {relu = true, sat = #nvvm.sat_mode<satfinite>} : i16 -> vector<2xbf16>
+ llvm.return %res : vector<2xbf16>
+}
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index c0fe0fa11f497..6d7221c1393da 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -209,22 +209,6 @@ llvm.func @nvvm_cvt_f16x2_to_f8x2_invalid_type(%src : vector<2xf16>) {
// -----
-llvm.func @nvvm_cvt_bf16x2_to_f8x2_invalid_type(%src : vector<2xbf16>) {
- // expected-error @below {{Only 'f8E8M0FNU' type is supported for conversions from bf16x2 to f8x2.}}
- %res = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xbf16> -> i16 (f8E4M3FN)
- llvm.return
-}
-
-// -----
-
-llvm.func @nvvm_cvt_bf16x2_to_f8x2_invalid_rounding(%src : vector<2xbf16>) {
- // expected-error @below {{Only RZ and RP rounding modes are supported for conversions from bf16x2 to f8x2.}}
- %res = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xbf16> -> i16 (f8E8M0FNU)
- llvm.return
-}
-
-// -----
-
llvm.func @nvvm_cvt_f32x2_to_f6x2_invalid_type(%a : f32, %b : f32) {
// expected-error @below {{Only 'f6E2M3FN' and 'f6E3M2FN' types are supported for conversions from f32x2 to f6x2.}}
%res = nvvm.convert.f32x2.to.f6x2 %a, %b : i16 (f8E8M0FNU)
>From c7fbdc27f19bd47490cda7d2caf32a016dea22a3 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Wed, 11 Mar 2026 15:48:42 +0000
Subject: [PATCH 2/3] address comments
---
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 134 ++++++++++--------
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 67 ++-------
.../LLVMIR/nvvm/convert_fp4x2_invalid.mlir | 4 +-
.../LLVMIR/nvvm/convert_fp6x2_invalid.mlir | 4 +-
.../LLVMIR/nvvm/convert_fp8x2_invalid.mlir | 2 +-
.../Target/LLVMIR/nvvm/convert_s2f6x2.mlir | 88 ++++++------
6 files changed, 142 insertions(+), 157 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index ef69def599ccc..7ff4e3215dc2b 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1974,26 +1974,25 @@ def NVVM_ConvertF32x2ToF4x2Op : NVVM_Op<"convert.f32x2.to.f4x2"> {
class NVVM_ConvertFPx2ToF4x2Op<string srcType>
: NVVM_Op<"convert."#!tolower(srcType)#"x2.to.f4x2"> {
- let summary = "Convert a pair of "#srcType#" inputs to F4x2";
+ let summary = "Convert an " # !tolower(srcType) # "x2 input to f4x2";
let description = [{
- This Op converts each of the given }]#srcType#[{ inputs to the specified
- F4x2 type.
+ This Op converts each of the given }]#srcType#[{ inputs in an }]#!tolower
+ (srcType)#[{x2 vector to the specified fp4 type.
The result `dst` is returned as an i8 type where the converted values are
- packed such that the value converted from `a` is stored in the upper 4 bits
- of `dst` and the value converted from `b` is stored in the lower 4 bits of
- `dst`.
+ packed such that the value converted from the first element of `a` is
+ stored in the lower 4 bits of `dst` and the value converted from the second
+ element of `a` is stored in the upper 4 bits of `dst`.
The `relu` attribute, when set, lowers to the '.relu' variant of
the cvt instruction.
}];
- let hasVerifier = 1;
let results = (outs I8:$dst);
let arguments = (ins
- VectorOfLengthAndType<[2], [!cast<Type>(srcType)]>:$a,
+ VectorOfLengthAndType<[2], [!cast<Type>(srcType)]>:$src,
DefaultValuedAttr<BoolAttr, "false">:$relu,
- TypeAttr:$dstTy);
+ TypeAttrOf<F4E2M1FN>:$dstTy);
let assemblyFormat =
- "$a attr-dict `:` type($a) `->` type($dst) `(` $dstTy `)`";
+ "$src attr-dict `:` type($src) `->` type($dst) `(` $dstTy `)`";
let extraClassDeclaration = [{
static NVVM::IDArgPair
@@ -2056,30 +2055,35 @@ def NVVM_ConvertF32x2ToF6x2Op : NVVM_Op<"convert.f32x2.to.f6x2"> {
class NVVM_ConvertFPx2ToF6x2Op<string srcType>
: NVVM_Op<"convert."#!tolower(srcType)#"x2.to.f6x2"> {
- let summary = "Convert a pair of "#srcType#" inputs to F6x2";
+ let summary = "Convert an " # !tolower(srcType) # "x2 input to f6x2";
let description = [{
- This Op converts each of the given }]#srcType#[{ inputs to the specified
- F6x2 type.
- The result `dst` is represented as a vector type(vector<2xi8>). Each
- converted value is stored as an i8 element in the vector.
+ This Op converts each of the given }]#srcType#[{ inputs in an }]#!tolower
+ (srcType)#[{x2 vector to the specified fp6 type. The result `dst` is
+ represented either as an i16 type or as a vector of two i8 types.
+ If `dst` is returned as an i16 type, the converted values are packed such
+ that the value converted from the first element of `a` is stored in the
+ lower 8 bits of `dst` with 2 MSB bits padded with zeros and the value
+ converted from the second element of `a` is stored in the upper 8 bits of
+ `dst` with 2 MSB bits padded with zeros.
+ If `dst` is returned as a vector type, each converted value is stored as an
+ i8 element in the vector with 2 MSB bits padded with zeros.
The `relu` attribute, when set, lowers to the '.relu' variant of
the cvt instruction.
}];
- let hasVerifier = 1;
let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
let arguments = (ins
- VectorOfLengthAndType<[2], [!cast<Type>(srcType)]>:$a,
+ VectorOfLengthAndType<[2], [!cast<Type>(srcType)]>:$src,
DefaultValuedAttr<BoolAttr, "false">:$relu,
- TypeAttr:$dstTy);
- let assemblyFormat = "$a attr-dict `:` type($a) `->` type($dst) `(` $dstTy `)`";
+ TypeAttrOf<AnyTypeOf<[F6E2M3FN, F6E3M2FN]>>:$dstTy);
+ let assemblyFormat = "$src attr-dict `:` type($src) `->` type($dst) `(` $dstTy `)`";
let extraClassDeclaration = [{
static llvm::Intrinsic::ID getIntrinsicID(mlir::Type dstTy, bool hasRelu);
}];
string llvmBuilder = [{
auto intId = NVVM::Convert}]#srcType#[{x2ToF6x2Op::getIntrinsicID($dstTy, $relu);
- llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a});
+ llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$src});
if(op.getDst().getType().isInteger(16))
$dst = packedI16;
else
@@ -2185,15 +2189,15 @@ def NVVM_ConvertBF16x2ToF8x2Op : NVVM_Op<"convert.bf16x2.to.f8x2"> {
let summary = "Convert a pair of bf16 inputs to f8x2";
let description = [{
This Op converts the given bf16 inputs in a bf16x2 vector to the specified
- f8 type.
- The result `dst` is represented as an i16 type or as a vector
- of two i8 types.
- If `dst` is returned as an i16 type, the converted values from `a`
- are packed such that the value converted from the first element of `a`
- is stored in the upper 8 bits of `dst` and the value converted from the
- second element of `a` is stored in the lower 8 bits of `dst`.
+ f8 type. The result `dst` is represented either as a packed i16 type or as
+ a vector of two i8 types.
+ If `dst` is returned as an i16 type, the converted values are packed such
+ that the value converted from the first element of `a` is stored in the
+ lower 8 bits of `dst` with 2 MSB bits padded with zeros and the value
+ converted from the second element of `a` is stored in the upper 8 bits of
+ `dst` with 2 MSB bits padded with zeros.
If `dst` is returned as a vector type, each converted value is stored as an
- i8 element in the vector.
+ i8 element in the vector with 2 MSB bits padded with zeros.
The `rnd` and `sat` attributes specify the rounding and saturation modes
respectively.
@@ -2203,12 +2207,12 @@ def NVVM_ConvertBF16x2ToF8x2Op : NVVM_Op<"convert.bf16x2.to.f8x2"> {
let hasVerifier = 1;
let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
let arguments = (ins
- VectorOfLengthAndType<[2], [BF16]>:$a,
+ VectorOfLengthAndType<[2], [BF16]>:$src,
DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd,
DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
DefaultValuedAttr<BoolAttr, "false">:$relu,
- TypeAttr:$dstTy);
- let assemblyFormat = "$a attr-dict `:` type($a) `->` type($dst) `(` $dstTy `)`";
+ TypeAttrOf<AnyTypeOf<[F8E8M0FNU, F8E4M3FN, F8E5M2]>>:$dstTy);
+ let assemblyFormat = "$src attr-dict `:` type($src) `->` type($dst) `(` $dstTy `)`";
let extraClassDeclaration = [{
static llvm::Intrinsic::ID getIntrinsicID(mlir::Type dstTy,
@@ -2218,7 +2222,7 @@ def NVVM_ConvertBF16x2ToF8x2Op : NVVM_Op<"convert.bf16x2.to.f8x2"> {
string llvmBuilder = [{
auto intId = NVVM::ConvertBF16x2ToF8x2Op::getIntrinsicID($dstTy, $rnd, $sat, $relu);
- llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a});
+ llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$src});
if(op.getDst().getType().isInteger(16))
$dst = packedI16;
else
@@ -2267,11 +2271,20 @@ def NVVM_ConvertF32x2ToS2F6x2Op : NVVM_Op<"convert.f32x2.to.s2f6x2"> {
let summary = "Convert a pair of f32 inputs to S2F6x2";
let description = [{
This Op converts each of the given f32 inputs to the
- S2F6x2 type. The result `dst` can be an I16 or vector<2xi8>.
+ S2F6x2 type. The result `dst` can be either a packed i16 type or a vector
+ of two i8 types.
+ If `dst` is returned as an i16 type, the converted values are packed such
+ that the value converted from `a` is stored in the upper 8 bits of `dst`
+ and the value converted from `b` is stored in the lower 8 bits of `dst`.
+ If `dst` is returned as a vector type, each converted value is stored as an
+ i8 element in the vector.
The `relu` attribute, when set, lowers to the '.relu' variant
- of the cvt instruction. The optional scaling-factor for the
- conversion is provided through the operand `scaleFactor`.
- Only `ue8m0` is supported as the type of the scale-factor currently.
+ of the cvt instruction.
+ The optional scaling-factors for each of the inputs are provided through
+ the operand `scaleFactor` as a packed i16 type. Only `ue8m0` is supported
+ as the type of the scale-factor currently.
+
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
}];
let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
@@ -2299,19 +2312,31 @@ def NVVM_ConvertF32x2ToS2F6x2Op : NVVM_Op<"convert.f32x2.to.s2f6x2"> {
def NVVM_ConvertBF16x2ToS2F6x2Op : NVVM_Op<"convert.bf16x2.to.s2f6x2"> {
let summary = "Convert a pair of BF16 inputs to S2F6x2";
let description = [{
- This Op converts each of the given BF16 inputs to the
- S2F6x2 type. The result `dst` can be an I16 or vector<2xi8>.
+ This Op converts each of the given BF16 inputs in a bf16x2 vector to the
+ S2F6x2 type. The result `dst` can be either a packed i16 type or a vector
+ of two i8 types.
+ If `dst` is returned as an i16 type, the converted values are packed such
+ that the value converted from the first element of `a` is stored in the
+ lower 8 bits of `dst` and the value converted from the second element of
+ `a` is stored in the upper 8 bits of `dst`.
+ If `dst` is returned as a vector type, each converted value is stored as an
+ i8 element in the vector.
The `relu` attribute, when set, lowers to the '.relu' variant
of the cvt instruction.
+ The optional scaling-factors for each of the inputs are provided through
+ the operand `scaleFactor` as a packed i16 type. Only `ue8m0` is supported
+ as the type of the scale-factor currently.
+
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
}];
let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
let arguments = (ins
- VectorOfLengthAndType<[2], [BF16]>:$a,
+ VectorOfLengthAndType<[2], [BF16]>:$src,
Optional<I16>:$scaleFactor,
DefaultValuedAttr<BoolAttr, "false">:$relu);
let assemblyFormat =
- "$a (`,` $scaleFactor^)? attr-dict `:` type($a) `->` type($dst)";
+ "$src (`,` $scaleFactor^)? attr-dict `:` type($src) `->` type($dst)";
let extraClassDeclaration = [{
static IDArgPair getIntrinsicIDAndArgs(Operation &op,
LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder);
@@ -2328,32 +2353,29 @@ def NVVM_ConvertBF16x2ToS2F6x2Op : NVVM_Op<"convert.bf16x2.to.s2f6x2"> {
}];
}
-def NVVM_ConvertS2F6x2ToBF16x2Op : NVVM_Op<"convert.s2f6x2.to.bf16x2"> {
+def NVVM_ConvertS2F6x2ToBF16x2Op : NVVM_SingleResultIntrinsicOp<"convert.s2f6x2.to.bf16x2", [], "$dst"> {
let summary = "Convert s2f6x2 to bf16x2";
let description = [{
- This Op converts the given s2f6x2 to bf16x2 values.
+ This Op converts a pair of s2f6x2 inputs to bf16x2 type. The result `dst`
+ is represented as a vector of two bf16 elements.
+
The `relu` attribute, when set, lowers to the '.relu' variant
- of the cvt instruction. The optional scaling-factor for the
- conversion is provided through the operand `scaleFactor`.
- Only `ue8m0` is supported as the type of the scale-factor currently.
+ of the cvt instruction.
+
+ The optional scaling-factors for each of the inputs are provided through
+ the operand `scaleFactor` as a packed i16 type. Only `ue8m0` is supported
+ as the type of the scale-factor currently.
+
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
}];
let results = (outs VectorOfLengthAndType<[2], [BF16]>:$dst);
- let arguments = (ins I16:$a,
+ let arguments = (ins VectorOfLengthAndType<[2], [I8]>:$src,
Optional<I16>:$scaleFactor,
DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
DefaultValuedAttr<BoolAttr, "false">:$relu);
let assemblyFormat =
- "$a (`,` $scaleFactor^)? attr-dict `:` type($a) `->` type($dst)";
- let extraClassDeclaration = [{
- static IDArgPair getIntrinsicIDAndArgs(Operation &op,
- LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder);
- }];
-
- string llvmBuilder = [{
- auto [id, args] = NVVM::ConvertS2F6x2ToBF16x2Op::getIntrinsicIDAndArgs(*op, moduleTranslation, builder);
- $dst = createIntrinsicCall(builder, id, args);
- }];
+ "$src (`,` $scaleFactor^)? attr-dict `:` type($src) `->` type($dst)";
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index a27ca2ebdfe78..b94de66e1c82a 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -347,32 +347,6 @@ LogicalResult ConvertF32x2ToF6x2Op::verify() {
return success();
}
-LogicalResult ConvertF16x2ToF6x2Op::verify() {
- mlir::MLIRContext *ctx = getContext();
-
- if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy())) {
- return emitOpError("Only ")
- << mlir::Float6E2M3FNType::get(ctx) << " and "
- << mlir::Float6E3M2FNType::get(ctx)
- << " types are supported for conversions from f16x2 to f6x2.";
- }
-
- return success();
-}
-
-LogicalResult ConvertBF16x2ToF6x2Op::verify() {
- mlir::MLIRContext *ctx = getContext();
-
- if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy())) {
- return emitOpError("Only ")
- << mlir::Float6E2M3FNType::get(ctx) << " and "
- << mlir::Float6E3M2FNType::get(ctx)
- << " types are supported for conversions from bf16x2 to f6x2.";
- }
-
- return success();
-}
-
LogicalResult ConvertF32x2ToF8x2Op::verify() {
using RndMode = NVVM::FPRoundingMode;
using SatMode = NVVM::SaturationMode;
@@ -476,11 +450,8 @@ LogicalResult ConvertBF16x2ToF8x2Op::verify() {
return success();
})
.Default([&](mlir::Type) -> LogicalResult {
- return emitOpError("Only ")
- << mlir::Float8E4M3FNType::get(ctx) << ", "
- << mlir::Float8E5M2Type::get(ctx) << ", and "
- << mlir::Float8E8M0FNUType::get(ctx)
- << " types are supported for conversions from bf16x2 to f8x2.";
+ llvm_unreachable("Invalid conversion in ConvertBF16x2ToF8x2Op");
+ return failure();
});
}
@@ -495,26 +466,6 @@ LogicalResult ConvertF32x2ToF4x2Op::verify() {
return success();
}
-LogicalResult ConvertF16x2ToF4x2Op::verify() {
- mlir::MLIRContext *ctx = getContext();
-
- if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
- return emitOpError("Only ")
- << mlir::Float4E2M1FNType::get(ctx)
- << " type is supported for conversions from f16x2 to f4x2.";
- return success();
-}
-
-LogicalResult ConvertBF16x2ToF4x2Op::verify() {
- mlir::MLIRContext *ctx = getContext();
-
- if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
- return emitOpError("Only ")
- << mlir::Float4E2M1FNType::get(ctx)
- << " type is supported for conversions from bf16x2 to f4x2.";
- return success();
-}
-
LogicalResult ConvertF8x2ToF16x2Op::verify() {
mlir::MLIRContext *ctx = getContext();
@@ -4275,7 +4226,7 @@ ConvertF16x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF16x2ToF4x2Op &op,
: llvm::Intrinsic::nvvm_f16x2_to_e2m1x2_rn_satfinite;
llvm::SmallVector<llvm::Value *> args;
- args.push_back(mt.lookupValue(op.getA()));
+ args.push_back(mt.lookupValue(op.getSrc()));
return {intId, std::move(args)};
}
@@ -4294,7 +4245,7 @@ ConvertBF16x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertBF16x2ToF4x2Op &op,
: llvm::Intrinsic::nvvm_bf16x2_to_e2m1x2_rn_satfinite;
llvm::SmallVector<llvm::Value *> args;
- args.push_back(mt.lookupValue(op.getA()));
+ args.push_back(mt.lookupValue(op.getSrc()));
return {intId, std::move(args)};
}
@@ -4403,7 +4354,7 @@ ConvertBF16x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy,
llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_rp_satfinite,
};
- return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
+ return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
.Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
return hasRelu
? llvm::Intrinsic::nvvm_bf16x2_to_e4m3x2_rn_relu_satfinite
@@ -4551,7 +4502,7 @@ NVVM::IDArgPair ConvertBF16x2ToS2F6x2Op::getIntrinsicIDAndArgs(
// Fill the Intrinsic Args
llvm::SmallVector<llvm::Value *> args;
- args.push_back(mt.lookupValue(thisOp.getA()));
+ args.push_back(mt.lookupValue(thisOp.getSrc()));
args.push_back(hasScale ? mt.lookupValue(thisOp.getScaleFactor())
: builder.getInt16(0x7f7f));
return {id, std::move(args)};
@@ -4575,9 +4526,13 @@ NVVM::IDArgPair ConvertS2F6x2ToBF16x2Op::getIntrinsicIDAndArgs(
// Fill the Intrinsic Args
llvm::SmallVector<llvm::Value *> args;
- args.push_back(mt.lookupValue(thisOp.getA()));
+ llvm::Value *packedI16 =
+ builder.CreateBitCast(mt.lookupValue(thisOp.getSrc()),
+ llvm::Type::getInt16Ty(builder.getContext()));
+ args.push_back(packedI16);
args.push_back(hasScale ? mt.lookupValue(thisOp.getScaleFactor())
: builder.getInt16(0x7f7f));
+
return {ids[idx], std::move(args)};
}
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2_invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2_invalid.mlir
index 637db9e8ca07a..d179431fe9369 100644
--- a/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2_invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2_invalid.mlir
@@ -3,7 +3,7 @@
// -----
llvm.func @convert_f16x2_to_f4x2_invalid_type(%src : vector<2xf16>) {
- // expected-error @below {{Only 'f4E2M1FN' type is supported for conversions from f16x2 to f4x2.}}
+ // expected-error @below {{attribute 'dstTy' failed to satisfy constraint: type attribute of f4E2M1FN type}}
%res = nvvm.convert.f16x2.to.f4x2 %src : vector<2xf16> -> i8 (f8E4M3FN)
llvm.return
}
@@ -11,7 +11,7 @@ llvm.func @convert_f16x2_to_f4x2_invalid_type(%src : vector<2xf16>) {
// -----
llvm.func @convert_bf16x2_to_f4x2_invalid_type(%src : vector<2xbf16>) {
- // expected-error @below {{Only 'f4E2M1FN' type is supported for conversions from bf16x2 to f4x2.}}
+ // expected-error @below {{attribute 'dstTy' failed to satisfy constraint: type attribute of f4E2M1FN type}}
%res = nvvm.convert.bf16x2.to.f4x2 %src : vector<2xbf16> -> i8 (f8E4M3FN)
llvm.return
}
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2_invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2_invalid.mlir
index 8d0df5863a931..e993868cf1c9f 100644
--- a/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2_invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2_invalid.mlir
@@ -3,7 +3,7 @@
// -----
llvm.func @convert_f16x2_to_f6x2_invalid_type(%src : vector<2xf16>) {
- // expected-error @below {{Only 'f6E2M3FN' and 'f6E3M2FN' types are supported for conversions from f16x2 to f6x2.}}
+ // expected-error @below {{attribute 'dstTy' failed to satisfy constraint: type attribute of f6E2M3FN type or f6E3M2FN type}}
%res = nvvm.convert.f16x2.to.f6x2 %src : vector<2xf16> -> vector<2xi8> (f8E4M3FN)
llvm.return
}
@@ -11,7 +11,7 @@ llvm.func @convert_f16x2_to_f6x2_invalid_type(%src : vector<2xf16>) {
// -----
llvm.func @convert_bf16x2_to_f6x2_invalid_type(%src : vector<2xbf16>) {
- // expected-error @below {{Only 'f6E2M3FN' and 'f6E3M2FN' types are supported for conversions from bf16x2 to f6x2.}}
+ // expected-error @below {{attribute 'dstTy' failed to satisfy constraint: type attribute of f6E2M3FN type or f6E3M2FN type}}
%res = nvvm.convert.bf16x2.to.f6x2 %src : vector<2xbf16> -> vector<2xi8> (f8E4M3FN)
llvm.return
}
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2_invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2_invalid.mlir
index bef5637839205..4164238eb6e53 100644
--- a/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2_invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2_invalid.mlir
@@ -3,7 +3,7 @@
// -----
llvm.func @convert_bf16x2_to_f8x2_invalid_type(%src : vector<2xbf16>) {
- // expected-error @below {{Only 'f8E4M3FN', 'f8E5M2', and 'f8E8M0FNU' types are supported for conversions from bf16x2 to f8x2.}}
+ // expected-error @below {{attribute 'dstTy' failed to satisfy constraint: type attribute of f8E8M0FNU type or f8E4M3FN type or f8E5M2 type}}
%res = nvvm.convert.bf16x2.to.f8x2 %src : vector<2xbf16> -> vector<2xi8> (f6E2M3FN)
llvm.return
}
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_s2f6x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_s2f6x2.mlir
index f54b391af96e2..7c1aa406a47af 100644
--- a/mlir/test/Target/LLVMIR/nvvm/convert_s2f6x2.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_s2f6x2.mlir
@@ -101,81 +101,89 @@ llvm.func @convert_bf16x2_to_s2f6x2_vector_scale(%srcA : vector<2xbf16>, %scale
}
// 1. no relu, no scale, no satfinite
-llvm.func @convert_s2f6x2_to_bf16x2(%src : i16) -> vector<2xbf16> {
- // CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2(i16 %0) {
- // CHECK-NEXT: %2 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.scale.n2.ue8m0(i16 %0, i16 32639)
- // CHECK-NEXT: ret <2 x bfloat> %2
+llvm.func @convert_s2f6x2_to_bf16x2(%src : vector<2xi8>) -> vector<2xbf16> {
+ // CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2(<2 x i8> %0) {
+ // CHECK-NEXT: %2 = bitcast <2 x i8> %0 to i16
+ // CHECK-NEXT: %3 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.scale.n2.ue8m0(i16 %2, i16 32639)
+ // CHECK-NEXT: ret <2 x bfloat> %3
// CHECK-NEXT: }
- %res = nvvm.convert.s2f6x2.to.bf16x2 %src : i16 -> vector<2xbf16>
+ %res = nvvm.convert.s2f6x2.to.bf16x2 %src : vector<2xi8> -> vector<2xbf16>
llvm.return %res : vector<2xbf16>
}
// 2. relu, no scale, no satfinite
-llvm.func @convert_s2f6x2_to_bf16x2_relu(%src : i16) -> vector<2xbf16> {
- // CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2_relu(i16 %0) {
- // CHECK-NEXT: %2 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.relu.scale.n2.ue8m0(i16 %0, i16 32639)
- // CHECK-NEXT: ret <2 x bfloat> %2
+llvm.func @convert_s2f6x2_to_bf16x2_relu(%src : vector<2xi8>) -> vector<2xbf16> {
+ // CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2_relu(<2 x i8> %0) {
+ // CHECK-NEXT: %2 = bitcast <2 x i8> %0 to i16
+ // CHECK-NEXT: %3 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.relu.scale.n2.ue8m0(i16 %2, i16 32639)
+ // CHECK-NEXT: ret <2 x bfloat> %3
// CHECK-NEXT: }
- %res = nvvm.convert.s2f6x2.to.bf16x2 %src {relu = true} : i16 -> vector<2xbf16>
+ %res = nvvm.convert.s2f6x2.to.bf16x2 %src {relu = true} : vector<2xi8> -> vector<2xbf16>
llvm.return %res : vector<2xbf16>
}
// 3. no relu, with scale, no satfinite
-llvm.func @convert_s2f6x2_to_bf16x2_scale(%src : i16, %scale : i16) -> vector<2xbf16> {
- // CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2_scale(i16 %0, i16 %1) {
- // CHECK-NEXT: %3 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.scale.n2.ue8m0(i16 %0, i16 %1)
- // CHECK-NEXT: ret <2 x bfloat> %3
+llvm.func @convert_s2f6x2_to_bf16x2_scale(%src : vector<2xi8>, %scale : i16) -> vector<2xbf16> {
+ // CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2_scale(<2 x i8> %0, i16 %1) {
+ // CHECK-NEXT: %3 = bitcast <2 x i8> %0 to i16
+ // CHECK-NEXT: %4 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.scale.n2.ue8m0(i16 %3, i16 %1)
+ // CHECK-NEXT: ret <2 x bfloat> %4
// CHECK-NEXT: }
- %res = nvvm.convert.s2f6x2.to.bf16x2 %src, %scale : i16 -> vector<2xbf16>
+ %res = nvvm.convert.s2f6x2.to.bf16x2 %src, %scale : vector<2xi8> -> vector<2xbf16>
llvm.return %res : vector<2xbf16>
}
// 4. relu, with scale, no satfinite
-llvm.func @convert_s2f6x2_to_bf16x2_scale_relu(%src : i16, %scale : i16) -> vector<2xbf16> {
- // CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2_scale_relu(i16 %0, i16 %1) {
- // CHECK-NEXT: %3 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.relu.scale.n2.ue8m0(i16 %0, i16 %1)
- // CHECK-NEXT: ret <2 x bfloat> %3
+llvm.func @convert_s2f6x2_to_bf16x2_scale_relu(%src : vector<2xi8>, %scale : i16) -> vector<2xbf16> {
+ // CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2_scale_relu(<2 x i8> %0, i16 %1) {
+ // CHECK-NEXT: %3 = bitcast <2 x i8> %0 to i16
+ // CHECK-NEXT: %4 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.relu.scale.n2.ue8m0(i16 %3, i16 %1)
+ // CHECK-NEXT: ret <2 x bfloat> %4
// CHECK-NEXT: }
- %res = nvvm.convert.s2f6x2.to.bf16x2 %src, %scale {relu = true} : i16 -> vector<2xbf16>
+ %res = nvvm.convert.s2f6x2.to.bf16x2 %src, %scale {relu = true} : vector<2xi8> -> vector<2xbf16>
llvm.return %res : vector<2xbf16>
}
// 5. no relu, no scale, satfinite
-llvm.func @convert_s2f6x2_to_bf16x2_satfinite(%src : i16) -> vector<2xbf16> {
- // CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2_satfinite(i16 %0) {
- // CHECK-NEXT: %2 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.satfinite.scale.n2.ue8m0(i16 %0, i16 32639)
- // CHECK-NEXT: ret <2 x bfloat> %2
+llvm.func @convert_s2f6x2_to_bf16x2_satfinite(%src : vector<2xi8>) -> vector<2xbf16> {
+ // CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2_satfinite(<2 x i8> %0) {
+ // CHECK-NEXT: %2 = bitcast <2 x i8> %0 to i16
+ // CHECK-NEXT: %3 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.satfinite.scale.n2.ue8m0(i16 %2, i16 32639)
+ // CHECK-NEXT: ret <2 x bfloat> %3
// CHECK-NEXT: }
- %res = nvvm.convert.s2f6x2.to.bf16x2 %src {sat = #nvvm.sat_mode<satfinite>} : i16 -> vector<2xbf16>
+ %res = nvvm.convert.s2f6x2.to.bf16x2 %src {sat = #nvvm.sat_mode<satfinite>} : vector<2xi8> -> vector<2xbf16>
llvm.return %res : vector<2xbf16>
}
// 6. relu, no scale, satfinite
-llvm.func @convert_s2f6x2_to_bf16x2_relu_satfinite(%src : i16) -> vector<2xbf16> {
- // CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2_relu_satfinite(i16 %0) {
- // CHECK-NEXT: %2 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.relu.satfinite.scale.n2.ue8m0(i16 %0, i16 32639)
- // CHECK-NEXT: ret <2 x bfloat> %2
+llvm.func @convert_s2f6x2_to_bf16x2_relu_satfinite(%src : vector<2xi8>) -> vector<2xbf16> {
+ // CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2_relu_satfinite(<2 x i8> %0) {
+ // CHECK-NEXT: %2 = bitcast <2 x i8> %0 to i16
+ // CHECK-NEXT: %3 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.relu.satfinite.scale.n2.ue8m0(i16 %2, i16 32639)
+ // CHECK-NEXT: ret <2 x bfloat> %3
// CHECK-NEXT: }
- %res = nvvm.convert.s2f6x2.to.bf16x2 %src {relu = true, sat = #nvvm.sat_mode<satfinite>} : i16 -> vector<2xbf16>
+ %res = nvvm.convert.s2f6x2.to.bf16x2 %src {relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xi8> -> vector<2xbf16>
llvm.return %res : vector<2xbf16>
}
// 7. no relu, with scale, satfinite
-llvm.func @convert_s2f6x2_to_bf16x2_scale_satfinite(%src : i16, %scale : i16) -> vector<2xbf16> {
- // CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2_scale_satfinite(i16 %0, i16 %1) {
- // CHECK-NEXT: %3 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.satfinite.scale.n2.ue8m0(i16 %0, i16 %1)
- // CHECK-NEXT: ret <2 x bfloat> %3
+llvm.func @convert_s2f6x2_to_bf16x2_scale_satfinite(%src : vector<2xi8>, %scale : i16) -> vector<2xbf16> {
+ // CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2_scale_satfinite(<2 x i8> %0, i16 %1) {
+ // CHECK-NEXT: %3 = bitcast <2 x i8> %0 to i16
+ // CHECK-NEXT: %4 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.satfinite.scale.n2.ue8m0(i16 %3, i16 %1)
+ // CHECK-NEXT: ret <2 x bfloat> %4
// CHECK-NEXT: }
- %res = nvvm.convert.s2f6x2.to.bf16x2 %src, %scale {sat = #nvvm.sat_mode<satfinite>} : i16 -> vector<2xbf16>
+ %res = nvvm.convert.s2f6x2.to.bf16x2 %src, %scale {sat = #nvvm.sat_mode<satfinite>} : vector<2xi8> -> vector<2xbf16>
llvm.return %res : vector<2xbf16>
}
// 8. relu, with scale, satfinite
-llvm.func @convert_s2f6x2_to_bf16x2_scale_relu_satfinite(%src : i16, %scale : i16) -> vector<2xbf16> {
- // CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2_scale_relu_satfinite(i16 %0, i16 %1) {
- // CHECK-NEXT: %3 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.relu.satfinite.scale.n2.ue8m0(i16 %0, i16 %1)
- // CHECK-NEXT: ret <2 x bfloat> %3
+llvm.func @convert_s2f6x2_to_bf16x2_scale_relu_satfinite(%src : vector<2xi8>, %scale : i16) -> vector<2xbf16> {
+ // CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2_scale_relu_satfinite(<2 x i8> %0, i16 %1) {
+ // CHECK-NEXT: %3 = bitcast <2 x i8> %0 to i16
+ // CHECK-NEXT: %4 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.relu.satfinite.scale.n2.ue8m0(i16 %3, i16 %1)
+ // CHECK-NEXT: ret <2 x bfloat> %4
// CHECK-NEXT: }
- %res = nvvm.convert.s2f6x2.to.bf16x2 %src, %scale {relu = true, sat = #nvvm.sat_mode<satfinite>} : i16 -> vector<2xbf16>
+ %res = nvvm.convert.s2f6x2.to.bf16x2 %src, %scale {relu = true, sat = #nvvm.sat_mode<satfinite>} : vector<2xi8> -> vector<2xbf16>
llvm.return %res : vector<2xbf16>
}
>From 89187c71908b1a7e7d0add2ed0c5d7f3107a81bd Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Thu, 2 Apr 2026 09:45:28 +0000
Subject: [PATCH 3/3] clean-up
---
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 7 +++----
mlir/test/Target/LLVMIR/nvvm/convert_fp8x2_invalid.mlir | 2 +-
2 files changed, 4 insertions(+), 5 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 7ff4e3215dc2b..9b79cf6877dd4 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -2193,11 +2193,10 @@ def NVVM_ConvertBF16x2ToF8x2Op : NVVM_Op<"convert.bf16x2.to.f8x2"> {
a vector of two i8 types.
If `dst` is returned as an i16 type, the converted values are packed such
that the value converted from the first element of `a` is stored in the
- lower 8 bits of `dst` with 2 MSB bits padded with zeros and the value
- converted from the second element of `a` is stored in the upper 8 bits of
- `dst` with 2 MSB bits padded with zeros.
+ lower 8 bits of `dst` and the value converted from the second element of
+ `a` is stored in the upper 8 bits of `dst`.
If `dst` is returned as a vector type, each converted value is stored as an
- i8 element in the vector with 2 MSB bits padded with zeros.
+ i8 element in the vector.
The `rnd` and `sat` attributes specify the rounding and saturation modes
respectively.
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2_invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2_invalid.mlir
index 4164238eb6e53..747706dfc3418 100644
--- a/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2_invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2_invalid.mlir
@@ -10,7 +10,7 @@ llvm.func @convert_bf16x2_to_f8x2_invalid_type(%src : vector<2xbf16>) {
// -----
-llvm.func @convert_f16x2_to_f8x2_invalid_rounding_1(%src : vector<2xbf16>) {
+llvm.func @convert_bf16x2_to_f8x2_invalid_rounding_1(%src : vector<2xbf16>) {
// expected-error @below {{Only RN rounding mode is supported for conversions from bf16x2 to 'f8E4M3FN' and 'f8E5M2' types}}
%res = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rm>} : vector<2xbf16> -> vector<2xi8> (f8E4M3FN)
llvm.return
More information about the Mlir-commits
mailing list