[Mlir-commits] [mlir] [MLIR][NVVM] Improve inline_ptx, add readwrite support (PR #154358)

Durgadoss R llvmlistbot at llvm.org
Thu Aug 21 06:50:40 PDT 2025


================
@@ -108,38 +129,246 @@ void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
 }
 
 /// Check if the operation needs to pack and unpack results.
-static bool needsPackUnpack(BasicPtxBuilderInterface interfaceOp) {
-  return interfaceOp->getNumResults() > 1;
+static bool
+needsPackUnpack(BasicPtxBuilderInterface interfaceOp,
+                bool needsManualRegisterMapping,
+                SmallVectorImpl<PTXRegisterMod> &registerModifiers) {
+  if (needsManualRegisterMapping)
+    return false;
+  const unsigned writeOnlyVals = interfaceOp->getNumResults();
+  const unsigned readWriteVals =
+      llvm::count_if(registerModifiers, [](PTXRegisterMod m) {
+        return m == PTXRegisterMod::ReadWrite;
+      });
+  return (writeOnlyVals + readWriteVals) > 1;
 }
 
 /// Pack the result types of the interface operation.
 /// If the operation has multiple results, it packs them into a struct
 /// type. Otherwise, it returns the original result types.
-static SmallVector<Type> packResultTypes(MLIRContext *ctx,
-                                         BasicPtxBuilderInterface interfaceOp) {
-  TypeRange results = interfaceOp->getResultTypes();
+static SmallVector<Type>
+packResultTypes(BasicPtxBuilderInterface interfaceOp,
+                bool needsManualRegisterMapping,
+                SmallVectorImpl<PTXRegisterMod> &registerModifiers,
+                SmallVectorImpl<Value> &ptxOperands) {
+  MLIRContext *ctx = interfaceOp->getContext();
+  TypeRange resultRange = interfaceOp->getResultTypes();
+
+  if (!needsPackUnpack(interfaceOp, needsManualRegisterMapping,
+                       registerModifiers)) {
+    // Single value path:
+    if (interfaceOp->getResults().size() == 1)
+      return SmallVector<Type>{resultRange.front()};
+
+    // No declared results: if there is an RW, forward its type.
+    for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands))
+      if (m == PTXRegisterMod::ReadWrite)
+        return SmallVector<Type>{v.getType()};
+  }
+
+  SmallVector<Type> packed;
+  for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands))
+    if (m == PTXRegisterMod::ReadWrite)
+      packed.push_back(v.getType());
+  for (Type t : resultRange)
+    packed.push_back(t);
+
+  if (packed.empty())
+    return {};
 
-  if (!needsPackUnpack(interfaceOp))
-    return llvm::to_vector<1>(results);
+  auto sTy = LLVM::LLVMStructType::getLiteral(ctx, packed, /*isPacked=*/false);
+  return SmallVector<Type>{sTy};
+}
+
+/// Canonicalize the register constraints:
+///  - Turn every "+X" into "=X"
+///  - Append (at the very end) the 0-based indices of tokens that were "+X"
+/// Examples:
+///  "+f,+f,+r,=r,=r,r,r" -> "=f,=f,=r,=r,=r,r,r,0,1,2"
+///  "+f,+f,+r,=r,=r"     -> "=f,=f,=r,=r,=r,0,1,2"
+static std::string canonicalizeRegisterConstraints(llvm::StringRef csv) {
+  SmallVector<llvm::StringRef> toks;
+  SmallVector<std::string> out;
+  SmallVector<unsigned> plusIdx;
+
+  csv.split(toks, ',');
+  out.reserve(toks.size() + 8);
+
+  for (unsigned i = 0, e = toks.size(); i < e; ++i) {
+    StringRef t = toks[i].trim();
+    if (t.consume_front("+")) {
+      plusIdx.push_back(i);
+      out.push_back(("=" + t).str());
+    } else {
+      out.push_back(t.str());
+    }
+  }
+
+  // Append indices of original "+X" tokens.
+  for (unsigned idx : plusIdx)
+    out.push_back(std::to_string(idx));
+
+  // Join back to CSV.
+  std::string result;
+  result.reserve(csv.size() + plusIdx.size() * 2);
+  llvm::raw_string_ostream os(result);
+  for (size_t i = 0; i < out.size(); ++i) {
+    if (i)
+      os << ',';
+    os << out[i];
+  }
+  return os.str();
+}
 
-  SmallVector<mlir::Type> elems(results.begin(), results.end());
-  auto sTy = LLVM::LLVMStructType::getLiteral(ctx, elems, /*isPacked=*/false);
-  return {sTy};
+constexpr llvm::StringLiteral kReadWritePrefix{"rw"};
+constexpr llvm::StringLiteral kWriteOnlyPrefix{"w"};
+constexpr llvm::StringLiteral kReadOnlyPrefix{"r"};
+
+/// Returns a regex that matches {$rwN}, {$wN}, {$rN}
+static llvm::Regex getPredicateMappingRegex() {
+  llvm::Regex rx(llvm::formatv(R"(\{\$({0}|{1}|{2})([0-9]+)\})",
+                               kReadWritePrefix, kWriteOnlyPrefix,
+                               kReadOnlyPrefix)
+                     .str());
+  return rx;
+}
+
+void mlir::NVVM::countPlaceholderNumbers(
+    StringRef ptxCode, llvm::SmallDenseSet<unsigned int> &seenRW,
+    llvm::SmallDenseSet<unsigned int> &seenW,
+    llvm::SmallDenseSet<unsigned int> &seenR,
+    llvm::SmallVectorImpl<unsigned int> &rwNums,
+    llvm::SmallVectorImpl<unsigned int> &wNums,
+    llvm::SmallVectorImpl<unsigned int> &rNums) {
+
+  llvm::Regex rx = getPredicateMappingRegex();
+  StringRef rest = ptxCode;
+
+  SmallVector<StringRef, 3> m; // 0: full, 1: kind, 2: number
+  while (!rest.empty() && rx.match(rest, &m)) {
+    unsigned num = 0;
+    (void)m[2].getAsInteger(10, num);
+
+    if (m[1].equals_insensitive(kReadWritePrefix)) {
+      if (seenRW.insert(num).second)
----------------
durga4github wrote:

Can we add a comment on the side:
"
we insert it into the vector only the first time we see this number
"

https://github.com/llvm/llvm-project/pull/154358


More information about the Mlir-commits mailing list