[Mlir-commits] [mlir] [MLIR][NVVM] Update mbarrier.test_wait Op (PR #169698)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Nov 26 09:48:48 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-mlir
Author: Durgadoss R (durga4github)
<details>
<summary>Changes</summary>
This patch extends mbarrier.test_wait to support scope,
semantics, and phase-parity, completing the updates for
this Op up to Blackwell. Corresponding lit tests are added
to verify the lowering.
---
Full diff: https://github.com/llvm/llvm-project/pull/169698.diff
5 Files Affected:
- (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+20-8)
- (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+32-10)
- (modified) mlir/test/Target/LLVMIR/nvvm/mbar_init.mlir (-20)
- (modified) mlir/test/Target/LLVMIR/nvvm/mbar_invalid.mlir (+8)
- (added) mlir/test/Target/LLVMIR/nvvm/mbar_test_wait.mlir (+73)
``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 57cb9b139111d..cfdf75242bd88 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -980,10 +980,7 @@ def NVVM_MBarrierTryWaitParityOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity"
let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)";
}
-def NVVM_MBarrierTestWaitOp : NVVM_Op<"mbarrier.test.wait">,
- Results<(outs I1:$res)>,
- Arguments<(ins AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$addr,
- I64:$state)> {
+def NVVM_MBarrierTestWaitOp : NVVM_Op<"mbarrier.test.wait"> {
let summary = "MBarrier Non-Blocking Test Wait Operation";
let description = [{
The `nvvm.mbarrier.test.wait` operation performs a non-blocking test for the
@@ -1002,9 +999,16 @@ def NVVM_MBarrierTestWaitOp : NVVM_Op<"mbarrier.test.wait">,
The operation takes the following operands:
- `addr`: A pointer to the memory location of the *mbarrier object*. Uses generic
addressing, but the address must still be in the shared memory space.
- - `state`: An opaque value returned by a previous `mbarrier.arrive`
- operation on the same *mbarrier object* during the current or immediately
- preceding phase.
+ - `stateOrPhase`: This argument represents a `state` when it is a 64-bit value
+ and represents a `phase` when it is a 32-bit value. The `state` is an opaque
+ value returned by a previous `mbarrier.arrive` operation on the same
+ *mbarrier object* during the current or immediately preceding phase.
+ The `phase` is an integer specifying the phase parity (0 or 1).
+ Even phases have parity 0, odd phases have parity 1.
+ - `scope`: This specifies the set of threads that directly observe the memory
+ synchronizing effect of the `mbarrier.test.wait` operation.
+ - `relaxed`: When set to true, the `arrive` operation has relaxed memory semantics
+ and does not provide any ordering or visibility guarantees.
The operation returns a boolean value indicating whether the specified phase
has completed:
@@ -1031,7 +1035,15 @@ def NVVM_MBarrierTestWaitOp : NVVM_Op<"mbarrier.test.wait">,
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-try-wait)
}];
- let assemblyFormat = "$addr `,` $state attr-dict `:` type(operands) `->` type($res)";
+ let results = (outs I1:$res);
+ let arguments = (ins
+ AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$addr,
+ AnyTypeOf<[I64, I32]>:$stateOrPhase,
+ DefaultValuedAttr<MemScopeKindAttr, "MemScopeKind::CTA">:$scope,
+ DefaultValuedAttr<BoolAttr, "false">:$relaxed);
+
+ let assemblyFormat = "$addr `,` $stateOrPhase attr-dict `:` type(operands) `->` type($res)";
+ let hasVerifier = 1;
let extraClassDeclaration = [{
static mlir::NVVM::IDArgPair
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index e9949547aaea4..bf5d7aab7f19a 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -252,10 +252,10 @@ LogicalResult CpAsyncBulkGlobalToSharedClusterOp::verify() {
static LogicalResult verifyMBarrierArriveLikeOp(Operation *op, Value addr,
NVVM::MemScopeKind scope,
Value retVal = nullptr) {
- bool isSharedCluster = isPtrInSharedClusterSpace(addr);
if (scope != NVVM::MemScopeKind::CTA && scope != NVVM::MemScopeKind::CLUSTER)
return op->emitError("mbarrier scope must be either CTA or Cluster");
+ bool isSharedCluster = isPtrInSharedClusterSpace(addr);
bool hasRetValue = static_cast<bool>(retVal);
if (isSharedCluster && hasRetValue)
return op->emitError(
@@ -282,6 +282,10 @@ LogicalResult MBarrierCompleteTxOp::verify() {
return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
}
+LogicalResult MBarrierTestWaitOp::verify() {
+ return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
+}
+
LogicalResult ConvertFloatToTF32Op::verify() {
using RndMode = NVVM::FPRoundingMode;
switch (getRnd()) {
@@ -2084,16 +2088,34 @@ mlir::NVVM::IDArgPair MBarrierArriveDropNocompleteOp::getIntrinsicIDAndArgs(
mlir::NVVM::IDArgPair MBarrierTestWaitOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::MBarrierTestWaitOp>(op);
- bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
- llvm::Intrinsic::ID id = isShared
- ? llvm::Intrinsic::nvvm_mbarrier_test_wait_shared
- : llvm::Intrinsic::nvvm_mbarrier_test_wait;
- // Fill the Intrinsic Args
- llvm::SmallVector<llvm::Value *> args;
- args.push_back(mt.lookupValue(thisOp.getAddr()));
- args.push_back(mt.lookupValue(thisOp.getState()));
+ bool isPhaseParity = thisOp.getStateOrPhase().getType().isInteger(32);
+ bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
+ // bit-0: isPhaseParity
+ // bit-1: Scope
+ size_t index = ((isClusterScope ? 1 : 0) << 1) | (isPhaseParity ? 1 : 0);
- return {id, std::move(args)};
+ static constexpr llvm::Intrinsic::ID IDs[] = {
+ llvm::Intrinsic::nvvm_mbarrier_test_wait_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_test_wait_scope_cluster_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_scope_cluster_space_cta};
+ static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
+ llvm::Intrinsic::nvvm_mbarrier_test_wait_relaxed_scope_cta_space_cta,
+ llvm::Intrinsic::
+ nvvm_mbarrier_test_wait_parity_relaxed_scope_cta_space_cta,
+ llvm::Intrinsic::nvvm_mbarrier_test_wait_relaxed_scope_cluster_space_cta,
+ llvm::Intrinsic::
+ nvvm_mbarrier_test_wait_parity_relaxed_scope_cluster_space_cta};
+ auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
+
+ // Tidy-up the Intrinsic Args
+ llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
+ llvm::Value *input = mt.lookupValue(thisOp.getStateOrPhase());
+ bool needCast = isPtrInGenericSpace(thisOp.getAddr());
+ if (needCast)
+ mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
+
+ return {id, {mbar, input}};
}
mlir::NVVM::IDArgPair CpAsyncMBarrierArriveOp::getIntrinsicIDAndArgs(
diff --git a/mlir/test/Target/LLVMIR/nvvm/mbar_init.mlir b/mlir/test/Target/LLVMIR/nvvm/mbar_init.mlir
index ae9c7f29bc7a5..9c1d1cc0cdc31 100644
--- a/mlir/test/Target/LLVMIR/nvvm/mbar_init.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/mbar_init.mlir
@@ -54,23 +54,3 @@ llvm.func @mbarrier_inval_shared(%barrier: !llvm.ptr<3>) {
nvvm.mbarrier.inval %barrier : !llvm.ptr<3>
llvm.return
}
-
-llvm.func @mbarrier_test_wait(%barrier: !llvm.ptr, %token : i64) -> i1 {
- // CHECK-LABEL: define i1 @mbarrier_test_wait(ptr %0, i64 %1) {
- // CHECK-NEXT: %3 = call i1 @llvm.nvvm.mbarrier.test.wait(ptr %0, i64 %1)
- // CHECK-NEXT: ret i1 %3
- // CHECK-NEXT: }
- %isComplete = nvvm.mbarrier.test.wait %barrier, %token : !llvm.ptr, i64 -> i1
- llvm.return %isComplete : i1
-}
-
-llvm.func @mbarrier_test_wait_shared(%barrier: !llvm.ptr<3>, %token : i64) {
- // CHECK-LABEL: define void @mbarrier_test_wait_shared(ptr addrspace(3) %0, i64 %1) {
- // CHECK-NEXT: %3 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
- // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.test.wait.shared(ptr addrspace(3) %0, i64 %1)
- // CHECK-NEXT: ret void
- // CHECK-NEXT: }
- %count = nvvm.read.ptx.sreg.ntid.x : i32
- %isComplete = nvvm.mbarrier.test.wait %barrier, %token : !llvm.ptr<3>, i64 -> i1
- llvm.return
-}
diff --git a/mlir/test/Target/LLVMIR/nvvm/mbar_invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/mbar_invalid.mlir
index 4ad76248b7e25..1c0e23516ffe1 100644
--- a/mlir/test/Target/LLVMIR/nvvm/mbar_invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/mbar_invalid.mlir
@@ -47,3 +47,11 @@ llvm.func @mbarrier_complete_tx_scope(%barrier: !llvm.ptr<3>, %tx_count: i32) {
nvvm.mbarrier.complete_tx %barrier, %tx_count {scope = #nvvm.mem_scope<sys>} : !llvm.ptr<3>, i32
llvm.return
}
+
+// -----
+
+llvm.func @mbarrier_test_wait(%barrier: !llvm.ptr<3>, %phase: i32) {
+ // expected-error @below {{mbarrier scope must be either CTA or Cluster}}
+ %1 = nvvm.mbarrier.test.wait %barrier, %phase {scope = #nvvm.mem_scope<gpu>} : !llvm.ptr<3>, i32 -> i1
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/mbar_test_wait.mlir b/mlir/test/Target/LLVMIR/nvvm/mbar_test_wait.mlir
new file mode 100644
index 0000000000000..21ab72eeab167
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/mbar_test_wait.mlir
@@ -0,0 +1,73 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+llvm.func @mbarrier_test_wait_state(%barrier: !llvm.ptr, %state : i64) {
+ // CHECK-LABEL: define void @mbarrier_test_wait_state(ptr %0, i64 %1) {
+ // CHECK-NEXT: %3 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.test.wait.scope.cta.space.cta(ptr addrspace(3) %3, i64 %1)
+ // CHECK-NEXT: %5 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.test.wait.scope.cluster.space.cta(ptr addrspace(3) %5, i64 %1)
+ // CHECK-NEXT: %7 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %8 = call i1 @llvm.nvvm.mbarrier.test.wait.relaxed.scope.cta.space.cta(ptr addrspace(3) %7, i64 %1)
+ // CHECK-NEXT: %9 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %10 = call i1 @llvm.nvvm.mbarrier.test.wait.relaxed.scope.cluster.space.cta(ptr addrspace(3) %9, i64 %1)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ %0 = nvvm.mbarrier.test.wait %barrier, %state : !llvm.ptr, i64 -> i1
+ %1 = nvvm.mbarrier.test.wait %barrier, %state {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i64 -> i1
+
+ %2 = nvvm.mbarrier.test.wait %barrier, %state {relaxed = true} : !llvm.ptr, i64 -> i1
+ %3 = nvvm.mbarrier.test.wait %barrier, %state {relaxed = true, scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i64 -> i1
+ llvm.return
+}
+
+llvm.func @mbarrier_test_wait_shared_state(%barrier: !llvm.ptr<3>, %state : i64) {
+ // CHECK-LABEL: define void @mbarrier_test_wait_shared_state(ptr addrspace(3) %0, i64 %1) {
+ // CHECK-NEXT: %3 = call i1 @llvm.nvvm.mbarrier.test.wait.scope.cta.space.cta(ptr addrspace(3) %0, i64 %1)
+ // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.test.wait.scope.cluster.space.cta(ptr addrspace(3) %0, i64 %1)
+ // CHECK-NEXT: %5 = call i1 @llvm.nvvm.mbarrier.test.wait.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i64 %1)
+ // CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.test.wait.relaxed.scope.cluster.space.cta(ptr addrspace(3) %0, i64 %1)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ %0 = nvvm.mbarrier.test.wait %barrier, %state : !llvm.ptr<3>, i64 -> i1
+ %1 = nvvm.mbarrier.test.wait %barrier, %state {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i64 -> i1
+
+ %2 = nvvm.mbarrier.test.wait %barrier, %state {relaxed = true} : !llvm.ptr<3>, i64 -> i1
+ %3 = nvvm.mbarrier.test.wait %barrier, %state {relaxed = true, scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i64 -> i1
+ llvm.return
+}
+
+llvm.func @mbarrier_test_wait_phase(%barrier: !llvm.ptr, %phase : i32) {
+ // CHECK-LABEL: define void @mbarrier_test_wait_phase(ptr %0, i32 %1) {
+ // CHECK-NEXT: %3 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.scope.cta.space.cta(ptr addrspace(3) %3, i32 %1)
+ // CHECK-NEXT: %5 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.scope.cluster.space.cta(ptr addrspace(3) %5, i32 %1)
+ // CHECK-NEXT: %7 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %8 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.relaxed.scope.cta.space.cta(ptr addrspace(3) %7, i32 %1)
+ // CHECK-NEXT: %9 = addrspacecast ptr %0 to ptr addrspace(3)
+ // CHECK-NEXT: %10 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.relaxed.scope.cluster.space.cta(ptr addrspace(3) %9, i32 %1)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ %0 = nvvm.mbarrier.test.wait %barrier, %phase : !llvm.ptr, i32 -> i1
+ %1 = nvvm.mbarrier.test.wait %barrier, %phase {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i32 -> i1
+
+ %2 = nvvm.mbarrier.test.wait %barrier, %phase {relaxed = true} : !llvm.ptr, i32 -> i1
+ %3 = nvvm.mbarrier.test.wait %barrier, %phase {relaxed = true, scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i32 -> i1
+ llvm.return
+}
+
+llvm.func @mbarrier_test_wait_shared_phase(%barrier: !llvm.ptr<3>, %phase : i32) {
+ // CHECK-LABEL: define void @mbarrier_test_wait_shared_phase(ptr addrspace(3) %0, i32 %1) {
+ // CHECK-NEXT: %3 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1)
+ // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1)
+ // CHECK-NEXT: %5 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1)
+ // CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.relaxed.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ %0 = nvvm.mbarrier.test.wait %barrier, %phase : !llvm.ptr<3>, i32 -> i1
+ %1 = nvvm.mbarrier.test.wait %barrier, %phase {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i32 -> i1
+
+ %2 = nvvm.mbarrier.test.wait %barrier, %phase {relaxed = true} : !llvm.ptr<3>, i32 -> i1
+ %3 = nvvm.mbarrier.test.wait %barrier, %phase {relaxed = true, scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i32 -> i1
+ llvm.return
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/169698
More information about the Mlir-commits
mailing list