[Mlir-commits] [mlir] [MLIR][NVVM] Add support for multiple return values in `inline_ptx` (PR #153774)
Guray Ozen
llvmlistbot at llvm.org
Fri Aug 15 03:04:11 PDT 2025
https://github.com/grypp created https://github.com/llvm/llvm-project/pull/153774
This PR adds the ability for `nvvm.inline_ptx` to return multiple values, matching the expected semantics in PTX while respecting LLVM’s constraints.
LLVM’s `inline_asm` op does not natively support multiple returns — instead, it requires packing results into an LLVM `struct` and then extracting them. This PR implements automatic packing/unpacking so that multiple return values can be expressed naturally in MLIR without extra user boilerplate.
**Example**
MLIR:
```
%r1, %r2 = nvvm.inline_ptx "{
.reg .pred p;
setp.ge.s32 p, $2, $3;
selp.s32 $0, $2, $3, p;
selp.s32 $1, $2, $3, !p;
}" (%a, %b) : i32, i32 -> i32, i32
%r3 = llvm.add %r1, %r2 : i32
```
Lowered LLVM IR:
```
%1 = llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09 .reg .pred p;\0A\09 setp.ge.s32 p, $2, $3;\0A\09 selp.s32 $0, $2, $3, p;\0A\09 selp.s32 $1, $2, $3, !p;\0A\09}\0A", "=r,=r,r,r" %a, %b : (i32, i32) -> !llvm.struct<(i32, i32)>
%2 = llvm.extractvalue %1[0] : !llvm.struct<(i32, i32)>
%3 = llvm.extractvalue %1[1] : !llvm.struct<(i32, i32)>
%4 = llvm.add %2, %3 : i32
```
>From 8e305119dcc1ab2b343f939c9548871895ea14ca Mon Sep 17 00:00:00 2001
From: Guray Ozen <gozen at nvidia.com>
Date: Fri, 15 Aug 2025 10:03:26 +0000
Subject: [PATCH] [MLIR][NVVM] Add support for multiple return values in
`inline_ptx`
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
This PR adds the ability for `nvvm.inline_ptx` to return multiple values, matching the expected semantics in PTX while respecting LLVM’s constraints.
LLVM’s `inline_asm` op does not natively support multiple returns — instead, it requires packing results into an LLVM `struct` and then extracting them. This PR implements automatic packing/unpacking so that multiple return values can be expressed naturally in MLIR without extra user boilerplate.
**Example**
MLIR:
```
%r1, %r2 = nvvm.inline_ptx "{
.reg .pred p;
setp.ge.s32 p, $2, $3;
selp.s32 $0, $2, $3, p;
selp.s32 $1, $2, $3, !p;
}" (%a, %b) : i32, i32 -> i32, i32
%r3 = llvm.add %r1, %r2 : i32
```
Lowered LLVM IR:
```
%1 = llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09 .reg .pred p;\0A\09 setp.ge.s32 p, $2, $3;\0A\09 selp.s32 $0, $2, $3, p;\0A\09 selp.s32 $1, $2, $3, !p;\0A\09}\0A", "=r,=r,r,r" %a, %b : (i32, i32) -> !llvm.struct<(i32, i32)>
%2 = llvm.extractvalue %1[0] : !llvm.struct<(i32, i32)>
%3 = llvm.extractvalue %1[1] : !llvm.struct<(i32, i32)>
%4 = llvm.add %2, %3 : i32
```
---
mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp | 2 +-
.../LLVMIR/IR/BasicPtxBuilderInterface.cpp | 48 +++++++++++++++++--
.../Conversion/NVVMToLLVM/nvvm-to-llvm.mlir | 12 +++++
3 files changed, 56 insertions(+), 6 deletions(-)
diff --git a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
index 91788f9848fe6..3fe9c87a79039 100644
--- a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
+++ b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
@@ -61,7 +61,7 @@ struct PtxLowering
op.getAsmValues(rewriter, asmValues);
for (auto &[asmValue, modifier] : asmValues) {
- LDBG() << asmValue << "\t Modifier : " << &modifier;
+ LDBG() << asmValue << "\t Modifier : " << (int)modifier;
generator.insertValue(asmValue, modifier);
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
index 894de4408c375..7d64d6b1fe18a 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
@@ -107,11 +107,28 @@ void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
ss << getModifier() << getRegisterType(v) << ",";
}
+static bool needsPackUnpack(BasicPtxBuilderInterface interfaceOp) {
+ return interfaceOp->getNumResults() > 1;
+}
+
+static SmallVector<Type>
+coalesceResultTypes(MLIRContext *ctx, BasicPtxBuilderInterface interfaceOp) {
+ TypeRange results = interfaceOp->getResultTypes();
+
+ if (!needsPackUnpack(interfaceOp))
+ return llvm::to_vector<1>(results);
+
+ SmallVector<mlir::Type> elems(results.begin(), results.end());
+ auto sTy = LLVM::LLVMStructType::getLiteral(ctx, elems, /*isPacked=*/false);
+ return {sTy};
+}
+
LLVM::InlineAsmOp PtxBuilder::build() {
+ MLIRContext *ctx = interfaceOp->getContext();
auto asmDialectAttr = LLVM::AsmDialectAttr::get(interfaceOp->getContext(),
LLVM::AsmDialect::AD_ATT);
- auto resultTypes = interfaceOp->getResultTypes();
+ SmallVector<Type> resultTypes = coalesceResultTypes(ctx, interfaceOp);
// Remove the last comma from the constraints string.
if (!registerConstraints.empty() &&
@@ -136,7 +153,7 @@ LLVM::InlineAsmOp PtxBuilder::build() {
rewriter, interfaceOp->getLoc(),
/*result types=*/resultTypes,
/*operands=*/ptxOperands,
- /*asm_string=*/llvm::StringRef(ptxInstruction),
+ /*asm_string=*/StringRef(ptxInstruction),
/*constraints=*/registerConstraints.data(),
/*has_side_effects=*/interfaceOp.hasSideEffect(),
/*is_align_stack=*/false, LLVM::TailCallKind::None,
@@ -147,9 +164,30 @@ LLVM::InlineAsmOp PtxBuilder::build() {
void PtxBuilder::buildAndReplaceOp() {
LLVM::InlineAsmOp inlineAsmOp = build();
LLVM_DEBUG(DBGS() << "\n Generated PTX \n\t" << inlineAsmOp << "\n");
- if (inlineAsmOp->getNumResults() == interfaceOp->getNumResults()) {
- rewriter.replaceOp(interfaceOp, inlineAsmOp);
- } else {
+
+ // Case 1: no result
+ if (inlineAsmOp->getNumResults() == 0) {
rewriter.eraseOp(interfaceOp);
+ return;
}
+
+ // Case 2: single result, forward it directly
+ if (!needsPackUnpack(interfaceOp)) {
+ rewriter.replaceOp(interfaceOp, inlineAsmOp->getResults());
+ return;
+ }
+
+ // Case 3: multiple results were packed; unpack the struct.
+ auto structTy = llvm::cast<mlir::LLVM::LLVMStructType>(
+ inlineAsmOp.getResultTypes().front());
+
+ SmallVector<mlir::Value> unpacked;
+ Value structVal = inlineAsmOp.getResult(0);
+ for (auto [idx, elemTy] : llvm::enumerate(structTy.getBody())) {
+ Value unpackedValue = LLVM::ExtractValueOp::create(
+ rewriter, interfaceOp->getLoc(), structVal, idx);
+ unpacked.push_back(unpackedValue);
+ }
+
+ rewriter.replaceOp(interfaceOp, unpacked);
}
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 24873340d7122..b38347c7cd1b7 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -683,6 +683,18 @@ llvm.func @ex2(%input : f32, %pred : i1) {
llvm.return
}
+// CHECK-LABEL: @multi_return(
+// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: i32, %[[arg1:[a-zA-Z0-9_]+]]: i32)
+llvm.func @multi_return(%a : i32, %b : i32) -> i32 {
+ // CHECK: %[[S1:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09 .reg .pred p;\0A\09 setp.ge.s32 p, $2, $3;\0A\09 selp.s32 $0, $2, $3, p;\0A\09 selp.s32 $1, $2, $3, !p;\0A\09}\0A", "=r,=r,r,r" %[[arg0]], %[[arg1]] : (i32, i32) -> !llvm.struct<(i32, i32)>
+ // CHECK: %[[S2:.+]] = llvm.extractvalue %[[S1]][0] : !llvm.struct<(i32, i32)>
+ // CHECK: %[[S3:.+]] = llvm.extractvalue %[[S1]][1] : !llvm.struct<(i32, i32)>
+ // CHECK: %[[S4:.+]] = llvm.add %[[S2]], %[[S3]] : i32
+ // CHECK: llvm.return %[[S4]] : i32
+ %r1, %r2 = nvvm.inline_ptx "{\n\t .reg .pred p;\n\t setp.ge.s32 p, $2, $3;\n\t selp.s32 $0, $2, $3, p;\n\t selp.s32 $1, $2, $3, !p;\n\t}\n" (%a, %b) : i32,i32 -> i32,i32
+ %r3 = llvm.add %r1, %r2 : i32
+ llvm.return %r3 : i32
+}
// -----
// CHECK-LABEL: @nvvm_pmevent
More information about the Mlir-commits
mailing list