[Mlir-commits] [mlir] [MLIR][NVVM] Update mbarrier.test_wait Op (PR #169698)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Nov 26 09:48:48 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-mlir

Author: Durgadoss R (durga4github)

<details>
<summary>Changes</summary>

This patch extends mbarrier.test_wait to support scope,
semantics, and phase-parity, completing the updates for
this Op up to Blackwell. Corresponding lit tests are added
to verify the lowering.

---
Full diff: https://github.com/llvm/llvm-project/pull/169698.diff


5 Files Affected:

- (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+20-8) 
- (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+32-10) 
- (modified) mlir/test/Target/LLVMIR/nvvm/mbar_init.mlir (-20) 
- (modified) mlir/test/Target/LLVMIR/nvvm/mbar_invalid.mlir (+8) 
- (added) mlir/test/Target/LLVMIR/nvvm/mbar_test_wait.mlir (+73) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 57cb9b139111d..cfdf75242bd88 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -980,10 +980,7 @@ def NVVM_MBarrierTryWaitParityOp : NVVM_PTXBuilder_Op<"mbarrier.try_wait.parity"
   let assemblyFormat = "$addr `,` $phase `,` $ticks attr-dict `:` type(operands)";
 }
 
-def NVVM_MBarrierTestWaitOp : NVVM_Op<"mbarrier.test.wait">,
-  Results<(outs I1:$res)>,
-  Arguments<(ins AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$addr,
-                 I64:$state)> {
+def NVVM_MBarrierTestWaitOp : NVVM_Op<"mbarrier.test.wait"> {
   let summary = "MBarrier Non-Blocking Test Wait Operation";
   let description = [{
     The `nvvm.mbarrier.test.wait` operation performs a non-blocking test for the 
@@ -1002,9 +999,16 @@ def NVVM_MBarrierTestWaitOp : NVVM_Op<"mbarrier.test.wait">,
     The operation takes the following operands:
     - `addr`: A pointer to the memory location of the *mbarrier object*. Uses generic 
       addressing, but the address must still be in the shared memory space.
-    - `state`: An opaque value returned by a previous `mbarrier.arrive` 
-      operation on the same *mbarrier object* during the current or immediately 
-      preceding phase.
+    - `stateOrPhase`: This argument represents a `state` when it is a 64-bit value
+      and represents a `phase` when it is a 32-bit value. The `state` is an opaque
+      value returned by a previous `mbarrier.arrive` operation on the same
+      *mbarrier object* during the current or immediately preceding phase.
+      The `phase` is an integer specifying the phase parity (0 or 1).
+      Even phases have parity 0, odd phases have parity 1.
+    - `scope`: This specifies the set of threads that directly observe the memory
+      synchronizing effect of the `mbarrier.test.wait` operation.
+    - `relaxed`: When set to true, the `arrive` operation has relaxed memory semantics
+      and does not provide any ordering or visibility guarantees.
 
     The operation returns a boolean value indicating whether the specified phase 
     has completed:
@@ -1031,7 +1035,15 @@ def NVVM_MBarrierTestWaitOp : NVVM_Op<"mbarrier.test.wait">,
     [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-mbarrier-test-wait-try-wait)
   }];
 
-  let assemblyFormat = "$addr `,` $state attr-dict `:` type(operands) `->` type($res)";
+  let results = (outs I1:$res);
+  let arguments = (ins
+    AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$addr,
+    AnyTypeOf<[I64, I32]>:$stateOrPhase,
+    DefaultValuedAttr<MemScopeKindAttr, "MemScopeKind::CTA">:$scope,
+    DefaultValuedAttr<BoolAttr, "false">:$relaxed);
+
+  let assemblyFormat = "$addr `,` $stateOrPhase attr-dict `:` type(operands) `->` type($res)";
+  let hasVerifier = 1;
 
   let extraClassDeclaration = [{
     static mlir::NVVM::IDArgPair
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index e9949547aaea4..bf5d7aab7f19a 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -252,10 +252,10 @@ LogicalResult CpAsyncBulkGlobalToSharedClusterOp::verify() {
 static LogicalResult verifyMBarrierArriveLikeOp(Operation *op, Value addr,
                                                 NVVM::MemScopeKind scope,
                                                 Value retVal = nullptr) {
-  bool isSharedCluster = isPtrInSharedClusterSpace(addr);
   if (scope != NVVM::MemScopeKind::CTA && scope != NVVM::MemScopeKind::CLUSTER)
     return op->emitError("mbarrier scope must be either CTA or Cluster");
 
+  bool isSharedCluster = isPtrInSharedClusterSpace(addr);
   bool hasRetValue = static_cast<bool>(retVal);
   if (isSharedCluster && hasRetValue)
     return op->emitError(
@@ -282,6 +282,10 @@ LogicalResult MBarrierCompleteTxOp::verify() {
   return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
 }
 
+LogicalResult MBarrierTestWaitOp::verify() {
+  return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
+}
+
 LogicalResult ConvertFloatToTF32Op::verify() {
   using RndMode = NVVM::FPRoundingMode;
   switch (getRnd()) {
@@ -2084,16 +2088,34 @@ mlir::NVVM::IDArgPair MBarrierArriveDropNocompleteOp::getIntrinsicIDAndArgs(
 mlir::NVVM::IDArgPair MBarrierTestWaitOp::getIntrinsicIDAndArgs(
     Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
   auto thisOp = cast<NVVM::MBarrierTestWaitOp>(op);
-  bool isShared = isPtrInSharedCTASpace(thisOp.getAddr());
-  llvm::Intrinsic::ID id = isShared
-                               ? llvm::Intrinsic::nvvm_mbarrier_test_wait_shared
-                               : llvm::Intrinsic::nvvm_mbarrier_test_wait;
-  // Fill the Intrinsic Args
-  llvm::SmallVector<llvm::Value *> args;
-  args.push_back(mt.lookupValue(thisOp.getAddr()));
-  args.push_back(mt.lookupValue(thisOp.getState()));
+  bool isPhaseParity = thisOp.getStateOrPhase().getType().isInteger(32);
+  bool isClusterScope = thisOp.getScope() == NVVM::MemScopeKind::CLUSTER;
+  // bit-0: isPhaseParity
+  // bit-1: Scope
+  size_t index = ((isClusterScope ? 1 : 0) << 1) | (isPhaseParity ? 1 : 0);
 
-  return {id, std::move(args)};
+  static constexpr llvm::Intrinsic::ID IDs[] = {
+      llvm::Intrinsic::nvvm_mbarrier_test_wait_scope_cta_space_cta,
+      llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_scope_cta_space_cta,
+      llvm::Intrinsic::nvvm_mbarrier_test_wait_scope_cluster_space_cta,
+      llvm::Intrinsic::nvvm_mbarrier_test_wait_parity_scope_cluster_space_cta};
+  static constexpr llvm::Intrinsic::ID relaxedIDs[] = {
+      llvm::Intrinsic::nvvm_mbarrier_test_wait_relaxed_scope_cta_space_cta,
+      llvm::Intrinsic::
+          nvvm_mbarrier_test_wait_parity_relaxed_scope_cta_space_cta,
+      llvm::Intrinsic::nvvm_mbarrier_test_wait_relaxed_scope_cluster_space_cta,
+      llvm::Intrinsic::
+          nvvm_mbarrier_test_wait_parity_relaxed_scope_cluster_space_cta};
+  auto id = thisOp.getRelaxed() ? relaxedIDs[index] : IDs[index];
+
+  // Tidy-up the Intrinsic Args
+  llvm::Value *mbar = mt.lookupValue(thisOp.getAddr());
+  llvm::Value *input = mt.lookupValue(thisOp.getStateOrPhase());
+  bool needCast = isPtrInGenericSpace(thisOp.getAddr());
+  if (needCast)
+    mbar = castPtrToAddrSpace(builder, mbar, NVVMMemorySpace::Shared);
+
+  return {id, {mbar, input}};
 }
 
 mlir::NVVM::IDArgPair CpAsyncMBarrierArriveOp::getIntrinsicIDAndArgs(
diff --git a/mlir/test/Target/LLVMIR/nvvm/mbar_init.mlir b/mlir/test/Target/LLVMIR/nvvm/mbar_init.mlir
index ae9c7f29bc7a5..9c1d1cc0cdc31 100644
--- a/mlir/test/Target/LLVMIR/nvvm/mbar_init.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/mbar_init.mlir
@@ -54,23 +54,3 @@ llvm.func @mbarrier_inval_shared(%barrier: !llvm.ptr<3>) {
   nvvm.mbarrier.inval %barrier : !llvm.ptr<3>
   llvm.return
 }
-
-llvm.func @mbarrier_test_wait(%barrier: !llvm.ptr, %token : i64) -> i1 {
-  // CHECK-LABEL: define i1 @mbarrier_test_wait(ptr %0, i64 %1) {
-  // CHECK-NEXT: %3 = call i1 @llvm.nvvm.mbarrier.test.wait(ptr %0, i64 %1)
-  // CHECK-NEXT: ret i1 %3
-  // CHECK-NEXT: }
-  %isComplete = nvvm.mbarrier.test.wait %barrier, %token : !llvm.ptr, i64 -> i1
-  llvm.return %isComplete : i1
-}
-
-llvm.func @mbarrier_test_wait_shared(%barrier: !llvm.ptr<3>, %token : i64) {
-  // CHECK-LABEL: define void @mbarrier_test_wait_shared(ptr addrspace(3) %0, i64 %1) {
-  // CHECK-NEXT: %3 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
-  // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.test.wait.shared(ptr addrspace(3) %0, i64 %1)
-  // CHECK-NEXT: ret void
-  // CHECK-NEXT: }
-  %count = nvvm.read.ptx.sreg.ntid.x : i32
-  %isComplete = nvvm.mbarrier.test.wait %barrier, %token : !llvm.ptr<3>, i64 -> i1
-  llvm.return
-}
diff --git a/mlir/test/Target/LLVMIR/nvvm/mbar_invalid.mlir b/mlir/test/Target/LLVMIR/nvvm/mbar_invalid.mlir
index 4ad76248b7e25..1c0e23516ffe1 100644
--- a/mlir/test/Target/LLVMIR/nvvm/mbar_invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/mbar_invalid.mlir
@@ -47,3 +47,11 @@ llvm.func @mbarrier_complete_tx_scope(%barrier: !llvm.ptr<3>, %tx_count: i32) {
   nvvm.mbarrier.complete_tx %barrier, %tx_count {scope = #nvvm.mem_scope<sys>} : !llvm.ptr<3>, i32
   llvm.return
 }
+
+// -----
+
+llvm.func @mbarrier_test_wait(%barrier: !llvm.ptr<3>, %phase: i32) {
+  // expected-error @below {{mbarrier scope must be either CTA or Cluster}}
+  %1 = nvvm.mbarrier.test.wait %barrier, %phase {scope = #nvvm.mem_scope<gpu>} : !llvm.ptr<3>, i32 -> i1
+  llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/mbar_test_wait.mlir b/mlir/test/Target/LLVMIR/nvvm/mbar_test_wait.mlir
new file mode 100644
index 0000000000000..21ab72eeab167
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/mbar_test_wait.mlir
@@ -0,0 +1,73 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+llvm.func @mbarrier_test_wait_state(%barrier: !llvm.ptr, %state : i64) {
+  // CHECK-LABEL: define void @mbarrier_test_wait_state(ptr %0, i64 %1) {
+  // CHECK-NEXT: %3 = addrspacecast ptr %0 to ptr addrspace(3)
+  // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.test.wait.scope.cta.space.cta(ptr addrspace(3) %3, i64 %1)
+  // CHECK-NEXT: %5 = addrspacecast ptr %0 to ptr addrspace(3)
+  // CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.test.wait.scope.cluster.space.cta(ptr addrspace(3) %5, i64 %1)
+  // CHECK-NEXT: %7 = addrspacecast ptr %0 to ptr addrspace(3)
+  // CHECK-NEXT: %8 = call i1 @llvm.nvvm.mbarrier.test.wait.relaxed.scope.cta.space.cta(ptr addrspace(3) %7, i64 %1)
+  // CHECK-NEXT: %9 = addrspacecast ptr %0 to ptr addrspace(3)
+  // CHECK-NEXT: %10 = call i1 @llvm.nvvm.mbarrier.test.wait.relaxed.scope.cluster.space.cta(ptr addrspace(3) %9, i64 %1)
+  // CHECK-NEXT: ret void
+  // CHECK-NEXT: }
+  %0 = nvvm.mbarrier.test.wait %barrier, %state : !llvm.ptr, i64 -> i1
+  %1 = nvvm.mbarrier.test.wait %barrier, %state {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i64 -> i1
+
+  %2 = nvvm.mbarrier.test.wait %barrier, %state {relaxed = true} : !llvm.ptr, i64 -> i1
+  %3 = nvvm.mbarrier.test.wait %barrier, %state {relaxed = true, scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i64 -> i1
+  llvm.return
+}
+
+llvm.func @mbarrier_test_wait_shared_state(%barrier: !llvm.ptr<3>, %state : i64) {
+  // CHECK-LABEL: define void @mbarrier_test_wait_shared_state(ptr addrspace(3) %0, i64 %1) {
+  // CHECK-NEXT: %3 = call i1 @llvm.nvvm.mbarrier.test.wait.scope.cta.space.cta(ptr addrspace(3) %0, i64 %1)
+  // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.test.wait.scope.cluster.space.cta(ptr addrspace(3) %0, i64 %1)
+  // CHECK-NEXT: %5 = call i1 @llvm.nvvm.mbarrier.test.wait.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i64 %1)
+  // CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.test.wait.relaxed.scope.cluster.space.cta(ptr addrspace(3) %0, i64 %1)
+  // CHECK-NEXT: ret void
+  // CHECK-NEXT: }
+  %0 = nvvm.mbarrier.test.wait %barrier, %state : !llvm.ptr<3>, i64 -> i1
+  %1 = nvvm.mbarrier.test.wait %barrier, %state {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i64 -> i1
+
+  %2 = nvvm.mbarrier.test.wait %barrier, %state {relaxed = true} : !llvm.ptr<3>, i64 -> i1
+  %3 = nvvm.mbarrier.test.wait %barrier, %state {relaxed = true, scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i64 -> i1
+  llvm.return
+}
+
+llvm.func @mbarrier_test_wait_phase(%barrier: !llvm.ptr, %phase : i32) {
+  // CHECK-LABEL: define void @mbarrier_test_wait_phase(ptr %0, i32 %1) {
+  // CHECK-NEXT: %3 = addrspacecast ptr %0 to ptr addrspace(3)
+  // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.scope.cta.space.cta(ptr addrspace(3) %3, i32 %1)
+  // CHECK-NEXT: %5 = addrspacecast ptr %0 to ptr addrspace(3)
+  // CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.scope.cluster.space.cta(ptr addrspace(3) %5, i32 %1)
+  // CHECK-NEXT: %7 = addrspacecast ptr %0 to ptr addrspace(3)
+  // CHECK-NEXT: %8 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.relaxed.scope.cta.space.cta(ptr addrspace(3) %7, i32 %1)
+  // CHECK-NEXT: %9 = addrspacecast ptr %0 to ptr addrspace(3)
+  // CHECK-NEXT: %10 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.relaxed.scope.cluster.space.cta(ptr addrspace(3) %9, i32 %1)
+  // CHECK-NEXT: ret void
+  // CHECK-NEXT: }
+  %0 = nvvm.mbarrier.test.wait %barrier, %phase : !llvm.ptr, i32 -> i1
+  %1 = nvvm.mbarrier.test.wait %barrier, %phase {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i32 -> i1
+
+  %2 = nvvm.mbarrier.test.wait %barrier, %phase {relaxed = true} : !llvm.ptr, i32 -> i1
+  %3 = nvvm.mbarrier.test.wait %barrier, %phase {relaxed = true, scope = #nvvm.mem_scope<cluster>} : !llvm.ptr, i32 -> i1
+  llvm.return
+}
+
+llvm.func @mbarrier_test_wait_shared_phase(%barrier: !llvm.ptr<3>, %phase : i32) {
+  // CHECK-LABEL: define void @mbarrier_test_wait_shared_phase(ptr addrspace(3) %0, i32 %1) {
+  // CHECK-NEXT: %3 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1)
+  // CHECK-NEXT: %4 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1)
+  // CHECK-NEXT: %5 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.relaxed.scope.cta.space.cta(ptr addrspace(3) %0, i32 %1)
+  // CHECK-NEXT: %6 = call i1 @llvm.nvvm.mbarrier.test.wait.parity.relaxed.scope.cluster.space.cta(ptr addrspace(3) %0, i32 %1)
+  // CHECK-NEXT: ret void
+  // CHECK-NEXT: }
+  %0 = nvvm.mbarrier.test.wait %barrier, %phase : !llvm.ptr<3>, i32 -> i1
+  %1 = nvvm.mbarrier.test.wait %barrier, %phase {scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i32 -> i1
+
+  %2 = nvvm.mbarrier.test.wait %barrier, %phase {relaxed = true} : !llvm.ptr<3>, i32 -> i1
+  %3 = nvvm.mbarrier.test.wait %barrier, %phase {relaxed = true, scope = #nvvm.mem_scope<cluster>} : !llvm.ptr<3>, i32 -> i1
+  llvm.return
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/169698


More information about the Mlir-commits mailing list