[Mlir-commits] [mlir] [MLIR][NVVM] Add new narrow FP convert Ops (PR #184291)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 2 22:19:15 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-llvm

Author: Srinivasa Ravi (Wolfram70)

<details>
<summary>Changes</summary>

This change adds the following NVVM Ops for new narrow FP conversions introduced in PTX 9.1:
- `convert.{f32x2/bf16x2}.to.s2f6x2`
- `convert.s2f6x2.to.bf16x2`
- `convert.bf16x2.to.f8x2` (extended for `f8E4M3FN` and `f8E5M2` types)
- `convert.{f16x2/bf16x2}.to.f6x2`
- `convert.{f16x2/bf16x2}.to.f4x2`

PTX ISA Reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-cvt

---

Patch is 47.74 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/184291.diff


11 Files Affected:

- (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+174-3) 
- (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+254-23) 
- (modified) mlir/test/Dialect/LLVMIR/nvvm/invalid-convert-stochastic-rounding.mlir (+1-1) 
- (modified) mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir (+28) 
- (added) mlir/test/Target/LLVMIR/nvvm/convert_fp4x2_invalid.mlir (+17) 
- (modified) mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir (+74-3) 
- (added) mlir/test/Target/LLVMIR/nvvm/convert_fp6x2_invalid.mlir (+17) 
- (modified) mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir (+25) 
- (added) mlir/test/Target/LLVMIR/nvvm/convert_fp8x2_invalid.mlir (+41) 
- (added) mlir/test/Target/LLVMIR/nvvm/convert_s2f6x2.mlir (+181) 
- (modified) mlir/test/Target/LLVMIR/nvvmir-invalid.mlir (-16) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 40f631fa0bb2c..aa12e8568a8b3 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1951,6 +1951,45 @@ def NVVM_ConvertF32x2ToF4x2Op : NVVM_Op<"convert.f32x2.to.f4x2"> {
   }];
 }
 
