[Mlir-commits] [mlir] [mlir][python] Add pythonic interface for GPUFuncOp (PR #163596)

Asher Mancinelli llvmlistbot at llvm.org
Wed Oct 15 11:17:25 PDT 2025


https://github.com/ashermancinelli updated https://github.com/llvm/llvm-project/pull/163596

>From 67e32af7a2216e0d26876ceb20c20f94d7462eff Mon Sep 17 00:00:00 2001
From: Asher Mancinelli <ashermancinelli at gmail.com>
Date: Wed, 15 Oct 2025 09:54:16 -0700
Subject: [PATCH 1/4] [mlir][python] Add pythonic interface for GPUFunc

The func dialect provides a more pythonic interface for constructing
operations, but the gpu dialect does not; this is the first PR to
provide the same conveniences for the gpu dialect, starting with the
gpu.func op.
---
 mlir/python/mlir/dialects/gpu/__init__.py | 116 ++++++++++++++++++++++
 mlir/test/python/dialects/gpu/dialect.py  |  63 ++++++++++++
 2 files changed, 179 insertions(+)

diff --git a/mlir/python/mlir/dialects/gpu/__init__.py b/mlir/python/mlir/dialects/gpu/__init__.py
index 4cd80aa8b7ca8..14b965927e280 100644
--- a/mlir/python/mlir/dialects/gpu/__init__.py
+++ b/mlir/python/mlir/dialects/gpu/__init__.py
@@ -3,5 +3,121 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 from .._gpu_ops_gen import *
+from .._gpu_ops_gen import _Dialect
 from .._gpu_enum_gen import *
 from ..._mlir_libs._mlirDialectsGPU import *
+from typing import Callable, Sequence, Union, Optional
+
+try:
+    from ...ir import (
+        FunctionType,
+        TypeAttr,
+        StringAttr,
+        UnitAttr,
+        Block,
+        InsertionPoint,
+        ArrayAttr,
+        Type,
+        DictAttr,
+        Attribute,
+    )
+    from .._ods_common import (
+        get_default_loc_context as _get_default_loc_context,
+        _cext as _ods_cext,
+    )
+except ImportError as e:
+    raise RuntimeError("Error loading imports from extension module") from e
+
+
+FUNCTION_TYPE_ATTRIBUTE_NAME = "function_type"
+KERNEL_ATTRIBUTE_NAME = "gpu.kernel"
+SYM_NAME_ATTRIBUTE_NAME = "sym_name"
+ARGUMENT_ATTRIBUTE_NAME = "arg_attrs"
+RESULT_ATTRIBUTE_NAME = "res_attrs"
+
+
+ at _ods_cext.register_operation(_Dialect, replace=True)
+class GPUFuncOp(GPUFuncOp):
+    def __init__(
+        self,
+        function_type: Union[FunctionType, TypeAttr],
+        sym_name: Optional[Union[str, StringAttr]] = None,
+        kernel: Optional[bool] = None,
+        body_builder: Optional[Callable[[GPUFuncOp], None]] = None,
+        *args,
+        loc=None,
+        ip=None,
+        **kwargs,
+    ):
+        function_type = (
+            TypeAttr.get(function_type)
+            if not isinstance(function_type, TypeAttr)
+            else function_type
+        )
+        super().__init__(function_type, *args, loc=loc, ip=ip, **kwargs)
+        if sym_name is not None:
+            self.attributes[SYM_NAME_ATTRIBUTE_NAME] = StringAttr.get(str(sym_name))
+        if kernel:
+            self.attributes[KERNEL_ATTRIBUTE_NAME] = UnitAttr.get()
+        if body_builder is not None:
+            with InsertionPoint(self.add_entry_block()):
+                body_builder(self)
+
+    @property
+    def type(self) -> FunctionType:
+        return FunctionType(
+            TypeAttr(self.attributes[FUNCTION_TYPE_ATTRIBUTE_NAME]).value
+        )
+
+    @property
+    def name(self) -> StringAttr:
+        return StringAttr(self.attributes[SYM_NAME_ATTRIBUTE_NAME])
+
+    @property
+    def is_kernel(self) -> bool:
+        return KERNEL_ATTRIBUTE_NAME in self.attributes
+
+    def add_entry_block(self) -> Block:
+        function_type = self.type
+        return self.body.blocks.append(
+            *function_type.inputs,
+            arg_locs=[self.location for _ in function_type.inputs],
+        )
+
+    @property
+    def entry_block(self) -> Block:
+        return self.body.blocks[0]
+
+    @property
+    def arguments(self) -> Sequence[Type]:
+        return self.type.inputs
+
+    @property
+    def arg_attrs(self):
+        if ARGUMENT_ATTRIBUTE_NAME not in self.attributes:
+            return ArrayAttr.get([DictAttr.get({}) for _ in self.type.inputs])
+        return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
+
+    @arg_attrs.setter
+    def arg_attrs(self, attribute: Union[ArrayAttr, list[Attribute]]):
+        if isinstance(attribute, ArrayAttr):
+            self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
+        else:
+            self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
+                attribute, context=self.context
+            )
+
+    @property
+    def result_attrs(self) -> Optional[ArrayAttr]:
+        if RESULT_ATTRIBUTE_NAME not in self.attributes:
+            return ArrayAttr.get([DictAttr.get({}) for _ in self.type.results])
+        return self.attributes[RESULT_ATTRIBUTE_NAME]
+
+    @result_attrs.setter
+    def result_attrs(self, attribute: Union[ArrayAttr, list[Attribute]]):
+        if isinstance(attribute, ArrayAttr):
+            self.attributes[RESULT_ATTRIBUTE_NAME] = attribute
+        else:
+            self.attributes[RESULT_ATTRIBUTE_NAME] = ArrayAttr.get(
+                attribute, context=self.context
+            )
diff --git a/mlir/test/python/dialects/gpu/dialect.py b/mlir/test/python/dialects/gpu/dialect.py
index 26ee9f34cb332..ce6e3df634e90 100644
--- a/mlir/test/python/dialects/gpu/dialect.py
+++ b/mlir/test/python/dialects/gpu/dialect.py
@@ -4,6 +4,7 @@
 import mlir.dialects.gpu as gpu
 import mlir.dialects.gpu.passes
 from mlir.passmanager import *
