[Mlir-commits] [mlir] [mlir][NVVM] Add InferTypeOpInterface to NVVM ops with deterministic result types (PR #188173)

Bastian Hagedorn llvmlistbot at llvm.org
Tue Mar 24 01:15:37 PDT 2026


https://github.com/bastianhagedorn updated https://github.com/llvm/llvm-project/pull/188173

>From afe4c1c0fa2c417877ed06ecd1ecf710d5c32a19 Mon Sep 17 00:00:00 2001
From: Bastian Hagedorn <bhagedorn at nvidia.com>
Date: Mon, 23 Mar 2026 08:31:52 +0000
Subject: [PATCH] [mlir][NVVM] Add InferTypeOpInterface to NVVM ops with
 deterministic result types

Add `InferTypeOpAdaptor` to 10 NVVM ops whose result types can be
fully determined from their operands and attributes, enabling automatic
result type inference at op construction time.

This causes the Python binding generator to emit `results=None` as a
default parameter instead of requiring explicit result types, which
simplifies downstream wrapper code.

Ops with optional results (0 or 1 results based on operand properties):
- MBarrierArriveOp, MBarrierArriveDropOp, MBarrierArriveDropExpectTxOp:
  i64 result for non-shared_cluster pointers, no result otherwise
- MBarrierArriveExpectTxOp: same, plus no result when predicate is set
- BarrierOp: i32 result when reductionOp attribute is present

These ops use InferTypeOpAdaptorWithIsCompatible with a permissive
isCompatibleReturnTypes that allows omitting the result (backward
compatible with the existing zero-result assembly form).

Ops with fixed results (type depends on operands/attributes):
- VoteSyncOp: ballot -> i32, any/all/uni -> i1
- MatchSyncOp: any -> i32, all -> struct<(i32, i1)>
- ShflOp: result matches val type, or struct<(val_type, i1)> with flag
- LdMatrixOp: i32 or struct of i32s based on num and shape
- ClusterLaunchControlQueryCancelOp: is_canceled -> i1, others -> i32

Also updates the NVGPUToNVVM conversion to work with the new inferred
result types for MBarrierArriveOp and MBarrierArriveExpectTxOp.

Note: this is a source-breaking change for Python callers that pass
result types positionally. The generated Python signature changes from
e.g. `mbarrier_arrive(res, addr, ...)` to
`mbarrier_arrive(addr, *, ..., results=None, ...)`.
---
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td   |  26 ++-
 .../Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp    |  13 +-
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp    | 172 ++++++++++++++
 mlir/test/Dialect/LLVMIR/nvvm.mlir            | 185 +++++++++++++++
 mlir/test/Target/LLVMIR/nvvm/mbar_arrive.mlir |   4 +-
 mlir/test/python/dialects/nvvm.py             | 218 +++++++++++++++++-
 6 files changed, 594 insertions(+), 24 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 57a0c67e82c47..fccab166ac5ae 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -720,7 +720,8 @@ def NVVM_MBarrierCompleteTxOp : NVVM_VoidIntrinsicOp<"mbarrier.complete_tx"> {
   let hasVerifier = 1;
 }
 
-def NVVM_MBarrierArriveOp : NVVM_SingleResultIntrinsicOp<"mbarrier.arrive"> {
+def NVVM_MBarrierArriveOp : NVVM_SingleResultIntrinsicOp<"mbarrier.arrive",
+    [InferTypeOpAdaptorWithIsCompatible]> {
   let summary = "MBarrier Arrive Operation";
   let description = [{
     The `nvvm.mbarrier.arrive` operation performs an arrive-on operation on the 
@@ -768,7 +769,8 @@ def NVVM_MBarrierArriveOp : NVVM_SingleResultIntrinsicOp<"mbarrier.arrive"> {
   let hasVerifier = 1;
 }
 
-def NVVM_MBarrierArriveDropOp : NVVM_SingleResultIntrinsicOp<"mbarrier.arrive_drop"> {
+def NVVM_MBarrierArriveDropOp : NVVM_SingleResultIntrinsicOp<"mbarrier.arrive_drop",
+    [InferTypeOpAdaptorWithIsCompatible]> {
   let summary = "MBarrier Arrive-Drop Operation";
   let description = [{
     The `nvvm.mbarrier.arrive_drop` operation decrements the expected arrival
@@ -844,7 +846,8 @@ def NVVM_MBarrierArriveDropNocompleteOp : NVVM_SingleResultIntrinsicOp<"mbarrier
   let assemblyFormat = "$addr `,` $count attr-dict `:` type(operands) `->` type($res)";
 }
 
-def NVVM_MBarrierArriveExpectTxOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx"> {
+def NVVM_MBarrierArriveExpectTxOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_tx",
+    [InferTypeOpAdaptorWithIsCompatible]> {
   let summary = "MBarrier Arrive with Expected Transaction Count";
   let description = [{
     The `nvvm.mbarrier.arrive.expect_tx` operation performs an expect-tx operation 
@@ -910,7 +913,8 @@ def NVVM_MBarrierArriveExpectTxOp : NVVM_PTXBuilder_Op<"mbarrier.arrive.expect_t
   }];
 }
 
-def NVVM_MBarrierArriveDropExpectTxOp : NVVM_SingleResultIntrinsicOp<"mbarrier.arrive_drop.expect_tx"> {
+def NVVM_MBarrierArriveDropExpectTxOp : NVVM_SingleResultIntrinsicOp<"mbarrier.arrive_drop.expect_tx",
+    [InferTypeOpAdaptorWithIsCompatible]> {
   let summary = "MBarrier arrive_drop with expected transaction count";
   let description = [{
     The `nvvm.mbarrier.arrive_drop.expect_tx` operation is similar to the
@@ -1123,7 +1127,8 @@ def BarrierReductionAttr
   let assemblyFormat = "`<` $value `>`";
 }
 
-def NVVM_BarrierOp : NVVM_SingleResultIntrinsicOp<"barrier", [AttrSizedOperandSegments]> {
+def NVVM_BarrierOp : NVVM_SingleResultIntrinsicOp<"barrier",
+    [AttrSizedOperandSegments, InferTypeOpAdaptorWithIsCompatible]> {
   let summary = "CTA Barrier Synchronization Op";
   let description = [{
     The `nvvm.barrier` operation performs barrier synchronization and communication 
@@ -1504,7 +1509,7 @@ def ShflKind : I32EnumAttr<"ShflKind", "NVVM shuffle kind",
 def ShflKindAttr : EnumAttr<NVVM_Dialect, ShflKind, "shfl_kind">;
 
 def NVVM_ShflOp :
-  NVVM_Op<"shfl.sync", [NVVMRequiresSM<30>]>,
+  NVVM_Op<"shfl.sync", [NVVMRequiresSM<30>, InferTypeOpAdaptor]>,
   Results<(outs AnyTypeOf<[I32, F32, LLVMStructType]>:$res)>,
   Arguments<(ins I32:$thread_mask,
                  AnyTypeOf<[I32, F32]>:$val,
@@ -1558,7 +1563,7 @@ def VoteSyncKind : I32EnumAttr<"VoteSyncKind", "NVVM vote sync kind",
 def VoteSyncKindAttr : EnumAttr<NVVM_Dialect, VoteSyncKind, "vote_sync_kind">;
 
 def NVVM_VoteSyncOp
-    : NVVM_Op<"vote.sync">,
+    : NVVM_Op<"vote.sync", [InferTypeOpAdaptor]>,
       Results<(outs AnyTypeOf<[I32, I1]>:$res)>,
       Arguments<(ins I32:$mask, I1:$pred, VoteSyncKindAttr:$kind)> {
   let summary = "Vote across thread group";
@@ -3028,7 +3033,7 @@ def NVVM_StMatrixOp: NVVM_Op<"stmatrix">,
   let hasVerifier = 1;
 }
 
-def NVVM_LdMatrixOp: NVVM_Op<"ldmatrix">,
+def NVVM_LdMatrixOp: NVVM_Op<"ldmatrix", [InferTypeOpAdaptor]>,
   Results<(outs AnyType:$res)>,
   Arguments<(ins LLVM_PointerShared:$ptr, I32Attr:$num,
                  MMALayoutAttr:$layout,
@@ -4783,7 +4788,7 @@ def MatchSyncKind : I32EnumAttr<"MatchSyncKind", "NVVM match sync kind",
 
 def MatchSyncKindAttr : EnumAttr<NVVM_Dialect, MatchSyncKind, "match_sync_kind">;
 
-def NVVM_MatchSyncOp : NVVM_Op<"match.sync">,
+def NVVM_MatchSyncOp : NVVM_Op<"match.sync", [InferTypeOpAdaptor]>,
   Results<(outs AnyTypeOf<[I32, LLVMStructType]>:$res)>,
   Arguments<(ins I32:$thread_mask,
                  AnyTypeOf<[I32, I64]>:$val,
@@ -5723,7 +5728,8 @@ def ClusterLaunchControlQueryTypeAttr
 }
 
 def NVVM_ClusterLaunchControlQueryCancelOp
-    : NVVM_Op<"clusterlaunchcontrol.query.cancel", [NVVMRequiresSM<100>]> {
+    : NVVM_Op<"clusterlaunchcontrol.query.cancel",
+              [NVVMRequiresSM<100>, InferTypeOpAdaptor]> {
   let summary = "Query the response of a clusterlaunchcontrol.try.cancel operation";
   let description = [{
     `clusterlaunchcontrol.query.cancel` queries the response of a 
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 6edc8f5c86dd3..956019f490c4c 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -852,9 +852,7 @@ struct NVGPUMBarrierArriveLowering
     Value barrier =
         getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
                        adaptor.getMbarId(), rewriter);
-    Type tokenType = getTypeConverter()->convertType(
-        nvgpu::MBarrierTokenType::get(op->getContext()));
-    rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveOp>(op, tokenType, barrier);
+    rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveOp>(op, barrier);
     return success();
   }
 };
@@ -911,12 +909,17 @@ struct NVGPUMBarrierArriveExpectTxLowering
         getMbarrierPtr(b, op.getBarriers().getType(), adaptor.getBarriers(),
                        adaptor.getMbarId(), rewriter);
     Value txcount = truncToI32(b, adaptor.getTxcount());
-    rewriter.replaceOpWithNewOp<NVVM::MBarrierArriveExpectTxOp>(
-        op, Type{},       // return-value is optional and is void by default
+    // Use create+eraseOp instead of replaceOpWithNewOp because the nvgpu op
+    // has 0 results, but the NVVM op now infers an i64 result (via
+    // InferTypeOpInterface) for non-cluster pointers. The create() call
+    // triggers InferTypeOpInterface automatically when no result types are
+    // provided; the inferred i64 result is simply unused.
+    NVVM::MBarrierArriveExpectTxOp::create(rewriter, op->getLoc(),
         barrier, txcount, // barrier and txcount
         NVVM::MemScopeKind::CTA, // default scope is CTA
         false,                   // relaxed-semantics is false
         adaptor.getPredicate());
+    rewriter.eraseOp(op);
     return success();
   }
 };
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 7d49aa3878ebe..edf142cfdca58 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -302,6 +302,88 @@ LogicalResult MBarrierArriveDropExpectTxOp::verify() {
                                     getRes());
 }
 
+//===----------------------------------------------------------------------===//
+// inferReturnTypes for mbarrier arrive-like ops
+//===----------------------------------------------------------------------===//
+
+/// Shared inference: shared_cluster addr -> no result; otherwise -> i64.
+/// Generic pointers (address space 0) are treated as shared::cta per the PTX
+/// ISA: "If no state space is specified then Generic Addressing is used. If the
+/// address does not fall within .shared::cta then the behavior is undefined."
+/// Therefore only ptr<7> (shared_cluster) produces zero results.
+static LogicalResult inferMBarrierArriveResultTypes(
+    MLIRContext *context, Value addr,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  auto ptrTy = llvm::dyn_cast<LLVM::LLVMPointerType>(addr.getType());
+  if (!ptrTy)
+    return failure();
+  if (ptrTy.getAddressSpace() !=
+      static_cast<unsigned>(NVVMMemorySpace::SharedCluster))
+    inferredReturnTypes.push_back(IntegerType::get(context, 64));
+  return success();
+}
+
+LogicalResult MBarrierArriveOp::inferReturnTypes(
+    MLIRContext *context, std::optional<Location> location,
+    MBarrierArriveOp::Adaptor adaptor,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  return inferMBarrierArriveResultTypes(context, adaptor.getAddr(),
+                                        inferredReturnTypes);
+}
+
+LogicalResult MBarrierArriveDropOp::inferReturnTypes(
+    MLIRContext *context, std::optional<Location> location,
+    MBarrierArriveDropOp::Adaptor adaptor,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  return inferMBarrierArriveResultTypes(context, adaptor.getAddr(),
+                                        inferredReturnTypes);
+}
+
+LogicalResult MBarrierArriveExpectTxOp::inferReturnTypes(
+    MLIRContext *context, std::optional<Location> location,
+    MBarrierArriveExpectTxOp::Adaptor adaptor,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  // Predicate forces no return value (inline PTX path).
+  // Note: predicate + shared_cluster is rejected by the verifier separately.
+  if (adaptor.getPredicate())
+    return success();
+  return inferMBarrierArriveResultTypes(context, adaptor.getAddr(),
+                                        inferredReturnTypes);
+}
+
+LogicalResult MBarrierArriveDropExpectTxOp::inferReturnTypes(
+    MLIRContext *context, std::optional<Location> location,
+    MBarrierArriveDropExpectTxOp::Adaptor adaptor,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  return inferMBarrierArriveResultTypes(context, adaptor.getAddr(),
+                                        inferredReturnTypes);
+}
+
+/// For ops with optional results, allow the user to omit the result even when
+/// inference would produce one. This preserves backward compatibility: the
+/// result can be silently discarded (e.g., for fire-and-forget arrive ops).
+static bool isCompatibleReturnTypesOptionalResult(TypeRange inferred,
+                                                  TypeRange actual) {
+  if (actual.empty())
+    return true;
+  return inferred == actual;
+}
+
+bool MBarrierArriveOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
+  return isCompatibleReturnTypesOptionalResult(l, r);
+}
+bool MBarrierArriveDropOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
+  return isCompatibleReturnTypesOptionalResult(l, r);
+}
+bool MBarrierArriveExpectTxOp::isCompatibleReturnTypes(TypeRange l,
+                                                       TypeRange r) {
+  return isCompatibleReturnTypesOptionalResult(l, r);
+}
+bool MBarrierArriveDropExpectTxOp::isCompatibleReturnTypes(TypeRange l,
+                                                           TypeRange r) {
+  return isCompatibleReturnTypesOptionalResult(l, r);
+}
+
 LogicalResult MBarrierExpectTxOp::verify() {
   return verifyMBarrierArriveLikeOp(getOperation(), getAddr(), getScope());
 }
@@ -2244,6 +2326,23 @@ LogicalResult ShflOp::verify() {
   return success();
 }
 
+LogicalResult ShflOp::inferReturnTypes(
+    MLIRContext *context, std::optional<Location> location,
+    ShflOp::Adaptor adaptor,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  Type valType = adaptor.getVal().getType();
+  if (adaptor.getReturnValueAndIsValid()) {
+    // struct<(val_type, i1)>
+    SmallVector<Type> body = {valType, IntegerType::get(context, 1)};
+    inferredReturnTypes.push_back(
+        LLVM::LLVMStructType::getLiteral(context, body));
+  } else {
+    // same type as val
+    inferredReturnTypes.push_back(valType);
+  }
+  return success();
+}
+
 std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
                                                    NVVM::MMAFrag frag, int nRow,
                                                    int nCol,
@@ -2450,6 +2549,26 @@ LogicalResult NVVM::LdMatrixOp::verify() {
   return success();
 }
 
+LogicalResult LdMatrixOp::inferReturnTypes(
+    MLIRContext *context, std::optional<Location> location,
+    LdMatrixOp::Adaptor adaptor,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  uint32_t num = adaptor.getNum();
+  uint32_t m = adaptor.getShape().getM();
+  uint32_t n = adaptor.getShape().getN();
+  uint32_t numElements = (m == 16 && n == 16) ? num * 2 : num;
+
+  Type i32 = IntegerType::get(context, 32);
+  if (numElements == 1) {
+    inferredReturnTypes.push_back(i32);
+  } else {
+    inferredReturnTypes.push_back(
+        LLVM::LLVMStructType::getLiteral(context,
+                                         SmallVector<Type>(numElements, i32)));
+  }
+  return success();
+}
+
 LogicalResult NVVM::StMatrixOp::verify() {
   int numMatrix = getSources().size();
   if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
@@ -2798,6 +2917,19 @@ LogicalResult NVVM::BarrierOp::verify() {
   return success();
 }
 
+LogicalResult BarrierOp::inferReturnTypes(
+    MLIRContext *context, std::optional<Location> location,
+    BarrierOp::Adaptor adaptor,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  if (adaptor.getReductionOp())
+    inferredReturnTypes.push_back(IntegerType::get(context, 32));
+  return success();
+}
+
+bool BarrierOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
+  return isCompatibleReturnTypesOptionalResult(l, r);
+}
+
 LogicalResult NVVM::Tcgen05CpOp::verify() {
   auto mc = getMulticast();
 
@@ -2840,6 +2972,23 @@ LogicalResult NVVM::MatchSyncOp::verify() {
   return success();
 }
 
+LogicalResult MatchSyncOp::inferReturnTypes(
+    MLIRContext *context, std::optional<Location> location,
+    MatchSyncOp::Adaptor adaptor,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  if (adaptor.getKind() == NVVM::MatchSyncKind::all) {
+    // all: returns struct<(i32, i1)>
+    SmallVector<Type> body = {IntegerType::get(context, 32),
+                                 IntegerType::get(context, 1)};
+    inferredReturnTypes.push_back(
+        LLVM::LLVMStructType::getLiteral(context, body));
+  } else {
+    // any: returns i32
+    inferredReturnTypes.push_back(IntegerType::get(context, 32));
+  }
+  return success();
+}
+
 LogicalResult NVVM::VoteSyncOp::verify() {
   if (getKind() == NVVM::VoteSyncKind::ballot) {
     if (!getType().isInteger(32)) {
@@ -2853,6 +3002,17 @@ LogicalResult NVVM::VoteSyncOp::verify() {
   return success();
 }
 
+LogicalResult VoteSyncOp::inferReturnTypes(
+    MLIRContext *context, std::optional<Location> location,
+    VoteSyncOp::Adaptor adaptor,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  if (adaptor.getKind() == NVVM::VoteSyncKind::ballot)
+    inferredReturnTypes.push_back(IntegerType::get(context, 32));
+  else
+    inferredReturnTypes.push_back(IntegerType::get(context, 1));
+  return success();
+}
+
 LogicalResult NVVM::PrefetchOp::verify() {
   using MemSpace = NVVM::NVVMMemorySpace;
   using CacheLevel = NVVM::PrefetchCacheLevel;
@@ -2946,6 +3106,18 @@ LogicalResult NVVM::ClusterLaunchControlQueryCancelOp::verify() {
   return success();
 }
 
+LogicalResult ClusterLaunchControlQueryCancelOp::inferReturnTypes(
+    MLIRContext *context, std::optional<Location> location,
+    ClusterLaunchControlQueryCancelOp::Adaptor adaptor,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  if (adaptor.getQueryType() ==
+      NVVM::ClusterLaunchControlQueryType::IS_CANCELED)
+    inferredReturnTypes.push_back(IntegerType::get(context, 1));
+  else
+    inferredReturnTypes.push_back(IntegerType::get(context, 32));
+  return success();
+}
+
 LogicalResult NVVM::ReduxOp::verify() {
   mlir::Type reduxType = getType();
 
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index b11ba944fe4ac..fad9c66bc694a 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -443,6 +443,13 @@ llvm.func private @mbarrier_arrive_shared(%barrier: !llvm.ptr<3>) {
   llvm.return
 }
 
+// CHECK-LABEL: @mbarrier_arrive_cluster
+llvm.func private @mbarrier_arrive_cluster(%barrier: !llvm.ptr<7>) {
+  // CHECK:   nvvm.mbarrier.arrive %{{.*}} : !llvm.ptr<7>{{$}}
+  nvvm.mbarrier.arrive %barrier : !llvm.ptr<7>
+  llvm.return
+}
+
 llvm.func private @mbarrier_arrive_nocomplete(%barrier: !llvm.ptr) {
   %count = nvvm.read.ptx.sreg.ntid.x : i32
   // CHECK:   nvvm.mbarrier.arrive.nocomplete %{{.*}} : !llvm.ptr
@@ -457,6 +464,55 @@ llvm.func private @mbarrier_arrive_nocomplete_shared(%barrier: !llvm.ptr<3>) {
   llvm.return
 }
 
+// CHECK-LABEL: @mbarrier_arrive_drop_shared
+llvm.func private @mbarrier_arrive_drop_shared(%barrier: !llvm.ptr<3>) {
+  // CHECK:   nvvm.mbarrier.arrive_drop %{{.*}} : !llvm.ptr<3> -> i64
+  %0 = nvvm.mbarrier.arrive_drop %barrier : !llvm.ptr<3> -> i64
+  llvm.return
+}
+
+// CHECK-LABEL: @mbarrier_arrive_drop_cluster
+llvm.func private @mbarrier_arrive_drop_cluster(%barrier: !llvm.ptr<7>) {
+  // CHECK:   nvvm.mbarrier.arrive_drop %{{.*}} : !llvm.ptr<7>{{$}}
+  nvvm.mbarrier.arrive_drop %barrier : !llvm.ptr<7>
+  llvm.return
+}
+
+// CHECK-LABEL: @mbarrier_arrive_expect_tx_shared
+llvm.func private @mbarrier_arrive_expect_tx_shared(%barrier: !llvm.ptr<3>, %txcount: i32) {
+  // CHECK:   nvvm.mbarrier.arrive.expect_tx %{{.*}}, %{{.*}} : !llvm.ptr<3>, i32 -> i64
+  %0 = nvvm.mbarrier.arrive.expect_tx %barrier, %txcount : !llvm.ptr<3>, i32 -> i64
+  llvm.return
+}
+
+// CHECK-LABEL: @mbarrier_arrive_expect_tx_cluster
+llvm.func private @mbarrier_arrive_expect_tx_cluster(%barrier: !llvm.ptr<7>, %txcount: i32) {
+  // CHECK:   nvvm.mbarrier.arrive.expect_tx %{{.*}}, %{{.*}} : !llvm.ptr<7>, i32{{$}}
+  nvvm.mbarrier.arrive.expect_tx %barrier, %txcount : !llvm.ptr<7>, i32
+  llvm.return
+}
+
+// CHECK-LABEL: @mbarrier_arrive_expect_tx_predicate
+llvm.func private @mbarrier_arrive_expect_tx_predicate(%barrier: !llvm.ptr<3>, %txcount: i32, %pred: i1) {
+  // CHECK:   nvvm.mbarrier.arrive.expect_tx %{{.*}}, %{{.*}}, predicate = %{{.*}} : !llvm.ptr<3>, i32, i1{{$}}
+  nvvm.mbarrier.arrive.expect_tx %barrier, %txcount, predicate = %pred : !llvm.ptr<3>, i32, i1
+  llvm.return
+}
+
+// CHECK-LABEL: @mbarrier_arrive_drop_expect_tx_shared
+llvm.func private @mbarrier_arrive_drop_expect_tx_shared(%barrier: !llvm.ptr<3>, %txcount: i32) {
+  // CHECK:   nvvm.mbarrier.arrive_drop.expect_tx %{{.*}}, %{{.*}} : !llvm.ptr<3>, i32 -> i64
+  %0 = nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount : !llvm.ptr<3>, i32 -> i64
+  llvm.return
+}
+
+// CHECK-LABEL: @mbarrier_arrive_drop_expect_tx_cluster
+llvm.func private @mbarrier_arrive_drop_expect_tx_cluster(%barrier: !llvm.ptr<7>, %txcount: i32) {
+  // CHECK:   nvvm.mbarrier.arrive_drop.expect_tx %{{.*}}, %{{.*}} : !llvm.ptr<7>, i32{{$}}
+  nvvm.mbarrier.arrive_drop.expect_tx %barrier, %txcount : !llvm.ptr<7>, i32
+  llvm.return
+}
+
 // CHECK-LABEL: @wgmma_fence_aligned
 func.func @wgmma_fence_aligned() {
   // CHECK: nvvm.wgmma.fence.aligned
@@ -616,6 +672,135 @@ llvm.func @kernel_func(%arg0: !llvm.ptr {nvvm.grid_constant}) attributes {nvvm.k
 
 // -----
 
+// CHECK-LABEL: @vote_sync_infer_type
+func.func @vote_sync_infer_type(%mask : i32, %pred : i1) {
+  // Ballot infers i32 result
+  // CHECK: nvvm.vote.sync ballot %{{.*}}, %{{.*}} -> i32
+  %0 = nvvm.vote.sync ballot %mask, %pred -> i32
+  // All infers i1 result
+  // CHECK: nvvm.vote.sync all %{{.*}}, %{{.*}} -> i1
+  %1 = nvvm.vote.sync all %mask, %pred -> i1
+  // Any infers i1 result
+  // CHECK: nvvm.vote.sync any %{{.*}}, %{{.*}} -> i1
+  %2 = nvvm.vote.sync any %mask, %pred -> i1
+  // Uni infers i1 result
+  // CHECK: nvvm.vote.sync uni %{{.*}}, %{{.*}} -> i1
+  %3 = nvvm.vote.sync uni %mask, %pred -> i1
+  llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @clusterlaunchcontrol_query_cancel_infer_type
+llvm.func @clusterlaunchcontrol_query_cancel_infer_type(%response : i128) {
+  // is_canceled infers i1
+  // CHECK: nvvm.clusterlaunchcontrol.query.cancel query = is_canceled, %{{.*}} : i1
+  %0 = nvvm.clusterlaunchcontrol.query.cancel query = is_canceled, %response : i1
+  // get_first_cta_id_x infers i32
+  // CHECK: nvvm.clusterlaunchcontrol.query.cancel query = get_first_cta_id_x, %{{.*}} : i32
+  %1 = nvvm.clusterlaunchcontrol.query.cancel query = get_first_cta_id_x, %response : i32
+  // get_first_cta_id_y infers i32
+  // CHECK: nvvm.clusterlaunchcontrol.query.cancel query = get_first_cta_id_y, %{{.*}} : i32
+  %2 = nvvm.clusterlaunchcontrol.query.cancel query = get_first_cta_id_y, %response : i32
+  // get_first_cta_id_z infers i32
+  // CHECK: nvvm.clusterlaunchcontrol.query.cancel query = get_first_cta_id_z, %{{.*}} : i32
+  %3 = nvvm.clusterlaunchcontrol.query.cancel query = get_first_cta_id_z, %response : i32
+  llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @match_sync_infer_type
+func.func @match_sync_infer_type(%mask : i32, %i32val : i32, %i64val : i64) {
+  // any: infers i32
+  // CHECK: nvvm.match.sync any %{{.*}}, %{{.*}} : i32 -> i32
+  %0 = nvvm.match.sync any %mask, %i32val : i32 -> i32
+  // all with i32: infers struct<(i32, i1)>
+  // CHECK: nvvm.match.sync all %{{.*}}, %{{.*}} : i32 -> !llvm.struct<(i32, i1)>
+  %1 = nvvm.match.sync all %mask, %i32val : i32 -> !llvm.struct<(i32, i1)>
+  // all with i64: infers struct<(i32, i1)>
+  // CHECK: nvvm.match.sync all %{{.*}}, %{{.*}} : i64 -> !llvm.struct<(i32, i1)>
+  %2 = nvvm.match.sync all %mask, %i64val : i64 -> !llvm.struct<(i32, i1)>
+  llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @shfl_sync_infer_type
+func.func @shfl_sync_infer_type(%mask : i32, %i32val : i32, %f32val : f32, %offset : i32, %clamp : i32) {
+  // Without return_value_and_is_valid: result = val type
+  // CHECK: nvvm.shfl.sync bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : i32 -> i32
+  %0 = nvvm.shfl.sync bfly %mask, %i32val, %offset, %clamp : i32 -> i32
+  // CHECK: nvvm.shfl.sync bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : f32 -> f32
+  %1 = nvvm.shfl.sync bfly %mask, %f32val, %offset, %clamp : f32 -> f32
+  // With return_value_and_is_valid: result = struct<(val_type, i1)>
+  // CHECK: nvvm.shfl.sync bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {return_value_and_is_valid} : i32 -> !llvm.struct<(i32, i1)>
+  %2 = nvvm.shfl.sync bfly %mask, %i32val, %offset, %clamp {return_value_and_is_valid} : i32 -> !llvm.struct<(i32, i1)>
+  // CHECK: nvvm.shfl.sync bfly %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} {return_value_and_is_valid} : f32 -> !llvm.struct<(f32, i1)>
+  %3 = nvvm.shfl.sync bfly %mask, %f32val, %offset, %clamp {return_value_and_is_valid} : f32 -> !llvm.struct<(f32, i1)>
+  llvm.return
+}
+
+// -----
+
+// CHECK-LABEL: @ldmatrix_infer_type
+llvm.func @ldmatrix_infer_type(%ptr : !llvm.ptr<3>) {
+  // num=1, 8x8 → i32
+  // CHECK: nvvm.ldmatrix %{{.*}} {{.*}}num = 1{{.*}} : (!llvm.ptr<3>) -> i32
+  %0 = nvvm.ldmatrix %ptr {num = 1 : i32, layout = #nvvm.mma_layout<row>,
+    shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>,
+    eltType = #nvvm.ld_st_matrix_elt_type<b16>} : (!llvm.ptr<3>) -> i32
+  // num=2, 8x8 → struct<(i32, i32)>
+  // CHECK: nvvm.ldmatrix %{{.*}} {{.*}}num = 2{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+  %1 = nvvm.ldmatrix %ptr {num = 2 : i32, layout = #nvvm.mma_layout<row>,
+    shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>,
+    eltType = #nvvm.ld_st_matrix_elt_type<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+  // num=4, 8x8 → struct<(i32, i32, i32, i32)>
+  // CHECK: nvvm.ldmatrix %{{.*}} {{.*}}num = 4{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+  %2 = nvvm.ldmatrix %ptr {num = 4 : i32, layout = #nvvm.mma_layout<row>,
+    shape = #nvvm.ld_st_matrix_shape<m = 8, n = 8>,
+    eltType = #nvvm.ld_st_matrix_elt_type<b16>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+  // num=1, 16x16 → numElements=2 → struct<(i32, i32)>
+  // CHECK: nvvm.ldmatrix %{{.*}} {{.*}}num = 1{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+  %3 = nvvm.ldmatrix %ptr {num = 1 : i32, layout = #nvvm.mma_layout<col>,
+    shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>,
+    eltType = #nvvm.ld_st_matrix_elt_type<b8>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)>
+  // num=2, 16x16 → numElements=4 → struct<(i32, i32, i32, i32)>
+  // CHECK: nvvm.ldmatrix %{{.*}} {{.*}}num = 2{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+  %4 = nvvm.ldmatrix %ptr {num = 2 : i32, layout = #nvvm.mma_layout<col>,
+    shape = #nvvm.ld_st_matrix_shape<m = 16, n = 16>,
+    eltType = #nvvm.ld_st_matrix_elt_type<b8>} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+  llvm.return
+}
+
+// -----
+
+// Negative tests: explicit result type disagrees with inference
+
+func.func @vote_sync_ballot_wrong_type(%mask: i32, %pred: i1) {
+  // expected-error @below {{vote.sync 'ballot' returns an i32}}
+  %0 = nvvm.vote.sync ballot %mask, %pred -> i1
+  return
+}
+
+// -----
+
+func.func @shfl_sync_wrong_result_type(%mask: i32, %val: i32, %off: i32, %clamp: i32) {
+  // expected-error @below {{expected return type to be of type 'i32' but got 'f32'}}
+  %0 = nvvm.shfl.sync bfly %mask, %val, %off, %clamp : i32 -> f32
+  return
+}
+
+// -----
+
+func.func @match_sync_any_wrong_type(%mask: i32, %val: i32) {
+  // expected-error @below {{must be 32-bit signless integer or LLVM struct type}}
+  %0 = nvvm.match.sync any %mask, %val : i32 -> i1
+  return
+}
+
+// -----
+
 // expected-error @below {{'"nvvm.grid_constant"' must be a unit attribute}}
 llvm.func @kernel_func(%arg0: !llvm.ptr {llvm.byval = i32, nvvm.grid_constant = true}) attributes {nvvm.kernel} {
   llvm.return
diff --git a/mlir/test/Target/LLVMIR/nvvm/mbar_arrive.mlir b/mlir/test/Target/LLVMIR/nvvm/mbar_arrive.mlir
index f406922ea3873..96c910b193f12 100644
--- a/mlir/test/Target/LLVMIR/nvvm/mbar_arrive.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/mbar_arrive.mlir
@@ -108,8 +108,10 @@ llvm.func @mbarrier_arrive_ignore_retval(%count : i32, %barrier: !llvm.ptr<3>) {
   // CHECK-NEXT: %4 = call i64 @llvm.nvvm.mbarrier.arrive.scope.cta.space.cta(ptr addrspace(3) %1, i32 %0)
   // CHECK-NEXT: ret void
   // CHECK-NEXT: }
+  // Result silently discarded (backward compatible form)
   nvvm.mbarrier.arrive %barrier, %count : !llvm.ptr<3>
-  nvvm.mbarrier.arrive %barrier, %count : !llvm.ptr<3>
+  // Result explicitly captured
+  %0 = nvvm.mbarrier.arrive %barrier, %count : !llvm.ptr<3> -> i64
 
   llvm.return
 }
diff --git a/mlir/test/python/dialects/nvvm.py b/mlir/test/python/dialects/nvvm.py
index dc8e4f462ad56..7e63adaf23037 100644
--- a/mlir/test/python/dialects/nvvm.py
+++ b/mlir/test/python/dialects/nvvm.py
@@ -93,6 +93,52 @@ def my_inline_ptx(a, b, c, d):
         arith.addf(wo0, wo1)
 
 
+ at constructAndPrintInModule
+def test_mbarrier_arrive():
+    ptr_shared = llvm.PointerType.get(3)
+    ptr_cluster = llvm.PointerType.get(7)
+    i32 = T.i32()
+    i64 = T.i64()
+
+    @func.FuncOp.from_py_func(ptr_shared, ptr_cluster, i32)
+    def mbarrier_arrive_ops(barrier_shared, barrier_cluster, txcount):
+        # mbarrier.arrive on shared pointer -> infers i64 result
+        token = nvvm.mbarrier_arrive(barrier_shared)
+
+        # mbarrier.arrive on cluster pointer -> infers no result
+        nvvm.mbarrier_arrive(barrier_cluster)
+
+        # mbarrier.arrive_drop on shared -> i64
+        token2 = nvvm.mbarrier_arrive_drop(barrier_shared)
+
+        # mbarrier.arrive_drop on cluster -> no result
+        nvvm.mbarrier_arrive_drop(barrier_cluster)
+
+        # mbarrier.arrive.expect_tx on shared -> i64
+        token3 = nvvm.mbarrier_arrive_expect_tx(barrier_shared, txcount)
+
+        # mbarrier.arrive.expect_tx on cluster -> no result
+        nvvm.mbarrier_arrive_expect_tx(barrier_cluster, txcount)
+
+        # mbarrier.arrive_drop.expect_tx on shared -> i64
+        token4 = nvvm.mbarrier_arrive_drop_expect_tx(barrier_shared, txcount)
+
+        # mbarrier.arrive_drop.expect_tx on cluster -> no result
+        nvvm.mbarrier_arrive_drop_expect_tx(barrier_cluster, txcount)
+
+
+# CHECK-LABEL:   func.func @mbarrier_arrive_ops(
+# CHECK-SAME:      %[[SHARED:.*]]: !llvm.ptr<3>, %[[CLUSTER:.*]]: !llvm.ptr<7>, %[[TXCOUNT:.*]]: i32)
+# CHECK:           %{{.*}} = nvvm.mbarrier.arrive %[[SHARED]] : !llvm.ptr<3> -> i64
+# CHECK-NEXT:      nvvm.mbarrier.arrive %[[CLUSTER]] : !llvm.ptr<7>{{$}}
+# CHECK-NEXT:      %{{.*}} = nvvm.mbarrier.arrive_drop %[[SHARED]] : !llvm.ptr<3> -> i64
+# CHECK-NEXT:      nvvm.mbarrier.arrive_drop %[[CLUSTER]] : !llvm.ptr<7>{{$}}
+# CHECK-NEXT:      %{{.*}} = nvvm.mbarrier.arrive.expect_tx %[[SHARED]], %[[TXCOUNT]] : !llvm.ptr<3>, i32 -> i64
+# CHECK-NEXT:      nvvm.mbarrier.arrive.expect_tx %[[CLUSTER]], %[[TXCOUNT]] : !llvm.ptr<7>, i32{{$}}
+# CHECK-NEXT:      %{{.*}} = nvvm.mbarrier.arrive_drop.expect_tx %[[SHARED]], %[[TXCOUNT]] : !llvm.ptr<3>, i32 -> i64
+# CHECK-NEXT:      nvvm.mbarrier.arrive_drop.expect_tx %[[CLUSTER]], %[[TXCOUNT]] : !llvm.ptr<7>, i32{{$}}
+
+
 @constructAndPrintInModule
 def test_barriers():
     i32 = T.i32()
@@ -102,21 +148,23 @@ def test_barriers():
     def barriers(mask, vi32, vf32):
         c0 = arith.constant(T.i32(), 0)
         cffff = arith.constant(T.i32(), 0xFFFF)
-        res = nvvm.barrier(
-            res=i32,
+        # barrier without reduction -> no result (inferred)
+        nvvm.barrier(
             barrier_id=c0,
             number_of_threads=cffff,
         )
 
+        # first reduction to get an initial predicate value
+        pred = arith.constant(T.i32(), 1)
         for reduction in (
             nvvm.BarrierReduction.AND,
             nvvm.BarrierReduction.OR,
             nvvm.BarrierReduction.POPC,
         ):
-            res = nvvm.barrier(
-                res=i32,
+            # barrier with reduction -> infers i32 result
+            pred = nvvm.barrier(
                 reduction_op=reduction,
-                reduction_predicate=res,
+                reduction_predicate=pred,
             )
 
         nvvm.barrier0()
@@ -129,15 +177,16 @@ def barriers(mask, vi32, vf32):
         nvvm.cluster_wait(aligned=True)
         nvvm.fence_mbarrier_init()
         nvvm.bar_warp_sync(mask)
-        return res
+        return pred
 
 
 # CHECK-LABEL:   func.func @barriers(
 # CHECK:           %[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: f32) -> i32 {
 # CHECK:           %[[CONSTANT_0:.*]] = arith.constant 0 : i32
 # CHECK:           %[[CONSTANT_1:.*]] = arith.constant 65535 : i32
-# CHECK:           %[[BARRIER_0:.*]] = nvvm.barrier id = %[[CONSTANT_0]] number_of_threads = %[[CONSTANT_1]] -> i32
-# CHECK:           %[[BARRIER_1:.*]] = nvvm.barrier #nvvm.reduction<and> %[[BARRIER_0]] -> i32
+# CHECK:           nvvm.barrier id = %[[CONSTANT_0]] number_of_threads = %[[CONSTANT_1]]
+# CHECK:           %[[PRED:.*]] = arith.constant 1 : i32
+# CHECK:           %[[BARRIER_1:.*]] = nvvm.barrier #nvvm.reduction<and> %[[PRED]] -> i32
 # CHECK:           %[[BARRIER_2:.*]] = nvvm.barrier #nvvm.reduction<or> %[[BARRIER_1]] -> i32
 # CHECK:           %[[BARRIER_3:.*]] = nvvm.barrier #nvvm.reduction<popc> %[[BARRIER_2]] -> i32
 # CHECK:           nvvm.barrier0
@@ -154,6 +203,159 @@ def barriers(mask, vi32, vf32):
 # CHECK:         }
 
 
+ at constructAndPrintInModule
+def test_vote_sync_infer_type():
+    i1 = IntegerType.get_signless(1)
+    i32 = T.i32()
+
+    @func.FuncOp.from_py_func(i32, i1)
+    def vote_sync_ops(mask, pred):
+        # ballot -> infers i32 result
+        ballot_res = nvvm.vote_sync(mask, pred, nvvm.VoteSyncKind.ballot)
+
+        # any -> infers i1 result
+        any_res = nvvm.vote_sync(mask, pred, nvvm.VoteSyncKind.any)
+
+        # all -> infers i1 result
+        all_res = nvvm.vote_sync(mask, pred, nvvm.VoteSyncKind.all)
+
+        # uni -> infers i1 result
+        uni_res = nvvm.vote_sync(mask, pred, nvvm.VoteSyncKind.uni)
+
+        return ballot_res
+
+
+# CHECK-LABEL:   func.func @vote_sync_ops(
+# CHECK-SAME:      %[[MASK:.*]]: i32, %[[PRED:.*]]: i1) -> i32 {
+# CHECK:           %[[BALLOT:.*]] = nvvm.vote.sync ballot %[[MASK]], %[[PRED]] -> i32
+# CHECK:           %[[ANY:.*]] = nvvm.vote.sync any %[[MASK]], %[[PRED]] -> i1
+# CHECK:           %[[ALL:.*]] = nvvm.vote.sync all %[[MASK]], %[[PRED]] -> i1
+# CHECK:           %[[UNI:.*]] = nvvm.vote.sync uni %[[MASK]], %[[PRED]] -> i1
+# CHECK:           return %[[BALLOT]] : i32
+# CHECK:         }
+
+
+ at constructAndPrintInModule
+def test_clusterlaunchcontrol_query_cancel_infer_type():
+    i128 = IntegerType.get_signless(128)
+
+    @func.FuncOp.from_py_func(i128)
+    def query_cancel_ops(response):
+        # is_canceled -> infers i1 result
+        is_canceled = nvvm.clusterlaunchcontrol_query_cancel(
+            nvvm.ClusterLaunchControlQueryType.IS_CANCELED,
+            response,
+        )
+
+        # get_first_cta_id_x -> infers i32 result
+        cta_x = nvvm.clusterlaunchcontrol_query_cancel(
+            nvvm.ClusterLaunchControlQueryType.GET_FIRST_CTA_ID_X,
+            response,
+        )
+
+        return cta_x
+
+
+# CHECK-LABEL:   func.func @query_cancel_ops(
+# CHECK-SAME:      %[[RESPONSE:.*]]: i128) -> i32 {
+# CHECK:           %{{.*}} = nvvm.clusterlaunchcontrol.query.cancel query = is_canceled, %[[RESPONSE]] : i1
+# CHECK:           %[[CTA_X:.*]] = nvvm.clusterlaunchcontrol.query.cancel query = get_first_cta_id_x, %[[RESPONSE]] : i32
+# CHECK:           return %[[CTA_X]] : i32
+# CHECK:         }
+
+
+ at constructAndPrintInModule
+def test_match_sync_infer_type():
+    i32 = T.i32()
+    i64 = IntegerType.get_signless(64)
+
+    @func.FuncOp.from_py_func(i32, i32, i64)
+    def match_sync_ops(mask, i32val, i64val):
+        # any -> infers i32
+        any_result = nvvm.match_sync(mask, i32val, nvvm.MatchSyncKind.any)
+
+        # all -> infers struct<(i32, i1)>
+        all_result = nvvm.match_sync(mask, i32val, nvvm.MatchSyncKind.all)
+
+        return any_result
+
+
+# CHECK-LABEL:   func.func @match_sync_ops(
+# CHECK-SAME:      %[[MASK:.*]]: i32, %[[I32VAL:.*]]: i32, %[[I64VAL:.*]]: i64) -> i32 {
+# CHECK:           %[[ANY:.*]] = nvvm.match.sync any %[[MASK]], %[[I32VAL]] : i32 -> i32
+# CHECK:           %[[ALL:.*]] = nvvm.match.sync all %[[MASK]], %[[I32VAL]] : i32 -> !llvm.struct<(i32, i1)>
+# CHECK:           return %[[ANY]] : i32
+# CHECK:         }
+
+
+ at constructAndPrintInModule
+def test_shfl_sync_infer_type():
+    i32 = T.i32()
+    f32 = T.f32()
+
+    @func.FuncOp.from_py_func(i32, i32, f32, i32, i32)
+    def shfl_sync_ops(mask, i32val, f32val, offset, clamp):
+        # Without return_value_and_is_valid: result = val type
+        i32_result = nvvm.shfl_sync(
+            mask, i32val, offset, clamp, nvvm.ShflKind.bfly
+        )
+
+        f32_result = nvvm.shfl_sync(
+            mask, f32val, offset, clamp, nvvm.ShflKind.bfly
+        )
+
+        # With return_value_and_is_valid: result = struct
+        struct_result = nvvm.shfl_sync(
+            mask, i32val, offset, clamp, nvvm.ShflKind.bfly,
+            return_value_and_is_valid=True
+        )
+
+        return i32_result
+
+
+# CHECK-LABEL:   func.func @shfl_sync_ops(
+# CHECK-SAME:      %[[MASK:.*]]: i32, %[[I32VAL:.*]]: i32, %[[F32VAL:.*]]: f32, %[[OFF:.*]]: i32, %[[CLAMP:.*]]: i32) -> i32 {
+# CHECK:           %[[I32R:.*]] = nvvm.shfl.sync bfly %[[MASK]], %[[I32VAL]], %[[OFF]], %[[CLAMP]] : i32 -> i32
+# CHECK:           %[[F32R:.*]] = nvvm.shfl.sync bfly %[[MASK]], %[[F32VAL]], %[[OFF]], %[[CLAMP]] : f32 -> f32
+# CHECK:           %[[STRUCT:.*]] = nvvm.shfl.sync bfly %[[MASK]], %[[I32VAL]], %[[OFF]], %[[CLAMP]] {return_value_and_is_valid} : i32 -> !llvm.struct<(i32, i1)>
+# CHECK:           return %[[I32R]] : i32
+# CHECK:         }
+
+
+ at constructAndPrintInModule
+def test_ldmatrix_infer_type():
+    ptr_shared = llvm.PointerType.get(3)
+
+    shape_8x8 = Attribute.parse("#nvvm.ld_st_matrix_shape<m = 8, n = 8>")
+    elt_b16 = Attribute.parse("#nvvm.ld_st_matrix_elt_type<b16>")
+
+    @func.FuncOp.from_py_func(ptr_shared)
+    def ldmatrix_ops(ptr):
+        # num=1 → infers i32
+        r1 = nvvm.ldmatrix(
+            ptr, num=1, layout=nvvm.MMALayout.row,
+            shape=shape_8x8,
+            elt_type=elt_b16,
+        )
+
+        # num=4 → infers struct<(i32, i32, i32, i32)>
+        r4 = nvvm.ldmatrix(
+            ptr, num=4, layout=nvvm.MMALayout.row,
+            shape=shape_8x8,
+            elt_type=elt_b16,
+        )
+
+        return r1
+
+
+# CHECK-LABEL:   func.func @ldmatrix_ops(
+# CHECK-SAME:      %[[PTR:.*]]: !llvm.ptr<3>) -> i32 {
+# CHECK:           %[[R1:.*]] = nvvm.ldmatrix %[[PTR]] {{.*}}num = 1{{.*}} : (!llvm.ptr<3>) -> i32
+# CHECK:           %[[R4:.*]] = nvvm.ldmatrix %[[PTR]] {{.*}}num = 4{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+# CHECK:           return %[[R1]] : i32
+# CHECK:         }
+
+
 @constructAndPrintInModule
 def test_reductions():
     i32 = T.i32()



More information about the Mlir-commits mailing list