[Mlir-commits] [mlir] [MLIR][NVVM] Add new narrow FP convert Ops (PR #184291)
Srinivasa Ravi
llvmlistbot at llvm.org
Mon Mar 2 22:18:45 PST 2026
https://github.com/Wolfram70 created https://github.com/llvm/llvm-project/pull/184291
This change adds the following NVVM Ops for new narrow FP conversions introduced in PTX 9.1:
- `convert.{f32x2/bf16x2}.to.s2f6x2`
- `convert.s2f6x2.to.bf16x2`
- `convert.bf16x2.to.f8x2` (extended for `f8E4M3FN` and `f8E5M2` types)
- `convert.{f16x2/bf16x2}.to.f6x2`
- `convert.{f16x2/bf16x2}.to.f4x2`
PTX ISA Reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt
>From d2debb661f5b85c842fce55422c89a376f43b259 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Mon, 2 Mar 2026 11:13:22 +0000
Subject: [PATCH] [MLIR][NVVM] Add new narrow FP convert Ops
This change adds the following NVVM Ops for new narrow FP conversions
introduced in PTX 9.1:
- `convert.{f32x2/bf16x2}.to.s2f6x2`
- `convert.s2f6x2.to.bf16x2`
- `convert.bf16x2.to.f8x2` (extended for `f8E4M3FN` and `f8E5M2` types)
- `convert.{f16x2/bf16x2}.to.f6x2`
- `convert.{f16x2/bf16x2}.to.f4x2`
PTX ISA Reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt
---
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 177 ++++++++++-
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 277 ++++++++++++++++--
.../invalid-convert-stochastic-rounding.mlir | 2 +-
.../Target/LLVMIR/nvvm/convert_fp4x2.mlir | 28 ++
.../LLVMIR/nvvm/convert_fp4x2_invalid.mlir | 17 ++
.../Target/LLVMIR/nvvm/convert_fp6x2.mlir | 77 ++++-
.../LLVMIR/nvvm/convert_fp6x2_invalid.mlir | 17 ++
.../Target/LLVMIR/nvvm/convert_fp8x2.mlir | 25 ++
.../LLVMIR/nvvm/convert_fp8x2_invalid.mlir | 41 +++
.../Target/LLVMIR/nvvm/convert_s2f6x2.mlir | 181 ++++++++++++
mlir/test/Target/LLVMIR/nvvmir-invalid.mlir | 16 -
11 files changed, 812 insertions(+), 46 deletions(-)
create mode 100644 mlir/test/Target/LLVMIR/nvvm/convert_fp4x2_invalid.mlir
create mode 100644 mlir/test/Target/LLVMIR/nvvm/convert_fp6x2_invalid.mlir
create mode 100644 mlir/test/Target/LLVMIR/nvvm/convert_fp8x2_invalid.mlir
create mode 100644 mlir/test/Target/LLVMIR/nvvm/convert_s2f6x2.mlir
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 40f631fa0bb2c..aa12e8568a8b3 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1951,6 +1951,45 @@ def NVVM_ConvertF32x2ToF4x2Op : NVVM_Op<"convert.f32x2.to.f4x2"> {
}];
}
+class NVVM_ConvertFPx2ToF4x2Op<string srcType>
+ : NVVM_Op<"convert."#!tolower(srcType)#"x2.to.f4x2"> {
+ let summary = "Convert a pair of "#srcType#" inputs to F4x2";
+ let description = [{
+ This Op converts each of the given }]#srcType#[{ inputs to the specified
+ F4x2 type.
+ The result `dst` is returned as an i8 type where the converted values are
+ packed such that the value converted from `a` is stored in the upper 4 bits
+ of `dst` and the value converted from `b` is stored in the lower 4 bits of
+ `dst`.
+ The `relu` attribute, when set, lowers to the '.relu' variant of
+ the cvt instruction.
+ }];
+ let hasVerifier = 1;
+ let results = (outs I8:$dst);
+ let arguments = (ins
+ VectorOfLengthAndType<[2], [!cast<Type>(srcType)]>:$a,
+ DefaultValuedAttr<BoolAttr, "false">:$relu,
+ TypeAttr:$dstTy);
+
+ let assemblyFormat =
+ "$a attr-dict `:` type($a) `->` type($dst) `(` $dstTy `)`";
+
+ let extraClassDeclaration = [{
+ static NVVM::IDArgPair
+ getIntrinsicIDAndArgs(NVVM::Convert}]#srcType#[{x2ToF4x2Op &op,
+ LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder);
+ }];
+
+ string llvmBuilder = [{
+ auto [intId, args] = NVVM::Convert}]#srcType#[{x2ToF4x2Op::getIntrinsicIDAndArgs(op, moduleTranslation, builder);
+ llvm::Value *packedI16 = createIntrinsicCall(builder, intId, args);
+ $dst = builder.CreateTruncOrBitCast(packedI16, llvm::Type::getInt8Ty(builder.getContext()));
+ }];
+}
+
+def NVVM_ConvertF16x2ToF4x2Op : NVVM_ConvertFPx2ToF4x2Op<"F16">;
+def NVVM_ConvertBF16x2ToF4x2Op : NVVM_ConvertFPx2ToF4x2Op<"BF16">;
+
def NVVM_ConvertF32x2ToF6x2Op : NVVM_Op<"convert.f32x2.to.f6x2"> {
let summary = "Convert a pair of float inputs to f6x2";
let description = [{
@@ -1994,6 +2033,43 @@ def NVVM_ConvertF32x2ToF6x2Op : NVVM_Op<"convert.f32x2.to.f6x2"> {
}];
}
+class NVVM_ConvertFPx2ToF6x2Op<string srcType>
+ : NVVM_Op<"convert."#!tolower(srcType)#"x2.to.f6x2"> {
+ let summary = "Convert a pair of "#srcType#" inputs to F6x2";
+ let description = [{
+ This Op converts each of the given }]#srcType#[{ inputs to the specified
+ F6x2 type.
+ The result `dst` is represented as a vector type(vector<2xi8>). Each
+ converted value is stored as an i8 element in the vector.
+ The `relu` attribute, when set, lowers to the '.relu' variant of
+ the cvt instruction.
+ }];
+
+ let hasVerifier = 1;
+ let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
+ let arguments = (ins
+ VectorOfLengthAndType<[2], [!cast<Type>(srcType)]>:$a,
+ DefaultValuedAttr<BoolAttr, "false">:$relu,
+ TypeAttr:$dstTy);
+ let assemblyFormat = "$a attr-dict `:` type($a) `->` type($dst) `(` $dstTy `)`";
+ let extraClassDeclaration = [{
+ static llvm::Intrinsic::ID getIntrinsicID(mlir::Type dstTy, bool hasRelu);
+ }];
+
+ string llvmBuilder = [{
+ auto intId = NVVM::Convert}]#srcType#[{x2ToF6x2Op::getIntrinsicID($dstTy, $relu);
+ llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a});
+ if(op.getDst().getType().isInteger(16))
+ $dst = packedI16;
+ else
+ $dst = builder.CreateBitCast(packedI16,
+ llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2));
+ }];
+}
+
+def NVVM_ConvertF16x2ToF6x2Op : NVVM_ConvertFPx2ToF6x2Op<"F16">;
+def NVVM_ConvertBF16x2ToF6x2Op : NVVM_ConvertFPx2ToF6x2Op<"BF16">;
+
def NVVM_ConvertF32x2ToF8x2Op : NVVM_Op<"convert.f32x2.to.f8x2"> {
let summary = "Convert a pair of float inputs to f8x2";
let description = [{
@@ -2109,16 +2185,18 @@ def NVVM_ConvertBF16x2ToF8x2Op : NVVM_Op<"convert.bf16x2.to.f8x2"> {
VectorOfLengthAndType<[2], [BF16]>:$a,
DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd,
DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
+ DefaultValuedAttr<BoolAttr, "false">:$relu,
TypeAttr:$dstTy);
let assemblyFormat = "$a attr-dict `:` type($a) `->` type($dst) `(` $dstTy `)`";
let extraClassDeclaration = [{
- static llvm::Intrinsic::ID getIntrinsicID(NVVM::FPRoundingMode rnd,
- NVVM::SaturationMode sat);
+ static llvm::Intrinsic::ID getIntrinsicID(mlir::Type dstTy,
+ NVVM::FPRoundingMode rnd,
+ NVVM::SaturationMode sat, bool hasRelu);
}];
string llvmBuilder = [{
- auto intId = NVVM::ConvertBF16x2ToF8x2Op::getIntrinsicID($rnd, $sat);
+ auto intId = NVVM::ConvertBF16x2ToF8x2Op::getIntrinsicID($dstTy, $rnd, $sat, $relu);
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a});
if(op.getDst().getType().isInteger(16))
$dst = packedI16;
@@ -2164,6 +2242,99 @@ def NVVM_ConvertF6x2ToF16x2Op :
def NVVM_ConvertF4x2ToF16x2Op :
NVVM_ConvertToFP16x2Op_Base<"F4", I8, "F16">;
+def NVVM_ConvertF32x2ToS2F6x2Op : NVVM_Op<"convert.f32x2.to.s2f6x2"> {
+ let summary = "Convert a pair of f32 inputs to S2F6x2";
+ let description = [{
+ This Op converts each of the given f32 inputs to the
+ S2F6x2 type. The result `dst` can be an I16 or vector<2xi8>.
+ The `relu` attribute, when set, lowers to the '.relu' variant
+ of the cvt instruction. The optional scaling-factor for the
+ conversion is provided through the operand `scaleFactor`.
+ Only `ue8m0` is supported as the type of the scale-factor currently.
+ }];
+
+ let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
+ let arguments = (ins F32:$a, F32:$b,
+ Optional<I16>:$scaleFactor,
+ DefaultValuedAttr<BoolAttr, "false">:$relu);
+ let assemblyFormat =
+ "$a `,` $b (`,` $scaleFactor^)? attr-dict `:` type($dst)";
+ let extraClassDeclaration = [{
+ static IDArgPair getIntrinsicIDAndArgs(Operation &op,
+ LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder);
+ }];
+
+ string llvmBuilder = [{
+ auto [id, args] = NVVM::ConvertF32x2ToS2F6x2Op::getIntrinsicIDAndArgs(*op, moduleTranslation, builder);
+ llvm::Value *packedI16 = createIntrinsicCall(builder, id, args);
+ if(op.getDst().getType().isInteger(16))
+ $dst = packedI16;
+ else
+ $dst = builder.CreateBitCast(packedI16,
+ llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2));
+ }];
+}
+
+def NVVM_ConvertBF16x2ToS2F6x2Op : NVVM_Op<"convert.bf16x2.to.s2f6x2"> {
+ let summary = "Convert a pair of BF16 inputs to S2F6x2";
+ let description = [{
+ This Op converts each of the given BF16 inputs to the
+ S2F6x2 type. The result `dst` can be an I16 or vector<2xi8>.
+ The `relu` attribute, when set, lowers to the '.relu' variant
+ of the cvt instruction.
+ }];
+
+ let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
+ let arguments = (ins
+ VectorOfLengthAndType<[2], [BF16]>:$a,
+ Optional<I16>:$scaleFactor,
+ DefaultValuedAttr<BoolAttr, "false">:$relu);
+ let assemblyFormat =
+ "$a (`,` $scaleFactor^)? attr-dict `:` type($a) `->` type($dst)";
+ let extraClassDeclaration = [{
+ static IDArgPair getIntrinsicIDAndArgs(Operation &op,
+ LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder);
+ }];
+
+ string llvmBuilder = [{
+ auto [id, args] = NVVM::ConvertBF16x2ToS2F6x2Op::getIntrinsicIDAndArgs(*op, moduleTranslation, builder);
+ llvm::Value *packedI16 = createIntrinsicCall(builder, id, args);
+ if(op.getDst().getType().isInteger(16))
+ $dst = packedI16;
+ else
+ $dst = builder.CreateBitCast(packedI16,
+ llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2));
+ }];
+}
+
+def NVVM_ConvertS2F6x2ToBF16x2Op : NVVM_Op<"convert.s2f6x2.to.bf16x2"> {
+ let summary = "Convert s2f6x2 to bf16x2";
+ let description = [{
+ This Op converts the given s2f6x2 to bf16x2 values.
+ The `relu` attribute, when set, lowers to the '.relu' variant
+ of the cvt instruction. The optional scaling-factor for the
+ conversion is provided through the operand `scaleFactor`.
+ Only `ue8m0` is supported as the type of the scale-factor currently.
+ }];
+
+ let results = (outs VectorOfLengthAndType<[2], [BF16]>:$dst);
+ let arguments = (ins I16:$a,
+ Optional<I16>:$scaleFactor,
+ DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
+ DefaultValuedAttr<BoolAttr, "false">:$relu);
+ let assemblyFormat =
+ "$a (`,` $scaleFactor^)? attr-dict `:` type($a) `->` type($dst)";
+ let extraClassDeclaration = [{
+ static IDArgPair getIntrinsicIDAndArgs(Operation &op,
+ LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder);
+ }];
+
+ string llvmBuilder = [{
+ auto [id, args] = NVVM::ConvertS2F6x2ToBF16x2Op::getIntrinsicIDAndArgs(*op, moduleTranslation, builder);
+ $dst = createIntrinsicCall(builder, id, args);
+ }];
+}
+
//===----------------------------------------------------------------------===//
// NVVM Stochastic Rounding Conversion Ops
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index fa085d407d6ec..456e5ce08f840 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -347,6 +347,32 @@ LogicalResult ConvertF32x2ToF6x2Op::verify() {
return success();
}
+LogicalResult ConvertF16x2ToF6x2Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy())) {
+ return emitOpError("Only ")
+ << mlir::Float6E2M3FNType::get(ctx) << " and "
+ << mlir::Float6E3M2FNType::get(ctx)
+ << " types are supported for conversions from f16x2 to f6x2.";
+ }
+
+ return success();
+}
+
+LogicalResult ConvertBF16x2ToF6x2Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy())) {
+ return emitOpError("Only ")
+ << mlir::Float6E2M3FNType::get(ctx) << " and "
+ << mlir::Float6E3M2FNType::get(ctx)
+ << " types are supported for conversions from bf16x2 to f6x2.";
+ }
+
+ return success();
+}
+
LogicalResult ConvertF32x2ToF8x2Op::verify() {
using RndMode = NVVM::FPRoundingMode;
using SatMode = NVVM::SaturationMode;
@@ -414,18 +440,48 @@ LogicalResult ConvertF16x2ToF8x2Op::verify() {
LogicalResult ConvertBF16x2ToF8x2Op::verify() {
using RndMode = NVVM::FPRoundingMode;
+ using SatMode = NVVM::SaturationMode;
- if (!llvm::isa<mlir::Float8E8M0FNUType>(getDstTy()))
- return emitOpError("Only ") << mlir::Float8E8M0FNUType::get(getContext())
- << " type is supported for conversions from "
- "bf16x2 to f8x2.";
+ bool isRoundingModeRN = getRnd() == RndMode::RN;
+ bool isRoundingModeRZ = getRnd() == RndMode::RZ;
+ bool isRoundingModeRP = getRnd() == RndMode::RP;
+ bool isSatFinite = getSat() == SatMode::SATFINITE;
+ bool hasRelu = getRelu();
- auto rnd = getRnd();
- if (rnd != RndMode::RZ && rnd != RndMode::RP)
- return emitOpError("Only RZ and RP rounding modes are supported for "
- "conversions from bf16x2 to f8x2.");
+ mlir::MLIRContext *ctx = getContext();
- return success();
+ return llvm::TypeSwitch<mlir::Type, LogicalResult>(getDstTy())
+ .Case<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(
+ [&](mlir::Type) -> LogicalResult {
+ if (!isRoundingModeRN)
+ return emitOpError("Only RN rounding mode is supported for "
+ "conversions from bf16x2 to ")
+ << mlir::Float8E4M3FNType::get(ctx) << " and "
+ << mlir::Float8E5M2Type::get(ctx) << " types";
+ if (!isSatFinite)
+ return emitOpError("Only SATFINITE saturation mode is supported "
+ "for conversions from bf16x2 to ")
+ << mlir::Float8E4M3FNType::get(ctx) << " and "
+ << mlir::Float8E5M2Type::get(ctx) << " types";
+ return success();
+ })
+ .Case<mlir::Float8E8M0FNUType>([&](mlir::Type) -> LogicalResult {
+ if (!(isRoundingModeRZ || isRoundingModeRP))
+ return emitOpError("Only RZ and RP rounding modes are supported for "
+ "conversions from bf16x2 to ")
+ << mlir::Float8E8M0FNUType::get(ctx) << " type";
+ if (hasRelu)
+ return emitOpError("relu not supported for conversions to ")
+ << mlir::Float8E8M0FNUType::get(ctx) << " type";
+ return success();
+ })
+ .Default([&](mlir::Type) -> LogicalResult {
+ return emitOpError("Only ")
+ << mlir::Float8E4M3FNType::get(ctx) << ", "
+ << mlir::Float8E5M2Type::get(ctx) << ", and "
+ << mlir::Float8E8M0FNUType::get(ctx)
+ << " types are supported for conversions from bf16x2 to f8x2.";
+ });
}
LogicalResult ConvertF32x2ToF4x2Op::verify() {
@@ -439,6 +495,26 @@ LogicalResult ConvertF32x2ToF4x2Op::verify() {
return success();
}
+LogicalResult ConvertF16x2ToF4x2Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
+ return emitOpError("Only ")
+ << mlir::Float4E2M1FNType::get(ctx)
+ << " type is supported for conversions from f16x2 to f4x2.";
+ return success();
+}
+
+LogicalResult ConvertBF16x2ToF4x2Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
+ return emitOpError("Only ")
+ << mlir::Float4E2M1FNType::get(ctx)
+ << " type is supported for conversions from bf16x2 to f4x2.";
+ return success();
+}
+
LogicalResult ConvertF8x2ToF16x2Op::verify() {
mlir::MLIRContext *ctx = getContext();
@@ -4116,6 +4192,80 @@ llvm::Intrinsic::ID ConvertF32x2ToF6x2Op::getIntrinsicID(mlir::Type dstTy,
});
}
+NVVM::IDArgPair
+ConvertF16x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF16x2ToF4x2Op &op,
+ LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
+ mlir::Type dstTy = op.getDstTy();
+ bool hasRelu = op.getRelu();
+
+ llvm::Intrinsic::ID intId = llvm::Intrinsic::not_intrinsic;
+
+ if (llvm::isa<mlir::Float4E2M1FNType>(dstTy))
+ intId = hasRelu ? llvm::Intrinsic::nvvm_f16x2_to_e2m1x2_rn_relu_satfinite
+ : llvm::Intrinsic::nvvm_f16x2_to_e2m1x2_rn_satfinite;
+
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(op.getA()));
+
+ return {intId, std::move(args)};
+}
+
+NVVM::IDArgPair
+ConvertBF16x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertBF16x2ToF4x2Op &op,
+ LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
+ mlir::Type dstTy = op.getDstTy();
+ bool hasRelu = op.getRelu();
+
+ llvm::Intrinsic::ID intId = llvm::Intrinsic::not_intrinsic;
+
+ if (llvm::isa<mlir::Float4E2M1FNType>(dstTy))
+ intId = hasRelu ? llvm::Intrinsic::nvvm_bf16x2_to_e2m1x2_rn_relu_satfinite
+ : llvm::Intrinsic::nvvm_bf16x2_to_e2m1x2_rn_satfinite;
+
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(op.getA()));
+
+ return {intId, std::move(args)};
+}
+
+llvm::Intrinsic::ID ConvertF16x2ToF6x2Op::getIntrinsicID(mlir::Type dstTy,
+ bool hasRelu) {
+ return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
+ .Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
+ return hasRelu ? llvm::Intrinsic::nvvm_f16x2_to_e2m3x2_rn_relu_satfinite
+ : llvm::Intrinsic::nvvm_f16x2_to_e2m3x2_rn_satfinite;
+ })
+ .Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
+ return hasRelu ? llvm::Intrinsic::nvvm_f16x2_to_e3m2x2_rn_relu_satfinite
+ : llvm::Intrinsic::nvvm_f16x2_to_e3m2x2_rn_satfinite;
+ })
+ .Default([](mlir::Type) {
+ llvm_unreachable("Invalid conversion in ConvertF16x2ToF6x2Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
+}
+
+llvm::Intrinsic::ID ConvertBF16x2ToF6x2Op::getIntrinsicID(mlir::Type dstTy,
+ bool hasRelu) {
+ return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
+ .Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
+ return hasRelu
+ ? llvm::Intrinsic::nvvm_bf16x2_to_e2m3x2_rn_relu_satfinite
+ : llvm::Intrinsic::nvvm_bf16x2_to_e2m3x2_rn_satfinite;
+ })
+ .Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
+ return hasRelu
+ ? llvm::Intrinsic::nvvm_bf16x2_to_e3m2x2_rn_relu_satfinite
+ : llvm::Intrinsic::nvvm_bf16x2_to_e3m2x2_rn_satfinite;
+ })
+ .Default([](mlir::Type) {
+ llvm_unreachable("Invalid conversion in ConvertBF16x2ToF6x2Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
+}
+
#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf) \
has_satf ? llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd##_satfinite \
: llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd
@@ -4171,22 +4321,39 @@ llvm::Intrinsic::ID ConvertF16x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy,
});
}
-#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf) \
- has_satf ? llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd##_satfinite \
- : llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd
-
llvm::Intrinsic::ID
-ConvertBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
- NVVM::SaturationMode sat) {
+ConvertBF16x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy,
+ NVVM::FPRoundingMode rnd,
+ NVVM::SaturationMode sat, bool hasRelu) {
bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
- switch (rnd) {
- case NVVM::FPRoundingMode::RZ:
- return GET_BF16X2_TO_F8X2_ID(rz, hasSatFinite);
- case NVVM::FPRoundingMode::RP:
- return GET_BF16X2_TO_F8X2_ID(rp, hasSatFinite);
- default:
- llvm_unreachable("Invalid rounding mode for CvtBF16x2ToF8x2Op");
- }
+
+ static constexpr llvm::Intrinsic::ID ue8m0x2IDs[] = {
+ llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_rz,
+ llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_rp,
+ llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_rz_satfinite,
+ llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_rp_satfinite,
+ };
+
+ return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
+ .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
+ return hasRelu
+ ? llvm::Intrinsic::nvvm_bf16x2_to_e4m3x2_rn_relu_satfinite
+ : llvm::Intrinsic::nvvm_bf16x2_to_e4m3x2_rn_satfinite;
+ })
+ .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
+ return hasRelu
+ ? llvm::Intrinsic::nvvm_bf16x2_to_e5m2x2_rn_relu_satfinite
+ : llvm::Intrinsic::nvvm_bf16x2_to_e5m2x2_rn_satfinite;
+ })
+ .Case<mlir::Float8E8M0FNUType>([&](mlir::Float8E8M0FNUType) {
+ bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
+ unsigned index = (hasSatFinite << 1) | hasRoundingModeRP;
+ return ue8m0x2IDs[index];
+ })
+ .Default([](mlir::Type) {
+ llvm_unreachable("Invalid conversion in ConvertBF16x2ToF8x2Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
}
NVVM::IDArgPair ConvertF8x2ToF16x2Op::getIntrinsicIDAndArgs(
@@ -4281,6 +4448,70 @@ NVVM::IDArgPair ConvertF4x2ToF16x2Op::getIntrinsicIDAndArgs(
return {intId, {extendedI16}};
}
+NVVM::IDArgPair ConvertF32x2ToS2F6x2Op::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::ConvertF32x2ToS2F6x2Op>(op);
+ bool hasRelu = thisOp.getRelu();
+ bool hasScale = static_cast<bool>(thisOp.getScaleFactor());
+
+ llvm::Intrinsic::ID id =
+ hasRelu
+ ? llvm::Intrinsic::nvvm_ff_to_s2f6x2_rn_relu_satfinite_scale_n2_ue8m0
+ : llvm::Intrinsic::nvvm_ff_to_s2f6x2_rn_satfinite_scale_n2_ue8m0;
+
+ // Fill the Intrinsic Args
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(thisOp.getA()));
+ args.push_back(mt.lookupValue(thisOp.getB()));
+ args.push_back(hasScale ? mt.lookupValue(thisOp.getScaleFactor())
+ : builder.getInt16(0x7f7f));
+ return {id, std::move(args)};
+}
+
+NVVM::IDArgPair ConvertBF16x2ToS2F6x2Op::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::ConvertBF16x2ToS2F6x2Op>(op);
+ bool hasRelu = thisOp.getRelu();
+ bool hasScale = static_cast<bool>(thisOp.getScaleFactor());
+
+ llvm::Intrinsic::ID id =
+ hasRelu
+ ? llvm::Intrinsic::
+ nvvm_bf16x2_to_s2f6x2_rn_relu_satfinite_scale_n2_ue8m0
+ : llvm::Intrinsic::nvvm_bf16x2_to_s2f6x2_rn_satfinite_scale_n2_ue8m0;
+
+ // Fill the Intrinsic Args
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(thisOp.getA()));
+ args.push_back(hasScale ? mt.lookupValue(thisOp.getScaleFactor())
+ : builder.getInt16(0x7f7f));
+ return {id, std::move(args)};
+}
+
+NVVM::IDArgPair ConvertS2F6x2ToBF16x2Op::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::ConvertS2F6x2ToBF16x2Op>(op);
+ bool hasRelu = thisOp.getRelu();
+ bool hasScale = static_cast<bool>(thisOp.getScaleFactor());
+ bool hasSat = thisOp.getSat() == NVVM::SaturationMode::SATFINITE;
+
+ static constexpr llvm::Intrinsic::ID ids[] = {
+ llvm::Intrinsic::nvvm_s2f6x2_to_bf16x2_rn_scale_n2_ue8m0,
+ llvm::Intrinsic::nvvm_s2f6x2_to_bf16x2_rn_relu_scale_n2_ue8m0,
+ llvm::Intrinsic::nvvm_s2f6x2_to_bf16x2_rn_satfinite_scale_n2_ue8m0,
+ llvm::Intrinsic::nvvm_s2f6x2_to_bf16x2_rn_relu_satfinite_scale_n2_ue8m0,
+ };
+
+ unsigned idx = (hasSat << 1) | hasRelu;
+
+ // Fill the Intrinsic Args
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(thisOp.getA()));
+ args.push_back(hasScale ? mt.lookupValue(thisOp.getScaleFactor())
+ : builder.getInt16(0x7f7f));
+ return {ids[idx], std::move(args)};
+}
+
llvm::Intrinsic::ID
Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op,
LLVM::ModuleTranslation &mt,
diff --git a/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir b/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir
index 506b81e1e7048..3e763a036f2d3 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir
@@ -64,7 +64,7 @@ llvm.func @convert_f32x2_to_f8x2_rs_not_supported(%a : f32, %b : f32) {
// -----
llvm.func @convert_bf16x2_to_f8x2_rs_not_supported(%src : vector<2xbf16>) {
- // expected-error @below {{Only RZ and RP rounding modes are supported for conversions from bf16x2 to f8x2.}}
+ // expected-error @below {{Only RZ and RP rounding modes are supported for conversions from bf16x2 to 'f8E8M0FNU' type}}
%res = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rs>} : vector<2xbf16> -> i16 (f8E8M0FNU)
llvm.return
}
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
index 451475ca76027..3d3bd714fa8fa 100644
--- a/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
@@ -11,6 +11,34 @@ llvm.func @convert_f32x2_to_f4x2_e2m1(%srcA : f32, %srcB : f32) {
llvm.return
}
+// -----
+
+// CHECK-LABEL: @convert_f16x2_to_f4x2
+llvm.func @convert_f16x2_to_f4x2(%srcA : vector<2xf16>) {
+ // CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.f16x2.to.e2m1x2.rn.satfinite(<2 x half> %{{.*}})
+ // CHECK-NEXT: %{{.*}} = trunc i16 %[[res1]] to i8
+ %res1 = nvvm.convert.f16x2.to.f4x2 %srcA : vector<2xf16> -> i8 (f4E2M1FN)
+ // CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.f16x2.to.e2m1x2.rn.relu.satfinite(<2 x half> %{{.*}})
+ // CHECK-NEXT: %{{.*}} = trunc i16 %[[res2]] to i8
+ %res2 = nvvm.convert.f16x2.to.f4x2 %srcA {relu = true} : vector<2xf16> -> i8 (f4E2M1FN)
+ llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @convert_bf16x2_to_f4x2
+llvm.func @convert_bf16x2_to_f4x2(%srcA : vector<2xbf16>) {
+ // CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.bf16x2.to.e2m1x2.rn.satfinite(<2 x bfloat> %{{.*}})
+ // CHECK-NEXT: %{{.*}} = trunc i16 %[[res1]] to i8
+ %res1 = nvvm.convert.bf16x2.to.f4x2 %srcA : vector<2xbf16> -> i8 (f4E2M1FN)
+ // CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.bf16x2.to.e2m1x2.rn.relu.satfinite(<2 x bfloat> %{{.*}})
+ // CHECK-NEXT: %{{.*}} = trunc i16 %[[res2]] to i8
+ %res2 = nvvm.convert.bf16x2.to.f4x2 %srcA {relu = true} : vector<2xbf16> -> i8 (f4E2M1FN)
+ llvm.return
+}
+
+// -----
+
// CHECK-LABEL: @convert_f4x2_to_f16x2
llvm.func @convert_f4x2_to_f16x2(%src : i8) {
// CHECK: %[[res1:.*]] = zext i8 %{{.*}} to i16
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2_invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2_invalid.mlir
new file mode 100644
index 0000000000000..637db9e8ca07a
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2_invalid.mlir
@@ -0,0 +1,17 @@
+// RUN: mlir-translate -mlir-to-llvmir -verify-diagnostics %s
+
+// -----
+
+llvm.func @convert_f16x2_to_f4x2_invalid_type(%src : vector<2xf16>) {
+ // expected-error @below {{Only 'f4E2M1FN' type is supported for conversions from f16x2 to f4x2.}}
+ %res = nvvm.convert.f16x2.to.f4x2 %src : vector<2xf16> -> i8 (f8E4M3FN)
+ llvm.return
+}
+
+// -----
+
+llvm.func @convert_bf16x2_to_f4x2_invalid_type(%src : vector<2xbf16>) {
+ // expected-error @below {{Only 'f4E2M1FN' type is supported for conversions from bf16x2 to f4x2.}}
+ %res = nvvm.convert.bf16x2.to.f4x2 %src : vector<2xbf16> -> i8 (f8E4M3FN)
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir
index 61a7a48f40d54..8d9e5ff2a6a82 100644
--- a/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir
@@ -1,11 +1,20 @@
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
-// CHECK-LABEL: @convert_f32x2_to_fp6x2_packed
-llvm.func @convert_f32x2_to_fp6x2_packed(%srcA : f32, %srcB : f32) {
+// CHECK-LABEL: @convert_f32x2_to_fp6x2_e2m3
+llvm.func @convert_f32x2_to_fp6x2_e2m3(%srcA : f32, %srcB : f32) {
//CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.satfinite(float %{{.*}}, float %{{.*}})
%res1 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : i16 (f6E2M3FN)
+ //CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.relu.satfinite(float %{{.*}}, float %{{.*}})
+ %res2 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB {relu = true} : i16 (f6E2M3FN)
+ llvm.return
+}
+
+// CHECK-LABEL: @convert_f32x2_to_fp6x2_e3m2
+llvm.func @convert_f32x2_to_fp6x2_e3m2(%srcA : f32, %srcB : f32) {
//CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.satfinite(float %{{.*}}, float %{{.*}})
- %res2 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : i16 (f6E3M2FN)
+ %res1 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : i16 (f6E3M2FN)
+ //CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.relu.satfinite(float %{{.*}}, float %{{.*}})
+ %res2 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB {relu = true} : i16 (f6E3M2FN)
llvm.return
}
@@ -22,6 +31,68 @@ llvm.func @convert_f32x2_to_fp6x2_vector(%srcA : f32, %srcB : f32) {
// -----
+// CHECK-LABEL: @convert_f16x2_to_fp6x2_e2m3
+llvm.func @convert_f16x2_to_fp6x2_e2m3(%srcA : vector<2xf16>) {
+ // CHECK: %{{.*}} = call i16 @llvm.nvvm.f16x2.to.e2m3x2.rn.satfinite(<2 x half> %{{.*}})
+ %res1 = nvvm.convert.f16x2.to.f6x2 %srcA : vector<2xf16> -> i16 (f6E2M3FN)
+ // CHECK: %{{.*}} = call i16 @llvm.nvvm.f16x2.to.e2m3x2.rn.relu.satfinite(<2 x half> %{{.*}})
+ %res2 = nvvm.convert.f16x2.to.f6x2 %srcA {relu = true} : vector<2xf16> -> i16 (f6E2M3FN)
+ llvm.return
+}
+
+// CHECK-LABEL: @convert_f16x2_to_fp6x2_e3m2
+llvm.func @convert_f16x2_to_fp6x2_e3m2(%srcA : vector<2xf16>) {
+ // CHECK: %{{.*}} = call i16 @llvm.nvvm.f16x2.to.e3m2x2.rn.satfinite(<2 x half> %{{.*}})
+ %res1 = nvvm.convert.f16x2.to.f6x2 %srcA : vector<2xf16> -> i16 (f6E3M2FN)
+ // CHECK: %{{.*}} = call i16 @llvm.nvvm.f16x2.to.e3m2x2.rn.relu.satfinite(<2 x half> %{{.*}})
+ %res2 = nvvm.convert.f16x2.to.f6x2 %srcA {relu = true} : vector<2xf16> -> i16 (f6E3M2FN)
+ llvm.return
+}
+
+// CHECK-LABEL: @convert_f16x2_to_fp6x2_vector
+llvm.func @convert_f16x2_to_fp6x2_vector(%srcA : vector<2xf16>) {
+ // CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.f16x2.to.e2m3x2.rn.satfinite(<2 x half> %{{.*}})
+ // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8>
+ %res1 = nvvm.convert.f16x2.to.f6x2 %srcA : vector<2xf16> -> vector<2xi8> (f6E2M3FN)
+ // CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.f16x2.to.e3m2x2.rn.satfinite(<2 x half> %{{.*}})
+ // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res2]] to <2 x i8>
+ %res2 = nvvm.convert.f16x2.to.f6x2 %srcA : vector<2xf16> -> vector<2xi8> (f6E3M2FN)
+ llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @convert_bf16x2_to_fp6x2_e2m3
+llvm.func @convert_bf16x2_to_fp6x2_e2m3(%srcA : vector<2xbf16>, %scale_factor : i16) {
+ // CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.e2m3x2.rn.satfinite(<2 x bfloat> %{{.*}})
+ %res1 = nvvm.convert.bf16x2.to.f6x2 %srcA : vector<2xbf16> -> i16 (f6E2M3FN)
+ // CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.e2m3x2.rn.relu.satfinite(<2 x bfloat> %{{.*}})
+ %res2 = nvvm.convert.bf16x2.to.f6x2 %srcA {relu = true} : vector<2xbf16> -> i16 (f6E2M3FN)
+ llvm.return
+}
+
+// CHECK-LABEL: @convert_bf16x2_to_fp6x2_e3m2
+llvm.func @convert_bf16x2_to_fp6x2_e3m2(%srcA : vector<2xbf16>, %scale_factor : i16) {
+ // CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.e3m2x2.rn.satfinite(<2 x bfloat> %{{.*}})
+ %res1 = nvvm.convert.bf16x2.to.f6x2 %srcA : vector<2xbf16> -> i16 (f6E3M2FN)
+ // CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.e3m2x2.rn.relu.satfinite(<2 x bfloat> %{{.*}})
+ %res2 = nvvm.convert.bf16x2.to.f6x2 %srcA {relu = true} : vector<2xbf16> -> i16 (f6E3M2FN)
+ llvm.return
+}
+
+// CHECK-LABEL: @convert_bf16x2_to_fp6x2_vector
+llvm.func @convert_bf16x2_to_fp6x2_vector(%srcA : vector<2xbf16>, %scale_factor : i16) {
+ // CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.e2m3x2.rn.satfinite(<2 x bfloat> %{{.*}})
+ // CHECK-NEXT: %{{.*}} = bitcast i16 %{{.*}} to <2 x i8>
+ %res1 = nvvm.convert.bf16x2.to.f6x2 %srcA : vector<2xbf16> -> vector<2xi8> (f6E2M3FN)
+ // CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.e3m2x2.rn.satfinite(<2 x bfloat> %{{.*}})
+ // CHECK-NEXT: %{{.*}} = bitcast i16 %{{.*}} to <2 x i8>
+ %res2 = nvvm.convert.bf16x2.to.f6x2 %srcA : vector<2xbf16> -> vector<2xi8> (f6E3M2FN)
+ llvm.return
+}
+
+// -----
+
// CHECK-LABEL: @convert_f6x2_to_f16x2_e2m3
llvm.func @convert_f6x2_to_f16x2_e2m3(%src : vector<2xi8>) {
// CHECK: %[[res1:.*]] = bitcast <2 x i8> %{{.*}} to i16
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2_invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2_invalid.mlir
new file mode 100644
index 0000000000000..8d0df5863a931
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2_invalid.mlir
@@ -0,0 +1,17 @@
+// RUN: mlir-translate -mlir-to-llvmir -verify-diagnostics %s
+
+// -----
+
+llvm.func @convert_f16x2_to_f6x2_invalid_type(%src : vector<2xf16>) {
+ // expected-error @below {{Only 'f6E2M3FN' and 'f6E3M2FN' types are supported for conversions from f16x2 to f6x2.}}
+ %res = nvvm.convert.f16x2.to.f6x2 %src : vector<2xf16> -> vector<2xi8> (f8E4M3FN)
+ llvm.return
+}
+
+// -----
+
+llvm.func @convert_bf16x2_to_f6x2_invalid_type(%src : vector<2xbf16>) {
+ // expected-error @below {{Only 'f6E2M3FN' and 'f6E3M2FN' types are supported for conversions from bf16x2 to f6x2.}}
+ %res = nvvm.convert.bf16x2.to.f6x2 %src : vector<2xbf16> -> vector<2xi8> (f8E4M3FN)
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir
index 4afe901bc08e9..d8002d790b6a2 100644
--- a/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir
@@ -90,6 +90,25 @@ llvm.func @convert_bf16x2_to_f8x2_ue8m0(%src : vector<2xbf16>) {
llvm.return
}
+
+// CHECK-LABEL: @convert_bf16x2_to_f8x2_e4m3
+llvm.func @convert_bf16x2_to_f8x2_e4m3(%srcA : vector<2xbf16>) {
+ // CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.e4m3x2.rn.satfinite(<2 x bfloat> %{{.*}})
+ %res1 = nvvm.convert.bf16x2.to.f8x2 %srcA {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> -> i16 (f8E4M3FN)
+ // CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.e4m3x2.rn.relu.satfinite(<2 x bfloat> %{{.*}})
+ %res2 = nvvm.convert.bf16x2.to.f8x2 %srcA {relu = true, rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> -> i16 (f8E4M3FN)
+ llvm.return
+}
+
+// CHECK-LABEL: @convert_bf16x2_to_f8x2_e5m2
+llvm.func @convert_bf16x2_to_f8x2_e5m2(%srcA : vector<2xbf16>) {
+ // CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.e5m2x2.rn.satfinite(<2 x bfloat> %{{.*}})
+ %res1 = nvvm.convert.bf16x2.to.f8x2 %srcA {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> -> vector<2xi8> (f8E5M2)
+ // CHECK: %{{.*}} = call i16 @llvm.nvvm.bf16x2.to.e5m2x2.rn.relu.satfinite(<2 x bfloat> %{{.*}})
+ %res2 = nvvm.convert.bf16x2.to.f8x2 %srcA {relu = true, rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> -> vector<2xi8> (f8E5M2)
+ llvm.return
+}
+
// CHECK-LABEL: @convert_bf16x2_to_f8x2_vector_return
llvm.func @convert_bf16x2_to_f8x2_vector_return(%src : vector<2xbf16>) {
// CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rz(<2 x bfloat> %{{.*}})
@@ -98,6 +117,12 @@ llvm.func @convert_bf16x2_to_f8x2_vector_return(%src : vector<2xbf16>) {
// CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.bf16x2.to.ue8m0x2.rp.satfinite(<2 x bfloat> %{{.*}})
// CHECK-NEXT: %{{.*}} = bitcast i16 %[[res2]] to <2 x i8>
%res2 = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rp>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> -> vector<2xi8> (f8E8M0FNU)
+ // CHECK: %[[res3:.*]] = call i16 @llvm.nvvm.bf16x2.to.e4m3x2.rn.satfinite(<2 x bfloat> %{{.*}})
+ // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res3]] to <2 x i8>
+ %res3 = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> -> vector<2xi8> (f8E4M3FN)
+ // CHECK: %[[res4:.*]] = call i16 @llvm.nvvm.bf16x2.to.e5m2x2.rn.relu.satfinite(<2 x bfloat> %{{.*}})
+ // CHECK-NEXT: %{{.*}} = bitcast i16 %[[res4]] to <2 x i8>
+ %res4 = nvvm.convert.bf16x2.to.f8x2 %src {relu = true, rnd = #nvvm.fp_rnd_mode<rn>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> -> vector<2xi8> (f8E5M2)
llvm.return
}
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2_invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2_invalid.mlir
new file mode 100644
index 0000000000000..bef5637839205
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2_invalid.mlir
@@ -0,0 +1,41 @@
+// RUN: mlir-translate -mlir-to-llvmir -verify-diagnostics %s
+
+// -----
+
+llvm.func @convert_bf16x2_to_f8x2_invalid_type(%src : vector<2xbf16>) {
+ // expected-error @below {{Only 'f8E4M3FN', 'f8E5M2', and 'f8E8M0FNU' types are supported for conversions from bf16x2 to f8x2.}}
+ %res = nvvm.convert.bf16x2.to.f8x2 %src : vector<2xbf16> -> vector<2xi8> (f6E2M3FN)
+ llvm.return
+}
+
+// -----
+
+llvm.func @convert_f16x2_to_f8x2_invalid_rounding_1(%src : vector<2xbf16>) {
+ // expected-error @below {{Only RN rounding mode is supported for conversions from bf16x2 to 'f8E4M3FN' and 'f8E5M2' types}}
+ %res = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rm>} : vector<2xbf16> -> vector<2xi8> (f8E4M3FN)
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_bf16x2_to_f8x2_invalid_rounding_2(%src : vector<2xbf16>) {
+ // expected-error @below {{Only RZ and RP rounding modes are supported for conversions from bf16x2 to 'f8E8M0FNU' type}}
+ %res = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xbf16> -> vector<2xi8> (f8E8M0FNU)
+ llvm.return
+}
+
+// -----
+
+llvm.func @convert_bf16x2_to_f8x2_invalid_sat_mode(%src : vector<2xbf16>) {
+ // expected-error @below {{Only SATFINITE saturation mode is supported for conversions from bf16x2 to 'f8E4M3FN' and 'f8E5M2' types}}
+ %res = nvvm.convert.bf16x2.to.f8x2 %src {sat = #nvvm.sat_mode<none>, rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xbf16> -> vector<2xi8> (f8E4M3FN)
+ llvm.return
+}
+
+// -----
+
+llvm.func @convert_bf16x2_to_f8x2_invalid_relu(%src : vector<2xbf16>) {
+ // expected-error @below {{relu not supported for conversions to 'f8E8M0FNU' type}}
+ %res = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rp>, relu = true} : vector<2xbf16> -> vector<2xi8> (f8E8M0FNU)
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_s2f6x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_s2f6x2.mlir
new file mode 100644
index 0000000000000..f54b391af96e2
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_s2f6x2.mlir
@@ -0,0 +1,181 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+llvm.func @convert_f32x2_to_s2f6x2(%srcA : f32, %srcB : f32) -> i16 {
+ // CHECK-LABEL: define i16 @convert_f32x2_to_s2f6x2(float %0, float %1) {
+ // CHECK-NEXT: %3 = call i16 @llvm.nvvm.ff.to.s2f6x2.rn.satfinite.scale.n2.ue8m0(float %0, float %1, i16 32639)
+ // CHECK-NEXT: %4 = call i16 @llvm.nvvm.ff.to.s2f6x2.rn.relu.satfinite.scale.n2.ue8m0(float %0, float %1, i16 32639)
+ // CHECK-NEXT: %5 = or i16 %3, %4
+ // CHECK-NEXT: ret i16 %5
+ // CHECK-NEXT: }
+ %res1 = nvvm.convert.f32x2.to.s2f6x2 %srcA, %srcB : i16
+ %res2 = nvvm.convert.f32x2.to.s2f6x2 %srcA, %srcB {relu = true} : i16
+
+ // Combine results to avoid dead code elimination
+ %final_result = llvm.or %res1, %res2 : i16
+ llvm.return %final_result : i16
+}
+
+llvm.func @convert_f32x2_to_s2f6x2_scale(%srcA : f32, %srcB : f32, %scale : i16) -> i16 {
+ // CHECK-LABEL: define i16 @convert_f32x2_to_s2f6x2_scale(float %0, float %1, i16 %2) {
+ // CHECK-NEXT: %4 = call i16 @llvm.nvvm.ff.to.s2f6x2.rn.satfinite.scale.n2.ue8m0(float %0, float %1, i16 %2)
+ // CHECK-NEXT: %5 = call i16 @llvm.nvvm.ff.to.s2f6x2.rn.relu.satfinite.scale.n2.ue8m0(float %0, float %1, i16 %2)
+ // CHECK-NEXT: %6 = or i16 %4, %5
+ // CHECK-NEXT: ret i16 %6
+ // CHECK-NEXT: }
+ %res1 = nvvm.convert.f32x2.to.s2f6x2 %srcA, %srcB, %scale : i16
+ %res2 = nvvm.convert.f32x2.to.s2f6x2 %srcA, %srcB, %scale {relu = true} : i16
+
+ // Combine results to avoid dead code elimination
+ %final_result = llvm.or %res1, %res2 : i16
+ llvm.return %final_result : i16
+}
+
+llvm.func @convert_f32x2_to_s2f6x2_vector(%srcA : f32, %srcB : f32) -> vector<2xi8> {
+ // CHECK-LABEL: define <2 x i8> @convert_f32x2_to_s2f6x2_vector(float %0, float %1) {
+ // CHECK-NEXT: %3 = call i16 @llvm.nvvm.ff.to.s2f6x2.rn.satfinite.scale.n2.ue8m0(float %0, float %1, i16 32639)
+ // CHECK-NEXT: %4 = bitcast i16 %3 to <2 x i8>
+ // CHECK-NEXT: ret <2 x i8> %4
+ // CHECK-NEXT: }
+ %res1 = nvvm.convert.f32x2.to.s2f6x2 %srcA, %srcB : vector<2xi8>
+ llvm.return %res1 : vector<2xi8>
+}
+
+llvm.func @convert_f32x2_to_s2f6x2_vector_scale(%srcA : f32, %srcB : f32, %scale : i16) -> vector<2xi8> {
+ // CHECK-LABEL: define <2 x i8> @convert_f32x2_to_s2f6x2_vector_scale(float %0, float %1, i16 %2) {
+ // CHECK-NEXT: %4 = call i16 @llvm.nvvm.ff.to.s2f6x2.rn.satfinite.scale.n2.ue8m0(float %0, float %1, i16 %2)
+ // CHECK-NEXT: %5 = bitcast i16 %4 to <2 x i8>
+ // CHECK-NEXT: ret <2 x i8> %5
+ // CHECK-NEXT: }
+ %res1 = nvvm.convert.f32x2.to.s2f6x2 %srcA, %srcB, %scale : vector<2xi8>
+ llvm.return %res1 : vector<2xi8>
+}
+
+llvm.func @convert_bf16x2_to_s2f6x2(%srcA : vector<2xbf16>) -> i16 {
+ // CHECK-LABEL: define i16 @convert_bf16x2_to_s2f6x2(<2 x bfloat> %0) {
+ // CHECK-NEXT: %2 = call i16 @llvm.nvvm.bf16x2.to.s2f6x2.rn.satfinite.scale.n2.ue8m0(<2 x bfloat> %0, i16 32639)
+ // CHECK-NEXT: %3 = call i16 @llvm.nvvm.bf16x2.to.s2f6x2.rn.relu.satfinite.scale.n2.ue8m0(<2 x bfloat> %0, i16 32639)
+ // CHECK-NEXT: %4 = or i16 %2, %3
+ // CHECK-NEXT: ret i16 %4
+ // CHECK-NEXT: }
+ %res1 = nvvm.convert.bf16x2.to.s2f6x2 %srcA : vector<2xbf16> -> i16
+ %res2 = nvvm.convert.bf16x2.to.s2f6x2 %srcA {relu = true} : vector<2xbf16> -> i16
+
+ // Combine results to avoid dead code elimination
+ %final_result = llvm.or %res1, %res2 : i16
+ llvm.return %final_result : i16
+}
+
+llvm.func @convert_bf16x2_to_s2f6x2_scale(%srcA : vector<2xbf16>, %scale : i16) -> i16 {
+ // CHECK-LABEL: define i16 @convert_bf16x2_to_s2f6x2_scale(<2 x bfloat> %0, i16 %1) {
+ // CHECK-NEXT: %3 = call i16 @llvm.nvvm.bf16x2.to.s2f6x2.rn.satfinite.scale.n2.ue8m0(<2 x bfloat> %0, i16 %1)
+ // CHECK-NEXT: %4 = call i16 @llvm.nvvm.bf16x2.to.s2f6x2.rn.relu.satfinite.scale.n2.ue8m0(<2 x bfloat> %0, i16 %1)
+ // CHECK-NEXT: %5 = or i16 %3, %4
+ // CHECK-NEXT: ret i16 %5
+ // CHECK-NEXT: }
+ %res1 = nvvm.convert.bf16x2.to.s2f6x2 %srcA, %scale : vector<2xbf16> -> i16
+ %res2 = nvvm.convert.bf16x2.to.s2f6x2 %srcA, %scale {relu = true} : vector<2xbf16> -> i16
+
+ // Combine results to avoid dead code elimination
+ %final_result = llvm.or %res1, %res2 : i16
+ llvm.return %final_result : i16
+}
+
+llvm.func @convert_bf16x2_to_s2f6x2_vector(%srcA : vector<2xbf16>) -> vector<2xi8> {
+ // CHECK-LABEL: define <2 x i8> @convert_bf16x2_to_s2f6x2_vector(<2 x bfloat> %0) {
+ // CHECK-NEXT: %2 = call i16 @llvm.nvvm.bf16x2.to.s2f6x2.rn.satfinite.scale.n2.ue8m0(<2 x bfloat> %0, i16 32639)
+ // CHECK-NEXT: %3 = bitcast i16 %2 to <2 x i8>
+ // CHECK-NEXT: ret <2 x i8> %3
+ // CHECK-NEXT: }
+ %res1 = nvvm.convert.bf16x2.to.s2f6x2 %srcA : vector<2xbf16> -> vector<2xi8>
+ llvm.return %res1 : vector<2xi8>
+}
+
+llvm.func @convert_bf16x2_to_s2f6x2_vector_scale(%srcA : vector<2xbf16>, %scale : i16) -> vector<2xi8> {
+ // CHECK-LABEL: define <2 x i8> @convert_bf16x2_to_s2f6x2_vector_scale(<2 x bfloat> %0, i16 %1) {
+ // CHECK-NEXT: %3 = call i16 @llvm.nvvm.bf16x2.to.s2f6x2.rn.satfinite.scale.n2.ue8m0(<2 x bfloat> %0, i16 %1)
+ // CHECK-NEXT: %4 = bitcast i16 %3 to <2 x i8>
+ // CHECK-NEXT: ret <2 x i8> %4
+ // CHECK-NEXT: }
+ %res1 = nvvm.convert.bf16x2.to.s2f6x2 %srcA, %scale : vector<2xbf16> -> vector<2xi8>
+ llvm.return %res1 : vector<2xi8>
+}
+
+// 1. no relu, no scale, no satfinite
+llvm.func @convert_s2f6x2_to_bf16x2(%src : i16) -> vector<2xbf16> {
+ // CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2(i16 %0) {
+ // CHECK-NEXT: %2 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.scale.n2.ue8m0(i16 %0, i16 32639)
+ // CHECK-NEXT: ret <2 x bfloat> %2
+ // CHECK-NEXT: }
+ %res = nvvm.convert.s2f6x2.to.bf16x2 %src : i16 -> vector<2xbf16>
+ llvm.return %res : vector<2xbf16>
+}
+
+// 2. relu, no scale, no satfinite
+llvm.func @convert_s2f6x2_to_bf16x2_relu(%src : i16) -> vector<2xbf16> {
+ // CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2_relu(i16 %0) {
+ // CHECK-NEXT: %2 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.relu.scale.n2.ue8m0(i16 %0, i16 32639)
+ // CHECK-NEXT: ret <2 x bfloat> %2
+ // CHECK-NEXT: }
+ %res = nvvm.convert.s2f6x2.to.bf16x2 %src {relu = true} : i16 -> vector<2xbf16>
+ llvm.return %res : vector<2xbf16>
+}
+
+// 3. no relu, with scale, no satfinite
+llvm.func @convert_s2f6x2_to_bf16x2_scale(%src : i16, %scale : i16) -> vector<2xbf16> {
+ // CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2_scale(i16 %0, i16 %1) {
+ // CHECK-NEXT: %3 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.scale.n2.ue8m0(i16 %0, i16 %1)
+ // CHECK-NEXT: ret <2 x bfloat> %3
+ // CHECK-NEXT: }
+ %res = nvvm.convert.s2f6x2.to.bf16x2 %src, %scale : i16 -> vector<2xbf16>
+ llvm.return %res : vector<2xbf16>
+}
+
+// 4. relu, with scale, no satfinite
+llvm.func @convert_s2f6x2_to_bf16x2_scale_relu(%src : i16, %scale : i16) -> vector<2xbf16> {
+ // CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2_scale_relu(i16 %0, i16 %1) {
+ // CHECK-NEXT: %3 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.relu.scale.n2.ue8m0(i16 %0, i16 %1)
+ // CHECK-NEXT: ret <2 x bfloat> %3
+ // CHECK-NEXT: }
+ %res = nvvm.convert.s2f6x2.to.bf16x2 %src, %scale {relu = true} : i16 -> vector<2xbf16>
+ llvm.return %res : vector<2xbf16>
+}
+
+// 5. no relu, no scale, satfinite
+llvm.func @convert_s2f6x2_to_bf16x2_satfinite(%src : i16) -> vector<2xbf16> {
+ // CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2_satfinite(i16 %0) {
+ // CHECK-NEXT: %2 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.satfinite.scale.n2.ue8m0(i16 %0, i16 32639)
+ // CHECK-NEXT: ret <2 x bfloat> %2
+ // CHECK-NEXT: }
+ %res = nvvm.convert.s2f6x2.to.bf16x2 %src {sat = #nvvm.sat_mode<satfinite>} : i16 -> vector<2xbf16>
+ llvm.return %res : vector<2xbf16>
+}
+
+// 6. relu, no scale, satfinite
+llvm.func @convert_s2f6x2_to_bf16x2_relu_satfinite(%src : i16) -> vector<2xbf16> {
+ // CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2_relu_satfinite(i16 %0) {
+ // CHECK-NEXT: %2 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.relu.satfinite.scale.n2.ue8m0(i16 %0, i16 32639)
+ // CHECK-NEXT: ret <2 x bfloat> %2
+ // CHECK-NEXT: }
+ %res = nvvm.convert.s2f6x2.to.bf16x2 %src {relu = true, sat = #nvvm.sat_mode<satfinite>} : i16 -> vector<2xbf16>
+ llvm.return %res : vector<2xbf16>
+}
+
+// 7. no relu, with scale, satfinite
+llvm.func @convert_s2f6x2_to_bf16x2_scale_satfinite(%src : i16, %scale : i16) -> vector<2xbf16> {
+ // CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2_scale_satfinite(i16 %0, i16 %1) {
+ // CHECK-NEXT: %3 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.satfinite.scale.n2.ue8m0(i16 %0, i16 %1)
+ // CHECK-NEXT: ret <2 x bfloat> %3
+ // CHECK-NEXT: }
+ %res = nvvm.convert.s2f6x2.to.bf16x2 %src, %scale {sat = #nvvm.sat_mode<satfinite>} : i16 -> vector<2xbf16>
+ llvm.return %res : vector<2xbf16>
+}
+
+// 8. relu, with scale, satfinite
+llvm.func @convert_s2f6x2_to_bf16x2_scale_relu_satfinite(%src : i16, %scale : i16) -> vector<2xbf16> {
+ // CHECK-LABEL: define <2 x bfloat> @convert_s2f6x2_to_bf16x2_scale_relu_satfinite(i16 %0, i16 %1) {
+ // CHECK-NEXT: %3 = call <2 x bfloat> @llvm.nvvm.s2f6x2.to.bf16x2.rn.relu.satfinite.scale.n2.ue8m0(i16 %0, i16 %1)
+ // CHECK-NEXT: ret <2 x bfloat> %3
+ // CHECK-NEXT: }
+ %res = nvvm.convert.s2f6x2.to.bf16x2 %src, %scale {relu = true, sat = #nvvm.sat_mode<satfinite>} : i16 -> vector<2xbf16>
+ llvm.return %res : vector<2xbf16>
+}
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index c0fe0fa11f497..6d7221c1393da 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -209,22 +209,6 @@ llvm.func @nvvm_cvt_f16x2_to_f8x2_invalid_type(%src : vector<2xf16>) {
// -----
-llvm.func @nvvm_cvt_bf16x2_to_f8x2_invalid_type(%src : vector<2xbf16>) {
- // expected-error @below {{Only 'f8E8M0FNU' type is supported for conversions from bf16x2 to f8x2.}}
- %res = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rz>} : vector<2xbf16> -> i16 (f8E4M3FN)
- llvm.return
-}
-
-// -----
-
-llvm.func @nvvm_cvt_bf16x2_to_f8x2_invalid_rounding(%src : vector<2xbf16>) {
- // expected-error @below {{Only RZ and RP rounding modes are supported for conversions from bf16x2 to f8x2.}}
- %res = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rn>} : vector<2xbf16> -> i16 (f8E8M0FNU)
- llvm.return
-}
-
-// -----
-
llvm.func @nvvm_cvt_f32x2_to_f6x2_invalid_type(%a : f32, %b : f32) {
// expected-error @below {{Only 'f6E2M3FN' and 'f6E3M2FN' types are supported for conversions from f32x2 to f6x2.}}
%res = nvvm.convert.f32x2.to.f6x2 %a, %b : i16 (f8E8M0FNU)
More information about the Mlir-commits
mailing list