[Mlir-commits] [mlir] [MLIR]\[NVVM] Support packed registers in `inline_ptx` (PR #154904)
Guray Ozen
llvmlistbot at llvm.org
Fri Aug 22 00:32:15 PDT 2025
https://github.com/grypp created https://github.com/llvm/llvm-project/pull/154904
Add support for packed registers with vectors.
Example:
```
%wo0 = nvvm.inline_ptx
"dp4a.s32.s32 {$w0}, {$r0}, {$r1}, {$r2};"
ro(%src, %mask, %zero : vector<4xi8>, i32, i32)
-> i32
```
Here, `vector<4xi8>` is lowered to an `i32` register (i.e., an `r` in PTX).
>From 118deb1c1fc73d41de52194ef89c38435f4428c3 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Fri, 22 Aug 2025 07:31:26 +0000
Subject: [PATCH] [MLIR]\[NVVM] Support packed registers in `inline_ptx`**
Add support for packed registers with vectors.
**Example:**
```mlir
%wo0 = nvvm.inline_ptx
"dp4a.s32.s32 {$w0}, {$r0}, {$r1}, {$r2};"
ro(%src, %mask, %zero : vector<4xi8>, i32, i32)
-> i32
```
Here, `vector<4xi8>` is lowered to an `i32` register (i.e., an `r` in PTX).
---
.../LLVMIR/IR/BasicPtxBuilderInterface.cpp | 118 ++++++++++++++----
.../Conversion/NVVMToLLVM/nvvm-to-llvm.mlir | 37 ++++++
2 files changed, 129 insertions(+), 26 deletions(-)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
index 6d2a64f94e3ca..a8b2663852e59 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
@@ -12,10 +12,17 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/Support/LLVM.h"
#include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/DebugLog.h"
#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/Regex.h"
#define DEBUG_TYPE "ptx-builder"
@@ -31,35 +38,87 @@ using namespace NVVM;
static constexpr int64_t kSharedMemorySpace = 3;
-static char getRegisterType(Type type) {
- if (type.isInteger(1))
- return 'b';
- if (type.isInteger(16))
- return 'h';
- if (type.isInteger(32))
- return 'r';
- if (type.isInteger(64))
- return 'l';
- if (type.isF32())
- return 'f';
- if (type.isF64())
- return 'd';
- if (auto ptr = dyn_cast<LLVM::LLVMPointerType>(type)) {
- // Shared address spaces is addressed with 32-bit pointers.
- if (ptr.getAddressSpace() == kSharedMemorySpace) {
+static FailureOr<char> getRegisterType(Type type, Location loc) {
+ MLIRContext *ctx = type.getContext();
+ auto i16 = IntegerType::get(ctx, 16);
+ auto i32 = IntegerType::get(ctx, 32);
+ auto f32 = Float32Type::get(ctx);
+
+ auto getRegisterTypeForScalar = [&](Type type) -> FailureOr<char> {
+ if (type.isInteger(1))
+ return 'b';
+ if (type.isInteger(16))
+ return 'h';
+ if (type.isInteger(32))
return 'r';
+ if (type.isInteger(64))
+ return 'l';
+ if (type.isF32())
+ return 'f';
+ if (type.isF64())
+ return 'd';
+ if (auto ptr = dyn_cast<LLVM::LLVMPointerType>(type)) {
+ // Shared address spaces is addressed with 32-bit pointers.
+ if (ptr.getAddressSpace() == kSharedMemorySpace) {
+ return 'r';
+ }
+ return 'l';
}
- return 'l';
+ // register type for struct is not supported.
+ mlir::emitError(
+ loc, "The register type could not be deduced from MLIR type. The ")
+ << type
+ << " is not supported. Supported types are: i1, i16, i32, f32, f64, "
+ "pointers.\nSee the constraints from here: "
+ "https://docs.nvidia.com/cuda/inline-ptx-assembly/"
+ "index.html#constraints";
+ return failure();
+ };
+
+ // Packed registers
+ if (auto v = dyn_cast<VectorType>(type)) {
+ assert(v.getNumDynamicDims() == 0 && "Dynamic vectors are not supported");
+
+ int64_t lanes = v.getNumElements();
+ Type elem = v.getElementType();
+
+ // Case 1. Single vector
+ if (lanes <= 1)
+ return getRegisterTypeForScalar(elem);
+
+ // Case 2. Packed registers
+ Type widened = elem;
+ switch (lanes) {
+ // Pack 2x
+ case 2:
+ if (elem.isF16() || elem.isBF16())
+ widened = f32;
+ else if (elem.isFloat(8))
+ widened = i16;
+ break;
+ // Pack 4x
+ case 4:
+ if (elem.isInteger(8))
+ widened = i32;
+ else if (elem.isFloat(8))
+ widened = f32;
+ else if (elem.isFloat(4))
+ widened = i16;
+ break;
+ // Other packing is not supported
+ default:
+ break;
+ }
+ return getRegisterTypeForScalar(widened);
}
- // register type for struct is not supported.
- llvm_unreachable("The register type could not deduced from MLIR type");
- return '?';
+
+ return getRegisterTypeForScalar(type);
}
-static char getRegisterType(Value v) {
+static FailureOr<char> getRegisterType(Value v, Location loc) {
if (v.getDefiningOp<LLVM::ConstantOp>())
return 'n';
- return getRegisterType(v.getType());
+ return getRegisterType(v.getType(), loc);
}
/// Extract every element of a struct value.
@@ -79,6 +138,7 @@ void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
LDBG() << v << "\t Modifier : " << itype << "\n";
registerModifiers.push_back(itype);
+ Location loc = interfaceOp->getLoc();
auto getModifier = [&]() -> const char * {
switch (itype) {
case PTXRegisterMod::Read:
@@ -111,21 +171,27 @@ void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
}
for (auto [idx, t] : llvm::enumerate(stype.getBody())) {
if (itype != PTXRegisterMod::Write) {
- Value extractValue = LLVM::ExtractValueOp::create(
- rewriter, interfaceOp->getLoc(), v, idx);
+ Value extractValue =
+ LLVM::ExtractValueOp::create(rewriter, loc, v, idx);
addValue(extractValue);
}
if (itype == PTXRegisterMod::ReadWrite) {
ss << idx << ",";
} else {
- ss << getModifier() << getRegisterType(t) << ",";
+ FailureOr<char> regType = getRegisterType(t, loc);
+ if (failed(regType))
+ (void)rewriter.notifyMatchFailure(loc, "failed to get register type");
+ ss << getModifier() << regType.value() << ",";
}
}
return;
}
// Handle Scalars
addValue(v);
- ss << getModifier() << getRegisterType(v) << ",";
+ FailureOr<char> regType = getRegisterType(v, loc);
+ if (failed(regType))
+ (void)rewriter.notifyMatchFailure(loc, "failed to get register type");
+ ss << getModifier() << regType.value() << ",";
}
/// Check if the operation needs to pack and unpack results.
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 2a19c72ab0840..89828b2077d8e 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -756,3 +756,40 @@ llvm.func @nvvm_pmevent() {
nvvm.pmevent id = 4
llvm.return
}
+
+// -----
+
+llvm.func @inline_ptx_pack_4i8(%src : vector<4xi8>) {
+ %mask = arith.constant 0x00000001 : i32
+ %zero = arith.constant 0 : i32
+// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "dp4a.s32.s32 $0, $1, $2, $3;", "=r,r,n,n" %{{.*}}, %{{.*}}, %{{.*}} : (vector<4xi8>, i32, i32) -> i32
+ %wo0 = nvvm.inline_ptx "dp4a.s32.s32 {$w0}, {$r0}, {$r1}, {$r2};"
+ ro(%src, %mask, %zero : vector<4xi8>, i32, i32)
+ -> i32
+ llvm.return
+}
+
+llvm.func @inline_ptx_pack_2bf16(%a : f32, %b : f32) {
+ // CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "cvt.rn.satfinite.bf16x2.f32 $0, $1, $2;", "=f,f,f" %{{.*}}, %{{.*}} : (f32, f32) -> vector<2xbf16>
+ %wo0 = nvvm.inline_ptx "cvt.rn.satfinite.bf16x2.f32 {$w0}, {$r0}, {$r1};"
+ ro(%a, %b : f32, f32)
+ -> vector<2xbf16>
+ llvm.return
+}
+
+llvm.func @inline_ptx_cvt_rn_e4m3x2_f16x2(%a : i16) {
+// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "cvt.rz.satfinite.ue8m0x2.bf16x2 $0, $1", "=f,h" %{{.*}} : (i16) -> vector<2xbf16>
+ %wo0 = nvvm.inline_ptx "cvt.rz.satfinite.ue8m0x2.bf16x2 {$w0}, {$r0}"
+ ro(%a : i16)
+ -> vector<2xbf16>
+ llvm.return
+}
+
+llvm.func @cvt_i8_bf16(%a : i8) {
+ // CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09.reg .b16 r;\0A\09.reg .b8 s;\0A\09mov.b16 {s,_}, $0;\0A\09cvt.rn.bf16.s8 r, s;\0A\09mov.b16 $1, r;\0A\09", "=h,h" %{{.*}} : (i16) -> i16
+ %za = llvm.zext %a : i8 to i16
+ %wo0 = nvvm.inline_ptx "{\n\t.reg .b16 r;\n\t.reg .b8 s;\n\tmov.b16 {s,_}, {$w0};\n\tcvt.rn.bf16.s8 r, s;\n\tmov.b16 {$r0}, r;\n\t"
+ ro(%za : i16)
+ -> i16
+ llvm.return
+}
More information about the Mlir-commits
mailing list