[Mlir-commits] [mlir] [mlir][NVVM] Add InferTypeOpInterface to sync and ldmatrix ops (PR #188238)
Guray Ozen
llvmlistbot at llvm.org
Thu Apr 9 03:49:52 PDT 2026
https://github.com/grypp updated https://github.com/llvm/llvm-project/pull/188238
>From 7c50d9516b1c7fe52744fdae6a5aea0511487284 Mon Sep 17 00:00:00 2001
From: Bastian Hagedorn <bhagedorn at nvidia.com>
Date: Tue, 24 Mar 2026 11:35:23 +0000
Subject: [PATCH] [mlir][NVVM] Add InferTypeOpInterface to sync and ldmatrix
ops
Add InferTypeOpAdaptor to 5 NVVM ops with deterministic result types:
- VoteSyncOp: ballot -> i32, any/all/uni -> i1
- MatchSyncOp: any -> i32, all -> struct<(i32, i1)>
- ShflOp: result matches val type, or struct<(val_type, i1)> with
return_value_and_is_valid
- LdMatrixOp: i32 or struct of i32s based on num and shape
- ClusterLaunchControlQueryCancelOp: is_canceled -> i1, others -> i32
These ops always have exactly one result, so InferTypeOpAdaptor is
sufficient (no isCompatibleReturnTypes override needed).
This causes the Python binding generator to emit `results=None` as a
default parameter instead of requiring explicit result types.
Note: this is a source-breaking change for Python callers that pass
result types positionally.
Co-Authored-By: Claude <noreply at anthropic.com>
---
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 11 +-
mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 63 ++++++++++
mlir/test/python/dialects/nvvm.py | 127 ++++++++++++++++++++
3 files changed, 196 insertions(+), 5 deletions(-)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 8153111f3cf5a..e0fef69f4f944 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1507,7 +1507,7 @@ def ShflKind : I32EnumAttr<"ShflKind", "NVVM shuffle kind",
def ShflKindAttr : EnumAttr<NVVM_Dialect, ShflKind, "shfl_kind">;
def NVVM_ShflOp :
- NVVM_Op<"shfl.sync", [NVVMRequiresSM<30>]>,
+ NVVM_Op<"shfl.sync", [NVVMRequiresSM<30>, InferTypeOpAdaptor]>,
Results<(outs AnyTypeOf<[I32, F32, LLVMStructType]>:$res)>,
Arguments<(ins I32:$thread_mask,
AnyTypeOf<[I32, F32]>:$val,
@@ -1561,7 +1561,7 @@ def VoteSyncKind : I32EnumAttr<"VoteSyncKind", "NVVM vote sync kind",
def VoteSyncKindAttr : EnumAttr<NVVM_Dialect, VoteSyncKind, "vote_sync_kind">;
def NVVM_VoteSyncOp
- : NVVM_Op<"vote.sync">,
+ : NVVM_Op<"vote.sync", [InferTypeOpAdaptor]>,
Results<(outs AnyTypeOf<[I32, I1]>:$res)>,
Arguments<(ins I32:$mask, I1:$pred, VoteSyncKindAttr:$kind)> {
let summary = "Vote across thread group";
@@ -3220,7 +3220,7 @@ def NVVM_StMatrixOp: NVVM_Op<"stmatrix">,
let hasVerifier = 1;
}
-def NVVM_LdMatrixOp: NVVM_Op<"ldmatrix">,
+def NVVM_LdMatrixOp: NVVM_Op<"ldmatrix", [InferTypeOpAdaptor]>,
Results<(outs AnyType:$res)>,
Arguments<(ins LLVM_PointerShared:$ptr, I32Attr:$num,
MMALayoutAttr:$layout,
@@ -4951,7 +4951,7 @@ def MatchSyncKind : I32EnumAttr<"MatchSyncKind", "NVVM match sync kind",
def MatchSyncKindAttr : EnumAttr<NVVM_Dialect, MatchSyncKind, "match_sync_kind">;
-def NVVM_MatchSyncOp : NVVM_Op<"match.sync">,
+def NVVM_MatchSyncOp : NVVM_Op<"match.sync", [InferTypeOpAdaptor]>,
Results<(outs AnyTypeOf<[I32, LLVMStructType]>:$res)>,
Arguments<(ins I32:$thread_mask,
AnyTypeOf<[I32, I64]>:$val,
@@ -5867,7 +5867,8 @@ def ClusterLaunchControlQueryTypeAttr
}
def NVVM_ClusterLaunchControlQueryCancelOp
- : NVVM_Op<"clusterlaunchcontrol.query.cancel", [NVVMRequiresSM<100>]> {
+ : NVVM_Op<"clusterlaunchcontrol.query.cancel",
+ [NVVMRequiresSM<100>, InferTypeOpAdaptor]> {
let summary = "Query the response of a clusterlaunchcontrol.try.cancel operation";
let description = [{
`clusterlaunchcontrol.query.cancel` queries the response of a
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 442f247445f6e..ab5f6036679f7 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -2271,6 +2271,19 @@ LogicalResult ShflOp::verify() {
return success();
}
+LogicalResult
+ShflOp::inferReturnTypes(MLIRContext *context, std::optional<Location> location,
+ ShflOp::Adaptor adaptor,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ Type valType = adaptor.getVal().getType();
+ if (adaptor.getReturnValueAndIsValid())
+ inferredReturnTypes.push_back(LLVM::LLVMStructType::getLiteral(
+ context, {valType, IntegerType::get(context, 1)}));
+ else
+ inferredReturnTypes.push_back(valType);
+ return success();
+}
+
std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
NVVM::MMAFrag frag, int nRow,
int nCol,
@@ -2477,6 +2490,23 @@ LogicalResult NVVM::LdMatrixOp::verify() {
return success();
}
+LogicalResult LdMatrixOp::inferReturnTypes(
+ MLIRContext *context, std::optional<Location> location,
+ LdMatrixOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
+ uint32_t num = adaptor.getNum();
+ uint32_t m = adaptor.getShape().getM();
+ uint32_t n = adaptor.getShape().getN();
+ uint32_t numElements = (m == 16 && n == 16) ? num * 2 : num;
+
+ Type i32 = IntegerType::get(context, 32);
+ if (numElements == 1)
+ inferredReturnTypes.push_back(i32);
+ else
+ inferredReturnTypes.push_back(LLVM::LLVMStructType::getLiteral(
+ context, SmallVector<Type>(numElements, i32)));
+ return success();
+}
+
LogicalResult NVVM::StMatrixOp::verify() {
int numMatrix = getSources().size();
if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
@@ -2867,6 +2897,19 @@ LogicalResult NVVM::MatchSyncOp::verify() {
return success();
}
+LogicalResult MatchSyncOp::inferReturnTypes(
+ MLIRContext *context, std::optional<Location> location,
+ MatchSyncOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
+ if (adaptor.getKind() == NVVM::MatchSyncKind::all) {
+ inferredReturnTypes.push_back(LLVM::LLVMStructType::getLiteral(
+ context,
+ {IntegerType::get(context, 32), IntegerType::get(context, 1)}));
+ } else {
+ inferredReturnTypes.push_back(IntegerType::get(context, 32));
+ }
+ return success();
+}
+
LogicalResult NVVM::VoteSyncOp::verify() {
if (getKind() == NVVM::VoteSyncKind::ballot) {
if (!getType().isInteger(32)) {
@@ -2880,6 +2923,14 @@ LogicalResult NVVM::VoteSyncOp::verify() {
return success();
}
+LogicalResult VoteSyncOp::inferReturnTypes(
+ MLIRContext *context, std::optional<Location> location,
+ VoteSyncOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
+ unsigned width = adaptor.getKind() == NVVM::VoteSyncKind::ballot ? 32 : 1;
+ inferredReturnTypes.push_back(IntegerType::get(context, width));
+ return success();
+}
+
LogicalResult NVVM::PrefetchOp::verify() {
using MemSpace = NVVM::NVVMMemorySpace;
using CacheLevel = NVVM::PrefetchCacheLevel;
@@ -2973,6 +3024,18 @@ LogicalResult NVVM::ClusterLaunchControlQueryCancelOp::verify() {
return success();
}
+LogicalResult ClusterLaunchControlQueryCancelOp::inferReturnTypes(
+ MLIRContext *context, std::optional<Location> location,
+ ClusterLaunchControlQueryCancelOp::Adaptor adaptor,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ if (adaptor.getQueryType() ==
+ NVVM::ClusterLaunchControlQueryType::IS_CANCELED)
+ inferredReturnTypes.push_back(IntegerType::get(context, 1));
+ else
+ inferredReturnTypes.push_back(IntegerType::get(context, 32));
+ return success();
+}
+
LogicalResult NVVM::ReduxOp::verify() {
mlir::Type reduxType = getType();
diff --git a/mlir/test/python/dialects/nvvm.py b/mlir/test/python/dialects/nvvm.py
index dc8e4f462ad56..b969faa088a46 100644
--- a/mlir/test/python/dialects/nvvm.py
+++ b/mlir/test/python/dialects/nvvm.py
@@ -154,6 +154,133 @@ def barriers(mask, vi32, vf32):
# CHECK: }
+ at constructAndPrintInModule
+def test_vote_sync_infer_type():
+ i1 = IntegerType.get_signless(1)
+ i32 = T.i32()
+
+ @func.FuncOp.from_py_func(i32, i1)
+ def vote_sync_ops(mask, pred):
+ ballot_res = nvvm.vote_sync(mask, pred, nvvm.VoteSyncKind.ballot)
+ any_res = nvvm.vote_sync(mask, pred, nvvm.VoteSyncKind.any)
+ all_res = nvvm.vote_sync(mask, pred, nvvm.VoteSyncKind.all)
+ uni_res = nvvm.vote_sync(mask, pred, nvvm.VoteSyncKind.uni)
+ return ballot_res
+
+
+# CHECK-LABEL: func.func @vote_sync_ops(
+# CHECK-SAME: %[[MASK:.*]]: i32, %[[PRED:.*]]: i1) -> i32 {
+# CHECK: %[[BALLOT:.*]] = nvvm.vote.sync ballot %[[MASK]], %[[PRED]] -> i32
+# CHECK: %[[ANY:.*]] = nvvm.vote.sync any %[[MASK]], %[[PRED]] -> i1
+# CHECK: %[[ALL:.*]] = nvvm.vote.sync all %[[MASK]], %[[PRED]] -> i1
+# CHECK: %[[UNI:.*]] = nvvm.vote.sync uni %[[MASK]], %[[PRED]] -> i1
+# CHECK: return %[[BALLOT]] : i32
+
+
+ at constructAndPrintInModule
+def test_clusterlaunchcontrol_query_cancel_infer_type():
+ i128 = IntegerType.get_signless(128)
+
+ @func.FuncOp.from_py_func(i128)
+ def query_cancel_ops(response):
+ is_canceled = nvvm.clusterlaunchcontrol_query_cancel(
+ nvvm.ClusterLaunchControlQueryType.IS_CANCELED,
+ response,
+ )
+ cta_x = nvvm.clusterlaunchcontrol_query_cancel(
+ nvvm.ClusterLaunchControlQueryType.GET_FIRST_CTA_ID_X,
+ response,
+ )
+ return cta_x
+
+
+# CHECK-LABEL: func.func @query_cancel_ops(
+# CHECK-SAME: %[[RESPONSE:.*]]: i128) -> i32 {
+# CHECK: %{{.*}} = nvvm.clusterlaunchcontrol.query.cancel query = is_canceled, %[[RESPONSE]] : i1
+# CHECK: %[[CTA_X:.*]] = nvvm.clusterlaunchcontrol.query.cancel query = get_first_cta_id_x, %[[RESPONSE]] : i32
+# CHECK: return %[[CTA_X]] : i32
+
+
+ at constructAndPrintInModule
+def test_match_sync_infer_type():
+ i32 = T.i32()
+ i64 = IntegerType.get_signless(64)
+
+ @func.FuncOp.from_py_func(i32, i32, i64)
+ def match_sync_ops(mask, i32val, i64val):
+ any_result = nvvm.match_sync(mask, i32val, nvvm.MatchSyncKind.any)
+ all_result = nvvm.match_sync(mask, i32val, nvvm.MatchSyncKind.all)
+ return any_result
+
+
+# CHECK-LABEL: func.func @match_sync_ops(
+# CHECK-SAME: %[[MASK:.*]]: i32, %[[I32VAL:.*]]: i32, %[[I64VAL:.*]]: i64) -> i32 {
+# CHECK: %[[ANY:.*]] = nvvm.match.sync any %[[MASK]], %[[I32VAL]] : i32 -> i32
+# CHECK: %[[ALL:.*]] = nvvm.match.sync all %[[MASK]], %[[I32VAL]] : i32 -> !llvm.struct<(i32, i1)>
+# CHECK: return %[[ANY]] : i32
+
+
+ at constructAndPrintInModule
+def test_shfl_sync_infer_type():
+ i32 = T.i32()
+ f32 = T.f32()
+
+ @func.FuncOp.from_py_func(i32, i32, f32, i32, i32)
+ def shfl_sync_ops(mask, i32val, f32val, offset, clamp):
+ i32_result = nvvm.shfl_sync(mask, i32val, offset, clamp, nvvm.ShflKind.bfly)
+ f32_result = nvvm.shfl_sync(mask, f32val, offset, clamp, nvvm.ShflKind.bfly)
+ struct_result = nvvm.shfl_sync(
+ mask,
+ i32val,
+ offset,
+ clamp,
+ nvvm.ShflKind.bfly,
+ return_value_and_is_valid=True,
+ )
+ return i32_result
+
+
+# CHECK-LABEL: func.func @shfl_sync_ops(
+# CHECK-SAME: %[[MASK:.*]]: i32, %[[I32VAL:.*]]: i32, %[[F32VAL:.*]]: f32, %[[OFF:.*]]: i32, %[[CLAMP:.*]]: i32) -> i32 {
+# CHECK: %[[I32R:.*]] = nvvm.shfl.sync bfly %[[MASK]], %[[I32VAL]], %[[OFF]], %[[CLAMP]] : i32 -> i32
+# CHECK: %[[F32R:.*]] = nvvm.shfl.sync bfly %[[MASK]], %[[F32VAL]], %[[OFF]], %[[CLAMP]] : f32 -> f32
+# CHECK: %[[STRUCT:.*]] = nvvm.shfl.sync bfly %[[MASK]], %[[I32VAL]], %[[OFF]], %[[CLAMP]] {return_value_and_is_valid} : i32 -> !llvm.struct<(i32, i1)>
+# CHECK: return %[[I32R]] : i32
+
+
+ at constructAndPrintInModule
+def test_ldmatrix_infer_type():
+ ptr_shared = llvm.PointerType.get(3)
+
+ shape_8x8 = Attribute.parse("#nvvm.ld_st_matrix_shape<m = 8, n = 8>")
+ elt_b16 = Attribute.parse("#nvvm.ld_st_matrix_elt_type<b16>")
+
+ @func.FuncOp.from_py_func(ptr_shared)
+ def ldmatrix_ops(ptr):
+ r1 = nvvm.ldmatrix(
+ ptr,
+ num=1,
+ layout=nvvm.MMALayout.row,
+ shape=shape_8x8,
+ elt_type=elt_b16,
+ )
+ r4 = nvvm.ldmatrix(
+ ptr,
+ num=4,
+ layout=nvvm.MMALayout.row,
+ shape=shape_8x8,
+ elt_type=elt_b16,
+ )
+ return r1
+
+
+# CHECK-LABEL: func.func @ldmatrix_ops(
+# CHECK-SAME: %[[PTR:.*]]: !llvm.ptr<3>) -> i32 {
+# CHECK: %[[R1:.*]] = nvvm.ldmatrix %[[PTR]] {{.*}}num = 1{{.*}} : (!llvm.ptr<3>) -> i32
+# CHECK: %[[R4:.*]] = nvvm.ldmatrix %[[PTR]] {{.*}}num = 4{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+# CHECK: return %[[R1]] : i32
+
+
@constructAndPrintInModule
def test_reductions():
i32 = T.i32()
More information about the Mlir-commits
mailing list