[Mlir-commits] [mlir] [MLIR][NVVM] Add support for multiple return values in `inline_ptx` (PR #153774)

Guray Ozen llvmlistbot at llvm.org
Fri Aug 15 03:04:11 PDT 2025


https://github.com/grypp created https://github.com/llvm/llvm-project/pull/153774

This PR adds the ability for `nvvm.inline_ptx` to return multiple values, matching the expected semantics in PTX while respecting LLVM’s constraints.

LLVM’s `inline_asm` op does not natively support multiple returns — instead, it requires packing results into an LLVM `struct` and then extracting them. This PR implements automatic packing/unpacking so that multiple return values can be expressed naturally in MLIR without extra user boilerplate.

**Example**
MLIR:

```
%r1, %r2 = nvvm.inline_ptx  "{
   .reg .pred p;
   setp.ge.s32 p, $2, $3;
   selp.s32 $0, $2, $3, p;
   selp.s32 $1, $2, $3, !p;
}" (%a, %b) : i32, i32 -> i32, i32

%r3 = llvm.add %r1, %r2 : i32
```

Lowered LLVM IR:

```
%1 = llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09 .reg .pred p;\0A\09 setp.ge.s32 p, $2, $3;\0A\09 selp.s32 $0, $2, $3, p;\0A\09 selp.s32 $1, $2, $3, !p;\0A\09}\0A", "=r,=r,r,r" %a, %b : (i32, i32) -> !llvm.struct<(i32, i32)>
%2 = llvm.extractvalue %1[0] : !llvm.struct<(i32, i32)>
%3 = llvm.extractvalue %1[1] : !llvm.struct<(i32, i32)>
%4 = llvm.add %2, %3 : i32
```

>From 8e305119dcc1ab2b343f939c9548871895ea14ca Mon Sep 17 00:00:00 2001
From: Guray Ozen <gozen at nvidia.com>
Date: Fri, 15 Aug 2025 10:03:26 +0000
Subject: [PATCH] [MLIR][NVVM] Add support for multiple return values in
 `inline_ptx`
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

This PR adds the ability for `nvvm.inline_ptx` to return multiple values, matching the expected semantics in PTX while respecting LLVM’s constraints.

LLVM’s `inline_asm` op does not natively support multiple returns — instead, it requires packing results into an LLVM `struct` and then extracting them. This PR implements automatic packing/unpacking so that multiple return values can be expressed naturally in MLIR without extra user boilerplate.

**Example**
MLIR:

```
%r1, %r2 = nvvm.inline_ptx  "{
   .reg .pred p;
   setp.ge.s32 p, $2, $3;
   selp.s32 $0, $2, $3, p;
   selp.s32 $1, $2, $3, !p;
}" (%a, %b) : i32, i32 -> i32, i32

%r3 = llvm.add %r1, %r2 : i32
```

Lowered LLVM IR:

```
%1 = llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09 .reg .pred p;\0A\09 setp.ge.s32 p, $2, $3;\0A\09 selp.s32 $0, $2, $3, p;\0A\09 selp.s32 $1, $2, $3, !p;\0A\09}\0A", "=r,=r,r,r" %a, %b : (i32, i32) -> !llvm.struct<(i32, i32)>
%2 = llvm.extractvalue %1[0] : !llvm.struct<(i32, i32)>
%3 = llvm.extractvalue %1[1] : !llvm.struct<(i32, i32)>
%4 = llvm.add %2, %3 : i32
```
---
 mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp |  2 +-
 .../LLVMIR/IR/BasicPtxBuilderInterface.cpp    | 48 +++++++++++++++++--
 .../Conversion/NVVMToLLVM/nvvm-to-llvm.mlir   | 12 +++++
 3 files changed, 56 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
index 91788f9848fe6..3fe9c87a79039 100644
--- a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
+++ b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
@@ -61,7 +61,7 @@ struct PtxLowering
 
     op.getAsmValues(rewriter, asmValues);
     for (auto &[asmValue, modifier] : asmValues) {
-      LDBG() << asmValue << "\t Modifier : " << &modifier;
+      LDBG() << asmValue << "\t Modifier : " << (int)modifier;
       generator.insertValue(asmValue, modifier);
     }
 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
index 894de4408c375..7d64d6b1fe18a 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
@@ -107,11 +107,28 @@ void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
   ss << getModifier() << getRegisterType(v) << ",";
 }
 
