[Mlir-commits] [mlir] [MLIR][NVVM] Support packed registers in `inline_ptx` (PR #154904)

Guray Ozen llvmlistbot at llvm.org
Tue Sep 2 05:05:10 PDT 2025


https://github.com/grypp updated https://github.com/llvm/llvm-project/pull/154904

>From 118deb1c1fc73d41de52194ef89c38435f4428c3 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Fri, 22 Aug 2025 07:31:26 +0000
Subject: [PATCH 1/2] [MLIR]\[NVVM] Support packed registers in `inline_ptx`**

Add support for packed registers with vectors.

**Example:**

```mlir
%wo0 = nvvm.inline_ptx
          "dp4a.s32.s32 {$w0}, {$r0}, {$r1}, {$r2};"
          ro(%src, %mask, %zero : vector<4xi8>, i32, i32)
          -> i32
```

Here, `vector<4xi8>` is lowered to an `i32` register (i.e., an `r` in PTX).
---
 .../LLVMIR/IR/BasicPtxBuilderInterface.cpp    | 118 ++++++++++++++----
 .../Conversion/NVVMToLLVM/nvvm-to-llvm.mlir   |  37 ++++++
 2 files changed, 129 insertions(+), 26 deletions(-)

diff --git a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
index 6d2a64f94e3ca..a8b2663852e59 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
@@ -12,10 +12,17 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/MLIRContext.h"
 
+#include "mlir/Support/LLVM.h"
 #include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/DebugLog.h"
 #include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/LogicalResult.h"
 #include "llvm/Support/Regex.h"
 
 #define DEBUG_TYPE "ptx-builder"
@@ -31,35 +38,87 @@ using namespace NVVM;
 
 static constexpr int64_t kSharedMemorySpace = 3;
 
-static char getRegisterType(Type type) {
-  if (type.isInteger(1))
-    return 'b';
-  if (type.isInteger(16))
-    return 'h';
-  if (type.isInteger(32))
-    return 'r';
-  if (type.isInteger(64))
-    return 'l';
-  if (type.isF32())
-    return 'f';
-  if (type.isF64())
-    return 'd';
-  if (auto ptr = dyn_cast<LLVM::LLVMPointerType>(type)) {
-    // Shared address spaces is addressed with 32-bit pointers.
-    if (ptr.getAddressSpace() == kSharedMemorySpace) {
+static FailureOr<char> getRegisterType(Type type, Location loc) {
+  MLIRContext *ctx = type.getContext();
+  auto i16 = IntegerType::get(ctx, 16);
+  auto i32 = IntegerType::get(ctx, 32);
+  auto f32 = Float32Type::get(ctx);
+
+  auto getRegisterTypeForScalar = [&](Type type) -> FailureOr<char> {
+    if (type.isInteger(1))
+      return 'b';
+    if (type.isInteger(16))
+      return 'h';
+    if (type.isInteger(32))
       return 'r';
+    if (type.isInteger(64))
+      return 'l';
+    if (type.isF32())
+      return 'f';
+    if (type.isF64())
+      return 'd';
+    if (auto ptr = dyn_cast<LLVM::LLVMPointerType>(type)) {
+      // Shared address spaces is addressed with 32-bit pointers.
+      if (ptr.getAddressSpace() == kSharedMemorySpace) {
+        return 'r';
+      }
+      return 'l';
     }
-    return 'l';
+    // register type for struct is not supported.
+    mlir::emitError(
+        loc, "The register type could not be deduced from MLIR type. The ")
+        << type
+        << " is not supported. Supported types are: i1, i16, i32, f32, f64, "
+           "pointers.\nSee the constraints from here: "
+           "https://docs.nvidia.com/cuda/inline-ptx-assembly/"
+           "index.html#constraints";
+    return failure();
+  };
+
+  // Packed registers
+  if (auto v = dyn_cast<VectorType>(type)) {
+    assert(v.getNumDynamicDims() == 0 && "Dynamic vectors are not supported");
+
+    int64_t lanes = v.getNumElements();
+    Type elem = v.getElementType();
+
+    // Case 1. Single vector
+    if (lanes <= 1)
+      return getRegisterTypeForScalar(elem);
+
+    // Case 2. Packed registers
+    Type widened = elem;
+    switch (lanes) {
+      // Pack 2x
+    case 2:
+      if (elem.isF16() || elem.isBF16())
+        widened = f32;
+      else if (elem.isFloat(8))
+        widened = i16;
+      break;
+      // Pack 4x
+    case 4:
+      if (elem.isInteger(8))
+        widened = i32;
+      else if (elem.isFloat(8))
+        widened = f32;
+      else if (elem.isFloat(4))
+        widened = i16;
+      break;
+      // Other packing is not supported
+    default:
+      break;
+    }
+    return getRegisterTypeForScalar(widened);
   }
-  // register type for struct is not supported.
-  llvm_unreachable("The register type could not deduced from MLIR type");
-  return '?';
+
+  return getRegisterTypeForScalar(type);
 }
 
-static char getRegisterType(Value v) {
+static FailureOr<char> getRegisterType(Value v, Location loc) {
   if (v.getDefiningOp<LLVM::ConstantOp>())
     return 'n';
-  return getRegisterType(v.getType());
+  return getRegisterType(v.getType(), loc);
 }
 
 /// Extract every element of a struct value.
@@ -79,6 +138,7 @@ void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
   LDBG() << v << "\t Modifier : " << itype << "\n";
   registerModifiers.push_back(itype);
 
+  Location loc = interfaceOp->getLoc();
   auto getModifier = [&]() -> const char * {
     switch (itype) {
     case PTXRegisterMod::Read:
@@ -111,21 +171,27 @@ void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
     }
     for (auto [idx, t] : llvm::enumerate(stype.getBody())) {
       if (itype != PTXRegisterMod::Write) {
-        Value extractValue = LLVM::ExtractValueOp::create(
-            rewriter, interfaceOp->getLoc(), v, idx);
+        Value extractValue =
+            LLVM::ExtractValueOp::create(rewriter, loc, v, idx);
         addValue(extractValue);
       }
       if (itype == PTXRegisterMod::ReadWrite) {
         ss << idx << ",";
       } else {
-        ss << getModifier() << getRegisterType(t) << ",";
+        FailureOr<char> regType = getRegisterType(t, loc);
+        if (failed(regType))
+          (void)rewriter.notifyMatchFailure(loc, "failed to get register type");
+        ss << getModifier() << regType.value() << ",";
       }
     }
     return;
   }
   // Handle Scalars
   addValue(v);
-  ss << getModifier() << getRegisterType(v) << ",";
+  FailureOr<char> regType = getRegisterType(v, loc);
+  if (failed(regType))
+    (void)rewriter.notifyMatchFailure(loc, "failed to get register type");
+  ss << getModifier() << regType.value() << ",";
 }
 
 /// Check if the operation needs to pack and unpack results.
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 2a19c72ab0840..89828b2077d8e 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -756,3 +756,40 @@ llvm.func @nvvm_pmevent() {
   nvvm.pmevent id = 4
   llvm.return
 }
+
+// -----
+
+llvm.func @inline_ptx_pack_4i8(%src : vector<4xi8>)  {
+ %mask = arith.constant 0x00000001 : i32
+ %zero = arith.constant 0 : i32
+// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "dp4a.s32.s32 $0, $1, $2, $3;", "=r,r,n,n" %{{.*}}, %{{.*}}, %{{.*}} : (vector<4xi8>, i32, i32) -> i32
+ %wo0 = nvvm.inline_ptx "dp4a.s32.s32 {$w0}, {$r0}, {$r1}, {$r2};" 
+                        ro(%src, %mask, %zero : vector<4xi8>, i32, i32) 
+                        -> i32
+ llvm.return  
+}
+
+llvm.func @inline_ptx_pack_2bf16(%a : f32, %b : f32)  {
+ // CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "cvt.rn.satfinite.bf16x2.f32 $0, $1, $2;", "=f,f,f" %{{.*}}, %{{.*}} : (f32, f32) -> vector<2xbf16>
+ %wo0 = nvvm.inline_ptx "cvt.rn.satfinite.bf16x2.f32 {$w0}, {$r0}, {$r1};" 
+                        ro(%a, %b : f32, f32) 
+                        -> vector<2xbf16>
+ llvm.return  
+}
+
+llvm.func @inline_ptx_cvt_rn_e4m3x2_f16x2(%a : i16)  {
+// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "cvt.rz.satfinite.ue8m0x2.bf16x2 $0, $1", "=f,h" %{{.*}} : (i16) -> vector<2xbf16>
+ %wo0 = nvvm.inline_ptx "cvt.rz.satfinite.ue8m0x2.bf16x2 {$w0}, {$r0}" 
+                        ro(%a : i16) 
+                        -> vector<2xbf16>
+ llvm.return  
+}
+
+llvm.func @cvt_i8_bf16(%a : i8)  {
+  // CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09.reg .b16 r;\0A\09.reg .b8 s;\0A\09mov.b16 {s,_}, $0;\0A\09cvt.rn.bf16.s8 r, s;\0A\09mov.b16 $1, r;\0A\09", "=h,h" %{{.*}} : (i16) -> i16
+  %za = llvm.zext %a : i8 to i16
+  %wo0 = nvvm.inline_ptx "{\n\t.reg .b16 r;\n\t.reg .b8 s;\n\tmov.b16 {s,_}, {$w0};\n\tcvt.rn.bf16.s8 r, s;\n\tmov.b16 {$r0}, r;\n\t" 
+                          ro(%za : i16) 
+                          -> i16
+  llvm.return  
+}

>From 0596685effd1835beeb973b971af680528499a8b Mon Sep 17 00:00:00 2001
From: Guray Ozen <gozen at nvidia.com>
Date: Tue, 2 Sep 2025 14:04:56 +0200
Subject: [PATCH 2/2] fx

---
 .../Dialect/LLVMIR/BasicPtxBuilderInterface.h |  4 ++-
 mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp |  4 ++-
 .../LLVMIR/IR/BasicPtxBuilderInterface.cpp    | 29 ++++++++++---------
 .../Conversion/NVVMToLLVM/nvvm-to-llvm.mlir   |  6 ++--
 4 files changed, 24 insertions(+), 19 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h b/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h
index 21331e5aa89f3..cb2489335a317 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h
@@ -18,6 +18,7 @@
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/Value.h"
+#include "llvm/Support/LogicalResult.h"
 
 namespace mlir {
 namespace NVVM {
@@ -82,7 +83,8 @@ class PtxBuilder {
         needsManualRegisterMapping(needsManualRegisterMapping) {}
 
   /// Add an operand with the read/write input type.
-  void insertValue(Value v, PTXRegisterMod itype = PTXRegisterMod::Read);
+  LogicalResult insertValue(Value v,
+                            PTXRegisterMod itype = PTXRegisterMod::Read);
 
   /// Builds the inline assembly Op and returns it. The `insertValue` needs to
   /// be called to pass operands before building the PTX.
diff --git a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
index c67ec3642f121..314cbed2e4f33 100644
--- a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
+++ b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
@@ -26,6 +26,7 @@
 #include "mlir/Pass/Pass.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/Support/DebugLog.h"
+#include "llvm/Support/LogicalResult.h"
 #include "llvm/Support/raw_ostream.h"
 
 #define DEBUG_TYPE "nvvm-to-llvm"
@@ -62,7 +63,8 @@ struct PtxLowering
     PtxBuilder generator(op, rewriter, needsManualMapping);
     for (auto &[asmValue, modifier] : asmValues) {
       LDBG() << asmValue << "\t Modifier : " << modifier;
-      generator.insertValue(asmValue, modifier);
+      if (failed(generator.insertValue(asmValue, modifier)))
+        return failure();
     }
 
     generator.buildAndReplaceOp();
diff --git a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
index a8b2663852e59..7220e10ea84d3 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
@@ -68,8 +68,10 @@ static FailureOr<char> getRegisterType(Type type, Location loc) {
     mlir::emitError(
         loc, "The register type could not be deduced from MLIR type. The ")
         << type
-        << " is not supported. Supported types are: i1, i16, i32, f32, f64, "
-           "pointers.\nSee the constraints from here: "
+        << " is not supported. Supported types are:"
+           "i1, i16, i32, i64, f32, f64,"
+           "pointers.\nPlease use llvm.bitcast if you have different type. "
+           "\nSee the constraints from here: "
            "https://docs.nvidia.com/cuda/inline-ptx-assembly/"
            "index.html#constraints";
     return failure();
@@ -89,20 +91,19 @@ static FailureOr<char> getRegisterType(Type type, Location loc) {
     // Case 2. Packed registers
     Type widened = elem;
     switch (lanes) {
-      // Pack 2x
+
     case 2:
-      if (elem.isF16() || elem.isBF16())
+      if (elem.isF16() || elem.isBF16()) // vector<2xf16>
         widened = f32;
-      else if (elem.isFloat(8))
+      else if (elem.isFloat(8)) // vector<2xf8>
         widened = i16;
       break;
-      // Pack 4x
     case 4:
-      if (elem.isInteger(8))
+      if (elem.isInteger(8)) // vector<i8x4>
         widened = i32;
-      else if (elem.isFloat(8))
+      else if (elem.isFloat(8)) // vector<f8x4>
         widened = f32;
-      else if (elem.isFloat(4))
+      else if (elem.isFloat(4)) // vector<f4x4>
         widened = i16;
       break;
       // Other packing is not supported
@@ -134,7 +135,7 @@ static SmallVector<Value> extractStructElements(PatternRewriter &rewriter,
   return elems;
 }
 
-void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
+LogicalResult PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
   LDBG() << v << "\t Modifier : " << itype << "\n";
   registerModifiers.push_back(itype);
 
@@ -180,18 +181,20 @@ void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
       } else {
         FailureOr<char> regType = getRegisterType(t, loc);
         if (failed(regType))
-          (void)rewriter.notifyMatchFailure(loc, "failed to get register type");
+          return rewriter.notifyMatchFailure(loc,
+                                             "failed to get register type");
         ss << getModifier() << regType.value() << ",";
       }
     }
-    return;
+    return success();
   }
   // Handle Scalars
   addValue(v);
   FailureOr<char> regType = getRegisterType(v, loc);
   if (failed(regType))
-    (void)rewriter.notifyMatchFailure(loc, "failed to get register type");
+    return rewriter.notifyMatchFailure(loc, "failed to get register type");
   ss << getModifier() << regType.value() << ",";
+  return success();
 }
 
 /// Check if the operation needs to pack and unpack results.
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 89828b2077d8e..ce17650d16d32 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -759,10 +759,8 @@ llvm.func @nvvm_pmevent() {
 
 // -----
 
-llvm.func @inline_ptx_pack_4i8(%src : vector<4xi8>)  {
- %mask = arith.constant 0x00000001 : i32
- %zero = arith.constant 0 : i32
-// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "dp4a.s32.s32 $0, $1, $2, $3;", "=r,r,n,n" %{{.*}}, %{{.*}}, %{{.*}} : (vector<4xi8>, i32, i32) -> i32
+llvm.func @inline_ptx_pack_4i8(%src : vector<4xi8>,  %mask : i32, %zero: i32)  {
+// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "dp4a.s32.s32 $0, $1, $2, $3;", "=r,r,r,r" %{{.*}}, %{{.*}}, %{{.*}} : (vector<4xi8>, i32, i32) -> i32
  %wo0 = nvvm.inline_ptx "dp4a.s32.s32 {$w0}, {$r0}, {$r1}, {$r2};" 
                         ro(%src, %mask, %zero : vector<4xi8>, i32, i32) 
                         -> i32



More information about the Mlir-commits mailing list