[Mlir-commits] [mlir] [MLIR][NVVM] Add support for f32x2 to f4x2 conversion (PR #162273)

Srinivasa Ravi llvmlistbot at llvm.org
Thu Oct 9 22:10:09 PDT 2025


https://github.com/Wolfram70 updated https://github.com/llvm/llvm-project/pull/162273

>From e5acd779f71e5e32b707cf9b28a35e8de7894b8e Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Mon, 18 Aug 2025 16:02:03 +0530
Subject: [PATCH 1/5] [MLIR][NVVM] Add support for f32x2 to f4x2 conversion

This change adds the `convert.f32x2.to.f4x2` op to the NVVM Dialect
for converting a pair of f32 values to an f4x2 (`e2m1x2`) value.

PTX reference: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt
---
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td   | 41 +++++++++++++++++++
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp    |  6 +++
 .../Target/LLVMIR/nvvm/convert_fp4x2.mlir     | 12 ++++++
 3 files changed, 59 insertions(+)
 create mode 100644 mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index e2a0331542742..518acc2afbc8e 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1655,6 +1655,47 @@ def NVVM_ConvertFloatToTF32Op : NVVM_Op<"convert.float.to.tf32"> {
   }];
 }
 
+def ConvertFP4E2M1 : I32EnumCase<"E2M1", 0, "e2m1">;
+
+def ConvertFP4Type : I32Enum<"ConvertFP4Type", "NVVM ConvertFP4Type kind",
+                             [ConvertFP4E2M1]> {
+  let cppNamespace = "::mlir::NVVM";
+}
+def ConvertFP4TypeAttr : EnumAttr<NVVM_Dialect, ConvertFP4Type, 
+                                  "convert_fp4_type"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
+def NVVM_ConvertF32x2ToF4x2Op : NVVM_Op<"convert.f32x2.to.f4x2"> {
+  let summary = "Convert a pair of float inputs to f4x2";
+  let description = [{
+    This Op converts each of the given float inputs to the specified fp4 type.
+    The result `dst` is returned as an i8 type where the converted values are 
+    packed such that the value converted from `a` is stored in the upper 4 bits 
+    of `dst` and the value converted from `b` is stored in the lower 4 bits of 
+    `dst`.
+    The `relu` attribute, when set, lowers to the '.relu' variant of
+    the cvt instruction.
+
+    [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
+  }];
+
+  let results = (outs I8:$dst);
+  let arguments = (ins ConvertFP4TypeAttr:$type, F32:$a, F32:$b,
+                       DefaultValuedAttr<BoolAttr, "false">:$relu);
+  let assemblyFormat = "$type $a `,` $b attr-dict";
+
+  let extraClassDeclaration = [{
+    static llvm::Intrinsic::ID getIntrinsicID(NVVM::ConvertFP4Type, bool hasRelu);
+  }];
+
+  string llvmBuilder = [{
+    auto intId = NVVM::ConvertF32x2ToF4x2Op::getIntrinsicID($type, $relu);
+    llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b});
+    $dst = builder.CreateTruncOrBitCast(packedI16, llvm::Type::getInt8Ty(builder.getContext()));
+  }];
+}
+
 def NVVM_ConvertF32x2ToF6x2Op : NVVM_Op<"convert.f32x2.to.f6x2"> {
   let summary = "Convert a pair of float inputs to f6x2";
   let description = [{
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 7f419a062201d..ac6d4f3d361b9 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -2014,6 +2014,12 @@ ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
   }
 }
 
+llvm::Intrinsic::ID
+ConvertF32x2ToF4x2Op::getIntrinsicID(NVVM::ConvertFP4Type type, bool hasRelu) {
+  return hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite
+                 : llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite;
+}
+
 #define GET_F32x2_TO_F6x2_ID(type, has_relu)                                   \
   has_relu ? llvm::Intrinsic::nvvm_ff_to_##type##_rn_relu_satfinite            \
            : llvm::Intrinsic::nvvm_ff_to_##type##_rn_satfinite
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
new file mode 100644
index 0000000000000..f1599a6b51e3a
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
@@ -0,0 +1,12 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: @convert_f32x2_to_f4x2_e2m1
+llvm.func @convert_f32x2_to_f4x2_e2m1(%srcA : f32, %srcB : f32) {
+  // CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.ff.to.e2m1x2.rn.satfinite(float %{{.*}}, float %{{.*}})
+  // CHECK-NEXT: %{{.*}} = trunc i16 %[[res1]] to i8
+  %res1 = nvvm.convert.f32x2.to.f4x2 <e2m1> %srcA, %srcB
+  // CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.ff.to.e2m1x2.rn.relu.satfinite(float %{{.*}}, float %{{.*}})
+  // CHECK-NEXT: %{{.*}} = trunc i16 %[[res2]] to i8
+  %res2 = nvvm.convert.f32x2.to.f4x2 <e2m1> %srcA, %srcB {relu = true}
+  llvm.return
+}

>From 8d9d1ab53a4acf5fb0389a3116588b03e5e55625 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Tue, 7 Oct 2025 12:20:12 +0000
Subject: [PATCH 2/5] update to use builtin types

---
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td   | 23 ++++++-------------
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp    | 13 ++++++++++-
 .../Target/LLVMIR/nvvm/convert_fp4x2.mlir     |  4 ++--
 mlir/test/Target/LLVMIR/nvvmir-invalid.mlir   |  8 +++++++
 4 files changed, 29 insertions(+), 19 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 518acc2afbc8e..347195dc87443 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1655,17 +1655,6 @@ def NVVM_ConvertFloatToTF32Op : NVVM_Op<"convert.float.to.tf32"> {
   }];
 }
 