+import mlir.ir as ir
 
 
 def run(f):
@@ -64,3 +65,65 @@ def testObjectAttr():
     # CHECK: #gpu.object<#nvvm.target, kernels = <[#gpu.kernel_metadata<"kernel", () -> ()>]>, "BC\C0\DE5\14\00\00\05\00\00\00b\0C0$MY\BEf">
     print(o)
     assert o.kernels == kernelTable
+
+
+# CHECK-LABEL: testGPUFuncOp
+ at run
+def testGPUFuncOp():
+    module = Module.create()
+    with InsertionPoint(module.body):
+        gpu_module_name = StringAttr.get("gpu_module")
+        gpumodule = gpu.GPUModuleOp(gpu_module_name)
+        block = gpumodule.bodyRegion.blocks.append()
+
+        def builder(func: gpu.GPUFuncOp) -> None:
+            _ = gpu.GlobalIdOp(gpu.Dimension.x)
+            _ = gpu.ReturnOp([])
+
+        with InsertionPoint(block):
+            name = StringAttr.get("kernel0")
+            func_type = ir.FunctionType.get(inputs=[], results=[])
+            type_attr = TypeAttr.get(func_type)
+            func = gpu.GPUFuncOp(type_attr, name)
+            func.attributes[gpu.SYM_NAME_ATTRIBUTE_NAME] = name
+            func.attributes[gpu.KERNEL_ATTRIBUTE_NAME] = UnitAttr.get()
+            block = func.body.blocks.append()
+            with InsertionPoint(block):
+                builder(func)
+
+            func = gpu.GPUFuncOp(
+                func_type,
+                sym_name="kernel1",
+                kernel=True,
+                body_builder=builder,
+            )
+
+            assert func.name.value == "kernel1"
+            assert func.arg_attrs == ArrayAttr.get([])
+            assert func.result_attrs == ArrayAttr.get([])
+            assert func.arguments == []
+            assert func.entry_block == func.body.blocks[0]
+            assert func.is_kernel
+
+            non_kernel_func = gpu.GPUFuncOp(
+                func_type,
+                sym_name="non_kernel_func",
+                body_builder=builder,
+            )
+            assert not non_kernel_func.is_kernel
+
+    print(module)
+
+    # CHECK: gpu.module @gpu_module
+    # CHECK: gpu.func @kernel0() kernel {
+    # CHECK:   %[[VAL_0:.*]] = gpu.global_id  x
+    # CHECK:   gpu.return
+    # CHECK: }
+    # CHECK: gpu.func @kernel1() kernel {
+    # CHECK:   %[[VAL_0:.*]] = gpu.global_id  x
+    # CHECK:   gpu.return
+    # CHECK: }
+    # CHECK: gpu.func @non_kernel_func() {
+    # CHECK:   %[[VAL_0:.*]] = gpu.global_id  x
+    # CHECK:   gpu.return
+    # CHECK: }