+class NVVM_ConvertFPx2ToF4x2Op<string srcType>
+    : NVVM_Op<"convert."#!tolower(srcType)#"x2.to.f4x2"> {
+  let summary = "Convert a pair of "#srcType#" inputs to F4x2";
+  let description = [{
+    This Op converts each of the given }]#srcType#[{ inputs to the specified 
+    F4x2 type.
+    The result `dst` is returned as an i8 type where the converted values are 
+    packed such that the value converted from `a` is stored in the upper 4 bits 
+    of `dst` and the value converted from `b` is stored in the lower 4 bits of 
+    `dst`.
+    The `relu` attribute, when set, lowers to the '.relu' variant of
+    the cvt instruction.
+    }];
+  let hasVerifier = 1;
+  let results = (outs I8:$dst);
+  let arguments = (ins
+      VectorOfLengthAndType<[2], [!cast<Type>(srcType)]>:$a,
+      DefaultValuedAttr<BoolAttr, "false">:$relu,
+      TypeAttr:$dstTy);
+
+  let assemblyFormat =
+      "$a attr-dict `:` type($a) `->` type($dst) `(` $dstTy `)`";
+
+  let extraClassDeclaration = [{
+    static NVVM::IDArgPair
+    getIntrinsicIDAndArgs(NVVM::Convert}]#srcType#[{x2ToF4x2Op &op, 
+      LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder);
+  }];
+
+  string llvmBuilder = [{
+    auto [intId, args] = NVVM::Convert}]#srcType#[{x2ToF4x2Op::getIntrinsicIDAndArgs(op, moduleTranslation, builder);
+    llvm::Value *packedI16 = createIntrinsicCall(builder, intId, args);
+    $dst = builder.CreateTruncOrBitCast(packedI16, llvm::Type::getInt8Ty(builder.getContext()));
+  }];
+}
+
+def NVVM_ConvertF16x2ToF4x2Op : NVVM_ConvertFPx2ToF4x2Op<"F16">;
+def NVVM_ConvertBF16x2ToF4x2Op : NVVM_ConvertFPx2ToF4x2Op<"BF16">;
+
 def NVVM_ConvertF32x2ToF6x2Op : NVVM_Op<"convert.f32x2.to.f6x2"> {
   let summary = "Convert a pair of float inputs to f6x2";
   let description = [{
@@ -1994,6 +2033,43 @@ def NVVM_ConvertF32x2ToF6x2Op : NVVM_Op<"convert.f32x2.to.f6x2"> {
   }];
 }
 
+class NVVM_ConvertFPx2ToF6x2Op<string srcType>
+    : NVVM_Op<"convert."#!tolower(srcType)#"x2.to.f6x2"> {
+  let summary = "Convert a pair of "#srcType#" inputs to F6x2";
+  let description = [{
+    This Op converts each of the given }]#srcType#[{ inputs to the specified 
+    F6x2 type.
+    The result `dst` is represented as a vector type(vector<2xi8>). Each 
+    converted value is stored as an i8 element in the vector.
+    The `relu` attribute, when set, lowers to the '.relu' variant of
+    the cvt instruction.
+  }];
+
+  let hasVerifier = 1;
+  let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
+  let arguments = (ins
+      VectorOfLengthAndType<[2], [!cast<Type>(srcType)]>:$a,
+      DefaultValuedAttr<BoolAttr, "false">:$relu,
+      TypeAttr:$dstTy);
+  let assemblyFormat = "$a attr-dict `:` type($a) `->` type($dst) `(` $dstTy `)`";
+  let extraClassDeclaration = [{
+    static llvm::Intrinsic::ID getIntrinsicID(mlir::Type dstTy, bool hasRelu);
+  }];
+
+  string llvmBuilder = [{
+    auto intId = NVVM::Convert}]#srcType#[{x2ToF6x2Op::getIntrinsicID($dstTy, $relu);
+    llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a});
+    if(op.getDst().getType().isInteger(16))
+      $dst = packedI16;
+    else
+      $dst = builder.CreateBitCast(packedI16,
+                        llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2));
+  }];
+}
+
+def NVVM_ConvertF16x2ToF6x2Op : NVVM_ConvertFPx2ToF6x2Op<"F16">;
+def NVVM_ConvertBF16x2ToF6x2Op : NVVM_ConvertFPx2ToF6x2Op<"BF16">;
+
 def NVVM_ConvertF32x2ToF8x2Op : NVVM_Op<"convert.f32x2.to.f8x2"> {
   let summary = "Convert a pair of float inputs to f8x2";
   let description = [{
@@ -2109,16 +2185,18 @@ def NVVM_ConvertBF16x2ToF8x2Op : NVVM_Op<"convert.bf16x2.to.f8x2"> {
     VectorOfLengthAndType<[2], [BF16]>:$a,
     DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd,
     DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
+    DefaultValuedAttr<BoolAttr, "false">:$relu,
     TypeAttr:$dstTy);
   let assemblyFormat = "$a attr-dict `:` type($a) `->` type($dst) `(` $dstTy `)`";
   
   let extraClassDeclaration = [{
-    static llvm::Intrinsic::ID getIntrinsicID(NVVM::FPRoundingMode rnd,
-                                              NVVM::SaturationMode sat);
+    static llvm::Intrinsic::ID getIntrinsicID(mlir::Type dstTy,
+                                  NVVM::FPRoundingMode rnd,
+                                  NVVM::SaturationMode sat, bool hasRelu);
   }];
 
   string llvmBuilder = [{
-    auto intId = NVVM::ConvertBF16x2ToF8x2Op::getIntrinsicID($rnd, $sat);
+    auto intId = NVVM::ConvertBF16x2ToF8x2Op::getIntrinsicID($dstTy, $rnd, $sat, $relu);
     llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a});
     if(op.getDst().getType().isInteger(16))
       $dst = packedI16;
@@ -2164,6 +2242,99 @@ def NVVM_ConvertF6x2ToF16x2Op :
 def NVVM_ConvertF4x2ToF16x2Op :
   NVVM_ConvertToFP16x2Op_Base<"F4", I8, "F16">;
 
+def NVVM_ConvertF32x2ToS2F6x2Op : NVVM_Op<"convert.f32x2.to.s2f6x2"> {
+  let summary = "Convert a pair of f32 inputs to S2F6x2";
+  let description = [{
+    This Op converts each of the given f32 inputs to the
+    S2F6x2 type. The result `dst` can be an I16 or vector<2xi8>.
+    The `relu` attribute, when set, lowers to the '.relu' variant
+    of the cvt instruction. The optional scaling-factor for the
+    conversion is provided through the operand `scaleFactor`.
+    Only `ue8m0` is supported as the type of the scale-factor currently.
+  }];
+
+  let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
+  let arguments = (ins F32:$a, F32:$b,
+      Optional<I16>:$scaleFactor,
+      DefaultValuedAttr<BoolAttr, "false">:$relu);
+  let assemblyFormat =
+      "$a `,` $b (`,` $scaleFactor^)? attr-dict `:` type($dst)";
+  let extraClassDeclaration = [{
+    static IDArgPair getIntrinsicIDAndArgs(Operation &op,
+      LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder);
+  }];
+
+  string llvmBuilder = [{
+    auto [id, args] = NVVM::ConvertF32x2ToS2F6x2Op::getIntrinsicIDAndArgs(*op, moduleTranslation, builder);
+    llvm::Value *packedI16 = createIntrinsicCall(builder, id, args);
+    if(op.getDst().getType().isInteger(16))
+      $dst = packedI16;
+    else
+      $dst = builder.CreateBitCast(packedI16,
+                      llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2));
+  }];
+}
+
+def NVVM_ConvertBF16x2ToS2F6x2Op : NVVM_Op<"convert.bf16x2.to.s2f6x2"> {
+  let summary = "Convert a pair of BF16 inputs to S2F6x2";
+  let description = [{
+    This Op converts each of the given BF16 inputs to the
+    S2F6x2 type. The result `dst` can be an I16 or vector<2xi8>.
+    The `relu` attribute, when set, lowers to the '.relu' variant
+    of the cvt instruction.
+  }];
+
+  let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
+  let arguments = (ins
+      VectorOfLengthAndType<[2], [BF16]>:$a,
+      Optional<I16>:$scaleFactor,
+      DefaultValuedAttr<BoolAttr, "false">:$relu);
+  let assemblyFormat =
+      "$a (`,` $scaleFactor^)? attr-dict `:` type($a) `->` type($dst)";
+  let extraClassDeclaration = [{
+    static IDArgPair getIntrinsicIDAndArgs(Operation &op,
+      LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder);
+  }];
+
+  string llvmBuilder = [{
+    auto [id, args] = NVVM::ConvertBF16x2ToS2F6x2Op::getIntrinsicIDAndArgs(*op, moduleTranslation, builder);
+    llvm::Value *packedI16 = createIntrinsicCall(builder, id, args);
+    if(op.getDst().getType().isInteger(16))
+      $dst = packedI16;
+    else
+      $dst = builder.CreateBitCast(packedI16,
+                      llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2));
+  }];
+}
+
+def NVVM_ConvertS2F6x2ToBF16x2Op : NVVM_Op<"convert.s2f6x2.to.bf16x2"> {
+  let summary = "Convert s2f6x2 to bf16x2";
+  let description = [{
+    This Op converts the given s2f6x2 to bf16x2 values.
+    The `relu` attribute, when set, lowers to the '.relu' variant
+    of the cvt instruction. The optional scaling-factor for the
+    conversion is provided through the operand `scaleFactor`.
+    Only `ue8m0` is supported as the type of the scale-factor currently.
+  }];
+
+  let results = (outs VectorOfLengthAndType<[2], [BF16]>:$dst);
+  let arguments = (ins I16:$a,
+      Optional<I16>:$scaleFactor,
+      DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
+      DefaultValuedAttr<BoolAttr, "false">:$relu);
+  let assemblyFormat =
+      "$a (`,` $scaleFactor^)? attr-dict `:` type($a) `->` type($dst)";
+  let extraClassDeclaration = [{
+    static IDArgPair getIntrinsicIDAndArgs(Operation &op,
+      LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder);
+  }];
+
+  string llvmBuilder = [{
+    auto [id, args] = NVVM::ConvertS2F6x2ToBF16x2Op::getIntrinsicIDAndArgs(*op, moduleTranslation, builder);
+    $dst = createIntrinsicCall(builder, id, args);
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // NVVM Stochastic Rounding Conversion Ops
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index fa085d407d6ec..456e5ce08f840 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -347,6 +347,32 @@ LogicalResult ConvertF32x2ToF6x2Op::verify() {
   return success();
 }
 
+LogicalResult ConvertF16x2ToF6x2Op::verify() {
+  mlir::MLIRContext *ctx = getContext();
+
+  if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy())) {
+    return emitOpError("Only ")
+           << mlir::Float6E2M3FNType::get(ctx) << " and "
+           << mlir::Float6E3M2FNType::get(ctx)
+           << " types are supported for conversions from f16x2 to f6x2.";
+  }
+
+  return success();
+}
+
+LogicalResult ConvertBF16x2ToF6x2Op::verify() {
+  mlir::MLIRContext *ctx = getContext();
+
+  if (!llvm::isa<mlir::Float6E2M3FNType, mlir::Float6E3M2FNType>(getDstTy())) {
+    return emitOpError("Only ")
+           << mlir::Float6E2M3FNType::get(ctx) << " and "
+           << mlir::Float6E3M2FNType::get(ctx)
+           << " types are supported for conversions from bf16x2 to f6x2.";
+  }
+
+  return success();
+}
+
 LogicalResult ConvertF32x2ToF8x2Op::verify() {
   using RndMode = NVVM::FPRoundingMode;
   using SatMode = NVVM::SaturationMode;
@@ -414,18 +440,48 @@ LogicalResult ConvertF16x2ToF8x2Op::verify() {
 
 LogicalResult ConvertBF16x2ToF8x2Op::verify() {
   using RndMode = NVVM::FPRoundingMode;
+  using SatMode = NVVM::SaturationMode;
 
-  if (!llvm::isa<mlir::Float8E8M0FNUType>(getDstTy()))
-    return emitOpError("Only ") << mlir::Float8E8M0FNUType::get(getContext())
-                                << " type is supported for conversions from "
-                                   "bf16x2 to f8x2.";
+  bool isRoundingModeRN = getRnd() == RndMode::RN;
+  bool isRoundingModeRZ = getRnd() == RndMode::RZ;
+  bool isRoundingModeRP = getRnd() == RndMode::RP;
+  bool isSatFinite = getSat() == SatMode::SATFINITE;
+  bool hasRelu = getRelu();
 
-  auto rnd = getRnd();
-  if (rnd != RndMode::RZ && rnd != RndMode::RP)
-    return emitOpError("Only RZ and RP rounding modes are supported for "
-                       "conversions from bf16x2 to f8x2.");
+  mlir::MLIRContext *ctx = getContext();
 
-  return success();
+  return llvm::TypeSwitch<mlir::Type, LogicalResult>(getDstTy())
+      .Case<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>(
+          [&](mlir::Type) -> LogicalResult {
+            if (!isRoundingModeRN)
+              return emitOpError("Only RN rounding mode is supported for "
+                                 "conversions from bf16x2 to ")
+                     << mlir::Float8E4M3FNType::get(ctx) << " and "
+                     << mlir::Float8E5M2Type::get(ctx) << " types";
+            if (!isSatFinite)
+              return emitOpError("Only SATFINITE saturation mode is supported "
+                                 "for conversions from bf16x2 to ")
+                     << mlir::Float8E4M3FNType::get(ctx) << " and "
+                     << mlir::Float8E5M2Type::get(ctx) << " types";
+            return success();
+          })
+      .Case<mlir::Float8E8M0FNUType>([&](mlir::Type) -> LogicalResult {
+        if (!(isRoundingModeRZ || isRoundingModeRP))
+          return emitOpError("Only RZ and RP rounding modes are supported for "
+                             "conversions from bf16x2 to ")
+                 << mlir::Float8E8M0FNUType::get(ctx) << " type";
+        if (hasRelu)
+          return emitOpError("relu not supported for conversions to ")
+                 << mlir::Float8E8M0FNUType::get(ctx) << " type";
+        return success();
+      })
+      .Default([&](mlir::Type) -> LogicalResult {
+        return emitOpError("Only ")
+               << mlir::Float8E4M3FNType::get(ctx) << ", "
+               << mlir::Float8E5M2Type::get(ctx) << ", and "
+               << mlir::Float8E8M0FNUType::get(ctx)
+               << " types are supported for conversions from bf16x2 to f8x2.";
+      });
 }
 
 LogicalResult ConvertF32x2ToF4x2Op::verify() {
@@ -439,6 +495,26 @@ LogicalResult ConvertF32x2ToF4x2Op::verify() {
   return success();
 }
 
+LogicalResult ConvertF16x2ToF4x2Op::verify() {
+  mlir::MLIRContext *ctx = getContext();
+
+  if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
+    return emitOpError("Only ")
+           << mlir::Float4E2M1FNType::get(ctx)
+           << " type is supported for conversions from f16x2 to f4x2.";
+  return success();
+}
+
+LogicalResult ConvertBF16x2ToF4x2Op::verify() {
+  mlir::MLIRContext *ctx = getContext();
+
+  if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
+    return emitOpError("Only ")
+           << mlir::Float4E2M1FNType::get(ctx)
+           << " type is supported for conversions from bf16x2 to f4x2.";
+  return success();
+}
+
 LogicalResult ConvertF8x2ToF16x2Op::verify() {
   mlir::MLIRContext *ctx = getContext();
 
@@ -4116,6 +4192,80 @@ llvm::Intrinsic::ID ConvertF32x2ToF6x2Op::getIntrinsicID(mlir::Type dstTy,
       });
 }
 
+NVVM::IDArgPair
+ConvertF16x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF16x2ToF4x2Op &op,
+                                            LLVM::ModuleTranslation &mt,
+                                            llvm::IRBuilderBase &builder) {
+  mlir::Type dstTy = op.getDstTy();
+  bool hasRelu = op.getRelu();
+
+  llvm::Intrinsic::ID intId = llvm::Intrinsic::not_intrinsic;
+
+  if (llvm::isa<mlir::Float4E2M1FNType>(dstTy))
+    intId = hasRelu ? llvm::Intrinsic::nvvm_f16x2_to_e2m1x2_rn_relu_satfinite
+                    : llvm::Intrinsic::nvvm_f16x2_to_e2m1x2_rn_satfinite;
+
+  llvm::SmallVector<llvm::Value *> args;
+  args.push_back(mt.lookupValue(op.getA()));
+
+  return {intId, std::move(args)};
+}
+
+NVVM::IDArgPair
+ConvertBF16x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertBF16x2ToF4x2Op &op,
+                                             LLVM::ModuleTranslation &mt,
+                                             llvm::IRBuilderBase &builder) {
+  mlir::Type dstTy = op.getDstTy();
+  bool hasRelu = op.getRelu();
+
+  llvm::Intrinsic::ID intId = llvm::Intrinsic::not_intrinsic;
+
+  if (llvm::isa<mlir::Float4E2M1FNType>(dstTy))
+    intId = hasRelu ? llvm::Intrinsic::nvvm_bf16x2_to_e2m1x2_rn_relu_satfinite
+                    : llvm::Intrinsic::nvvm_bf16x2_to_e2m1x2_rn_satfinite;
+
+  llvm::SmallVector<llvm::Value *> args;
+  args.push_back(mt.lookupValue(op.getA()));
+
+  return {intId, std::move(args)};
+}
+
+llvm::Intrinsic::ID ConvertF16x2ToF6x2Op::getIntrinsicID(mlir::Type dstTy,
+                                                         bool hasRelu) {
+  return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
+      .Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
+        return hasRelu ? llvm::Intrinsic::nvvm_f16x2_to_e2m3x2_rn_relu_satfinite
+                       : llvm::Intrinsic::nvvm_f16x2_to_e2m3x2_rn_satfinite;
+      })
+      .Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
+        return hasRelu ? llvm::Intrinsic::nvvm_f16x2_to_e3m2x2_rn_relu_satfinite
+                       : llvm::Intrinsic::nvvm_f16x2_to_e3m2x2_rn_satfinite;
+      })
+      .Default([](mlir::Type) {
+        llvm_unreachable("Invalid conversion in ConvertF16x2ToF6x2Op");
+        return llvm::Intrinsic::not_intrinsic;
+      });
+}
+
+llvm::Intrinsic::ID ConvertBF16x2ToF6x2Op::getIntrinsicID(mlir::Type dstTy,
+                                                          bool hasRelu) {
+  return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy)
+      .Case<mlir::Float6E2M3FNType>([&](mlir::Float6E2M3FNType) {
+        return hasRelu
+                   ? llvm::Intrinsic::nvvm_bf16x2_to_e2m3x2_rn_relu_satfinite
+                   : llvm::Intrinsic::nvvm_bf16x2_to_e2m3x2_rn_satfinite;
+      })
+      .Case<mlir::Float6E3M2FNType>([&](mlir::Float6E3M2FNType) {
+        return hasRelu
+                   ? llvm::Intrinsic::nvvm_bf16x2_to_e3m2x2_rn_relu_satfinite
+                   : llvm::Intrinsic::nvvm_bf16x2_to_e3m2x2_rn_satfinite;
+      })
+      .Default([](mlir::Type) {
+        llvm_unreachable("Invalid conversion in ConvertBF16x2ToF6x2Op");
+        return llvm::Intrinsic::not_intrinsic;
+      });
+}
+
 #define GET_F32x2_TO_F8X2_US_ID(rnd, has_satf)                                 \
   has_satf ? llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd##_satfinite             \
            : llvm::Intrinsic::nvvm_ff_to_ue8m0x2_##rnd
