[Mlir-commits] [mlir] [MLIR][NVVM] Improve inline_ptx, add readwrite support (PR #154358)
Durgadoss R
llvmlistbot at llvm.org
Thu Aug 21 08:21:31 PDT 2025
================
@@ -169,33 +399,87 @@ void PtxBuilder::buildAndReplaceOp() {
LLVM::InlineAsmOp inlineAsmOp = build();
LDBG() << "\n Generated PTX \n\t" << inlineAsmOp;
- // Case 1: no result
- if (inlineAsmOp->getNumResults() == 0) {
+ // Case 0: no result at all → just erase wrapper op.
+ if (!hasResult) {
rewriter.eraseOp(interfaceOp);
return;
}
- // Case 2: single result, forward it directly
- if (!needsPackUnpack(interfaceOp)) {
+ if (needsManualRegisterMapping) {
rewriter.replaceOp(interfaceOp, inlineAsmOp->getResults());
return;
}
- // Case 3: multiple results were packed; unpack the struct.
- assert(mlir::LLVM::LLVMStructType::classof(
- inlineAsmOp.getResultTypes().front()) &&
- "Expected result type to be LLVMStructType when unpacking multiple "
- "results");
- auto structTy = llvm::cast<mlir::LLVM::LLVMStructType>(
- inlineAsmOp.getResultTypes().front());
+ // Case 1: Simple path, return single scalar
+ if (!needsPackUnpack(interfaceOp, needsManualRegisterMapping,
+ registerModifiers)) {
+ if (inlineAsmOp->getNumResults() > 0) {
+ rewriter.replaceOp(interfaceOp, inlineAsmOp->getResults());
+ } else {
+ // RW-only case with no declared results: forward the RW value.
+ SmallVector<Value> results;
+ for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands))
+ if (m == PTXRegisterMod::ReadWrite) {
+ results.push_back(v);
+ break;
+ }
+ rewriter.replaceOp(interfaceOp, results);
+ }
+ return;
+ }
+
+ const bool hasRW = llvm::any_of(registerModifiers, [](PTXRegisterMod m) {
+ return m == PTXRegisterMod::ReadWrite;
+ });
- SmallVector<mlir::Value> unpacked;
+ // All multi-value paths produce a single struct result we need to unpack.
+ assert(LLVM::LLVMStructType::classof(inlineAsmOp.getResultTypes().front()) &&
+ "expected struct return for multi-result inline asm");
Value structVal = inlineAsmOp.getResult(0);
- for (auto [idx, elemTy] : llvm::enumerate(structTy.getBody())) {
- Value unpackedValue = LLVM::ExtractValueOp::create(
- rewriter, interfaceOp->getLoc(), structVal, idx);
- unpacked.push_back(unpackedValue);
+ SmallVector<Value> unpacked =
+ extractStructElements(rewriter, interfaceOp->getLoc(), structVal);
+
+ // Case 2: only declared results (no RW): replace the op with all unpacked.
+ if (!hasRW && interfaceOp->getResults().size() > 0) {
+ rewriter.replaceOp(interfaceOp, unpacked);
+ return;
}
- rewriter.replaceOp(interfaceOp, unpacked);
+ // Case 3: RW-only (no declared results): update RW uses and erase wrapper.
+ if (hasRW && interfaceOp->getResults().size() == 0) {
+ unsigned idx = 0;
+ for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands)) {
+ if (m != PTXRegisterMod::ReadWrite)
+ continue;
----------------
durga4github wrote:
ok, for my understanding:
Can we not break here (instead of continue) since we fill the AsmValues in RW, WO, RO order?
(I see that this is not currently the case for the wgmma Op's getAsmValues() but for others...)
https://github.com/llvm/llvm-project/pull/154358
More information about the Mlir-commits
mailing list