+static bool needsPackUnpack(BasicPtxBuilderInterface interfaceOp) {
+  return interfaceOp->getNumResults() > 1;
+}
+
+static SmallVector<Type>
+coalesceResultTypes(MLIRContext *ctx, BasicPtxBuilderInterface interfaceOp) {
+  TypeRange results = interfaceOp->getResultTypes();
+
+  if (!needsPackUnpack(interfaceOp))
+    return llvm::to_vector<1>(results);
+
+  SmallVector<mlir::Type> elems(results.begin(), results.end());
+  auto sTy = LLVM::LLVMStructType::getLiteral(ctx, elems, /*isPacked=*/false);
+  return {sTy};
+}
+
 LLVM::InlineAsmOp PtxBuilder::build() {
+  MLIRContext *ctx = interfaceOp->getContext();
   auto asmDialectAttr = LLVM::AsmDialectAttr::get(interfaceOp->getContext(),
                                                   LLVM::AsmDialect::AD_ATT);
 
-  auto resultTypes = interfaceOp->getResultTypes();
+  SmallVector<Type> resultTypes = coalesceResultTypes(ctx, interfaceOp);
 
   // Remove the last comma from the constraints string.
   if (!registerConstraints.empty() &&
@@ -136,7 +153,7 @@ LLVM::InlineAsmOp PtxBuilder::build() {
       rewriter, interfaceOp->getLoc(),
       /*result types=*/resultTypes,
       /*operands=*/ptxOperands,
-      /*asm_string=*/llvm::StringRef(ptxInstruction),
+      /*asm_string=*/StringRef(ptxInstruction),
       /*constraints=*/registerConstraints.data(),
       /*has_side_effects=*/interfaceOp.hasSideEffect(),
       /*is_align_stack=*/false, LLVM::TailCallKind::None,
@@ -147,9 +164,30 @@ LLVM::InlineAsmOp PtxBuilder::build() {
 void PtxBuilder::buildAndReplaceOp() {
   LLVM::InlineAsmOp inlineAsmOp = build();
   LLVM_DEBUG(DBGS() << "\n Generated PTX \n\t" << inlineAsmOp << "\n");
-  if (inlineAsmOp->getNumResults() == interfaceOp->getNumResults()) {
-    rewriter.replaceOp(interfaceOp, inlineAsmOp);
-  } else {
+
+  // Case 1: no result
+  if (inlineAsmOp->getNumResults() == 0) {
     rewriter.eraseOp(interfaceOp);
+    return;
   }
+
+  // Case 2: single result, forward it directly
+  if (!needsPackUnpack(interfaceOp)) {
+    rewriter.replaceOp(interfaceOp, inlineAsmOp->getResults());
+    return;
+  }
+
+  // Case 3: multiple results were packed; unpack the struct.
+  auto structTy = llvm::cast<mlir::LLVM::LLVMStructType>(
+      inlineAsmOp.getResultTypes().front());
+
+  SmallVector<mlir::Value> unpacked;
+  Value structVal = inlineAsmOp.getResult(0);
+  for (auto [idx, elemTy] : llvm::enumerate(structTy.getBody())) {
+    Value unpackedValue = LLVM::ExtractValueOp::create(
+        rewriter, interfaceOp->getLoc(), structVal, idx);
+    unpacked.push_back(unpackedValue);
+  }
+
+  rewriter.replaceOp(interfaceOp, unpacked);
 }
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 24873340d7122..b38347c7cd1b7 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -683,6 +683,18 @@ llvm.func @ex2(%input : f32, %pred : i1) {
   llvm.return
 }
 
+// CHECK-LABEL: @multi_return(
+// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: i32, %[[arg1:[a-zA-Z0-9_]+]]: i32)
+llvm.func @multi_return(%a : i32, %b : i32) -> i32 {
+  // CHECK: %[[S1:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09 .reg .pred p;\0A\09 setp.ge.s32 p, $2, $3;\0A\09 selp.s32 $0, $2, $3, p;\0A\09 selp.s32 $1, $2, $3, !p;\0A\09}\0A", "=r,=r,r,r" %[[arg0]], %[[arg1]] : (i32, i32) -> !llvm.struct<(i32, i32)>
+  // CHECK: %[[S2:.+]] = llvm.extractvalue %[[S1]][0] : !llvm.struct<(i32, i32)> 
+  // CHECK: %[[S3:.+]] = llvm.extractvalue %[[S1]][1] : !llvm.struct<(i32, i32)> 
+  // CHECK: %[[S4:.+]] = llvm.add %[[S2]], %[[S3]] : i32
+  // CHECK: llvm.return %[[S4]] : i32
+   %r1, %r2 = nvvm.inline_ptx  "{\n\t .reg .pred p;\n\t setp.ge.s32 p, $2, $3;\n\t selp.s32 $0, $2, $3, p;\n\t selp.s32 $1, $2, $3, !p;\n\t}\n" (%a, %b) : i32,i32 -> i32,i32
+   %r3 = llvm.add %r1, %r2 : i32
+   llvm.return %r3 : i32
+}
 // -----
 
 // CHECK-LABEL: @nvvm_pmevent



More information about the Mlir-commits mailing list