[Mlir-commits] [mlir] [MLIR][NVVM] Improve inline_ptx, add readwrite support (PR #154358)

Guray Ozen llvmlistbot at llvm.org
Thu Aug 21 00:13:59 PDT 2025


https://github.com/grypp updated https://github.com/llvm/llvm-project/pull/154358

>From edaa2d4052b5b261b619feafbf78f35709c982be Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Tue, 19 Aug 2025 15:13:06 +0000
Subject: [PATCH 1/4] =?UTF-8?q?[MLIR][NVVM]=20Improve=20inline=5Fptx,=20ad?=
 =?UTF-8?q?d=20readwrite=20support=20Key=20Features=201.=20Multiple=20SSA?=
 =?UTF-8?q?=20returns=20=E2=80=93=20no=20struct=20packing/unpacking=20requ?=
 =?UTF-8?q?ired.=202.=20Automatic=20struct=20unpacking=20=E2=80=93=20value?=
 =?UTF-8?q?s=20are=20directly=20usable.=203.=20Readable=20register=20mappi?=
 =?UTF-8?q?ng=20=20=20=20=20*=20{$rwN}=20=E2=86=92=20read-write=20=20=20?=
 =?UTF-8?q?=20=20*=20{$roN}=20=E2=86=92=20read-only=20=20=20=20=20*=20{$wo?=
 =?UTF-8?q?N}=20=E2=86=92=20write-only=204.=20Full=20read-write=20support?=
 =?UTF-8?q?=20(+=20modifier).=205.=20Simplified=20operand=20specification?=
 =?UTF-8?q?=20=E2=80=93=20avoids=20cryptic=20"=3Dr,=3Dr,=3Df,=3Df,f,f,0,1"?=
 =?UTF-8?q?=20constraints.=206.=20Predicate=20support:=20PTX=20 at p=20predic?=
 =?UTF-8?q?ation=20support?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

IR Example:
```
%wo0, %wo1 = nvvm.inline_ptx """
 .reg .pred p;
 setp.ge.s32 p,   {$r0}, {$r1};
 selp.s32 {$rw0}, {$r0}, {$r1}, p;
 selp.s32 {$rw1}, {$r0}, {$r1}, p;
 selp.s32 {$w0},  {$r0}, {$r1}, p;
 selp.s32 {$w1},  {$r0}, {$r1}, p;
""" ro(%a, %b : f32, f32) rw(%c, %d : i32, i32) -> f32, f32
```

After lowering
```
 %0 = llvm.inline_asm has_side_effects asm_dialect = att
 "{
                              .reg .pred p;\
                              setp.ge.s32 p, $4, $5;   \
                              selp.s32   $0, $4, $5, p;\
                              selp.s32   $1, $4, $5, p;\
                              selp.s32   $2, $4, $5, p;\
                              selp.s32   $3, $4, $5, p;\
   }"
   "=r,=r,=f,=f,f,f,0,1"
   %c500_i32, %c400_i32, %cst, %cst_0
   : (i32, i32, f32, f32)
   -> !llvm.struct<(i32, i32, f32, f32)>

 %1 = llvm.extractvalue %0 : !llvm.struct<(i32, i32, f32, f32)>
 %2 = llvm.extractvalue %0 : !llvm.struct<(i32, i32, f32, f32)>
 %3 = llvm.extractvalue %0 : !llvm.struct<(i32, i32, f32, f32)>
 %4 = llvm.extractvalue %0 : !llvm.struct<(i32, i32, f32, f32)>

 // Unpacked result from nvvm.inline_ptx
 %5 = arith.addi %1, %2 : i32
 // read only
 %6 = arith.addf %cst, %cst_0 : f32
 // write only
 %7 = arith.addf %3, %4 : f32
```
---
 .../Dialect/LLVMIR/BasicPtxBuilderInterface.h |  14 +-
 .../LLVMIR/BasicPtxBuilderInterface.td        |  13 +-
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td   |  22 +-
 mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp |   4 +-
 .../LLVMIR/IR/BasicPtxBuilderInterface.cpp    | 321 ++++++++++++++++--
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp    |  19 +-
 .../Conversion/NVVMToLLVM/nvvm-to-llvm.mlir   |  60 +++-
 mlir/test/python/dialects/nvvm.py             |  41 +++
 8 files changed, 435 insertions(+), 59 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h b/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h
index 3e3fcd7d1fb82..99b1d9709e3e1 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h
@@ -26,9 +26,9 @@ namespace NVVM {
 enum class PTXRegisterMod {
   /// Read register with no modifier
   Read = 0,
-  /// Read register with '+' modifier
+  /// Read register with '=' modifier
   Write = 2,
-  /// Read register with '=' modifier.
+  /// Read register with '+' modifier.
   /// Note that, this is not natively supported by LLVM, but it is possible to
   /// set read and write for the same operand.
   ReadWrite = 1,
@@ -67,13 +67,17 @@ class PtxBuilder {
   SmallVector<Value> ptxOperands;
   // Register constraints (read, write, readwrite) and register data types
   std::string registerConstraints;
-
+  // Modifiers
+  SmallVector<PTXRegisterMod> registerModifiers;
   bool hasResult = false;
+  bool needsManualMapping = false;
 
 public:
   /// Single constructor that only initializes members.
-  PtxBuilder(Operation *op, PatternRewriter &rewriter)
-      : interfaceOp(op), rewriter(rewriter) {}
+  PtxBuilder(Operation *op, PatternRewriter &rewriter,
+             bool needsManualMapping = false)
+      : interfaceOp(op), rewriter(rewriter),
+        needsManualMapping(needsManualMapping) {}
 
   /// Add an operand with the read/write input type.
   void insertValue(Value v, PTXRegisterMod itype = PTXRegisterMod::Read);
diff --git a/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td b/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td
index e98b94b5b3052..8e36749cdb361 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td
@@ -124,19 +124,21 @@ def BasicPtxBuilderOpInterface : OpInterface<"BasicPtxBuilderInterface"> {
             following this order:
              1) Adds results 
              2) Adds operands 
-             3) Adds attributes             
+             3) Adds attributes
+             Returns true if it does the mapping manually
           }],
-         /*retType=*/"void",
+         /*retType=*/"bool",
          /*methodName=*/"getAsmValues",
          /*args=*/(ins "::mlir::RewriterBase &":$rewriter, 
-         "llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>&" : $asmValues),
+         "llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>&" : $asmValues         
+         ),
          /*methodBody=*/"",
          /*defaultImpl=*/ [{         
            mlir::Operation* op = $_op;
            
            // Step 1. Add results
-           for (auto val : op->getResults()) 
-            asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Write});
+            for (auto val : op->getResults()) 
+              asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Write});
 
            // Step 2. Add operands
            for (auto val : op->getOperands()) 
@@ -149,6 +151,7 @@ def BasicPtxBuilderOpInterface : OpInterface<"BasicPtxBuilderInterface"> {
              asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
              }
            }
+           return false; // No needs manual mapping
          }]
        >
   ];
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index f9cd58de8915f..786d42cf15666 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -315,16 +315,19 @@ def NVVM_InlinePtxOp : NVVM_Op<"inline_ptx",
   }];
 
   let arguments = (ins Variadic<AnyType>:$readOnlyArgs, 
+                       Variadic<AnyType>:$readWriteArgs, 
                        StrAttr:$ptxCode,
                        PtxPredicate:$predicate);
                       
   let results = (outs Variadic<AnyType>:$writeOnlyArgs);