-def ConvertFP4E2M1 : I32EnumCase<"E2M1", 0, "e2m1">;
-
-def ConvertFP4Type : I32Enum<"ConvertFP4Type", "NVVM ConvertFP4Type kind",
-                             [ConvertFP4E2M1]> {
-  let cppNamespace = "::mlir::NVVM";
-}
-def ConvertFP4TypeAttr : EnumAttr<NVVM_Dialect, ConvertFP4Type, 
-                                  "convert_fp4_type"> {
-  let assemblyFormat = "`<` $value `>`";
-}
-
 def NVVM_ConvertF32x2ToF4x2Op : NVVM_Op<"convert.f32x2.to.f4x2"> {
   let summary = "Convert a pair of float inputs to f4x2";
   let description = [{
@@ -1681,16 +1670,18 @@ def NVVM_ConvertF32x2ToF4x2Op : NVVM_Op<"convert.f32x2.to.f4x2"> {
   }];
 
   let results = (outs I8:$dst);
-  let arguments = (ins ConvertFP4TypeAttr:$type, F32:$a, F32:$b,
-                       DefaultValuedAttr<BoolAttr, "false">:$relu);
-  let assemblyFormat = "$type $a `,` $b attr-dict";
+  let arguments = (ins F32:$a, F32:$b,
+                       DefaultValuedAttr<BoolAttr, "false">:$relu,
+                       TypeAttr:$dstTy);
+  let assemblyFormat = "$a `,` $b attr-dict `:` type($dst) `(` $dstTy `)`";
+  let hasVerifier = 1;
 
   let extraClassDeclaration = [{
-    static llvm::Intrinsic::ID getIntrinsicID(NVVM::ConvertFP4Type, bool hasRelu);
+    static llvm::Intrinsic::ID getIntrinsicID(bool hasRelu);
   }];
 
   string llvmBuilder = [{
-    auto intId = NVVM::ConvertF32x2ToF4x2Op::getIntrinsicID($type, $relu);
+    auto intId = NVVM::ConvertF32x2ToF4x2Op::getIntrinsicID($relu);
     llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b});
     $dst = builder.CreateTruncOrBitCast(packedI16, llvm::Type::getInt8Ty(builder.getContext()));
   }];
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index ac6d4f3d361b9..cf92856461cc8 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -309,6 +309,17 @@ LogicalResult ConvertBF16x2ToF8x2Op::verify() {
   return success();
 }
 
+LogicalResult ConvertF32x2ToF4x2Op::verify() {
+  mlir::MLIRContext *ctx = getContext();
+
+  if (!llvm::isa<mlir::Float4E2M1FNType>(getDstTy()))
+    return emitOpError("Only ")
+           << mlir::Float4E2M1FNType::get(ctx)
+           << " type is supported for conversions from f32x2 to f4x2.";
+
+  return success();
+}
+
 LogicalResult BulkStoreOp::verify() {
   if (getInitVal() != 0)
     return emitOpError("only 0 is supported for initVal, got ") << getInitVal();
@@ -2015,7 +2026,7 @@ ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
 }
 
 llvm::Intrinsic::ID
