[Mlir-commits] [mlir] [mlir][NVVM] Add InferTypeOpInterface to NVVM ops with deterministic result types (PR #188173)

Bastian Hagedorn llvmlistbot at llvm.org
Thu Apr 9 01:49:26 PDT 2026


https://github.com/bastianhagedorn updated https://github.com/llvm/llvm-project/pull/188173

>From eb5f29145847c8ed5278ce6348adc2f33b358f7c Mon Sep 17 00:00:00 2001
From: Bastian Hagedorn <bhagedorn at nvidia.com>
Date: Mon, 23 Mar 2026 08:31:52 +0000
Subject: [PATCH] [mlir][NVVM] Add InferTypeOpInterface to mbarrier arrive and
 barrier ops

Add result type inference to the 5 NVVM barrier/mbarrier ops with
optional results, using InferTypeOpAdaptorWithIsCompatible.

- MBarrierArriveOp, MBarrierArriveDropOp, MBarrierArriveDropExpectTxOp:
  i64 result for non-shared_cluster pointers, no result otherwise
- MBarrierArriveExpectTxOp: same, plus no result when predicate is set
- BarrierOp: i32 result when reductionOp attribute is present

A permissive isCompatibleReturnTypes allows omitting the result,
preserving backward compatibility with the existing zero-result
assembly form.

This causes the Python binding generator to emit `results=None` as a
default parameter instead of requiring explicit result types.

Also updates the NVGPUToNVVM conversion to work with the new inferred
result types for MBarrierArriveOp and MBarrierArriveExpectTxOp.

Note: this is a source-breaking change for Python callers that pass
result types positionally.

Co-Authored-By: Claude <noreply at anthropic.com>
---
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td   | 15 ++--
 .../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp    | 14 ++-
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp    | 88 +++++++++++++++++++
 mlir/test/Dialect/LLVMIR/nvvm.mlir            | 14 +++
 mlir/test/Target/LLVMIR/nvvm/mbar_arrive.mlir |  4 +-
 mlir/test/python/dialects/nvvm.py             | 48 ++++++++--
 6 files changed, 160 insertions(+), 23 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 57a0c67e82c47..ed1e0982db74a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -720,7 +720,8 @@ def NVVM_MBarrierCompleteTxOp : NVVM_VoidIntrinsicOp<"mbarrier.complete_tx"> {
   let hasVerifier = 1;
 }
 
-def NVVM_MBarrierArriveOp : NVVM_SingleResultIntrinsicOp<"mbarrier.arrive"> {
+def NVVM_MBarrierArriveOp : NVVM_SingleResultIntrinsicOp<"mbarrier.arrive",
+    [InferTypeOpAdaptorWithIsCompatible]> {
   let summary = "MBarrier Arrive Operation";
   let description = [{
     The `nvvm.mbarrier.arrive` operation performs an arrive-on operation on the 
@@ -768,7 +769,8 @@ def NVVM_MBarrierArriveOp : NVVM_SingleResultIntrinsicOp<"mbarrier.arrive"> {
   let hasVerifier = 1;
 }
 
-def NVVM_MBarrierArriveDropOp : NVVM_SingleResultIntrinsicOp<"mbarrier.arrive_drop"> {
+def NVVM_MBarrierArriveDropOp : NVVM_SingleResultIntrinsicOp<"mbarrier.arrive_drop",
+    [InferTypeOpAdaptorWithIsCompatible]> {
   let summary = "MBarrier Arrive-Drop Operation";
   let description = [{
     The `nvvm.mbarrier.arrive_drop` operation decrements the expected arrival
@@ -844,7 +846,8 @@ def NVVM_MBarrierArriveDropNocompleteOp : NVVM_SingleResultIntrinsicOp<"mbarrier
   let assemblyFormat = "$addr `,` $count attr-dict `:` type(operands) `->` type($res)";
 }
 
-def NVVM_MBarrierArriveExpectTxOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx"> {
+def NVVM_MBarrierArriveExpectTxOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx",
+    [InferTypeOpAdaptorWithIsCompatible]> {
   let summary = "MBarrier Arrive with Expected Transaction Count";
   let description = [{
     The `nvvm.mbarrier.arrive.expect_tx` operation performs an expect-tx operation 
@@ -910,7 +913,8 @@ def NVVM_MBarrierArriveExpectTxOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_t
   }];
 }
 
-def NVVM_MBarrierArriveDropExpectTxOp : NVVM_SingleResultIntrinsicOp<"mbarrier.arrive_drop.expect_tx"> {
+def NVVM_MBarrierArriveDropExpectTxOp : NVVM_SingleResultIntrinsicOp<"mbarrier.arrive_drop.expect_tx",
+    [InferTypeOpAdaptorWithIsCompatible]> {
   let summary = "MBarrier arrive_drop with expected transaction count";
   let description = [{
     The `nvvm.mbarrier.arrive_drop.expect_tx` operation is similar to the
@@ -1123,7 +1127,8 @@ def BarrierReductionAttr
   let assemblyFormat = "`<` $value `>`";
 }
 
-def NVVM_BarrierOp : NVVM_SingleResultIntrinsicOp<"barrier", [AttrSizedOperandSegments]> {
+def NVVM_BarrierOp : NVVM_SingleResultIntrinsicOp<"barrier",
+    [AttrSizedOperandSegments, InferTypeOpAdaptorWithIsCompatible]> {
   let summary = "CTA Barrier Synchronization Op";
   let description = [{
     The `nvvm.barrier` operation performs barrier synchronization and communication 
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 6edc8f5c86dd3..3da46e9c9f979 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -852,9 +852,7 @@ struct NVGPUMBarrierArriveLowering
     Value barrier =
         getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
                        adaptor.getMbarId(), rewriter);
-    Type tokenType = getTypeConverter()->convertType(
-        nvgpu::MBarrierTokenType::get(op->getContext()));
-    rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveOp>(op, tokenType, barrier);
+    rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveOp>(op, barrier);
     return success();
   }
 };
@@ -911,12 +909,12 @@ struct NVGPUMBarrierArriveExpectTxLowering
         getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
                        adaptor.getMbarId(), rewriter);
     Value txcount = truncToI32(b, adaptor.getTxcount());
-    rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>(
-        op, Type{},       // return-value is optional and is void by default
-        barrier, txcount, // barrier and txcount
-        NVVM::MemScopeKind::CTA, // default scope is CTA
-        false,                   // relaxed-semantics is false
+    NVVM::MBarrierArriveExpectTxOp::create(
+        rewriter, op->getLoc(), barrier, txcount, // barrier and txcount
+        NVVM::MemScopeKind::CTA,                  // default scope is CTA
+        false,                                    // relaxed-semantics is false
         adaptor.getPredicate());
+    rewriter.eraseOp(op);
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 528e709629ebf..bb76c92aec47b 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -302,6 +302,82 @@ LogicalResult MBarrierArriveDropExpectTxOp::verify() {
                                     getRes());
 }
 
+//===----------------------------------------------------------------------===//
+// inferReturnTypes for mbarrier arrive-like ops
+//===----------------------------------------------------------------------===//
+
+/// Only shared_cluster (ptr<7>) produces zero results; all other address
+/// spaces (including generic) return i64.
+static LogicalResult
+inferMBarrierArriveResultTypes(MLIRContext *context, Value addr,
+                               SmallVectorImpl<Type> &inferredReturnTypes) {
+  if (!isPtrInSharedClusterSpace(addr))
+    inferredReturnTypes.push_back(IntegerType::get(context, 64));
+  return success();
+}
+
+LogicalResult
+MBarrierArriveOp::inferReturnTypes(MLIRContext *context,
+                                   std::optional<Location> location,
+                                   MBarrierArriveOp::Adaptor adaptor,
+                                   SmallVectorImpl<Type> &inferredReturnTypes) {
+  return inferMBarrierArriveResultTypes(context, adaptor.getAddr(),
+                                        inferredReturnTypes);
+}
+
+LogicalResult MBarrierArriveDropOp::inferReturnTypes(
+    MLIRContext *context, std::optional<Location> location,
+    MBarrierArriveDropOp::Adaptor adaptor,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  return inferMBarrierArriveResultTypes(context, adaptor.getAddr(),
+                                        inferredReturnTypes);
+}
+
+LogicalResult MBarrierArriveExpectTxOp::inferReturnTypes(
+    MLIRContext *context, std::optional<Location> location,
+    MBarrierArriveExpectTxOp::Adaptor adaptor,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  // Predicate forces no return value (inline PTX path).
+  // Note: predicate + shared_cluster is rejected by the verifier separately.
+  if (adaptor.getPredicate())
+    return success();
+  return inferMBarrierArriveResultTypes(context, adaptor.getAddr(),
+                                        inferredReturnTypes);
+}
+
+LogicalResult MBarrierArriveDropExpectTxOp::inferReturnTypes(
+    MLIRContext *context, std::optional<Location> location,
+    MBarrierArriveDropExpectTxOp::Adaptor adaptor,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  return inferMBarrierArriveResultTypes(context, adaptor.getAddr(),
+                                        inferredReturnTypes);
+}
+
+/// For ops with optional results, allow the user to omit the result even when
+/// inference would produce one. This preserves backward compatibility: the
+/// result can be silently discarded (e.g., for fire-and-forget arrive ops).
+static bool isCompatibleReturnTypesOptionalResult(TypeRange inferred,
+                                                  TypeRange actual) {
+  if (actual.empty())
+    return true;
+  return inferred == actual;
+}
+
+bool MBarrierArriveOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
+  return isCompatibleReturnTypesOptionalResult(l, r);
+}
+bool MBarrierArriveDropOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
+  return isCompatibleReturnTypesOptionalResult(l, r);
+}
+bool MBarrierArriveExpectTxOp::isCompatibleReturnTypes(TypeRange l,
+                                                       TypeRange r) {
+  return isCompatibleReturnTypesOptionalResult(l, r);
+}
+bool MBarrierArriveDropExpectTxOp::isCompatibleReturnTypes(TypeRange l,
+                                                           TypeRange r) {
+  return isCompatibleReturnTypesOptionalResult(l, r);
+}
+
 LogicalResult MBarrierExpectTxOp::verify() {
   return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
 }
@@ -2798,6 +2874,18 @@ LogicalResult NVVM::BarrierOp::verify() {
   return success();
 }
 
+LogicalResult BarrierOp::inferReturnTypes(
+    MLIRContext *context, std::optional<Location> location,
+    BarrierOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
+  if (adaptor.getReductionOp())
+    inferredReturnTypes.push_back(IntegerType::get(context, 32));
+  return success();
+}
+
+bool BarrierOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
+  return isCompatibleReturnTypesOptionalResult(l, r);
+}
+
 LogicalResult NVVM::Tcgen05CpOp::verify() {
   auto mc = getMulticast();
 
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index b11ba944fe4ac..47350b05a862c 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -443,6 +443,13 @@ llvm.func private @mbarrier_arrive_shared(%barrier: !llvm.ptr<3>) {
   llvm.return
 }
 
+// CHECK-LABEL: @mbarrier_arrive_cluster
+llvm.func private @mbarrier_arrive_cluster(%barrier: !llvm.ptr<7>) {
+  // CHECK:   nvvm.mbarrier.arrive %{{.*}} : !llvm.ptr<7>{{$}}
+  nvvm.mbarrier.arrive %barrier : !llvm.ptr<7>
+  llvm.return
+}
+
 llvm.func private @mbarrier_arrive_nocomplete(%barrier: !llvm.ptr) {
   %count = nvvm.read.ptx.sreg.ntid.x : i32
   // CHECK:   nvvm.mbarrier.arrive.nocomplete %{{.*}} : !llvm.ptr
@@ -457,6 +464,13 @@ llvm.func private @mbarrier_arrive_nocomplete_shared(%barrier: !llvm.ptr<3>) {
   llvm.return
 }
 
+// CHECK-LABEL: @mbarrier_arrive_expect_tx_predicate
+llvm.func private @mbarrier_arrive_expect_tx_predicate(%barrier: !llvm.ptr<3>, %txcount: i32, %pred: i1) {
+  // CHECK:   nvvm.mbarrier.arrive.expect_tx %{{.*}}, %{{.*}}, predicate = %{{.*}} : !llvm.ptr<3>, i32, i1{{$}}
+  nvvm.mbarrier.arrive.expect_tx %barrier, %txcount, predicate = %pred : !llvm.ptr<3>, i32, i1
+  llvm.return
+}
+
 // CHECK-LABEL: @wgmma_fence_aligned
 func.func @wgmma_fence_aligned() {
   // CHECK: nvvm.wgmma.fence.aligned
diff --git a/mlir/test/Target/LLVMIR/nvvm/mbar_arrive.mlir b/mlir/test/Target/LLVMIR/nvvm/mbar_arrive.mlir
index f406922ea3873..96c910b193f12 100644
--- a/mlir/test/Target/LLVMIR/nvvm/mbar_arrive.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/mbar_arrive.mlir
@@ -108,8 +108,10 @@ llvm.func @mbarrier_arrive_ignore_retval(%count : i32, %barrier: !llvm.ptr<3>) {
   // CHECK-NEXT: %4 = call i64 @llvm.nvvm.mbarrier.arrive.scope.cta.space.cta(ptr addrspace(3) %1, i32 %0)
   // CHECK-NEXT: ret void
   // CHECK-NEXT: }
+  // Result silently discarded (backward compatible form)
   nvvm.mbarrier.arrive %barrier, %count : !llvm.ptr<3>
-  nvvm.mbarrier.arrive %barrier, %count : !llvm.ptr<3>
+  // Result explicitly captured
+  %0 = nvvm.mbarrier.arrive %barrier, %count : !llvm.ptr<3> -> i64
 
   llvm.return
 }
diff --git a/mlir/test/python/dialects/nvvm.py b/mlir/test/python/dialects/nvvm.py
index dc8e4f462ad56..d80126a5c8646 100644
--- a/mlir/test/python/dialects/nvvm.py
+++ b/mlir/test/python/dialects/nvvm.py
@@ -93,6 +93,37 @@ def my_inline_ptx(a, b, c, d):
         arith.addf(wo0, wo1)
 
 
+ at constructAndPrintInModule
+def test_mbarrier_arrive():
+    ptr_shared = llvm.PointerType.get(3)
+    ptr_cluster = llvm.PointerType.get(7)
+    i32 = T.i32()
+    i64 = T.i64()
+
+    @func.FuncOp.from_py_func(ptr_shared, ptr_cluster, i32)
+    def mbarrier_arrive_ops(barrier_shared, barrier_cluster, txcount):
+        token = nvvm.mbarrier_arrive(barrier_shared)
+        nvvm.mbarrier_arrive(barrier_cluster)
+        token2 = nvvm.mbarrier_arrive_drop(barrier_shared)
+        nvvm.mbarrier_arrive_drop(barrier_cluster)
+        token3 = nvvm.mbarrier_arrive_expect_tx(barrier_shared, txcount)
+        nvvm.mbarrier_arrive_expect_tx(barrier_cluster, txcount)
+        token4 = nvvm.mbarrier_arrive_drop_expect_tx(barrier_shared, txcount)
+        nvvm.mbarrier_arrive_drop_expect_tx(barrier_cluster, txcount)
+
+
+# CHECK-LABEL:   func.func @mbarrier_arrive_ops(
+# CHECK-SAME:      %[[SHARED:.*]]: !llvm.ptr<3>, %[[CLUSTER:.*]]: !llvm.ptr<7>, %[[TXCOUNT:.*]]: i32)
+# CHECK:           %{{.*}} = nvvm.mbarrier.arrive %[[SHARED]] : !llvm.ptr<3> -> i64
+# CHECK-NEXT:      nvvm.mbarrier.arrive %[[CLUSTER]] : !llvm.ptr<7>{{$}}
+# CHECK-NEXT:      %{{.*}} = nvvm.mbarrier.arrive_drop %[[SHARED]] : !llvm.ptr<3> -> i64
+# CHECK-NEXT:      nvvm.mbarrier.arrive_drop %[[CLUSTER]] : !llvm.ptr<7>{{$}}
+# CHECK-NEXT:      %{{.*}} = nvvm.mbarrier.arrive.expect_tx %[[SHARED]], %[[TXCOUNT]] : !llvm.ptr<3>, i32 -> i64
+# CHECK-NEXT:      nvvm.mbarrier.arrive.expect_tx %[[CLUSTER]], %[[TXCOUNT]] : !llvm.ptr<7>, i32{{$}}
+# CHECK-NEXT:      %{{.*}} = nvvm.mbarrier.arrive_drop.expect_tx %[[SHARED]], %[[TXCOUNT]] : !llvm.ptr<3>, i32 -> i64
+# CHECK-NEXT:      nvvm.mbarrier.arrive_drop.expect_tx %[[CLUSTER]], %[[TXCOUNT]] : !llvm.ptr<7>, i32{{$}}
+
+
 @constructAndPrintInModule
 def test_barriers():
     i32 = T.i32()
@@ -102,21 +133,20 @@ def test_barriers():
     def barriers(mask, vi32, vf32):
         c0 = arith.constant(T.i32(), 0)
         cffff = arith.constant(T.i32(), 0xFFFF)
-        res = nvvm.barrier(
-            res=i32,
+        nvvm.barrier(
             barrier_id=c0,
             number_of_threads=cffff,
         )
 
+        pred = arith.constant(T.i32(), 1)
         for reduction in (
             nvvm.BarrierReduction.AND,
             nvvm.BarrierReduction.OR,
             nvvm.BarrierReduction.POPC,
         ):
-            res = nvvm.barrier(
-                res=i32,
+            pred = nvvm.barrier(
                 reduction_op=reduction,
-                reduction_predicate=res,
+                reduction_predicate=pred,
             )
 
         nvvm.barrier0()
@@ -129,15 +159,16 @@ def barriers(mask, vi32, vf32):
         nvvm.cluster_wait(aligned=True)
         nvvm.fence_mbarrier_init()
         nvvm.bar_warp_sync(mask)
-        return res
+        return pred
 
 
 # CHECK-LABEL:   func.func @barriers(
 # CHECK:           %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: f32) -> i32 {
 # CHECK:           %[[CONSTANT_0:.*]] = arith.constant 0 : i32
 # CHECK:           %[[CONSTANT_1:.*]] = arith.constant 65535 : i32
-# CHECK:           %[[BARRIER_0:.*]] = nvvm.barrier id = %[[CONSTANT_0]] number_of_threads = %[[CONSTANT_1]] -> i32
-# CHECK:           %[[BARRIER_1:.*]] = nvvm.barrier #nvvm.reduction<and> %[[BARRIER_0]] -> i32
+# CHECK:           nvvm.barrier id = %[[CONSTANT_0]] number_of_threads = %[[CONSTANT_1]]
+# CHECK:           %[[PRED:.*]] = arith.constant 1 : i32
+# CHECK:           %[[BARRIER_1:.*]] = nvvm.barrier #nvvm.reduction<and> %[[PRED]] -> i32
 # CHECK:           %[[BARRIER_2:.*]] = nvvm.barrier #nvvm.reduction<or> %[[BARRIER_1]] -> i32
 # CHECK:           %[[BARRIER_3:.*]] = nvvm.barrier #nvvm.reduction<popc> %[[BARRIER_2]] -> i32
 # CHECK:           nvvm.barrier0
@@ -151,7 +182,6 @@ def barriers(mask, vi32, vf32):
 # CHECK:           nvvm.fence.mbarrier.init
 # CHECK:           nvvm.bar.warp.sync %[[ARG0]] : i32
 # CHECK:           return %[[BARRIER_3]] : i32
-# CHECK:         }
 
 
 @constructAndPrintInModule



More information about the Mlir-commits mailing list