-
-  let assemblyFormat = [{ 
-    $ptxCode `(` $readOnlyArgs `)` 
-    (`,` `predicate` `=` $predicate^)? attr-dict 
-    `:` type(operands) 
-    (`->` type($writeOnlyArgs)^)?
+  
+  let assemblyFormat = [{
+    $ptxCode
+    ( `ro` `(` $readOnlyArgs^ `:` type($readOnlyArgs) `)` )?
+    ( `rw` `(` $readWriteArgs^ `:` type($readWriteArgs) `)` )?
+    (`,` `predicate` `=` $predicate^)? 
+    attr-dict
+    ( `->` type($writeOnlyArgs)^ )?
   }];
   
   let extraClassDefinition = [{
@@ -333,6 +336,10 @@ def NVVM_InlinePtxOp : NVVM_Op<"inline_ptx",
       return std::string(ptxInstStr.data());
     }
   }];
+
+  let extraClassDeclaration = [{
+    bool getAsmValues(RewriterBase &, llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> &);
+  }];
 }
 
 //===----------------------------------------------------------------------===//
@@ -3027,8 +3034,7 @@ def NVVM_WgmmaMmaAsyncOp : NVVM_Op<"wgmma.mma_async",
   let hasVerifier = 1;
 
   let extraClassDeclaration = [{
-    void getAsmValues(RewriterBase &rewriter, 
-        llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> &asmValues);
+    bool getAsmValues(RewriterBase &, llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> &);
   }];
 }
 
diff --git a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
index e0144bff4d371..c67ec3642f121 100644
--- a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
+++ b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
@@ -57,9 +57,9 @@ struct PtxLowering
 
     SmallVector<std::pair<Value, PTXRegisterMod>> asmValues;
     LDBG() << op.getPtx();
-    PtxBuilder generator(op, rewriter);
 
