[llvm-branch-commits] [mlir] 922b26c - Add Python bindings for the builtin dialect

Mehdi Amini via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Jan 21 14:55:01 PST 2021


Author: Mehdi Amini
Date: 2021-01-21T22:44:44Z
New Revision: 922b26cde4d1c89a5fa90e6a1d6d97d0f8eace6d

URL: https://github.com/llvm/llvm-project/commit/922b26cde4d1c89a5fa90e6a1d6d97d0f8eace6d
DIFF: https://github.com/llvm/llvm-project/commit/922b26cde4d1c89a5fa90e6a1d6d97d0f8eace6d.diff

LOG: Add Python bindings for the builtin dialect

This includes some minor customization for FuncOp and ModuleOp.

Differential Revision: https://reviews.llvm.org/D95022

Added: 
    mlir/lib/Bindings/Python/BuiltinOps.td
    mlir/lib/Bindings/Python/mlir/dialects/_builtin.py
    mlir/test/Bindings/Python/.style.yapf
    mlir/test/Bindings/Python/dialects/builtin.py

Modified: 
    mlir/lib/Bindings/Python/CMakeLists.txt
    mlir/lib/Bindings/Python/mlir/dialects/__init__.py
    mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/BuiltinOps.td b/mlir/lib/Bindings/Python/BuiltinOps.td
new file mode 100644
index 000000000000..ecbb8227d490
--- /dev/null
+++ b/mlir/lib/Bindings/Python/BuiltinOps.td
@@ -0,0 +1,15 @@
+//===-- BuiltinOps.td - Entry point for builtin bindings ---*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PYTHON_BINDINGS_BUILTIN_OPS
+#define PYTHON_BINDINGS_BUILTIN_OPS
+
+include "mlir/Bindings/Python/Attributes.td"
+include "mlir/IR/BuiltinOps.td"
+
+#endif

