[Mlir-commits] [mlir] [mlir][NVVM] Add InferTypeOpInterface to sync and ldmatrix ops (PR #188238)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Mar 24 05:24:26 PDT 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-llvm

Author: Bastian Hagedorn (bastianhagedorn)

<details>
<summary>Changes</summary>

Add InferTypeOpAdaptor to 5 NVVM ops with deterministic result types:

- VoteSyncOp: ballot -> i32, any/all/uni -> i1
- MatchSyncOp: any -> i32, all -> struct<(i32, i1)>
- ShflOp: result matches val type, or struct<(val_type, i1)> with return_value_and_is_valid
- LdMatrixOp: i32 or struct of i32s based on num and shape
- ClusterLaunchControlQueryCancelOp: is_canceled -> i1, others -> i32

Note: this is a source-breaking change for Python callers that pass result types positionally.

---
Full diff: https://github.com/llvm/llvm-project/pull/188238.diff


3 Files Affected:

- (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+6-5) 
- (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+73) 
- (modified) mlir/test/python/dialects/nvvm.py (+128) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 57a0c67e82c47..b977f26635979 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -1504,7 +1504,7 @@ def ShflKind : I32EnumAttr<"ShflKind", "NVVM shuffle kind",
 def ShflKindAttr : EnumAttr<NVVM_Dialect, ShflKind, "shfl_kind">;
 
 def NVVM_ShflOp :
-  NVVM_Op<"shfl.sync", [NVVMRequiresSM<30>]>,
+  NVVM_Op<"shfl.sync", [NVVMRequiresSM<30>, InferTypeOpAdaptor]>,
   Results<(outs AnyTypeOf<[I32, F32, LLVMStructType]>:$res)>,
   Arguments<(ins I32:$thread_mask,
                  AnyTypeOf<[I32, F32]>:$val,
@@ -1558,7 +1558,7 @@ def VoteSyncKind : I32EnumAttr<"VoteSyncKind", "NVVM vote sync kind",
 def VoteSyncKindAttr : EnumAttr<NVVM_Dialect, VoteSyncKind, "vote_sync_kind">;
 
 def NVVM_VoteSyncOp
-    : NVVM_Op<"vote.sync">,
+    : NVVM_Op<"vote.sync", [InferTypeOpAdaptor]>,
       Results<(outs AnyTypeOf<[I32, I1]>:$res)>,
       Arguments<(ins I32:$mask, I1:$pred, VoteSyncKindAttr:$kind)> {
   let summary = "Vote across thread group";
@@ -3028,7 +3028,7 @@ def NVVM_StMatrixOp: NVVM_Op<"stmatrix">,
   let hasVerifier = 1;
 }
 
-def NVVM_LdMatrixOp: NVVM_Op<"ldmatrix">,
+def NVVM_LdMatrixOp: NVVM_Op<"ldmatrix", [InferTypeOpAdaptor]>,
   Results<(outs AnyType:$res)>,
   Arguments<(ins LLVM_PointerShared:$ptr, I32Attr:$num,
                  MMALayoutAttr:$layout,
@@ -4783,7 +4783,7 @@ def MatchSyncKind : I32EnumAttr<"MatchSyncKind", "NVVM match sync kind",
 
 def MatchSyncKindAttr : EnumAttr<NVVM_Dialect, MatchSyncKind, "match_sync_kind">;
 
-def NVVM_MatchSyncOp : NVVM_Op<"match.sync">,
+def NVVM_MatchSyncOp : NVVM_Op<"match.sync", [InferTypeOpAdaptor]>,
   Results<(outs AnyTypeOf<[I32, LLVMStructType]>:$res)>,
   Arguments<(ins I32:$thread_mask,
                  AnyTypeOf<[I32, I64]>:$val,
@@ -5723,7 +5723,8 @@ def ClusterLaunchControlQueryTypeAttr
 }
 
 def NVVM_ClusterLaunchControlQueryCancelOp
-    : NVVM_Op<"clusterlaunchcontrol.query.cancel", [NVVMRequiresSM<100>]> {
+    : NVVM_Op<"clusterlaunchcontrol.query.cancel",
+              [NVVMRequiresSM<100>, InferTypeOpAdaptor]> {
   let summary = "Query the response of a clusterlaunchcontrol.try.cancel operation";
   let description = [{
     `clusterlaunchcontrol.query.cancel` queries the response of a 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 528e709629ebf..f0e2fbf2ac560 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -2244,6 +2244,21 @@ LogicalResult ShflOp::verify() {
   return success();
 }
 
+LogicalResult ShflOp::inferReturnTypes(
+    MLIRContext *context, std::optional<Location> location,
+    ShflOp::Adaptor adaptor,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  Type valType = adaptor.getVal().getType();
+  if (adaptor.getReturnValueAndIsValid()) {
+    SmallVector<Type> body = {valType, IntegerType::get(context, 1)};
+    inferredReturnTypes.push_back(
+        LLVM::LLVMStructType::getLiteral(context, body));
+  } else {
+    inferredReturnTypes.push_back(valType);
+  }
+  return success();
+}
+
 std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
                                                    NVVM::MMAFrag frag, int nRow,
                                                    int nCol,
@@ -2450,6 +2465,26 @@ LogicalResult NVVM::LdMatrixOp::verify() {
   return success();
 }
 
+LogicalResult LdMatrixOp::inferReturnTypes(
+    MLIRContext *context, std::optional<Location> location,
+    LdMatrixOp::Adaptor adaptor,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  uint32_t num = adaptor.getNum();
+  uint32_t m = adaptor.getShape().getM();
+  uint32_t n = adaptor.getShape().getN();
+  uint32_t numElements = (m == 16 && n == 16) ? num * 2 : num;
+
+  Type i32 = IntegerType::get(context, 32);
+  if (numElements == 1) {
+    inferredReturnTypes.push_back(i32);
+  } else {
+    inferredReturnTypes.push_back(
+        LLVM::LLVMStructType::getLiteral(context,
+                                         SmallVector<Type>(numElements, i32)));
+  }
+  return success();
+}
+
 LogicalResult NVVM::StMatrixOp::verify() {
   int numMatrix = getSources().size();
   if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
@@ -2840,6 +2875,21 @@ LogicalResult NVVM::MatchSyncOp::verify() {
   return success();
 }
 
+LogicalResult MatchSyncOp::inferReturnTypes(
+    MLIRContext *context, std::optional<Location> location,
+    MatchSyncOp::Adaptor adaptor,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  if (adaptor.getKind() == NVVM::MatchSyncKind::all) {
+    SmallVector<Type> body = {IntegerType::get(context, 32),
+                              IntegerType::get(context, 1)};
+    inferredReturnTypes.push_back(
+        LLVM::LLVMStructType::getLiteral(context, body));
+  } else {
+    inferredReturnTypes.push_back(IntegerType::get(context, 32));
+  }
+  return success();
+}
+
 LogicalResult NVVM::VoteSyncOp::verify() {
   if (getKind() == NVVM::VoteSyncKind::ballot) {
     if (!getType().isInteger(32)) {
@@ -2853,6 +2903,17 @@ LogicalResult NVVM::VoteSyncOp::verify() {
   return success();
 }
 
+LogicalResult VoteSyncOp::inferReturnTypes(
+    MLIRContext *context, std::optional<Location> location,
+    VoteSyncOp::Adaptor adaptor,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  if (adaptor.getKind() == NVVM::VoteSyncKind::ballot)
+    inferredReturnTypes.push_back(IntegerType::get(context, 32));
+  else
+    inferredReturnTypes.push_back(IntegerType::get(context, 1));
+  return success();
+}
+
 LogicalResult NVVM::PrefetchOp::verify() {
   using MemSpace = NVVM::NVVMMemorySpace;
   using CacheLevel = NVVM::PrefetchCacheLevel;
@@ -2946,6 +3007,18 @@ LogicalResult NVVM::ClusterLaunchControlQueryCancelOp::verify() {
   return success();
 }
 
+LogicalResult ClusterLaunchControlQueryCancelOp::inferReturnTypes(
+    MLIRContext *context, std::optional<Location> location,
+    ClusterLaunchControlQueryCancelOp::Adaptor adaptor,
+    SmallVectorImpl<Type> &inferredReturnTypes) {
+  if (adaptor.getQueryType() ==
+      NVVM::ClusterLaunchControlQueryType::IS_CANCELED)
+    inferredReturnTypes.push_back(IntegerType::get(context, 1));
+  else
+    inferredReturnTypes.push_back(IntegerType::get(context, 32));
+  return success();
+}
+
 LogicalResult NVVM::ReduxOp::verify() {
   mlir::Type reduxType = getType();
 
diff --git a/mlir/test/python/dialects/nvvm.py b/mlir/test/python/dialects/nvvm.py
index dc8e4f462ad56..b5e5298a94ab4 100644
--- a/mlir/test/python/dialects/nvvm.py
+++ b/mlir/test/python/dialects/nvvm.py
@@ -154,6 +154,134 @@ def barriers(mask, vi32, vf32):
 # CHECK:         }
 
 
+ at constructAndPrintInModule
+def test_vote_sync_infer_type():
+    i1 = IntegerType.get_signless(1)
+    i32 = T.i32()
+
+    @func.FuncOp.from_py_func(i32, i1)
+    def vote_sync_ops(mask, pred):
+        ballot_res = nvvm.vote_sync(mask, pred, nvvm.VoteSyncKind.ballot)
+        any_res = nvvm.vote_sync(mask, pred, nvvm.VoteSyncKind.any)
+        all_res = nvvm.vote_sync(mask, pred, nvvm.VoteSyncKind.all)
+        uni_res = nvvm.vote_sync(mask, pred, nvvm.VoteSyncKind.uni)
+        return ballot_res
+
+
+# CHECK-LABEL:   func.func @vote_sync_ops(
+# CHECK-SAME:      %[[MASK:.*]]: i32, %[[PRED:.*]]: i1) -> i32 {
+# CHECK:           %[[BALLOT:.*]] = nvvm.vote.sync ballot %[[MASK]], %[[PRED]] -> i32
+# CHECK:           %[[ANY:.*]] = nvvm.vote.sync any %[[MASK]], %[[PRED]] -> i1
+# CHECK:           %[[ALL:.*]] = nvvm.vote.sync all %[[MASK]], %[[PRED]] -> i1
+# CHECK:           %[[UNI:.*]] = nvvm.vote.sync uni %[[MASK]], %[[PRED]] -> i1
+# CHECK:           return %[[BALLOT]] : i32
+# CHECK:         }
+
+
+ at constructAndPrintInModule
+def test_clusterlaunchcontrol_query_cancel_infer_type():
+    i128 = IntegerType.get_signless(128)
+
+    @func.FuncOp.from_py_func(i128)
+    def query_cancel_ops(response):
+        is_canceled = nvvm.clusterlaunchcontrol_query_cancel(
+            nvvm.ClusterLaunchControlQueryType.IS_CANCELED,
+            response,
+        )
+        cta_x = nvvm.clusterlaunchcontrol_query_cancel(
+            nvvm.ClusterLaunchControlQueryType.GET_FIRST_CTA_ID_X,
+            response,
+        )
+        return cta_x
+
+
+# CHECK-LABEL:   func.func @query_cancel_ops(
+# CHECK-SAME:      %[[RESPONSE:.*]]: i128) -> i32 {
+# CHECK:           %{{.*}} = nvvm.clusterlaunchcontrol.query.cancel query = is_canceled, %[[RESPONSE]] : i1
+# CHECK:           %[[CTA_X:.*]] = nvvm.clusterlaunchcontrol.query.cancel query = get_first_cta_id_x, %[[RESPONSE]] : i32
+# CHECK:           return %[[CTA_X]] : i32
+# CHECK:         }
+
+
+ at constructAndPrintInModule
+def test_match_sync_infer_type():
+    i32 = T.i32()
+    i64 = IntegerType.get_signless(64)
+
+    @func.FuncOp.from_py_func(i32, i32, i64)
+    def match_sync_ops(mask, i32val, i64val):
+        any_result = nvvm.match_sync(mask, i32val, nvvm.MatchSyncKind.any)
+        all_result = nvvm.match_sync(mask, i32val, nvvm.MatchSyncKind.all)
+        return any_result
+
+
+# CHECK-LABEL:   func.func @match_sync_ops(
+# CHECK-SAME:      %[[MASK:.*]]: i32, %[[I32VAL:.*]]: i32, %[[I64VAL:.*]]: i64) -> i32 {
+# CHECK:           %[[ANY:.*]] = nvvm.match.sync any %[[MASK]], %[[I32VAL]] : i32 -> i32
+# CHECK:           %[[ALL:.*]] = nvvm.match.sync all %[[MASK]], %[[I32VAL]] : i32 -> !llvm.struct<(i32, i1)>
+# CHECK:           return %[[ANY]] : i32
+# CHECK:         }
+
+
+ at constructAndPrintInModule
+def test_shfl_sync_infer_type():
+    i32 = T.i32()
+    f32 = T.f32()
+
+    @func.FuncOp.from_py_func(i32, i32, f32, i32, i32)
+    def shfl_sync_ops(mask, i32val, f32val, offset, clamp):
+        i32_result = nvvm.shfl_sync(
+            mask, i32val, offset, clamp, nvvm.ShflKind.bfly
+        )
+        f32_result = nvvm.shfl_sync(
+            mask, f32val, offset, clamp, nvvm.ShflKind.bfly
+        )
+        struct_result = nvvm.shfl_sync(
+            mask, i32val, offset, clamp, nvvm.ShflKind.bfly,
+            return_value_and_is_valid=True
+        )
+        return i32_result
+
+
+# CHECK-LABEL:   func.func @shfl_sync_ops(
+# CHECK-SAME:      %[[MASK:.*]]: i32, %[[I32VAL:.*]]: i32, %[[F32VAL:.*]]: f32, %[[OFF:.*]]: i32, %[[CLAMP:.*]]: i32) -> i32 {
+# CHECK:           %[[I32R:.*]] = nvvm.shfl.sync bfly %[[MASK]], %[[I32VAL]], %[[OFF]], %[[CLAMP]] : i32 -> i32
+# CHECK:           %[[F32R:.*]] = nvvm.shfl.sync bfly %[[MASK]], %[[F32VAL]], %[[OFF]], %[[CLAMP]] : f32 -> f32
+# CHECK:           %[[STRUCT:.*]] = nvvm.shfl.sync bfly %[[MASK]], %[[I32VAL]], %[[OFF]], %[[CLAMP]] {return_value_and_is_valid} : i32 -> !llvm.struct<(i32, i1)>
+# CHECK:           return %[[I32R]] : i32
+# CHECK:         }
+
+
+ at constructAndPrintInModule
+def test_ldmatrix_infer_type():
+    ptr_shared = llvm.PointerType.get(3)
+
+    shape_8x8 = Attribute.parse("#nvvm.ld_st_matrix_shape<m = 8, n = 8>")
+    elt_b16 = Attribute.parse("#nvvm.ld_st_matrix_elt_type<b16>")
+
+    @func.FuncOp.from_py_func(ptr_shared)
+    def ldmatrix_ops(ptr):
+        r1 = nvvm.ldmatrix(
+            ptr, num=1, layout=nvvm.MMALayout.row,
+            shape=shape_8x8,
+            elt_type=elt_b16,
+        )
+        r4 = nvvm.ldmatrix(
+            ptr, num=4, layout=nvvm.MMALayout.row,
+            shape=shape_8x8,
+            elt_type=elt_b16,
+        )
+        return r1
+
+
+# CHECK-LABEL:   func.func @ldmatrix_ops(
+# CHECK-SAME:      %[[PTR:.*]]: !llvm.ptr<3>) -> i32 {
+# CHECK:           %[[R1:.*]] = nvvm.ldmatrix %[[PTR]] {{.*}}num = 1{{.*}} : (!llvm.ptr<3>) -> i32
+# CHECK:           %[[R4:.*]] = nvvm.ldmatrix %[[PTR]] {{.*}}num = 4{{.*}} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32, i32, i32)>
+# CHECK:           return %[[R1]] : i32
+# CHECK:         }
+
+
 @constructAndPrintInModule
 def test_reductions():
     i32 = T.i32()

``````````

</details>


https://github.com/llvm/llvm-project/pull/188238


More information about the Mlir-commits mailing list