[Mlir-commits] [mlir] [MLIR][NVVM] Support packed registers in `inline_ptx` (PR #154904)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Aug 22 00:32:47 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-llvm

Author: Guray Ozen (grypp)

<details>
<summary>Changes</summary>

Add support for packed registers with vectors.

Example:

```
%wo0 = nvvm.inline_ptx
          "dp4a.s32.s32 {$w0}, {$r0}, {$r1}, {$r2};"
          ro(%src, %mask, %zero : vector<4xi8>, i32, i32)
          -> i32
```

Here, `vector<4xi8>` is lowered to an `i32` register (i.e., an `r` in PTX).

---
Full diff: https://github.com/llvm/llvm-project/pull/154904.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp (+92-26) 
- (modified) mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir (+37) 


``````````diff
diff --git a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
index 6d2a64f94e3ca..a8b2663852e59 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
@@ -12,10 +12,17 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/MLIRContext.h"
 
+#include "mlir/Support/LLVM.h"
 #include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/DebugLog.h"
 #include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/LogicalResult.h"
 #include "llvm/Support/Regex.h"
 
 #define DEBUG_TYPE "ptx-builder"
@@ -31,35 +38,87 @@ using namespace NVVM;
 
 static constexpr int64_t kSharedMemorySpace = 3;
 
-static char getRegisterType(Type type) {
-  if (type.isInteger(1))
-    return 'b';
-  if (type.isInteger(16))
-    return 'h';
-  if (type.isInteger(32))
-    return 'r';
-  if (type.isInteger(64))
-    return 'l';
-  if (type.isF32())
-    return 'f';
-  if (type.isF64())
-    return 'd';
-  if (auto ptr = dyn_cast<LLVM::LLVMPointerType>(type)) {
-    // Shared address spaces is addressed with 32-bit pointers.
-    if (ptr.getAddressSpace() == kSharedMemorySpace) {
+static FailureOr<char> getRegisterType(Type type, Location loc) {
+  MLIRContext *ctx = type.getContext();
+  auto i16 = IntegerType::get(ctx, 16);
+  auto i32 = IntegerType::get(ctx, 32);
+  auto f32 = Float32Type::get(ctx);
+
+  auto getRegisterTypeForScalar = [&](Type type) -> FailureOr<char> {
+    if (type.isInteger(1))
+      return 'b';
+    if (type.isInteger(16))
+      return 'h';
+    if (type.isInteger(32))
       return 'r';
+    if (type.isInteger(64))
+      return 'l';
+    if (type.isF32())
+      return 'f';
+    if (type.isF64())
+      return 'd';
+    if (auto ptr = dyn_cast<LLVM::LLVMPointerType>(type)) {
+      // Shared address spaces is addressed with 32-bit pointers.
+      if (ptr.getAddressSpace() == kSharedMemorySpace) {
+        return 'r';
+      }
+      return 'l';
     }
-    return 'l';
+    // register type for struct is not supported.
+    mlir::emitError(
+        loc, "The register type could not be deduced from MLIR type. The ")
+        << type
+        << " is not supported. Supported types are: i1, i16, i32, f32, f64, "
+           "pointers.\nSee the constraints from here: "
+           "https://docs.nvidia.com/cuda/inline-ptx-assembly/"
+           "index.html#constraints";
+    return failure();
+  };
+
+  // Packed registers
+  if (auto v = dyn_cast<VectorType>(type)) {
+    assert(v.getNumDynamicDims() == 0 && "Dynamic vectors are not supported");
+
+    int64_t lanes = v.getNumElements();
+    Type elem = v.getElementType();
+
+    // Case 1. Single vector
+    if (lanes <= 1)
+      return getRegisterTypeForScalar(elem);
+
+    // Case 2. Packed registers
+    Type widened = elem;
+    switch (lanes) {
+      // Pack 2x
+    case 2:
+      if (elem.isF16() || elem.isBF16())
+        widened = f32;
+      else if (elem.isFloat(8))
+        widened = i16;
+      break;
+      // Pack 4x
+    case 4:
+      if (elem.isInteger(8))
+        widened = i32;
+      else if (elem.isFloat(8))
+        widened = f32;
+      else if (elem.isFloat(4))
+        widened = i16;
+      break;
+      // Other packing is not supported
+    default:
+      break;
+    }
+    return getRegisterTypeForScalar(widened);
   }
-  // register type for struct is not supported.
-  llvm_unreachable("The register type could not deduced from MLIR type");
-  return '?';
+
+  return getRegisterTypeForScalar(type);
 }
 
-static char getRegisterType(Value v) {
+static FailureOr<char> getRegisterType(Value v, Location loc) {
   if (v.getDefiningOp<LLVM::ConstantOp>())
     return 'n';
-  return getRegisterType(v.getType());
+  return getRegisterType(v.getType(), loc);
 }
 
 /// Extract every element of a struct value.
@@ -79,6 +138,7 @@ void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
   LDBG() << v << "\t Modifier : " << itype << "\n";
   registerModifiers.push_back(itype);
 
+  Location loc = interfaceOp->getLoc();
   auto getModifier = [&]() -> const char * {
     switch (itype) {
     case PTXRegisterMod::Read:
@@ -111,21 +171,27 @@ void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
     }
     for (auto [idx, t] : llvm::enumerate(stype.getBody())) {
       if (itype != PTXRegisterMod::Write) {
-        Value extractValue = LLVM::ExtractValueOp::create(
-            rewriter, interfaceOp->getLoc(), v, idx);
+        Value extractValue =
+            LLVM::ExtractValueOp::create(rewriter, loc, v, idx);
         addValue(extractValue);
       }
       if (itype == PTXRegisterMod::ReadWrite) {
         ss << idx << ",";
       } else {
-        ss << getModifier() << getRegisterType(t) << ",";
+        FailureOr<char> regType = getRegisterType(t, loc);
+        if (failed(regType))
+          (void)rewriter.notifyMatchFailure(loc, "failed to get register type");
+        ss << getModifier() << regType.value() << ",";
       }
     }
     return;
   }
   // Handle Scalars
   addValue(v);
-  ss << getModifier() << getRegisterType(v) << ",";
+  FailureOr<char> regType = getRegisterType(v, loc);
+  if (failed(regType))
+    (void)rewriter.notifyMatchFailure(loc, "failed to get register type");
+  ss << getModifier() << regType.value() << ",";
 }
 
 /// Check if the operation needs to pack and unpack results.
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 2a19c72ab0840..89828b2077d8e 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -756,3 +756,40 @@ llvm.func @nvvm_pmevent() {
   nvvm.pmevent id = 4
   llvm.return
 }
+
+// -----
+
+llvm.func @inline_ptx_pack_4i8(%src : vector<4xi8>)  {
+ %mask = arith.constant 0x00000001 : i32
+ %zero = arith.constant 0 : i32
+// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "dp4a.s32.s32 $0, $1, $2, $3;", "=r,r,n,n" %{{.*}}, %{{.*}}, %{{.*}} : (vector<4xi8>, i32, i32) -> i32
+ %wo0 = nvvm.inline_ptx "dp4a.s32.s32 {$w0}, {$r0}, {$r1}, {$r2};" 
+                        ro(%src, %mask, %zero : vector<4xi8>, i32, i32) 
+                        -> i32
+ llvm.return  
+}
+
+llvm.func @inline_ptx_pack_2bf16(%a : f32, %b : f32)  {
+ // CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "cvt.rn.satfinite.bf16x2.f32 $0, $1, $2;", "=f,f,f" %{{.*}}, %{{.*}} : (f32, f32) -> vector<2xbf16>
+ %wo0 = nvvm.inline_ptx "cvt.rn.satfinite.bf16x2.f32 {$w0}, {$r0}, {$r1};" 
+                        ro(%a, %b : f32, f32) 
+                        -> vector<2xbf16>
+ llvm.return  
+}
+
+llvm.func @inline_ptx_cvt_rn_e4m3x2_f16x2(%a : i16)  {
+// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "cvt.rz.satfinite.ue8m0x2.bf16x2 $0, $1", "=f,h" %{{.*}} : (i16) -> vector<2xbf16>
+ %wo0 = nvvm.inline_ptx "cvt.rz.satfinite.ue8m0x2.bf16x2 {$w0}, {$r0}" 
+                        ro(%a : i16) 
+                        -> vector<2xbf16>
+ llvm.return  
+}
+
+llvm.func @cvt_i8_bf16(%a : i8)  {
+  // CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09.reg .b16 r;\0A\09.reg .b8 s;\0A\09mov.b16 {s,_}, $0;\0A\09cvt.rn.bf16.s8 r, s;\0A\09mov.b16 $1, r;\0A\09", "=h,h" %{{.*}} : (i16) -> i16
+  %za = llvm.zext %a : i8 to i16
+  %wo0 = nvvm.inline_ptx "{\n\t.reg .b16 r;\n\t.reg .b8 s;\n\tmov.b16 {s,_}, {$w0};\n\tcvt.rn.bf16.s8 r, s;\n\tmov.b16 {$r0}, r;\n\t" 
+                          ro(%za : i16) 
+                          -> i16
+  llvm.return  
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/154904


More information about the Mlir-commits mailing list