[Mlir-commits] [mlir] [MLIR][NVVM] Add support for f32x2 to f4x2 conversion (PR #162273)
Srinivasa Ravi
llvmlistbot at llvm.org
Tue Oct 7 05:26:02 PDT 2025
https://github.com/Wolfram70 updated https://github.com/llvm/llvm-project/pull/162273
>From 12ceea55d55f050cd1a850b7150140c349f40806 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Mon, 18 Aug 2025 16:02:03 +0530
Subject: [PATCH 1/3] [MLIR][NVVM] Add support for f32x2 to f4x2 conversion
This change adds the `convert.f32x2.to.f4x2` op to the NVVM Dialect
for converting a pair of f32 values to an f4x2 (`e2m1x2`) value.
PTX reference: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt
---
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 41 +++++++++++++++++++
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 6 +++
.../Target/LLVMIR/nvvm/convert_fp4x2.mlir | 12 ++++++
3 files changed, 59 insertions(+)
create mode 100644 mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index f8e3167b42c35..700b232013498 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1654,6 +1654,47 @@ def NVVM_ConvertFloatToTF32Op : NVVM_Op<"convert.float.to.tf32"> {
}];
}
+def ConvertFP4E2M1 : I32EnumCase<"E2M1", 0, "e2m1">;
+
+def ConvertFP4Type : I32Enum<"ConvertFP4Type", "NVVM ConvertFP4Type kind",
+ [ConvertFP4E2M1]> {
+ let cppNamespace = "::mlir::NVVM";
+}
+def ConvertFP4TypeAttr : EnumAttr<NVVM_Dialect, ConvertFP4Type,
+ "convert_fp4_type"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
+def NVVM_ConvertF32x2ToF4x2Op : NVVM_Op<"convert.f32x2.to.f4x2"> {
+ let summary = "Convert a pair of float inputs to f4x2";
+ let description = [{
+ This Op converts each of the given float inputs to the specified fp4 type.
+ The result `dst` is returned as an i8 type where the converted values are
+ packed such that the value converted from `a` is stored in the upper 4 bits
+ of `dst` and the value converted from `b` is stored in the lower 4 bits of
+ `dst`.
+ The `relu` attribute, when set, lowers to the '.relu' variant of
+ the cvt instruction.
+
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
+ }];
+
+ let results = (outs I8:$dst);
+ let arguments = (ins ConvertFP4TypeAttr:$type, F32:$a, F32:$b,
+ DefaultValuedAttr<BoolAttr, "false">:$relu);
+ let assemblyFormat = "$type $a `,` $b attr-dict";
+
+ let extraClassDeclaration = [{
+ static llvm::Intrinsic::ID getIntrinsicID(NVVM::ConvertFP4Type, bool hasRelu);
+ }];
+
+ string llvmBuilder = [{
+ auto intId = NVVM::ConvertF32x2ToF4x2Op::getIntrinsicID($type, $relu);
+ llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b});
+ $dst = builder.CreateTruncOrBitCast(packedI16, llvm::Type::getInt8Ty(builder.getContext()));
+ }];
+}
+
def ConvertFP6E2M3 : I32EnumAttrCase<"E2M3", 0, "e2m3">;
def ConvertFP6E3M2 : I32EnumAttrCase<"E3M2", 1, "e3m2">;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index e8f8824d47de0..824079fdfbc51 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1976,6 +1976,12 @@ ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
}
}
+llvm::Intrinsic::ID
+ConvertF32x2ToF4x2Op::getIntrinsicID(NVVM::ConvertFP4Type type, bool hasRelu) {
+ return hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite
+ : llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite;
+}
+
#define GET_F32x2_TO_F6x2_ID(type, has_relu) \
has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite \
: llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
new file mode 100644
index 0000000000000..f1599a6b51e3a
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
@@ -0,0 +1,12 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: @convert_f32x2_to_f4x2_e2m1
+llvm.func @convert_f32x2_to_f4x2_e2m1(%srcA : f32, %srcB : f32) {
+ // CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.ff.to.e2m1x2.rn.satfinite(float %{{.*}}, float %{{.*}})
+ // CHECK-NEXT: %{{.*}} = trunc i16 %[[res1]] to i8
+ %res1 = nvvm.convert.f32x2.to.f4x2 <e2m1> %srcA, %srcB
+ // CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.ff.to.e2m1x2.rn.relu.satfinite(float %{{.*}}, float %{{.*}})
+ // CHECK-NEXT: %{{.*}} = trunc i16 %[[res2]] to i8
+ %res2 = nvvm.convert.f32x2.to.f4x2 <e2m1> %srcA, %srcB {relu = true}
+ llvm.return
+}
>From fae7b3a666932a1d641b3df197577550616b033d Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Tue, 7 Oct 2025 12:20:12 +0000
Subject: [PATCH 2/3] update to use builtin types
---
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 23 ++++++-------------
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 13 ++++++++++-
.../Target/LLVMIR/nvvm/convert_fp4x2.mlir | 4 ++--
mlir/test/Target/LLVMIR/nvvmir-invalid.mlir | 8 +++++++
4 files changed, 29 insertions(+), 19 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 700b232013498..2bba893148318 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1654,17 +1654,6 @@ def NVVM_ConvertFloatToTF32Op : NVVM_Op<"convert.float.to.tf32"> {
}];
}
-def ConvertFP4E2M1 : I32EnumCase<"E2M1", 0, "e2m1">;
-
-def ConvertFP4Type : I32Enum<"ConvertFP4Type", "NVVM ConvertFP4Type kind",
- [ConvertFP4E2M1]> {
- let cppNamespace = "::mlir::NVVM";
-}
-def ConvertFP4TypeAttr : EnumAttr<NVVM_Dialect, ConvertFP4Type,
- "convert_fp4_type"> {
- let assemblyFormat = "`<` $value `>`";
-}
-
def NVVM_ConvertF32x2ToF4x2Op : NVVM_Op<"convert.f32x2.to.f4x2"> {
let summary = "Convert a pair of float inputs to f4x2";
let description = [{
@@ -1680,16 +1669,18 @@ def NVVM_ConvertF32x2ToF4x2Op : NVVM_Op<"convert.f32x2.to.f4x2"> {
}];
let results = (outs I8:$dst);
- let arguments = (ins ConvertFP4TypeAttr:$type, F32:$a, F32:$b,
- DefaultValuedAttr<BoolAttr, "false">:$relu);
- let assemblyFormat = "$type $a `,` $b attr-dict";
+ let arguments = (ins F32:$a, F32:$b,
+ DefaultValuedAttr<BoolAttr, "false">:$relu,
+ TypeAttr:$dstTy);
+ let assemblyFormat = "$a `,` $b attr-dict `:` type($dst) `(` $dstTy `)`";
+ let hasVerifier = 1;
let extraClassDeclaration = [{
- static llvm::Intrinsic::ID getIntrinsicID(NVVM::ConvertFP4Type, bool hasRelu);
+ static llvm::Intrinsic::ID getIntrinsicID(bool hasRelu);
}];
string llvmBuilder = [{
- auto intId = NVVM::ConvertF32x2ToF4x2Op::getIntrinsicID($type, $relu);
+ auto intId = NVVM::ConvertF32x2ToF4x2Op::getIntrinsicID($relu);
llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b});
$dst = builder.CreateTruncOrBitCast(packedI16, llvm::Type::getInt8Ty(builder.getContext()));
}];
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 824079fdfbc51..36988c53a39cd 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -271,6 +271,17 @@ LogicalResult ConvertBF16x2ToF8x2Op::verify() {
return success();
}
+LogicalResult ConvertF32x2ToF4x2Op::verify() {
+ mlir::MLIRContext *ctx = getContext();
+
+ if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
+ return emitOpError("Only ")
+ << mlir::Float4E2M1FNType::get(ctx)
+ << " type is supported for conversions from f32x2 to f4x2.";
+
+ return success();
+}
+
LogicalResult BulkStoreOp::verify() {
if (getInitVal() != 0)
return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
@@ -1977,7 +1988,7 @@ ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
}
llvm::Intrinsic::ID
-ConvertF32x2ToF4x2Op::getIntrinsicID(NVVM::ConvertFP4Type type, bool hasRelu) {
+ConvertF32x2ToF4x2Op::getIntrinsicID(bool hasRelu) {
return hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite
: llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite;
}
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
index f1599a6b51e3a..04e2ddff802a9 100644
--- a/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
@@ -4,9 +4,9 @@
llvm.func @convert_f32x2_to_f4x2_e2m1(%srcA : f32, %srcB : f32) {
// CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.ff.to.e2m1x2.rn.satfinite(float %{{.*}}, float %{{.*}})
// CHECK-NEXT: %{{.*}} = trunc i16 %[[res1]] to i8
- %res1 = nvvm.convert.f32x2.to.f4x2 <e2m1> %srcA, %srcB
+ %res1 = nvvm.convert.f32x2.to.f4x2 %srcA, %srcB : i8 (f4E2M1FN)
// CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.ff.to.e2m1x2.rn.relu.satfinite(float %{{.*}}, float %{{.*}})
// CHECK-NEXT: %{{.*}} = trunc i16 %[[res2]] to i8
- %res2 = nvvm.convert.f32x2.to.f4x2 <e2m1> %srcA, %srcB {relu = true}
+ %res2 = nvvm.convert.f32x2.to.f4x2 %srcA, %srcB {relu = true} : i8 (f4E2M1FN)
llvm.return
}
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index 383f4829f3287..79ad6e847fa5c 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -246,6 +246,14 @@ llvm.func @nvvm_cvt_bf16x2_to_f8x2_invalid_rounding(%src : vector<2xbf16>) {
// -----
+llvm.func @nvvm_cvt_f32x2_to_f4x2_invalid_type(%a : f32, %b : f32) {
+ // expected-error @below {{Only 'f4E2M1FN' type is supported for conversions from f32x2 to f4x2.}}
+ %res = nvvm.convert.f32x2.to.f4x2 %a, %b : i8 (f8E4M3FN)
+ llvm.return
+}
+
+// -----
+
llvm.func @nvvm_prefetch_L1_with_evict_priority(%global_ptr: !llvm.ptr<1>) {
// expected-error @below {{cache eviction priority supported only for cache level L2}}
nvvm.prefetch level = L1, evict_priority = evict_last, %global_ptr : !llvm.ptr<1>
>From 0c8e78787e4e2aeadf7075eae292f2e96a2cedcd Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Tue, 7 Oct 2025 12:25:41 +0000
Subject: [PATCH 3/3] fix formatting
---
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 36988c53a39cd..732f790809198 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1987,8 +1987,7 @@ ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
}
}
-llvm::Intrinsic::ID
-ConvertF32x2ToF4x2Op::getIntrinsicID(bool hasRelu) {
+llvm::Intrinsic::ID ConvertF32x2ToF4x2Op::getIntrinsicID(bool hasRelu) {
return hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite
: llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite;
}
More information about the Mlir-commits
mailing list