[Mlir-commits] [mlir] [MLIR][NVVM] Add support for f32x2 to f4x2 conversion (PR #162273)

Srinivasa Ravi llvmlistbot at llvm.org
Tue Oct 7 05:26:02 PDT 2025


https://github.com/Wolfram70 updated https://github.com/llvm/llvm-project/pull/162273

>From 12ceea55d55f050cd1a850b7150140c349f40806 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Mon, 18 Aug 2025 16:02:03 +0530
Subject: [PATCH 1/3] [MLIR][NVVM] Add support for f32x2 to f4x2 conversion

This change adds the `convert.f32x2.to.f4x2` op to the NVVM Dialect
for converting a pair of f32 values to an f4x2 (`e2m1x2`) value.

PTX reference: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt
---
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td   | 41 +++++++++++++++++++
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp    |  6 +++
 .../Target/LLVMIR/nvvm/convert_fp4x2.mlir     | 12 ++++++
 3 files changed, 59 insertions(+)
 create mode 100644 mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index f8e3167b42c35..700b232013498 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1654,6 +1654,47 @@ def NVVM_ConvertFloatToTF32Op : NVVM_Op<"convert.float.to.tf32"> {
   }];
 }
 
+def ConvertFP4E2M1 : I32EnumCase<"E2M1", 0, "e2m1">;
+
+def ConvertFP4Type : I32Enum<"ConvertFP4Type", "NVVM ConvertFP4Type kind",
+                             [ConvertFP4E2M1]> {
+  let cppNamespace = "::mlir::NVVM";
+}
+def ConvertFP4TypeAttr : EnumAttr<NVVM_Dialect, ConvertFP4Type, 
+                                  "convert_fp4_type"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
+def NVVM_ConvertF32x2ToF4x2Op : NVVM_Op<"convert.f32x2.to.f4x2"> {
+  let summary = "Convert a pair of float inputs to f4x2";
+  let description = [{
+    This Op converts each of the given float inputs to the specified fp4 type.
+    The result `dst` is returned as an i8 type where the converted values are 
+    packed such that the value converted from `a` is stored in the upper 4 bits 
+    of `dst` and the value converted from `b` is stored in the lower 4 bits of 
+    `dst`.
+    The `relu` attribute, when set, lowers to the '.relu' variant of
+    the cvt instruction.
+
+    [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
+  }];
+
+  let results = (outs I8:$dst);
+  let arguments = (ins ConvertFP4TypeAttr:$type, F32:$a, F32:$b,
+                       DefaultValuedAttr<BoolAttr, "false">:$relu);
+  let assemblyFormat = "$type $a `,` $b attr-dict";
+
+  let extraClassDeclaration = [{
+    static llvm::Intrinsic::ID getIntrinsicID(NVVM::ConvertFP4Type, bool hasRelu);
+  }];
+
+  string llvmBuilder = [{
+    auto intId = NVVM::ConvertF32x2ToF4x2Op::getIntrinsicID($type, $relu);
+    llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b});
+    $dst = builder.CreateTruncOrBitCast(packedI16, llvm::Type::getInt8Ty(builder.getContext()));
+  }];
+}
+
 def ConvertFP6E2M3 : I32EnumAttrCase<"E2M3", 0, "e2m3">;
 def ConvertFP6E3M2 : I32EnumAttrCase<"E3M2", 1, "e3m2">;
 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index e8f8824d47de0..824079fdfbc51 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1976,6 +1976,12 @@ ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
   }
 }
 
+llvm::Intrinsic::ID
+ConvertF32x2ToF4x2Op::getIntrinsicID(NVVM::ConvertFP4Type type, bool hasRelu) {
+  return hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite
+                 : llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite;
+}
+
 #define GET_F32x2_TO_F6x2_ID(type, has_relu)                                   \
   has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite            \
            : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
new file mode 100644
index 0000000000000..f1599a6b51e3a
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
@@ -0,0 +1,12 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: @convert_f32x2_to_f4x2_e2m1
+llvm.func @convert_f32x2_to_f4x2_e2m1(%srcA : f32, %srcB : f32) {
+  // CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.ff.to.e2m1x2.rn.satfinite(float %{{.*}}, float %{{.*}})
+  // CHECK-NEXT: %{{.*}} = trunc i16 %[[res1]] to i8
+  %res1 = nvvm.convert.f32x2.to.f4x2 <e2m1> %srcA, %srcB
+  // CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.ff.to.e2m1x2.rn.relu.satfinite(float %{{.*}}, float %{{.*}})
+  // CHECK-NEXT: %{{.*}} = trunc i16 %[[res2]] to i8
+  %res2 = nvvm.convert.f32x2.to.f4x2 <e2m1> %srcA, %srcB {relu = true}
+  llvm.return
+}

>From fae7b3a666932a1d641b3df197577550616b033d Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Tue, 7 Oct 2025 12:20:12 +0000
Subject: [PATCH 2/3] update to use builtin types

---
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td   | 23 ++++++-------------
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp    | 13 ++++++++++-
 .../Target/LLVMIR/nvvm/convert_fp4x2.mlir     |  4 ++--
 mlir/test/Target/LLVMIR/nvvmir-invalid.mlir   |  8 +++++++
 4 files changed, 29 insertions(+), 19 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 700b232013498..2bba893148318 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1654,17 +1654,6 @@ def NVVM_ConvertFloatToTF32Op : NVVM_Op<"convert.float.to.tf32"> {
   }];
 }
 