>From 380a6946af8f738483b0e630d7f58727b2d9257e Mon Sep 17 00:00:00 2001
From: Asher Mancinelli <ashermancinelli at gmail.com>
Date: Wed, 15 Oct 2025 10:40:53 -0700
Subject: [PATCH 2/4] Improve documentation, add launch param attrs

---
 mlir/python/mlir/dialects/gpu/__init__.py | 51 +++++++++++++++++++++--
 mlir/test/python/dialects/gpu/dialect.py  | 21 +++++++---
 2 files changed, 63 insertions(+), 9 deletions(-)

diff --git a/mlir/python/mlir/dialects/gpu/__init__.py b/mlir/python/mlir/dialects/gpu/__init__.py
index 14b965927e280..9192ff1ad5e4e 100644
--- a/mlir/python/mlir/dialects/gpu/__init__.py
+++ b/mlir/python/mlir/dialects/gpu/__init__.py
@@ -6,7 +6,7 @@
 from .._gpu_ops_gen import _Dialect
 from .._gpu_enum_gen import *
 from ..._mlir_libs._mlirDialectsGPU import *
-from typing import Callable, Sequence, Union, Optional
+from typing import Callable, Sequence, Union, Optional, List
 
 try:
     from ...ir import (
@@ -20,6 +20,7 @@
         Type,
         DictAttr,
         Attribute,
+        DenseI32ArrayAttr,
     )
     from .._ods_common import (
         get_default_loc_context as _get_default_loc_context,
@@ -29,26 +30,44 @@
     raise RuntimeError("Error loading imports from extension module") from e
 
 
-FUNCTION_TYPE_ATTRIBUTE_NAME = "function_type"
 KERNEL_ATTRIBUTE_NAME = "gpu.kernel"
+KNOWN_BLOCK_SIZE_ATTRIBUTE_NAME = "gpu.known_block_size"
+KNOWN_GRID_SIZE_ATTRIBUTE_NAME = "gpu.known_grid_size"
+
+FUNCTION_TYPE_ATTRIBUTE_NAME = "function_type"
 SYM_NAME_ATTRIBUTE_NAME = "sym_name"
 ARGUMENT_ATTRIBUTE_NAME = "arg_attrs"
 RESULT_ATTRIBUTE_NAME = "res_attrs"
 
-
 @_ods_cext.register_operation(_Dialect, replace=True)
 class GPUFuncOp(GPUFuncOp):
+    __doc__ = GPUFuncOp.__doc__
+
     def __init__(
         self,
         function_type: Union[FunctionType, TypeAttr],
         sym_name: Optional[Union[str, StringAttr]] = None,
         kernel: Optional[bool] = None,
         body_builder: Optional[Callable[[GPUFuncOp], None]] = None,
+        known_block_size: Optional[Union[List[int], DenseI32ArrayAttr]] = None,
+        known_grid_size: Optional[Union[List[int], DenseI32ArrayAttr]] = None,
         *args,
         loc=None,
         ip=None,
         **kwargs,
     ):
+        """
+        Create a GPUFuncOp with the provided `function_type`, `sym_name`, `kernel`, `body_builder`, `known_block_size`, and `known_grid_size`.
+        - `function_type` is a FunctionType or a TypeAttr.
+        - `sym_name` is a string or a StringAttr representing the function name.
+        - `kernel` is a boolean representing whether the function is a kernel.
+        - `body_builder` is an optional callback. When provided, a new entry block
+          is created and the callback is invoked with the new op as argument within
+          an InsertionPoint context already set for the block. The callback is
+          expected to insert a terminator in the block.
+        - `known_block_size` is an optional list of integers or a DenseI32ArrayAttr representing the known block size.
+        - `known_grid_size` is an optional list of integers or a DenseI32ArrayAttr representing the known grid size.
+        """
         function_type = (
             TypeAttr.get(function_type)
             if not isinstance(function_type, TypeAttr)
@@ -62,6 +81,20 @@ def __init__(
         if body_builder is not None:
             with InsertionPoint(self.add_entry_block()):
                 body_builder(self)
+        if known_block_size is not None:
+            if isinstance(known_block_size, list):
+                self.attributes[KNOWN_BLOCK_SIZE_ATTRIBUTE_NAME] = (
+                    DenseI32ArrayAttr.get(known_block_size)
+                )
+            else:
+                self.attributes[KNOWN_BLOCK_SIZE_ATTRIBUTE_NAME] = known_block_size
+        if known_grid_size is not None:
+            if isinstance(known_grid_size, list):
+                self.attributes[KNOWN_GRID_SIZE_ATTRIBUTE_NAME] = (
+                    DenseI32ArrayAttr.get(known_grid_size)
+                )
+            else:
+                self.attributes[KNOWN_GRID_SIZE_ATTRIBUTE_NAME] = known_grid_size
 
     @property
     def type(self) -> FunctionType:
@@ -77,6 +110,18 @@ def name(self) -> StringAttr:
     def is_kernel(self) -> bool:
         return KERNEL_ATTRIBUTE_NAME in self.attributes
 
+    @property
+    def known_block_size(self) -> Optional[DenseI32ArrayAttr]:
+        if KNOWN_BLOCK_SIZE_ATTRIBUTE_NAME not in self.attributes:
+            return None
+        return self.attributes[KNOWN_BLOCK_SIZE_ATTRIBUTE_NAME]
+
+    @property
+    def known_grid_size(self) -> Optional[DenseI32ArrayAttr]:
+        if KNOWN_GRID_SIZE_ATTRIBUTE_NAME not in self.attributes:
+            return None
+        return self.attributes[KNOWN_GRID_SIZE_ATTRIBUTE_NAME]
+
     def add_entry_block(self) -> Block:
         function_type = self.type
         return self.body.blocks.append(
diff --git a/mlir/test/python/dialects/gpu/dialect.py b/mlir/test/python/dialects/gpu/dialect.py
index ce6e3df634e90..ecc9278b04dc5 100644
--- a/mlir/test/python/dialects/gpu/dialect.py
+++ b/mlir/test/python/dialects/gpu/dialect.py
@@ -1,10 +1,10 @@
 # RUN: %PYTHON %s | FileCheck %s
 
 from mlir.ir import *
+import mlir.ir as ir
 import mlir.dialects.gpu as gpu
 import mlir.dialects.gpu.passes
 from mlir.passmanager import *
-import mlir.ir as ir
 
 
 def run(f):
@@ -70,6 +70,7 @@ def testObjectAttr():
 # CHECK-LABEL: testGPUFuncOp
 @run
 def testGPUFuncOp():
+    assert gpu.GPUFuncOp.__doc__ is not None
     module = Module.create()
     with InsertionPoint(module.body):
         gpu_module_name = StringAttr.get("gpu_module")
@@ -77,8 +78,8 @@ def testGPUFuncOp():
         block = gpumodule.bodyRegion.blocks.append()
 
         def builder(func: gpu.GPUFuncOp) -> None:
-            _ = gpu.GlobalIdOp(gpu.Dimension.x)
-            _ = gpu.ReturnOp([])
+            gpu.GlobalIdOp(gpu.Dimension.x)
+            gpu.ReturnOp([])
 
         with InsertionPoint(block):
             name = StringAttr.get("kernel0")
@@ -96,6 +97,8 @@ def builder(func: gpu.GPUFuncOp) -> None:
                 sym_name="kernel1",
                 kernel=True,
                 body_builder=builder,
+                known_block_size=[1, 2, 3],
+                known_grid_size=DenseI32ArrayAttr.get([4, 5, 6]),
             )
 
             assert func.name.value == "kernel1"
@@ -104,13 +107,17 @@ def builder(func: gpu.GPUFuncOp) -> None:
             assert func.arguments == []
             assert func.entry_block == func.body.blocks[0]
             assert func.is_kernel
+            assert func.known_block_size == DenseI32ArrayAttr.get([1, 2, 3])
+            assert func.known_grid_size == DenseI32ArrayAttr.get([4, 5, 6])
 
-            non_kernel_func = gpu.GPUFuncOp(
+            func = gpu.GPUFuncOp(
                 func_type,
                 sym_name="non_kernel_func",
                 body_builder=builder,
             )
-            assert not non_kernel_func.is_kernel
+            assert not func.is_kernel
+            assert func.known_block_size is None
+            assert func.known_grid_size is None
 
     print(module)
 
@@ -119,7 +126,9 @@ def builder(func: gpu.GPUFuncOp) -> None:
     # CHECK:   %[[VAL_0:.*]] = gpu.global_id  x
     # CHECK:   gpu.return
     # CHECK: }
-    # CHECK: gpu.func @kernel1() kernel {
+    # CHECK: gpu.func @kernel1() kernel attributes
+    # CHECK-SAME: gpu.known_block_size = array<i32: 1, 2, 3>
+    # CHECK-SAME: gpu.known_grid_size = array<i32: 4, 5, 6>
     # CHECK:   %[[VAL_0:.*]] = gpu.global_id  x
     # CHECK:   gpu.return
     # CHECK: }

>From 09a604cd6642ed1500a998d6c1e3c623a9cbba7b Mon Sep 17 00:00:00 2001
From: Asher Mancinelli <ashermancinelli at gmail.com>
Date: Wed, 15 Oct 2025 10:43:35 -0700
Subject: [PATCH 3/4] format

---
 mlir/python/mlir/dialects/gpu/__init__.py | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/mlir/python/mlir/dialects/gpu/__init__.py b/mlir/python/mlir/dialects/gpu/__init__.py
index 9192ff1ad5e4e..726e5e440ceed 100644
--- a/mlir/python/mlir/dialects/gpu/__init__.py
+++ b/mlir/python/mlir/dialects/gpu/__init__.py
@@ -39,6 +39,7 @@
 ARGUMENT_ATTRIBUTE_NAME = "arg_attrs"
 RESULT_ATTRIBUTE_NAME = "res_attrs"
 
+
 @_ods_cext.register_operation(_Dialect, replace=True)
 class GPUFuncOp(GPUFuncOp):
     __doc__ = GPUFuncOp.__doc__
@@ -90,8 +91,8 @@ def __init__(
                 self.attributes[KNOWN_BLOCK_SIZE_ATTRIBUTE_NAME] = known_block_size
         if known_grid_size is not None:
             if isinstance(known_grid_size, list):
-                self.attributes[KNOWN_GRID_SIZE_ATTRIBUTE_NAME] = (
-                    DenseI32ArrayAttr.get(known_grid_size)
+                self.attributes[KNOWN_GRID_SIZE_ATTRIBUTE_NAME] = DenseI32ArrayAttr.get(
+                    known_grid_size
                 )
             else:
                 self.attributes[KNOWN_GRID_SIZE_ATTRIBUTE_NAME] = known_grid_size

>From 19ce1177cca0e2a4959bdea6e4b84623ded3828c Mon Sep 17 00:00:00 2001
From: Asher Mancinelli <ashermancinelli at gmail.com>
Date: Wed, 15 Oct 2025 11:16:47 -0700
Subject: [PATCH 4/4] Apply suggestions

---
 mlir/python/mlir/dialects/gpu/__init__.py | 111 +++++++++++-----------
 mlir/test/python/dialects/gpu/dialect.py  |  13 ++-
 2 files changed, 64 insertions(+), 60 deletions(-)

diff --git a/mlir/python/mlir/dialects/gpu/__init__.py b/mlir/python/mlir/dialects/gpu/__init__.py
index 726e5e440ceed..df17ebd9a6db6 100644
--- a/mlir/python/mlir/dialects/gpu/__init__.py
+++ b/mlir/python/mlir/dialects/gpu/__init__.py
@@ -30,28 +30,27 @@
     raise RuntimeError("Error loading imports from extension module") from e
 
 
-KERNEL_ATTRIBUTE_NAME = "gpu.kernel"
-KNOWN_BLOCK_SIZE_ATTRIBUTE_NAME = "gpu.known_block_size"
-KNOWN_GRID_SIZE_ATTRIBUTE_NAME = "gpu.known_grid_size"
-
-FUNCTION_TYPE_ATTRIBUTE_NAME = "function_type"
-SYM_NAME_ATTRIBUTE_NAME = "sym_name"
-ARGUMENT_ATTRIBUTE_NAME = "arg_attrs"
-RESULT_ATTRIBUTE_NAME = "res_attrs"
-
-
 @_ods_cext.register_operation(_Dialect, replace=True)
 class GPUFuncOp(GPUFuncOp):
     __doc__ = GPUFuncOp.__doc__
 
+    KERNEL_ATTR_NAME = "gpu.kernel"
+    KNOWN_BLOCK_SIZE_ATTR_NAME = "known_block_size"
+    KNOWN_GRID_SIZE_ATTR_NAME = "known_grid_size"
+
+    FUNCTION_TYPE_ATTR_NAME = "function_type"
+    SYM_NAME_ATTR_NAME = "sym_name"
+    ARGUMENT_ATTR_NAME = "arg_attrs"
+    RESULT_ATTR_NAME = "res_attrs"
+
     def __init__(
         self,
         function_type: Union[FunctionType, TypeAttr],
         sym_name: Optional[Union[str, StringAttr]] = None,
         kernel: Optional[bool] = None,
         body_builder: Optional[Callable[[GPUFuncOp], None]] = None,
-        known_block_size: Optional[Union[List[int], DenseI32ArrayAttr]] = None,
-        known_grid_size: Optional[Union[List[int], DenseI32ArrayAttr]] = None,
+        known_block_size: Optional[Union[Sequence[int], DenseI32ArrayAttr]] = None,
+        known_grid_size: Optional[Union[Sequence[int], DenseI32ArrayAttr]] = None,
         *args,
         loc=None,
         ip=None,
@@ -75,56 +74,52 @@ def __init__(
             else function_type
         )
         super().__init__(function_type, *args, loc=loc, ip=ip, **kwargs)
+
+        if isinstance(sym_name, str):
+            sym_name = StringAttr.get(str(sym_name))
         if sym_name is not None:
-            self.attributes[SYM_NAME_ATTRIBUTE_NAME] = StringAttr.get(str(sym_name))
+            self.attributes[self.SYM_NAME_ATTR_NAME] = sym_name
+
         if kernel:
-            self.attributes[KERNEL_ATTRIBUTE_NAME] = UnitAttr.get()
-        if body_builder is not None:
-            with InsertionPoint(self.add_entry_block()):
-                body_builder(self)
+            self.attributes[self.KERNEL_ATTR_NAME] = UnitAttr.get()
         if known_block_size is not None:
-            if isinstance(known_block_size, list):
-                self.attributes[KNOWN_BLOCK_SIZE_ATTRIBUTE_NAME] = (
+            if isinstance(known_block_size, Sequence):
+                self.attributes[self.KNOWN_BLOCK_SIZE_ATTR_NAME] = (
                     DenseI32ArrayAttr.get(known_block_size)
                 )
+            elif isinstance(known_block_size, DenseI32ArrayAttr):
+                self.attributes[self.KNOWN_BLOCK_SIZE_ATTR_NAME] = known_block_size
             else:
-                self.attributes[KNOWN_BLOCK_SIZE_ATTRIBUTE_NAME] = known_block_size
+                raise ValueError(
+                    "known_block_size must be a list of integers or a DenseI32ArrayAttr"
+                )
+
         if known_grid_size is not None:
-            if isinstance(known_grid_size, list):
-                self.attributes[KNOWN_GRID_SIZE_ATTRIBUTE_NAME] = DenseI32ArrayAttr.get(
+            if isinstance(known_grid_size, Sequence):
+                self.attributes[self.KNOWN_GRID_SIZE_ATTR_NAME] = DenseI32ArrayAttr.get(
                     known_grid_size
                 )
+            elif isinstance(known_grid_size, DenseI32ArrayAttr):
+                self.attributes[self.KNOWN_GRID_SIZE_ATTR_NAME] = known_grid_size
             else:
-                self.attributes[KNOWN_GRID_SIZE_ATTRIBUTE_NAME] = known_grid_size
+                raise ValueError(
+                    "known_grid_size must be a list of integers or a DenseI32ArrayAttr"
+                )
 
-    @property
-    def type(self) -> FunctionType:
-        return FunctionType(
-            TypeAttr(self.attributes[FUNCTION_TYPE_ATTRIBUTE_NAME]).value
-        )
+        if body_builder is not None:
+            with InsertionPoint(self.add_entry_block()):
+                body_builder(self)
 
     @property
     def name(self) -> StringAttr:
-        return StringAttr(self.attributes[SYM_NAME_ATTRIBUTE_NAME])
+        return StringAttr(self.attributes[self.SYM_NAME_ATTR_NAME])
 
     @property
     def is_kernel(self) -> bool:
-        return KERNEL_ATTRIBUTE_NAME in self.attributes
-
-    @property
-    def known_block_size(self) -> Optional[DenseI32ArrayAttr]:
-        if KNOWN_BLOCK_SIZE_ATTRIBUTE_NAME not in self.attributes:
-            return None
-        return self.attributes[KNOWN_BLOCK_SIZE_ATTRIBUTE_NAME]
-
-    @property
-    def known_grid_size(self) -> Optional[DenseI32ArrayAttr]:
-        if KNOWN_GRID_SIZE_ATTRIBUTE_NAME not in self.attributes:
-            return None
-        return self.attributes[KNOWN_GRID_SIZE_ATTRIBUTE_NAME]
+        return self.KERNEL_ATTR_NAME in self.attributes
 
     def add_entry_block(self) -> Block:
-        function_type = self.type
+        function_type = self.function_type.value
         return self.body.blocks.append(
             *function_type.inputs,
             arg_locs=[self.location for _ in function_type.inputs],
@@ -136,34 +131,38 @@ def entry_block(self) -> Block:
 
     @property
     def arguments(self) -> Sequence[Type]:
-        return self.type.inputs
+        return self.function_type.value.inputs
 
     @property
     def arg_attrs(self):
-        if ARGUMENT_ATTRIBUTE_NAME not in self.attributes:
-            return ArrayAttr.get([DictAttr.get({}) for _ in self.type.inputs])
-        return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])
+        if self.ARGUMENT_ATTR_NAME not in self.attributes:
+            return ArrayAttr.get(
+                [DictAttr.get({}) for _ in self.function_type.value.inputs]
+            )
+        return ArrayAttr(self.attributes[self.ARGUMENT_ATTR_NAME])
 
     @arg_attrs.setter
-    def arg_attrs(self, attribute: Union[ArrayAttr, list[Attribute]]):
+    def arg_attrs(self, attribute: Union[ArrayAttr, Sequence[Attribute]]):
         if isinstance(attribute, ArrayAttr):
-            self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
+            self.attributes[self.ARGUMENT_ATTR_NAME] = attribute
         else:
-            self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
+            self.attributes[self.ARGUMENT_ATTR_NAME] = ArrayAttr.get(
                 attribute, context=self.context
             )
 
     @property
     def result_attrs(self) -> Optional[ArrayAttr]:
-        if RESULT_ATTRIBUTE_NAME not in self.attributes:
-            return ArrayAttr.get([DictAttr.get({}) for _ in self.type.results])
-        return self.attributes[RESULT_ATTRIBUTE_NAME]
+        if self.RESULT_ATTR_NAME not in self.attributes:
+            return ArrayAttr.get(
+                [DictAttr.get({}) for _ in self.function_type.value.results]
+            )
+        return self.attributes[self.RESULT_ATTR_NAME]
 
     @result_attrs.setter
-    def result_attrs(self, attribute: Union[ArrayAttr, list[Attribute]]):
+    def result_attrs(self, attribute: Union[ArrayAttr, Sequence[Attribute]]):
         if isinstance(attribute, ArrayAttr):
-            self.attributes[RESULT_ATTRIBUTE_NAME] = attribute
+            self.attributes[self.RESULT_ATTR_NAME] = attribute
         else:
-            self.attributes[RESULT_ATTRIBUTE_NAME] = ArrayAttr.get(
+            self.attributes[self.RESULT_ATTR_NAME] = ArrayAttr.get(
                 attribute, context=self.context
             )
diff --git a/mlir/test/python/dialects/gpu/dialect.py b/mlir/test/python/dialects/gpu/dialect.py
index ecc9278b04dc5..beccc65a15999 100644
--- a/mlir/test/python/dialects/gpu/dialect.py
+++ b/mlir/test/python/dialects/gpu/dialect.py
@@ -86,8 +86,8 @@ def builder(func: gpu.GPUFuncOp) -> None:
             func_type = ir.FunctionType.get(inputs=[], results=[])
             type_attr = TypeAttr.get(func_type)
             func = gpu.GPUFuncOp(type_attr, name)
-            func.attributes[gpu.SYM_NAME_ATTRIBUTE_NAME] = name
-            func.attributes[gpu.KERNEL_ATTRIBUTE_NAME] = UnitAttr.get()
+            func.attributes["sym_name"] = name
+            func.attributes["gpu.kernel"] = UnitAttr.get()
             block = func.body.blocks.append()
             with InsertionPoint(block):
                 builder(func)
@@ -102,13 +102,18 @@ def builder(func: gpu.GPUFuncOp) -> None:
             )
 
             assert func.name.value == "kernel1"
+            assert func.function_type.value == func_type
             assert func.arg_attrs == ArrayAttr.get([])
             assert func.result_attrs == ArrayAttr.get([])
             assert func.arguments == []
             assert func.entry_block == func.body.blocks[0]
             assert func.is_kernel
-            assert func.known_block_size == DenseI32ArrayAttr.get([1, 2, 3])
-            assert func.known_grid_size == DenseI32ArrayAttr.get([4, 5, 6])
+            assert func.known_block_size == DenseI32ArrayAttr.get(
+                [1, 2, 3]
+            ), func.known_block_size
+            assert func.known_grid_size == DenseI32ArrayAttr.get(
+                [4, 5, 6]
+            ), func.known_grid_size
 
             func = gpu.GPUFuncOp(
                 func_type,



More information about the Mlir-commits mailing list