[Mlir-commits] [mlir] [MLIR][NVVM] Support packed registers in `inline_ptx` (PR #154904)
Guray Ozen
llvmlistbot at llvm.org
Tue Sep 2 05:05:10 PDT 2025
https://github.com/grypp updated https://github.com/llvm/llvm-project/pull/154904
>From 118deb1c1fc73d41de52194ef89c38435f4428c3 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Fri, 22 Aug 2025 07:31:26 +0000
Subject: [PATCH 1/2] [MLIR]\[NVVM] Support packed registers in `inline_ptx`**
Add support for packed registers with vectors.
**Example:**
```mlir
%wo0 = nvvm.inline_ptx
"dp4a.s32.s32 {$w0}, {$r0}, {$r1}, {$r2};"
ro(%src, %mask, %zero : vector<4xi8>, i32, i32)
-> i32
```
Here, `vector<4xi8>` is lowered to an `i32` register (i.e., an `r` in PTX).
---
.../LLVMIR/IR/BasicPtxBuilderInterface.cpp | 118 ++++++++++++++----
.../Conversion/NVVMToLLVM/nvvm-to-llvm.mlir | 37 ++++++
2 files changed, 129 insertions(+), 26 deletions(-)
diff --git a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
index 6d2a64f94e3ca..a8b2663852e59 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
@@ -12,10 +12,17 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/Support/LLVM.h"
#include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/DebugLog.h"
#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/Regex.h"
#define DEBUG_TYPE "ptx-builder"
@@ -31,35 +38,87 @@ using namespace NVVM;
static constexpr int64_t kSharedMemorySpace = 3;
-static char getRegisterType(Type type) {
- if (type.isInteger(1))
- return 'b';
- if (type.isInteger(16))
- return 'h';
- if (type.isInteger(32))
- return 'r';
- if (type.isInteger(64))
- return 'l';
- if (type.isF32())
- return 'f';
- if (type.isF64())
- return 'd';
- if (auto ptr = dyn_cast<LLVM::LLVMPointerType>(type)) {
- // Shared address spaces is addressed with 32-bit pointers.
- if (ptr.getAddressSpace() == kSharedMemorySpace) {
+static FailureOr<char> getRegisterType(Type type, Location loc) {
+ MLIRContext *ctx = type.getContext();
+ auto i16 = IntegerType::get(ctx, 16);
+ auto i32 = IntegerType::get(ctx, 32);
+ auto f32 = Float32Type::get(ctx);
+
+ auto getRegisterTypeForScalar = [&](Type type) -> FailureOr<char> {
+ if (type.isInteger(1))
+ return 'b';
+ if (type.isInteger(16))
+ return 'h';
+ if (type.isInteger(32))
return 'r';
+ if (type.isInteger(64))
+ return 'l';
+ if (type.isF32())
+ return 'f';
+ if (type.isF64())
+ return 'd';
+ if (auto ptr = dyn_cast<LLVM::LLVMPointerType>(type)) {
+ // Shared address spaces is addressed with 32-bit pointers.
+ if (ptr.getAddressSpace() == kSharedMemorySpace) {
+ return 'r';
+ }
+ return 'l';
}
- return 'l';
+ // register type for struct is not supported.
+ mlir::emitError(
+ loc, "The register type could not be deduced from MLIR type. The ")
+ << type
+ << " is not supported. Supported types are: i1, i16, i32, f32, f64, "
+ "pointers.\nSee the constraints from here: "
+ "https://docs.nvidia.com/cuda/inline-ptx-assembly/"
+ "index.html#constraints";
+ return failure();
+ };
+
+ // Packed registers
+ if (auto v = dyn_cast<VectorType>(type)) {
+ assert(v.getNumDynamicDims() == 0 && "Dynamic vectors are not supported");
+
+ int64_t lanes = v.getNumElements();
+ Type elem = v.getElementType();
+
+ // Case 1. Single vector
+ if (lanes <= 1)
+ return getRegisterTypeForScalar(elem);
+
+ // Case 2. Packed registers
+ Type widened = elem;
+ switch (lanes) {
+ // Pack 2x
+ case 2:
+ if (elem.isF16() || elem.isBF16())
+ widened = f32;
+ else if (elem.isFloat(8))
+ widened = i16;
+ break;
+ // Pack 4x
+ case 4:
+ if (elem.isInteger(8))
+ widened = i32;
+ else if (elem.isFloat(8))
+ widened = f32;
+ else if (elem.isFloat(4))
+ widened = i16;
+ break;
+ // Other packing is not supported
+ default:
+ break;
+ }
+ return getRegisterTypeForScalar(widened);
}
- // register type for struct is not supported.
- llvm_unreachable("The register type could not deduced from MLIR type");
- return '?';
+
+ return getRegisterTypeForScalar(type);
}
-static char getRegisterType(Value v) {
+static FailureOr<char> getRegisterType(Value v, Location loc) {
if (v.getDefiningOp<LLVM::ConstantOp>())
return 'n';
- return getRegisterType(v.getType());
+ return getRegisterType(v.getType(), loc);
}
/// Extract every element of a struct value.
@@ -79,6 +138,7 @@ void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
LDBG() << v << "\t Modifier : " << itype << "\n";
registerModifiers.push_back(itype);
+ Location loc = interfaceOp->getLoc();
auto getModifier = [&]() -> const char * {
switch (itype) {
case PTXRegisterMod::Read:
@@ -111,21 +171,27 @@ void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
}
for (auto [idx, t] : llvm::enumerate(stype.getBody())) {
if (itype != PTXRegisterMod::Write) {
- Value extractValue = LLVM::ExtractValueOp::create(
- rewriter, interfaceOp->getLoc(), v, idx);
+ Value extractValue =
+ LLVM::ExtractValueOp::create(rewriter, loc, v, idx);
addValue(extractValue);
}
if (itype == PTXRegisterMod::ReadWrite) {
ss << idx << ",";
} else {
- ss << getModifier() << getRegisterType(t) << ",";
+ FailureOr<char> regType = getRegisterType(t, loc);
+ if (failed(regType))
+ (void)rewriter.notifyMatchFailure(loc, "failed to get register type");
+ ss << getModifier() << regType.value() << ",";
}
}
return;
}
// Handle Scalars
addValue(v);
- ss << getModifier() << getRegisterType(v) << ",";
+ FailureOr<char> regType = getRegisterType(v, loc);
+ if (failed(regType))
+ (void)rewriter.notifyMatchFailure(loc, "failed to get register type");
+ ss << getModifier() << regType.value() << ",";
}
/// Check if the operation needs to pack and unpack results.
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 2a19c72ab0840..89828b2077d8e 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -756,3 +756,40 @@ llvm.func @nvvm_pmevent() {
nvvm.pmevent id = 4
llvm.return
}
+
+// -----
+
+llvm.func @inline_ptx_pack_4i8(%src : vector<4xi8>) {
+ %mask = arith.constant 0x00000001 : i32
+ %zero = arith.constant 0 : i32
+// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "dp4a.s32.s32 $0, $1, $2, $3;", "=r,r,n,n" %{{.*}}, %{{.*}}, %{{.*}} : (vector<4xi8>, i32, i32) -> i32
+ %wo0 = nvvm.inline_ptx "dp4a.s32.s32 {$w0}, {$r0}, {$r1}, {$r2};"
+ ro(%src, %mask, %zero : vector<4xi8>, i32, i32)
+ -> i32
+ llvm.return
+}
+
+llvm.func @inline_ptx_pack_2bf16(%a : f32, %b : f32) {
+ // CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "cvt.rn.satfinite.bf16x2.f32 $0, $1, $2;", "=f,f,f" %{{.*}}, %{{.*}} : (f32, f32) -> vector<2xbf16>
+ %wo0 = nvvm.inline_ptx "cvt.rn.satfinite.bf16x2.f32 {$w0}, {$r0}, {$r1};"
+ ro(%a, %b : f32, f32)
+ -> vector<2xbf16>
+ llvm.return
+}
+
+llvm.func @inline_ptx_cvt_rn_e4m3x2_f16x2(%a : i16) {
+// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "cvt.rz.satfinite.ue8m0x2.bf16x2 $0, $1", "=f,h" %{{.*}} : (i16) -> vector<2xbf16>
+ %wo0 = nvvm.inline_ptx "cvt.rz.satfinite.ue8m0x2.bf16x2 {$w0}, {$r0}"
+ ro(%a : i16)
+ -> vector<2xbf16>
+ llvm.return
+}
+
+llvm.func @cvt_i8_bf16(%a : i8) {
+ // CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09.reg .b16 r;\0A\09.reg .b8 s;\0A\09mov.b16 {s,_}, $0;\0A\09cvt.rn.bf16.s8 r, s;\0A\09mov.b16 $1, r;\0A\09", "=h,h" %{{.*}} : (i16) -> i16
+ %za = llvm.zext %a : i8 to i16
+ %wo0 = nvvm.inline_ptx "{\n\t.reg .b16 r;\n\t.reg .b8 s;\n\tmov.b16 {s,_}, {$w0};\n\tcvt.rn.bf16.s8 r, s;\n\tmov.b16 {$r0}, r;\n\t"
+ ro(%za : i16)
+ -> i16
+ llvm.return
+}
>From 0596685effd1835beeb973b971af680528499a8b Mon Sep 17 00:00:00 2001
From: Guray Ozen <gozen at nvidia.com>
Date: Tue, 2 Sep 2025 14:04:56 +0200
Subject: [PATCH 2/2] fx
---
.../Dialect/LLVMIR/BasicPtxBuilderInterface.h | 4 ++-
mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp | 4 ++-
.../LLVMIR/IR/BasicPtxBuilderInterface.cpp | 29 ++++++++++---------
.../Conversion/NVVMToLLVM/nvvm-to-llvm.mlir | 6 ++--
4 files changed, 24 insertions(+), 19 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h b/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h
index 21331e5aa89f3..cb2489335a317 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h
@@ -18,6 +18,7 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
+#include "llvm/Support/LogicalResult.h"
namespace mlir {
namespace NVVM {
@@ -82,7 +83,8 @@ class PtxBuilder {
needsManualRegisterMapping(needsManualRegisterMapping) {}
/// Add an operand with the read/write input type.
- void insertValue(Value v, PTXRegisterMod itype = PTXRegisterMod::Read);
+ LogicalResult insertValue(Value v,
+ PTXRegisterMod itype = PTXRegisterMod::Read);
/// Builds the inline assembly Op and returns it. The `insertValue` needs to
/// be called to pass operands before building the PTX.
diff --git a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
index c67ec3642f121..314cbed2e4f33 100644
--- a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
+++ b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
@@ -26,6 +26,7 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "llvm/Support/DebugLog.h"
+#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/raw_ostream.h"
#define DEBUG_TYPE "nvvm-to-llvm"
@@ -62,7 +63,8 @@ struct PtxLowering
PtxBuilder generator(op, rewriter, needsManualMapping);
for (auto &[asmValue, modifier] : asmValues) {
LDBG() << asmValue << "\t Modifier : " << modifier;
- generator.insertValue(asmValue, modifier);
+ if (failed(generator.insertValue(asmValue, modifier)))
+ return failure();
}
generator.buildAndReplaceOp();
diff --git a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
index a8b2663852e59..7220e10ea84d3 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
@@ -68,8 +68,10 @@ static FailureOr<char> getRegisterType(Type type, Location loc) {
mlir::emitError(
loc, "The register type could not be deduced from MLIR type. The ")
<< type
- << " is not supported. Supported types are: i1, i16, i32, f32, f64, "
- "pointers.\nSee the constraints from here: "
+ << " is not supported. Supported types are:"
+ "i1, i16, i32, i64, f32, f64,"
+ "pointers.\nPlease use llvm.bitcast if you have different type. "
+ "\nSee the constraints from here: "
"https://docs.nvidia.com/cuda/inline-ptx-assembly/"
"index.html#constraints";
return failure();
@@ -89,20 +91,19 @@ static FailureOr<char> getRegisterType(Type type, Location loc) {
// Case 2. Packed registers
Type widened = elem;
switch (lanes) {
- // Pack 2x
+
case 2:
- if (elem.isF16() || elem.isBF16())
+ if (elem.isF16() || elem.isBF16()) // vector<2xf16>
widened = f32;
- else if (elem.isFloat(8))
+ else if (elem.isFloat(8)) // vector<2xf8>
widened = i16;
break;
- // Pack 4x
case 4:
- if (elem.isInteger(8))
+ if (elem.isInteger(8)) // vector<i8x4>
widened = i32;
- else if (elem.isFloat(8))
+ else if (elem.isFloat(8)) // vector<f8x4>
widened = f32;
- else if (elem.isFloat(4))
+ else if (elem.isFloat(4)) // vector<f4x4>
widened = i16;
break;
// Other packing is not supported
@@ -134,7 +135,7 @@ static SmallVector<Value> extractStructElements(PatternRewriter &rewriter,
return elems;
}
-void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
+LogicalResult PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
LDBG() << v << "\t Modifier : " << itype << "\n";
registerModifiers.push_back(itype);
@@ -180,18 +181,20 @@ void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
} else {
FailureOr<char> regType = getRegisterType(t, loc);
if (failed(regType))
- (void)rewriter.notifyMatchFailure(loc, "failed to get register type");
+ return rewriter.notifyMatchFailure(loc,
+ "failed to get register type");
ss << getModifier() << regType.value() << ",";
}
}
- return;
+ return success();
}
// Handle Scalars
addValue(v);
FailureOr<char> regType = getRegisterType(v, loc);
if (failed(regType))
- (void)rewriter.notifyMatchFailure(loc, "failed to get register type");
+ return rewriter.notifyMatchFailure(loc, "failed to get register type");
ss << getModifier() << regType.value() << ",";
+ return success();
}
/// Check if the operation needs to pack and unpack results.
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 89828b2077d8e..ce17650d16d32 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -759,10 +759,8 @@ llvm.func @nvvm_pmevent() {
// -----
-llvm.func @inline_ptx_pack_4i8(%src : vector<4xi8>) {
- %mask = arith.constant 0x00000001 : i32
- %zero = arith.constant 0 : i32
-// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "dp4a.s32.s32 $0, $1, $2, $3;", "=r,r,n,n" %{{.*}}, %{{.*}}, %{{.*}} : (vector<4xi8>, i32, i32) -> i32
+llvm.func @inline_ptx_pack_4i8(%src : vector<4xi8>, %mask : i32, %zero: i32) {
+// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "dp4a.s32.s32 $0, $1, $2, $3;", "=r,r,r,r" %{{.*}}, %{{.*}}, %{{.*}} : (vector<4xi8>, i32, i32) -> i32
%wo0 = nvvm.inline_ptx "dp4a.s32.s32 {$w0}, {$r0}, {$r1}, {$r2};"
ro(%src, %mask, %zero : vector<4xi8>, i32, i32)
-> i32
More information about the Mlir-commits
mailing list