-    op.getAsmValues(rewriter, asmValues);
+    bool needsManualMapping = op.getAsmValues(rewriter, asmValues);
+    PtxBuilder generator(op, rewriter, needsManualMapping);
     for (auto &[asmValue, modifier] : asmValues) {
       LDBG() << asmValue << "\t Modifier : " << modifier;
       generator.insertValue(asmValue, modifier);
diff --git a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
index e004d5f64733e..3cad9d3bd16e3 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
@@ -12,6 +12,9 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/Regex.h"
 
 #define DEBUG_TYPE "ptx-builder"
 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
@@ -59,12 +62,28 @@ static char getRegisterType(Value v) {
   return getRegisterType(v.getType());
 }
 
+/// Extract every elements of a struct value.
+static SmallVector<Value> extractStructElements(PatternRewriter &rewriter,
+                                                Location loc, Value agg) {
+  auto structTy = cast<LLVM::LLVMStructType>(agg.getType());
+  SmallVector<Value> elems;
+  elems.reserve(structTy.getBody().size());
+  for (auto [i, t] : llvm::enumerate(structTy.getBody())) {
+    (void)t;
+    Value e = LLVM::ExtractValueOp::create(rewriter, loc, agg, i);
+    elems.push_back(e);
+  }
+  return elems;
+}
+
 void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
-  LLVM_DEBUG(DBGS() << v << "\t Modifier : " << &itype << "\n");
+  LLVM_DEBUG(DBGS() << v << "\t Modifier : " << itype << "\n");
+  registerModifiers.push_back(itype);
+
   auto getModifier = [&]() -> const char * {
     if (itype == PTXRegisterMod::ReadWrite) {
-      assert(false && "Read-Write modifier is not supported. Try setting the "
-                      "same value as Write and Read separately.");
+      // "Read-Write modifier is not supported
+      // Interface canonicalize it later
       return "+";
     }
     if (itype == PTXRegisterMod::Write) {
@@ -72,6 +91,7 @@ void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
     }
     return "";
   };
+
   auto addValue = [&](Value v) {
     if (itype == PTXRegisterMod::Read) {
       ptxOperands.push_back(v);
@@ -108,38 +128,222 @@ void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
 }
 
 /// Check if the operation needs to pack and unpack results.
-static bool needsPackUnpack(BasicPtxBuilderInterface interfaceOp) {
-  return interfaceOp->getNumResults() > 1;
+static bool
+needsPackUnpack(BasicPtxBuilderInterface interfaceOp, bool needsManualMapping,
+                SmallVectorImpl<PTXRegisterMod> &registerModifiers) {
+  if (needsManualMapping)
+    return false;
+  const unsigned writeOnly = interfaceOp->getNumResults();
+  const unsigned readWrite =
+      llvm::count_if(registerModifiers, [](PTXRegisterMod m) {
+        return m == PTXRegisterMod::ReadWrite;
+      });
+  return (writeOnly + readWrite) > 1;
 }
 
 /// Pack the result types of the interface operation.
 /// If the operation has multiple results, it packs them into a struct
 /// type. Otherwise, it returns the original result types.
-static SmallVector<Type> packResultTypes(MLIRContext *ctx,
-                                         BasicPtxBuilderInterface interfaceOp) {
-  TypeRange results = interfaceOp->getResultTypes();
+static SmallVector<Type>
+packResultTypes(BasicPtxBuilderInterface interfaceOp, bool needsManualMapping,
+                SmallVectorImpl<PTXRegisterMod> &registerModifiers,
+                SmallVectorImpl<Value> &ptxOperands) {
+  MLIRContext *ctx = interfaceOp->getContext();
+  TypeRange resultRange = interfaceOp->getResultTypes();
+
+  if (!needsPackUnpack(interfaceOp, needsManualMapping, registerModifiers)) {
+    // Single value path:
+    if (interfaceOp->getResults().size() == 1)
+      return SmallVector<Type>{resultRange.front()};
+
+    // No declared results: if there is an RW, forward its type.
+    for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands))
+      if (m == PTXRegisterMod::ReadWrite)
+        return SmallVector<Type>{v.getType()};
+  }
+
+  SmallVector<Type> packed;
+  for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands))
+    if (m == PTXRegisterMod::ReadWrite)
+      packed.push_back(v.getType());
+  for (Type t : resultRange)
+    packed.push_back(t);
+
+  if (packed.empty())
+    return {};
+
+  auto sTy = LLVM::LLVMStructType::getLiteral(ctx, packed, /*isPacked=*/false);
+  return SmallVector<Type>{sTy};
+}
+
+/// Canonicalize the register constraints:
+///  - Turn every "+X" into "=X"
+///  - Append (at the very end) the 0-based indices of tokens that were "+X"
+/// Examples:
+///  "+f,+f,+r,=r,=r,r,r" -> "=f,=f,=r,=r,=r,r,r,0,1,2"
+///  "+f,+f,+r,=r,=r"     -> "=f,=f,=r,=r,=r,0,1,2"
+static std::string canonicalizeRegisterConstraints(llvm::StringRef csv) {
+  SmallVector<llvm::StringRef> toks;
+  SmallVector<std::string> out;
+  SmallVector<unsigned> plusIdx;
+
+  csv.split(toks, ',');
+  out.reserve(toks.size() + 8);
+
+  for (unsigned i = 0, e = toks.size(); i < e; ++i) {
+    StringRef t = toks[i].trim();
+    if (t.consume_front("+")) {
+      plusIdx.push_back(i);
+      out.push_back(("=" + t).str());
+    } else {
+      out.push_back(t.str());
+    }
+  }
+
+  // Append indices of original "+X" tokens.
+  for (unsigned idx : plusIdx)
+    out.push_back(std::to_string(idx));
+
+  // Join back to CSV.
+  std::string result;
+  result.reserve(csv.size() + plusIdx.size() * 2);
+  llvm::raw_string_ostream os(result);
+  for (size_t i = 0; i < out.size(); ++i) {
+    if (i)
+      os << ',';
+    os << out[i];
+  }
+  return os.str();
+}
+
+constexpr llvm::StringLiteral kReadWrite{"rw"};
+constexpr llvm::StringLiteral kWriteOnly{"w"};
+constexpr llvm::StringLiteral kReadOnly{"r"};
+
+/// Rewrites placeholders of the form `{$rN}`, `{$wN}`, `{$rwN}` in `asmText`
+/// to compact `$K` indices where all `rw*` come first (ascending N), then `w*`,
+/// then `r*`. Duplicates are de-duplicated when assigning numbers.
+/// Unknown text is preserved verbatim.
+///
+/// Example Input:
+/// "{
+///       reg .pred p;
+///       setp.ge.s32 p,   {$r0}, {$r1};"
+///       selp.s32 {$rw0}, {$r0}, {$r1}, p;
+///       selp.s32 {$rw1}, {$r0}, {$r1}, p;
+///       selp.s32 {$w0},  {$r0}, {$r1}, p;
+///       selp.s32 {$w1},  {$r0}, {$r1}, p;
+/// }\n"
+/// Example Output:
+/// "{
+///       reg .pred p;
+///       setp.ge.s32 p, $4, $5;"
+///       selp.s32 $0,   $4, $5, p;
+///       selp.s32 $1,   $4, $5, p;
+///       selp.s32 $2,   $4, $5, p;
+///       selp.s32 $3,   $4, $5, p;
+/// }\n"
+static std::string rewriteAsmPlaceholders(llvm::StringRef asmText) {
+  // Match {$rwN}, {$wN}, {$rN}
+  llvm::Regex rx(llvm::formatv(R"(\{\$({0}|{1}|{2})([0-9]+)\})", kReadWrite,
+                               kWriteOnly, kReadOnly)
+                     .str());
+
+  llvm::SmallDenseSet<unsigned> seenRW, seenW, seenR;
+  llvm::SmallVector<unsigned> rwNums, wNums, rNums;
+
+  {
+    StringRef rest = asmText;
+    SmallVector<StringRef, 3> m; // 0: full, 1: kind, 2: number
+    while (!rest.empty() && rx.match(rest, &m)) {
+      unsigned num = 0;
+      (void)m[2].getAsInteger(10, num);
+
+      if (m[1].equals_insensitive(kReadWrite)) {
+        if (seenRW.insert(num).second)
+          rwNums.push_back(num);
+      } else if (m[1].equals_insensitive(kWriteOnly)) {
+        if (seenW.insert(num).second)
+          wNums.push_back(num);
+      } else {
+        if (seenR.insert(num).second)
+          rNums.push_back(num);
+      }
+
+      const size_t advance = (size_t)(m[0].data() - rest.data()) + m[0].size();
+      rest = rest.drop_front(advance);
+    }
+  }
+
+  llvm::sort(rwNums);
+  llvm::sort(wNums);
+  llvm::sort(rNums);
+
+  llvm::DenseMap<unsigned, unsigned> rwMap, wMap, rMap;
+  unsigned nextId = 0;
+  for (unsigned n : rwNums)
+    rwMap[n] = nextId++;
+  for (unsigned n : wNums)
+    wMap[n] = nextId++;
+  for (unsigned n : rNums)
+    rMap[n] = nextId++;
+
+  std::string out;
+  out.reserve(asmText.size());
 
-  if (!needsPackUnpack(interfaceOp))
-    return llvm::to_vector<1>(results);
+  size_t prev = 0;
+  StringRef rest = asmText;
+  SmallVector<StringRef, 3> m;
+  while (!rest.empty() && rx.match(rest, &m)) {
+    // Compute absolute match bounds in the original buffer.
+    size_t absStart = (size_t)(m[0].data() - asmText.data());
+    size_t absEnd = absStart + m[0].size();
 
-  SmallVector<mlir::Type> elems(results.begin(), results.end());
-  auto sTy = LLVM::LLVMStructType::getLiteral(ctx, elems, /*isPacked=*/false);
-  return {sTy};
+    // Emit text before the match.
+    out.append(asmText.data() + prev, asmText.data() + absStart);
+
+    // Emit compact $K
+    unsigned num = 0;
+    (void)m[2].getAsInteger(10, num);
+    unsigned id = 0;
+    if (m[1].equals_insensitive(kReadWrite))
+      id = rwMap.lookup(num);
+    else if (m[1].equals_insensitive(kWriteOnly))
+      id = wMap.lookup(num);
+    else
+      id = rMap.lookup(num);
+
+    out.push_back('$');
+    out += std::to_string(id);
+
+    prev = absEnd;
+
+    // Advance search window.
+    const size_t advance = (size_t)(m[0].data() - rest.data()) + m[0].size();
+    rest = rest.drop_front(advance);
+  }
+
+  // Tail.
+  out.append(asmText.data() + prev, asmText.data() + asmText.size());
+  return out;
 }
 
 LLVM::InlineAsmOp PtxBuilder::build() {
-  MLIRContext *ctx = interfaceOp->getContext();
   auto asmDialectAttr = LLVM::AsmDialectAttr::get(interfaceOp->getContext(),
                                                   LLVM::AsmDialect::AD_ATT);
 
-  SmallVector<Type> resultTypes = packResultTypes(ctx, interfaceOp);
+  SmallVector<Type> resultTypes = packResultTypes(
+      interfaceOp, needsManualMapping, registerModifiers, ptxOperands);
 
   // Remove the last comma from the constraints string.
   if (!registerConstraints.empty() &&
       registerConstraints[registerConstraints.size() - 1] == ',')
     registerConstraints.pop_back();
+  registerConstraints = canonicalizeRegisterConstraints(registerConstraints);
 
   std::string ptxInstruction = interfaceOp.getPtx();
+  if (!needsManualMapping)
+    ptxInstruction = rewriteAsmPlaceholders(ptxInstruction);
 
   // Add the predicate to the asm string.
   if (interfaceOp.getPredicate().has_value() &&
@@ -169,33 +373,86 @@ void PtxBuilder::buildAndReplaceOp() {
   LLVM::InlineAsmOp inlineAsmOp = build();
   LLVM_DEBUG(DBGS() << "\n Generated PTX \n\t" << inlineAsmOp << "\n");
 
-  // Case 1: no result
-  if (inlineAsmOp->getNumResults() == 0) {
+  // Case 0: no result at all → just erase wrapper op.
+  if (!hasResult) {
     rewriter.eraseOp(interfaceOp);
     return;
   }
 
-  // Case 2: single result, forward it directly
-  if (!needsPackUnpack(interfaceOp)) {
+  if (needsManualMapping) {
     rewriter.replaceOp(interfaceOp, inlineAsmOp->getResults());
     return;
   }
 
-  // Case 3: multiple results were packed; unpack the struct.
-  assert(mlir::LLVM::LLVMStructType::classof(
-             inlineAsmOp.getResultTypes().front()) &&
-         "Expected result type to be LLVMStructType when unpacking multiple "
-         "results");
-  auto structTy = llvm::cast<mlir::LLVM::LLVMStructType>(
-      inlineAsmOp.getResultTypes().front());
+  // Case 1: Simple path, return single scalar
+  if (!needsPackUnpack(interfaceOp, needsManualMapping, registerModifiers)) {
+    if (inlineAsmOp->getNumResults() > 0) {
+      rewriter.replaceOp(interfaceOp, inlineAsmOp->getResults());
+    } else {
+      // RW-only case with no declared results: forward the RW value.
+      SmallVector<Value> results;
+      for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands))
+        if (m == PTXRegisterMod::ReadWrite) {
+          results.push_back(v);
+          break;
+        }
+      rewriter.replaceOp(interfaceOp, results);
+    }
+    return;
+  }
+
+  const bool hasRW = llvm::any_of(registerModifiers, [](PTXRegisterMod m) {
+    return m == PTXRegisterMod::ReadWrite;
+  });
 
-  SmallVector<mlir::Value> unpacked;
+  // All multi-value paths produce a single struct result we need to unpack.
+  assert(LLVM::LLVMStructType::classof(inlineAsmOp.getResultTypes().front()) &&
+         "expected struct return for multi-result inline asm");
   Value structVal = inlineAsmOp.getResult(0);
-  for (auto [idx, elemTy] : llvm::enumerate(structTy.getBody())) {
-    Value unpackedValue = LLVM::ExtractValueOp::create(
-        rewriter, interfaceOp->getLoc(), structVal, idx);
-    unpacked.push_back(unpackedValue);
+  SmallVector<Value> unpacked =
+      extractStructElements(rewriter, interfaceOp->getLoc(), structVal);
+
+  // Case 2: only declared results (no RW): replace the op with all unpacked.
+  if (!hasRW && interfaceOp->getResults().size() > 0) {
+    rewriter.replaceOp(interfaceOp, unpacked);
+    return;
+  }
+
+  // Case 3: RW-only (no declared results): update RW uses and erase wrapper.
+  if (hasRW && interfaceOp->getResults().size() == 0) {
+    unsigned idx = 0;
+    for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands)) {
+      if (m != PTXRegisterMod::ReadWrite)
+        continue;
+      Value repl = unpacked[idx++];
+      v.replaceUsesWithIf(repl, [&](OpOperand &use) {
+        Operation *owner = use.getOwner();
+        return owner != interfaceOp && owner != inlineAsmOp;
+      });
+    }
+    rewriter.eraseOp(interfaceOp);
+    return;
   }
 
-  rewriter.replaceOp(interfaceOp, unpacked);
+  // Case 4: mixed (RW + declared results).
+  {
+    // First rewrite RW operands in place.
+    unsigned idx = 0;
+    for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands)) {
+      if (m != PTXRegisterMod::ReadWrite)
+        continue;
+      Value repl = unpacked[idx++];
+      v.replaceUsesWithIf(repl, [&](OpOperand &use) {
+        Operation *owner = use.getOwner();
+        return owner != interfaceOp && owner != inlineAsmOp;
+      });
+    }
+    // The remaining unpacked values correspond to the declared results.
+    SmallVector<Value> tail;
+    tail.reserve(unpacked.size() - idx);
+    for (unsigned i = idx, e = unpacked.size(); i < e; ++i)
+      tail.push_back(unpacked[i]);
+
+    rewriter.replaceOp(interfaceOp, tail);
+  }
 }
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index dbcc738b4419f..ae9134458095f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1123,7 +1123,7 @@ std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
   return ptx;
 }
 
