[Mlir-commits] [mlir] 1f8618f - [mlir] python enum bindings generator

Alex Zinenko llvmlistbot at llvm.org
Mon Jul 31 08:43:08 PDT 2023


Author: Alex Zinenko
Date: 2023-07-31T15:42:56Z
New Revision: 1f8618f88c58509d0ef8fae813f708a9dc2a86d8

URL: https://github.com/llvm/llvm-project/commit/1f8618f88c58509d0ef8fae813f708a9dc2a86d8
DIFF: https://github.com/llvm/llvm-project/commit/1f8618f88c58509d0ef8fae813f708a9dc2a86d8.diff

LOG: [mlir] python enum bindings generator

Add an ODS (tablegen) backend to generate Python enum classes and
attribute builders for enum attributes defined in ODS. This will allow
us to keep the enum attribute definitions in sync between C++ and
Python, as opposed to handwritten enum classes in Python that may end up
using mismatching values. This also makes autogenerated bindings more
convenient even in absence of mixins.

Use this backend for the transform dialect failure propagation mode enum
attribute as demonstration.

Reviewed By: ingomueller-net

Differential Revision: https://reviews.llvm.org/D156553

Added: 
    mlir/test/mlir-tblgen/enums-python-bindings.td
    mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp

Modified: 
    mlir/python/CMakeLists.txt
    mlir/python/mlir/dialects/_transform_ops_ext.py
    mlir/python/mlir/dialects/transform/__init__.py
    mlir/tools/mlir-tblgen/CMakeLists.txt
    utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel

Removed: 
    


################################################################################
diff  --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index d233194b1819e5..3263bc1db7834c 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -134,6 +134,15 @@ declare_mlir_dialect_python_bindings(
     _mlir_libs/_mlir/dialects/transform/__init__.pyi
   DIALECT_NAME transform)
 
