[Mlir-commits] [mlir] [mlir][NVVM] Add InferTypeOpInterface to sync and ldmatrix ops (PR #188238)

Guray Ozen llvmlistbot at llvm.org
Thu Apr 9 03:49:52 PDT 2026


https://github.com/grypp updated https://github.com/llvm/llvm-project/pull/188238

>From 7c50d9516b1c7fe52744fdae6a5aea0511487284 Mon Sep 17 00:00:00 2001
From: Bastian Hagedorn <bhagedorn at nvidia.com>
Date: Tue, 24 Mar 2026 11:35:23 +0000
Subject: [PATCH] [mlir][NVVM] Add InferTypeOpInterface to sync and ldmatrix
 ops

Add InferTypeOpAdaptor to 5 NVVM ops with deterministic result types:

- VoteSyncOp: ballot -> i32, any/all/uni -> i1
- MatchSyncOp: any -> i32, all -> struct<(i32, i1)>
- ShflOp: result matches val type, or struct<(val_type, i1)> with
  return_value_and_is_valid
- LdMatrixOp: i32 or struct of i32s based on num and shape
- ClusterLaunchControlQueryCancelOp: is_canceled -> i1, others -> i32

These ops always have exactly one result, so InferTypeOpAdaptor is
sufficient (no isCompatibleReturnTypes override needed).

This causes the Python binding generator to emit `results=None` as a
default parameter instead of requiring explicit result types.

Note: this is a source-breaking change for Python callers that pass
result types positionally.

Co-Authored-By: Claude <noreply at anthropic.com>
---
 mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td |  11 +-
 mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp  |  63 ++++++++++
 mlir/test/python/dialects/nvvm.py           | 127 ++++++++++++++++++++
 3 files changed, 196 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 8153111f3cf5a..e0fef69f4f944 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1507,7 +1507,7 @@ def ShflKind : I32EnumAttr<"ShflKind", "NVVM shuffle kind",
 def ShflKindAttr : EnumAttr<NVVM_Dialect, ShflKind, "shfl_kind">;
 
 def NVVM_ShflOp :
-  NVVM_Op<"shfl.sync", [NVVMRequiresSM<30>]>,
+  NVVM_Op<"shfl.sync", [NVVMRequiresSM<30>, InferTypeOpAdaptor]>,
   Results<(outs AnyTypeOf<[I32, F32, LLVMStructType]>:$res)>,
   Arguments<(ins I32:$thread_mask,
                  AnyTypeOf<[I32, F32]>:$val,
@@ -1561,7 +1561,7 @@ def VoteSyncKind : I32EnumAttr<"VoteSyncKind", "NVVM vote sync kind",
 def VoteSyncKindAttr : EnumAttr<NVVM_Dialect, VoteSyncKind, "vote_sync_kind">;
 
 def NVVM_VoteSyncOp
-    : NVVM_Op<"vote.sync">,
+    : NVVM_Op<"vote.sync", [InferTypeOpAdaptor]>,
       Results<(outs AnyTypeOf<[I32, I1]>:$res)>,
       Arguments<(ins I32:$mask, I1:$pred, VoteSyncKindAttr:$kind)> {
   let summary = "Vote across thread group";
@@ -3220,7 +3220,7 @@ def NVVM_StMatrixOp: NVVM_Op<"stmatrix">,
   let hasVerifier = 1;
 }
 
-def NVVM_LdMatrixOp: NVVM_Op<"ldmatrix">,
+def NVVM_LdMatrixOp: NVVM_Op<"ldmatrix", [InferTypeOpAdaptor]>,
   Results<(outs AnyType:$res)>,
   Arguments<(ins LLVM_PointerShared:$ptr, I32Attr:$num,
                  MMALayoutAttr:$layout,
@@ -4951,7 +4951,7 @@ def MatchSyncKind : I32EnumAttr<"MatchSyncKind", "NVVM match sync kind",
 
 def MatchSyncKindAttr : EnumAttr<NVVM_Dialect, MatchSyncKind, "match_sync_kind">;
 
-def NVVM_MatchSyncOp : NVVM_Op<"match.sync">,
+def NVVM_MatchSyncOp : NVVM_Op<"match.sync", [InferTypeOpAdaptor]>,
   Results<(outs AnyTypeOf<[I32, LLVMStructType]>:$res)>,
   Arguments<(ins I32:$thread_mask,
                  AnyTypeOf<[I32, I64]>:$val,
@@ -5867,7 +5867,8 @@ def ClusterLaunchControlQueryTypeAttr
 }
 
 def NVVM_ClusterLaunchControlQueryCancelOp
-    : NVVM_Op<"clusterlaunchcontrol.query.cancel", [NVVMRequiresSM<100>]> {
+    : NVVM_Op<"clusterlaunchcontrol.query.cancel",
+              [NVVMRequiresSM<100>, InferTypeOpAdaptor]> {
   let summary = "Query the response of a clusterlaunchcontrol.try.cancel operation";
   let description = [{
     `clusterlaunchcontrol.query.cancel` queries the response of a 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 442f247445f6e..ab5f6036679f7 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -2271,6 +2271,19 @@ LogicalResult ShflOp::verify() {
   return success();
 }
 
+LogicalResult
+ShflOp::inferReturnTypes(MLIRContext *context, std::optional<Location> location,
+                         ShflOp::Adaptor adaptor,
+                         SmallVectorImpl<Type> &inferredReturnTypes) {
+  Type valType = adaptor.getVal().getType();
+  if (adaptor.getReturnValueAndIsValid())
+    inferredReturnTypes.push_back(LLVM::LLVMStructType::getLiteral(
+        context, {valType, IntegerType::get(context, 1)}));
+  else
+    inferredReturnTypes.push_back(valType);
+  return success();
+}
+
 std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
                                                    NVVM::MMAFrag frag, int nRow,
                                                    int nCol,
@@ -2477,6 +2490,23 @@ LogicalResult NVVM::LdMatrixOp::verify() {
   return success();
 }
 
+LogicalResult LdMatrixOp::inferReturnTypes(
+    MLIRContext *context, std::optional<Location> location,
+    LdMatrixOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
+  uint32_t num = adaptor.getNum();
+  uint32_t m = adaptor.getShape().getM();
+  uint32_t n = adaptor.getShape().getN();
+  uint32_t numElements = (m == 16 && n == 16) ? num * 2 : num;
+
+  Type i32 = IntegerType::get(context, 32);
+  if (numElements == 1)
+    inferredReturnTypes.push_back(i32);
+  else
+    inferredReturnTypes.push_back(LLVM::LLVMStructType::getLiteral(
+        context, SmallVector<Type>(numElements, i32)));
+  return success();
+}
+
 LogicalResult NVVM::StMatrixOp::verify() {
   int numMatrix = getSources().size();
   if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
@@ -2867,6 +2897,19 @@ LogicalResult NVVM::MatchSyncOp::verify() {
   return success();
 }
 
+LogicalResult MatchSyncOp::inferReturnTypes(
+    MLIRContext *context, std::optional<Location> location,
+    MatchSyncOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
+  if (adaptor.getKind() == NVVM::MatchSyncKind::all) {
+    inferredReturnTypes.push_back(LLVM::LLVMStructType::getLiteral(
+        context,
+        {IntegerType::get(context, 32), IntegerType::get(context, 1)}));
+  } else {
+    inferredReturnTypes.push_back(IntegerType::get(context, 32));
+  }
+  return success();
+}
+
 LogicalResult NVVM::VoteSyncOp::verify() {
   if (getKind() == NVVM::VoteSyncKind::ballot) {
     if (!getType().isInteger(32)) {
@@ -2880,6 +2923,14 @@ LogicalResult NVVM::VoteSyncOp::verify() {
   return success();
 }
 
+LogicalResult VoteSyncOp::inferReturnTypes(
+    MLIRContext *context, std::optional<Location> location,
+    VoteSyncOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
+  unsigned width = adaptor.getKind() == NVVM::VoteSyncKind::ballot ? 32 : 1;
+  inferredReturnTypes.push_back(IntegerType::get(context, width));
+  return success();
+}
+
 LogicalResult NVVM::PrefetchOp::verify() {
   using MemSpace = NVVM::NVVMMemorySpace;
   using CacheLevel = NVVM::PrefetchCacheLevel;
@@ -2973,6 +3024,18 @@ LogicalResult NVVM::ClusterLaunchControlQueryCancelOp::verify() {
   return success();
 }
 
+LogicalResult ClusterLaunchControlQueryCancelOp::inferReturnTypes(
+    MLIRContext *context, std::optional<Location> location,
+    ClusterLaunchControlQueryCancelOp::Adaptor adaptor,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  if (adaptor.getQueryType() ==
+      NVVM::ClusterLaunchControlQueryType::IS_CANCELED)
+    inferredReturnTypes.push_back(IntegerType::get(context, 1));
+  else
+    inferredReturnTypes.push_back(IntegerType::get(context, 32));
+  return success();
+}
+
 LogicalResult NVVM::ReduxOp::verify() {
   mlir::Type reduxType = getType();
 
diff --git a/mlir/test/python/dialects/nvvm.py b/mlir/test/python/dialects/nvvm.py
index dc8e4f462ad56..b969faa088a46 100644
--- a/mlir/test/python/dialects/nvvm.py
+++ b/mlir/test/python/dialects/nvvm.py
@@ -154,6 +154,133 @@ def barriers(mask, vi32, vf32):
 # CHECK:         }
 
 
+ at constructAndPrintInModule
+def test_vote_sync_infer_type():
+    i1 = IntegerType.get_signless(1)
+    i32 = T.i32()
+
+    @func.FuncOp.from_py_func(i32, i1)
+    def vote_sync_ops(mask, pred):
+        ballot_res = nvvm.vote_sync(mask, pred, nvvm.VoteSyncKind.ballot)
+        any_res = nvvm.vote_sync(mask, pred, nvvm.VoteSyncKind.any)
+        all_res = nvvm.vote_sync(mask, pred, nvvm.VoteSyncKind.all)
+        uni_res = nvvm.vote_sync(mask, pred, nvvm.VoteSyncKind.uni)
+        return ballot_res
+
+
+# CHECK-LABEL:   func.func @vote_sync_ops(
+# CHECK-SAME:      %[[MASK:.*]]: i32, %[[PRED:.*]]: i1) -> i32 {
+# CHECK:           %[[BALLOT:.*]] = nvvm.vote.sync ballot %[[MASK]], %[[PRED]] -> i32
+# CHECK:           %[[ANY:.*]] = nvvm.vote.sync any %[[MASK]], %[[PRED]] -> i1
+# CHECK:           %[[ALL:.*]] = nvvm.vote.sync all %[[MASK]], %[[PRED]] -> i1
+# CHECK:           %[[UNI:.*]] = nvvm.vote.sync uni %[[MASK]], %[[PRED]] -> i1
+# CHECK:           return %[[BALLOT]] : i32
+
+
+ at constructAndPrintInModule
+def test_clusterlaunchcontrol_query_cancel_infer_type():
+    i128 = IntegerType.get_signless(128)
+
+    @func.FuncOp.from_py_func(i128)
+    def query_cancel_ops(response):
+        is_canceled = nvvm.clusterlaunchcontrol_query_cancel(
+            nvvm.ClusterLaunchControlQueryType.IS_CANCELED,
+            response,
+        )
+        cta_x = nvvm.clusterlaunchcontrol_query_cancel(
+            nvvm.ClusterLaunchControlQueryType.GET_FIRST_CTA_ID_X,
+            response,
+        )
+        return cta_x
+
+
+# CHECK-LABEL:   func.func @query_cancel_ops(
+# CHECK-SAME:      %[[RESPONSE:.*]]: i128) -> i32 {
+# CHECK:           %{{.*}} = nvvm.clusterlaunchcontrol.query.cancel query = is_canceled, %[[RESPONSE]] : i1
+# CHECK:           %[[CTA_X:.*]] = nvvm.clusterlaunchcontrol.query.cancel query = get_first_cta_id_x, %[[RESPONSE]] : i32
+# CHECK:           return %[[CTA_X]] : i32
+
+
+ at constructAndPrintInModule
+def test_match_sync_infer_type():
+    i32 = T.i32()
+    i64 = IntegerType.get_signless(64)
+
+    @func.FuncOp.from_py_func(i32, i32, i64)
+    def match_sync_ops(mask, i32val, i64val):
+        any_result = nvvm.match_sync(mask, i32val, nvvm.MatchSyncKind.any)
+        all_result = nvvm.match_sync(mask, i32val, nvvm.MatchSyncKind.all)
+        return any_result
+
+
+# CHECK-LABEL:   func.func @match_sync_ops(
+# CHECK-SAME:      %[[MASK:.*]]: i32, %[[I32VAL:.*]]: i32, %[[I64VAL:.*]]: i64) -> i32 {
+# CHECK:           %[[ANY:.*]] = nvvm.match.sync any %[[MASK]], %[[I32VAL]] : i32 -> i32
+# CHECK:           %[[ALL:.*]] = nvvm.match.sync all %[[MASK]], %[[I32VAL]] : i32 -> !llvm.struct<(i32, i1)>
+# CHECK:           return %[[ANY]] : i32
+
+
+ at constructAndPrintInModule
+def test_shfl_sync_infer_type():
+    i32 = T.i32()
+    f32 = T.f32()
+
+    @func.FuncOp.from_py_func(i32, i32, f32, i32, i32)
+    def shfl_sync_ops(mask, i32val, f32val, offset, clamp):
+        i32_result = nvvm.shfl_sync(mask, i32val, offset, clamp, nvvm.ShflKind.bfly)
+        f32_result = nvvm.shfl_sync(mask, f32val, offset, clamp, nvvm.ShflKind.bfly)
+        struct_result = nvvm.shfl_sync(
+            mask,
+            i32val,
+            offset,
+            clamp,
+            nvvm.ShflKind.bfly,
+            return_value_and_is_valid=True,
+        )
+        return i32_result
+
+
+# CHECK-LABEL:   func.func @shfl_sync_ops(
+# CHECK-SAME:      %[[MASK:.*]]: i32, %[[I32VAL:.*]]: i32, %[[F32VAL:.*]]: f32, %[[OFF:.*]]: i32, %[[CLAMP:.*]]: i32) -> i32 {
+# CHECK:           %[[I32R:.*]] = nvvm.shfl.sync bfly %[[MASK]], %[[I32VAL]], %[[OFF]], %[[CLAMP]] : i32 -> i32
+# CHECK:           %[[F32R:.*]] = nvvm.shfl.sync bfly %[[MASK]], %[[F32VAL]], %[[OFF]], %[[CLAMP]] : f32 -> f32
+# CHECK:           %[[STRUCT:.*]] = nvvm.shfl.sync bfly %[[MASK]], %[[I32VAL]], %[[OFF]], %[[CLAMP]] {return_value_and_is_valid} : i32 -> !llvm.struct<(i32, i1)>
+# CHECK:           return %[[I32R]] : i32
+
+
+ at constructAndPrintInModule
+def test_ldmatrix_infer_type():
+    ptr_shared = llvm.PointerType.get(3)
+
+    shape_8x8 = Attribute.parse("#nvvm.ld_st_matrix_shape<m = 8, n = 8>")
+    elt_b16 = Attribute.parse("#nvvm.ld_st_matrix_elt_type<b16>")
+
+    @func.FuncOp.from_py_func(ptr_shared)
+    def ldmatrix_ops(ptr):
+        r1 = nvvm.ldmatrix(
+            ptr,
+            num=1,
+            layout=nvvm.MMALayout.row,
+            shape=shape_8x8,
+            elt_type=elt_b16,
+        )
+        r4 = nvvm.ldmatrix(
+            ptr,
+            num=4,
+            layout=nvvm.MMALayout.row,
+            shape=shape_8x8,
+            elt_type=elt_b16,
+        )
+        return r1
+
+
+# CHECK-LABEL:   func.func @ldmatrix_ops(
+# CHECK-SAME:      %[[PTR:.*]]: !llvm.ptr<3>) -> i32 {
+# CHECK:           %[[R1:.*]] = nvvm.ldmatrix %[[PTR]] {{.*}}num = 1{{.*}} : (!llvm.ptr<3>) -> i32
+# CHECK:           %[[R4:.*]] = nvvm.ldmatrix %[[PTR]] {{.*}}num = 4{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+# CHECK:           return %[[R1]] : i32
+
+
 @constructAndPrintInModule
 def test_reductions():
     i32 = T.i32()



More information about the Mlir-commits mailing list