-void NVVM::WgmmaMmaAsyncOp::getAsmValues(
+bool NVVM::WgmmaMmaAsyncOp::getAsmValues(
     RewriterBase &rewriter,
     llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
         &asmValues) {
@@ -1154,7 +1154,9 @@ void NVVM::WgmmaMmaAsyncOp::getAsmValues(
         {makeConstantI32(rewriter, 1 - static_cast<int>(getLayoutB())),
          mlir::NVVM::PTXRegisterMod::Read});
   }
+  return true; // Has manual mapping
 }
+
 LogicalResult NVVM::FenceProxyOp::verify() {
   if (getKind() == NVVM::ProxyKind::TENSORMAP)
     return emitOpError() << "tensormap proxy is not a supported proxy kind";
@@ -1870,6 +1872,21 @@ llvm::Intrinsic::ID PrefetchOp::getIntrinsicID(NVVM::PrefetchOp &op) {
   }
 }
 
+bool NVVM::InlinePtxOp::getAsmValues(
+    RewriterBase &rewriter,
+    llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
+        &asmValues) {
+  for (auto arg : getReadWriteArgs())
+    asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::ReadWrite});
+  for (auto arg : getResults())
+    asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::Write});
+  for (auto arg : getReadOnlyArgs())
+    asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::Read});
+  if (getPredicate())
+    asmValues.push_back({getPredicate(), mlir::NVVM::PTXRegisterMod::Read});
+  return false; // Needs manual mapping
+}
+
 //===----------------------------------------------------------------------===//
 // NVVMDialect initialization, type parsing, and registration.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index b38347c7cd1b7..2a19c72ab0840 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -667,34 +667,82 @@ llvm.func @init_mbarrier(
     %count : i32, 
     %pred : i1) {
   // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.init.b64 [$0], $1;", "l,r" 
-  nvvm.inline_ptx "mbarrier.init.b64 [$0], $1;" (%barrier_gen, %count)  : !llvm.ptr, i32 
+  nvvm.inline_ptx "mbarrier.init.b64 [{$r0}], {$r1};" ro (%barrier_gen, %count : !llvm.ptr, i32)
   // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.init.b64 [$0], $1;", "l,r,b"
-  nvvm.inline_ptx "mbarrier.init.b64 [$0], $1;" (%barrier_gen, %count), predicate = %pred : !llvm.ptr, i32, i1
+  nvvm.inline_ptx "mbarrier.init.b64 [{$r0}], {$r1};" ro (%barrier_gen, %count : !llvm.ptr, i32), predicate = %pred
   llvm.return
 }
 // -----
 
 llvm.func @ex2(%input : f32, %pred : i1) {
   // CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "ex2.approx.ftz.f32 $0, $1;", "=f,f" %{{.*}} : (f32) -> f32
-  %0 = nvvm.inline_ptx "ex2.approx.ftz.f32 $0, $1;" (%input) : f32 -> f32
+  %0 = nvvm.inline_ptx "ex2.approx.ftz.f32 {$w0}, {$r0};" ro (%input : f32) -> f32
   
   // CHECK: %{{.*}} =  llvm.inline_asm has_side_effects asm_dialect = att "@$1 ex2.approx.ftz.f32 $0, $1;", "=f,f,b" %{{.*}}, %{{.*}} : (f32, i1) -> f32
-  %1 = nvvm.inline_ptx "ex2.approx.ftz.f32 $0, $1;" (%input), predicate = %pred  : f32, i1 -> f32
+  %1 = nvvm.inline_ptx "ex2.approx.ftz.f32 {$w0}, {$r0};" ro (%input : f32), predicate = %pred  -> f32
   llvm.return
 }
 
 // CHECK-LABEL: @multi_return(
 // CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: i32, %[[arg1:[a-zA-Z0-9_]+]]: i32)
 llvm.func @multi_return(%a : i32, %b : i32) -> i32 {
-  // CHECK: %[[S1:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09 .reg .pred p;\0A\09 setp.ge.s32 p, $2, $3;\0A\09 selp.s32 $0, $2, $3, p;\0A\09 selp.s32 $1, $2, $3, !p;\0A\09}\0A", "=r,=r,r,r" %[[arg0]], %[[arg1]] : (i32, i32) -> !llvm.struct<(i32, i32)>
+  // CHECK: %[[S1:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "{.reg .pred p; setp.ge.s32 p, $2, $3; selp.s32 $0, $2,$3, p; selp.s32 $1, $2,$3, p;}", "=r,=r,r,r" %[[arg0]], %[[arg1]] : (i32, i32) -> !llvm.struct<(i32, i32)>
   // CHECK: %[[S2:.+]] = llvm.extractvalue %[[S1]][0] : !llvm.struct<(i32, i32)> 
   // CHECK: %[[S3:.+]] = llvm.extractvalue %[[S1]][1] : !llvm.struct<(i32, i32)> 
   // CHECK: %[[S4:.+]] = llvm.add %[[S2]], %[[S3]] : i32
   // CHECK: llvm.return %[[S4]] : i32
-   %r1, %r2 = nvvm.inline_ptx  "{\n\t .reg .pred p;\n\t setp.ge.s32 p, $2, $3;\n\t selp.s32 $0, $2, $3, p;\n\t selp.s32 $1, $2, $3, !p;\n\t}\n" (%a, %b) : i32,i32 -> i32,i32
+   %r1, %r2 = nvvm.inline_ptx "{.reg .pred p; setp.ge.s32 p, {$r0}, {$r1}; selp.s32 {$w0}, {$r0},{$r1}, p; selp.s32 {$w1}, {$r0},{$r1}, p;}"
+                  ro (%a, %b : i32,i32) -> i32,i32
    %r3 = llvm.add %r1, %r2 : i32
    llvm.return %r3 : i32
 }
