[Mlir-commits] [mlir] [MLIR][NVVM] Improve inline_ptx, add readwrite support (PR #154358)
Durgadoss R
llvmlistbot at llvm.org
Thu Aug 21 06:50:08 PDT 2025
================
@@ -108,38 +129,246 @@ void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
}
/// Check if the operation needs to pack and unpack results.
-static bool needsPackUnpack(BasicPtxBuilderInterface interfaceOp) {
- return interfaceOp->getNumResults() > 1;
+static bool
+needsPackUnpack(BasicPtxBuilderInterface interfaceOp,
+ bool needsManualRegisterMapping,
+ SmallVectorImpl<PTXRegisterMod> ®isterModifiers) {
+ if (needsManualRegisterMapping)
+ return false;
+ const unsigned writeOnlyVals = interfaceOp->getNumResults();
+ const unsigned readWriteVals =
+ llvm::count_if(registerModifiers, [](PTXRegisterMod m) {
+ return m == PTXRegisterMod::ReadWrite;
+ });
+ return (writeOnlyVals + readWriteVals) > 1;
}
/// Pack the result types of the interface operation.
/// If the operation has multiple results, it packs them into a struct
/// type. Otherwise, it returns the original result types.
-static SmallVector<Type> packResultTypes(MLIRContext *ctx,
- BasicPtxBuilderInterface interfaceOp) {
- TypeRange results = interfaceOp->getResultTypes();
+static SmallVector<Type>
+packResultTypes(BasicPtxBuilderInterface interfaceOp,
+ bool needsManualRegisterMapping,
+ SmallVectorImpl<PTXRegisterMod> ®isterModifiers,
+ SmallVectorImpl<Value> &ptxOperands) {
+ MLIRContext *ctx = interfaceOp->getContext();
+ TypeRange resultRange = interfaceOp->getResultTypes();
+
+ if (!needsPackUnpack(interfaceOp, needsManualRegisterMapping,
+ registerModifiers)) {
+ // Single value path:
+ if (interfaceOp->getResults().size() == 1)
+ return SmallVector<Type>{resultRange.front()};
+
+ // No declared results: if there is an RW, forward its type.
+ for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands))
+ if (m == PTXRegisterMod::ReadWrite)
+ return SmallVector<Type>{v.getType()};
+ }
+
+ SmallVector<Type> packed;
+ for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands))
+ if (m == PTXRegisterMod::ReadWrite)
+ packed.push_back(v.getType());
+ for (Type t : resultRange)
+ packed.push_back(t);
+
+ if (packed.empty())
+ return {};
- if (!needsPackUnpack(interfaceOp))
- return llvm::to_vector<1>(results);
+ auto sTy = LLVM::LLVMStructType::getLiteral(ctx, packed, /*isPacked=*/false);
+ return SmallVector<Type>{sTy};
+}
+
+/// Canonicalize the register constraints:
+/// - Turn every "+X" into "=X"
+/// - Append (at the very end) the 0-based indices of tokens that were "+X"
+/// Examples:
+/// "+f,+f,+r,=r,=r,r,r" -> "=f,=f,=r,=r,=r,r,r,0,1,2"
+/// "+f,+f,+r,=r,=r" -> "=f,=f,=r,=r,=r,0,1,2"
+static std::string canonicalizeRegisterConstraints(llvm::StringRef csv) {
+ SmallVector<llvm::StringRef> toks;
+ SmallVector<std::string> out;
+ SmallVector<unsigned> plusIdx;
+
+ csv.split(toks, ',');
+ out.reserve(toks.size() + 8);
+
+ for (unsigned i = 0, e = toks.size(); i < e; ++i) {
+ StringRef t = toks[i].trim();
+ if (t.consume_front("+")) {
+ plusIdx.push_back(i);
+ out.push_back(("=" + t).str());
+ } else {
+ out.push_back(t.str());
+ }
+ }
+
+ // Append indices of original "+X" tokens.
+ for (unsigned idx : plusIdx)
+ out.push_back(std::to_string(idx));
+
+ // Join back to CSV.
+ std::string result;
+ result.reserve(csv.size() + plusIdx.size() * 2);
+ llvm::raw_string_ostream os(result);
+ for (size_t i = 0; i < out.size(); ++i) {
+ if (i)
+ os << ',';
+ os << out[i];
+ }
+ return os.str();
+}
- SmallVector<mlir::Type> elems(results.begin(), results.end());
- auto sTy = LLVM::LLVMStructType::getLiteral(ctx, elems, /*isPacked=*/false);
- return {sTy};
+constexpr llvm::StringLiteral kReadWritePrefix{"rw"};
+constexpr llvm::StringLiteral kWriteOnlyPrefix{"w"};
+constexpr llvm::StringLiteral kReadOnlyPrefix{"r"};
+
+/// Returns a regex that matches {$rwN}, {$wN}, {$rN}
+static llvm::Regex getPredicateMappingRegex() {
+ llvm::Regex rx(llvm::formatv(R"(\{\$({0}|{1}|{2})([0-9]+)\})",
+ kReadWritePrefix, kWriteOnlyPrefix,
+ kReadOnlyPrefix)
+ .str());
+ return rx;
+}
+
+void mlir::NVVM::countPlaceholderNumbers(
+ StringRef ptxCode, llvm::SmallDenseSet<unsigned int> &seenRW,
+ llvm::SmallDenseSet<unsigned int> &seenW,
+ llvm::SmallDenseSet<unsigned int> &seenR,
+ llvm::SmallVectorImpl<unsigned int> &rwNums,
+ llvm::SmallVectorImpl<unsigned int> &wNums,
+ llvm::SmallVectorImpl<unsigned int> &rNums) {
+
+ llvm::Regex rx = getPredicateMappingRegex();
+ StringRef rest = ptxCode;
+
+ SmallVector<StringRef, 3> m; // 0: full, 1: kind, 2: number
----------------
durga4github wrote:
Thanks for this comment on the parts!
https://github.com/llvm/llvm-project/pull/154358
More information about the Mlir-commits
mailing list