diff  --git a/mlir/lib/Bindings/Python/CMakeLists.txt b/mlir/lib/Bindings/Python/CMakeLists.txt
index 1749ea2e5472..951aa7883c90 100644
--- a/mlir/lib/Bindings/Python/CMakeLists.txt
+++ b/mlir/lib/Bindings/Python/CMakeLists.txt
@@ -11,6 +11,7 @@ set(PY_SRC_FILES
   mlir/ir.py
   mlir/dialects/__init__.py
   mlir/dialects/_linalg.py
+  mlir/dialects/_builtin.py
   mlir/ir.py
   mlir/passmanager.py
   mlir/transforms/__init__.py
@@ -36,6 +37,11 @@ endforeach()
 # Generate dialect-specific bindings.
 ################################################################################
 
+add_mlir_dialect_python_bindings(MLIRBindingsPythonBuiltinOps
+  TD_FILE BuiltinOps.td
+  DIALECT_NAME builtin)
+add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonBuiltinOps)
+
 add_mlir_dialect_python_bindings(MLIRBindingsPythonLinalgOps
   TD_FILE LinalgOps.td
   DIALECT_NAME linalg

diff  --git a/mlir/lib/Bindings/Python/mlir/dialects/__init__.py b/mlir/lib/Bindings/Python/mlir/dialects/__init__.py
index 9c003b415438..f5a71bf88700 100644
--- a/mlir/lib/Bindings/Python/mlir/dialects/__init__.py
+++ b/mlir/lib/Bindings/Python/mlir/dialects/__init__.py
@@ -43,7 +43,7 @@ def class_decorator(parent_opview_cls: type):
     except AttributeError:
       # Try to default resolve it.
       try:
-        select_mixin = getattr(ext_module, parent_opview_cls.__name__)
+        mixin_cls = getattr(ext_module, parent_opview_cls.__name__)
       except AttributeError:
         pass
     else:

diff  --git a/mlir/lib/Bindings/Python/mlir/dialects/_builtin.py b/mlir/lib/Bindings/Python/mlir/dialects/_builtin.py
new file mode 100644
index 000000000000..8d430d5a50da
--- /dev/null
+++ b/mlir/lib/Bindings/Python/mlir/dialects/_builtin.py
@@ -0,0 +1,93 @@
+#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+#  See https://llvm.org/LICENSE.txt for license information.
+#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from mlir.ir import *
+
+
+class ModuleOp:
+  """Specialization for the module op class."""
+
+  def __init__(self, loc=None, ip=None):
+    super().__init__(
+        self._ods_build_default(operands=[], results=[], loc=loc, ip=ip))
+    body = self.regions[0].blocks.append()
+    with InsertionPoint(body):
+      Operation.create("module_terminator")
+
+  @property
+  def body(self):
+    return self.regions[0].blocks[0]
+
+
+class FuncOp:
+  """Specialization for the func op class."""
+
+  def __init__(self,
+               name,
+               type,
+               visibility,
+               body_builder=None,
+               loc=None,
+               ip=None):
+    """
+    Create a FuncOp with the provided `name`, `type`, and `visibility`.
+    - `name` is a string representing the function name.
+    - `type` is either a FunctionType or a pair of list describing inputs and
+      results.
+    - `visibility` is a string matching `public`, `private`, or `nested`. The
+      empty string implies a private visibility.
+    - `body_builder` is an optional callback, when provided a new entry block
+      is created and the callback is invoked with the new op as argument within
+      an InsertionPoint context already set for the block. The callback is
+      expected to insert a terminator in the block.
+    """
+    sym_name = StringAttr.get(str(name))
+
+    # If the type is passed as a tuple, build a FunctionType on the fly.
+    if isinstance(type, tuple):
+      type = FunctionType.get(inputs=type[0], results=type[1])
+
+    type = TypeAttr.get(type)
+    sym_visibility = StringAttr.get(
+        str(visibility)) if visibility is not None else None
+    super().__init__(sym_name, type, sym_visibility, loc, ip)
+    if body_builder:
+      entry_block = self.add_entry_block()
+      with InsertionPoint(entry_block):
+        body_builder(self)
+
+  @property
+  def is_external(self):
+    return len(self.regions[0].blocks) == 0
+
+  @property
+  def body(self):
+    return self.regions[0]
+
+  @property
+  def type(self):
+    return FunctionType(TypeAttr(self.attributes["type"]).value)
+
+  @property
+  def visibility(self):
+    return self.attributes["sym_visibility"]
+
+  @property
+  def name(self):
+    return self.attributes["sym_name"]
+
+  @property
+  def entry_block(self):
+    if self.is_external:
+      raise IndexError('External function does not have a body')
+    return self.regions[0].blocks[0]
+
+  def add_entry_block(self):
+    '''
+    Add an entry block to the function body using the function signature to infer block arguments
+    Returns the newly created block
+    '''
+    if not self.is_external:
+      raise IndexError('The function already has an entry block!')
+    self.body.blocks.append(*self.type.inputs)
+    return self.body.blocks[0]

diff  --git a/mlir/test/Bindings/Python/.style.yapf b/mlir/test/Bindings/Python/.style.yapf
new file mode 100644
index 000000000000..9ef1dc15ba62
--- /dev/null
+++ b/mlir/test/Bindings/Python/.style.yapf
@@ -0,0 +1,4 @@
+[style]
+  based_on_style = google
+  column_limit = 80
+  indent_width = 2

diff  --git a/mlir/test/Bindings/Python/dialects/builtin.py b/mlir/test/Bindings/Python/dialects/builtin.py
new file mode 100644
index 000000000000..447a255f6021
--- /dev/null
+++ b/mlir/test/Bindings/Python/dialects/builtin.py
@@ -0,0 +1,69 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from mlir.ir import *
+import mlir.dialects.builtin as builtin
+import mlir.dialects.std as std
+
+
+def run(f):
+  print("\nTEST:", f.__name__)
+  f()
+
+
+# CHECK-LABEL: TEST: testBuildFuncOp
+def testBuildFuncOp():
+  ctx = Context()
+  with Location.unknown(ctx) as loc:
+    m = builtin.ModuleOp()
+
+    f32 = F32Type.get()
+    tensor_type = RankedTensorType.get((2, 3, 4), f32)
+    with InsertionPoint.at_block_begin(m.body):
+      func = builtin.FuncOp(name="some_func",
+                            type=FunctionType.get(
+                                inputs=[tensor_type, tensor_type],
+                                results=[tensor_type]),
+                            visibility="nested")
+      # CHECK: Name is: "some_func"
+      print("Name is: ", func.name)
+
+      # CHECK: Type is: (tensor<2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
+      print("Type is: ", func.type)
+
+      # CHECK: Visibility is: "nested"
+      print("Visibility is: ", func.visibility)
+
+      try:
+        entry_block = func.entry_block
+      except IndexError as e:
+        # CHECK: External function does not have a body
+        print(e)
+
+      with InsertionPoint(func.add_entry_block()):
+        std.ReturnOp([func.entry_block.arguments[0]])
+        pass
+
+      try:
+        func.add_entry_block()
+      except IndexError as e:
+        # CHECK: The function already has an entry block!
+        print(e)
+
+      # Try the callback builder and passing type as tuple.
+      func = builtin.FuncOp(name="some_other_func",
+                            type=([tensor_type, tensor_type], [tensor_type]),
+                            visibility="nested",
+                            body_builder=lambda func: std.ReturnOp(
+                                [func.entry_block.arguments[0]]))
+
+  # CHECK: module  {
+  # CHECK:  func nested @some_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
+  # CHECK:   return %arg0 : tensor<2x3x4xf32>
+  # CHECK:  }
+  # CHECK:  func nested @some_other_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
+  # CHECK:   return %arg0 : tensor<2x3x4xf32>
+  # CHECK:  }
+  print(m)
+
+
+run(testBuildFuncOp)

diff  --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 0197bfb15577..94bfd58ab3a5 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -716,6 +716,10 @@ static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) {
 
   os << llvm::formatv(fileHeader, clDialectName.getValue());
   os << llvm::formatv(dialectClassTemplate, clDialectName.getValue());
+
+  if (clDialectName == "builtin")
+    clDialectName = "";
+
   for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) {
     Operator op(rec);
     if (op.getDialectName() == clDialectName.getValue())


        


More information about the llvm-branch-commits mailing list