+
+// CHECK-LABEL: @inline_ptx_multi_rw(
+// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: i32, %[[arg1:[a-zA-Z0-9_]+]]: i32, %[[arg2:[a-zA-Z0-9_]+]]: f32, %[[arg3:[a-zA-Z0-9_]+]]: f32)
+llvm.func @inline_ptx_multi_rw(%a : i32, %b : i32,  %rw_c : f32, %rw_d : f32) -> f32 {
+// CHECK: %[[S0:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "{.reg .pred p; setp.ge.s32 p, $2, $3; selp.s32 $0, $2,$3, p; selp.s32 $1, $2,$3, p;}", 
+// CHECK-SAME: "=f,=f,r,r,0,1" 
+// CHECK-SAME: %[[arg2]], %[[arg3]], %[[arg0]], %[[arg1]] 
+// CHECK-SAME: : (f32, f32, i32, i32) -> !llvm.struct<(f32, f32)>
+// CHECK: %[[S1:.+]] = llvm.extractvalue %[[S0]][0] : !llvm.struct<(f32, f32)> 
+// CHECK: %[[S2:.+]] = llvm.extractvalue %[[S0]][1] : !llvm.struct<(f32, f32)> 
+// CHECK: %[[S3:.+]] = llvm.fadd %[[S1]], %[[S2]] : f32
+// CHECK: llvm.return %[[S3]] : f32
+    nvvm.inline_ptx "{.reg .pred p; setp.ge.s32 p, {$r0}, {$r1}; selp.s32 {$rw0}, {$r0},{$r1}, p; selp.s32 {$rw1}, {$r0},{$r1}, p;}"
+    ro (%a, %b : i32,i32) 
+    rw (%rw_c, %rw_d: f32,f32)
+   %r4 = llvm.fadd %rw_c, %rw_d : f32
+   llvm.return %r4 : f32
+}
+
+// CHECK-LABEL: @inline_ptx_multi_rw_r(
+// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: i32, %[[arg1:[a-zA-Z0-9_]+]]: i32, %[[arg2:[a-zA-Z0-9_]+]]: f32, %[[arg3:[a-zA-Z0-9_]+]]: f32)
+llvm.func @inline_ptx_multi_rw_r(%a : i32, %b : i32,  %rw_c : f32, %rw_d : f32) -> f32 {
+// CHECK: %[[S0:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "{.reg .pred p; setp.ge.s32 p, $4, $5; selp.s32 $0, $4,$5, p; selp.s32 $1, $4,$5, p; selp.s32 $2, $4,$5, p; selp.s32 $3, $4,$5, p;}", 
+// CHECK-SAME: "=f,=f,=r,=r,r,r,0,1" 
+// CHECK-SAME: %[[arg2]], %[[arg3]], %[[arg0]], %[[arg1]] : 
+// CHECK-SAME: (f32, f32, i32, i32) -> !llvm.struct<(f32, f32, i32, i32)>
+// CHECK: %[[S1:.+]] = llvm.extractvalue %[[S0]][0] : !llvm.struct<(f32, f32, i32, i32)> 
+// CHECK: %[[S2:.+]] = llvm.extractvalue %[[S0]][1] : !llvm.struct<(f32, f32, i32, i32)> 
+// CHECK: %[[S3:.+]] = llvm.extractvalue %[[S0]][2] : !llvm.struct<(f32, f32, i32, i32)> 
+// CHECK: %[[S4:.+]] = llvm.extractvalue %[[S0]][3] : !llvm.struct<(f32, f32, i32, i32)> 
+// CHECK: %[[S5:.+]] = llvm.add %[[S3]], %[[S4]] : i32
+// CHECK: %[[S6:.+]] = llvm.sitofp %[[S5]] : i32 to f32
+// CHECK: %[[S7:.+]] = llvm.fadd %[[S1]], %[[S2]] : f32
+// CHECK: %[[S8:.+]] = llvm.fadd %[[S6]], %[[S2]] : f32
+// CHECK: llvm.return %[[S8]] : f32
+
+  %wo0, %wo1 = nvvm.inline_ptx "{.reg .pred p; setp.ge.s32 p, {$r0}, {$r1}; selp.s32 {$rw0}, {$r0},{$r1}, p; selp.s32 {$rw1}, {$r0},{$r1}, p; selp.s32 {$w0}, {$r0},{$r1}, p; selp.s32 {$w1}, {$r0},{$r1}, p;}"
+      ro (%a, %b : i32,i32) 
+      rw (%rw_c, %rw_d: f32,f32) -> i32,i32
+   %r3 = llvm.add %wo0, %wo1 : i32
+   %r3f = llvm.sitofp %r3 : i32 to f32
+   %r4 = llvm.fadd %rw_c, %rw_d : f32
+   %r5 = llvm.fadd %r3f, %rw_d : f32
+   llvm.return %r5 : f32
+}
+
+
 // -----
 
 // CHECK-LABEL: @nvvm_pmevent
