[Mlir-commits] [mlir] [mlir][NVVM] Add InferTypeOpInterface to sync and ldmatrix ops (PR #188238)

Guray Ozen llvmlistbot at llvm.org
Tue Mar 24 05:53:02 PDT 2026


================
@@ -154,6 +154,134 @@ def barriers(mask, vi32, vf32):
 # CHECK:         }
 
 
+ at constructAndPrintInModule
+def test_vote_sync_infer_type():
+    i1 = IntegerType.get_signless(1)
+    i32 = T.i32()
+
+    @func.FuncOp.from_py_func(i32, i1)
+    def vote_sync_ops(mask, pred):
+        ballot_res = nvvm.vote_sync(mask, pred, nvvm.VoteSyncKind.ballot)
+        any_res = nvvm.vote_sync(mask, pred, nvvm.VoteSyncKind.any)
+        all_res = nvvm.vote_sync(mask, pred, nvvm.VoteSyncKind.all)
+        uni_res = nvvm.vote_sync(mask, pred, nvvm.VoteSyncKind.uni)
+        return ballot_res
+
+
+# CHECK-LABEL:   func.func @vote_sync_ops(
+# CHECK-SAME:      %[[MASK:.*]]: i32, %[[PRED:.*]]: i1) -> i32 {
+# CHECK:           %[[BALLOT:.*]] = nvvm.vote.sync ballot %[[MASK]], %[[PRED]] -> i32
+# CHECK:           %[[ANY:.*]] = nvvm.vote.sync any %[[MASK]], %[[PRED]] -> i1
+# CHECK:           %[[ALL:.*]] = nvvm.vote.sync all %[[MASK]], %[[PRED]] -> i1
+# CHECK:           %[[UNI:.*]] = nvvm.vote.sync uni %[[MASK]], %[[PRED]] -> i1
+# CHECK:           return %[[BALLOT]] : i32
+# CHECK:         }
+
+
+ at constructAndPrintInModule
+def test_clusterlaunchcontrol_query_cancel_infer_type():
+    i128 = IntegerType.get_signless(128)
+
+    @func.FuncOp.from_py_func(i128)
+    def query_cancel_ops(response):
+        is_canceled = nvvm.clusterlaunchcontrol_query_cancel(
+            nvvm.ClusterLaunchControlQueryType.IS_CANCELED,
+            response,
+        )
+        cta_x = nvvm.clusterlaunchcontrol_query_cancel(
+            nvvm.ClusterLaunchControlQueryType.GET_FIRST_CTA_ID_X,
+            response,
+        )
+        return cta_x
+
+
+# CHECK-LABEL:   func.func @query_cancel_ops(
+# CHECK-SAME:      %[[RESPONSE:.*]]: i128) -> i32 {
+# CHECK:           %{{.*}} = nvvm.clusterlaunchcontrol.query.cancel query = is_canceled, %[[RESPONSE]] : i1
+# CHECK:           %[[CTA_X:.*]] = nvvm.clusterlaunchcontrol.query.cancel query = get_first_cta_id_x, %[[RESPONSE]] : i32
+# CHECK:           return %[[CTA_X]] : i32
+# CHECK:         }
+
+
+ at constructAndPrintInModule
+def test_match_sync_infer_type():
+    i32 = T.i32()
+    i64 = IntegerType.get_signless(64)
+
+    @func.FuncOp.from_py_func(i32, i32, i64)
+    def match_sync_ops(mask, i32val, i64val):
+        any_result = nvvm.match_sync(mask, i32val, nvvm.MatchSyncKind.any)
+        all_result = nvvm.match_sync(mask, i32val, nvvm.MatchSyncKind.all)
+        return any_result
+
+
+# CHECK-LABEL:   func.func @match_sync_ops(
+# CHECK-SAME:      %[[MASK:.*]]: i32, %[[I32VAL:.*]]: i32, %[[I64VAL:.*]]: i64) -> i32 {
+# CHECK:           %[[ANY:.*]] = nvvm.match.sync any %[[MASK]], %[[I32VAL]] : i32 -> i32
+# CHECK:           %[[ALL:.*]] = nvvm.match.sync all %[[MASK]], %[[I32VAL]] : i32 -> !llvm.struct<(i32, i1)>
+# CHECK:           return %[[ANY]] : i32
+# CHECK:         }
+
+
+ at constructAndPrintInModule
+def test_shfl_sync_infer_type():
+    i32 = T.i32()
+    f32 = T.f32()
+
+    @func.FuncOp.from_py_func(i32, i32, f32, i32, i32)
+    def shfl_sync_ops(mask, i32val, f32val, offset, clamp):
+        i32_result = nvvm.shfl_sync(
+            mask, i32val, offset, clamp, nvvm.ShflKind.bfly
+        )
+        f32_result = nvvm.shfl_sync(
+            mask, f32val, offset, clamp, nvvm.ShflKind.bfly
+        )
+        struct_result = nvvm.shfl_sync(
+            mask, i32val, offset, clamp, nvvm.ShflKind.bfly,
+            return_value_and_is_valid=True
+        )
+        return i32_result
+
+
+# CHECK-LABEL:   func.func @shfl_sync_ops(
+# CHECK-SAME:      %[[MASK:.*]]: i32, %[[I32VAL:.*]]: i32, %[[F32VAL:.*]]: f32, %[[OFF:.*]]: i32, %[[CLAMP:.*]]: i32) -> i32 {
+# CHECK:           %[[I32R:.*]] = nvvm.shfl.sync bfly %[[MASK]], %[[I32VAL]], %[[OFF]], %[[CLAMP]] : i32 -> i32
+# CHECK:           %[[F32R:.*]] = nvvm.shfl.sync bfly %[[MASK]], %[[F32VAL]], %[[OFF]], %[[CLAMP]] : f32 -> f32
+# CHECK:           %[[STRUCT:.*]] = nvvm.shfl.sync bfly %[[MASK]], %[[I32VAL]], %[[OFF]], %[[CLAMP]] {return_value_and_is_valid} : i32 -> !llvm.struct<(i32, i1)>
+# CHECK:           return %[[I32R]] : i32
+# CHECK:         }
----------------
grypp wrote:

nit: remove this check 

https://github.com/llvm/llvm-project/pull/188238


More information about the Mlir-commits mailing list