[Mlir-commits] [mlir] [mlir][NVVM] Add InferTypeOpInterface to NVVM ops with deterministic result types (PR #188173)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 23 21:59:18 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-gpu

Author: Bastian Hagedorn (bastianhagedorn)

<details>
<summary>Changes</summary>

Add result type inference to 10 NVVM ops whose result types can be fully determined from their operands and attributes. This enables the Python binding generator to emit `results=None` as a default parameter, removing the need for callers to pass explicit result types.

Ops with always-present results (using `InferTypeOpAdaptor`):
- `VoteSyncOp`: ballot → i32, any/all/uni → i1
- `MatchSyncOp`: any → i32, all → struct<(i32, i1)>
- `ShflOp`: follows val type, or struct<(val_type, i1)> with `return_value_and_is_valid`
- `LdMatrixOp`: i32 or struct of i32s based on num and shape
- `ClusterLaunchControlQueryCancelOp`: is_canceled → i1, others → i32

Ops with optional results (using `InferTypeOpAdaptorWithIsCompatible`):
- `MBarrierArriveOp`: i64 for non-cluster pointers, no result for shared_cluster
- `MBarrierArriveDropOp`: same as above
- `MBarrierArriveExpectTxOp`: same, plus no result when predicate is set
- `MBarrierArriveDropExpectTxOp`: same as MBarrierArriveOp
- `BarrierOp`: i32 when reductionOp is present, no result otherwise

The optional-result ops use a permissive `isCompatibleReturnTypes` that allows omitting the result, preserving backward compatibility with the existing zero-result assembly form.

Also updates the NVGPUToNVVM conversion to work with the new inferred result types for `MBarrierArriveOp` and `MBarrierArriveExpectTxOp`.

This is a source-breaking change for Python callers that pass result types positionally (e.g. `mbarrier_arrive(res, addr, ...)` becomes `mbarrier_arrive(addr, *, ..., results=None, ...)`). Existing MLIR assembly is fully backward compatible.


---

Patch is 35.96 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/188173.diff


6 Files Affected:

- (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+16-10) 
- (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+8-5) 
- (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+172) 
- (modified) mlir/test/Dialect/LLVMIR/nvvm.mlir (+185) 
- (modified) mlir/test/Target/LLVMIR/nvvm/mbar_arrive.mlir (+3-1) 
- (modified) mlir/test/python/dialects/nvvm.py (+210-8) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 57a0c67e82c47..fccab166ac5ae 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -720,7 +720,8 @@ def NVVM_MBarrierCompleteTxOp : NVVM_VoidIntrinsicOp<"mbarrier.complete_tx"> {
   let hasVerifier = 1;
 }
 
-def NVVM_MBarrierArriveOp : NVVM_SingleResultIntrinsicOp<"mbarrier.arrive"> {
+def NVVM_MBarrierArriveOp : NVVM_SingleResultIntrinsicOp<"mbarrier.arrive",
+    [InferTypeOpAdaptorWithIsCompatible]> {
   let summary = "MBarrier Arrive Operation";
   let description = [{
     The `nvvm.mbarrier.arrive` operation performs an arrive-on operation on the 
@@ -768,7 +769,8 @@ def NVVM_MBarrierArriveOp : NVVM_SingleResultIntrinsicOp<"mbarrier.arrive"> {
   let hasVerifier = 1;
 }
 
-def NVVM_MBarrierArriveDropOp : NVVM_SingleResultIntrinsicOp<"mbarrier.arrive_drop"> {
+def NVVM_MBarrierArriveDropOp : NVVM_SingleResultIntrinsicOp<"mbarrier.arrive_drop",
+    [InferTypeOpAdaptorWithIsCompatible]> {
   let summary = "MBarrier Arrive-Drop Operation";
   let description = [{
     The `nvvm.mbarrier.arrive_drop` operation decrements the expected arrival
@@ -844,7 +846,8 @@ def NVVM_MBarrierArriveDropNocompleteOp : NVVM_SingleResultIntrinsicOp<"mbarrier
   let assemblyFormat = "$addr `,` $count attr-dict `:` type(operands) `->` type($res)";
 }
 
-def NVVM_MBarrierArriveExpectTxOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx"> {
+def NVVM_MBarrierArriveExpectTxOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx",
+    [InferTypeOpAdaptorWithIsCompatible]> {
   let summary = "MBarrier Arrive with Expected Transaction Count";
   let description = [{
     The `nvvm.mbarrier.arrive.expect_tx` operation performs an expect-tx operation 
@@ -910,7 +913,8 @@ def NVVM_MBarrierArriveExpectTxOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_t
   }];
 }
 
-def NVVM_MBarrierArriveDropExpectTxOp : NVVM_SingleResultIntrinsicOp<"mbarrier.arrive_drop.expect_tx"> {
+def NVVM_MBarrierArriveDropExpectTxOp : NVVM_SingleResultIntrinsicOp<"mbarrier.arrive_drop.expect_tx",
+    [InferTypeOpAdaptorWithIsCompatible]> {
   let summary = "MBarrier arrive_drop with expected transaction count";
   let description = [{
     The `nvvm.mbarrier.arrive_drop.expect_tx` operation is similar to the
@@ -1123,7 +1127,8 @@ def BarrierReductionAttr
   let assemblyFormat = "`<` $value `>`";
 }
 
-def NVVM_BarrierOp : NVVM_SingleResultIntrinsicOp<"barrier", [AttrSizedOperandSegments]> {
+def NVVM_BarrierOp : NVVM_SingleResultIntrinsicOp<"barrier",
+    [AttrSizedOperandSegments, InferTypeOpAdaptorWithIsCompatible]> {
   let summary = "CTA Barrier Synchronization Op";
   let description = [{
     The `nvvm.barrier` operation performs barrier synchronization and communication 
@@ -1504,7 +1509,7 @@ def ShflKind : I32EnumAttr<"ShflKind", "NVVM shuffle kind",
 def ShflKindAttr : EnumAttr<NVVM_Dialect, ShflKind, "shfl_kind">;
 
 def NVVM_ShflOp :
-  NVVM_Op<"shfl.sync", [NVVMRequiresSM<30>]>,
+  NVVM_Op<"shfl.sync", [NVVMRequiresSM<30>, InferTypeOpAdaptor]>,
   Results<(outs AnyTypeOf<[I32, F32, LLVMStructType]>:$res)>,
   Arguments<(ins I32:$thread_mask,
                  AnyTypeOf<[I32, F32]>:$val,
@@ -1558,7 +1563,7 @@ def VoteSyncKind : I32EnumAttr<"VoteSyncKind", "NVVM vote sync kind",
 def VoteSyncKindAttr : EnumAttr<NVVM_Dialect, VoteSyncKind, "vote_sync_kind">;
 
 def NVVM_VoteSyncOp
-    : NVVM_Op<"vote.sync">,
+    : NVVM_Op<"vote.sync", [InferTypeOpAdaptor]>,
       Results<(outs AnyTypeOf<[I32, I1]>:$res)>,
       Arguments<(ins I32:$mask, I1:$pred, VoteSyncKindAttr:$kind)> {
   let summary = "Vote across thread group";
@@ -3028,7 +3033,7 @@ def NVVM_StMatrixOp: NVVM_Op<"stmatrix">,
   let hasVerifier = 1;
 }
 
-def NVVM_LdMatrixOp: NVVM_Op<"ldmatrix">,
+def NVVM_LdMatrixOp: NVVM_Op<"ldmatrix", [InferTypeOpAdaptor]>,
   Results<(outs AnyType:$res)>,
   Arguments<(ins LLVM_PointerShared:$ptr, I32Attr:$num,
                  MMALayoutAttr:$layout,
@@ -4783,7 +4788,7 @@ def MatchSyncKind : I32EnumAttr<"MatchSyncKind", "NVVM match sync kind",
 
 def MatchSyncKindAttr : EnumAttr<NVVM_Dialect, MatchSyncKind, "match_sync_kind">;
 
-def NVVM_MatchSyncOp : NVVM_Op<"match.sync">,
+def NVVM_MatchSyncOp : NVVM_Op<"match.sync", [InferTypeOpAdaptor]>,
   Results<(outs AnyTypeOf<[I32, LLVMStructType]>:$res)>,
   Arguments<(ins I32:$thread_mask,
                  AnyTypeOf<[I32, I64]>:$val,
@@ -5723,7 +5728,8 @@ def ClusterLaunchControlQueryTypeAttr
 }
 
 def NVVM_ClusterLaunchControlQueryCancelOp
-    : NVVM_Op<"clusterlaunchcontrol.query.cancel", [NVVMRequiresSM<100>]> {
+    : NVVM_Op<"clusterlaunchcontrol.query.cancel",
+              [NVVMRequiresSM<100>, InferTypeOpAdaptor]> {
   let summary = "Query the response of a clusterlaunchcontrol.try.cancel operation";
   let description = [{
     `clusterlaunchcontrol.query.cancel` queries the response of a 
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 6edc8f5c86dd3..956019f490c4c 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -852,9 +852,7 @@ struct NVGPUMBarrierArriveLowering
     Value barrier =
         getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
                        adaptor.getMbarId(), rewriter);
-    Type tokenType = getTypeConverter()->convertType(
-        nvgpu::MBarrierTokenType::get(op->getContext()));
-    rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveOp>(op, tokenType, barrier);
+    rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveOp>(op, barrier);
     return success();
   }
 };
@@ -911,12 +909,17 @@ struct NVGPUMBarrierArriveExpectTxLowering
         getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
                        adaptor.getMbarId(), rewriter);
     Value txcount = truncToI32(b, adaptor.getTxcount());
-    rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>(
-        op, Type{},       // return-value is optional and is void by default
+    // Use create+eraseOp instead of replaceOpWithNewOp because the nvgpu op
+    // has 0 results, but the NVVM op now infers an i64 result (via
+    // InferTypeOpInterface) for non-cluster pointers. The create() call
+    // triggers InferTypeOpInterface automatically when no result types are
+    // provided; the inferred i64 result is simply unused.
+    NVVM::MBarrierArriveExpectTxOp::create(rewriter, op->getLoc(),
         barrier, txcount, // barrier and txcount
         NVVM::MemScopeKind::CTA, // default scope is CTA
         false,                   // relaxed-semantics is false
         adaptor.getPredicate());
+    rewriter.eraseOp(op);
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 7d49aa3878ebe..edf142cfdca58 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -302,6 +302,88 @@ LogicalResult MBarrierArriveDropExpectTxOp::verify() {
                                     getRes());
 }
 
+//===----------------------------------------------------------------------===//
+// inferReturnTypes for mbarrier arrive-like ops
+//===----------------------------------------------------------------------===//
+
+/// Shared inference: shared_cluster addr -> no result; otherwise -> i64.
+/// Generic pointers (address space 0) are treated as shared::cta per the PTX
+/// ISA: "If no state space is specified then Generic Addressing is used. If the
+/// address does not fall within .shared::cta then the behavior is undefined."
+/// Therefore only ptr<7> (shared_cluster) produces zero results.
+static LogicalResult inferMBarrierArriveResultTypes(
+    MLIRContext *context, Value addr,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  auto ptrTy = llvm::dyn_cast<LLVM::LLVMPointerType>(addr.getType());
+  if (!ptrTy)
+    return failure();
+  if (ptrTy.getAddressSpace() !=
+      static_cast<unsigned>(NVVMMemorySpace::SharedCluster))
+    inferredReturnTypes.push_back(IntegerType::get(context, 64));
+  return success();
+}
+
+LogicalResult MBarrierArriveOp::inferReturnTypes(
+    MLIRContext *context, std::optional<Location> location,
+    MBarrierArriveOp::Adaptor adaptor,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  return inferMBarrierArriveResultTypes(context, adaptor.getAddr(),
+                                        inferredReturnTypes);
+}
+
+LogicalResult MBarrierArriveDropOp::inferReturnTypes(
+    MLIRContext *context, std::optional<Location> location,
+    MBarrierArriveDropOp::Adaptor adaptor,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  return inferMBarrierArriveResultTypes(context, adaptor.getAddr(),
+                                        inferredReturnTypes);
+}
+
+LogicalResult MBarrierArriveExpectTxOp::inferReturnTypes(
+    MLIRContext *context, std::optional<Location> location,
+    MBarrierArriveExpectTxOp::Adaptor adaptor,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  // Predicate forces no return value (inline PTX path).
+  // Note: predicate + shared_cluster is rejected by the verifier separately.
+  if (adaptor.getPredicate())
+    return success();
+  return inferMBarrierArriveResultTypes(context, adaptor.getAddr(),
+                                        inferredReturnTypes);
+}
+
+LogicalResult MBarrierArriveDropExpectTxOp::inferReturnTypes(
+    MLIRContext *context, std::optional<Location> location,
+    MBarrierArriveDropExpectTxOp::Adaptor adaptor,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  return inferMBarrierArriveResultTypes(context, adaptor.getAddr(),
+                                        inferredReturnTypes);
+}
+
+/// For ops with optional results, allow the user to omit the result even when
+/// inference would produce one. This preserves backward compatibility: the
+/// result can be silently discarded (e.g., for fire-and-forget arrive ops).
+static bool isCompatibleReturnTypesOptionalResult(TypeRange inferred,
+                                                  TypeRange actual) {
+  if (actual.empty())
+    return true;
+  return inferred == actual;
+}
+
+bool MBarrierArriveOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
+  return isCompatibleReturnTypesOptionalResult(l, r);
+}
+bool MBarrierArriveDropOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
+  return isCompatibleReturnTypesOptionalResult(l, r);
+}
+bool MBarrierArriveExpectTxOp::isCompatibleReturnTypes(TypeRange l,
+                                                       TypeRange r) {
+  return isCompatibleReturnTypesOptionalResult(l, r);
+}
+bool MBarrierArriveDropExpectTxOp::isCompatibleReturnTypes(TypeRange l,
+                                                           TypeRange r) {
+  return isCompatibleReturnTypesOptionalResult(l, r);
+}
+
 LogicalResult MBarrierExpectTxOp::verify() {
   return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
 }
@@ -2244,6 +2326,23 @@ LogicalResult ShflOp::verify() {
   return success();
 }
 
+LogicalResult ShflOp::inferReturnTypes(
+    MLIRContext *context, std::optional<Location> location,
+    ShflOp::Adaptor adaptor,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  Type valType = adaptor.getVal().getType();
+  if (adaptor.getReturnValueAndIsValid()) {
+    // struct<(val_type, i1)>
+    SmallVector<Type> body = {valType, IntegerType::get(context, 1)};
+    inferredReturnTypes.push_back(
+        LLVM::LLVMStructType::getLiteral(context, body));
+  } else {
+    // same type as val
+    inferredReturnTypes.push_back(valType);
+  }
+  return success();
+}
+
 std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
                                                    NVVM::MMAFrag frag, int nRow,
                                                    int nCol,
@@ -2450,6 +2549,26 @@ LogicalResult NVVM::LdMatrixOp::verify() {
   return success();
 }
 
+LogicalResult LdMatrixOp::inferReturnTypes(
+    MLIRContext *context, std::optional<Location> location,
+    LdMatrixOp::Adaptor adaptor,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  uint32_t num = adaptor.getNum();
+  uint32_t m = adaptor.getShape().getM();
+  uint32_t n = adaptor.getShape().getN();
+  uint32_t numElements = (m == 16 && n == 16) ? num * 2 : num;
+
+  Type i32 = IntegerType::get(context, 32);
+  if (numElements == 1) {
+    inferredReturnTypes.push_back(i32);
+  } else {
+    inferredReturnTypes.push_back(
+        LLVM::LLVMStructType::getLiteral(context,
+                                         SmallVector<Type>(numElements, i32)));
+  }
+  return success();
+}
+
 LogicalResult NVVM::StMatrixOp::verify() {
   int numMatrix = getSources().size();
   if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
@@ -2798,6 +2917,19 @@ LogicalResult NVVM::BarrierOp::verify() {
   return success();
 }
 
+LogicalResult BarrierOp::inferReturnTypes(
+    MLIRContext *context, std::optional<Location> location,
+    BarrierOp::Adaptor adaptor,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  if (adaptor.getReductionOp())
+    inferredReturnTypes.push_back(IntegerType::get(context, 32));
+  return success();
+}
+
+bool BarrierOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
+  return isCompatibleReturnTypesOptionalResult(l, r);
+}
+
 LogicalResult NVVM::Tcgen05CpOp::verify() {
   auto mc = getMulticast();
 
@@ -2840,6 +2972,23 @@ LogicalResult NVVM::MatchSyncOp::verify() {
   return success();
 }
 
+LogicalResult MatchSyncOp::inferReturnTypes(
+    MLIRContext *context, std::optional<Location> location,
+    MatchSyncOp::Adaptor adaptor,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  if (adaptor.getKind() == NVVM::MatchSyncKind::all) {
+    // all: returns struct<(i32, i1)>
+    SmallVector<Type> body = {IntegerType::get(context, 32),
+                                 IntegerType::get(context, 1)};
+    inferredReturnTypes.push_back(
+        LLVM::LLVMStructType::getLiteral(context, body));
+  } else {
+    // any: returns i32
+    inferredReturnTypes.push_back(IntegerType::get(context, 32));
+  }
+  return success();
+}
+
 LogicalResult NVVM::VoteSyncOp::verify() {
   if (getKind() == NVVM::VoteSyncKind::ballot) {
     if (!getType().isInteger(32)) {
@@ -2853,6 +3002,17 @@ LogicalResult NVVM::VoteSyncOp::verify() {
   return success();
 }
 
+LogicalResult VoteSyncOp::inferReturnTypes(
+    MLIRContext *context, std::optional<Location> location,
+    VoteSyncOp::Adaptor adaptor,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  if (adaptor.getKind() == NVVM::VoteSyncKind::ballot)
+    inferredReturnTypes.push_back(IntegerType::get(context, 32));
+  else
+    inferredReturnTypes.push_back(IntegerType::get(context, 1));
+  return success();
+}
+
 LogicalResult NVVM::PrefetchOp::verify() {
   using MemSpace = NVVM::NVVMMemorySpace;
   using CacheLevel = NVVM::PrefetchCacheLevel;
@@ -2946,6 +3106,18 @@ LogicalResult NVVM::ClusterLaunchControlQueryCancelOp::verify() {
   return success();
 }
 
+LogicalResult ClusterLaunchControlQueryCancelOp::inferReturnTypes(
+    MLIRContext *context, std::optional<Location> location,
+    ClusterLaunchControlQueryCancelOp::Adaptor adaptor,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  if (adaptor.getQueryType() ==
+      NVVM::ClusterLaunchControlQueryType::IS_CANCELED)
+    inferredReturnTypes.push_back(IntegerType::get(context, 1));
+  else
+    inferredReturnTypes.push_back(IntegerType::get(context, 32));
+  return success();
+}
+
 LogicalResult NVVM::ReduxOp::verify() {
   mlir::Type reduxType = getType();
 
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index b11ba944fe4ac..fad9c66bc694a 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -443,6 +443,13 @@ llvm.func private @mbarrier_arrive_shared(%barrier: !llvm.ptr<3>) {
   llvm.return
 }
 
+// CHECK-LABEL: @mbarrier_arrive_cluster
+llvm.func private @mbarrier_arrive_cluster(%barrier: !llvm.ptr<7>) {
+  // CHECK:   nvvm.mbarrier.arrive %{{.*}} : !llvm.ptr<7>{{$}}
+  nvvm.mbarrier.arrive %barrier : !llvm.ptr<7>
+  llvm.return
+}
+
 llvm.func private @mbarrier_arrive_nocomplete(%barrier: !llvm.ptr) {
   %count = nvvm.read.ptx.sreg.ntid.x : i32
   // CHECK:   nvvm.mbarrier.arrive.nocomplete %{{.*}} : !llvm.ptr
@@ -457,6 +464,55 @@ llvm.func private @mbarrier_arrive_nocomplete_shared(%barrier: !llvm.ptr<3>) {
   llvm.return
 }
 
+// CHECK-LABEL: @mbarrier_arrive_drop_shared
+llvm.func private @mbarrier_arrive_drop_shared(%barrier: !llvm.ptr<3>) {
+  // CHECK:   nvvm.mbarrier.arrive_drop %{{.*}} : !llvm.ptr<3> -> i64
+  %0 = nvvm.mbarrier.arrive_drop %barrier : !llvm.ptr<3> -> i64
+  llvm.return
+}
+
+// CHECK-LABEL: @mbarrier_arrive_drop_cluster
+llvm.func private @mbarrier_arrive_drop_cluster(%barrier: !llvm.ptr<7>) {
+  // CHECK:   nvvm.mbarrier.arrive_drop %{{.*}} : !llvm.ptr<7>{{$}}
+  nvvm.mbarrier.arrive_drop %barrier : !llvm.ptr<7>
+  llvm.return
+}
+
+// CHECK-LABEL: @mbarrier_arrive_expect_tx_shared
+llvm.func private @mbarrier_arrive_expect_tx_shared(%barrier: !llvm.ptr<3>, %txcount: i32) {
+  // CHECK:   nvvm.mbarrier.arrive.expect_tx %{{.*}}, %{{.*}} : !llvm.ptr<3>, i32 -> i64
+  %0 = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount : !llvm.ptr<3>, i32 -> i64
+  llvm.return
+}
+
+// CHECK-LABEL: @mbarrier_arrive_expect_tx_cluster
+llvm.func private @mbarrier_arrive_expect_tx_cluster(%barrier: !llvm.ptr<7>, %txcount: i32) {
+  // CHECK:   nvvm.mbarrier.arrive.expect_tx %{{.*}}, %{{.*}} : !llvm.ptr<7>, i32{{$}}
+  nvvm.mbarrier.arrive.expect_tx %barrier, %txcount : !llvm.ptr<7>, i32
+  llvm.return
+}
+
+// CHECK-LABEL: @mbarrier_arrive_expect_tx_predicate
+llvm.func private @mbarrier_arrive_expect_tx_predicate(%barrier: !llvm.ptr<3>, %txcount: i32, %pred: i1) {
+  // CHECK:   nvvm.mbarrier.arrive.expect_tx %{{.*}}, %{{.*}}, predicate = %{{.*}} : !llvm.ptr<3>, i32, i1{{$}}
+  nvvm.mbarrier.arrive.expect_tx %barrier, %txcount, predicate = %pred : !llvm.ptr<3>, i32, i1
+  llvm.return
+}
+
+// CHECK-LABEL: @mbarrier_arrive_drop_expect_tx_shared
+llvm.func private @mbarrier_arrive_drop_expect_tx_shared(%barrier: !llvm.ptr<3>, %txcount: i32) {
+  // CHECK:   nvvm.mbarrier.arrive_drop.expect_tx %{{.*}}, %{{.*}} : !llvm.ptr<3>, i32 -> i64
+  %0 = nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount : !llvm.ptr<3>, i32 -> i64
+  llvm.return
+}
+
+// CHECK-LABEL: @mbarrier_arrive_drop_expect_tx_cluster
+llvm.func private @mbarrier_arrive_drop_expect_tx_cluster(%barrier: !llvm.ptr<7>, %txcount: i32) {
+  // CHECK:   nvvm.mbarrier.arrive_drop.expect_tx %{{.*}}, %{{.*}} : !llvm.ptr<7>, i32{{$}}
+  nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount : !llvm.ptr<7>, i32
+  llvm.return
+}
+
 // CHECK-LABEL: @wgmma_fence_aligned
 func.func @wgmma_fence_aligned() {
   // CHECK: nvvm.wgmma.fence.aligned
@@ -616,6 +672,135 @@ llvm.func @kernel_func(%arg0: !llvm.ptr {nvvm.grid_constant}) attributes {nvvm.k
 
 // -----
 
+// CHECK-LABEL: @vote_sync_infer_type
+func.func @vote_sync_infer_type(%mask : i32, %pred : i1) {
+  // Ballot infers i32 result
+  // CHECK: nvvm.vote.sync ballot %{{.*}}, %{{.*}} -> i32
+  %0 = nvvm.vote.sync ballot %mask, %pred -> i32
+  // All infers i1 result
+  // CHECK: nvvm.vote.sync all %{{.*}}, %{{.*}} -> i1
+  %1 = nvvm.vote.sync all %mask, %pred -> i1
+  // Any infers i1 result
+  // CHECK: nvvm.vote.sync any %{{.*}}, %{{.*}} -> i1
+  %2 = nvvm.vote.sync any %mask, %pred -> i1
+  // Uni infers i1 result
+  // CHECK: nvvm.vote.sync uni %{{.*}}, %{{.*}} -> i1
+  %3 = nvvm.vote.sync uni %mask, %pred -> i1
+  llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @clusterlaunchcontrol_query_cancel_infer_type
+llvm.func @clusterlaunchcontrol_query_cancel_infer_type(%response : i128) {
+  // is_canceled infers i1
+  // CHECK: nvvm.clusterlaunchcontrol.query.cancel query = is_canceled, %{{.*}} : i1
+  %0 = nvvm.clusterlaunchcontrol.query.cancel query = is_canceled, %response : i1
+  // get_first_cta_id_x infers i32
+  // CHECK: nvvm.clusterlaunchcontrol.query.cancel query = get_first_cta_id_x, %{{.*}} : i32
+  %1 = nvvm.clusterlaunchcontrol.query.cancel query = get_first_cta_id_x, %response : i32
+  // get_first_cta_id_y infers i32
+  // CHECK: nvvm.clusterlaunchcontrol.query.cancel query = get_first_cta_id_y, %{{.*}} : i32
+  %2 = nvvm.clusterlaunchco...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/188173


More information about the Mlir-commits mailing list