[Mlir-commits] [mlir] [MLIR][NVVM] Fix predicate operand index in BasicPtxBuilderInterface (PR #189552)

Guray Ozen llvmlistbot at llvm.org
Tue Mar 31 03:08:41 PDT 2026


https://github.com/grypp updated https://github.com/llvm/llvm-project/pull/189552

>From c58ec05d1c5d3ee0763886cb3c1aa9e5d8b25612 Mon Sep 17 00:00:00 2001
From: Guray Ozen <gozen at nvidia.com>
Date: Tue, 31 Mar 2026 09:46:20 +0200
Subject: [PATCH 1/3] [MLIR][NVVM] Fix predicate operand index in
 BasicPtxBuilderInterface

Predicate index computation was incorrect, it was not coutning write/readwrite symbols.

Wrong case
```
  // CHECK: %{{.*}} =  llvm.inline_asm has_side_effects asm_dialect = att "@$1 ex2.approx.ftz.f32 $0, $1;", "=f,f,b" %{{.*}}, %{{.*}} : (f32, i1) -> f32
  %1 = nvvm.inline_ptx "ex2.approx.ftz.f32 {$w0}, {$r0};" ro (%input : f32), predicate = %pred  -> f32
```

PR fixes
```
// CHECK: %{{.*}} =  llvm.inline_asm has_side_effects asm_dialect = att "@$2 ex2.approx.ftz.f32 $0, $1;", "=f,f,b" %{{.*}}, %{{.*}} : (f32, i1) -> f32
  %1 = nvvm.inline_ptx "ex2.approx.ftz.f32 {$w0}, {$r0};" ro (%input : f32), predicate = %pred  -> f32
  ```
---
 .../LLVMIR/IR/BasicPtxBuilderInterface.cpp    | 21 ++++++++++++++++++-
 .../Conversion/NVVMToLLVM/nvvm-to-llvm.mlir   | 16 +++++++++++++-
 2 files changed, 35 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
index d85b2ad9a0542..e6a83a5c1598c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
@@ -423,6 +423,24 @@ static std::string rewriteAsmPlaceholders(llvm::StringRef ptxCode) {
   return out;
 }
 
+/// Return the constraint index of the predicate operand.  The predicate
+/// constraint ("b") is always the last non-tied token in the canonicalized
+/// constraint string.  Tied constraints (digit-only tokens from read-write
+/// canonicalization) are appended at the end, so we walk backwards to skip
+/// them.
+static unsigned getPredicateConstraintIndex(StringRef constraints) {
+  SmallVector<StringRef> tokens;
+  constraints.split(tokens, ',');
+  unsigned numTied = 0;
+  for (auto it = tokens.rbegin(); it != tokens.rend(); ++it) {
+    unsigned id;
+    if (it->trim().getAsInteger(10, id))
+      break;
+    ++numTied;
+  }
+  return tokens.size() - numTied - 1;
+}
+
 LLVM::InlineAsmOp PtxBuilder::build() {
   auto asmDialectAttr = LLVM::AsmDialectAttr::get(interfaceOp->getContext(),
                                                   LLVM::AsmDialect::AD_ATT);
@@ -443,8 +461,9 @@ LLVM::InlineAsmOp PtxBuilder::build() {
   // Add the predicate to the asm string.
   if (interfaceOp.getPredicate().has_value() &&
       interfaceOp.getPredicate().value()) {
+    unsigned predIdx = getPredicateConstraintIndex(registerConstraints);
     std::string predicateStr = "@%";
-    predicateStr += std::to_string((ptxOperands.size() - 1));
+    predicateStr += std::to_string(predIdx);
     ptxInstruction = predicateStr + " " + ptxInstruction;
   }
 
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 8e16a92d96a7b..9734fd45980f7 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -614,11 +614,25 @@ llvm.func @ex2(%input : f32, %pred : i1) {
   // CHECK: %{{.*}} = llvm.inline_asm has_side_effects asm_dialect = att "ex2.approx.ftz.f32 $0, $1;", "=f,f" %{{.*}} : (f32) -> f32
   %0 = nvvm.inline_ptx "ex2.approx.ftz.f32 {$w0}, {$r0};" ro (%input : f32) -> f32
   
-  // CHECK: %{{.*}} =  llvm.inline_asm has_side_effects asm_dialect = att "@$1 ex2.approx.ftz.f32 $0, $1;", "=f,f,b" %{{.*}}, %{{.*}} : (f32, i1) -> f32
+  // CHECK: %{{.*}} =  llvm.inline_asm has_side_effects asm_dialect = att "@$2 ex2.approx.ftz.f32 $0, $1;", "=f,f,b" %{{.*}}, %{{.*}} : (f32, i1) -> f32
   %1 = nvvm.inline_ptx "ex2.approx.ftz.f32 {$w0}, {$r0};" ro (%input : f32), predicate = %pred  -> f32
   llvm.return
 }
 
+// CHECK-LABEL: @multi_return_pred(
+// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: i32, %[[arg1:[a-zA-Z0-9_]+]]: i32, %[[pred:[a-zA-Z0-9_]+]]: i1)
+llvm.func @multi_return_pred(%a : i32, %b : i32, %pred : i1) -> i32 {
+  // CHECK: %[[S1:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "@$4 {.reg .pred p; setp.ge.s32 p, $2, $3; selp.s32 $0, $2,$3, p; selp.s32 $1, $2,$3, p;}", "=r,=r,r,r,b" %[[arg0]], %[[arg1]], %[[pred]] : (i32, i32, i1) -> !llvm.struct<(i32, i32)>
+  // CHECK: %[[S2:.+]] = llvm.extractvalue %[[S1]][0] : !llvm.struct<(i32, i32)>
+  // CHECK: %[[S3:.+]] = llvm.extractvalue %[[S1]][1] : !llvm.struct<(i32, i32)>
+  // CHECK: %[[S4:.+]] = llvm.add %[[S2]], %[[S3]] : i32
+  // CHECK: llvm.return %[[S4]] : i32
+   %r1, %r2 = nvvm.inline_ptx "{.reg .pred p; setp.ge.s32 p, {$r0}, {$r1}; selp.s32 {$w0}, {$r0},{$r1}, p; selp.s32 {$w1}, {$r0},{$r1}, p;}"
+                  ro (%a, %b : i32,i32), predicate = %pred -> i32,i32
+   %r3 = llvm.add %r1, %r2 : i32
+   llvm.return %r3 : i32
+}
+
 // CHECK-LABEL: @multi_return(
 // CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: i32, %[[arg1:[a-zA-Z0-9_]+]]: i32)
 llvm.func @multi_return(%a : i32, %b : i32) -> i32 {

>From d6bcc0fb96396e90364509416f382f7251b5e0b7 Mon Sep 17 00:00:00 2001
From: Guray Ozen <guray.ozen at gmail.com>
Date: Tue, 31 Mar 2026 12:02:19 +0200
Subject: [PATCH 2/3] Update
 mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp

Co-authored-by: Mehdi Amini <joker.eph at gmail.com>
---
 mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
index e6a83a5c1598c..9af9dede19e59 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
@@ -432,7 +432,7 @@ static unsigned getPredicateConstraintIndex(StringRef constraints) {
   SmallVector<StringRef> tokens;
   constraints.split(tokens, ',');
   unsigned numTied = 0;
-  for (auto it = tokens.rbegin(); it != tokens.rend(); ++it) {
+  for (StringRef &token : llvm::reverse(tokens)) {
     unsigned id;
     if (it->trim().getAsInteger(10, id))
       break;

>From baa07c9cde206fd51a212d276a0ca30bac73fd86 Mon Sep 17 00:00:00 2001
From: Guray Ozen <gozen at nvidia.com>
Date: Tue, 31 Mar 2026 12:08:06 +0200
Subject: [PATCH 3/3] fz

---
 .../Conversion/NVVMToLLVM/nvvm-to-llvm.mlir    | 18 ++++++++++++++++++
 1 file changed, 18 insertions(+)

diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 9734fd45980f7..898bea2bf5e84 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -665,6 +665,24 @@ llvm.func @inline_ptx_multi_rw(%a : i32, %b : i32,  %rw_c : f32, %rw_d : f32) ->
    llvm.return %r4 : f32
 }
 
+// CHECK-LABEL: @inline_ptx_multi_rw_pred(
+// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: i32, %[[arg1:[a-zA-Z0-9_]+]]: i32, %[[arg2:[a-zA-Z0-9_]+]]: f32, %[[arg3:[a-zA-Z0-9_]+]]: f32, %[[pred:[a-zA-Z0-9_]+]]: i1)
+llvm.func @inline_ptx_multi_rw_pred(%a : i32, %b : i32, %rw_c : f32, %rw_d : f32, %pred : i1) -> f32 {
+// CHECK: %[[S0:.+]] = llvm.inline_asm has_side_effects asm_dialect = att "@$4 {.reg .pred p; setp.ge.s32 p, $2, $3; selp.s32 $0, $2,$3, p; selp.s32 $1, $2,$3, p;}",
+// CHECK-SAME: "=f,=f,r,r,b,0,1"
+// CHECK-SAME: %[[arg2]], %[[arg3]], %[[arg0]], %[[arg1]], %[[pred]]
+// CHECK-SAME: : (f32, f32, i32, i32, i1) -> !llvm.struct<(f32, f32)>
+// CHECK: %[[S1:.+]] = llvm.extractvalue %[[S0]][0] : !llvm.struct<(f32, f32)>
+// CHECK: %[[S2:.+]] = llvm.extractvalue %[[S0]][1] : !llvm.struct<(f32, f32)>
+// CHECK: %[[S3:.+]] = llvm.fadd %[[S1]], %[[S2]] : f32
+// CHECK: llvm.return %[[S3]] : f32
+    nvvm.inline_ptx "{.reg .pred p; setp.ge.s32 p, {$r0}, {$r1}; selp.s32 {$rw0}, {$r0},{$r1}, p; selp.s32 {$rw1}, {$r0},{$r1}, p;}"
+    ro (%a, %b : i32,i32)
+    rw (%rw_c, %rw_d: f32,f32), predicate = %pred
+   %r4 = llvm.fadd %rw_c, %rw_d : f32
+   llvm.return %r4 : f32
+}
+
 // CHECK-LABEL: @inline_ptx_multi_rw_r(
 // CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: i32, %[[arg1:[a-zA-Z0-9_]+]]: i32, %[[arg2:[a-zA-Z0-9_]+]]: f32, %[[arg3:[a-zA-Z0-9_]+]]: f32)
 llvm.func @inline_ptx_multi_rw_r(%a : i32, %b : i32,  %rw_c : f32, %rw_d : f32) -> f32 {



More information about the Mlir-commits mailing list