diff --git a/mlir/test/python/dialects/nvvm.py b/mlir/test/python/dialects/nvvm.py
index 0eef97d95479a..3eb62bef50de9 100644
--- a/mlir/test/python/dialects/nvvm.py
+++ b/mlir/test/python/dialects/nvvm.py
@@ -5,6 +5,8 @@
 from mlir.dialects import nvvm
 from mlir.dialects import llvm
 from mlir.dialects import func
+import mlir.extras.types as T
+from mlir.dialects import arith
 
 
 def constructAndPrintInModule(f):
@@ -25,6 +27,7 @@ def testSmoke():
         "!llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>"
     )
     shape_attr = Attribute.parse("#nvvm.shape<m = 64, n = 32, k = 16>")
+
     # CHECK-LABEL: func @wgmma_f32_f16_f16(%arg0: i64, %arg1: i64)
     @func.FuncOp.from_py_func(i64, i64)
     def wgmma_f32_f16_f16(desc_a, desc_b):
@@ -48,3 +51,41 @@ def wgmma_f32_f16_f16(desc_a, desc_b):
             layoutA=nvvm.MMALayout.col,
             layoutB=nvvm.MMALayout.col,
         )
+
+
+# CHECK-LABEL: TEST: test_inline_ptx
+# CHECK-LABEL: func.func @my_inline_ptx(
+# CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: f32, %[[arg1:[a-zA-Z0-9_]+]]: f32, %[[arg2:[a-zA-Z0-9_]+]]: i32, %[[arg3:[a-zA-Z0-9_]+]]: i32)
+# CHECK: %[[S0:.+]]:2 = nvvm.inline_ptx
+# CHECK-SAME: ro(%[[arg0]], %[[arg1]] : f32, f32) rw(%[[arg2]], %[[arg3]] : i32, i32) -> f32, f32
+# CHECK: %[[S1:.+]] = arith.addf %[[arg0]], %[[arg1]] : f32
+# CHECK: %[[S2:.+]] = arith.addi %[[arg2]], %[[arg3]] : i32
+# CHECK: %[[S3:.+]] = arith.addf %[[S0]]#0, %[[S0]]#1 : f32
+
+
+ at constructAndPrintInModule
+def test_inline_ptx():
+    i32 = T.i32()
+    f32 = T.f32()
+
+    @func.FuncOp.from_py_func(f32, f32, i32, i32)
+    def my_inline_ptx(a, b, c, d):
+        ptx = r"""
+            {
+                .reg .pred p;
+                setp.ge.s32   p,      {$r0}, {$r1};
+                selp.s32      {$r0},  {$r0}, {$r1}, p;
+                selp.s32      {$r1},  {$r0}, {$r1}, p;
+                selp.s32      {$rw0}, {$r0}, {$r1}, p;
+                selp.s32      {$rw1}, {$r0}, {$r1}, p;
+            }
+            """
+        wo0, wo1 = nvvm.inline_ptx(
+            read_only_args=[a, b],
+            read_write_args=[c, d],
+            write_only_args=[f32, f32],
+            ptx_code=ptx,
+        )
+        arith.addf(a, b)
+        arith.addi(c, d)
+        arith.addf(wo0, wo1)

>From f5c3f5e50bcd0e7581ea790883bc5f70fd85492c Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Wed, 20 Aug 2025 11:56:44 +0000
Subject: [PATCH 2/4] address comments

---
 .../Dialect/LLVMIR/BasicPtxBuilderInterface.h | 16 ++++++++-------
 .../LLVMIR/BasicPtxBuilderInterface.td        |  4 ++--
 .../LLVMIR/IR/BasicPtxBuilderInterface.cpp    | 20 +++++++++++--------
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp    |  2 +-
 4 files changed, 24 insertions(+), 18 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h b/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h
index 99b1d9709e3e1..f506754ead7ad 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h
@@ -26,11 +26,11 @@ namespace NVVM {
 enum class PTXRegisterMod {
   /// Read register with no modifier
   Read = 0,
-  /// Read register with '=' modifier
+  /// Write register with '=' modifier
   Write = 2,
-  /// Read register with '+' modifier.
-  /// Note that, this is not natively supported by LLVM, but it is possible to
-  /// set read and write for the same operand.
+  /// ReadWrite register with '+' modifier.
+  /// Note that, this is not natively supported by LLVM, the Interface does
+  /// mapping
   ReadWrite = 1,
 };
 
@@ -69,15 +69,17 @@ class PtxBuilder {
   std::string registerConstraints;
   // Modifiers
   SmallVector<PTXRegisterMod> registerModifiers;
+  // Has return value as write-only or read-write
   bool hasResult = false;
-  bool needsManualMapping = false;
+  // Indicates if the Op will handle the register mapping manually.
+  bool needsManualRegisterMapping = false;
 
 public:
   /// Single constructor that only initializes members.
   PtxBuilder(Operation *op, PatternRewriter &rewriter,
-             bool needsManualMapping = false)
+             bool needsManualRegisterMapping = false)
       : interfaceOp(op), rewriter(rewriter),
-        needsManualMapping(needsManualMapping) {}
+        needsManualRegisterMapping(needsManualRegisterMapping) {}
 
   /// Add an operand with the read/write input type.
   void insertValue(Value v, PTXRegisterMod itype = PTXRegisterMod::Read);
diff --git a/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td b/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td
index 8e36749cdb361..086cdccb01221 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td
@@ -125,7 +125,7 @@ def BasicPtxBuilderOpInterface : OpInterface<"BasicPtxBuilderInterface"> {
              1) Adds results 
              2) Adds operands 
              3) Adds attributes
-             Returns true if it does the mapping manually
+             Returns true if the OP is going to do register mapping itself
           }],
          /*retType=*/"bool",
          /*methodName=*/"getAsmValues",
@@ -151,7 +151,7 @@ def BasicPtxBuilderOpInterface : OpInterface<"BasicPtxBuilderInterface"> {
              asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
              }
            }
-           return false; // No needs manual mapping
+           return false; // No manual mapping needed
          }]
        >
   ];
diff --git a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
index 3cad9d3bd16e3..75ab126d61a34 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
@@ -129,9 +129,10 @@ void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
 
 /// Check if the operation needs to pack and unpack results.
 static bool