-def ConvertFP4E2M1 : I32EnumCase<"E2M1", 0, "e2m1">;
-
-def ConvertFP4Type : I32Enum<"ConvertFP4Type", "NVVM ConvertFP4Type kind",
-                             [ConvertFP4E2M1]> {
-  let cppNamespace = "::mlir::NVVM";
-}
-def ConvertFP4TypeAttr : EnumAttr<NVVM_Dialect, ConvertFP4Type, 
-                                  "convert_fp4_type"> {
-  let assemblyFormat = "`<` $value `>`";
-}
-
 def NVVM_ConvertF32x2ToF4x2Op : NVVM_Op<"convert.f32x2.to.f4x2"> {
   let summary = "Convert a pair of float inputs to f4x2";
   let description = [{
@@ -1680,16 +1669,18 @@ def NVVM_ConvertF32x2ToF4x2Op : NVVM_Op<"convert.f32x2.to.f4x2"> {
   }];
 
   let results = (outs I8:$dst);
-  let arguments = (ins ConvertFP4TypeAttr:$type, F32:$a, F32:$b,
-                       DefaultValuedAttr<BoolAttr, "false">:$relu);
-  let assemblyFormat = "$type $a `,` $b attr-dict";
+  let arguments = (ins F32:$a, F32:$b,
+                       DefaultValuedAttr<BoolAttr, "false">:$relu,
+                       TypeAttr:$dstTy);
+  let assemblyFormat = "$a `,` $b attr-dict `:` type($dst) `(` $dstTy `)`";
+  let hasVerifier = 1;
 
   let extraClassDeclaration = [{
-    static llvm::Intrinsic::ID getIntrinsicID(NVVM::ConvertFP4Type, bool hasRelu);
+    static llvm::Intrinsic::ID getIntrinsicID(bool hasRelu);
   }];
 
   string llvmBuilder = [{
-    auto intId = NVVM::ConvertF32x2ToF4x2Op::getIntrinsicID($type, $relu);
+    auto intId = NVVM::ConvertF32x2ToF4x2Op::getIntrinsicID($relu);
     llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b});
     $dst = builder.CreateTruncOrBitCast(packedI16, llvm::Type::getInt8Ty(builder.getContext()));
   }];
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 824079fdfbc51..36988c53a39cd 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -271,6 +271,17 @@ LogicalResult ConvertBF16x2ToF8x2Op::verify() {
   return success();
 }
 
+LogicalResult ConvertF32x2ToF4x2Op::verify() {
+  mlir::MLIRContext *ctx = getContext();
+
+  if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
+    return emitOpError("Only ")
+           << mlir::Float4E2M1FNType::get(ctx)
+           << " type is supported for conversions from f32x2 to f4x2.";
+
+  return success();
+}
+
 LogicalResult BulkStoreOp::verify() {
   if (getInitVal() != 0)
     return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
@@ -1977,7 +1988,7 @@ ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
 }
 
 llvm::Intrinsic::ID
-ConvertF32x2ToF4x2Op::getIntrinsicID(NVVM::ConvertFP4Type type, bool hasRelu) {
+ConvertF32x2ToF4x2Op::getIntrinsicID(bool hasRelu) {
   return hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite
                  : llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite;
 }
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
index f1599a6b51e3a..04e2ddff802a9 100644
--- a/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
@@ -4,9 +4,9 @@
 llvm.func @convert_f32x2_to_f4x2_e2m1(%srcA : f32, %srcB : f32) {
   // CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.ff.to.e2m1x2.rn.satfinite(float %{{.*}}, float %{{.*}})
   // CHECK-NEXT: %{{.*}} = trunc i16 %[[res1]] to i8
-  %res1 = nvvm.convert.f32x2.to.f4x2 <e2m1> %srcA, %srcB
+  %res1 = nvvm.convert.f32x2.to.f4x2 %srcA, %srcB : i8 (f4E2M1FN)
   // CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.ff.to.e2m1x2.rn.relu.satfinite(float %{{.*}}, float %{{.*}})
   // CHECK-NEXT: %{{.*}} = trunc i16 %[[res2]] to i8
-  %res2 = nvvm.convert.f32x2.to.f4x2 <e2m1> %srcA, %srcB {relu = true}
+  %res2 = nvvm.convert.f32x2.to.f4x2 %srcA, %srcB {relu = true} : i8 (f4E2M1FN)
   llvm.return
 }
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index 383f4829f3287..79ad6e847fa5c 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -246,6 +246,14 @@ llvm.func @nvvm_cvt_bf16x2_to_f8x2_invalid_rounding(%src : vector<2xbf16>) {
 
 // -----
 
+llvm.func @nvvm_cvt_f32x2_to_f4x2_invalid_type(%a : f32, %b : f32) {
+  // expected-error @below {{Only 'f4E2M1FN' type is supported for conversions from f32x2 to f4x2.}}
+  %res = nvvm.convert.f32x2.to.f4x2 %a, %b : i8 (f8E4M3FN)
+  llvm.return
+}
+
+// -----
+
 llvm.func @nvvm_prefetch_L1_with_evict_priority(%global_ptr: !llvm.ptr<1>) {
   // expected-error @below {{cache eviction priority supported only for cache level L2}}
   nvvm.prefetch level = L1, evict_priority = evict_last, %global_ptr : !llvm.ptr<1>

>From 0c8e78787e4e2aeadf7075eae292f2e96a2cedcd Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Tue, 7 Oct 2025 12:25:41 +0000
Subject: [PATCH 3/3] fix formatting

---
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 36988c53a39cd..732f790809198 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1987,8 +1987,7 @@ ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
   }
 }
 
-llvm::Intrinsic::ID
-ConvertF32x2ToF4x2Op::getIntrinsicID(bool hasRelu) {
+llvm::Intrinsic::ID ConvertF32x2ToF4x2Op::getIntrinsicID(bool hasRelu) {
   return hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite
                  : llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite;
 }



More information about the Mlir-commits mailing list