[Mlir-commits] [mlir] [mlir][NVVM] Add InferTypeOpInterface to NVVM MBarrier ops with deterministic result types (PR #188173)
Bastian Hagedorn
llvmlistbot at llvm.org
Thu Apr 9 21:06:59 PDT 2026
https://github.com/bastianhagedorn updated https://github.com/llvm/llvm-project/pull/188173
>From be7ade0c59887c0d31ab669e861d7a620d23b6b9 Mon Sep 17 00:00:00 2001
From: Bastian Hagedorn <bhagedorn at nvidia.com>
Date: Mon, 23 Mar 2026 08:31:52 +0000
Subject: [PATCH] [mlir][NVVM] Add InferTypeOpInterface to mbarrier arrive and
barrier ops
Add result type inference to the 5 NVVM barrier/mbarrier ops with
optional results, using InferTypeOpAdaptorWithIsCompatible.
- MBarrierArriveOp, MBarrierArriveDropOp, MBarrierArriveDropExpectTxOp:
i64 result for non-shared_cluster pointers, no result otherwise
- MBarrierArriveExpectTxOp: same, plus no result when predicate is set
- BarrierOp: i32 result when reductionOp attribute is present
A permissive isCompatibleReturnTypes allows omitting the result,
preserving backward compatibility with the existing zero-result
assembly form.
This causes the Python binding generator to emit `results=None` as a
default parameter instead of requiring explicit result types.
Also updates the NVGPUToNVVM conversion to work with the new inferred
result types for MBarrierArriveOp and MBarrierArriveExpectTxOp.
Note: this is a source-breaking change for Python callers that pass
result types positionally.
Co-Authored-By: Claude <noreply at anthropic.com>
---
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 15 ++--
.../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp | 14 ++-
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 88 +++++++++++++++++++
mlir/test/Dialect/LLVMIR/nvvm.mlir | 7 ++
mlir/test/Target/LLVMIR/nvvm/mbar_arrive.mlir | 4 +-
mlir/test/python/dialects/nvvm.py | 47 ++++++++--
6 files changed, 152 insertions(+), 23 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index e0fef69f4f944..9b2a8985a1a44 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -723,7 +723,8 @@ def NVVM_MBarrierCompleteTxOp : NVVM_VoidIntrinsicOp<"mbarrier.complete_tx"> {
let hasVerifier = 1;
}
-def NVVM_MBarrierArriveOp : NVVM_SingleResultIntrinsicOp<"mbarrier.arrive"> {
+def NVVM_MBarrierArriveOp : NVVM_SingleResultIntrinsicOp<"mbarrier.arrive",
+ [InferTypeOpAdaptorWithIsCompatible]> {
let summary = "MBarrier Arrive Operation";
let description = [{
The `nvvm.mbarrier.arrive` operation performs an arrive-on operation on the
@@ -771,7 +772,8 @@ def NVVM_MBarrierArriveOp : NVVM_SingleResultIntrinsicOp<"mbarrier.arrive"> {
let hasVerifier = 1;
}
-def NVVM_MBarrierArriveDropOp : NVVM_SingleResultIntrinsicOp<"mbarrier.arrive_drop"> {
+def NVVM_MBarrierArriveDropOp : NVVM_SingleResultIntrinsicOp<"mbarrier.arrive_drop",
+ [InferTypeOpAdaptorWithIsCompatible]> {
let summary = "MBarrier Arrive-Drop Operation";
let description = [{
The `nvvm.mbarrier.arrive_drop` operation decrements the expected arrival
@@ -847,7 +849,8 @@ def NVVM_MBarrierArriveDropNocompleteOp : NVVM_SingleResultIntrinsicOp<"mbarrier
let assemblyFormat = "$addr `,` $count attr-dict `:` type(operands) `->` type($res)";
}
-def NVVM_MBarrierArriveExpectTxOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx"> {
+def NVVM_MBarrierArriveExpectTxOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx",
+ [InferTypeOpAdaptorWithIsCompatible]> {
let summary = "MBarrier Arrive with Expected Transaction Count";
let description = [{
The `nvvm.mbarrier.arrive.expect_tx` operation performs an expect-tx operation
@@ -913,7 +916,8 @@ def NVVM_MBarrierArriveExpectTxOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_t
}];
}
-def NVVM_MBarrierArriveDropExpectTxOp : NVVM_SingleResultIntrinsicOp<"mbarrier.arrive_drop.expect_tx"> {
+def NVVM_MBarrierArriveDropExpectTxOp : NVVM_SingleResultIntrinsicOp<"mbarrier.arrive_drop.expect_tx",
+ [InferTypeOpAdaptorWithIsCompatible]> {
let summary = "MBarrier arrive_drop with expected transaction count";
let description = [{
The `nvvm.mbarrier.arrive_drop.expect_tx` operation is similar to the
@@ -1126,7 +1130,8 @@ def BarrierReductionAttr
let assemblyFormat = "`<` $value `>`";
}
-def NVVM_BarrierOp : NVVM_SingleResultIntrinsicOp<"barrier", [AttrSizedOperandSegments]> {
+def NVVM_BarrierOp : NVVM_SingleResultIntrinsicOp<"barrier",
+ [AttrSizedOperandSegments, InferTypeOpAdaptorWithIsCompatible]> {
let summary = "CTA Barrier Synchronization Op";
let description = [{
The `nvvm.barrier` operation performs barrier synchronization and communication
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 303dc82a67374..c4175016ab30c 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -852,9 +852,7 @@ struct NVGPUMBarrierArriveLowering
Value barrier =
getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
adaptor.getMbarId(), rewriter);
- Type tokenType = getTypeConverter()->convertType(
- nvgpu::MBarrierTokenType::get(op->getContext()));
- rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveOp>(op, tokenType, barrier);
+ rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveOp>(op, barrier);
return success();
}
};
@@ -911,12 +909,12 @@ struct NVGPUMBarrierArriveExpectTxLowering
getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
adaptor.getMbarId(), rewriter);
Value txcount = truncToI32(b, adaptor.getTxcount());
- rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>(
- op, Type{}, // return-value is optional and is void by default
- barrier, txcount, // barrier and txcount
- NVVM::MemScopeKind::CTA, // default scope is CTA
- false, // relaxed-semantics is false
+ NVVM::MBarrierArriveExpectTxOp::create(
+ rewriter, op->getLoc(), barrier, txcount, // barrier and txcount
+ NVVM::MemScopeKind::CTA, // default scope is CTA
+ false, // relaxed-semantics is false
adaptor.getPredicate());
+ rewriter.eraseOp(op);
return success();
}
};
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 13f05b8f40ed8..31e7ff209db5c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -302,6 +302,82 @@ LogicalResult MBarrierArriveDropExpectTxOp::verify() {
getRes());
}
+//===----------------------------------------------------------------------===//
+// inferReturnTypes for mbarrier arrive-like ops
+//===----------------------------------------------------------------------===//
+
+/// Only shared_cluster (ptr<7>) produces zero results; all other address
+/// spaces (including generic) return i64.
+static LogicalResult
+inferMBarrierArriveResultTypes(MLIRContext *context, Value addr,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ if (!isPtrInSharedClusterSpace(addr))
+ inferredReturnTypes.push_back(IntegerType::get(context, 64));
+ return success();
+}
+
+LogicalResult
+MBarrierArriveOp::inferReturnTypes(MLIRContext *context,
+ std::optional<Location> location,
+ MBarrierArriveOp::Adaptor adaptor,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ return inferMBarrierArriveResultTypes(context, adaptor.getAddr(),
+ inferredReturnTypes);
+}
+
+LogicalResult MBarrierArriveDropOp::inferReturnTypes(
+ MLIRContext *context, std::optional<Location> location,
+ MBarrierArriveDropOp::Adaptor adaptor,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ return inferMBarrierArriveResultTypes(context, adaptor.getAddr(),
+ inferredReturnTypes);
+}
+
+LogicalResult MBarrierArriveExpectTxOp::inferReturnTypes(
+ MLIRContext *context, std::optional<Location> location,
+ MBarrierArriveExpectTxOp::Adaptor adaptor,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ // Predicate forces no return value (inline PTX path).
+ // Note: predicate + shared_cluster is rejected by the verifier separately.
+ if (adaptor.getPredicate())
+ return success();
+ return inferMBarrierArriveResultTypes(context, adaptor.getAddr(),
+ inferredReturnTypes);
+}
+
+LogicalResult MBarrierArriveDropExpectTxOp::inferReturnTypes(
+ MLIRContext *context, std::optional<Location> location,
+ MBarrierArriveDropExpectTxOp::Adaptor adaptor,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ return inferMBarrierArriveResultTypes(context, adaptor.getAddr(),
+ inferredReturnTypes);
+}
+
+/// For ops with optional results, allow the user to omit the result even when
+/// inference would produce one. This preserves backward compatibility: the
+/// result can be silently discarded (e.g., for fire-and-forget arrive ops).
+static bool isCompatibleReturnTypesOptionalResult(TypeRange inferred,
+ TypeRange actual) {
+ if (actual.empty())
+ return true;
+ return inferred == actual;
+}
+
+bool MBarrierArriveOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
+ return isCompatibleReturnTypesOptionalResult(l, r);
+}
+bool MBarrierArriveDropOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
+ return isCompatibleReturnTypesOptionalResult(l, r);
+}
+bool MBarrierArriveExpectTxOp::isCompatibleReturnTypes(TypeRange l,
+ TypeRange r) {
+ return isCompatibleReturnTypesOptionalResult(l, r);
+}
+bool MBarrierArriveDropExpectTxOp::isCompatibleReturnTypes(TypeRange l,
+ TypeRange r) {
+ return isCompatibleReturnTypesOptionalResult(l, r);
+}
+
LogicalResult MBarrierExpectTxOp::verify() {
return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
}
@@ -2855,6 +2931,18 @@ LogicalResult NVVM::BarrierOp::verify() {
return success();
}
+LogicalResult BarrierOp::inferReturnTypes(
+ MLIRContext *context, std::optional<Location> location,
+ BarrierOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
+ if (adaptor.getReductionOp())
+ inferredReturnTypes.push_back(IntegerType::get(context, 32));
+ return success();
+}
+
+bool BarrierOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
+ return isCompatibleReturnTypesOptionalResult(l, r);
+}
+
LogicalResult NVVM::Tcgen05CpOp::verify() {
auto mc = getMulticast();
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index b11ba944fe4ac..c039edc6b5de5 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -457,6 +457,13 @@ llvm.func private @mbarrier_arrive_nocomplete_shared(%barrier: !llvm.ptr<3>) {
llvm.return
}
+// CHECK-LABEL: @mbarrier_arrive_expect_tx_predicate
+llvm.func private @mbarrier_arrive_expect_tx_predicate(%barrier: !llvm.ptr<3>, %txcount: i32, %pred: i1) {
+ // CHECK: nvvm.mbarrier.arrive.expect_tx %{{.*}}, %{{.*}}, predicate = %{{.*}} : !llvm.ptr<3>, i32, i1{{$}}
+ nvvm.mbarrier.arrive.expect_tx %barrier, %txcount, predicate = %pred : !llvm.ptr<3>, i32, i1
+ llvm.return
+}
+
// CHECK-LABEL: @wgmma_fence_aligned
func.func @wgmma_fence_aligned() {
// CHECK: nvvm.wgmma.fence.aligned
diff --git a/mlir/test/Target/LLVMIR/nvvm/mbar_arrive.mlir b/mlir/test/Target/LLVMIR/nvvm/mbar_arrive.mlir
index f406922ea3873..96c910b193f12 100644
--- a/mlir/test/Target/LLVMIR/nvvm/mbar_arrive.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/mbar_arrive.mlir
@@ -108,8 +108,10 @@ llvm.func @mbarrier_arrive_ignore_retval(%count : i32, %barrier: !llvm.ptr<3>) {
// CHECK-NEXT: %4 = call i64 @llvm.nvvm.mbarrier.arrive.scope.cta.space.cta(ptr addrspace(3) %1, i32 %0)
// CHECK-NEXT: ret void
// CHECK-NEXT: }
+ // Result silently discarded (backward compatible form)
nvvm.mbarrier.arrive %barrier, %count : !llvm.ptr<3>
- nvvm.mbarrier.arrive %barrier, %count : !llvm.ptr<3>
+ // Result explicitly captured
+ %0 = nvvm.mbarrier.arrive %barrier, %count : !llvm.ptr<3> -> i64
llvm.return
}
diff --git a/mlir/test/python/dialects/nvvm.py b/mlir/test/python/dialects/nvvm.py
index b969faa088a46..24abf617548b8 100644
--- a/mlir/test/python/dialects/nvvm.py
+++ b/mlir/test/python/dialects/nvvm.py
@@ -93,6 +93,36 @@ def my_inline_ptx(a, b, c, d):
arith.addf(wo0, wo1)
+ at constructAndPrintInModule
+def test_mbarrier_arrive():
+ ptr_shared = llvm.PointerType.get(3)
+ ptr_cluster = llvm.PointerType.get(7)
+ i32 = T.i32()
+
+ @func.FuncOp.from_py_func(ptr_shared, ptr_cluster, i32)
+ def mbarrier_arrive_ops(barrier_shared, barrier_cluster, txcount):
+ token = nvvm.mbarrier_arrive(barrier_shared)
+ nvvm.mbarrier_arrive(barrier_cluster)
+ token2 = nvvm.mbarrier_arrive_drop(barrier_shared)
+ nvvm.mbarrier_arrive_drop(barrier_cluster)
+ token3 = nvvm.mbarrier_arrive_expect_tx(barrier_shared, txcount)
+ nvvm.mbarrier_arrive_expect_tx(barrier_cluster, txcount)
+ token4 = nvvm.mbarrier_arrive_drop_expect_tx(barrier_shared, txcount)
+ nvvm.mbarrier_arrive_drop_expect_tx(barrier_cluster, txcount)
+
+
+# CHECK-LABEL: func.func @mbarrier_arrive_ops(
+# CHECK-SAME: %[[SHARED:.*]]: !llvm.ptr<3>, %[[CLUSTER:.*]]: !llvm.ptr<7>, %[[TXCOUNT:.*]]: i32)
+# CHECK: %{{.*}} = nvvm.mbarrier.arrive %[[SHARED]] : !llvm.ptr<3> -> i64
+# CHECK-NEXT: nvvm.mbarrier.arrive %[[CLUSTER]] : !llvm.ptr<7>{{$}}
+# CHECK-NEXT: %{{.*}} = nvvm.mbarrier.arrive_drop %[[SHARED]] : !llvm.ptr<3> -> i64
+# CHECK-NEXT: nvvm.mbarrier.arrive_drop %[[CLUSTER]] : !llvm.ptr<7>{{$}}
+# CHECK-NEXT: %{{.*}} = nvvm.mbarrier.arrive.expect_tx %[[SHARED]], %[[TXCOUNT]] : !llvm.ptr<3>, i32 -> i64
+# CHECK-NEXT: nvvm.mbarrier.arrive.expect_tx %[[CLUSTER]], %[[TXCOUNT]] : !llvm.ptr<7>, i32{{$}}
+# CHECK-NEXT: %{{.*}} = nvvm.mbarrier.arrive_drop.expect_tx %[[SHARED]], %[[TXCOUNT]] : !llvm.ptr<3>, i32 -> i64
+# CHECK-NEXT: nvvm.mbarrier.arrive_drop.expect_tx %[[CLUSTER]], %[[TXCOUNT]] : !llvm.ptr<7>, i32{{$}}
+
+
@constructAndPrintInModule
def test_barriers():
i32 = T.i32()
@@ -102,21 +132,20 @@ def test_barriers():
def barriers(mask, vi32, vf32):
c0 = arith.constant(T.i32(), 0)
cffff = arith.constant(T.i32(), 0xFFFF)
- res = nvvm.barrier(
- res=i32,
+ nvvm.barrier(
barrier_id=c0,
number_of_threads=cffff,
)
+ pred = arith.constant(T.i32(), 1)
for reduction in (
nvvm.BarrierReduction.AND,
nvvm.BarrierReduction.OR,
nvvm.BarrierReduction.POPC,
):
- res = nvvm.barrier(
- res=i32,
+ pred = nvvm.barrier(
reduction_op=reduction,
- reduction_predicate=res,
+ reduction_predicate=pred,
)
nvvm.barrier0()
@@ -129,15 +158,16 @@ def barriers(mask, vi32, vf32):
nvvm.cluster_wait(aligned=True)
nvvm.fence_mbarrier_init()
nvvm.bar_warp_sync(mask)
- return res
+ return pred
# CHECK-LABEL: func.func @barriers(
# CHECK: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: f32) -> i32 {
# CHECK: %[[CONSTANT_0:.*]] = arith.constant 0 : i32
# CHECK: %[[CONSTANT_1:.*]] = arith.constant 65535 : i32
-# CHECK: %[[BARRIER_0:.*]] = nvvm.barrier id = %[[CONSTANT_0]] number_of_threads = %[[CONSTANT_1]] -> i32
-# CHECK: %[[BARRIER_1:.*]] = nvvm.barrier #nvvm.reduction<and> %[[BARRIER_0]] -> i32
+# CHECK: nvvm.barrier id = %[[CONSTANT_0]] number_of_threads = %[[CONSTANT_1]]
+# CHECK: %[[PRED:.*]] = arith.constant 1 : i32
+# CHECK: %[[BARRIER_1:.*]] = nvvm.barrier #nvvm.reduction<and> %[[PRED]] -> i32
# CHECK: %[[BARRIER_2:.*]] = nvvm.barrier #nvvm.reduction<or> %[[BARRIER_1]] -> i32
# CHECK: %[[BARRIER_3:.*]] = nvvm.barrier #nvvm.reduction<popc> %[[BARRIER_2]] -> i32
# CHECK: nvvm.barrier0
@@ -151,7 +181,6 @@ def barriers(mask, vi32, vf32):
# CHECK: nvvm.fence.mbarrier.init
# CHECK: nvvm.bar.warp.sync %[[ARG0]] : i32
# CHECK: return %[[BARRIER_3]] : i32
-# CHECK: }
@constructAndPrintInModule
More information about the Mlir-commits
mailing list