[Mlir-commits] [mlir] [MLIR][NVVM] Fix predicate operand index in BasicPtxBuilderInterface (PR #189552)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Mar 31 00:47:24 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-llvm
Author: Guray Ozen (grypp)
<details>
<summary>Changes</summary>
Predicate index computation was incorrect, it was not coutning write/readwrite symbols.
Wrong case
```
// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "@$1 ex2.approx.ftz.f32 $0, $1;", "=f,f,b" %{{.*}}, %{{.*}} : (f32, i1) -> f32
%1 = nvvm.inline_ptx "ex2.approx.ftz.f32 {$w0}, {$r0};" ro (%input : f32), predicate = %pred -> f32
```
PR fixes
```
// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "@$2 ex2.approx.ftz.f32 $0, $1;", "=f,f,b" %{{.*}}, %{{.*}} : (f32, i1) -> f32
%1 = nvvm.inline_ptx "ex2.approx.ftz.f32 {$w0}, {$r0};" ro (%input : f32), predicate = %pred -> f32
```
---
Full diff: https://github.com/llvm/llvm-project/pull/189552.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp (+20-1)
- (modified) mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir (+15-1)
``````````diff
diff --git a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
index d85b2ad9a0542..e6a83a5c1598c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
@@ -423,6 +423,24 @@ static std::string rewriteAsmPlaceholders(llvm::StringRef ptxCode) {
return out;
}
+/// Return the constraint index of the predicate operand. The predicate
+/// constraint ("b") is always the last non-tied token in the canonicalized
+/// constraint string. Tied constraints (digit-only tokens from read-write
+/// canonicalization) are appended at the end, so we walk backwards to skip
+/// them.
+static unsigned getPredicateConstraintIndex(StringRef constraints) {
+ SmallVector<StringRef> tokens;
+ constraints.split(tokens, ',');
+ unsigned numTied = 0;
+ for (auto it = tokens.rbegin(); it != tokens.rend(); ++it) {
+ unsigned id;
+ if (it->trim().getAsInteger(10, id))
+ break;
+ ++numTied;
+ }
+ return tokens.size() - numTied - 1;
+}
+
LLVM::InlineAsmOp PtxBuilder::build() {
auto asmDialectAttr = LLVM::AsmDialectAttr::get(interfaceOp->getContext(),
LLVM::AsmDialect::AD_ATT);
@@ -443,8 +461,9 @@ LLVM::InlineAsmOp PtxBuilder::build() {
// Add the predicate to the asm string.
if (interfaceOp.getPredicate().has_value() &&
interfaceOp.getPredicate().value()) {
+ unsigned predIdx = getPredicateConstraintIndex(registerConstraints);
std::string predicateStr = "@%";
- predicateStr += std::to_string((ptxOperands.size() - 1));
+ predicateStr += std::to_string(predIdx);
ptxInstruction = predicateStr + " " + ptxInstruction;
}
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 8e16a92d96a7b..9734fd45980f7 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -614,11 +614,25 @@ llvm.func @ex2(%input : f32, %pred : i1) {
// CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "ex2.approx.ftz.f32 $0, $1;", "=f,f" %{{.*}} : (f32) -> f32
%0 = nvvm.inline_ptx "ex2.approx.ftz.f32 {$w0}, {$r0};" ro (%input : f32) -> f32
- // CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "@$1 ex2.approx.ftz.f32 $0, $1;", "=f,f,b" %{{.*}}, %{{.*}} : (f32, i1) -> f32
+ // CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "@$2 ex2.approx.ftz.f32 $0, $1;", "=f,f,b" %{{.*}}, %{{.*}} : (f32, i1) -> f32
%1 = nvvm.inline_ptx "ex2.approx.ftz.f32 {$w0}, {$r0};" ro (%input : f32), predicate = %pred -> f32
llvm.return
}
+// CHECK-LABEL: @multi_return_pred(
+// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: i32, %[[arg1:[a-zA-Z0-9_]+]]: i32, %[[pred:[a-zA-Z0-9_]+]]: i1)
+llvm.func @multi_return_pred(%a : i32, %b : i32, %pred : i1) -> i32 {
+ // CHECK: %[[S1:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "@$4 {.reg .pred p; setp.ge.s32 p, $2, $3; selp.s32 $0, $2,$3, p; selp.s32 $1, $2,$3, p;}", "=r,=r,r,r,b" %[[arg0]], %[[arg1]], %[[pred]] : (i32, i32, i1) -> !llvm.struct<(i32, i32)>
+ // CHECK: %[[S2:.+]] = llvm.extractvalue %[[S1]][0] : !llvm.struct<(i32, i32)>
+ // CHECK: %[[S3:.+]] = llvm.extractvalue %[[S1]][1] : !llvm.struct<(i32, i32)>
+ // CHECK: %[[S4:.+]] = llvm.add %[[S2]], %[[S3]] : i32
+ // CHECK: llvm.return %[[S4]] : i32
+ %r1, %r2 = nvvm.inline_ptx "{.reg .pred p; setp.ge.s32 p, {$r0}, {$r1}; selp.s32 {$w0}, {$r0},{$r1}, p; selp.s32 {$w1}, {$r0},{$r1}, p;}"
+ ro (%a, %b : i32,i32), predicate = %pred -> i32,i32
+ %r3 = llvm.add %r1, %r2 : i32
+ llvm.return %r3 : i32
+}
+
// CHECK-LABEL: @multi_return(
// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: i32, %[[arg1:[a-zA-Z0-9_]+]]: i32)
llvm.func @multi_return(%a : i32, %b : i32) -> i32 {
``````````
</details>
https://github.com/llvm/llvm-project/pull/189552
More information about the Mlir-commits
mailing list