[Mlir-commits] [mlir] [mlir][NVVM] Add InferTypeOpInterface to NVVM ops with deterministic result types (PR #188173)
Bastian Hagedorn
llvmlistbot at llvm.org
Tue Mar 24 22:53:51 PDT 2026
https://github.com/bastianhagedorn updated https://github.com/llvm/llvm-project/pull/188173
>From bbea6fc239290b0794f3866759e96b7696f05b49 Mon Sep 17 00:00:00 2001
From: Bastian Hagedorn <bhagedorn at nvidia.com>
Date: Mon, 23 Mar 2026 08:31:52 +0000
Subject: [PATCH] [mlir][NVVM] Add InferTypeOpInterface to mbarrier arrive and
barrier ops
Add result type inference to the 5 NVVM barrier/mbarrier ops with
optional results, using InferTypeOpAdaptorWithIsCompatible.
- MBarrierArriveOp, MBarrierArriveDropOp, MBarrierArriveDropExpectTxOp:
i64 result for non-shared_cluster pointers, no result otherwise
- MBarrierArriveExpectTxOp: same, plus no result when predicate is set
- BarrierOp: i32 result when reductionOp attribute is present
A permissive isCompatibleReturnTypes allows omitting the result,
preserving backward compatibility with the existing zero-result
assembly form.
This causes the Python binding generator to emit `results=None` as a
default parameter instead of requiring explicit result types.
Also updates the NVGPUToNVVM conversion to work with the new inferred
result types for MBarrierArriveOp and MBarrierArriveExpectTxOp.
Note: this is a source-breaking change for Python callers that pass
result types positionally.
Co-Authored-By: Claude <noreply at anthropic.com>
---
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 15 +-
.../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp | 13 +-
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 95 +++++++++
mlir/test/Dialect/LLVMIR/nvvm.mlir | 185 ++++++++++++++++++
mlir/test/Target/LLVMIR/nvvm/mbar_arrive.mlir | 4 +-
mlir/test/python/dialects/nvvm.py | 48 ++++-
6 files changed, 340 insertions(+), 20 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 57a0c67e82c47..ed1e0982db74a 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -720,7 +720,8 @@ def NVVM_MBarrierCompleteTxOp : NVVM_VoidIntrinsicOp<"mbarrier.complete_tx"> {
let hasVerifier = 1;
}
-def NVVM_MBarrierArriveOp : NVVM_SingleResultIntrinsicOp<"mbarrier.arrive"> {
+def NVVM_MBarrierArriveOp : NVVM_SingleResultIntrinsicOp<"mbarrier.arrive",
+ [InferTypeOpAdaptorWithIsCompatible]> {
let summary = "MBarrier Arrive Operation";
let description = [{
The `nvvm.mbarrier.arrive` operation performs an arrive-on operation on the
@@ -768,7 +769,8 @@ def NVVM_MBarrierArriveOp : NVVM_SingleResultIntrinsicOp<"mbarrier.arrive"> {
let hasVerifier = 1;
}
-def NVVM_MBarrierArriveDropOp : NVVM_SingleResultIntrinsicOp<"mbarrier.arrive_drop"> {
+def NVVM_MBarrierArriveDropOp : NVVM_SingleResultIntrinsicOp<"mbarrier.arrive_drop",
+ [InferTypeOpAdaptorWithIsCompatible]> {
let summary = "MBarrier Arrive-Drop Operation";
let description = [{
The `nvvm.mbarrier.arrive_drop` operation decrements the expected arrival
@@ -844,7 +846,8 @@ def NVVM_MBarrierArriveDropNocompleteOp : NVVM_SingleResultIntrinsicOp<"mbarrier
let assemblyFormat = "$addr `,` $count attr-dict `:` type(operands) `->` type($res)";
}
-def NVVM_MBarrierArriveExpectTxOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx"> {
+def NVVM_MBarrierArriveExpectTxOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx",
+ [InferTypeOpAdaptorWithIsCompatible]> {
let summary = "MBarrier Arrive with Expected Transaction Count";
let description = [{
The `nvvm.mbarrier.arrive.expect_tx` operation performs an expect-tx operation
@@ -910,7 +913,8 @@ def NVVM_MBarrierArriveExpectTxOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_t
}];
}
-def NVVM_MBarrierArriveDropExpectTxOp : NVVM_SingleResultIntrinsicOp<"mbarrier.arrive_drop.expect_tx"> {
+def NVVM_MBarrierArriveDropExpectTxOp : NVVM_SingleResultIntrinsicOp<"mbarrier.arrive_drop.expect_tx",
+ [InferTypeOpAdaptorWithIsCompatible]> {
let summary = "MBarrier arrive_drop with expected transaction count";
let description = [{
The `nvvm.mbarrier.arrive_drop.expect_tx` operation is similar to the
@@ -1123,7 +1127,8 @@ def BarrierReductionAttr
let assemblyFormat = "`<` $value `>`";
}
-def NVVM_BarrierOp : NVVM_SingleResultIntrinsicOp<"barrier", [AttrSizedOperandSegments]> {
+def NVVM_BarrierOp : NVVM_SingleResultIntrinsicOp<"barrier",
+ [AttrSizedOperandSegments, InferTypeOpAdaptorWithIsCompatible]> {
let summary = "CTA Barrier Synchronization Op";
let description = [{
The `nvvm.barrier` operation performs barrier synchronization and communication
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 6edc8f5c86dd3..956019f490c4c 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -852,9 +852,7 @@ struct NVGPUMBarrierArriveLowering
Value barrier =
getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
adaptor.getMbarId(), rewriter);
- Type tokenType = getTypeConverter()->convertType(
- nvgpu::MBarrierTokenType::get(op->getContext()));
- rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveOp>(op, tokenType, barrier);
+ rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveOp>(op, barrier);
return success();
}
};
@@ -911,12 +909,17 @@ struct NVGPUMBarrierArriveExpectTxLowering
getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
adaptor.getMbarId(), rewriter);
Value txcount = truncToI32(b, adaptor.getTxcount());
- rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>(
- op, Type{}, // return-value is optional and is void by default
+ // Use create+eraseOp instead of replaceOpWithNewOp because the nvgpu op
+ // has 0 results, but the NVVM op now infers an i64 result (via
+ // InferTypeOpInterface) for non-cluster pointers. The create() call
+ // triggers InferTypeOpInterface automatically when no result types are
+ // provided; the inferred i64 result is simply unused.
+ NVVM::MBarrierArriveExpectTxOp::create(rewriter, op->getLoc(),
barrier, txcount, // barrier and txcount
NVVM::MemScopeKind::CTA, // default scope is CTA
false, // relaxed-semantics is false
adaptor.getPredicate());
+ rewriter.eraseOp(op);
return success();
}
};
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 528e709629ebf..242bad5458c32 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -302,6 +302,88 @@ LogicalResult MBarrierArriveDropExpectTxOp::verify() {
getRes());
}
+//===----------------------------------------------------------------------===//
+// inferReturnTypes for mbarrier arrive-like ops
+//===----------------------------------------------------------------------===//
+
+/// Shared inference: shared_cluster addr -> no result; otherwise -> i64.
+/// Generic pointers (address space 0) are treated as shared::cta per the PTX
+/// ISA: "If no state space is specified then Generic Addressing is used. If the
+/// address does not fall within .shared::cta then the behavior is undefined."
+/// Therefore only ptr<7> (shared_cluster) produces zero results.
+static LogicalResult inferMBarrierArriveResultTypes(
+ MLIRContext *context, Value addr,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ auto ptrTy = llvm::dyn_cast<LLVM::LLVMPointerType>(addr.getType());
+ if (!ptrTy)
+ return failure();
+ if (ptrTy.getAddressSpace() !=
+ static_cast<unsigned>(NVVMMemorySpace::SharedCluster))
+ inferredReturnTypes.push_back(IntegerType::get(context, 64));
+ return success();
+}
+
+LogicalResult MBarrierArriveOp::inferReturnTypes(
+ MLIRContext *context, std::optional<Location> location,
+ MBarrierArriveOp::Adaptor adaptor,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ return inferMBarrierArriveResultTypes(context, adaptor.getAddr(),
+ inferredReturnTypes);
+}
+
+LogicalResult MBarrierArriveDropOp::inferReturnTypes(
+ MLIRContext *context, std::optional<Location> location,
+ MBarrierArriveDropOp::Adaptor adaptor,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ return inferMBarrierArriveResultTypes(context, adaptor.getAddr(),
+ inferredReturnTypes);
+}
+
+LogicalResult MBarrierArriveExpectTxOp::inferReturnTypes(
+ MLIRContext *context, std::optional<Location> location,
+ MBarrierArriveExpectTxOp::Adaptor adaptor,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ // Predicate forces no return value (inline PTX path).
+ // Note: predicate + shared_cluster is rejected by the verifier separately.
+ if (adaptor.getPredicate())
+ return success();
+ return inferMBarrierArriveResultTypes(context, adaptor.getAddr(),
+ inferredReturnTypes);
+}
+
+LogicalResult MBarrierArriveDropExpectTxOp::inferReturnTypes(
+ MLIRContext *context, std::optional<Location> location,
+ MBarrierArriveDropExpectTxOp::Adaptor adaptor,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ return inferMBarrierArriveResultTypes(context, adaptor.getAddr(),
+ inferredReturnTypes);
+}
+
+/// For ops with optional results, allow the user to omit the result even when
+/// inference would produce one. This preserves backward compatibility: the
+/// result can be silently discarded (e.g., for fire-and-forget arrive ops).
+static bool isCompatibleReturnTypesOptionalResult(TypeRange inferred,
+ TypeRange actual) {
+ if (actual.empty())
+ return true;
+ return inferred == actual;
+}
+
+bool MBarrierArriveOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
+ return isCompatibleReturnTypesOptionalResult(l, r);
+}
+bool MBarrierArriveDropOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
+ return isCompatibleReturnTypesOptionalResult(l, r);
+}
+bool MBarrierArriveExpectTxOp::isCompatibleReturnTypes(TypeRange l,
+ TypeRange r) {
+ return isCompatibleReturnTypesOptionalResult(l, r);
+}
+bool MBarrierArriveDropExpectTxOp::isCompatibleReturnTypes(TypeRange l,
+ TypeRange r) {
+ return isCompatibleReturnTypesOptionalResult(l, r);
+}
+
LogicalResult MBarrierExpectTxOp::verify() {
return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
}
@@ -2798,6 +2880,19 @@ LogicalResult NVVM::BarrierOp::verify() {
return success();
}
+LogicalResult BarrierOp::inferReturnTypes(
+ MLIRContext *context, std::optional<Location> location,
+ BarrierOp::Adaptor adaptor,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ if (adaptor.getReductionOp())
+ inferredReturnTypes.push_back(IntegerType::get(context, 32));
+ return success();
+}
+
+bool BarrierOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
+ return isCompatibleReturnTypesOptionalResult(l, r);
+}
+
LogicalResult NVVM::Tcgen05CpOp::verify() {
auto mc = getMulticast();
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index b11ba944fe4ac..fad9c66bc694a 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -443,6 +443,13 @@ llvm.func private @mbarrier_arrive_shared(%barrier: !llvm.ptr<3>) {
llvm.return
}
+// CHECK-LABEL: @mbarrier_arrive_cluster
+llvm.func private @mbarrier_arrive_cluster(%barrier: !llvm.ptr<7>) {
+ // CHECK: nvvm.mbarrier.arrive %{{.*}} : !llvm.ptr<7>{{$}}
+ nvvm.mbarrier.arrive %barrier : !llvm.ptr<7>
+ llvm.return
+}
+
llvm.func private @mbarrier_arrive_nocomplete(%barrier: !llvm.ptr) {
%count = nvvm.read.ptx.sreg.ntid.x : i32
// CHECK: nvvm.mbarrier.arrive.nocomplete %{{.*}} : !llvm.ptr
@@ -457,6 +464,55 @@ llvm.func private @mbarrier_arrive_nocomplete_shared(%barrier: !llvm.ptr<3>) {
llvm.return
}
+// CHECK-LABEL: @mbarrier_arrive_drop_shared
+llvm.func private @mbarrier_arrive_drop_shared(%barrier: !llvm.ptr<3>) {
+ // CHECK: nvvm.mbarrier.arrive_drop %{{.*}} : !llvm.ptr<3> -> i64
+ %0 = nvvm.mbarrier.arrive_drop %barrier : !llvm.ptr<3> -> i64
+ llvm.return
+}
+
+// CHECK-LABEL: @mbarrier_arrive_drop_cluster
+llvm.func private @mbarrier_arrive_drop_cluster(%barrier: !llvm.ptr<7>) {
+ // CHECK: nvvm.mbarrier.arrive_drop %{{.*}} : !llvm.ptr<7>{{$}}
+ nvvm.mbarrier.arrive_drop %barrier : !llvm.ptr<7>
+ llvm.return
+}
+
+// CHECK-LABEL: @mbarrier_arrive_expect_tx_shared
+llvm.func private @mbarrier_arrive_expect_tx_shared(%barrier: !llvm.ptr<3>, %txcount: i32) {
+ // CHECK: nvvm.mbarrier.arrive.expect_tx %{{.*}}, %{{.*}} : !llvm.ptr<3>, i32 -> i64
+ %0 = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount : !llvm.ptr<3>, i32 -> i64
+ llvm.return
+}
+
+// CHECK-LABEL: @mbarrier_arrive_expect_tx_cluster
+llvm.func private @mbarrier_arrive_expect_tx_cluster(%barrier: !llvm.ptr<7>, %txcount: i32) {
+ // CHECK: nvvm.mbarrier.arrive.expect_tx %{{.*}}, %{{.*}} : !llvm.ptr<7>, i32{{$}}
+ nvvm.mbarrier.arrive.expect_tx %barrier, %txcount : !llvm.ptr<7>, i32
+ llvm.return
+}
+
+// CHECK-LABEL: @mbarrier_arrive_expect_tx_predicate
+llvm.func private @mbarrier_arrive_expect_tx_predicate(%barrier: !llvm.ptr<3>, %txcount: i32, %pred: i1) {
+ // CHECK: nvvm.mbarrier.arrive.expect_tx %{{.*}}, %{{.*}}, predicate = %{{.*}} : !llvm.ptr<3>, i32, i1{{$}}
+ nvvm.mbarrier.arrive.expect_tx %barrier, %txcount, predicate = %pred : !llvm.ptr<3>, i32, i1
+ llvm.return
+}
+
+// CHECK-LABEL: @mbarrier_arrive_drop_expect_tx_shared
+llvm.func private @mbarrier_arrive_drop_expect_tx_shared(%barrier: !llvm.ptr<3>, %txcount: i32) {
+ // CHECK: nvvm.mbarrier.arrive_drop.expect_tx %{{.*}}, %{{.*}} : !llvm.ptr<3>, i32 -> i64
+ %0 = nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount : !llvm.ptr<3>, i32 -> i64
+ llvm.return
+}
+
+// CHECK-LABEL: @mbarrier_arrive_drop_expect_tx_cluster
+llvm.func private @mbarrier_arrive_drop_expect_tx_cluster(%barrier: !llvm.ptr<7>, %txcount: i32) {
+ // CHECK: nvvm.mbarrier.arrive_drop.expect_tx %{{.*}}, %{{.*}} : !llvm.ptr<7>, i32{{$}}
+ nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount : !llvm.ptr<7>, i32
+ llvm.return
+}
+
// CHECK-LABEL: @wgmma_fence_aligned
func.func @wgmma_fence_aligned() {
// CHECK: nvvm.wgmma.fence.aligned
@@ -616,6 +672,135 @@ llvm.func @kernel_func(%arg0: !llvm.ptr {nvvm.grid_constant}) attributes {nvvm.k
// -----
+// CHECK-LABEL: @vote_sync_infer_type
+func.func @vote_sync_infer_type(%mask : i32, %pred : i1) {
+ // Ballot infers i32 result
+ // CHECK: nvvm.vote.sync ballot %{{.*}}, %{{.*}} -> i32
+ %0 = nvvm.vote.sync ballot %mask, %pred -> i32
+ // All infers i1 result
+ // CHECK: nvvm.vote.sync all %{{.*}}, %{{.*}} -> i1
+ %1 = nvvm.vote.sync all %mask, %pred -> i1
+ // Any infers i1 result
+ // CHECK: nvvm.vote.sync any %{{.*}}, %{{.*}} -> i1
+ %2 = nvvm.vote.sync any %mask, %pred -> i1
+ // Uni infers i1 result
+ // CHECK: nvvm.vote.sync uni %{{.*}}, %{{.*}} -> i1
+ %3 = nvvm.vote.sync uni %mask, %pred -> i1
+ llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @clusterlaunchcontrol_query_cancel_infer_type
+llvm.func @clusterlaunchcontrol_query_cancel_infer_type(%response : i128) {
+ // is_canceled infers i1
+ // CHECK: nvvm.clusterlaunchcontrol.query.cancel query = is_canceled, %{{.*}} : i1
+ %0 = nvvm.clusterlaunchcontrol.query.cancel query = is_canceled, %response : i1
+ // get_first_cta_id_x infers i32
+ // CHECK: nvvm.clusterlaunchcontrol.query.cancel query = get_first_cta_id_x, %{{.*}} : i32
+ %1 = nvvm.clusterlaunchcontrol.query.cancel query = get_first_cta_id_x, %response : i32
+ // get_first_cta_id_y infers i32
+ // CHECK: nvvm.clusterlaunchcontrol.query.cancel query = get_first_cta_id_y, %{{.*}} : i32
+ %2 = nvvm.clusterlaunchcontrol.query.cancel query = get_first_cta_id_y, %response : i32
+ // get_first_cta_id_z infers i32
+ // CHECK: nvvm.clusterlaunchcontrol.query.cancel query = get_first_cta_id_z, %{{.*}} : i32
+ %3 = nvvm.clusterlaunchcontrol.query.cancel query = get_first_cta_id_z, %response : i32
+ llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @match_sync_infer_type
+func.func @match_sync_infer_type(%mask : i32, %i32val : i32, %i64val : i64) {
+ // any: infers i32
+ // CHECK: nvvm.match.sync any %{{.*}}, %{{.*}} : i32 -> i32
+ %0 = nvvm.match.sync any %mask, %i32val : i32 -> i32
+ // all with i32: infers struct<(i32, i1)>
+ // CHECK: nvvm.match.sync all %{{.*}}, %{{.*}} : i32 -> !llvm.struct<(i32, i1)>
+ %1 = nvvm.match.sync all %mask, %i32val : i32 -> !llvm.struct<(i32, i1)>
+ // all with i64: infers struct<(i32, i1)>
+ // CHECK: nvvm.match.sync all %{{.*}}, %{{.*}} : i64 -> !llvm.struct<(i32, i1)>
+ %2 = nvvm.match.sync all %mask, %i64val : i64 -> !llvm.struct<(i32, i1)>
+ llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @shfl_sync_infer_type
+func.func @shfl_sync_infer_type(%mask : i32, %i32val : i32, %f32val : f32, %offset : i32, %clamp : i32) {
+ // Without return_value_and_is_valid: result = val type
+ // CHECK: nvvm.shfl.sync bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : i32 -> i32
+ %0 = nvvm.shfl.sync bfly %mask, %i32val, %offset, %clamp : i32 -> i32
+ // CHECK: nvvm.shfl.sync bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : f32 -> f32
+ %1 = nvvm.shfl.sync bfly %mask, %f32val, %offset, %clamp : f32 -> f32
+ // With return_value_and_is_valid: result = struct<(val_type, i1)>
+ // CHECK: nvvm.shfl.sync bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {return_value_and_is_valid} : i32 -> !llvm.struct<(i32, i1)>
+ %2 = nvvm.shfl.sync bfly %mask, %i32val, %offset, %clamp {return_value_and_is_valid} : i32 -> !llvm.struct<(i32, i1)>
+ // CHECK: nvvm.shfl.sync bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {return_value_and_is_valid} : f32 -> !llvm.struct<(f32, i1)>
+ %3 = nvvm.shfl.sync bfly %mask, %f32val, %offset, %clamp {return_value_and_is_valid} : f32 -> !llvm.struct<(f32, i1)>
+ llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @ldmatrix_infer_type
+llvm.func @ldmatrix_infer_type(%ptr : !llvm.ptr<3>) {
+ // num=1, 8x8 → i32
+ // CHECK: nvvm.ldmatrix %{{.*}} {{.*}}num = 1{{.*}} : (!llvm.ptr<3>) -> i32
+ %0 = nvvm.ldmatrix %ptr {num = 1 : i32, layout = #nvvm.mma_layout<row>,
+ shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>,
+ eltType = #nvvm.ld_st_matrix_elt_type<b16>} : (!llvm.ptr<3>) -> i32
+ // num=2, 8x8 → struct<(i32, i32)>
+ // CHECK: nvvm.ldmatrix %{{.*}} {{.*}}num = 2{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+ %1 = nvvm.ldmatrix %ptr {num = 2 : i32, layout = #nvvm.mma_layout<row>,
+ shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>,
+ eltType = #nvvm.ld_st_matrix_elt_type<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+ // num=4, 8x8 → struct<(i32, i32, i32, i32)>
+ // CHECK: nvvm.ldmatrix %{{.*}} {{.*}}num = 4{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+ %2 = nvvm.ldmatrix %ptr {num = 4 : i32, layout = #nvvm.mma_layout<row>,
+ shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>,
+ eltType = #nvvm.ld_st_matrix_elt_type<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+ // num=1, 16x16 → numElements=2 → struct<(i32, i32)>
+ // CHECK: nvvm.ldmatrix %{{.*}} {{.*}}num = 1{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+ %3 = nvvm.ldmatrix %ptr {num = 1 : i32, layout = #nvvm.mma_layout<col>,
+ shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>,
+ eltType = #nvvm.ld_st_matrix_elt_type<b8>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+ // num=2, 16x16 → numElements=4 → struct<(i32, i32, i32, i32)>
+ // CHECK: nvvm.ldmatrix %{{.*}} {{.*}}num = 2{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+ %4 = nvvm.ldmatrix %ptr {num = 2 : i32, layout = #nvvm.mma_layout<col>,
+ shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>,
+ eltType = #nvvm.ld_st_matrix_elt_type<b8>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+ llvm.return
+}
+
+// -----
+
+// Negative tests: explicit result type disagrees with inference
+
+func.func @vote_sync_ballot_wrong_type(%mask: i32, %pred: i1) {
+ // expected-error @below {{vote.sync 'ballot' returns an i32}}
+ %0 = nvvm.vote.sync ballot %mask, %pred -> i1
+ return
+}
+
+// -----
+
+func.func @shfl_sync_wrong_result_type(%mask: i32, %val: i32, %off: i32, %clamp: i32) {
+ // expected-error @below {{expected return type to be of type 'i32' but got 'f32'}}
+ %0 = nvvm.shfl.sync bfly %mask, %val, %off, %clamp : i32 -> f32
+ return
+}
+
+// -----
+
+func.func @match_sync_any_wrong_type(%mask: i32, %val: i32) {
+ // expected-error @below {{must be 32-bit signless integer or LLVM struct type}}
+ %0 = nvvm.match.sync any %mask, %val : i32 -> i1
+ return
+}
+
+// -----
+
// expected-error @below {{'"nvvm.grid_constant"' must be a unit attribute}}
llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant = true}) attributes {nvvm.kernel} {
llvm.return
diff --git a/mlir/test/Target/LLVMIR/nvvm/mbar_arrive.mlir b/mlir/test/Target/LLVMIR/nvvm/mbar_arrive.mlir
index f406922ea3873..96c910b193f12 100644
--- a/mlir/test/Target/LLVMIR/nvvm/mbar_arrive.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/mbar_arrive.mlir
@@ -108,8 +108,10 @@ llvm.func @mbarrier_arrive_ignore_retval(%count : i32, %barrier: !llvm.ptr<3>) {
// CHECK-NEXT: %4 = call i64 @llvm.nvvm.mbarrier.arrive.scope.cta.space.cta(ptr addrspace(3) %1, i32 %0)
// CHECK-NEXT: ret void
// CHECK-NEXT: }
+ // Result silently discarded (backward compatible form)
nvvm.mbarrier.arrive %barrier, %count : !llvm.ptr<3>
- nvvm.mbarrier.arrive %barrier, %count : !llvm.ptr<3>
+ // Result explicitly captured
+ %0 = nvvm.mbarrier.arrive %barrier, %count : !llvm.ptr<3> -> i64
llvm.return
}
diff --git a/mlir/test/python/dialects/nvvm.py b/mlir/test/python/dialects/nvvm.py
index dc8e4f462ad56..d80126a5c8646 100644
--- a/mlir/test/python/dialects/nvvm.py
+++ b/mlir/test/python/dialects/nvvm.py
@@ -93,6 +93,37 @@ def my_inline_ptx(a, b, c, d):
arith.addf(wo0, wo1)
+ at constructAndPrintInModule
+def test_mbarrier_arrive():
+ ptr_shared = llvm.PointerType.get(3)
+ ptr_cluster = llvm.PointerType.get(7)
+ i32 = T.i32()
+ i64 = T.i64()
+
+ @func.FuncOp.from_py_func(ptr_shared, ptr_cluster, i32)
+ def mbarrier_arrive_ops(barrier_shared, barrier_cluster, txcount):
+ token = nvvm.mbarrier_arrive(barrier_shared)
+ nvvm.mbarrier_arrive(barrier_cluster)
+ token2 = nvvm.mbarrier_arrive_drop(barrier_shared)
+ nvvm.mbarrier_arrive_drop(barrier_cluster)
+ token3 = nvvm.mbarrier_arrive_expect_tx(barrier_shared, txcount)
+ nvvm.mbarrier_arrive_expect_tx(barrier_cluster, txcount)
+ token4 = nvvm.mbarrier_arrive_drop_expect_tx(barrier_shared, txcount)
+ nvvm.mbarrier_arrive_drop_expect_tx(barrier_cluster, txcount)
+
+
+# CHECK-LABEL: func.func @mbarrier_arrive_ops(
+# CHECK-SAME: %[[SHARED:.*]]: !llvm.ptr<3>, %[[CLUSTER:.*]]: !llvm.ptr<7>, %[[TXCOUNT:.*]]: i32)
+# CHECK: %{{.*}} = nvvm.mbarrier.arrive %[[SHARED]] : !llvm.ptr<3> -> i64
+# CHECK-NEXT: nvvm.mbarrier.arrive %[[CLUSTER]] : !llvm.ptr<7>{{$}}
+# CHECK-NEXT: %{{.*}} = nvvm.mbarrier.arrive_drop %[[SHARED]] : !llvm.ptr<3> -> i64
+# CHECK-NEXT: nvvm.mbarrier.arrive_drop %[[CLUSTER]] : !llvm.ptr<7>{{$}}
+# CHECK-NEXT: %{{.*}} = nvvm.mbarrier.arrive.expect_tx %[[SHARED]], %[[TXCOUNT]] : !llvm.ptr<3>, i32 -> i64
+# CHECK-NEXT: nvvm.mbarrier.arrive.expect_tx %[[CLUSTER]], %[[TXCOUNT]] : !llvm.ptr<7>, i32{{$}}
+# CHECK-NEXT: %{{.*}} = nvvm.mbarrier.arrive_drop.expect_tx %[[SHARED]], %[[TXCOUNT]] : !llvm.ptr<3>, i32 -> i64
+# CHECK-NEXT: nvvm.mbarrier.arrive_drop.expect_tx %[[CLUSTER]], %[[TXCOUNT]] : !llvm.ptr<7>, i32{{$}}
+
+
@constructAndPrintInModule
def test_barriers():
i32 = T.i32()
@@ -102,21 +133,20 @@ def test_barriers():
def barriers(mask, vi32, vf32):
c0 = arith.constant(T.i32(), 0)
cffff = arith.constant(T.i32(), 0xFFFF)
- res = nvvm.barrier(
- res=i32,
+ nvvm.barrier(
barrier_id=c0,
number_of_threads=cffff,
)
+ pred = arith.constant(T.i32(), 1)
for reduction in (
nvvm.BarrierReduction.AND,
nvvm.BarrierReduction.OR,
nvvm.BarrierReduction.POPC,
):
- res = nvvm.barrier(
- res=i32,
+ pred = nvvm.barrier(
reduction_op=reduction,
- reduction_predicate=res,
+ reduction_predicate=pred,
)
nvvm.barrier0()
@@ -129,15 +159,16 @@ def barriers(mask, vi32, vf32):
nvvm.cluster_wait(aligned=True)
nvvm.fence_mbarrier_init()
nvvm.bar_warp_sync(mask)
- return res
+ return pred
# CHECK-LABEL: func.func @barriers(
# CHECK: %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: f32) -> i32 {
# CHECK: %[[CONSTANT_0:.*]] = arith.constant 0 : i32
# CHECK: %[[CONSTANT_1:.*]] = arith.constant 65535 : i32
-# CHECK: %[[BARRIER_0:.*]] = nvvm.barrier id = %[[CONSTANT_0]] number_of_threads = %[[CONSTANT_1]] -> i32
-# CHECK: %[[BARRIER_1:.*]] = nvvm.barrier #nvvm.reduction<and> %[[BARRIER_0]] -> i32
+# CHECK: nvvm.barrier id = %[[CONSTANT_0]] number_of_threads = %[[CONSTANT_1]]
+# CHECK: %[[PRED:.*]] = arith.constant 1 : i32
+# CHECK: %[[BARRIER_1:.*]] = nvvm.barrier #nvvm.reduction<and> %[[PRED]] -> i32
# CHECK: %[[BARRIER_2:.*]] = nvvm.barrier #nvvm.reduction<or> %[[BARRIER_1]] -> i32
# CHECK: %[[BARRIER_3:.*]] = nvvm.barrier #nvvm.reduction<popc> %[[BARRIER_2]] -> i32
# CHECK: nvvm.barrier0
@@ -151,7 +182,6 @@ def barriers(mask, vi32, vf32):
# CHECK: nvvm.fence.mbarrier.init
# CHECK: nvvm.bar.warp.sync %[[ARG0]] : i32
# CHECK: return %[[BARRIER_3]] : i32
-# CHECK: }
@constructAndPrintInModule
More information about the Mlir-commits
mailing list