[Mlir-commits] [mlir] [mlir][NVVM] Add InferTypeOpInterface to NVVM ops with deterministic result types (PR #188173)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Mar 23 21:59:18 PDT 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-gpu
Author: Bastian Hagedorn (bastianhagedorn)
<details>
<summary>Changes</summary>
Add result type inference to 10 NVVM ops whose result types can be fully determined from their operands and attributes. This enables the Python binding generator to emit `results=None` as a default parameter, removing the need for callers to pass explicit result types.
Ops with always-present results (using `InferTypeOpAdaptor`):
- `VoteSyncOp`: ballot → i32, any/all/uni → i1
- `MatchSyncOp`: any → i32, all → struct<(i32, i1)>
- `ShflOp`: follows val type, or struct<(val_type, i1)> with `return_value_and_is_valid`
- `LdMatrixOp`: i32 or struct of i32s based on num and shape
- `ClusterLaunchControlQueryCancelOp`: is_canceled → i1, others → i32
Ops with optional results (using `InferTypeOpAdaptorWithIsCompatible`):
- `MBarrierArriveOp`: i64 for non-cluster pointers, no result for shared_cluster
- `MBarrierArriveDropOp`: same as above
- `MBarrierArriveExpectTxOp`: same, plus no result when predicate is set
- `MBarrierArriveDropExpectTxOp`: same as MBarrierArriveOp
- `BarrierOp`: i32 when reductionOp is present, no result otherwise
The optional-result ops use a permissive `isCompatibleReturnTypes` that allows omitting the result, preserving backward compatibility with the existing zero-result assembly form.
Also updates the NVGPUToNVVM conversion to work with the new inferred result types for `MBarrierArriveOp` and `MBarrierArriveExpectTxOp`.
This is a source-breaking change for Python callers that pass result types positionally (e.g. `mbarrier_arrive(res, addr, ...)` becomes `mbarrier_arrive(addr, *, ..., results=None, ...)`). Existing MLIR assembly is fully backward compatible.
---
Patch is 35.96 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/188173.diff
6 Files Affected:
- (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+16-10)
- (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+8-5)
- (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+172)
- (modified) mlir/test/Dialect/LLVMIR/nvvm.mlir (+185)
- (modified) mlir/test/Target/LLVMIR/nvvm/mbar_arrive.mlir (+3-1)
- (modified) mlir/test/python/dialects/nvvm.py (+210-8)
``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 57a0c67e82c47..fccab166ac5ae 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -720,7 +720,8 @@ def NVVM_MBarrierCompleteTxOp : NVVM_VoidIntrinsicOp<"mbarrier.complete_tx"> {
let hasVerifier = 1;
}
-def NVVM_MBarrierArriveOp : NVVM_SingleResultIntrinsicOp<"mbarrier.arrive"> {
+def NVVM_MBarrierArriveOp : NVVM_SingleResultIntrinsicOp<"mbarrier.arrive",
+ [InferTypeOpAdaptorWithIsCompatible]> {
let summary = "MBarrier Arrive Operation";
let description = [{
The `nvvm.mbarrier.arrive` operation performs an arrive-on operation on the
@@ -768,7 +769,8 @@ def NVVM_MBarrierArriveOp : NVVM_SingleResultIntrinsicOp<"mbarrier.arrive"> {
let hasVerifier = 1;
}
-def NVVM_MBarrierArriveDropOp : NVVM_SingleResultIntrinsicOp<"mbarrier.arrive_drop"> {
+def NVVM_MBarrierArriveDropOp : NVVM_SingleResultIntrinsicOp<"mbarrier.arrive_drop",
+ [InferTypeOpAdaptorWithIsCompatible]> {
let summary = "MBarrier Arrive-Drop Operation";
let description = [{
The `nvvm.mbarrier.arrive_drop` operation decrements the expected arrival
@@ -844,7 +846,8 @@ def NVVM_MBarrierArriveDropNocompleteOp : NVVM_SingleResultIntrinsicOp<"mbarrier
let assemblyFormat = "$addr `,` $count attr-dict `:` type(operands) `->` type($res)";
}
-def NVVM_MBarrierArriveExpectTxOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx"> {
+def NVVM_MBarrierArriveExpectTxOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx",
+ [InferTypeOpAdaptorWithIsCompatible]> {
let summary = "MBarrier Arrive with Expected Transaction Count";
let description = [{
The `nvvm.mbarrier.arrive.expect_tx` operation performs an expect-tx operation
@@ -910,7 +913,8 @@ def NVVM_MBarrierArriveExpectTxOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_t
}];
}
-def NVVM_MBarrierArriveDropExpectTxOp : NVVM_SingleResultIntrinsicOp<"mbarrier.arrive_drop.expect_tx"> {
+def NVVM_MBarrierArriveDropExpectTxOp : NVVM_SingleResultIntrinsicOp<"mbarrier.arrive_drop.expect_tx",
+ [InferTypeOpAdaptorWithIsCompatible]> {
let summary = "MBarrier arrive_drop with expected transaction count";
let description = [{
The `nvvm.mbarrier.arrive_drop.expect_tx` operation is similar to the
@@ -1123,7 +1127,8 @@ def BarrierReductionAttr
let assemblyFormat = "`<` $value `>`";
}
-def NVVM_BarrierOp : NVVM_SingleResultIntrinsicOp<"barrier", [AttrSizedOperandSegments]> {
+def NVVM_BarrierOp : NVVM_SingleResultIntrinsicOp<"barrier",
+ [AttrSizedOperandSegments, InferTypeOpAdaptorWithIsCompatible]> {
let summary = "CTA Barrier Synchronization Op";
let description = [{
The `nvvm.barrier` operation performs barrier synchronization and communication
@@ -1504,7 +1509,7 @@ def ShflKind : I32EnumAttr<"ShflKind", "NVVM shuffle kind",
def ShflKindAttr : EnumAttr<NVVM_Dialect, ShflKind, "shfl_kind">;
def NVVM_ShflOp :
- NVVM_Op<"shfl.sync", [NVVMRequiresSM<30>]>,
+ NVVM_Op<"shfl.sync", [NVVMRequiresSM<30>, InferTypeOpAdaptor]>,
Results<(outs AnyTypeOf<[I32, F32, LLVMStructType]>:$res)>,
Arguments<(ins I32:$thread_mask,
AnyTypeOf<[I32, F32]>:$val,
@@ -1558,7 +1563,7 @@ def VoteSyncKind : I32EnumAttr<"VoteSyncKind", "NVVM vote sync kind",
def VoteSyncKindAttr : EnumAttr<NVVM_Dialect, VoteSyncKind, "vote_sync_kind">;
def NVVM_VoteSyncOp
- : NVVM_Op<"vote.sync">,
+ : NVVM_Op<"vote.sync", [InferTypeOpAdaptor]>,
Results<(outs AnyTypeOf<[I32, I1]>:$res)>,
Arguments<(ins I32:$mask, I1:$pred, VoteSyncKindAttr:$kind)> {
let summary = "Vote across thread group";
@@ -3028,7 +3033,7 @@ def NVVM_StMatrixOp: NVVM_Op<"stmatrix">,
let hasVerifier = 1;
}
-def NVVM_LdMatrixOp: NVVM_Op<"ldmatrix">,
+def NVVM_LdMatrixOp: NVVM_Op<"ldmatrix", [InferTypeOpAdaptor]>,
Results<(outs AnyType:$res)>,
Arguments<(ins LLVM_PointerShared:$ptr, I32Attr:$num,
MMALayoutAttr:$layout,
@@ -4783,7 +4788,7 @@ def MatchSyncKind : I32EnumAttr<"MatchSyncKind", "NVVM match sync kind",
def MatchSyncKindAttr : EnumAttr<NVVM_Dialect, MatchSyncKind, "match_sync_kind">;
-def NVVM_MatchSyncOp : NVVM_Op<"match.sync">,
+def NVVM_MatchSyncOp : NVVM_Op<"match.sync", [InferTypeOpAdaptor]>,
Results<(outs AnyTypeOf<[I32, LLVMStructType]>:$res)>,
Arguments<(ins I32:$thread_mask,
AnyTypeOf<[I32, I64]>:$val,
@@ -5723,7 +5728,8 @@ def ClusterLaunchControlQueryTypeAttr
}
def NVVM_ClusterLaunchControlQueryCancelOp
- : NVVM_Op<"clusterlaunchcontrol.query.cancel", [NVVMRequiresSM<100>]> {
+ : NVVM_Op<"clusterlaunchcontrol.query.cancel",
+ [NVVMRequiresSM<100>, InferTypeOpAdaptor]> {
let summary = "Query the response of a clusterlaunchcontrol.try.cancel operation";
let description = [{
`clusterlaunchcontrol.query.cancel` queries the response of a
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 6edc8f5c86dd3..956019f490c4c 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -852,9 +852,7 @@ struct NVGPUMBarrierArriveLowering
Value barrier =
getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
adaptor.getMbarId(), rewriter);
- Type tokenType = getTypeConverter()->convertType(
- nvgpu::MBarrierTokenType::get(op->getContext()));
- rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveOp>(op, tokenType, barrier);
+ rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveOp>(op, barrier);
return success();
}
};
@@ -911,12 +909,17 @@ struct NVGPUMBarrierArriveExpectTxLowering
getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
adaptor.getMbarId(), rewriter);
Value txcount = truncToI32(b, adaptor.getTxcount());
- rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>(
- op, Type{}, // return-value is optional and is void by default
+ // Use create+eraseOp instead of replaceOpWithNewOp because the nvgpu op
+ // has 0 results, but the NVVM op now infers an i64 result (via
+ // InferTypeOpInterface) for non-cluster pointers. The create() call
+ // triggers InferTypeOpInterface automatically when no result types are
+ // provided; the inferred i64 result is simply unused.
+ NVVM::MBarrierArriveExpectTxOp::create(rewriter, op->getLoc(),
barrier, txcount, // barrier and txcount
NVVM::MemScopeKind::CTA, // default scope is CTA
false, // relaxed-semantics is false
adaptor.getPredicate());
+ rewriter.eraseOp(op);
return success();
}
};
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 7d49aa3878ebe..edf142cfdca58 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -302,6 +302,88 @@ LogicalResult MBarrierArriveDropExpectTxOp::verify() {
getRes());
}
+//===----------------------------------------------------------------------===//
+// inferReturnTypes for mbarrier arrive-like ops
+//===----------------------------------------------------------------------===//
+
+/// Shared inference: shared_cluster addr -> no result; otherwise -> i64.
+/// Generic pointers (address space 0) are treated as shared::cta per the PTX
+/// ISA: "If no state space is specified then Generic Addressing is used. If the
+/// address does not fall within .shared::cta then the behavior is undefined."
+/// Therefore only ptr<7> (shared_cluster) produces zero results.
+static LogicalResult inferMBarrierArriveResultTypes(
+ MLIRContext *context, Value addr,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ auto ptrTy = llvm::dyn_cast<LLVM::LLVMPointerType>(addr.getType());
+ if (!ptrTy)
+ return failure();
+ if (ptrTy.getAddressSpace() !=
+ static_cast<unsigned>(NVVMMemorySpace::SharedCluster))
+ inferredReturnTypes.push_back(IntegerType::get(context, 64));
+ return success();
+}
+
+LogicalResult MBarrierArriveOp::inferReturnTypes(
+ MLIRContext *context, std::optional<Location> location,
+ MBarrierArriveOp::Adaptor adaptor,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ return inferMBarrierArriveResultTypes(context, adaptor.getAddr(),
+ inferredReturnTypes);
+}
+
+LogicalResult MBarrierArriveDropOp::inferReturnTypes(
+ MLIRContext *context, std::optional<Location> location,
+ MBarrierArriveDropOp::Adaptor adaptor,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ return inferMBarrierArriveResultTypes(context, adaptor.getAddr(),
+ inferredReturnTypes);
+}
+
+LogicalResult MBarrierArriveExpectTxOp::inferReturnTypes(
+ MLIRContext *context, std::optional<Location> location,
+ MBarrierArriveExpectTxOp::Adaptor adaptor,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ // Predicate forces no return value (inline PTX path).
+ // Note: predicate + shared_cluster is rejected by the verifier separately.
+ if (adaptor.getPredicate())
+ return success();
+ return inferMBarrierArriveResultTypes(context, adaptor.getAddr(),
+ inferredReturnTypes);
+}
+
+LogicalResult MBarrierArriveDropExpectTxOp::inferReturnTypes(
+ MLIRContext *context, std::optional<Location> location,
+ MBarrierArriveDropExpectTxOp::Adaptor adaptor,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ return inferMBarrierArriveResultTypes(context, adaptor.getAddr(),
+ inferredReturnTypes);
+}
+
+/// For ops with optional results, allow the user to omit the result even when
+/// inference would produce one. This preserves backward compatibility: the
+/// result can be silently discarded (e.g., for fire-and-forget arrive ops).
+static bool isCompatibleReturnTypesOptionalResult(TypeRange inferred,
+ TypeRange actual) {
+ if (actual.empty())
+ return true;
+ return inferred == actual;
+}
+
+bool MBarrierArriveOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
+ return isCompatibleReturnTypesOptionalResult(l, r);
+}
+bool MBarrierArriveDropOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
+ return isCompatibleReturnTypesOptionalResult(l, r);
+}
+bool MBarrierArriveExpectTxOp::isCompatibleReturnTypes(TypeRange l,
+ TypeRange r) {
+ return isCompatibleReturnTypesOptionalResult(l, r);
+}
+bool MBarrierArriveDropExpectTxOp::isCompatibleReturnTypes(TypeRange l,
+ TypeRange r) {
+ return isCompatibleReturnTypesOptionalResult(l, r);
+}
+
LogicalResult MBarrierExpectTxOp::verify() {
return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
}
@@ -2244,6 +2326,23 @@ LogicalResult ShflOp::verify() {
return success();
}
+LogicalResult ShflOp::inferReturnTypes(
+ MLIRContext *context, std::optional<Location> location,
+ ShflOp::Adaptor adaptor,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ Type valType = adaptor.getVal().getType();
+ if (adaptor.getReturnValueAndIsValid()) {
+ // struct<(val_type, i1)>
+ SmallVector<Type> body = {valType, IntegerType::get(context, 1)};
+ inferredReturnTypes.push_back(
+ LLVM::LLVMStructType::getLiteral(context, body));
+ } else {
+ // same type as val
+ inferredReturnTypes.push_back(valType);
+ }
+ return success();
+}
+
std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
NVVM::MMAFrag frag, int nRow,
int nCol,
@@ -2450,6 +2549,26 @@ LogicalResult NVVM::LdMatrixOp::verify() {
return success();
}
+LogicalResult LdMatrixOp::inferReturnTypes(
+ MLIRContext *context, std::optional<Location> location,
+ LdMatrixOp::Adaptor adaptor,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ uint32_t num = adaptor.getNum();
+ uint32_t m = adaptor.getShape().getM();
+ uint32_t n = adaptor.getShape().getN();
+ uint32_t numElements = (m == 16 && n == 16) ? num * 2 : num;
+
+ Type i32 = IntegerType::get(context, 32);
+ if (numElements == 1) {
+ inferredReturnTypes.push_back(i32);
+ } else {
+ inferredReturnTypes.push_back(
+ LLVM::LLVMStructType::getLiteral(context,
+ SmallVector<Type>(numElements, i32)));
+ }
+ return success();
+}
+
LogicalResult NVVM::StMatrixOp::verify() {
int numMatrix = getSources().size();
if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
@@ -2798,6 +2917,19 @@ LogicalResult NVVM::BarrierOp::verify() {
return success();
}
+LogicalResult BarrierOp::inferReturnTypes(
+ MLIRContext *context, std::optional<Location> location,
+ BarrierOp::Adaptor adaptor,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ if (adaptor.getReductionOp())
+ inferredReturnTypes.push_back(IntegerType::get(context, 32));
+ return success();
+}
+
+bool BarrierOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
+ return isCompatibleReturnTypesOptionalResult(l, r);
+}
+
LogicalResult NVVM::Tcgen05CpOp::verify() {
auto mc = getMulticast();
@@ -2840,6 +2972,23 @@ LogicalResult NVVM::MatchSyncOp::verify() {
return success();
}
+LogicalResult MatchSyncOp::inferReturnTypes(
+ MLIRContext *context, std::optional<Location> location,
+ MatchSyncOp::Adaptor adaptor,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ if (adaptor.getKind() == NVVM::MatchSyncKind::all) {
+ // all: returns struct<(i32, i1)>
+ SmallVector<Type> body = {IntegerType::get(context, 32),
+ IntegerType::get(context, 1)};
+ inferredReturnTypes.push_back(
+ LLVM::LLVMStructType::getLiteral(context, body));
+ } else {
+ // any: returns i32
+ inferredReturnTypes.push_back(IntegerType::get(context, 32));
+ }
+ return success();
+}
+
LogicalResult NVVM::VoteSyncOp::verify() {
if (getKind() == NVVM::VoteSyncKind::ballot) {
if (!getType().isInteger(32)) {
@@ -2853,6 +3002,17 @@ LogicalResult NVVM::VoteSyncOp::verify() {
return success();
}
+LogicalResult VoteSyncOp::inferReturnTypes(
+ MLIRContext *context, std::optional<Location> location,
+ VoteSyncOp::Adaptor adaptor,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ if (adaptor.getKind() == NVVM::VoteSyncKind::ballot)
+ inferredReturnTypes.push_back(IntegerType::get(context, 32));
+ else
+ inferredReturnTypes.push_back(IntegerType::get(context, 1));
+ return success();
+}
+
LogicalResult NVVM::PrefetchOp::verify() {
using MemSpace = NVVM::NVVMMemorySpace;
using CacheLevel = NVVM::PrefetchCacheLevel;
@@ -2946,6 +3106,18 @@ LogicalResult NVVM::ClusterLaunchControlQueryCancelOp::verify() {
return success();
}
+LogicalResult ClusterLaunchControlQueryCancelOp::inferReturnTypes(
+ MLIRContext *context, std::optional<Location> location,
+ ClusterLaunchControlQueryCancelOp::Adaptor adaptor,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ if (adaptor.getQueryType() ==
+ NVVM::ClusterLaunchControlQueryType::IS_CANCELED)
+ inferredReturnTypes.push_back(IntegerType::get(context, 1));
+ else
+ inferredReturnTypes.push_back(IntegerType::get(context, 32));
+ return success();
+}
+
LogicalResult NVVM::ReduxOp::verify() {
mlir::Type reduxType = getType();
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index b11ba944fe4ac..fad9c66bc694a 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -443,6 +443,13 @@ llvm.func private @mbarrier_arrive_shared(%barrier: !llvm.ptr<3>) {
llvm.return
}
+// CHECK-LABEL: @mbarrier_arrive_cluster
+llvm.func private @mbarrier_arrive_cluster(%barrier: !llvm.ptr<7>) {
+ // CHECK: nvvm.mbarrier.arrive %{{.*}} : !llvm.ptr<7>{{$}}
+ nvvm.mbarrier.arrive %barrier : !llvm.ptr<7>
+ llvm.return
+}
+
llvm.func private @mbarrier_arrive_nocomplete(%barrier: !llvm.ptr) {
%count = nvvm.read.ptx.sreg.ntid.x : i32
// CHECK: nvvm.mbarrier.arrive.nocomplete %{{.*}} : !llvm.ptr
@@ -457,6 +464,55 @@ llvm.func private @mbarrier_arrive_nocomplete_shared(%barrier: !llvm.ptr<3>) {
llvm.return
}
+// CHECK-LABEL: @mbarrier_arrive_drop_shared
+llvm.func private @mbarrier_arrive_drop_shared(%barrier: !llvm.ptr<3>) {
+ // CHECK: nvvm.mbarrier.arrive_drop %{{.*}} : !llvm.ptr<3> -> i64
+ %0 = nvvm.mbarrier.arrive_drop %barrier : !llvm.ptr<3> -> i64
+ llvm.return
+}
+
+// CHECK-LABEL: @mbarrier_arrive_drop_cluster
+llvm.func private @mbarrier_arrive_drop_cluster(%barrier: !llvm.ptr<7>) {
+ // CHECK: nvvm.mbarrier.arrive_drop %{{.*}} : !llvm.ptr<7>{{$}}
+ nvvm.mbarrier.arrive_drop %barrier : !llvm.ptr<7>
+ llvm.return
+}
+
+// CHECK-LABEL: @mbarrier_arrive_expect_tx_shared
+llvm.func private @mbarrier_arrive_expect_tx_shared(%barrier: !llvm.ptr<3>, %txcount: i32) {
+ // CHECK: nvvm.mbarrier.arrive.expect_tx %{{.*}}, %{{.*}} : !llvm.ptr<3>, i32 -> i64
+ %0 = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount : !llvm.ptr<3>, i32 -> i64
+ llvm.return
+}
+
+// CHECK-LABEL: @mbarrier_arrive_expect_tx_cluster
+llvm.func private @mbarrier_arrive_expect_tx_cluster(%barrier: !llvm.ptr<7>, %txcount: i32) {
+ // CHECK: nvvm.mbarrier.arrive.expect_tx %{{.*}}, %{{.*}} : !llvm.ptr<7>, i32{{$}}
+ nvvm.mbarrier.arrive.expect_tx %barrier, %txcount : !llvm.ptr<7>, i32
+ llvm.return
+}
+
+// CHECK-LABEL: @mbarrier_arrive_expect_tx_predicate
+llvm.func private @mbarrier_arrive_expect_tx_predicate(%barrier: !llvm.ptr<3>, %txcount: i32, %pred: i1) {
+ // CHECK: nvvm.mbarrier.arrive.expect_tx %{{.*}}, %{{.*}}, predicate = %{{.*}} : !llvm.ptr<3>, i32, i1{{$}}
+ nvvm.mbarrier.arrive.expect_tx %barrier, %txcount, predicate = %pred : !llvm.ptr<3>, i32, i1
+ llvm.return
+}
+
+// CHECK-LABEL: @mbarrier_arrive_drop_expect_tx_shared
+llvm.func private @mbarrier_arrive_drop_expect_tx_shared(%barrier: !llvm.ptr<3>, %txcount: i32) {
+ // CHECK: nvvm.mbarrier.arrive_drop.expect_tx %{{.*}}, %{{.*}} : !llvm.ptr<3>, i32 -> i64
+ %0 = nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount : !llvm.ptr<3>, i32 -> i64
+ llvm.return
+}
+
+// CHECK-LABEL: @mbarrier_arrive_drop_expect_tx_cluster
+llvm.func private @mbarrier_arrive_drop_expect_tx_cluster(%barrier: !llvm.ptr<7>, %txcount: i32) {
+ // CHECK: nvvm.mbarrier.arrive_drop.expect_tx %{{.*}}, %{{.*}} : !llvm.ptr<7>, i32{{$}}
+ nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount : !llvm.ptr<7>, i32
+ llvm.return
+}
+
// CHECK-LABEL: @wgmma_fence_aligned
func.func @wgmma_fence_aligned() {
// CHECK: nvvm.wgmma.fence.aligned
@@ -616,6 +672,135 @@ llvm.func @kernel_func(%arg0: !llvm.ptr {nvvm.grid_constant}) attributes {nvvm.k
// -----
+// CHECK-LABEL: @vote_sync_infer_type
+func.func @vote_sync_infer_type(%mask : i32, %pred : i1) {
+ // Ballot infers i32 result
+ // CHECK: nvvm.vote.sync ballot %{{.*}}, %{{.*}} -> i32
+ %0 = nvvm.vote.sync ballot %mask, %pred -> i32
+ // All infers i1 result
+ // CHECK: nvvm.vote.sync all %{{.*}}, %{{.*}} -> i1
+ %1 = nvvm.vote.sync all %mask, %pred -> i1
+ // Any infers i1 result
+ // CHECK: nvvm.vote.sync any %{{.*}}, %{{.*}} -> i1
+ %2 = nvvm.vote.sync any %mask, %pred -> i1
+ // Uni infers i1 result
+ // CHECK: nvvm.vote.sync uni %{{.*}}, %{{.*}} -> i1
+ %3 = nvvm.vote.sync uni %mask, %pred -> i1
+ llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @clusterlaunchcontrol_query_cancel_infer_type
+llvm.func @clusterlaunchcontrol_query_cancel_infer_type(%response : i128) {
+ // is_canceled infers i1
+ // CHECK: nvvm.clusterlaunchcontrol.query.cancel query = is_canceled, %{{.*}} : i1
+ %0 = nvvm.clusterlaunchcontrol.query.cancel query = is_canceled, %response : i1
+ // get_first_cta_id_x infers i32
+ // CHECK: nvvm.clusterlaunchcontrol.query.cancel query = get_first_cta_id_x, %{{.*}} : i32
+ %1 = nvvm.clusterlaunchcontrol.query.cancel query = get_first_cta_id_x, %response : i32
+ // get_first_cta_id_y infers i32
+ // CHECK: nvvm.clusterlaunchcontrol.query.cancel query = get_first_cta_id_y, %{{.*}} : i32
+ %2 = nvvm.clusterlaunchco...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/188173
More information about the Mlir-commits
mailing list