-needsPackUnpack(BasicPtxBuilderInterface interfaceOp, bool needsManualMapping,
+needsPackUnpack(BasicPtxBuilderInterface interfaceOp,
+                bool needsManualRegisterMapping,
                 SmallVectorImpl<PTXRegisterMod> &registerModifiers) {
-  if (needsManualMapping)
+  if (needsManualRegisterMapping)
     return false;
   const unsigned writeOnly = interfaceOp->getNumResults();
   const unsigned readWrite =
@@ -145,13 +146,15 @@ needsPackUnpack(BasicPtxBuilderInterface interfaceOp, bool needsManualMapping,
 /// If the operation has multiple results, it packs them into a struct
 /// type. Otherwise, it returns the original result types.
 static SmallVector<Type>
-packResultTypes(BasicPtxBuilderInterface interfaceOp, bool needsManualMapping,
+packResultTypes(BasicPtxBuilderInterface interfaceOp,
+                bool needsManualRegisterMapping,
                 SmallVectorImpl<PTXRegisterMod> &registerModifiers,
                 SmallVectorImpl<Value> &ptxOperands) {
   MLIRContext *ctx = interfaceOp->getContext();
   TypeRange resultRange = interfaceOp->getResultTypes();
 
-  if (!needsPackUnpack(interfaceOp, needsManualMapping, registerModifiers)) {
+  if (!needsPackUnpack(interfaceOp, needsManualRegisterMapping,
+                       registerModifiers)) {
     // Single value path:
     if (interfaceOp->getResults().size() == 1)
       return SmallVector<Type>{resultRange.front()};
@@ -333,7 +336,7 @@ LLVM::InlineAsmOp PtxBuilder::build() {
                                                   LLVM::AsmDialect::AD_ATT);
 
   SmallVector<Type> resultTypes = packResultTypes(
-      interfaceOp, needsManualMapping, registerModifiers, ptxOperands);
+      interfaceOp, needsManualRegisterMapping, registerModifiers, ptxOperands);
 
   // Remove the last comma from the constraints string.
   if (!registerConstraints.empty() &&
@@ -342,7 +345,7 @@ LLVM::InlineAsmOp PtxBuilder::build() {
   registerConstraints = canonicalizeRegisterConstraints(registerConstraints);
 
   std::string ptxInstruction = interfaceOp.getPtx();
-  if (!needsManualMapping)
+  if (!needsManualRegisterMapping)
     ptxInstruction = rewriteAsmPlaceholders(ptxInstruction);
 
   // Add the predicate to the asm string.
@@ -379,13 +382,14 @@ void PtxBuilder::buildAndReplaceOp() {
     return;
   }
 
-  if (needsManualMapping) {
+  if (needsManualRegisterMapping) {
     rewriter.replaceOp(interfaceOp, inlineAsmOp->getResults());
     return;
   }
 
   // Case 1: Simple path, return single scalar
-  if (!needsPackUnpack(interfaceOp, needsManualMapping, registerModifiers)) {
+  if (!needsPackUnpack(interfaceOp, needsManualRegisterMapping,
+                       registerModifiers)) {
     if (inlineAsmOp->getNumResults() > 0) {
       rewriter.replaceOp(interfaceOp, inlineAsmOp->getResults());
     } else {
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index ae9134458095f..042103cad83ca 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1884,7 +1884,7 @@ bool NVVM::InlinePtxOp::getAsmValues(
     asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::Read});
   if (getPredicate())
     asmValues.push_back({getPredicate(), mlir::NVVM::PTXRegisterMod::Read});
-  return false; // Needs manual mapping
+  return false; // No manual mapping needed
 }
 
 //===----------------------------------------------------------------------===//

>From 98dfb038c99f7ca56ff6d95c6bd6c0b6539a05f3 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Wed, 20 Aug 2025 11:59:35 +0000
Subject: [PATCH 3/4] fx

---
 mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
index 75ab126d61a34..997c893f6e6de 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
@@ -62,7 +62,7 @@ static char getRegisterType(Value v) {
   return getRegisterType(v.getType());
 }
 
-/// Extract every elements of a struct value.
+/// Extract every element of a struct value.
 static SmallVector<Value> extractStructElements(PatternRewriter &rewriter,
                                                 Location loc, Value agg) {
   auto structTy = cast<LLVM::LLVMStructType>(agg.getType());

>From 341f7f270b7870b820bd5612b202dda8bf21e7cc Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Thu, 21 Aug 2025 07:13:41 +0000
Subject: [PATCH 4/4] address comments, make code more readable

---
 .../Dialect/LLVMIR/BasicPtxBuilderInterface.h |  10 ++
 .../LLVMIR/IR/BasicPtxBuilderInterface.cpp    | 148 ++++++++++--------
 2 files changed, 95 insertions(+), 63 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h b/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h
index f506754ead7ad..21331e5aa89f3 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h
@@ -93,6 +93,16 @@ class PtxBuilder {
   void buildAndReplaceOp();
 };
 
+/// Count the number of placeholder variables such as {$r}, {$w}, {$rw} in the
+/// PTX code.
+void countPlaceholderNumbers(StringRef ptxCode,
+                             llvm::SmallDenseSet<unsigned> &seenRW,
+                             llvm::SmallDenseSet<unsigned> &seenW,
+                             llvm::SmallDenseSet<unsigned> &seenR,
+                             llvm::SmallVectorImpl<unsigned> &rwNums,
+                             llvm::SmallVectorImpl<unsigned> &wNums,
+                             llvm::SmallVectorImpl<unsigned> &rNums);
+
 } // namespace NVVM
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
index 997c893f6e6de..5bdfd2b995d54 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
@@ -64,15 +64,14 @@ static char getRegisterType(Value v) {
 
 /// Extract every element of a struct value.
 static SmallVector<Value> extractStructElements(PatternRewriter &rewriter,
-                                                Location loc, Value agg) {
-  auto structTy = cast<LLVM::LLVMStructType>(agg.getType());
+                                                Location loc, Value structVal) {
+  auto structTy = dyn_cast<LLVM::LLVMStructType>(structVal.getType());
+  assert(structTy && "expected LLVM struct");
+
   SmallVector<Value> elems;
-  elems.reserve(structTy.getBody().size());
-  for (auto [i, t] : llvm::enumerate(structTy.getBody())) {
-    (void)t;
-    Value e = LLVM::ExtractValueOp::create(rewriter, loc, agg, i);
-    elems.push_back(e);
-  }
+  for (unsigned i : llvm::seq<unsigned>(0, structTy.getBody().size()))
+    elems.push_back(rewriter.create<LLVM::ExtractValueOp>(loc, structVal, i));
+
   return elems;
 }
 
@@ -81,15 +80,17 @@ void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
   registerModifiers.push_back(itype);
 
   auto getModifier = [&]() -> const char * {
-    if (itype == PTXRegisterMod::ReadWrite) {
-      // "Read-Write modifier is not supported
-      // Interface canonicalize it later
-      return "+";
-    }
-    if (itype == PTXRegisterMod::Write) {
+    switch (itype) {
+    case PTXRegisterMod::Read:
+      return "";
+    case PTXRegisterMod::Write:
       return "=";
+    case PTXRegisterMod::ReadWrite:
+      // "Read-Write modifier is not actually supported
+      // Interface will change it to "=" later and add integer mapping
+      return "+";
     }
-    return "";
+    llvm_unreachable("Unknown PTX register modifier");
   };
 
   auto addValue = [&](Value v) {
@@ -134,12 +135,12 @@ needsPackUnpack(BasicPtxBuilderInterface interfaceOp,
                 SmallVectorImpl<PTXRegisterMod> &registerModifiers) {
   if (needsManualRegisterMapping)
     return false;
-  const unsigned writeOnly = interfaceOp->getNumResults();
-  const unsigned readWrite =
+  const unsigned writeOnlyVals = interfaceOp->getNumResults();
+  const unsigned readWriteVals =
       llvm::count_if(registerModifiers, [](PTXRegisterMod m) {
         return m == PTXRegisterMod::ReadWrite;
       });
-  return (writeOnly + readWrite) > 1;
+  return (writeOnlyVals + readWriteVals) > 1;
 }
 
 /// Pack the result types of the interface operation.
@@ -219,14 +220,58 @@ static std::string canonicalizeRegisterConstraints(llvm::StringRef csv) {
   return os.str();
 }
 
-constexpr llvm::StringLiteral kReadWrite{"rw"};
-constexpr llvm::StringLiteral kWriteOnly{"w"};
-constexpr llvm::StringLiteral kReadOnly{"r"};
+constexpr llvm::StringLiteral kReadWritePrefix{"rw"};
+constexpr llvm::StringLiteral kWriteOnlyPrefix{"w"};
+constexpr llvm::StringLiteral kReadOnlyPrefix{"r"};
 
-/// Rewrites placeholders of the form `{$rN}`, `{$wN}`, `{$rwN}` in `asmText`
-/// to compact `$K` indices where all `rw*` come first (ascending N), then `w*`,
-/// then `r*`. Duplicates are de-duplicated when assigning numbers.
-/// Unknown text is preserved verbatim.
+/// Returns a regex that matches {$rwN}, {$wN}, {$rN}
+static llvm::Regex getPredicateMappingRegex() {
+  llvm::Regex rx(llvm::formatv(R"(\{\$({0}|{1}|{2})([0-9]+)\})",
+                               kReadWritePrefix, kWriteOnlyPrefix,
+                               kReadOnlyPrefix)
+                     .str());
+  return rx;
+}
+
+void mlir::NVVM::countPlaceholderNumbers(
+    StringRef ptxCode, llvm::SmallDenseSet<unsigned int> &seenRW,
+    llvm::SmallDenseSet<unsigned int> &seenW,
+    llvm::SmallDenseSet<unsigned int> &seenR,
+    llvm::SmallVectorImpl<unsigned int> &rwNums,
+    llvm::SmallVectorImpl<unsigned int> &wNums,
+    llvm::SmallVectorImpl<unsigned int> &rNums) {
+
+  llvm::Regex rx = getPredicateMappingRegex();
+  StringRef rest = ptxCode;
+
+  SmallVector<StringRef, 3> m; // 0: full, 1: kind, 2: number
+  while (!rest.empty() && rx.match(rest, &m)) {
+    unsigned num = 0;
+    (void)m[2].getAsInteger(10, num);
+
+    if (m[1].equals_insensitive(kReadWritePrefix)) {
+      if (seenRW.insert(num).second)
+        rwNums.push_back(num);
+    } else if (m[1].equals_insensitive(kWriteOnlyPrefix)) {
+      if (seenW.insert(num).second)
+        wNums.push_back(num);
+    } else {
+      if (seenR.insert(num).second)
+        rNums.push_back(num);
+    }
+
+    const size_t advance = (size_t)(m[0].data() - rest.data()) + m[0].size();
+    rest = rest.drop_front(advance);
+  }
+}
+
+/// Rewrites `{$rwN}`, `{$wN}`, and `{$rN}` placeholders in `ptxCode` into
+/// compact `$K` indices:
+///   - All `rw*` first (sorted by N),
+///   - Then `w*`,
+///   - Then `r*`.
+/// If there a predicate, it comes always in the end.
+/// Each number is assigned once; duplicates are ignored.
 ///
 /// Example Input:
 /// "{
@@ -246,42 +291,19 @@ constexpr llvm::StringLiteral kReadOnly{"r"};
 ///       selp.s32 $2,   $4, $5, p;
 ///       selp.s32 $3,   $4, $5, p;
 /// }\n"
-static std::string rewriteAsmPlaceholders(llvm::StringRef asmText) {
-  // Match {$rwN}, {$wN}, {$rN}
-  llvm::Regex rx(llvm::formatv(R"(\{\$({0}|{1}|{2})([0-9]+)\})", kReadWrite,
-                               kWriteOnly, kReadOnly)
-                     .str());
-
+static std::string rewriteAsmPlaceholders(llvm::StringRef ptxCode) {
   llvm::SmallDenseSet<unsigned> seenRW, seenW, seenR;
   llvm::SmallVector<unsigned> rwNums, wNums, rNums;
 
-  {
-    StringRef rest = asmText;
-    SmallVector<StringRef, 3> m; // 0: full, 1: kind, 2: number
-    while (!rest.empty() && rx.match(rest, &m)) {
-      unsigned num = 0;
-      (void)m[2].getAsInteger(10, num);
-
-      if (m[1].equals_insensitive(kReadWrite)) {
-        if (seenRW.insert(num).second)
-          rwNums.push_back(num);
-      } else if (m[1].equals_insensitive(kWriteOnly)) {
-        if (seenW.insert(num).second)
-          wNums.push_back(num);
-      } else {
-        if (seenR.insert(num).second)
-          rNums.push_back(num);
-      }
-
-      const size_t advance = (size_t)(m[0].data() - rest.data()) + m[0].size();
-      rest = rest.drop_front(advance);
-    }
-  }
+  // Step 1. Count Register Placeholder numbers
+  countPlaceholderNumbers(ptxCode, seenRW, seenW, seenR, rwNums, wNums, rNums);
 
+  // Step 2. Sort the Register Placeholder numbers
   llvm::sort(rwNums);
   llvm::sort(wNums);
   llvm::sort(rNums);
 
+  // Step 3. Create mapping from original to new IDs
   llvm::DenseMap<unsigned, unsigned> rwMap, wMap, rMap;
   unsigned nextId = 0;
   for (unsigned n : rwNums)
@@ -291,27 +313,28 @@ static std::string rewriteAsmPlaceholders(llvm::StringRef asmText) {
   for (unsigned n : rNums)
     rMap[n] = nextId++;
 
+  // Step 4. Rewrite the PTX code with new IDs
   std::string out;
-  out.reserve(asmText.size());
-
+  out.reserve(ptxCode.size());
   size_t prev = 0;
-  StringRef rest = asmText;
+  StringRef rest = ptxCode;
   SmallVector<StringRef, 3> m;
+  llvm::Regex rx = getPredicateMappingRegex();
   while (!rest.empty() && rx.match(rest, &m)) {
     // Compute absolute match bounds in the original buffer.
-    size_t absStart = (size_t)(m[0].data() - asmText.data());
+    size_t absStart = (size_t)(m[0].data() - ptxCode.data());
     size_t absEnd = absStart + m[0].size();
 
     // Emit text before the match.
-    out.append(asmText.data() + prev, asmText.data() + absStart);
+    out.append(ptxCode.data() + prev, ptxCode.data() + absStart);
 
     // Emit compact $K
     unsigned num = 0;
     (void)m[2].getAsInteger(10, num);
     unsigned id = 0;
-    if (m[1].equals_insensitive(kReadWrite))
+    if (m[1].equals_insensitive(kReadWritePrefix))
       id = rwMap.lookup(num);
-    else if (m[1].equals_insensitive(kWriteOnly))
+    else if (m[1].equals_insensitive(kWriteOnlyPrefix))
       id = wMap.lookup(num);
     else
       id = rMap.lookup(num);
@@ -321,13 +344,12 @@ static std::string rewriteAsmPlaceholders(llvm::StringRef asmText) {
 
     prev = absEnd;
 
-    // Advance search window.
     const size_t advance = (size_t)(m[0].data() - rest.data()) + m[0].size();
     rest = rest.drop_front(advance);
   }
 
-  // Tail.
-  out.append(asmText.data() + prev, asmText.data() + asmText.size());
+  // Step 5. Tail.
+  out.append(ptxCode.data() + prev, ptxCode.data() + ptxCode.size());
   return out;
 }
 



More information about the Mlir-commits mailing list