[Mlir-commits] [mlir] [MLIR][NVVM] Add support for f32x2 to f4x2 conversion (PR #162273)
Srinivasa Ravi
llvmlistbot at llvm.org
Thu Oct 9 22:10:09 PDT 2025
https://github.com/Wolfram70 updated https://github.com/llvm/llvm-project/pull/162273
>From e5acd779f71e5e32b707cf9b28a35e8de7894b8e Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Mon, 18 Aug 2025 16:02:03 +0530
Subject: [PATCH 1/5] [MLIR][NVVM] Add support for f32x2 to f4x2 conversion
This change adds the `convert.f32x2.to.f4x2` op to the NVVM Dialect
for converting a pair of f32 values to an f4x2 (`e2m1x2`) value.
PTX reference: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt
---
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 41 +++++++++++++++++++
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 6 +++
.../Target/LLVMIR/nvvm/convert_fp4x2.mlir | 12 ++++++
3 files changed, 59 insertions(+)
create mode 100644 mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index e2a0331542742..518acc2afbc8e 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1655,6 +1655,47 @@ def NVVM_ConvertFloatToTF32Op : NVVM_Op<"convert.float.to.tf32"> {
}];
}
+def ConvertFP4E2M1 : I32EnumCase<"E2M1", 0, "e2m1">;
+
+def ConvertFP4Type : I32Enum<"ConvertFP4Type", "NVVM ConvertFP4Type kind",
+ [ConvertFP4E2M1]> {
+ let cppNamespace = "::mlir::NVVM";
+}
+def ConvertFP4TypeAttr : EnumAttr<NVVM_Dialect, ConvertFP4Type,
+ "convert_fp4_type"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
+def NVVM_ConvertF32x2ToF4x2Op : NVVM_Op<"convert.f32x2.to.f4x2"> {
+ let summary = "Convert a pair of float inputs to f4x2";
+ let description = [{
+ This Op converts each of the given float inputs to the specified fp4 type.
+ The result `dst` is returned as an i8 type where the converted values are
+ packed such that the value converted from `a` is stored in the upper 4 bits
+ of `dst` and the value converted from `b` is stored in the lower 4 bits of
+ `dst`.
+ The `relu` attribute, when set, lowers to the '.relu' variant of
+ the cvt instruction.
+
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
+ }];
+
+ let results = (outs I8:$dst);
+ let arguments = (ins ConvertFP4TypeAttr:$type, F32:$a, F32:$b,
+ DefaultValuedAttr<BoolAttr, "false">:$relu);
+ let assemblyFormat = "$type $a `,` $b attr-dict";
+
+ let extraClassDeclaration = [{
+ static llvm::Intrinsic::ID getIntrinsicID(NVVM::ConvertFP4Type, bool hasRelu);
+ }];
+
+ string llvmBuilder = [{
+ auto intId = NVVM::ConvertF32x2ToF4x2Op::getIntrinsicID($type, $relu);
+ llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b});
+ $dst = builder.CreateTruncOrBitCast(packedI16, llvm::Type::getInt8Ty(builder.getContext()));
+ }];
+}
+
def NVVM_ConvertF32x2ToF6x2Op : NVVM_Op<"convert.f32x2.to.f6x2"> {
let summary = "Convert a pair of float inputs to f6x2";
let description = [{
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 7f419a062201d..ac6d4f3d361b9 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -2014,6 +2014,12 @@ ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
}
}
+llvm::Intrinsic::ID
+ConvertF32x2ToF4x2Op::getIntrinsicID(NVVM::ConvertFP4Type type, bool hasRelu) {
+ return hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite
+ : llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite;
+}
+
#define GET_F32x2_TO_F6x2_ID(type, has_relu) \
has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
: llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
new file mode 100644
index 0000000000000..f1599a6b51e3a
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
@@ -0,0 +1,12 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: @convert_f32x2_to_f4x2_e2m1
+llvm.func @convert_f32x2_to_f4x2_e2m1(%srcA : f32, %srcB : f32) {
+ // CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.ff.to.e2m1x2.rn.satfinite(float %{{.*}}, float %{{.*}})
+ // CHECK-NEXT: %{{.*}} = trunc i16 %[[res1]] to i8
+ %res1 = nvvm.convert.f32x2.to.f4x2 <e2m1> %srcA, %srcB
+ // CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.ff.to.e2m1x2.rn.relu.satfinite(float %{{.*}}, float %{{.*}})
+ // CHECK-NEXT: %{{.*}} = trunc i16 %[[res2]] to i8
+ %res2 = nvvm.convert.f32x2.to.f4x2 <e2m1> %srcA, %srcB {relu = true}
+ llvm.return
+}
>From 8d9d1ab53a4acf5fb0389a3116588b03e5e55625 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Tue, 7 Oct 2025 12:20:12 +0000
Subject: [PATCH 2/5] update to use builtin types
---
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 23 ++++++-------------
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 13 ++++++++++-
.../Target/LLVMIR/nvvm/convert_fp4x2.mlir | 4 ++--
mlir/test/Target/LLVMIR/nvvmir-invalid.mlir | 8 +++++++
4 files changed, 29 insertions(+), 19 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 518acc2afbc8e..347195dc87443 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1655,17 +1655,6 @@ def NVVM_ConvertFloatToTF32Op : NVVM_Op<"convert.float.to.tf32"> {
}];
}
-def ConvertFP4E2M1 : I32EnumCase<"E2M1", 0, "e2m1">;
-
-def ConvertFP4Type : I32Enum<"ConvertFP4Type", "NVVM ConvertFP4Type kind",
- [ConvertFP4E2M1]> {
- let cppNamespace = "::mlir::NVVM";
-}
-def ConvertFP4TypeAttr : EnumAttr<NVVM_Dialect, ConvertFP4Type,
- "convert_fp4_type"> {
- let assemblyFormat = "`<` $value `>`";
-}
-
def NVVM_ConvertF32x2ToF4x2Op : NVVM_Op<"convert.f32x2.to.f4x2"> {
let summary = "Convert a pair of float inputs to f4x2";
let description = [{
@@ -1681,16 +1670,18 @@ def NVVM_ConvertF32x2ToF4x2Op : NVVM_Op<"convert.f32x2.to.f4x2"> {
}];
let results = (outs I8:$dst);
- let arguments = (ins ConvertFP4TypeAttr:$type, F32:$a, F32:$b,
- DefaultValuedAttr<BoolAttr, "false">:$relu);
- let assemblyFormat = "$type $a `,` $b attr-dict";
+ let arguments = (ins F32:$a, F32:$b,
+ DefaultValuedAttr<BoolAttr, "false">:$relu,
+ TypeAttr:$dstTy);
+ let assemblyFormat = "$a `,` $b attr-dict `:` type($dst) `(` $dstTy `)`";
+ let hasVerifier = 1;
let extraClassDeclaration = [{
- static llvm::Intrinsic::ID getIntrinsicID(NVVM::ConvertFP4Type, bool hasRelu);
+ static llvm::Intrinsic::ID getIntrinsicID(bool hasRelu);
}];
string llvmBuilder = [{
- auto intId = NVVM::ConvertF32x2ToF4x2Op::getIntrinsicID($type, $relu);
+ auto intId = NVVM::ConvertF32x2ToF4x2Op::getIntrinsicID($relu);
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b});
$dst = builder.CreateTruncOrBitCast(packedI16, llvm::Type::getInt8Ty(builder.getContext()));
}];
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index ac6d4f3d361b9..cf92856461cc8 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -309,6 +309,17 @@ LogicalResult ConvertBF16x2ToF8x2Op::verify() {
return success();
}
+LogicalResult ConvertF32x2ToF4x2Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
+ return emitOpError("Only ")
+ << mlir::Float4E2M1FNType::get(ctx)
+ << " type is supported for conversions from f32x2 to f4x2.";
+
+ return success();
+}
+
LogicalResult BulkStoreOp::verify() {
if (getInitVal() != 0)
return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
@@ -2015,7 +2026,7 @@ ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
}
llvm::Intrinsic::ID
-ConvertF32x2ToF4x2Op::getIntrinsicID(NVVM::ConvertFP4Type type, bool hasRelu) {
+ConvertF32x2ToF4x2Op::getIntrinsicID(bool hasRelu) {
return hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite
: llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite;
}
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
index f1599a6b51e3a..04e2ddff802a9 100644
--- a/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
@@ -4,9 +4,9 @@
llvm.func @convert_f32x2_to_f4x2_e2m1(%srcA : f32, %srcB : f32) {
// CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.ff.to.e2m1x2.rn.satfinite(float %{{.*}}, float %{{.*}})
// CHECK-NEXT: %{{.*}} = trunc i16 %[[res1]] to i8
- %res1 = nvvm.convert.f32x2.to.f4x2 <e2m1> %srcA, %srcB
+ %res1 = nvvm.convert.f32x2.to.f4x2 %srcA, %srcB : i8 (f4E2M1FN)
// CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.ff.to.e2m1x2.rn.relu.satfinite(float %{{.*}}, float %{{.*}})
// CHECK-NEXT: %{{.*}} = trunc i16 %[[res2]] to i8
- %res2 = nvvm.convert.f32x2.to.f4x2 <e2m1> %srcA, %srcB {relu = true}
+ %res2 = nvvm.convert.f32x2.to.f4x2 %srcA, %srcB {relu = true} : i8 (f4E2M1FN)
llvm.return
}
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index 0b3615487716d..78e1e659ed85d 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -254,6 +254,14 @@ llvm.func @nvvm_cvt_f32x2_to_f6x2_invalid_type(%a : f32, %b : f32) {
// -----
+llvm.func @nvvm_cvt_f32x2_to_f4x2_invalid_type(%a : f32, %b : f32) {
+ // expected-error @below {{Only 'f4E2M1FN' type is supported for conversions from f32x2 to f4x2.}}
+ %res = nvvm.convert.f32x2.to.f4x2 %a, %b : i8 (f8E4M3FN)
+ llvm.return
+}
+
+// -----
+
llvm.func @nvvm_prefetch_L1_with_evict_priority(%global_ptr: !llvm.ptr<1>) {
// expected-error @below {{cache eviction priority supported only for cache level L2}}
nvvm.prefetch level = L1, evict_priority = evict_last, %global_ptr : !llvm.ptr<1>
>From 587f17eea938cd065218a9d059e96d5594e8acdb Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Tue, 7 Oct 2025 12:25:41 +0000
Subject: [PATCH 3/5] fix formatting
---
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index cf92856461cc8..23fb350881611 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -2025,8 +2025,7 @@ ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
}
}
-llvm::Intrinsic::ID
-ConvertF32x2ToF4x2Op::getIntrinsicID(bool hasRelu) {
+llvm::Intrinsic::ID ConvertF32x2ToF4x2Op::getIntrinsicID(bool hasRelu) {
return hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite
: llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite;
}
>From 6cf5a899741b5ff96ba14d6bb6477f5b2bf0c613 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Fri, 10 Oct 2025 05:05:14 +0000
Subject: [PATCH 4/5] use getIntrinsicIDAndArgs
---
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 8 +++++---
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 18 +++++++++++++++---
2 files changed, 20 insertions(+), 6 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 347195dc87443..3a65555204c36 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1677,12 +1677,14 @@ def NVVM_ConvertF32x2ToF4x2Op : NVVM_Op<"convert.f32x2.to.f4x2"> {
let hasVerifier = 1;
let extraClassDeclaration = [{
- static llvm::Intrinsic::ID getIntrinsicID(bool hasRelu);
+ static mlir::NVVM::IDArgPair
+ getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op,
+ LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder);
}];
string llvmBuilder = [{
- auto intId = NVVM::ConvertF32x2ToF4x2Op::getIntrinsicID($relu);
- llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b});
+ auto [intId, args] = NVVM::ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(op, moduleTranslation, builder);
+ llvm::Value *packedI16 = createIntrinsicCall(builder, intId, args);
$dst = builder.CreateTruncOrBitCast(packedI16, llvm::Type::getInt8Ty(builder.getContext()));
}];
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 23fb350881611..37b4168386da8 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -2025,9 +2025,21 @@ ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
}
}
-llvm::Intrinsic::ID ConvertF32x2ToF4x2Op::getIntrinsicID(bool hasRelu) {
- return hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite
- : llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite;
+NVVM::IDArgPair
+ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op,
+ LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(mt.lookupValue(op.getA()));
+ args.push_back(mt.lookupValue(op.getB()));
+
+ bool hasRelu = op.getRelu();
+
+ llvm::Intrinsic::ID intId =
+ hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite
+ : llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite;
+
+ return {intId, std::move(args)};
}
#define GET_F32x2_TO_F6x2_ID(type, has_relu) \
>From 44691fb6fa47d058da2babaababf0a14d4a161ee Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Fri, 10 Oct 2025 05:09:40 +0000
Subject: [PATCH 5/5] fix formatting
---
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 37b4168386da8..3e4c7cd6826fc 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -2032,13 +2032,13 @@ ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op,
llvm::SmallVector<llvm::Value *> args;
args.push_back(mt.lookupValue(op.getA()));
args.push_back(mt.lookupValue(op.getB()));
-
+
bool hasRelu = op.getRelu();
-
+
llvm::Intrinsic::ID intId =
hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite
: llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite;
-
+
return {intId, std::move(args)};
}
More information about the Mlir-commits
mailing list