[Mlir-commits] [mlir] [MLIR][NVVM] Add new narrow FP convert Ops (PR #184291)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Mar 2 22:19:15 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-llvm
Author: Srinivasa Ravi (Wolfram70)
<details>
<summary>Changes</summary>
This change adds the following NVVM Ops for new narrow FP conversions introduced in PTX 9.1:
- `convert.{f32x2/bf16x2}.to.s2f6x2`
- `convert.s2f6x2.to.bf16x2`
- `convert.bf16x2.to.f8x2` (extended for `f8E4M3FN` and `f8E5M2` types)
- `convert.{f16x2/bf16x2}.to.f6x2`
- `convert.{f16x2/bf16x2}.to.f4x2`
PTX ISA Reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt
---
Patch is 47.74 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/184291.diff
11 Files Affected:
- (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+174-3)
- (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+254-23)
- (modified) mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir (+1-1)
- (modified) mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir (+28)
- (added) mlir/test/Target/LLVMIR/nvvm/convert_fp4x2_invalid.mlir (+17)
- (modified) mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir (+74-3)
- (added) mlir/test/Target/LLVMIR/nvvm/convert_fp6x2_invalid.mlir (+17)
- (modified) mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir (+25)
- (added) mlir/test/Target/LLVMIR/nvvm/convert_fp8x2_invalid.mlir (+41)
- (added) mlir/test/Target/LLVMIR/nvvm/convert_s2f6x2.mlir (+181)
- (modified) mlir/test/Target/LLVMIR/nvvmir-invalid.mlir (-16)
``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 40f631fa0bb2c..aa12e8568a8b3 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1951,6 +1951,45 @@ def NVVM_ConvertF32x2ToF4x2Op : NVVM_Op<"convert.f32x2.to.f4x2"> {
}];
}
+class NVVM_ConvertFPx2ToF4x2Op<string srcType>
+ : NVVM_Op<"convert."#!tolower(srcType)#"x2.to.f4x2"> {
+ let summary = "Convert a pair of "#srcType#" inputs to F4x2";
+ let description = [{
+ This Op converts each of the given }]#srcType#[{ inputs to the specified
+ F4x2 type.
+ The result `dst` is returned as an i8 type where the converted values are
+ packed such that the value converted from `a` is stored in the upper 4 bits
+ of `dst` and the value converted from `b` is stored in the lower 4 bits of
+ `dst`.
+ The `relu` attribute, when set, lowers to the '.relu' variant of
+ the cvt instruction.
+ }];
+ let hasVerifier = 1;
+ let results = (outs I8:$dst);
+ let arguments = (ins
+ VectorOfLengthAndType<[2], [!cast<Type>(srcType)]>:$a,
+ DefaultValuedAttr<BoolAttr, "false">:$relu,
+ TypeAttr:$dstTy);
+
+ let assemblyFormat =
+ "$a attr-dict `:` type($a) `->` type($dst) `(` $dstTy `)`";
+
+ let extraClassDeclaration = [{
+ static NVVM::IDArgPair
+ getIntrinsicIDAndArgs(NVVM::Convert}]#srcType#[{x2ToF4x2Op &op,
+ LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder);
+ }];
+
+ string llvmBuilder = [{
+ auto [intId, args] = NVVM::Convert}]#srcType#[{x2ToF4x2Op::getIntrinsicIDAndArgs(op, moduleTranslation, builder);
+ llvm::Value *packedI16 = createIntrinsicCall(builder, intId, args);
+ $dst = builder.CreateTruncOrBitCast(packedI16, llvm::Type::getInt8Ty(builder.getContext()));
+ }];
+}
+
+def NVVM_ConvertF16x2ToF4x2Op : NVVM_ConvertFPx2ToF4x2Op<"F16">;
+def NVVM_ConvertBF16x2ToF4x2Op : NVVM_ConvertFPx2ToF4x2Op<"BF16">;
+
def NVVM_ConvertF32x2ToF6x2Op : NVVM_Op<"convert.f32x2.to.f6x2"> {
let summary = "Convert a pair of float inputs to f6x2";
let description = [{
@@ -1994,6 +2033,43 @@ def NVVM_ConvertF32x2ToF6x2Op : NVVM_Op<"convert.f32x2.to.f6x2"> {
}];
}
+class NVVM_ConvertFPx2ToF6x2Op<string srcType>
+ : NVVM_Op<"convert."#!tolower(srcType)#"x2.to.f6x2"> {
+ let summary = "Convert a pair of "#srcType#" inputs to F6x2";
+ let description = [{
+ This Op converts each of the given }]#srcType#[{ inputs to the specified
+ F6x2 type.
+ The result `dst` is represented as a vector type(vector<2xi8>). Each
+ converted value is stored as an i8 element in the vector.
+ The `relu` attribute, when set, lowers to the '.relu' variant of
+ the cvt instruction.
+ }];
+
+ let hasVerifier = 1;
+ let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
+ let arguments = (ins
+ VectorOfLengthAndType<[2], [!cast<Type>(srcType)]>:$a,
+ DefaultValuedAttr<BoolAttr, "false">:$relu,
+ TypeAttr:$dstTy);
+ let assemblyFormat = "$a attr-dict `:` type($a) `->` type($dst) `(` $dstTy `)`";
+ let extraClassDeclaration = [{
+ static llvm::Intrinsic::ID getIntrinsicID(mlir::Type dstTy, bool hasRelu);
+ }];
+
+ string llvmBuilder = [{
+ auto intId = NVVM::Convert}]#srcType#[{x2ToF6x2Op::getIntrinsicID($dstTy, $relu);
+ llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a});
+ if(op.getDst().getType().isInteger(16))
+ $dst = packedI16;
+ else
+ $dst = builder.CreateBitCast(packedI16,
+ llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2));
+ }];
+}
+
+def NVVM_ConvertF16x2ToF6x2Op : NVVM_ConvertFPx2ToF6x2Op<"F16">;
+def NVVM_ConvertBF16x2ToF6x2Op : NVVM_ConvertFPx2ToF6x2Op<"BF16">;
+
def NVVM_ConvertF32x2ToF8x2Op : NVVM_Op<"convert.f32x2.to.f8x2"> {
let summary = "Convert a pair of float inputs to f8x2";
let description = [{
@@ -2109,16 +2185,18 @@ def NVVM_ConvertBF16x2ToF8x2Op : NVVM_Op<"convert.bf16x2.to.f8x2"> {
VectorOfLengthAndType<[2], [BF16]>:$a,
DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd,
DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
+ DefaultValuedAttr<BoolAttr, "false">:$relu,
TypeAttr:$dstTy);
let assemblyFormat = "$a attr-dict `:` type($a) `->` type($dst) `(` $dstTy `)`";
let extraClassDeclaration = [{
- static llvm::Intrinsic::ID getIntrinsicID(NVVM::FPRoundingMode rnd,
- NVVM::SaturationMode sat);
+ static llvm::Intrinsic::ID getIntrinsicID(mlir::Type dstTy,
+ NVVM::FPRoundingMode rnd,
+ NVVM::SaturationMode sat, bool hasRelu);
}];
string llvmBuilder = [{
- auto intId = NVVM::ConvertBF16x2ToF8x2Op::getIntrinsicID($rnd, $sat);
+ auto intId = NVVM::ConvertBF16x2ToF8x2Op::getIntrinsicID($dstTy, $rnd, $sat, $relu);
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a});
if(op.getDst().getType().isInteger(16))
$dst = packedI16;
@@ -2164,6 +2242,99 @@ def NVVM_ConvertF6x2ToF16x2Op :
def NVVM_ConvertF4x2ToF16x2Op :
NVVM_ConvertToFP16x2Op_Base<"F4", I8, "F16">;
+def NVVM_ConvertF32x2ToS2F6x2Op : NVVM_Op<"convert.f32x2.to.s2f6x2"> {
+ let summary = "Convert a pair of f32 inputs to S2F6x2";
+ let description = [{
+ This Op converts each of the given f32 inputs to the
+ S2F6x2 type. The result `dst` can be an I16 or vector<2xi8>.
+ The `relu` attribute, when set, lowers to the '.relu' variant
+ of the cvt instruction. The optional scaling-factor for the
+ conversion is provided through the operand `scaleFactor`.
+ Only `ue8m0` is supported as the type of the scale-factor currently.
+ }];
+
+ let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
+ let arguments = (ins F32:$a, F32:$b,
+ Optional<I16>:$scaleFactor,
+ DefaultValuedAttr<BoolAttr, "false">:$relu);
+ let assemblyFormat =
+ "$a `,` $b (`,` $scaleFactor^)? attr-dict `:` type($dst)";
+ let extraClassDeclaration = [{
+ static IDArgPair getIntrinsicIDAndArgs(Operation &op,
+ LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder);
+ }];
+
+ string llvmBuilder = [{
+ auto [id, args] = NVVM::ConvertF32x2ToS2F6x2Op::getIntrinsicIDAndArgs(*op, moduleTranslation, builder);
+ llvm::Value *packedI16 = createIntrinsicCall(builder, id, args);
+ if(op.getDst().getType().isInteger(16))
+ $dst = packedI16;
+ else
+ $dst = builder.CreateBitCast(packedI16,
+ llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2));
+ }];
+}
+
+def NVVM_ConvertBF16x2ToS2F6x2Op : NVVM_Op<"convert.bf16x2.to.s2f6x2"> {
+ let summary = "Convert a pair of BF16 inputs to S2F6x2";
+ let description = [{
+ This Op converts each of the given BF16 inputs to the
+ S2F6x2 type. The result `dst` can be an I16 or vector<2xi8>.
+ The `relu` attribute, when set, lowers to the '.relu' variant
+ of the cvt instruction.
+ }];
+
+ let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
+ let arguments = (ins
+ VectorOfLengthAndType<[2], [BF16]>:$a,
+ Optional<I16>:$scaleFactor,
+ DefaultValuedAttr<BoolAttr, "false">:$relu);
+ let assemblyFormat =
+ "$a (`,` $scaleFactor^)? attr-dict `:` type($a) `->` type($dst)";
+ let extraClassDeclaration = [{
+ static IDArgPair getIntrinsicIDAndArgs(Operation &op,
+ LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder);
+ }];
+
+ string llvmBuilder = [{
+ auto [id, args] = NVVM::ConvertBF16x2ToS2F6x2Op::getIntrinsicIDAndArgs(*op, moduleTranslation, builder);
+ llvm::Value *packedI16 = createIntrinsicCall(builder, id, args);
+ if(op.getDst().getType().isInteger(16))
+ $dst = packedI16;
+ else
+ $dst = builder.CreateBitCast(packedI16,
+ llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2));
+ }];
+}
+
+def NVVM_ConvertS2F6x2ToBF16x2Op : NVVM_Op<"convert.s2f6x2.to.bf16x2"> {
+ let summary = "Convert s2f6x2 to bf16x2";
+ let description = [{
+ This Op converts the given s2f6x2 to bf16x2 values.
+ The `relu` attribute, when set, lowers to the '.relu' variant
+ of the cvt instruction. The optional scaling-factor for the
+ conversion is provided through the operand `scaleFactor`.
+ Only `ue8m0` is supported as the type of the scale-factor currently.
+ }];
+
+ let results = (outs VectorOfLengthAndType<[2], [BF16]>:$dst);
+ let arguments = (ins I16:$a,
+ Optional<I16>:$scaleFactor,
+ DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
+ DefaultValuedAttr<BoolAttr, "false">:$relu);
+ let assemblyFormat =
+ "$a (`,` $scaleFactor^)? attr-dict `:` type($a) `->` type($dst)";
+ let extraClassDeclaration = [{
+ static IDArgPair getIntrinsicIDAndArgs(Operation &op,
+ LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder);
+ }];
+
+ string llvmBuilder = [{
+ auto [id, args] = NVVM::ConvertS2F6x2ToBF16x2Op::getIntrinsicIDAndArgs(*op, moduleTranslation, builder);
+ $dst = createIntrinsicCall(builder, id, args);
+ }];
+}
+
//===----------------------------------------------------------------------===//
// NVVM Stochastic Rounding Conversion Ops
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index fa085d407d6ec..456e5ce08f840 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -347,6 +347,32 @@ LogicalResult ConvertF32x2ToF6x2Op::verify() {
return success();
}
+LogicalResult ConvertF16x2ToF6x2Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy())) {
+ return emitOpError("Only ")
+ << mlir::Float6E2M3FNType::get(ctx) << " and "
+ << mlir::Float6E3M2FNType::get(ctx)
+ << " types are supported for conversions from f16x2 to f6x2.";
+ }
+
+ return success();
+}
+
+LogicalResult ConvertBF16x2ToF6x2Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy())) {
+ return emitOpError("Only ")
+ << mlir::Float6E2M3FNType::get(ctx) << " and "
+ << mlir::Float6E3M2FNType::get(ctx)
+ << " types are supported for conversions from bf16x2 to f6x2.";
+ }
+
+ return success();
+}
+
LogicalResult ConvertF32x2ToF8x2Op::verify() {
using RndMode = NVVM::FPRoundingMode;
using SatMode = NVVM::SaturationMode;
@@ -414,18 +440,48 @@ LogicalResult ConvertF16x2ToF8x2Op::verify() {
LogicalResult ConvertBF16x2ToF8x2Op::verify() {
using RndMode = NVVM::FPRoundingMode;
+ using SatMode = NVVM::SaturationMode;
- if (!llvm::isa<mlir::Float8E8M0FNUType>(getDstTy()))
- return emitOpError("Only ") << mlir::Float8E8M0FNUType::get(getContext())
- << " type is supported for conversions from "
- "bf16x2 to f8x2.";
+ bool isRoundingModeRN = getRnd() == RndMode::RN;
+ bool isRoundingModeRZ = getRnd() == RndMode::RZ;
+ bool isRoundingModeRP = getRnd() == RndMode::RP;
+ bool isSatFinite = getSat() == SatMode::SATFINITE;
+ bool hasRelu = getRelu();
- auto rnd = getRnd();
- if (rnd != RndMode::RZ && rnd != RndMode::RP)
- return emitOpError("Only RZ and RP rounding modes are supported for "
- "conversions from bf16x2 to f8x2.");
+ mlir::MLIRContext *ctx = getContext();
- return success();
+ return llvm::TypeSwitch<mlir::Type, LogicalResult>(getDstTy())
+ .Case<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(
+ [&](mlir::Type) -> LogicalResult {
+ if (!isRoundingModeRN)
+ return emitOpError("Only RN rounding mode is supported for "
+ "conversions from bf16x2 to ")
+ << mlir::Float8E4M3FNType::get(ctx) << " and "
+ << mlir::Float8E5M2Type::get(ctx) << " types";
+ if (!isSatFinite)
+ return emitOpError("Only SATFINITE saturation mode is supported "
+ "for conversions from bf16x2 to ")
+ << mlir::Float8E4M3FNType::get(ctx) << " and "
+ << mlir::Float8E5M2Type::get(ctx) << " types";
+ return success();
+ })
+ .Case<mlir::Float8E8M0FNUType>([&](mlir::Type) -> LogicalResult {
+ if (!(isRoundingModeRZ || isRoundingModeRP))
+ return emitOpError("Only RZ and RP rounding modes are supported for "
+ "conversions from bf16x2 to ")
+ << mlir::Float8E8M0FNUType::get(ctx) << " type";
+ if (hasRelu)
+ return emitOpError("relu not supported for conversions to ")
+ << mlir::Float8E8M0FNUType::get(ctx) << " type";
+ return success();
+ })
+ .Default([&](mlir::Type) -> LogicalResult {
+ return emitOpError("Only ")
+ << mlir::Float8E4M3FNType::get(ctx) << ", "
+ << mlir::Float8E5M2Type::get(ctx) << ", and "
+ << mlir::Float8E8M0FNUType::get(ctx)
+ << " types are supported for conversions from bf16x2 to f8x2.";
+ });
}
LogicalResult ConvertF32x2ToF4x2Op::verify() {
@@ -439,6 +495,26 @@ LogicalResult ConvertF32x2ToF4x2Op::verify() {
return success();
}
+LogicalResult ConvertF16x2ToF4x2Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
+ return emitOpError("Only ")
+ << mlir::Float4E2M1FNType::get(ctx)
+ << " type is supported for conversions from f16x2 to f4x2.";
+ return success();
+}
+
+LogicalResult ConvertBF16x2ToF4x2Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
+ return emitOpError("Only ")
+ << mlir::Float4E2M1FNType::get(ctx)
+ << " type is supported for conversions from bf16x2 to f4x2.";
+ return success();
+}
+
LogicalResult ConvertF8x2ToF16x2Op::verify() {
mlir::MLIRContext *ctx = getContext();
@@ -4116,6 +4192,80 @@ llvm::Intrinsic::ID ConvertF32x2ToF6x2Op::getIntrinsicID(mlir::Type dstTy,
});
}
+NVVM::IDArgPair
+ConvertF16x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF16x2ToF4x2Op &op,
+ LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
+ mlir::Type dstTy = op.getDstTy();
+ bool hasRelu = op.getRelu();
+
+ llvm::Intrinsic::ID intId = llvm::Intrinsic::not_intrinsic;
+
+ if (llvm::isa<mlir::Float4E2M1FNType>(dstTy))
+ intId = hasRelu ? llvm::Intrinsic::nvvm_f16x2_to_e2m1x2_rn_relu_satfinite
+ : llvm::Intrinsic::nvvm_f16x2_to_e2m1x2_rn_satfinite;
+
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(op.getA()));
+
+ return {intId, std::move(args)};
+}
+
+NVVM::IDArgPair
+ConvertBF16x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertBF16x2ToF4x2Op &op,
+ LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
+ mlir::Type dstTy = op.getDstTy();
+ bool hasRelu = op.getRelu();
+
+ llvm::Intrinsic::ID intId = llvm::Intrinsic::not_intrinsic;
+
+ if (llvm::isa<mlir::Float4E2M1FNType>(dstTy))
+ intId = hasRelu ? llvm::Intrinsic::nvvm_bf16x2_to_e2m1x2_rn_relu_satfinite
+ : llvm::Intrinsic::nvvm_bf16x2_to_e2m1x2_rn_satfinite;
+
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(op.getA()));
+
+ return {intId, std::move(args)};
+}
+
+llvm::Intrinsic::ID ConvertF16x2ToF6x2Op::getIntrinsicID(mlir::Type dstTy,
+ bool hasRelu) {
+ return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
+ .Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
+ return hasRelu ? llvm::Intrinsic::nvvm_f16x2_to_e2m3x2_rn_relu_satfinite
+ : llvm::Intrinsic::nvvm_f16x2_to_e2m3x2_rn_satfinite;
+ })
+ .Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
+ return hasRelu ? llvm::Intrinsic::nvvm_f16x2_to_e3m2x2_rn_relu_satfinite
+ : llvm::Intrinsic::nvvm_f16x2_to_e3m2x2_rn_satfinite;
+ })
+ .Default([](mlir::Type) {
+ llvm_unreachable("Invalid conversion in ConvertF16x2ToF6x2Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
+}
+
+llvm::Intrinsic::ID ConvertBF16x2ToF6x2Op::getIntrinsicID(mlir::Type dstTy,
+ bool hasRelu) {
+ return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
+ .Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
+ return hasRelu
+ ? llvm::Intrinsic::nvvm_bf16x2_to_e2m3x2_rn_relu_satfinite
+ : llvm::Intrinsic::nvvm_bf16x2_to_e2m3x2_rn_satfinite;
+ })
+ .Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
+ return hasRelu
+ ? llvm::Intrinsic::nvvm_bf16x2_to_e3m2x2_rn_relu_satfinite
+ : llvm::Intrinsic::nvvm_bf16x2_to_e3m2x2_rn_satfinite;
+ })
+ .Default([](mlir::Type) {
+ llvm_unreachable("Invalid conversion in ConvertBF16x2ToF6x2Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
+}
+
#define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf) \
has_satf ? llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd##_satfinite \
: llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd
@@ -4171,22 +4321,39 @@ llvm::Intrinsic::ID ConvertF16x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy,
});
}
-#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf) \
- has_satf ? llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd##_satfinite \
- : llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd
-
llvm::Intrinsic::ID
-ConvertBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
- NVVM::SaturationMode sat) {
+ConvertBF16x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy,
+ NVVM::FPRoundingMode rnd,
+ NVVM::SaturationMode sat, bool hasRelu) {
bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
- switch (rnd) {
- case NVVM::FPRoundingMode::RZ:
- return GET_BF16X2_TO_F8X2_ID(rz, hasSatFinite);
- case NVVM::FPRoundingMode::RP:
- return GET_BF16X2_TO_F8X2_ID(rp, hasSatFinite);
- default:
- llvm_unreachable("Invalid rounding mode for CvtBF16x2ToF8x2Op");
- }
+
+ static constexpr llvm::Intrinsic::ID ue8m0x2IDs[] = {
+ llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_rz,
+ llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_rp,
+ llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_rz_satfinite,
+ llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_rp_satfinite,
+ };
+
+ return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
+ .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
+ return hasRelu
+ ? llvm::Intrinsic::nvvm_bf16x2_to_e4m3x2_rn_relu_satfinite
+ : llvm::Intrinsic::nvvm_bf16x2_to_e4m3x2_rn_satfinite;
+ })
+ .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
+ return hasRelu
+ ? llvm::Intrinsic::nvvm_bf16x2_to_e5m2x2_rn_relu_satfinite
+ : llvm::Intrinsic::nvvm_bf16x2_to_e5m2x2_rn_satfinite;
+ })
+ .Case<mlir::Float8E8M0FNUType>([&](mlir::Float8E8M0FNUType) {
+ bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
+ unsigned index = (hasSatFinite << 1) | hasRoundingModeRP;
+ return ue8m0x2IDs[index];
+ })
+ .Default([](mlir::Type) {
+ llvm_unreachable("Invalid conversion in ConvertBF16x2ToF8x2Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
}
NVVM::IDArgPair ConvertF8x2ToF16x2Op::getIntrinsicIDAndArgs(
@@ -4281,6 +4448,70 @@ NVVM::IDArgPair ConvertF4x2ToF16x2Op::getIntrinsicIDAndArgs(
return {intId, {extendedI16}};
}
+NVVM::IDAr...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/184291
More information about the Mlir-commits
mailing list