[Mlir-commits] [mlir] [mlir][NVVM] Add InferTypeOpInterface to sync and ldmatrix ops (PR #188238)
Guray Ozen
llvmlistbot at llvm.org
Tue Mar 24 05:53:02 PDT 2026
================
@@ -154,6 +154,134 @@ def barriers(mask, vi32, vf32):
# CHECK: }
+ at constructAndPrintInModule
+def test_vote_sync_infer_type():
+ i1 = IntegerType.get_signless(1)
+ i32 = T.i32()
+
+ @func.FuncOp.from_py_func(i32, i1)
+ def vote_sync_ops(mask, pred):
+ ballot_res = nvvm.vote_sync(mask, pred, nvvm.VoteSyncKind.ballot)
+ any_res = nvvm.vote_sync(mask, pred, nvvm.VoteSyncKind.any)
+ all_res = nvvm.vote_sync(mask, pred, nvvm.VoteSyncKind.all)
+ uni_res = nvvm.vote_sync(mask, pred, nvvm.VoteSyncKind.uni)
+ return ballot_res
+
+
+# CHECK-LABEL: func.func @vote_sync_ops(
+# CHECK-SAME: %[[MASK:.*]]: i32, %[[PRED:.*]]: i1) -> i32 {
+# CHECK: %[[BALLOT:.*]] = nvvm.vote.sync ballot %[[MASK]], %[[PRED]] -> i32
+# CHECK: %[[ANY:.*]] = nvvm.vote.sync any %[[MASK]], %[[PRED]] -> i1
+# CHECK: %[[ALL:.*]] = nvvm.vote.sync all %[[MASK]], %[[PRED]] -> i1
+# CHECK: %[[UNI:.*]] = nvvm.vote.sync uni %[[MASK]], %[[PRED]] -> i1
+# CHECK: return %[[BALLOT]] : i32
+# CHECK: }
+
+
+ at constructAndPrintInModule
+def test_clusterlaunchcontrol_query_cancel_infer_type():
+ i128 = IntegerType.get_signless(128)
+
+ @func.FuncOp.from_py_func(i128)
+ def query_cancel_ops(response):
+ is_canceled = nvvm.clusterlaunchcontrol_query_cancel(
+ nvvm.ClusterLaunchControlQueryType.IS_CANCELED,
+ response,
+ )
+ cta_x = nvvm.clusterlaunchcontrol_query_cancel(
+ nvvm.ClusterLaunchControlQueryType.GET_FIRST_CTA_ID_X,
+ response,
+ )
+ return cta_x
+
+
+# CHECK-LABEL: func.func @query_cancel_ops(
+# CHECK-SAME: %[[RESPONSE:.*]]: i128) -> i32 {
+# CHECK: %{{.*}} = nvvm.clusterlaunchcontrol.query.cancel query = is_canceled, %[[RESPONSE]] : i1
+# CHECK: %[[CTA_X:.*]] = nvvm.clusterlaunchcontrol.query.cancel query = get_first_cta_id_x, %[[RESPONSE]] : i32
+# CHECK: return %[[CTA_X]] : i32
+# CHECK: }
+
+
+ at constructAndPrintInModule
+def test_match_sync_infer_type():
+ i32 = T.i32()
+ i64 = IntegerType.get_signless(64)
+
+ @func.FuncOp.from_py_func(i32, i32, i64)
+ def match_sync_ops(mask, i32val, i64val):
+ any_result = nvvm.match_sync(mask, i32val, nvvm.MatchSyncKind.any)
+ all_result = nvvm.match_sync(mask, i32val, nvvm.MatchSyncKind.all)
+ return any_result
+
+
+# CHECK-LABEL: func.func @match_sync_ops(
+# CHECK-SAME: %[[MASK:.*]]: i32, %[[I32VAL:.*]]: i32, %[[I64VAL:.*]]: i64) -> i32 {
+# CHECK: %[[ANY:.*]] = nvvm.match.sync any %[[MASK]], %[[I32VAL]] : i32 -> i32
+# CHECK: %[[ALL:.*]] = nvvm.match.sync all %[[MASK]], %[[I32VAL]] : i32 -> !llvm.struct<(i32, i1)>
+# CHECK: return %[[ANY]] : i32
+# CHECK: }
+
+
+ at constructAndPrintInModule
+def test_shfl_sync_infer_type():
+ i32 = T.i32()
+ f32 = T.f32()
+
+ @func.FuncOp.from_py_func(i32, i32, f32, i32, i32)
+ def shfl_sync_ops(mask, i32val, f32val, offset, clamp):
+ i32_result = nvvm.shfl_sync(
+ mask, i32val, offset, clamp, nvvm.ShflKind.bfly
+ )
+ f32_result = nvvm.shfl_sync(
+ mask, f32val, offset, clamp, nvvm.ShflKind.bfly
+ )
+ struct_result = nvvm.shfl_sync(
+ mask, i32val, offset, clamp, nvvm.ShflKind.bfly,
+ return_value_and_is_valid=True
+ )
+ return i32_result
+
+
+# CHECK-LABEL: func.func @shfl_sync_ops(
+# CHECK-SAME: %[[MASK:.*]]: i32, %[[I32VAL:.*]]: i32, %[[F32VAL:.*]]: f32, %[[OFF:.*]]: i32, %[[CLAMP:.*]]: i32) -> i32 {
+# CHECK: %[[I32R:.*]] = nvvm.shfl.sync bfly %[[MASK]], %[[I32VAL]], %[[OFF]], %[[CLAMP]] : i32 -> i32
+# CHECK: %[[F32R:.*]] = nvvm.shfl.sync bfly %[[MASK]], %[[F32VAL]], %[[OFF]], %[[CLAMP]] : f32 -> f32
+# CHECK: %[[STRUCT:.*]] = nvvm.shfl.sync bfly %[[MASK]], %[[I32VAL]], %[[OFF]], %[[CLAMP]] {return_value_and_is_valid} : i32 -> !llvm.struct<(i32, i1)>
+# CHECK: return %[[I32R]] : i32
+# CHECK: }
----------------
grypp wrote:
nit: remove this check
https://github.com/llvm/llvm-project/pull/188238
More information about the Mlir-commits
mailing list