[Mlir-commits] [mlir] [mlir][python] Add tests for gpu.launch(_func) ops (PR #163883)

Asher Mancinelli llvmlistbot at llvm.org
Thu Oct 16 15:49:39 PDT 2025


https://github.com/ashermancinelli created https://github.com/llvm/llvm-project/pull/163883

Add some tests for launching GPU kernels and regions and correct some small documentation mistakes. I wonder if we could add builders that take 3-tuples for the dim3 launch parameters, and let the async tokens default to None/empty list; essentially supporting the use cases provided by the builders on the C++ side, like:
```cpp
    OpBuilder<(ins "GPUFuncOp":$kernelFunc, "KernelDim3":$gridSize,
      "KernelDim3":$blockSize, "Value":$dynamicSharedMemorySize,
      "ValueRange":$kernelOperands,
      CArg<"Type", "nullptr">:$asyncTokenType,
      CArg<"ValueRange", "{}">:$asyncDependencies,
      CArg<"std::optional<KernelDim3>", "std::nullopt">:$clusterSize)>,
```
This PR is only to test what's currently there, but if folks support a builder that mirrors the C++ builders while preserving the current use-case, I'll add that next. Thanks in advance!

>From 1db01a2bda0d3704ef273ec665d4be98e4c9472d Mon Sep 17 00:00:00 2001
From: Asher Mancinelli <ashermancinelli at gmail.com>
Date: Thu, 16 Oct 2025 15:42:54 -0700
Subject: [PATCH] [mlir][python] Add tests for gpu.launch(_func) ops

These are the tests I wish I could have referred to during development.
Also corrected some small documentation mistakes.
---
 mlir/docs/Dialects/GPU.md                  |  2 +-
 mlir/include/mlir/Dialect/GPU/IR/GPUOps.td |  2 +-
 mlir/test/python/dialects/gpu/dialect.py   | 99 +++++++++++++++++++++-
 3 files changed, 100 insertions(+), 3 deletions(-)

diff --git a/mlir/docs/Dialects/GPU.md b/mlir/docs/Dialects/GPU.md
index 8d4d2ca3e5743..c16ed57737e5b 100644
--- a/mlir/docs/Dialects/GPU.md
+++ b/mlir/docs/Dialects/GPU.md
@@ -121,7 +121,7 @@ func.func @main() {
     gpu.launch
         blocks(%0, %1, %2) in (%3 = %c1, %4 = %c1, %5 = %c1)
         threads(%6, %7, %8) in (%9 = %c2, %10 = %c1, %11 = %c1) {
-        gpu.printf "Hello from %d\n" %6 : index
+        gpu.printf "Hello from %d\n", %6 : index
         gpu.terminator
     }
     return
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 987fc13e0508d..a6c6038e1e224 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -584,7 +584,7 @@ def GPU_DynamicSharedMemoryOp : GPU_Op<"dynamic_shared_memory", [Pure]>
     This operation provides a memref pointer to the start of dynamic shared
     memory, often referred to as workgroup memory. It's important to note that
     this dynamic shared memory needs to be allocated at kernel launch. One can
-    conveniently utilize `the dynamic_shared_memory_size` parameter of
+    conveniently utilize the `dynamic_shared_memory_size` parameter of
     `gpu.launch` for this purpose.
 
     Examples:
diff --git a/mlir/test/python/dialects/gpu/dialect.py b/mlir/test/python/dialects/gpu/dialect.py
index 66c401886804c..24f20d109b3d0 100644
--- a/mlir/test/python/dialects/gpu/dialect.py
+++ b/mlir/test/python/dialects/gpu/dialect.py
@@ -2,7 +2,8 @@
 
 from mlir.ir import *
 import mlir.ir as ir
-import mlir.dialects.gpu as gpu
+from mlir.dialects import gpu, func, arith, math
+from mlir.extras import types as T
 import mlir.dialects.gpu.passes
 from mlir.passmanager import *
 
@@ -157,3 +158,99 @@ def builder(func: gpu.GPUFuncOp) -> None:
     # CHECK:   %[[VAL_0:.*]] = gpu.global_id  x
     # CHECK:   gpu.return
     # CHECK: }
+
+# CHECK-LABEL: testGPULaunchFuncOp
+ at run
+def testGPULaunchFuncOp():
+    module = Module.create()
+
+    module.operation.attributes["gpu.container_module"] = UnitAttr.get()
+    with InsertionPoint(module.body):
+        gpu_module = gpu.GPUModuleOp("gpu_module")
+        block = gpu_module.bodyRegion.blocks.append()
+
+    with InsertionPoint(block):
+        gpu_func = gpu.GPUFuncOp(
+            FunctionType.get([], []),
+            "kernel",
+            body_builder=lambda func: gpu.return_([]),
+            kernel=True,
+        )
+
+    with InsertionPoint(module.body):
+        host = func.FuncOp(type=FunctionType.get([], []), name="host")
+
+    with InsertionPoint(host.add_entry_block()):
+        c1 = arith.constant(T.index(), 1)
+        grid_sizes = [c1] * 3
+        block_sizes = [c1] * 3
+        sym_ref = SymbolRefAttr.get([gpu_module.sym_name.value, gpu_func.name.value])
+        token_type = Type.parse("!gpu.async.token")
+        token = gpu.wait(async_token=token_type, async_dependencies=[])
+        token = gpu.launch_func(
+            async_token=token_type,
+            async_dependencies=[token],
+            kernel=sym_ref,
+            grid_size_x=grid_sizes[0],
+            grid_size_y=grid_sizes[1],
+            grid_size_z=grid_sizes[2],
+            block_size_x=block_sizes[0],
+            block_size_y=block_sizes[1],
+            block_size_z=block_sizes[2],
+            kernel_operands=[],
+        )
+        gpu.wait(async_token=None, async_dependencies=[token])
+        func.ReturnOp([])
+
+    print(module)
+
+    # CHECK-LABEL:   gpu.module @gpu_module {
+    # CHECK:           gpu.func @kernel() kernel {
+    # CHECK:             gpu.return
+    # CHECK:           }
+    # CHECK:         }
+
+    # CHECK-LABEL:   func.func @host() {
+    # CHECK:           %[[CONSTANT_0:.*]] = arith.constant 1 : index
+    # CHECK:           %[[WAIT_0:.*]] = gpu.wait async
+    # CHECK:           %[[LAUNCH_FUNC_0:.*]] = gpu.launch_func async {{\[}}%[[WAIT_0]]] @gpu_module::@kernel blocks in (%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]]) threads in (%[[CONSTANT_0]], %[[CONSTANT_0]], %[[CONSTANT_0]])
+    # CHECK:           gpu.wait {{\[}}%[[LAUNCH_FUNC_0]]]
+    # CHECK:           return
+    # CHECK:         }
+
+
+# CHECK-LABEL: testGPULaunchOp
+ at run
+def testGPULaunchOp():
+    module = Module.create()
+
+    with InsertionPoint(module.body):
+        host = func.FuncOp(type=FunctionType.get([T.f32()], []), name="gpu_printf")
+
+    entry_block = host.add_entry_block()
+    with InsertionPoint(entry_block):
+        c1 = arith.constant(T.index(), 1)
+
+        launch = gpu.launch(None, [], c1, c1, c1, c1, c1, c1)
+        launch_block = launch.regions[0].blocks.append()
+        for _ in range(12):
+            launch_block.add_argument(T.index(), Location.unknown())
+
+    with InsertionPoint(launch_block):
+        gpu.printf("%f", [entry_block.arguments[0]])
+        gpu.terminator()
+
+    with InsertionPoint(entry_block):
+        func.ReturnOp([])
+
+    print(module)
+
+    # CHECK-LABEL:   func.func @gpu_printf(
+    # CHECK-SAME:      %[[ARG0:.*]]: f32) {
+    # CHECK:           %[[CONSTANT_0:.*]] = arith.constant 1 : index
+    # CHECK:           gpu.launch blocks(%[[VAL_0:.*]], %[[VAL_1:.*]], %[[VAL_2:.*]]) in (%[[VAL_3:.*]] = %[[CONSTANT_0]], %[[VAL_4:.*]] = %[[CONSTANT_0]], %[[VAL_5:.*]] = %[[CONSTANT_0]]) threads(%[[VAL_6:.*]], %[[VAL_7:.*]], %[[VAL_8:.*]]) in (%[[VAL_9:.*]] = %[[CONSTANT_0]], %[[VAL_10:.*]] = %[[CONSTANT_0]], %[[VAL_11:.*]] = %[[CONSTANT_0]]) {
+    # CHECK:             gpu.printf "%[[VAL_12:.*]]", %[[ARG0]] : f32
+    # CHECK:             gpu.terminator
+    # CHECK:           }
+    # CHECK:           return
+    # CHECK:         }



More information about the Mlir-commits mailing list