@@ -4171,22 +4321,39 @@ llvm::Intrinsic::ID ConvertF16x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy,
       });
 }
 
-#define GET_BF16X2_TO_F8X2_ID(rnd, has_satf)                                   \
-  has_satf ? llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd##_satfinite         \
-           : llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_##rnd
-
 llvm::Intrinsic::ID
-ConvertBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
-                                      NVVM::SaturationMode sat) {
+ConvertBF16x2ToF8x2Op::getIntrinsicID(mlir::Type dstTy,
+                                      NVVM::FPRoundingMode rnd,
+                                      NVVM::SaturationMode sat, bool hasRelu) {
   bool hasSatFinite = (sat == NVVM::SaturationMode::SATFINITE);
-  switch (rnd) {
-  case NVVM::FPRoundingMode::RZ:
-    return GET_BF16X2_TO_F8X2_ID(rz, hasSatFinite);
-  case NVVM::FPRoundingMode::RP:
-    return GET_BF16X2_TO_F8X2_ID(rp, hasSatFinite);
-  default:
-    llvm_unreachable("Invalid rounding mode for CvtBF16x2ToF8x2Op");
-  }
+
+  static constexpr llvm::Intrinsic::ID ue8m0x2IDs[] = {
+      llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_rz,
+      llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_rp,
+      llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_rz_satfinite,
+      llvm::Intrinsic::nvvm_bf16x2_to_ue8m0x2_rp_satfinite,
+  };
+
+  return llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(dstTy) 
+      .Case<mlir::Float8E4M3FNType>([&](mlir::Float8E4M3FNType) {
+        return hasRelu
+                   ? llvm::Intrinsic::nvvm_bf16x2_to_e4m3x2_rn_relu_satfinite
+                   : llvm::Intrinsic::nvvm_bf16x2_to_e4m3x2_rn_satfinite;
+      })
+      .Case<mlir::Float8E5M2Type>([&](mlir::Float8E5M2Type) {
+        return hasRelu
+                   ? llvm::Intrinsic::nvvm_bf16x2_to_e5m2x2_rn_relu_satfinite
+                   : llvm::Intrinsic::nvvm_bf16x2_to_e5m2x2_rn_satfinite;
+      })
+      .Case<mlir::Float8E8M0FNUType>([&](mlir::Float8E8M0FNUType) {
+        bool hasRoundingModeRP = (rnd == NVVM::FPRoundingMode::RP);
+        unsigned index = (hasSatFinite << 1) | hasRoundingModeRP;
+        return ue8m0x2IDs[index];
+      })
+      .Default([](mlir::Type) {
+        llvm_unreachable("Invalid conversion in ConvertBF16x2ToF8x2Op");
+        return llvm::Intrinsic::not_intrinsic;
+      });
 }
 
 NVVM::IDArgPair ConvertF8x2ToF16x2Op::getIntrinsicIDAndArgs(
@@ -4281,6 +4448,70 @@ NVVM::IDArgPair ConvertF4x2ToF16x2Op::getIntrinsicIDAndArgs(
   return {intId, {extendedI16}};
 }
 
+NVVM::IDAr...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/184291


More information about the Mlir-commits mailing list