[Mlir-commits] [mlir] [MLIR][NVVM] Add support for converting fp4/6/8 to fp16x2 (PR #162439)
Srinivasa Ravi
llvmlistbot at llvm.org
Fri Oct 10 01:18:54 PDT 2025
https://github.com/Wolfram70 updated https://github.com/llvm/llvm-project/pull/162439
>From d2ebae85ca5e44728b8117d1f47c5d512ce960b1 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Wed, 8 Oct 2025 06:26:01 +0000
Subject: [PATCH 1/6] [MLIR][NVVM] Add support for converting fp4/6/8 to fp16x2
This change adds the following NVVM dialect Ops for converting fp4/6/8
to fp16x2:
- convert.f4x2.to.f16x2
- convert.f6x2.to.f16x2
- convert.f8x2.to.f16x2
- convert.f8x2.to.bf16x2
Tests are added in `convert_fp4x2.mlir`, `convert_fp6x2.mlir`, and
`convert_fp8x2.mlir`.
PTX Reference: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt
---
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 107 ++++++++++++++
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 139 ++++++++++++++++++
.../Target/LLVMIR/nvvm/convert_fp4x2.mlir | 14 ++
.../Target/LLVMIR/nvvm/convert_fp6x2.mlir | 24 +++
.../Target/LLVMIR/nvvm/convert_fp8x2.mlir | 34 +++++
5 files changed, 318 insertions(+)
create mode 100644 mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index e2a0331542742..5020af3992173 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1832,6 +1832,113 @@ def NVVM_ConvertBF16x2ToF8x2Op : NVVM_Op<"convert.bf16x2.to.f8x2"> {
}];
}
+class NVVM_ConvertF8x2ToFP16x2Op_Base <string dstType>
+: NVVM_Op<"convert.f8x2.to." # !tolower(dstType) # "x2"> {
+ let summary = "Convert a pair of f8 inputs to " # !tolower(dstType) # "x2";
+ let description = [{
+ This Op converts the given f8 inputs in a i8x2 vector to }] # !tolower(dstType) # [{.
+
+ The result `dst` is represented as a vector of }] # !tolower(dstType) # [{ elements.
+
+ The `relu` attribute, when set, lowers to the '.relu' variant of
+ the cvt instruction.
+
+ For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
+ }];
+ let results = (outs VectorOfLengthAndType<[2], [!cast<Type>(dstType)]>:$dst);
+ let arguments = !if(!eq(dstType, "F16"),
+ (ins VectorOfLengthAndType<[2], [I8]>:$src,
+ DefaultValuedAttr<BoolAttr, "false">:$relu,
+ TypeAttr:$srcType),
+ (ins VectorOfLengthAndType<[2], [I8]>:$src,
+ TypeAttr:$srcType));
+ let assemblyFormat = "$src attr-dict `:` type($src) `(` $srcType `)` `->` type($dst)";
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ static IDArgPair
+ getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder);
+ }];
+
+ string llvmBuilder = [{
+ auto [intId, args] =
+ NVVM::ConvertF8x2To}] # dstType # [{x2Op::getIntrinsicIDAndArgs(*op, moduleTranslation, builder);
+ $dst = createIntrinsicCall(builder, intId, args);
+ }];
+}
+
+def NVVM_ConvertF8x2ToF16x2Op : NVVM_ConvertF8x2ToFP16x2Op_Base<"F16">;
+def NVVM_ConvertF8x2ToBF16x2Op : NVVM_ConvertF8x2ToFP16x2Op_Base<"BF16">;
+
+
+def NVVM_ConvertF6x2ToF16x2Op : NVVM_Op<"convert.f6x2.to.f16x2"> {
+ let summary = "Convert a pair of f6 inputs to f16x2";
+ let description = [{
+ This Op converts the given f6 inputs in a i8x2 vector to f16x2.
+
+ The result `dst` is represented as a vector of f16 elements.
+
+ The `relu` attribute, when set, lowers to the '.relu' variant of
+ the cvt instruction.
+
+ For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
+ }];
+ let results = (outs VectorOfLengthAndType<[2], [F16]>:$dst);
+ let arguments = (ins VectorOfLengthAndType<[2], [I8]>:$src,
+ DefaultValuedAttr<BoolAttr, "false">:$relu,
+ TypeAttr:$srcType);
+ let assemblyFormat = "$src attr-dict `:` type($src) `(` $srcType `)` `->` type($dst)";
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ static IDArgPair
+ getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder);
+ }];
+
+ string llvmBuilder = [{
+ auto [intId, args] =
+ NVVM::ConvertF6x2ToF16x2Op::getIntrinsicIDAndArgs(*op, moduleTranslation, builder);
+ $dst = createIntrinsicCall(builder, intId, args);
+ }];
+}
+
+def NVVM_ConvertF4x2ToF16x2Op : NVVM_Op<"convert.f4x2.to.f16x2"> {
+ let summary = "Convert a pair of f4 inputs to f16x2";
+ let description = [{
+ This Op converts the given f4 inputs packed in an i8 to f16x2.
+
+ The result `dst` is represented as a vector of f16 elements. The value
+ converted from the lower 4 bits of `src` is stored in the first element of
+ `dst` and the value converted from the upper 4 bits of `src` is stored in
+ the second element of `dst`.
+
+ The `relu` attribute, when set, lowers to the '.relu' variant of
+ the cvt instruction.
+
+ For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
+ }];
+ let results = (outs VectorOfLengthAndType<[2], [F16]>:$dst);
+ let arguments = (ins I8:$src,
+ DefaultValuedAttr<BoolAttr, "false">:$relu,
+ TypeAttr:$srcType);
+ let assemblyFormat = "$src attr-dict `:` type($src) `(` $srcType `)` `->` type($dst)";
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ static IDArgPair
+ getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder);
+ }];
+
+ string llvmBuilder = [{
+ auto [intId, args] =
+ NVVM::ConvertF4x2ToF16x2Op::getIntrinsicIDAndArgs(*op, moduleTranslation, builder);
+ $dst = createIntrinsicCall(builder, intId, args);
+ }];
+}
+
//===----------------------------------------------------------------------===//
// NVVM MMA Ops
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 7f419a062201d..bd38db52179a6 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -309,6 +309,51 @@ LogicalResult ConvertBF16x2ToF8x2Op::verify() {
return success();
}
+LogicalResult ConvertF8x2ToF16x2Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (!llvm::isa<Float8E4M3FNType, Float8E5M2Type>(getSrcType()))
+ return emitOpError("Only ")
+ << mlir::Float8E4M3FNType::get(ctx) << " and "
+ << mlir::Float8E5M2Type::get(ctx)
+ << " types are supported for conversions from f8x2 to f16x2.";
+
+ return success();
+}
+
+LogicalResult ConvertF8x2ToBF16x2Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+ if (!llvm::isa<Float8E8M0FNUType>(getSrcType()))
+ return emitOpError("Only ")
+ << mlir::Float8E8M0FNUType::get(ctx)
+ << " type is supported for conversions from f8x2 to bf16x2.";
+
+ return success();
+}
+
+LogicalResult ConvertF6x2ToF16x2Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (!llvm::isa<Float6E2M3FNType, Float6E3M2FNType>(getSrcType()))
+ return emitOpError("Only ")
+ << mlir::Float6E2M3FNType::get(ctx) << " and "
+ << mlir::Float6E3M2FNType::get(ctx)
+ << " types are supported for conversions from f6x2 to f16x2.";
+
+ return success();
+}
+
+LogicalResult ConvertF4x2ToF16x2Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (!llvm::isa<Float4E2M1FNType>(getSrcType()))
+ return emitOpError("Only ")
+ << mlir::Float4E2M1FNType::get(ctx)
+ << " type is supported for conversions from f4x2 to f16x2.";
+
+ return success();
+}
+
LogicalResult BulkStoreOp::verify() {
if (getInitVal() != 0)
return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
@@ -2106,6 +2151,100 @@ ConvertBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
}
}
+NVVM::IDArgPair ConvertF8x2ToF16x2Op::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto curOp = cast<NVVM::ConvertF8x2ToF16x2Op>(op);
+
+ bool hasRelu = curOp.getRelu();
+
+ llvm::Intrinsic::ID intId =
+ llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(curOp.getSrcType())
+ .Case<Float8E4M3FNType>([&](Float8E4M3FNType type) {
+ return hasRelu
+ ? llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn_relu
+ : llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn;
+ })
+ .Case<Float8E5M2Type>([&](Float8E5M2Type type) {
+ return hasRelu
+ ? llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn_relu
+ : llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn;
+ })
+ .Default([](mlir::Type type) {
+ llvm_unreachable("Invalid type for ConvertF8x2ToF16x2Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
+
+ llvm::Value *packedI16 =
+ builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
+ llvm::Type::getInt16Ty(builder.getContext()));
+
+ return {intId, {packedI16}};
+}
+
+NVVM::IDArgPair ConvertF8x2ToBF16x2Op::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto curOp = cast<NVVM::ConvertF8x2ToBF16x2Op>(op);
+
+ llvm::Intrinsic::ID intId = llvm::Intrinsic::nvvm_ue8m0x2_to_bf16x2;
+ llvm::Value *packedI16 =
+ builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
+ llvm::Type::getInt16Ty(builder.getContext()));
+
+ return {intId, {packedI16}};
+}
+
+NVVM::IDArgPair ConvertF6x2ToF16x2Op::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto curOp = cast<NVVM::ConvertF6x2ToF16x2Op>(op);
+
+ bool hasRelu = curOp.getRelu();
+
+ llvm::Intrinsic::ID intId =
+ llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(curOp.getSrcType())
+ .Case<Float6E2M3FNType>([&](Float6E2M3FNType type) {
+ return hasRelu ? llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn_relu
+ : llvm::Intrinsic::nvvm_e2m3x2_to_f16x2_rn;
+ })
+ .Case<Float6E3M2FNType>([&](Float6E3M2FNType type) {
+ return hasRelu ? llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn_relu
+ : llvm::Intrinsic::nvvm_e3m2x2_to_f16x2_rn;
+ })
+ .Default([](mlir::Type type) {
+ llvm_unreachable("Invalid type for ConvertF6x2ToF16x2Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
+
+ llvm::Value *packedI16 =
+ builder.CreateBitCast(mt.lookupValue(curOp.getSrc()),
+ llvm::Type::getInt16Ty(builder.getContext()));
+
+ return {intId, {packedI16}};
+}
+
+NVVM::IDArgPair ConvertF4x2ToF16x2Op::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto curOp = cast<NVVM::ConvertF4x2ToF16x2Op>(op);
+
+ bool hasRelu = curOp.getRelu();
+
+ llvm::Intrinsic::ID intId =
+ llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(curOp.getSrcType())
+ .Case<Float4E2M1FNType>([&](Float4E2M1FNType type) {
+ return hasRelu ? llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn_relu
+ : llvm::Intrinsic::nvvm_e2m1x2_to_f16x2_rn;
+ })
+ .Default([](mlir::Type type) {
+ llvm_unreachable("Invalid type for ConvertF4x2ToF16x2Op");
+ return llvm::Intrinsic::not_intrinsic;
+ });
+
+ llvm::Value *extendedI16 =
+ builder.CreateZExt(mt.lookupValue(curOp.getSrc()),
+ llvm::Type::getInt16Ty(builder.getContext()));
+
+ return {intId, {extendedI16}};
+}
+
llvm::Intrinsic::ID
Tcgen05AllocOp::getIntrinsicIDAndArgs(Operation &op,
LLVM::ModuleTranslation &mt,
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
new file mode 100644
index 0000000000000..e43dea4065c08
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// -----
+
+// CHECK-LABEL: @convert_f4x2_to_f16x2
+llvm.func @convert_f4x2_to_f16x2(%src : i8) {
+ // CHECK: %[[res1:.*]] = zext i8 %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn(i16 %[[res1]])
+ %res1 = nvvm.convert.f4x2.to.f16x2 %src : i8 (f4E2M1FN)-> vector<2xf16>
+ // CHECK: %[[res2:.*]] = zext i8 %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e2m1x2.to.f16x2.rn.relu(i16 %[[res2]])
+ %res2 = nvvm.convert.f4x2.to.f16x2 %src {relu = true} : i8 (f4E2M1FN)-> vector<2xf16>
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir
index 99289923b58b1..61a7a48f40d54 100644
--- a/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp6x2.mlir
@@ -19,3 +19,27 @@ llvm.func @convert_f32x2_to_fp6x2_vector(%srcA : f32, %srcB : f32) {
%res2 = nvvm.convert.f32x2.to.f6x2 %srcA, %srcB : vector<2xi8> (f6E3M2FN)
llvm.return
}
+
+// -----
+
+// CHECK-LABEL: @convert_f6x2_to_f16x2_e2m3
+llvm.func @convert_f6x2_to_f16x2_e2m3(%src : vector<2xi8>) {
+ // CHECK: %[[res1:.*]] = bitcast <2 x i8> %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e2m3x2.to.f16x2.rn(i16 %[[res1]])
+ %res1 = nvvm.convert.f6x2.to.f16x2 %src : vector<2xi8> (f6E2M3FN)-> vector<2xf16>
+ // CHECK: %[[res2:.*]] = bitcast <2 x i8> %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e2m3x2.to.f16x2.rn.relu(i16 %[[res2]])
+ %res2 = nvvm.convert.f6x2.to.f16x2 %src {relu = true} : vector<2xi8> (f6E2M3FN)-> vector<2xf16>
+ llvm.return
+}
+
+// CHECK-LABEL: @convert_f6x2_to_f16x2_e3m2
+llvm.func @convert_f6x2_to_f16x2_e3m2(%src : vector<2xi8>) {
+ // CHECK: %[[res1:.*]] = bitcast <2 x i8> %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e3m2x2.to.f16x2.rn(i16 %[[res1]])
+ %res1 = nvvm.convert.f6x2.to.f16x2 %src : vector<2xi8> (f6E3M2FN)-> vector<2xf16>
+ // CHECK: %[[res2:.*]] = bitcast <2 x i8> %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e3m2x2.to.f16x2.rn.relu(i16 %[[res2]])
+ %res2 = nvvm.convert.f6x2.to.f16x2 %src {relu = true} : vector<2xi8> (f6E3M2FN)-> vector<2xf16>
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir
index de21826445afb..4afe901bc08e9 100644
--- a/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp8x2.mlir
@@ -100,3 +100,37 @@ llvm.func @convert_bf16x2_to_f8x2_vector_return(%src : vector<2xbf16>) {
%res2 = nvvm.convert.bf16x2.to.f8x2 %src {rnd = #nvvm.fp_rnd_mode<rp>, sat = #nvvm.sat_mode<satfinite>} : vector<2xbf16> -> vector<2xi8> (f8E8M0FNU)
llvm.return
}
+
+// -----
+
+// CHECK-LABEL: @convert_f8x2_to_f16x2
+llvm.func @convert_f8x2_to_f16x2_e4m3(%src : vector<2xi8>) {
+ // CHECK: %[[res1:.*]] = bitcast <2 x i8> %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e4m3x2.to.f16x2.rn(i16 %[[res1]])
+ %res1 = nvvm.convert.f8x2.to.f16x2 %src : vector<2xi8> (f8E4M3FN)-> vector<2xf16>
+ // CHECK: %[[res2:.*]] = bitcast <2 x i8> %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e4m3x2.to.f16x2.rn.relu(i16 %[[res2]])
+ %res2 = nvvm.convert.f8x2.to.f16x2 %src {relu = true} : vector<2xi8> (f8E4M3FN)-> vector<2xf16>
+ llvm.return
+}
+
+// CHECK-LABEL: @convert_f8x2_to_f16x2_e5m2
+llvm.func @convert_f8x2_to_f16x2_e5m2(%src : vector<2xi8>) {
+ // CHECK: %[[res1:.*]] = bitcast <2 x i8> %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e5m2x2.to.f16x2.rn(i16 %[[res1]])
+ %res1 = nvvm.convert.f8x2.to.f16x2 %src : vector<2xi8> (f8E5M2)-> vector<2xf16>
+ // CHECK: %[[res2:.*]] = bitcast <2 x i8> %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x half> @llvm.nvvm.e5m2x2.to.f16x2.rn.relu(i16 %[[res2]])
+ %res2 = nvvm.convert.f8x2.to.f16x2 %src {relu = true} : vector<2xi8> (f8E5M2)-> vector<2xf16>
+ llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @convert_f8x2_to_bf16x2_ue8m0
+llvm.func @convert_f8x2_to_bf16x2_ue8m0(%src : vector<2xi8>) {
+ // CHECK: %[[res1:.*]] = bitcast <2 x i8> %{{.*}} to i16
+ // CHECK-NEXT: %{{.*}} = call <2 x bfloat> @llvm.nvvm.ue8m0x2.to.bf16x2(i16 %[[res1]])
+ %res1 = nvvm.convert.f8x2.to.bf16x2 %src : vector<2xi8> (f8E8M0FNU)-> vector<2xbf16>
+ llvm.return
+}
>From 75a3137bc0dc4a52ef642dc17336f5e8f2101dd8 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Wed, 8 Oct 2025 06:42:13 +0000
Subject: [PATCH 2/6] add invalid test cases
---
mlir/test/Target/LLVMIR/nvvmir-invalid.mlir | 32 +++++++++++++++++++++
1 file changed, 32 insertions(+)
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index 0b3615487716d..c5f71cfeaba8b 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -254,6 +254,38 @@ llvm.func @nvvm_cvt_f32x2_to_f6x2_invalid_type(%a : f32, %b : f32) {
// -----
+llvm.func @nvvm_cvt_f8x2_to_f16x2_invalid_type(%src : vector<2xi8>) {
+ // expected-error @below {{Only 'f8E4M3FN' and 'f8E5M2' types are supported for conversions from f8x2 to f16x2.}}
+ %res = nvvm.convert.f8x2.to.f16x2 %src : vector<2xi8> (f8E4M3) -> vector<2xf16>
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_f8x2_to_bf16x2_invalid_type(%src : vector<2xi8>) {
+ // expected-error @below {{Only 'f8E8M0FNU' type is supported for conversions from f8x2 to bf16x2.}}
+ %res = nvvm.convert.f8x2.to.bf16x2 %src : vector<2xi8> (f8E4M3FN) -> vector<2xbf16>
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_f6x2_to_f16x2_invalid_type(%src : vector<2xi8>) {
+ // expected-error @below {{Only 'f6E2M3FN' and 'f6E3M2FN' types are supported for conversions from f6x2 to f16x2.}}
+ %res = nvvm.convert.f6x2.to.f16x2 %src : vector<2xi8> (f8E4M3FN) -> vector<2xf16>
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_cvt_f4x2_to_f16x2_invalid_type(%src : i8) {
+ // expected-error @below {{Only 'f4E2M1FN' type is supported for conversions from f4x2 to f16x2.}}
+ %res = nvvm.convert.f4x2.to.f16x2 %src : i8 (f6E2M3FN) -> vector<2xf16>
+ llvm.return
+}
+
+// -----
+
llvm.func @nvvm_prefetch_L1_with_evict_priority(%global_ptr: !llvm.ptr<1>) {
// expected-error @below {{cache eviction priority supported only for cache level L2}}
nvvm.prefetch level = L1, evict_priority = evict_last, %global_ptr : !llvm.ptr<1>
>From 36d41ee036f7c3c8717bc52e372c80d91dd5f0b8 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Wed, 8 Oct 2025 06:44:25 +0000
Subject: [PATCH 3/6] fix formatting
---
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 8 +++-----
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 2 +-
2 files changed, 4 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 5020af3992173..e6cfa24ccc3bb 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1843,7 +1843,7 @@ class NVVM_ConvertF8x2ToFP16x2Op_Base <string dstType>
The `relu` attribute, when set, lowers to the '.relu' variant of
the cvt instruction.
- For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
}];
let results = (outs VectorOfLengthAndType<[2], [!cast<Type>(dstType)]>:$dst);
let arguments = !if(!eq(dstType, "F16"),
@@ -1867,11 +1867,9 @@ class NVVM_ConvertF8x2ToFP16x2Op_Base <string dstType>
$dst = createIntrinsicCall(builder, intId, args);
}];
}
-
def NVVM_ConvertF8x2ToF16x2Op : NVVM_ConvertF8x2ToFP16x2Op_Base<"F16">;
def NVVM_ConvertF8x2ToBF16x2Op : NVVM_ConvertF8x2ToFP16x2Op_Base<"BF16">;
-
def NVVM_ConvertF6x2ToF16x2Op : NVVM_Op<"convert.f6x2.to.f16x2"> {
let summary = "Convert a pair of f6 inputs to f16x2";
let description = [{
@@ -1882,7 +1880,7 @@ def NVVM_ConvertF6x2ToF16x2Op : NVVM_Op<"convert.f6x2.to.f16x2"> {
The `relu` attribute, when set, lowers to the '.relu' variant of
the cvt instruction.
- For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
}];
let results = (outs VectorOfLengthAndType<[2], [F16]>:$dst);
let arguments = (ins VectorOfLengthAndType<[2], [I8]>:$src,
@@ -1917,7 +1915,7 @@ def NVVM_ConvertF4x2ToF16x2Op : NVVM_Op<"convert.f4x2.to.f16x2"> {
The `relu` attribute, when set, lowers to the '.relu' variant of
the cvt instruction.
- For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
}];
let results = (outs VectorOfLengthAndType<[2], [F16]>:$dst);
let arguments = (ins I8:$src,
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index bd38db52179a6..de7b9c31f7623 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -2154,7 +2154,7 @@ ConvertBF16x2ToF8x2Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
NVVM::IDArgPair ConvertF8x2ToF16x2Op::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto curOp = cast<NVVM::ConvertF8x2ToF16x2Op>(op);
-
+
bool hasRelu = curOp.getRelu();
llvm::Intrinsic::ID intId =
>From 23afb336e4a66bf49116ff67e03a6b0451bed7cf Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Wed, 8 Oct 2025 08:48:10 +0000
Subject: [PATCH 4/6] fix formatting
---
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 10 ++++------
1 file changed, 4 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index de7b9c31f7623..229a1f3a4ad66 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -2160,14 +2160,12 @@ NVVM::IDArgPair ConvertF8x2ToF16x2Op::getIntrinsicIDAndArgs(
llvm::Intrinsic::ID intId =
llvm::TypeSwitch<mlir::Type, llvm::Intrinsic::ID>(curOp.getSrcType())
.Case<Float8E4M3FNType>([&](Float8E4M3FNType type) {
- return hasRelu
- ? llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn_relu
- : llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn;
+ return hasRelu ? llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn_relu
+ : llvm::Intrinsic::nvvm_e4m3x2_to_f16x2_rn;
})
.Case<Float8E5M2Type>([&](Float8E5M2Type type) {
- return hasRelu
- ? llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn_relu
- : llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn;
+ return hasRelu ? llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn_relu
+ : llvm::Intrinsic::nvvm_e5m2x2_to_f16x2_rn;
})
.Default([](mlir::Type type) {
llvm_unreachable("Invalid type for ConvertF8x2ToF16x2Op");
>From 41e2da8173dec1d64e03bbb7d30fe1bb35b94018 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Wed, 8 Oct 2025 10:27:02 +0000
Subject: [PATCH 5/6] refactor tablegen
---
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 100 +++++---------------
1 file changed, 22 insertions(+), 78 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index e6cfa24ccc3bb..ed1308fba5578 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1832,25 +1832,28 @@ def NVVM_ConvertBF16x2ToF8x2Op : NVVM_Op<"convert.bf16x2.to.f8x2"> {
}];
}
-class NVVM_ConvertF8x2ToFP16x2Op_Base <string dstType>
-: NVVM_Op<"convert.f8x2.to." # !tolower(dstType) # "x2"> {
- let summary = "Convert a pair of f8 inputs to " # !tolower(dstType) # "x2";
+class NVVM_ConvertToFP16x2Op_Base <string srcType, Type srcArgType, string dstType>
+: NVVM_Op<"convert." # srcType # "x2.to." # !tolower(dstType) # "x2"> {
+ let summary = "Convert a pair of " # srcType # " inputs to " # !tolower(dstType) # "x2";
let description = [{
- This Op converts the given f8 inputs in a i8x2 vector to }] # !tolower(dstType) # [{.
+ This Op converts the given }] # srcType # [{ inputs in a }] #
+ !if(!eq(srcType, "f4"), "packed i8", "i8x2 vector") # [{ to }] #
+ !tolower(dstType) # [{.
The result `dst` is represented as a vector of }] # !tolower(dstType) # [{ elements.
-
- The `relu` attribute, when set, lowers to the '.relu' variant of
- the cvt instruction.
+ }] #
+ !if(!eq(dstType, "F16"),
+ [{The `relu` attribute, when set, lowers to the '.relu' variant of
+ the cvt instruction."}], "") # [{
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
}];
let results = (outs VectorOfLengthAndType<[2], [!cast<Type>(dstType)]>:$dst);
let arguments = !if(!eq(dstType, "F16"),
- (ins VectorOfLengthAndType<[2], [I8]>:$src,
+ (ins srcArgType:$src,
DefaultValuedAttr<BoolAttr, "false">:$relu,
TypeAttr:$srcType),
- (ins VectorOfLengthAndType<[2], [I8]>:$src,
+ (ins srcArgType:$src,
TypeAttr:$srcType));
let assemblyFormat = "$src attr-dict `:` type($src) `(` $srcType `)` `->` type($dst)";
let hasVerifier = 1;
@@ -1863,79 +1866,20 @@ class NVVM_ConvertF8x2ToFP16x2Op_Base <string dstType>
string llvmBuilder = [{
auto [intId, args] =
- NVVM::ConvertF8x2To}] # dstType # [{x2Op::getIntrinsicIDAndArgs(*op, moduleTranslation, builder);
- $dst = createIntrinsicCall(builder, intId, args);
- }];
-}
-def NVVM_ConvertF8x2ToF16x2Op : NVVM_ConvertF8x2ToFP16x2Op_Base<"F16">;
-def NVVM_ConvertF8x2ToBF16x2Op : NVVM_ConvertF8x2ToFP16x2Op_Base<"BF16">;
-
-def NVVM_ConvertF6x2ToF16x2Op : NVVM_Op<"convert.f6x2.to.f16x2"> {
- let summary = "Convert a pair of f6 inputs to f16x2";
- let description = [{
- This Op converts the given f6 inputs in a i8x2 vector to f16x2.
-
- The result `dst` is represented as a vector of f16 elements.
-
- The `relu` attribute, when set, lowers to the '.relu' variant of
- the cvt instruction.
-
- [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
- }];
- let results = (outs VectorOfLengthAndType<[2], [F16]>:$dst);
- let arguments = (ins VectorOfLengthAndType<[2], [I8]>:$src,
- DefaultValuedAttr<BoolAttr, "false">:$relu,
- TypeAttr:$srcType);
- let assemblyFormat = "$src attr-dict `:` type($src) `(` $srcType `)` `->` type($dst)";
- let hasVerifier = 1;
-
- let extraClassDeclaration = [{
- static IDArgPair
- getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
- llvm::IRBuilderBase &builder);
- }];
-
- string llvmBuilder = [{
- auto [intId, args] =
- NVVM::ConvertF6x2ToF16x2Op::getIntrinsicIDAndArgs(*op, moduleTranslation, builder);
+ NVVM::Convert}] # !toupper(srcType) # [{x2To}] # dstType #
+ [{x2Op::getIntrinsicIDAndArgs(*op, moduleTranslation, builder);
$dst = createIntrinsicCall(builder, intId, args);
}];
}
-def NVVM_ConvertF4x2ToF16x2Op : NVVM_Op<"convert.f4x2.to.f16x2"> {
- let summary = "Convert a pair of f4 inputs to f16x2";
- let description = [{
- This Op converts the given f4 inputs packed in an i8 to f16x2.
-
- The result `dst` is represented as a vector of f16 elements. The value
- converted from the lower 4 bits of `src` is stored in the first element of
- `dst` and the value converted from the upper 4 bits of `src` is stored in
- the second element of `dst`.
-
- The `relu` attribute, when set, lowers to the '.relu' variant of
- the cvt instruction.
-
- [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
- }];
- let results = (outs VectorOfLengthAndType<[2], [F16]>:$dst);
- let arguments = (ins I8:$src,
- DefaultValuedAttr<BoolAttr, "false">:$relu,
- TypeAttr:$srcType);
- let assemblyFormat = "$src attr-dict `:` type($src) `(` $srcType `)` `->` type($dst)";
- let hasVerifier = 1;
-
- let extraClassDeclaration = [{
- static IDArgPair
- getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
- llvm::IRBuilderBase &builder);
- }];
-
- string llvmBuilder = [{
- auto [intId, args] =
- NVVM::ConvertF4x2ToF16x2Op::getIntrinsicIDAndArgs(*op, moduleTranslation, builder);
- $dst = createIntrinsicCall(builder, intId, args);
- }];
-}
+def NVVM_ConvertF8x2ToF16x2Op :
+ NVVM_ConvertToFP16x2Op_Base<"f8", VectorOfLengthAndType<[2], [I8]>, "F16">;
+def NVVM_ConvertF8x2ToBF16x2Op :
+ NVVM_ConvertToFP16x2Op_Base<"f8", VectorOfLengthAndType<[2], [I8]>, "BF16">;
+def NVVM_ConvertF6x2ToF16x2Op :
+ NVVM_ConvertToFP16x2Op_Base<"f6", VectorOfLengthAndType<[2], [I8]>, "F16">;
+def NVVM_ConvertF4x2ToF16x2Op :
+ NVVM_ConvertToFP16x2Op_Base<"f4", I8, "F16">;
//===----------------------------------------------------------------------===//
// NVVM MMA Ops
>From a349a4f0da142436f176bc4fb89615822ddeef04 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Fri, 10 Oct 2025 06:44:53 +0000
Subject: [PATCH 6/6] match argument cases
---
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 18 +++++++++---------
1 file changed, 9 insertions(+), 9 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index ed1308fba5578..0287a20e7ed10 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1833,11 +1833,11 @@ def NVVM_ConvertBF16x2ToF8x2Op : NVVM_Op<"convert.bf16x2.to.f8x2"> {
}
class NVVM_ConvertToFP16x2Op_Base <string srcType, Type srcArgType, string dstType>
-: NVVM_Op<"convert." # srcType # "x2.to." # !tolower(dstType) # "x2"> {
- let summary = "Convert a pair of " # srcType # " inputs to " # !tolower(dstType) # "x2";
+: NVVM_Op<"convert." # !tolower(srcType) # "x2.to." # !tolower(dstType) # "x2"> {
+ let summary = "Convert a pair of " # !tolower(srcType) # " inputs to " # !tolower(dstType) # "x2";
let description = [{
- This Op converts the given }] # srcType # [{ inputs in a }] #
- !if(!eq(srcType, "f4"), "packed i8", "i8x2 vector") # [{ to }] #
+ This Op converts the given }] # !tolower(srcType) # [{ inputs in a }] #
+ !if(!eq(srcType, "F4"), "packed i8", "i8x2 vector") # [{ to }] #
!tolower(dstType) # [{.
The result `dst` is represented as a vector of }] # !tolower(dstType) # [{ elements.
@@ -1866,20 +1866,20 @@ class NVVM_ConvertToFP16x2Op_Base <string srcType, Type srcArgType, string dstTy
string llvmBuilder = [{
auto [intId, args] =
- NVVM::Convert}] # !toupper(srcType) # [{x2To}] # dstType #
+ NVVM::Convert}] # srcType # [{x2To}] # dstType #
[{x2Op::getIntrinsicIDAndArgs(*op, moduleTranslation, builder);
$dst = createIntrinsicCall(builder, intId, args);
}];
}
def NVVM_ConvertF8x2ToF16x2Op :
- NVVM_ConvertToFP16x2Op_Base<"f8", VectorOfLengthAndType<[2], [I8]>, "F16">;
+ NVVM_ConvertToFP16x2Op_Base<"F8", VectorOfLengthAndType<[2], [I8]>, "F16">;
def NVVM_ConvertF8x2ToBF16x2Op :
- NVVM_ConvertToFP16x2Op_Base<"f8", VectorOfLengthAndType<[2], [I8]>, "BF16">;
+ NVVM_ConvertToFP16x2Op_Base<"F8", VectorOfLengthAndType<[2], [I8]>, "BF16">;
def NVVM_ConvertF6x2ToF16x2Op :
- NVVM_ConvertToFP16x2Op_Base<"f6", VectorOfLengthAndType<[2], [I8]>, "F16">;
+ NVVM_ConvertToFP16x2Op_Base<"F6", VectorOfLengthAndType<[2], [I8]>, "F16">;
def NVVM_ConvertF4x2ToF16x2Op :
- NVVM_ConvertToFP16x2Op_Base<"f4", I8, "F16">;
+ NVVM_ConvertToFP16x2Op_Base<"F4", I8, "F16">;
//===----------------------------------------------------------------------===//
// NVVM MMA Ops
More information about the Mlir-commits
mailing list