[Mlir-commits] [mlir] [MLIR][NVVM] Update mbarrier Ops to use AnyTypeOf[] (3/3) (PR #167567)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Nov 11 10:50:12 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-gpu
Author: Durgadoss R (durga4github)
<details>
<summary>Changes</summary>
This is a follow-up of PR #<!-- -->165558 and #<!-- -->165993.
This patch updates the remaining two Ops to use the AnyTypeOf[]
construct, completing the migration for the mbarrier family of Ops.
---
Full diff: https://github.com/llvm/llvm-project/pull/167567.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+6-60)
- (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (-14)
- (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+39-14)
- (modified) mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir (+5-5)
- (modified) mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir (+3-3)
``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 1cc5b74a3cb67..13a2af0efe87d 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -745,7 +745,9 @@ def NVVM_MBarrierArriveNocompleteOp : NVVM_Op<"mbarrier.arrive.nocomplete">,
}
def NVVM_MBarrierArriveExpectTxOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx">,
- Arguments<(ins LLVM_AnyPointer:$addr, I32:$txcount, PtxPredicate:$predicate)> {
+ Arguments<(ins
+ AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$addr,
+ I32:$txcount, PtxPredicate:$predicate)> {
let summary = "MBarrier Arrive with Expected Transaction Count";
let description = [{
The `nvvm.mbarrier.arrive.expect_tx` operation performs an expect-tx operation
@@ -773,28 +775,12 @@ def NVVM_MBarrierArriveExpectTxOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_t
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive)
}];
let assemblyFormat = "$addr `,` $txcount (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
- let extraClassDefinition = [{
- std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.b64 _, [%0], %1;"); }
- }];
-}
-
-def NVVM_MBarrierArriveExpectTxSharedOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx.shared">,
- Arguments<(ins LLVM_PointerShared:$addr, I32:$txcount, PtxPredicate:$predicate)> {
- let summary = "Shared MBarrier Arrive with Expected Transaction Count";
- let description = [{
- This Op is the same as `nvvm.mbarrier.arrive.expect_tx` except that the *mbarrier object*
- should be accessed using a shared-memory pointer instead of a generic-memory pointer.
-
- [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive)
- }];
- let assemblyFormat = "$addr `,` $txcount (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
- let extraClassDefinition = [{
- std::string $cppClass::getPtx() { return std::string("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;"); }
- }];
}
def NVVM_MBarrierTryWaitParityOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity">,
- Arguments<(ins LLVM_AnyPointer:$addr, I32:$phase, I32:$ticks)> {
+ Arguments<(ins
+ AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$addr,
+ I32:$phase, I32:$ticks)> {
let summary = "MBarrier Potentially-Blocking Try Wait with Phase Parity";
let description = [{
The `nvvm.mbarrier.try_wait.parity` operation performs a potentially-blocking
@@ -847,46 +833,6 @@ def NVVM_MBarrierTryWaitParityOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity"
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-try-wait)
}];
let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)";
- let extraClassDefinition = [{
- std::string $cppClass::getPtx() {
- return std::string(
- "{\n\t"
- ".reg .pred P1; \n\t"
- "LAB_WAIT: \n\t"
- "mbarrier.try_wait.parity.b64 P1, [%0], %1, %2; \n\t"
- "@P1 bra.uni DONE; \n\t"
- "bra.uni LAB_WAIT; \n\t"
- "DONE: \n\t"
- "}"
- );
- }
- }];
-}
-
-def NVVM_MBarrierTryWaitParitySharedOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity.shared">,
- Arguments<(ins LLVM_PointerShared:$addr, I32:$phase, I32:$ticks)> {
- let summary = "Shared MBarrier Potentially-Blocking Try Wait with Phase Parity";
- let description = [{
- This Op is the same as `nvvm.mbarrier.try_wait.parity` except that the *mbarrier object*
- should be accessed using a shared-memory pointer instead of a generic-memory pointer.
-
- [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-try-wait)
- }];
- let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)";
- let extraClassDefinition = [{
- std::string $cppClass::getPtx() {
- return std::string(
- "{\n\t"
- ".reg .pred P1; \n\t"
- "LAB_WAIT: \n\t"
- "mbarrier.try_wait.parity.shared.b64 P1, [%0], %1, %2; \n\t"
- "@P1 bra.uni DONE; \n\t"
- "bra.uni LAB_WAIT; \n\t"
- "DONE: \n\t"
- "}"
- );
- }
- }];
}
def NVVM_MBarrierTestWaitOp : NVVM_Op<"mbarrier.test.wait">,
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 9348d3c172a07..3a70f787da124 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -922,13 +922,6 @@ struct NVGPUMBarrierArriveExpectTxLowering
getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
adaptor.getMbarId(), rewriter);
Value txcount = truncToI32(b, adaptor.getTxcount());
-
- if (isMbarrierShared(op.getBarriers().getType())) {
- rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxSharedOp>(
- op, barrier, txcount, adaptor.getPredicate());
- return success();
- }
-
rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>(
op, barrier, txcount, adaptor.getPredicate());
return success();
@@ -949,13 +942,6 @@ struct NVGPUMBarrierTryWaitParityLowering
Value ticks = truncToI32(b, adaptor.getTicks());
Value phase =
LLVM::ZExtOp::create(b, b.getI32Type(), adaptor.getPhaseParity());
-
- if (isMbarrierShared(op.getBarriers().getType())) {
- rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParitySharedOp>(
- op, barrier, phase, ticks);
- return success();
- }
-
rewriter.replaceOpWithNewOp<NVVM::MBarrierTryWaitParityOp>(op, barrier,
phase, ticks);
return success();
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index d43f8815be16d..d3bbda175093d 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -47,6 +47,19 @@ using namespace NVVM;
static constexpr unsigned notIntrinsic = llvm::Intrinsic::not_intrinsic;
+//===----------------------------------------------------------------------===//
+// Helper/Utility methods
+//===----------------------------------------------------------------------===//
+
+static bool isPtrInAddrSpace(mlir::Value ptr, NVVMMemorySpace targetAS) {
+ auto ptrTy = llvm::cast<LLVM::LLVMPointerType>(ptr.getType());
+ return ptrTy.getAddressSpace() == static_cast<unsigned>(targetAS);
+}
+
+static bool isPtrInSharedCTASpace(mlir::Value ptr) {
+ return isPtrInAddrSpace(ptr, NVVMMemorySpace::Shared);
+}
+
//===----------------------------------------------------------------------===//
// Verifier methods
//===----------------------------------------------------------------------===//
@@ -1741,26 +1754,38 @@ void Tcgen05MmaSmemDescOp::createSmemDescriptor(Operation &op,
//===----------------------------------------------------------------------===//
std::string NVVM::MBarrierInitOp::getPtx() {
- unsigned addressSpace =
- llvm::cast<LLVM::LLVMPointerType>(getAddr().getType()).getAddressSpace();
- return (addressSpace == NVVMMemorySpace::Shared)
- ? std::string("mbarrier.init.shared.b64 [%0], %1;")
- : std::string("mbarrier.init.b64 [%0], %1;");
+ bool isShared = isPtrInSharedCTASpace(getAddr());
+ return isShared ? std::string("mbarrier.init.shared.b64 [%0], %1;")
+ : std::string("mbarrier.init.b64 [%0], %1;");
}
-//===----------------------------------------------------------------------===//
-// getIntrinsicID/getIntrinsicIDAndArgs methods
-//===----------------------------------------------------------------------===//
-
-static bool isPtrInAddrSpace(mlir::Value ptr, NVVMMemorySpace targetAS) {
- auto ptrTy = llvm::cast<LLVM::LLVMPointerType>(ptr.getType());
- return ptrTy.getAddressSpace() == static_cast<unsigned>(targetAS);
+std::string NVVM::MBarrierArriveExpectTxOp::getPtx() {
+ bool isShared = isPtrInSharedCTASpace(getAddr());
+ return isShared
+ ? std::string("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;")
+ : std::string("mbarrier.arrive.expect_tx.b64 _, [%0], %1;");
}
-static bool isPtrInSharedCTASpace(mlir::Value ptr) {
- return isPtrInAddrSpace(ptr, NVVMMemorySpace::Shared);
+std::string NVVM::MBarrierTryWaitParityOp::getPtx() {
+ bool isShared = isPtrInSharedCTASpace(getAddr());
+ std::string space = isShared ? ".shared" : "";
+
+ return "{\n\t"
+ ".reg .pred P1; \n\t"
+ "LAB_WAIT: \n\t"
+ "mbarrier.try_wait.parity" +
+ space +
+ ".b64 P1, [%0], %1, %2; \n\t"
+ "@P1 bra.uni DONE; \n\t"
+ "bra.uni LAB_WAIT; \n\t"
+ "DONE: \n\t"
+ "}";
}
+//===----------------------------------------------------------------------===//
+// getIntrinsicID/getIntrinsicIDAndArgs methods
+//===----------------------------------------------------------------------===//
+
mlir::NVVM::IDArgPair MBarrierInitOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::MBarrierInitOp>(op);
diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
index dcf4ddb2dd48c..0eb44789fe31d 100644
--- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
@@ -603,14 +603,14 @@ func.func @mbarrier_txcount() {
%txcount = arith.constant 256 : index
// CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
- // CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]]
+ // CHECK: nvvm.mbarrier.arrive.expect_tx %[[barPtr2]]
nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount : !barrierType
scf.yield
} else {
%txcount = arith.constant 0 : index
// CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
- // CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]]
+ // CHECK: nvvm.mbarrier.arrive.expect_tx %[[barPtr2]]
nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount : !barrierType
scf.yield
}
@@ -620,7 +620,7 @@ func.func @mbarrier_txcount() {
%ticks = arith.constant 10000000 : index
// CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[barPtr3:.+]] = llvm.getelementptr %[[base3]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
- // CHECK: nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]]
+ // CHECK: nvvm.mbarrier.try_wait.parity %[[barPtr3]]
nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase_c0, %ticks : !barrierType
func.return
@@ -649,14 +649,14 @@ func.func @mbarrier_txcount_pred() {
%txcount = arith.constant 256 : index
// CHECK: %[[base2:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[barPtr2:.+]] = llvm.getelementptr %[[base2]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
- // CHECK: nvvm.mbarrier.arrive.expect_tx.shared %[[barPtr2]], {{.*}}, predicate = %[[P]]
+ // CHECK: nvvm.mbarrier.arrive.expect_tx %[[barPtr2]], {{.*}}, predicate = %[[P]]
nvgpu.mbarrier.arrive.expect_tx %barrier[%c0], %txcount, predicate = %pred : !barrierType
%phase_c0 = arith.constant 0 : i1
%ticks = arith.constant 10000000 : index
// CHECK: %[[base3:.+]] = llvm.extractvalue %[[barStr]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)>
// CHECK: %[[barPtr3:.+]] = llvm.getelementptr %[[base3]][%[[mid]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i64
- // CHECK: nvvm.mbarrier.try_wait.parity.shared %[[barPtr3]]
+ // CHECK: nvvm.mbarrier.try_wait.parity %[[barPtr3]]
nvgpu.mbarrier.try_wait.parity %barrier[%c0], %phase_c0, %ticks : !barrierType
func.return
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index a9356c5cb60bb..a94fcb4856db4 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -17,9 +17,9 @@ llvm.func @init_mbarrier(%barrier_gen : !llvm.ptr, %barrier : !llvm.ptr<3>, %cou
// CHECK-LABEL: @init_mbarrier_arrive_expect_tx
llvm.func @init_mbarrier_arrive_expect_tx(%barrier : !llvm.ptr<3>, %txcount : i32, %pred : i1) {
//CHECK: llvm.inline_asm has_side_effects asm_dialect = att "mbarrier.arrive.expect_tx.shared.b64 _, [$0], $1;", "r,r"
- nvvm.mbarrier.arrive.expect_tx.shared %barrier, %txcount : !llvm.ptr<3>, i32
+ nvvm.mbarrier.arrive.expect_tx %barrier, %txcount : !llvm.ptr<3>, i32
//CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$2 mbarrier.arrive.expect_tx.shared.b64 _, [$0], $1;", "r,r,b"
- nvvm.mbarrier.arrive.expect_tx.shared %barrier, %txcount, predicate = %pred : !llvm.ptr<3>, i32, i1
+ nvvm.mbarrier.arrive.expect_tx %barrier, %txcount, predicate = %pred : !llvm.ptr<3>, i32, i1
llvm.return
}
@@ -44,7 +44,7 @@ llvm.func @init_mbarrier_try_wait_shared(%barrier : !llvm.ptr<3>, %ticks : i32,
// CHECK-SAME: DONE:
// CHECK-SAME: }",
// CHECK-SAME: "r,r,r"
- nvvm.mbarrier.try_wait.parity.shared %barrier, %phase, %ticks : !llvm.ptr<3>, i32, i32
+ nvvm.mbarrier.try_wait.parity %barrier, %phase, %ticks : !llvm.ptr<3>, i32, i32
llvm.return
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/167567
More information about the Mlir-commits
mailing list