+set(LLVM_TARGET_DEFINITIONS "${CMAKE_CURRENT_SOURCE_DIR}/mlir/dialects/TransformOps.td")
+mlir_tablegen("dialects/_transform_enum_gen.py" -gen-python-enum-bindings)
+add_public_tablegen_target(MLIRTransformDialectPyEnumGen)
+declare_mlir_python_sources(
+  MLIRPythonSources.Dialects.transform.enum_gen
+  ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}"
+  ADD_TO_PARENT MLIRPythonSources.Dialects.transform
+  SOURCES "dialects/_transform_enum_gen.py")
+
 declare_mlir_dialect_extension_python_bindings(
   ADD_TO_PARENT MLIRPythonSources.Dialects
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"

diff  --git a/mlir/python/mlir/dialects/_transform_ops_ext.py b/mlir/python/mlir/dialects/_transform_ops_ext.py
index 0db2e3bd93a3aa..b1e7b892536f4a 100644
--- a/mlir/python/mlir/dialects/_transform_ops_ext.py
+++ b/mlir/python/mlir/dialects/_transform_ops_ext.py
@@ -15,68 +15,66 @@
 
 
 class CastOp:
-
-  def __init__(
-      self,
-      result_type: Type,
-      target: Union[Operation, Value],
-      *,
-      loc=None,
-      ip=None,
-  ):
-    super().__init__(
-        result_type, _get_op_result_or_value(target), loc=loc, ip=ip
-    )
+    def __init__(
+        self,
+        result_type: Type,
+        target: Union[Operation, Value],
+        *,
+        loc=None,
+        ip=None,
+    ):
+        super().__init__(result_type, _get_op_result_or_value(target), loc=loc, ip=ip)
 
 
 class ApplyPatternsOp:
+    def __init__(
+        self,
+        target: Union[Operation, Value, OpView],
+        *,
+        loc=None,
+        ip=None,
+    ):
+        operands = []
+        operands.append(_get_op_result_or_value(target))
+        super().__init__(
+            self.build_generic(
+                attributes={},
+                results=[],
+                operands=operands,
+                successors=None,
+                regions=None,
+                loc=loc,
+                ip=ip,
+            )
+        )
+        self.regions[0].blocks.append()
 
-  def __init__(
-      self,
-      target: Union[Operation, Value, OpView],
-      *,
-      loc=None,
-      ip=None,
-  ):
-    operands = []
-    operands.append(_get_op_result_or_value(target))
-    super().__init__(
-        self.build_generic(attributes={},
-                           results=[],
-                           operands=operands,
-                           successors=None,
-                           regions=None,
-                           loc=loc,
-                           ip=ip))
-    self.regions[0].blocks.append()
-
-  @property
-  def patterns(self) -> Block:
-    return self.regions[0].blocks[0]
+    @property
+    def patterns(self) -> Block:
+        return self.regions[0].blocks[0]
 
 
 class testGetParentOp:
-
-  def __init__(
-      self,
-      result_type: Type,
-      target: Union[Operation, Value],
-      *,
-      isolated_from_above: bool = False,
-      op_name: Optional[str] = None,
-      deduplicate: bool = False,
-      loc=None,
-      ip=None,
-  ):
-    super().__init__(
-        result_type,
-        _get_op_result_or_value(target),
-        isolated_from_above=isolated_from_above,
-        op_name=op_name,
-        deduplicate=deduplicate,
-        loc=loc,
-        ip=ip,
-    )
+    def __init__(
+        self,
+        result_type: Type,
+        target: Union[Operation, Value],
+        *,
+        isolated_from_above: bool = False,
+        op_name: Optional[str] = None,
+        deduplicate: bool = False,
+        loc=None,
+        ip=None,
+    ):
+        super().__init__(
+            result_type,
+            _get_op_result_or_value(target),
+            isolated_from_above=isolated_from_above,
+            op_name=op_name,
+            deduplicate=deduplicate,
+            loc=loc,
+            ip=ip,
+        )
 
 
 class MergeHandlesOp:
@@ -130,12 +128,6 @@ def __init__(
             else None
         )
         root_type = root.type if not isinstance(target, Type) else target
-        if not isinstance(failure_propagation_mode, Attribute):
-            failure_propagation_mode_attr = IntegerAttr.get(
-                IntegerType.get_signless(32), failure_propagation_mode._as_int()
-            )
-        else:
-            failure_propagation_mode_attr = failure_propagation_mode
 
         if extra_bindings is None:
             extra_bindings = []
@@ -152,7 +144,7 @@ def __init__(
 
         super().__init__(
             results_=results,
-            failure_propagation_mode=failure_propagation_mode_attr,
+            failure_propagation_mode=failure_propagation_mode,
             root=root,
             extra_bindings=extra_bindings,
         )

diff  --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index b505a490aeb97b..b020ad35fcf062 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -2,22 +2,6 @@
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
-from enum import Enum
-
-
-class FailurePropagationMode(Enum):
-    """Propagation mode for silenceable errors."""
-
-    PROPAGATE = 1
-    SUPPRESS = 2
-
-    def _as_int(self):
-        if self is FailurePropagationMode.PROPAGATE:
-            return 1
-
-        assert self is FailurePropagationMode.SUPPRESS
-        return 2
-
-
+from .._transform_enum_gen import *
 from .._transform_ops_gen import *
 from ..._mlir_libs._mlirDialectsTransform import *

diff  --git a/mlir/test/mlir-tblgen/enums-python-bindings.td b/mlir/test/mlir-tblgen/enums-python-bindings.td
new file mode 100644
index 00000000000000..5272eba50f0e7a
--- /dev/null
+++ b/mlir/test/mlir-tblgen/enums-python-bindings.td
@@ -0,0 +1,57 @@
+// RUN: mlir-tblgen -gen-python-enum-bindings %s -I %S/../../include | FileCheck %s
+
+include "mlir/IR/EnumAttr.td"
+
+// CHECK: Autogenerated by mlir-tblgen; don't manually edit.
+
+// CHECK: from enum import Enum
+// CHECK: from ._ods_common import _cext as _ods_cext
+// CHECK: _ods_ir = _ods_cext.ir
+
+def One : I32EnumAttrCase<"CaseOne", 1, "one">;
+def Two : I32EnumAttrCase<"CaseTwo", 2, "two">;
+
+def MyEnum : I32EnumAttr<"MyEnum", "An example 32-bit enum", [One, Two]>;
+// CHECK: def _register_attribute_builder(kind):
+// CHECK:     def decorator_builder(func):
+// CHECK:         _ods_ir.AttrBuilder.insert(kind, func)
+// CHECK:         return func
+// CHECK:     return decorator_builder
+
+// CHECK-LABEL: class MyEnum(Enum):
+// CHECK:     """An example 32-bit enum"""
+
+// CHECK:     CASE_ONE = 1
+// CHECK:     CASE_TWO = 2
+
+// CHECK:     def _as_int(self):
+// CHECK:         if self is MyEnum.CASE_ONE:
+// CHECK:             return 1
+// CHECK:         if self is MyEnum.CASE_TWO:
+// CHECK:             return 2
+// CHECK:         assert False, "Unknown MyEnum enum entry."
+
+def One64 : I64EnumAttrCase<"CaseOne64", 1, "one">;
+def Two64 : I64EnumAttrCase<"CaseTwo64", 2, "two">;
+
+def MyEnum64 : I64EnumAttr<"MyEnum64", "An example 64-bit enum", [One64, Two64]>;
+// CHECK: @_register_attribute_builder("MyEnum")
+// CHECK: def _my_enum(x, context):
+// CHECK:     return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(32, context=context), x._as_int())
+
+// CHECK-LABEL: class MyEnum64(Enum):
+// CHECK:     """An example 64-bit enum"""
+
+// CHECK:     CASE_ONE64 = 1
+// CHECK:     CASE_TWO64 = 2
+
+// CHECK:     def _as_int(self):
+// CHECK:         if self is MyEnum64.CASE_ONE64:
+// CHECK:             return 1
+// CHECK:         if self is MyEnum64.CASE_TWO64:
+// CHECK:             return 2
+// CHECK:         assert False, "Unknown MyEnum64 enum entry."
+
+// CHECK: @_register_attribute_builder("MyEnum64")
+// CHECK: def _my_enum64(x, context):
+// CHECK:     return _ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless(64, context=context), x._as_int())

diff  --git a/mlir/tools/mlir-tblgen/CMakeLists.txt b/mlir/tools/mlir-tblgen/CMakeLists.txt
index ce16899fd4dd24..f2c5e4b3f87afb 100644
--- a/mlir/tools/mlir-tblgen/CMakeLists.txt
+++ b/mlir/tools/mlir-tblgen/CMakeLists.txt
@@ -14,6 +14,7 @@ add_tablegen(mlir-tblgen MLIR
   DialectGen.cpp
   DirectiveCommonGen.cpp
   EnumsGen.cpp
+  EnumPythonBindingGen.cpp
   FormatGen.cpp
   LLVMIRConversionGen.cpp
   LLVMIRIntrinsicGen.cpp

diff  --git a/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
new file mode 100644
index 00000000000000..9748e33e2ebe8a
--- /dev/null
+++ b/mlir/tools/mlir-tblgen/EnumPythonBindingGen.cpp
@@ -0,0 +1,130 @@
+//===- EnumPythonBindingGen.cpp - Generator of Python API for ODS enums ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// EnumPythonBindingGen uses ODS specification of MLIR enum attributes to
+// generate the corresponding Python binding classes.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/TableGen/Attribute.h"
+#include "mlir/TableGen/GenInfo.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/TableGen/Record.h"
+
+using namespace mlir;
+using namespace mlir::tblgen;
+
+/// File header and includes.
+constexpr const char *fileHeader = R"Py(
+# Autogenerated by mlir-tblgen; don't manually edit.
+
+from enum import Enum
+from ._ods_common import _cext as _ods_cext
+_ods_ir = _ods_cext.ir
+
+# Convenience decorator for registering user-friendly Attribute builders.
+def _register_attribute_builder(kind):
+    def decorator_builder(func):
+        _ods_ir.AttrBuilder.insert(kind, func)
+        return func
+
+    return decorator_builder
+
+)Py";
+
+/// Makes enum case name Python-compatible, i.e. UPPER_SNAKE_CASE.
+static std::string makePythonEnumCaseName(StringRef name) {
+  return StringRef(llvm::convertToSnakeFromCamelCase(name)).upper();
+}
+
+/// Emits the Python class for the given enum.
+static void emitEnumClass(StringRef enumName, StringRef description,
+                          ArrayRef<EnumAttrCase> cases, raw_ostream &os) {
+  os << llvm::formatv("class {0}(Enum):\n", enumName);
+  if (!description.empty())
+    os << llvm::formatv("    \"\"\"{0}\"\"\"\n", description);
+  os << "\n";
+
+  for (const EnumAttrCase &enumCase : cases) {
+    os << llvm::formatv("    {0} = {1}\n",
+                        makePythonEnumCaseName(enumCase.getSymbol()),
+                        enumCase.getValue());
+  }
+
+  os << "\n";
+  os << llvm::formatv("    def _as_int(self):\n");
+  for (const EnumAttrCase &enumCase : cases) {
+    os << llvm::formatv("        if self is {0}.{1}:\n", enumName,
+                        makePythonEnumCaseName(enumCase.getSymbol()));
+    os << llvm::formatv("            return {0}\n", enumCase.getValue());
+  }
+  os << llvm::formatv("        assert False, \"Unknown {0} enum entry.\"\n\n\n",
+                      enumName);
+}
+
+/// Attempts to extract the bitwidth B from string "uintB_t" describing the
+/// type. This bitwidth information is not readily available in ODS. Returns
+/// `false` on success, `true` on failure.
+static bool extractUIntBitwidth(StringRef uintType, int64_t &bitwidth) {
+  if (!uintType.consume_front("uint"))
+    return true;
+  if (!uintType.consume_back("_t"))
+    return true;
+  return uintType.getAsInteger(/*Radix=*/10, bitwidth);
+}
+
+/// Emits an attribute builder for the given enum attribute to support automatic
+/// conversion between enum values and attributes in Python. Returns
+/// `false` on success, `true` on failure.
+static bool emitAttributeBuilder(const EnumAttr &enumAttr, raw_ostream &os) {
+  int64_t bitwidth;
+  if (extractUIntBitwidth(enumAttr.getUnderlyingType(), bitwidth)) {
+    llvm::errs() << "failed to identify bitwidth of "
+                 << enumAttr.getUnderlyingType();
+    return true;
+  }
+
+  os << llvm::formatv("@_register_attribute_builder(\"{0}\")\n",
+                      enumAttr.getAttrDefName());
+  os << llvm::formatv(
+      "def _{0}(x, context):\n",
+      llvm::convertToSnakeFromCamelCase(enumAttr.getAttrDefName()));
+  os << llvm::formatv(
+      "    return "
+      "_ods_ir.IntegerAttr.get(_ods_ir.IntegerType.get_signless({0}, "
+      "context=context), x._as_int())\n\n",
+      bitwidth);
+  return false;
+}
+
+/// Emits Python bindings for all enums in the record keeper. Returns
+/// `false` on success, `true` on failure.
+static bool emitPythonEnums(const llvm::RecordKeeper &recordKeeper,
+                            raw_ostream &os) {
+  os << fileHeader;
+  std::vector<llvm::Record *> defs =
+      recordKeeper.getAllDerivedDefinitionsIfDefined("EnumAttrInfo");
+  for (const llvm::Record *def : defs) {
+    EnumAttr enumAttr(*def);
+    if (enumAttr.isBitEnum()) {
+      llvm::errs() << "bit enums not supported\n";
+      return true;
+    }
+    emitEnumClass(enumAttr.getEnumClassName(), enumAttr.getSummary(),
+                  enumAttr.getAllCases(), os);
+    emitAttributeBuilder(enumAttr, os);
+  }
+  return false;
+}
+
+// Registers the enum utility generator to mlir-tblgen.
+static mlir::GenRegistration
+    genPythonEnumBindings("gen-python-enum-bindings",
+                          "Generate Python bindings for enum attributes",
+                          &emitPythonEnums);

diff  --git a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
index bb818edf51793e..cc41eb564cf47a 100644
--- a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel
@@ -732,6 +732,25 @@ filegroup(
 # Transform dialect and extensions.
 ##---------------------------------------------------------------------------##
 
+
+gentbl_filegroup(
+    name = "TransformEnumPyGen",
+    tbl_outs = [
+        (
+            ["-gen-python-enum-bindings"],
+            "mlir/dialects/_transform_enum_gen.py",
+        ),
+    ],
+    tblgen = "//mlir:mlir-tblgen",
+    td_file = "mlir/dialects/TransformOps.td",
+    deps = [
+        "//mlir:CallInterfacesTdFiles",
+        "//mlir:FunctionInterfacesTdFiles",
+        "//mlir:OpBaseTdFiles",
+        "//mlir:TransformDialectTdFiles",
+    ],
+)
+
 gentbl_filegroup(
     name = "TransformOpsPyGen",
     tbl_outs = [
@@ -898,6 +917,7 @@ filegroup(
         ":MemRefTransformOpsPyGen",
         ":PDLTransformOpsPyGen",
         ":StructuredTransformOpsPyGen",
+        ":TransformEnumPyGen",
         ":TransformOpsPyGen",
     ],
 )


        


More information about the Mlir-commits mailing list