[Mlir-commits] [mlir] [mlir][python] Add tests for gpu.launch(_func) ops (PR #163883)
Maksim Levental
llvmlistbot at llvm.org
Thu Oct 16 19:05:23 PDT 2025
================
@@ -151,3 +160,176 @@ def entry_block(self) -> Block:
@property
def arguments(self) -> Sequence[Type]:
return self.function_type.value.inputs
+
+
+def _convert_literal_to_constant(value: Union[int, ConstantOp, Value]) -> Value:
+ if isinstance(value, int):
+ return constant(T.index(), value)
+ elif isinstance(value, (ConstantOp, Value)):
+ return value
+ else:
+ raise ValueError(f"Invalid value: {value}")
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class LaunchFuncOp(LaunchFuncOp):
+ __doc__ = LaunchFuncOp.__doc__
+
+ def __init__(
+ self,
+ kernel: List[str],
+ grid_size: Tuple[Any, Any, Any],
+ block_size: Tuple[Any, Any, Any],
+ kernel_operands: Optional[List[Value]] = None,
+ async_dependencies: Optional[List[Value]] = None,
+ dynamic_shared_memory_size: Optional[Value] = None,
+ async_object=None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ if async_dependencies is None:
+ async_dependencies = []
+ async_token = None
+ if len(async_dependencies):
+ async_token = gpu_async_token()
+
+ grid_size_x, grid_size_y, grid_size_z = map(
+ _convert_literal_to_constant, grid_size
+ )
+ block_size_x, block_size_y, block_size_z = map(
+ _convert_literal_to_constant, block_size
+ )
+
+ super().__init__(
+ async_token,
+ async_dependencies,
+ kernel,
+ grid_size_x,
+ grid_size_y,
+ grid_size_z,
+ block_size_x,
+ block_size_y,
+ block_size_z,
+ kernel_operands,
+ dynamicSharedMemorySize=dynamic_shared_memory_size,
+ asyncObject=async_object,
+ loc=loc,
+ ip=ip,
+ )
+
+
+def launch_func(
+ kernel: List[str],
+ grid_size: Tuple[Any, Any, Any],
+ block_size: Tuple[Any, Any, Any],
+ kernel_operands: Optional[List[Value]] = None,
+ async_dependencies: Optional[List[Value]] = None,
+ dynamic_shared_memory_size: Optional[Value] = None,
+ async_object=None,
+ *,
+ loc=None,
+ ip=None,
+) -> Union[Value, List[Value], LaunchFuncOp]:
+ op = LaunchFuncOp(
+ kernel=kernel,
+ grid_size=grid_size,
+ block_size=block_size,
+ kernel_operands=kernel_operands,
+ async_dependencies=async_dependencies,
+ dynamic_shared_memory_size=dynamic_shared_memory_size,
+ async_object=async_object,
+ loc=loc,
+ ip=ip,
+ )
+ results = op.results
+ if len(results) == 1:
+ return results[0]
+ elif len(results) > 1:
+ return results
+ else:
+ return op
+
+
+def wait(
+ async_dependencies: Optional[List[Value]] = None, *, loc=None, ip=None
+) -> Union[Value, List[Value], WaitOp]:
+ if async_dependencies is None:
+ async_dependencies = []
+ return get_op_result_or_op_results(
+ WaitOp(gpu_async_token(), async_dependencies, loc=loc, ip=ip)
+ )
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class LaunchOp(LaunchOp):
+ __doc__ = LaunchOp.__doc__
+
+ def __init__(
+ self,
+ grid_size: Tuple[Any, Any, Any],
+ block_size: Tuple[Any, Any, Any],
+ async_dependencies=None,
+ dynamic_shared_memory_size: Optional[Value] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ if async_dependencies is None:
+ async_dependencies = []
+ async_token = None
+ if len(async_dependencies):
+ async_token = gpu_async_token()
+ grid_size_x, grid_size_y, grid_size_z = map(
+ _convert_literal_to_constant, grid_size
+ )
+ block_size_x, block_size_y, block_size_z = map(
+ _convert_literal_to_constant, block_size
+ )
+
+ super().__init__(
+ async_token,
+ async_dependencies,
+ grid_size_x,
+ grid_size_y,
+ grid_size_z,
+ block_size_x,
+ block_size_y,
+ block_size_z,
+ dynamicSharedMemorySize=dynamic_shared_memory_size,
+ loc=loc,
+ ip=ip,
+ )
+ self.regions[0].blocks.append(*[T.index() for _ in range(12)])
+
+
+def launch_(
+ grid_size: Tuple[Any, Any, Any],
+ block_size: Tuple[Any, Any, Any],
+ async_dependencies=None,
+ dynamic_shared_memory_size: Optional[Value] = None,
+ *,
+ loc=None,
+ ip=None,
+):
+ grid_size = tuple(map(_convert_literal_to_constant, grid_size))
+ block_size = tuple(map(_convert_literal_to_constant, block_size))
+ launch_op = LaunchOp(
+ grid_size,
+ block_size,
+ async_dependencies,
+ dynamic_shared_memory_size,
+ loc=loc,
+ ip=ip,
+ )
+ return launch_op
+
+
+launch = region_op(launch_, terminator=lambda *_args: terminator())
+
+
+_printf = printf
+
+
+def printf(format, *args, loc=None, ip=None):
+ return _printf(format=format, args=args, loc=loc, ip=ip)
----------------
makslevental wrote:
ok lol this is a good use of `*args`
https://github.com/llvm/llvm-project/pull/163883
More information about the Mlir-commits
mailing list