[Mlir-commits] [mlir] [MLIR][NVVM] Improve inline_ptx, add readwrite support (PR #154358)
Guray Ozen
llvmlistbot at llvm.org
Tue Aug 19 08:13:42 PDT 2025
https://github.com/grypp created https://github.com/llvm/llvm-project/pull/154358
Key Features
1. Multiple SSA returns – no struct packing/unpacking required.
2. Automatic struct unpacking – values are directly usable.
3. Readable register mapping
* {$rwN} → read-write
* {$roN} → read-only
* {$woN} → write-only
4. Full read-write support (+ modifier).
5. Simplified operand specification – avoids cryptic "=r,=r,=f,=f,f,f,0,1" constraints.
6. Predicate support: PTX @p predication support
IR Example:
```
%wo0, %wo1 = nvvm.inline_ptx """
.reg .pred p;
setp.ge.s32 p, {$r0}, {$r1};
selp.s32 {$rw0}, {$r0}, {$r1}, p;
selp.s32 {$rw1}, {$r0}, {$r1}, p;
selp.s32 {$w0}, {$r0}, {$r1}, p;
selp.s32 {$w1}, {$r0}, {$r1}, p;
""" ro(%a, %b : f32, f32) rw(%c, %d : i32, i32) -> f32, f32
```
After lowering
```
%0 = llvm.inline_asm has_side_effects asm_dialect = att
"{
.reg .pred p;\
setp.ge.s32 p, $4, $5; \
selp.s32 $0, $4, $5, p;\
selp.s32 $1, $4, $5, p;\
selp.s32 $2, $4, $5, p;\
selp.s32 $3, $4, $5, p;\
}"
"=r,=r,=f,=f,f,f,0,1"
%c500_i32, %c400_i32, %cst, %cst_0
: (i32, i32, f32, f32)
-> !llvm.struct<(i32, i32, f32, f32)>
%1 = llvm.extractvalue %0 : !llvm.struct<(i32, i32, f32, f32)>
%2 = llvm.extractvalue %0 : !llvm.struct<(i32, i32, f32, f32)>
%3 = llvm.extractvalue %0 : !llvm.struct<(i32, i32, f32, f32)>
%4 = llvm.extractvalue %0 : !llvm.struct<(i32, i32, f32, f32)>
// Unpacked result from nvvm.inline_ptx
%5 = arith.addi %1, %2 : i32
// read only
%6 = arith.addf %cst, %cst_0 : f32
// write only
%7 = arith.addf %3, %4 : f32
```
>From edaa2d4052b5b261b619feafbf78f35709c982be Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Tue, 19 Aug 2025 15:13:06 +0000
Subject: [PATCH] =?UTF-8?q?[MLIR][NVVM]=20Improve=20inline=5Fptx,=20add=20?=
=?UTF-8?q?readwrite=20support=20Key=20Features=201.=20Multiple=20SSA=20re?=
=?UTF-8?q?turns=20=E2=80=93=20no=20struct=20packing/unpacking=20required.?=
=?UTF-8?q?=202.=20Automatic=20struct=20unpacking=20=E2=80=93=20values=20a?=
=?UTF-8?q?re=20directly=20usable.=203.=20Readable=20register=20mapping=20?=
=?UTF-8?q?=20=20=20=20*=20{$rwN}=20=E2=86=92=20read-write=20=20=20=20=20*?=
=?UTF-8?q?=20{$roN}=20=E2=86=92=20read-only=20=20=20=20=20*=20{$woN}=20?=
=?UTF-8?q?=E2=86=92=20write-only=204.=20Full=20read-write=20support=20(+?=
=?UTF-8?q?=20modifier).=205.=20Simplified=20operand=20specification=20?=
=?UTF-8?q?=E2=80=93=20avoids=20cryptic=20"=3Dr,=3Dr,=3Df,=3Df,f,f,0,1"=20?=
=?UTF-8?q?constraints.=206.=20Predicate=20support:=20PTX=20 at p=20predicati?=
=?UTF-8?q?on=20support?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
IR Example:
```
%wo0, %wo1 = nvvm.inline_ptx """
.reg .pred p;
setp.ge.s32 p, {$r0}, {$r1};
selp.s32 {$rw0}, {$r0}, {$r1}, p;
selp.s32 {$rw1}, {$r0}, {$r1}, p;
selp.s32 {$w0}, {$r0}, {$r1}, p;
selp.s32 {$w1}, {$r0}, {$r1}, p;
""" ro(%a, %b : f32, f32) rw(%c, %d : i32, i32) -> f32, f32
```
After lowering
```
%0 = llvm.inline_asm has_side_effects asm_dialect = att
"{
.reg .pred p;\
setp.ge.s32 p, $4, $5; \
selp.s32 $0, $4, $5, p;\
selp.s32 $1, $4, $5, p;\
selp.s32 $2, $4, $5, p;\
selp.s32 $3, $4, $5, p;\
}"
"=r,=r,=f,=f,f,f,0,1"
%c500_i32, %c400_i32, %cst, %cst_0
: (i32, i32, f32, f32)
-> !llvm.struct<(i32, i32, f32, f32)>
%1 = llvm.extractvalue %0 : !llvm.struct<(i32, i32, f32, f32)>
%2 = llvm.extractvalue %0 : !llvm.struct<(i32, i32, f32, f32)>
%3 = llvm.extractvalue %0 : !llvm.struct<(i32, i32, f32, f32)>
%4 = llvm.extractvalue %0 : !llvm.struct<(i32, i32, f32, f32)>
// Unpacked result from nvvm.inline_ptx
%5 = arith.addi %1, %2 : i32
// read only
%6 = arith.addf %cst, %cst_0 : f32
// write only
%7 = arith.addf %3, %4 : f32
```
---
.../Dialect/LLVMIR/BasicPtxBuilderInterface.h | 14 +-
.../LLVMIR/BasicPtxBuilderInterface.td | 13 +-
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 22 +-
mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp | 4 +-
.../LLVMIR/IR/BasicPtxBuilderInterface.cpp | 321 ++++++++++++++++--
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 19 +-
.../Conversion/NVVMToLLVM/nvvm-to-llvm.mlir | 60 +++-
mlir/test/python/dialects/nvvm.py | 41 +++
8 files changed, 435 insertions(+), 59 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h b/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h
index 3e3fcd7d1fb82..99b1d9709e3e1 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h
@@ -26,9 +26,9 @@ namespace NVVM {
enum class PTXRegisterMod {
/// Read register with no modifier
Read = 0,
- /// Read register with '+' modifier
+ /// Read register with '=' modifier
Write = 2,
- /// Read register with '=' modifier.
+ /// Read register with '+' modifier.
/// Note that, this is not natively supported by LLVM, but it is possible to
/// set read and write for the same operand.
ReadWrite = 1,
@@ -67,13 +67,17 @@ class PtxBuilder {
SmallVector<Value> ptxOperands;
// Register constraints (read, write, readwrite) and register data types
std::string registerConstraints;
-
+ // Modifiers
+ SmallVector<PTXRegisterMod> registerModifiers;
bool hasResult = false;
+ bool needsManualMapping = false;
public:
/// Single constructor that only initializes members.
- PtxBuilder(Operation *op, PatternRewriter &rewriter)
- : interfaceOp(op), rewriter(rewriter) {}
+ PtxBuilder(Operation *op, PatternRewriter &rewriter,
+ bool needsManualMapping = false)
+ : interfaceOp(op), rewriter(rewriter),
+ needsManualMapping(needsManualMapping) {}
/// Add an operand with the read/write input type.
void insertValue(Value v, PTXRegisterMod itype = PTXRegisterMod::Read);
diff --git a/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td b/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td
index e98b94b5b3052..8e36749cdb361 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.td
@@ -124,19 +124,21 @@ def BasicPtxBuilderOpInterface : OpInterface<"BasicPtxBuilderInterface"> {
following this order:
1) Adds results
2) Adds operands
- 3) Adds attributes
+ 3) Adds attributes
+ Returns true if it does the mapping manually
}],
- /*retType=*/"void",
+ /*retType=*/"bool",
/*methodName=*/"getAsmValues",
/*args=*/(ins "::mlir::RewriterBase &":$rewriter,
- "llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>&" : $asmValues),
+ "llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>&" : $asmValues
+ ),
/*methodBody=*/"",
/*defaultImpl=*/ [{
mlir::Operation* op = $_op;
// Step 1. Add results
- for (auto val : op->getResults())
- asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Write});
+ for (auto val : op->getResults())
+ asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Write});
// Step 2. Add operands
for (auto val : op->getOperands())
@@ -149,6 +151,7 @@ def BasicPtxBuilderOpInterface : OpInterface<"BasicPtxBuilderInterface"> {
asmValues.push_back({val, mlir::NVVM::PTXRegisterMod::Read});
}
}
+ return false; // No needs manual mapping
}]
>
];
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index f9cd58de8915f..786d42cf15666 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -315,16 +315,19 @@ def NVVM_InlinePtxOp : NVVM_Op<"inline_ptx",
}];
let arguments = (ins Variadic<AnyType>:$readOnlyArgs,
+ Variadic<AnyType>:$readWriteArgs,
StrAttr:$ptxCode,
PtxPredicate:$predicate);
let results = (outs Variadic<AnyType>:$writeOnlyArgs);
-
- let assemblyFormat = [{
- $ptxCode `(` $readOnlyArgs `)`
- (`,` `predicate` `=` $predicate^)? attr-dict
- `:` type(operands)
- (`->` type($writeOnlyArgs)^)?
+
+ let assemblyFormat = [{
+ $ptxCode
+ ( `ro` `(` $readOnlyArgs^ `:` type($readOnlyArgs) `)` )?
+ ( `rw` `(` $readWriteArgs^ `:` type($readWriteArgs) `)` )?
+ (`,` `predicate` `=` $predicate^)?
+ attr-dict
+ ( `->` type($writeOnlyArgs)^ )?
}];
let extraClassDefinition = [{
@@ -333,6 +336,10 @@ def NVVM_InlinePtxOp : NVVM_Op<"inline_ptx",
return std::string(ptxInstStr.data());
}
}];
+
+ let extraClassDeclaration = [{
+ bool getAsmValues(RewriterBase &, llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> &);
+ }];
}
//===----------------------------------------------------------------------===//
@@ -3027,8 +3034,7 @@ def NVVM_WgmmaMmaAsyncOp : NVVM_Op<"wgmma.mma_async",
let hasVerifier = 1;
let extraClassDeclaration = [{
- void getAsmValues(RewriterBase &rewriter,
- llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> &asmValues);
+ bool getAsmValues(RewriterBase &, llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> &);
}];
}
diff --git a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
index e0144bff4d371..c67ec3642f121 100644
--- a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
+++ b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
@@ -57,9 +57,9 @@ struct PtxLowering
SmallVector<std::pair<Value, PTXRegisterMod>> asmValues;
LDBG() << op.getPtx();
- PtxBuilder generator(op, rewriter);
- op.getAsmValues(rewriter, asmValues);
+ bool needsManualMapping = op.getAsmValues(rewriter, asmValues);
+ PtxBuilder generator(op, rewriter, needsManualMapping);
for (auto &[asmValue, modifier] : asmValues) {
LDBG() << asmValue << "\t Modifier : " << modifier;
generator.insertValue(asmValue, modifier);
diff --git a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
index e004d5f64733e..3cad9d3bd16e3 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
@@ -12,6 +12,9 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/Regex.h"
#define DEBUG_TYPE "ptx-builder"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
@@ -59,12 +62,28 @@ static char getRegisterType(Value v) {
return getRegisterType(v.getType());
}
+/// Extract every elements of a struct value.
+static SmallVector<Value> extractStructElements(PatternRewriter &rewriter,
+ Location loc, Value agg) {
+ auto structTy = cast<LLVM::LLVMStructType>(agg.getType());
+ SmallVector<Value> elems;
+ elems.reserve(structTy.getBody().size());
+ for (auto [i, t] : llvm::enumerate(structTy.getBody())) {
+ (void)t;
+ Value e = LLVM::ExtractValueOp::create(rewriter, loc, agg, i);
+ elems.push_back(e);
+ }
+ return elems;
+}
+
void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
- LLVM_DEBUG(DBGS() << v << "\t Modifier : " << &itype << "\n");
+ LLVM_DEBUG(DBGS() << v << "\t Modifier : " << itype << "\n");
+ registerModifiers.push_back(itype);
+
auto getModifier = [&]() -> const char * {
if (itype == PTXRegisterMod::ReadWrite) {
- assert(false && "Read-Write modifier is not supported. Try setting the "
- "same value as Write and Read separately.");
+ // "Read-Write modifier is not supported
+ // Interface canonicalize it later
return "+";
}
if (itype == PTXRegisterMod::Write) {
@@ -72,6 +91,7 @@ void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
}
return "";
};
+
auto addValue = [&](Value v) {
if (itype == PTXRegisterMod::Read) {
ptxOperands.push_back(v);
@@ -108,38 +128,222 @@ void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
}
/// Check if the operation needs to pack and unpack results.
-static bool needsPackUnpack(BasicPtxBuilderInterface interfaceOp) {
- return interfaceOp->getNumResults() > 1;
+static bool
+needsPackUnpack(BasicPtxBuilderInterface interfaceOp, bool needsManualMapping,
+ SmallVectorImpl<PTXRegisterMod> ®isterModifiers) {
+ if (needsManualMapping)
+ return false;
+ const unsigned writeOnly = interfaceOp->getNumResults();
+ const unsigned readWrite =
+ llvm::count_if(registerModifiers, [](PTXRegisterMod m) {
+ return m == PTXRegisterMod::ReadWrite;
+ });
+ return (writeOnly + readWrite) > 1;
}
/// Pack the result types of the interface operation.
/// If the operation has multiple results, it packs them into a struct
/// type. Otherwise, it returns the original result types.
-static SmallVector<Type> packResultTypes(MLIRContext *ctx,
- BasicPtxBuilderInterface interfaceOp) {
- TypeRange results = interfaceOp->getResultTypes();
+static SmallVector<Type>
+packResultTypes(BasicPtxBuilderInterface interfaceOp, bool needsManualMapping,
+ SmallVectorImpl<PTXRegisterMod> ®isterModifiers,
+ SmallVectorImpl<Value> &ptxOperands) {
+ MLIRContext *ctx = interfaceOp->getContext();
+ TypeRange resultRange = interfaceOp->getResultTypes();
+
+ if (!needsPackUnpack(interfaceOp, needsManualMapping, registerModifiers)) {
+ // Single value path:
+ if (interfaceOp->getResults().size() == 1)
+ return SmallVector<Type>{resultRange.front()};
+
+ // No declared results: if there is an RW, forward its type.
+ for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands))
+ if (m == PTXRegisterMod::ReadWrite)
+ return SmallVector<Type>{v.getType()};
+ }
+
+ SmallVector<Type> packed;
+ for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands))
+ if (m == PTXRegisterMod::ReadWrite)
+ packed.push_back(v.getType());
+ for (Type t : resultRange)
+ packed.push_back(t);
+
+ if (packed.empty())
+ return {};
+
+ auto sTy = LLVM::LLVMStructType::getLiteral(ctx, packed, /*isPacked=*/false);
+ return SmallVector<Type>{sTy};
+}
+
+/// Canonicalize the register constraints:
+/// - Turn every "+X" into "=X"
+/// - Append (at the very end) the 0-based indices of tokens that were "+X"
+/// Examples:
+/// "+f,+f,+r,=r,=r,r,r" -> "=f,=f,=r,=r,=r,r,r,0,1,2"
+/// "+f,+f,+r,=r,=r" -> "=f,=f,=r,=r,=r,0,1,2"
+static std::string canonicalizeRegisterConstraints(llvm::StringRef csv) {
+ SmallVector<llvm::StringRef> toks;
+ SmallVector<std::string> out;
+ SmallVector<unsigned> plusIdx;
+
+ csv.split(toks, ',');
+ out.reserve(toks.size() + 8);
+
+ for (unsigned i = 0, e = toks.size(); i < e; ++i) {
+ StringRef t = toks[i].trim();
+ if (t.consume_front("+")) {
+ plusIdx.push_back(i);
+ out.push_back(("=" + t).str());
+ } else {
+ out.push_back(t.str());
+ }
+ }
+
+ // Append indices of original "+X" tokens.
+ for (unsigned idx : plusIdx)
+ out.push_back(std::to_string(idx));
+
+ // Join back to CSV.
+ std::string result;
+ result.reserve(csv.size() + plusIdx.size() * 2);
+ llvm::raw_string_ostream os(result);
+ for (size_t i = 0; i < out.size(); ++i) {
+ if (i)
+ os << ',';
+ os << out[i];
+ }
+ return os.str();
+}
+
+constexpr llvm::StringLiteral kReadWrite{"rw"};
+constexpr llvm::StringLiteral kWriteOnly{"w"};
+constexpr llvm::StringLiteral kReadOnly{"r"};
+
+/// Rewrites placeholders of the form `{$rN}`, `{$wN}`, `{$rwN}` in `asmText`
+/// to compact `$K` indices where all `rw*` come first (ascending N), then `w*`,
+/// then `r*`. Duplicates are de-duplicated when assigning numbers.
+/// Unknown text is preserved verbatim.
+///
+/// Example Input:
+/// "{
+/// reg .pred p;
+/// setp.ge.s32 p, {$r0}, {$r1};"
+/// selp.s32 {$rw0}, {$r0}, {$r1}, p;
+/// selp.s32 {$rw1}, {$r0}, {$r1}, p;
+/// selp.s32 {$w0}, {$r0}, {$r1}, p;
+/// selp.s32 {$w1}, {$r0}, {$r1}, p;
+/// }\n"
+/// Example Output:
+/// "{
+/// reg .pred p;
+/// setp.ge.s32 p, $4, $5;"
+/// selp.s32 $0, $4, $5, p;
+/// selp.s32 $1, $4, $5, p;
+/// selp.s32 $2, $4, $5, p;
+/// selp.s32 $3, $4, $5, p;
+/// }\n"
+static std::string rewriteAsmPlaceholders(llvm::StringRef asmText) {
+ // Match {$rwN}, {$wN}, {$rN}
+ llvm::Regex rx(llvm::formatv(R"(\{\$({0}|{1}|{2})([0-9]+)\})", kReadWrite,
+ kWriteOnly, kReadOnly)
+ .str());
+
+ llvm::SmallDenseSet<unsigned> seenRW, seenW, seenR;
+ llvm::SmallVector<unsigned> rwNums, wNums, rNums;
+
+ {
+ StringRef rest = asmText;
+ SmallVector<StringRef, 3> m; // 0: full, 1: kind, 2: number
+ while (!rest.empty() && rx.match(rest, &m)) {
+ unsigned num = 0;
+ (void)m[2].getAsInteger(10, num);
+
+ if (m[1].equals_insensitive(kReadWrite)) {
+ if (seenRW.insert(num).second)
+ rwNums.push_back(num);
+ } else if (m[1].equals_insensitive(kWriteOnly)) {
+ if (seenW.insert(num).second)
+ wNums.push_back(num);
+ } else {
+ if (seenR.insert(num).second)
+ rNums.push_back(num);
+ }
+
+ const size_t advance = (size_t)(m[0].data() - rest.data()) + m[0].size();
+ rest = rest.drop_front(advance);
+ }
+ }
+
+ llvm::sort(rwNums);
+ llvm::sort(wNums);
+ llvm::sort(rNums);
+
+ llvm::DenseMap<unsigned, unsigned> rwMap, wMap, rMap;
+ unsigned nextId = 0;
+ for (unsigned n : rwNums)
+ rwMap[n] = nextId++;
+ for (unsigned n : wNums)
+ wMap[n] = nextId++;
+ for (unsigned n : rNums)
+ rMap[n] = nextId++;
+
+ std::string out;
+ out.reserve(asmText.size());
- if (!needsPackUnpack(interfaceOp))
- return llvm::to_vector<1>(results);
+ size_t prev = 0;
+ StringRef rest = asmText;
+ SmallVector<StringRef, 3> m;
+ while (!rest.empty() && rx.match(rest, &m)) {
+ // Compute absolute match bounds in the original buffer.
+ size_t absStart = (size_t)(m[0].data() - asmText.data());
+ size_t absEnd = absStart + m[0].size();
- SmallVector<mlir::Type> elems(results.begin(), results.end());
- auto sTy = LLVM::LLVMStructType::getLiteral(ctx, elems, /*isPacked=*/false);
- return {sTy};
+ // Emit text before the match.
+ out.append(asmText.data() + prev, asmText.data() + absStart);
+
+ // Emit compact $K
+ unsigned num = 0;
+ (void)m[2].getAsInteger(10, num);
+ unsigned id = 0;
+ if (m[1].equals_insensitive(kReadWrite))
+ id = rwMap.lookup(num);
+ else if (m[1].equals_insensitive(kWriteOnly))
+ id = wMap.lookup(num);
+ else
+ id = rMap.lookup(num);
+
+ out.push_back('$');
+ out += std::to_string(id);
+
+ prev = absEnd;
+
+ // Advance search window.
+ const size_t advance = (size_t)(m[0].data() - rest.data()) + m[0].size();
+ rest = rest.drop_front(advance);
+ }
+
+ // Tail.
+ out.append(asmText.data() + prev, asmText.data() + asmText.size());
+ return out;
}
LLVM::InlineAsmOp PtxBuilder::build() {
- MLIRContext *ctx = interfaceOp->getContext();
auto asmDialectAttr = LLVM::AsmDialectAttr::get(interfaceOp->getContext(),
LLVM::AsmDialect::AD_ATT);
- SmallVector<Type> resultTypes = packResultTypes(ctx, interfaceOp);
+ SmallVector<Type> resultTypes = packResultTypes(
+ interfaceOp, needsManualMapping, registerModifiers, ptxOperands);
// Remove the last comma from the constraints string.
if (!registerConstraints.empty() &&
registerConstraints[registerConstraints.size() - 1] == ',')
registerConstraints.pop_back();
+ registerConstraints = canonicalizeRegisterConstraints(registerConstraints);
std::string ptxInstruction = interfaceOp.getPtx();
+ if (!needsManualMapping)
+ ptxInstruction = rewriteAsmPlaceholders(ptxInstruction);
// Add the predicate to the asm string.
if (interfaceOp.getPredicate().has_value() &&
@@ -169,33 +373,86 @@ void PtxBuilder::buildAndReplaceOp() {
LLVM::InlineAsmOp inlineAsmOp = build();
LLVM_DEBUG(DBGS() << "\n Generated PTX \n\t" << inlineAsmOp << "\n");
- // Case 1: no result
- if (inlineAsmOp->getNumResults() == 0) {
+ // Case 0: no result at all → just erase wrapper op.
+ if (!hasResult) {
rewriter.eraseOp(interfaceOp);
return;
}
- // Case 2: single result, forward it directly
- if (!needsPackUnpack(interfaceOp)) {
+ if (needsManualMapping) {
rewriter.replaceOp(interfaceOp, inlineAsmOp->getResults());
return;
}
- // Case 3: multiple results were packed; unpack the struct.
- assert(mlir::LLVM::LLVMStructType::classof(
- inlineAsmOp.getResultTypes().front()) &&
- "Expected result type to be LLVMStructType when unpacking multiple "
- "results");
- auto structTy = llvm::cast<mlir::LLVM::LLVMStructType>(
- inlineAsmOp.getResultTypes().front());
+ // Case 1: Simple path, return single scalar
+ if (!needsPackUnpack(interfaceOp, needsManualMapping, registerModifiers)) {
+ if (inlineAsmOp->getNumResults() > 0) {
+ rewriter.replaceOp(interfaceOp, inlineAsmOp->getResults());
+ } else {
+ // RW-only case with no declared results: forward the RW value.
+ SmallVector<Value> results;
+ for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands))
+ if (m == PTXRegisterMod::ReadWrite) {
+ results.push_back(v);
+ break;
+ }
+ rewriter.replaceOp(interfaceOp, results);
+ }
+ return;
+ }
+
+ const bool hasRW = llvm::any_of(registerModifiers, [](PTXRegisterMod m) {
+ return m == PTXRegisterMod::ReadWrite;
+ });
- SmallVector<mlir::Value> unpacked;
+ // All multi-value paths produce a single struct result we need to unpack.
+ assert(LLVM::LLVMStructType::classof(inlineAsmOp.getResultTypes().front()) &&
+ "expected struct return for multi-result inline asm");
Value structVal = inlineAsmOp.getResult(0);
- for (auto [idx, elemTy] : llvm::enumerate(structTy.getBody())) {
- Value unpackedValue = LLVM::ExtractValueOp::create(
- rewriter, interfaceOp->getLoc(), structVal, idx);
- unpacked.push_back(unpackedValue);
+ SmallVector<Value> unpacked =
+ extractStructElements(rewriter, interfaceOp->getLoc(), structVal);
+
+ // Case 2: only declared results (no RW): replace the op with all unpacked.
+ if (!hasRW && interfaceOp->getResults().size() > 0) {
+ rewriter.replaceOp(interfaceOp, unpacked);
+ return;
+ }
+
+ // Case 3: RW-only (no declared results): update RW uses and erase wrapper.
+ if (hasRW && interfaceOp->getResults().size() == 0) {
+ unsigned idx = 0;
+ for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands)) {
+ if (m != PTXRegisterMod::ReadWrite)
+ continue;
+ Value repl = unpacked[idx++];
+ v.replaceUsesWithIf(repl, [&](OpOperand &use) {
+ Operation *owner = use.getOwner();
+ return owner != interfaceOp && owner != inlineAsmOp;
+ });
+ }
+ rewriter.eraseOp(interfaceOp);
+ return;
}
- rewriter.replaceOp(interfaceOp, unpacked);
+ // Case 4: mixed (RW + declared results).
+ {
+ // First rewrite RW operands in place.
+ unsigned idx = 0;
+ for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands)) {
+ if (m != PTXRegisterMod::ReadWrite)
+ continue;
+ Value repl = unpacked[idx++];
+ v.replaceUsesWithIf(repl, [&](OpOperand &use) {
+ Operation *owner = use.getOwner();
+ return owner != interfaceOp && owner != inlineAsmOp;
+ });
+ }
+ // The remaining unpacked values correspond to the declared results.
+ SmallVector<Value> tail;
+ tail.reserve(unpacked.size() - idx);
+ for (unsigned i = idx, e = unpacked.size(); i < e; ++i)
+ tail.push_back(unpacked[i]);
+
+ rewriter.replaceOp(interfaceOp, tail);
+ }
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index dbcc738b4419f..ae9134458095f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1123,7 +1123,7 @@ std::string NVVM::WgmmaMmaAsyncOp::getPtx() {
return ptx;
}
-void NVVM::WgmmaMmaAsyncOp::getAsmValues(
+bool NVVM::WgmmaMmaAsyncOp::getAsmValues(
RewriterBase &rewriter,
llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
&asmValues) {
@@ -1154,7 +1154,9 @@ void NVVM::WgmmaMmaAsyncOp::getAsmValues(
{makeConstantI32(rewriter, 1 - static_cast<int>(getLayoutB())),
mlir::NVVM::PTXRegisterMod::Read});
}
+ return true; // Has manual mapping
}
+
LogicalResult NVVM::FenceProxyOp::verify() {
if (getKind() == NVVM::ProxyKind::TENSORMAP)
return emitOpError() << "tensormap proxy is not a supported proxy kind";
@@ -1870,6 +1872,21 @@ llvm::Intrinsic::ID PrefetchOp::getIntrinsicID(NVVM::PrefetchOp &op) {
}
}
+bool NVVM::InlinePtxOp::getAsmValues(
+ RewriterBase &rewriter,
+ llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
+ &asmValues) {
+ for (auto arg : getReadWriteArgs())
+ asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::ReadWrite});
+ for (auto arg : getResults())
+ asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::Write});
+ for (auto arg : getReadOnlyArgs())
+ asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::Read});
+ if (getPredicate())
+ asmValues.push_back({getPredicate(), mlir::NVVM::PTXRegisterMod::Read});
+ return false; // Needs manual mapping
+}
+
//===----------------------------------------------------------------------===//
// NVVMDialect initialization, type parsing, and registration.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index b38347c7cd1b7..2a19c72ab0840 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -667,34 +667,82 @@ llvm.func @init_mbarrier(
%count : i32,
%pred : i1) {
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.init.b64 [$0], $1;", "l,r"
- nvvm.inline_ptx "mbarrier.init.b64 [$0], $1;" (%barrier_gen, %count) : !llvm.ptr, i32
+ nvvm.inline_ptx "mbarrier.init.b64 [{$r0}], {$r1};" ro (%barrier_gen, %count : !llvm.ptr, i32)
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.init.b64 [$0], $1;", "l,r,b"
- nvvm.inline_ptx "mbarrier.init.b64 [$0], $1;" (%barrier_gen, %count), predicate = %pred : !llvm.ptr, i32, i1
+ nvvm.inline_ptx "mbarrier.init.b64 [{$r0}], {$r1};" ro (%barrier_gen, %count : !llvm.ptr, i32), predicate = %pred
llvm.return
}
// -----
llvm.func @ex2(%input : f32, %pred : i1) {
// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "ex2.approx.ftz.f32 $0, $1;", "=f,f" %{{.*}} : (f32) -> f32
- %0 = nvvm.inline_ptx "ex2.approx.ftz.f32 $0, $1;" (%input) : f32 -> f32
+ %0 = nvvm.inline_ptx "ex2.approx.ftz.f32 {$w0}, {$r0};" ro (%input : f32) -> f32
// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "@$1 ex2.approx.ftz.f32 $0, $1;", "=f,f,b" %{{.*}}, %{{.*}} : (f32, i1) -> f32
- %1 = nvvm.inline_ptx "ex2.approx.ftz.f32 $0, $1;" (%input), predicate = %pred : f32, i1 -> f32
+ %1 = nvvm.inline_ptx "ex2.approx.ftz.f32 {$w0}, {$r0};" ro (%input : f32), predicate = %pred -> f32
llvm.return
}
// CHECK-LABEL: @multi_return(
// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: i32, %[[arg1:[a-zA-Z0-9_]+]]: i32)
llvm.func @multi_return(%a : i32, %b : i32) -> i32 {
- // CHECK: %[[S1:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "{\0A\09 .reg .pred p;\0A\09 setp.ge.s32 p, $2, $3;\0A\09 selp.s32 $0, $2, $3, p;\0A\09 selp.s32 $1, $2, $3, !p;\0A\09}\0A", "=r,=r,r,r" %[[arg0]], %[[arg1]] : (i32, i32) -> !llvm.struct<(i32, i32)>
+ // CHECK: %[[S1:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "{.reg .pred p; setp.ge.s32 p, $2, $3; selp.s32 $0, $2,$3, p; selp.s32 $1, $2,$3, p;}", "=r,=r,r,r" %[[arg0]], %[[arg1]] : (i32, i32) -> !llvm.struct<(i32, i32)>
// CHECK: %[[S2:.+]] = llvm.extractvalue %[[S1]][0] : !llvm.struct<(i32, i32)>
// CHECK: %[[S3:.+]] = llvm.extractvalue %[[S1]][1] : !llvm.struct<(i32, i32)>
// CHECK: %[[S4:.+]] = llvm.add %[[S2]], %[[S3]] : i32
// CHECK: llvm.return %[[S4]] : i32
- %r1, %r2 = nvvm.inline_ptx "{\n\t .reg .pred p;\n\t setp.ge.s32 p, $2, $3;\n\t selp.s32 $0, $2, $3, p;\n\t selp.s32 $1, $2, $3, !p;\n\t}\n" (%a, %b) : i32,i32 -> i32,i32
+ %r1, %r2 = nvvm.inline_ptx "{.reg .pred p; setp.ge.s32 p, {$r0}, {$r1}; selp.s32 {$w0}, {$r0},{$r1}, p; selp.s32 {$w1}, {$r0},{$r1}, p;}"
+ ro (%a, %b : i32,i32) -> i32,i32
%r3 = llvm.add %r1, %r2 : i32
llvm.return %r3 : i32
}
+
+// CHECK-LABEL: @inline_ptx_multi_rw(
+// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: i32, %[[arg1:[a-zA-Z0-9_]+]]: i32, %[[arg2:[a-zA-Z0-9_]+]]: f32, %[[arg3:[a-zA-Z0-9_]+]]: f32)
+llvm.func @inline_ptx_multi_rw(%a : i32, %b : i32, %rw_c : f32, %rw_d : f32) -> f32 {
+// CHECK: %[[S0:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "{.reg .pred p; setp.ge.s32 p, $2, $3; selp.s32 $0, $2,$3, p; selp.s32 $1, $2,$3, p;}",
+// CHECK-SAME: "=f,=f,r,r,0,1"
+// CHECK-SAME: %[[arg2]], %[[arg3]], %[[arg0]], %[[arg1]]
+// CHECK-SAME: : (f32, f32, i32, i32) -> !llvm.struct<(f32, f32)>
+// CHECK: %[[S1:.+]] = llvm.extractvalue %[[S0]][0] : !llvm.struct<(f32, f32)>
+// CHECK: %[[S2:.+]] = llvm.extractvalue %[[S0]][1] : !llvm.struct<(f32, f32)>
+// CHECK: %[[S3:.+]] = llvm.fadd %[[S1]], %[[S2]] : f32
+// CHECK: llvm.return %[[S3]] : f32
+ nvvm.inline_ptx "{.reg .pred p; setp.ge.s32 p, {$r0}, {$r1}; selp.s32 {$rw0}, {$r0},{$r1}, p; selp.s32 {$rw1}, {$r0},{$r1}, p;}"
+ ro (%a, %b : i32,i32)
+ rw (%rw_c, %rw_d: f32,f32)
+ %r4 = llvm.fadd %rw_c, %rw_d : f32
+ llvm.return %r4 : f32
+}
+
+// CHECK-LABEL: @inline_ptx_multi_rw_r(
+// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: i32, %[[arg1:[a-zA-Z0-9_]+]]: i32, %[[arg2:[a-zA-Z0-9_]+]]: f32, %[[arg3:[a-zA-Z0-9_]+]]: f32)
+llvm.func @inline_ptx_multi_rw_r(%a : i32, %b : i32, %rw_c : f32, %rw_d : f32) -> f32 {
+// CHECK: %[[S0:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "{.reg .pred p; setp.ge.s32 p, $4, $5; selp.s32 $0, $4,$5, p; selp.s32 $1, $4,$5, p; selp.s32 $2, $4,$5, p; selp.s32 $3, $4,$5, p;}",
+// CHECK-SAME: "=f,=f,=r,=r,r,r,0,1"
+// CHECK-SAME: %[[arg2]], %[[arg3]], %[[arg0]], %[[arg1]] :
+// CHECK-SAME: (f32, f32, i32, i32) -> !llvm.struct<(f32, f32, i32, i32)>
+// CHECK: %[[S1:.+]] = llvm.extractvalue %[[S0]][0] : !llvm.struct<(f32, f32, i32, i32)>
+// CHECK: %[[S2:.+]] = llvm.extractvalue %[[S0]][1] : !llvm.struct<(f32, f32, i32, i32)>
+// CHECK: %[[S3:.+]] = llvm.extractvalue %[[S0]][2] : !llvm.struct<(f32, f32, i32, i32)>
+// CHECK: %[[S4:.+]] = llvm.extractvalue %[[S0]][3] : !llvm.struct<(f32, f32, i32, i32)>
+// CHECK: %[[S5:.+]] = llvm.add %[[S3]], %[[S4]] : i32
+// CHECK: %[[S6:.+]] = llvm.sitofp %[[S5]] : i32 to f32
+// CHECK: %[[S7:.+]] = llvm.fadd %[[S1]], %[[S2]] : f32
+// CHECK: %[[S8:.+]] = llvm.fadd %[[S6]], %[[S2]] : f32
+// CHECK: llvm.return %[[S8]] : f32
+
+ %wo0, %wo1 = nvvm.inline_ptx "{.reg .pred p; setp.ge.s32 p, {$r0}, {$r1}; selp.s32 {$rw0}, {$r0},{$r1}, p; selp.s32 {$rw1}, {$r0},{$r1}, p; selp.s32 {$w0}, {$r0},{$r1}, p; selp.s32 {$w1}, {$r0},{$r1}, p;}"
+ ro (%a, %b : i32,i32)
+ rw (%rw_c, %rw_d: f32,f32) -> i32,i32
+ %r3 = llvm.add %wo0, %wo1 : i32
+ %r3f = llvm.sitofp %r3 : i32 to f32
+ %r4 = llvm.fadd %rw_c, %rw_d : f32
+ %r5 = llvm.fadd %r3f, %rw_d : f32
+ llvm.return %r5 : f32
+}
+
+
// -----
// CHECK-LABEL: @nvvm_pmevent
diff --git a/mlir/test/python/dialects/nvvm.py b/mlir/test/python/dialects/nvvm.py
index 0eef97d95479a..3eb62bef50de9 100644
--- a/mlir/test/python/dialects/nvvm.py
+++ b/mlir/test/python/dialects/nvvm.py
@@ -5,6 +5,8 @@
from mlir.dialects import nvvm
from mlir.dialects import llvm
from mlir.dialects import func
+import mlir.extras.types as T
+from mlir.dialects import arith
def constructAndPrintInModule(f):
@@ -25,6 +27,7 @@ def testSmoke():
"!llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>"
)
shape_attr = Attribute.parse("#nvvm.shape<m = 64, n = 32, k = 16>")
+
# CHECK-LABEL: func @wgmma_f32_f16_f16(%arg0: i64, %arg1: i64)
@func.FuncOp.from_py_func(i64, i64)
def wgmma_f32_f16_f16(desc_a, desc_b):
@@ -48,3 +51,41 @@ def wgmma_f32_f16_f16(desc_a, desc_b):
layoutA=nvvm.MMALayout.col,
layoutB=nvvm.MMALayout.col,
)
+
+
+# CHECK-LABEL: TEST: test_inline_ptx
+# CHECK-LABEL: func.func @my_inline_ptx(
+# CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: f32, %[[arg1:[a-zA-Z0-9_]+]]: f32, %[[arg2:[a-zA-Z0-9_]+]]: i32, %[[arg3:[a-zA-Z0-9_]+]]: i32)
+# CHECK: %[[S0:.+]]:2 = nvvm.inline_ptx
+# CHECK-SAME: ro(%[[arg0]], %[[arg1]] : f32, f32) rw(%[[arg2]], %[[arg3]] : i32, i32) -> f32, f32
+# CHECK: %[[S1:.+]] = arith.addf %[[arg0]], %[[arg1]] : f32
+# CHECK: %[[S2:.+]] = arith.addi %[[arg2]], %[[arg3]] : i32
+# CHECK: %[[S3:.+]] = arith.addf %[[S0]]#0, %[[S0]]#1 : f32
+
+
+ at constructAndPrintInModule
+def test_inline_ptx():
+ i32 = T.i32()
+ f32 = T.f32()
+
+ @func.FuncOp.from_py_func(f32, f32, i32, i32)
+ def my_inline_ptx(a, b, c, d):
+ ptx = r"""
+ {
+ .reg .pred p;
+ setp.ge.s32 p, {$r0}, {$r1};
+ selp.s32 {$r0}, {$r0}, {$r1}, p;
+ selp.s32 {$r1}, {$r0}, {$r1}, p;
+ selp.s32 {$rw0}, {$r0}, {$r1}, p;
+ selp.s32 {$rw1}, {$r0}, {$r1}, p;
+ }
+ """
+ wo0, wo1 = nvvm.inline_ptx(
+ read_only_args=[a, b],
+ read_write_args=[c, d],
+ write_only_args=[f32, f32],
+ ptx_code=ptx,
+ )
+ arith.addf(a, b)
+ arith.addi(c, d)
+ arith.addf(wo0, wo1)
More information about the Mlir-commits
mailing list