-ConvertF32x2ToF4x2Op::getIntrinsicID(NVVM::ConvertFP4Type type, bool hasRelu) {
+ConvertF32x2ToF4x2Op::getIntrinsicID(bool hasRelu) {
   return hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite
                  : llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite;
 }
diff --git a/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
index f1599a6b51e3a..04e2ddff802a9 100644
--- a/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/convert_fp4x2.mlir
@@ -4,9 +4,9 @@
 llvm.func @convert_f32x2_to_f4x2_e2m1(%srcA : f32, %srcB : f32) {
   // CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.ff.to.e2m1x2.rn.satfinite(float %{{.*}}, float %{{.*}})
   // CHECK-NEXT: %{{.*}} = trunc i16 %[[res1]] to i8
-  %res1 = nvvm.convert.f32x2.to.f4x2 <e2m1> %srcA, %srcB
+  %res1 = nvvm.convert.f32x2.to.f4x2 %srcA, %srcB : i8 (f4E2M1FN)
   // CHECK: %[[res2:.*]] = call i16 @llvm.nvvm.ff.to.e2m1x2.rn.relu.satfinite(float %{{.*}}, float %{{.*}})
   // CHECK-NEXT: %{{.*}} = trunc i16 %[[res2]] to i8
-  %res2 = nvvm.convert.f32x2.to.f4x2 <e2m1> %srcA, %srcB {relu = true}
+  %res2 = nvvm.convert.f32x2.to.f4x2 %srcA, %srcB {relu = true} : i8 (f4E2M1FN)
   llvm.return
 }
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index 0b3615487716d..78e1e659ed85d 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -254,6 +254,14 @@ llvm.func @nvvm_cvt_f32x2_to_f6x2_invalid_type(%a : f32, %b : f32) {
 
 // -----
 
+llvm.func @nvvm_cvt_f32x2_to_f4x2_invalid_type(%a : f32, %b : f32) {
+  // expected-error @below {{Only 'f4E2M1FN' type is supported for conversions from f32x2 to f4x2.}}
+  %res = nvvm.convert.f32x2.to.f4x2 %a, %b : i8 (f8E4M3FN)
+  llvm.return
+}
+
+// -----
+
 llvm.func @nvvm_prefetch_L1_with_evict_priority(%global_ptr: !llvm.ptr<1>) {
   // expected-error @below {{cache eviction priority supported only for cache level L2}}
   nvvm.prefetch level = L1, evict_priority = evict_last, %global_ptr : !llvm.ptr<1>

>From 587f17eea938cd065218a9d059e96d5594e8acdb Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Tue, 7 Oct 2025 12:25:41 +0000
Subject: [PATCH 3/5] fix formatting

---
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index cf92856461cc8..23fb350881611 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -2025,8 +2025,7 @@ ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
   }
 }
 
-llvm::Intrinsic::ID
-ConvertF32x2ToF4x2Op::getIntrinsicID(bool hasRelu) {
+llvm::Intrinsic::ID ConvertF32x2ToF4x2Op::getIntrinsicID(bool hasRelu) {
   return hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite
                  : llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite;
 }

>From 6cf5a899741b5ff96ba14d6bb6477f5b2bf0c613 Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Fri, 10 Oct 2025 05:05:14 +0000
Subject: [PATCH 4/5] use getIntrinsicIDAndArgs

---
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td |  8 +++++---
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp  | 18 +++++++++++++++---
 2 files changed, 20 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 347195dc87443..3a65555204c36 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1677,12 +1677,14 @@ def NVVM_ConvertF32x2ToF4x2Op : NVVM_Op<"convert.f32x2.to.f4x2"> {
   let hasVerifier = 1;
 
   let extraClassDeclaration = [{
-    static llvm::Intrinsic::ID getIntrinsicID(bool hasRelu);
+    static mlir::NVVM::IDArgPair
+    getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op, 
+      LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder);
   }];
 
   string llvmBuilder = [{
-    auto intId = NVVM::ConvertF32x2ToF4x2Op::getIntrinsicID($relu);
-    llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b});
+    auto [intId, args] = NVVM::ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(op, moduleTranslation, builder);
+    llvm::Value *packedI16 = createIntrinsicCall(builder, intId, args);
     $dst = builder.CreateTruncOrBitCast(packedI16, llvm::Type::getInt8Ty(builder.getContext()));
   }];
 }
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 23fb350881611..37b4168386da8 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -2025,9 +2025,21 @@ ConvertFloatToTF32Op::getIntrinsicID(NVVM::FPRoundingMode rnd,
   }
 }
 
-llvm::Intrinsic::ID ConvertF32x2ToF4x2Op::getIntrinsicID(bool hasRelu) {
-  return hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite
-                 : llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite;
+NVVM::IDArgPair
+ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op,
+                                            LLVM::ModuleTranslation &mt,
+                                            llvm::IRBuilderBase &builder) {
+  llvm::SmallVector<llvm::Value *> args;
+  args.push_back(mt.lookupValue(op.getA()));
+  args.push_back(mt.lookupValue(op.getB()));
+  
+  bool hasRelu = op.getRelu();
+  
+  llvm::Intrinsic::ID intId =
+      hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite
+              : llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite;
+              
+  return {intId, std::move(args)};
 }
 
 #define GET_F32x2_TO_F6x2_ID(type, has_relu)                                   \

>From 44691fb6fa47d058da2babaababf0a14d4a161ee Mon Sep 17 00:00:00 2001
From: Srinivasa Ravi <srinivasar at nvidia.com>
Date: Fri, 10 Oct 2025 05:09:40 +0000
Subject: [PATCH 5/5] fix formatting

---
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 37b4168386da8..3e4c7cd6826fc 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -2032,13 +2032,13 @@ ConvertF32x2ToF4x2Op::getIntrinsicIDAndArgs(NVVM::ConvertF32x2ToF4x2Op op,
   llvm::SmallVector<llvm::Value *> args;
   args.push_back(mt.lookupValue(op.getA()));
   args.push_back(mt.lookupValue(op.getB()));
-  
+
   bool hasRelu = op.getRelu();
-  
+
   llvm::Intrinsic::ID intId =
       hasRelu ? llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_relu_satfinite
               : llvm::Intrinsic::nvvm_ff_to_e2m1x2_rn_satfinite;
-              
+
   return {intId, std::move(args)};
 }
 



More information about the Mlir-commits mailing list