[llvm-branch-commits] [mlir] 922b26c - Add Python bindings for the builtin dialect
Mehdi Amini via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Jan 21 14:55:01 PST 2021
Author: Mehdi Amini
Date: 2021-01-21T22:44:44Z
New Revision: 922b26cde4d1c89a5fa90e6a1d6d97d0f8eace6d
URL: https://github.com/llvm/llvm-project/commit/922b26cde4d1c89a5fa90e6a1d6d97d0f8eace6d
DIFF: https://github.com/llvm/llvm-project/commit/922b26cde4d1c89a5fa90e6a1d6d97d0f8eace6d.diff
LOG: Add Python bindings for the builtin dialect
This includes some minor customization for FuncOp and ModuleOp.
Differential Revision: https://reviews.llvm.org/D95022
Added:
mlir/lib/Bindings/Python/BuiltinOps.td
mlir/lib/Bindings/Python/mlir/dialects/_builtin.py
mlir/test/Bindings/Python/.style.yapf
mlir/test/Bindings/Python/dialects/builtin.py
Modified:
mlir/lib/Bindings/Python/CMakeLists.txt
mlir/lib/Bindings/Python/mlir/dialects/__init__.py
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/BuiltinOps.td b/mlir/lib/Bindings/Python/BuiltinOps.td
new file mode 100644
index 000000000000..ecbb8227d490
--- /dev/null
+++ b/mlir/lib/Bindings/Python/BuiltinOps.td
@@ -0,0 +1,15 @@
+//===-- BuiltinOps.td - Entry point for builtin bindings ---*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PYTHON_BINDINGS_BUILTIN_OPS
+#define PYTHON_BINDINGS_BUILTIN_OPS
+
+include "mlir/Bindings/Python/Attributes.td"
+include "mlir/IR/BuiltinOps.td"
+
+#endif
diff --git a/mlir/lib/Bindings/Python/CMakeLists.txt b/mlir/lib/Bindings/Python/CMakeLists.txt
index 1749ea2e5472..951aa7883c90 100644
--- a/mlir/lib/Bindings/Python/CMakeLists.txt
+++ b/mlir/lib/Bindings/Python/CMakeLists.txt
@@ -11,6 +11,7 @@ set(PY_SRC_FILES
mlir/ir.py
mlir/dialects/__init__.py
mlir/dialects/_linalg.py
+ mlir/dialects/_builtin.py
mlir/ir.py
mlir/passmanager.py
mlir/transforms/__init__.py
@@ -36,6 +37,11 @@ endforeach()
# Generate dialect-specific bindings.
################################################################################
+add_mlir_dialect_python_bindings(MLIRBindingsPythonBuiltinOps
+ TD_FILE BuiltinOps.td
+ DIALECT_NAME builtin)
+add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonBuiltinOps)
+
add_mlir_dialect_python_bindings(MLIRBindingsPythonLinalgOps
TD_FILE LinalgOps.td
DIALECT_NAME linalg
diff --git a/mlir/lib/Bindings/Python/mlir/dialects/__init__.py b/mlir/lib/Bindings/Python/mlir/dialects/__init__.py
index 9c003b415438..f5a71bf88700 100644
--- a/mlir/lib/Bindings/Python/mlir/dialects/__init__.py
+++ b/mlir/lib/Bindings/Python/mlir/dialects/__init__.py
@@ -43,7 +43,7 @@ def class_decorator(parent_opview_cls: type):
except AttributeError:
# Try to default resolve it.
try:
- select_mixin = getattr(ext_module, parent_opview_cls.__name__)
+ mixin_cls = getattr(ext_module, parent_opview_cls.__name__)
except AttributeError:
pass
else:
diff --git a/mlir/lib/Bindings/Python/mlir/dialects/_builtin.py b/mlir/lib/Bindings/Python/mlir/dialects/_builtin.py
new file mode 100644
index 000000000000..8d430d5a50da
--- /dev/null
+++ b/mlir/lib/Bindings/Python/mlir/dialects/_builtin.py
@@ -0,0 +1,93 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from mlir.ir import *
+
+
+class ModuleOp:
+ """Specialization for the module op class."""
+
+ def __init__(self, loc=None, ip=None):
+ super().__init__(
+ self._ods_build_default(operands=[], results=[], loc=loc, ip=ip))
+ body = self.regions[0].blocks.append()
+ with InsertionPoint(body):
+ Operation.create("module_terminator")
+
+ @property
+ def body(self):
+ return self.regions[0].blocks[0]
+
+
+class FuncOp:
+ """Specialization for the func op class."""
+
+ def __init__(self,
+ name,
+ type,
+ visibility,
+ body_builder=None,
+ loc=None,
+ ip=None):
+ """
+ Create a FuncOp with the provided `name`, `type`, and `visibility`.
+ - `name` is a string representing the function name.
+ - `type` is either a FunctionType or a pair of list describing inputs and
+ results.
+ - `visibility` is a string matching `public`, `private`, or `nested`. The
+ empty string implies a private visibility.
+ - `body_builder` is an optional callback, when provided a new entry block
+ is created and the callback is invoked with the new op as argument within
+ an InsertionPoint context already set for the block. The callback is
+ expected to insert a terminator in the block.
+ """
+ sym_name = StringAttr.get(str(name))
+
+ # If the type is passed as a tuple, build a FunctionType on the fly.
+ if isinstance(type, tuple):
+ type = FunctionType.get(inputs=type[0], results=type[1])
+
+ type = TypeAttr.get(type)
+ sym_visibility = StringAttr.get(
+ str(visibility)) if visibility is not None else None
+ super().__init__(sym_name, type, sym_visibility, loc, ip)
+ if body_builder:
+ entry_block = self.add_entry_block()
+ with InsertionPoint(entry_block):
+ body_builder(self)
+
+ @property
+ def is_external(self):
+ return len(self.regions[0].blocks) == 0
+
+ @property
+ def body(self):
+ return self.regions[0]
+
+ @property
+ def type(self):
+ return FunctionType(TypeAttr(self.attributes["type"]).value)
+
+ @property
+ def visibility(self):
+ return self.attributes["sym_visibility"]
+
+ @property
+ def name(self):
+ return self.attributes["sym_name"]
+
+ @property
+ def entry_block(self):
+ if self.is_external:
+ raise IndexError('External function does not have a body')
+ return self.regions[0].blocks[0]
+
+ def add_entry_block(self):
+ '''
+ Add an entry block to the function body using the function signature to infer block arguments
+ Returns the newly created block
+ '''
+ if not self.is_external:
+ raise IndexError('The function already has an entry block!')
+ self.body.blocks.append(*self.type.inputs)
+ return self.body.blocks[0]
diff --git a/mlir/test/Bindings/Python/.style.yapf b/mlir/test/Bindings/Python/.style.yapf
new file mode 100644
index 000000000000..9ef1dc15ba62
--- /dev/null
+++ b/mlir/test/Bindings/Python/.style.yapf
@@ -0,0 +1,4 @@
+[style]
+ based_on_style = google
+ column_limit = 80
+ indent_width = 2
diff --git a/mlir/test/Bindings/Python/dialects/builtin.py b/mlir/test/Bindings/Python/dialects/builtin.py
new file mode 100644
index 000000000000..447a255f6021
--- /dev/null
+++ b/mlir/test/Bindings/Python/dialects/builtin.py
@@ -0,0 +1,69 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from mlir.ir import *
+import mlir.dialects.builtin as builtin
+import mlir.dialects.std as std
+
+
+def run(f):
+ print("\nTEST:", f.__name__)
+ f()
+
+
+# CHECK-LABEL: TEST: testBuildFuncOp
+def testBuildFuncOp():
+ ctx = Context()
+ with Location.unknown(ctx) as loc:
+ m = builtin.ModuleOp()
+
+ f32 = F32Type.get()
+ tensor_type = RankedTensorType.get((2, 3, 4), f32)
+ with InsertionPoint.at_block_begin(m.body):
+ func = builtin.FuncOp(name="some_func",
+ type=FunctionType.get(
+ inputs=[tensor_type, tensor_type],
+ results=[tensor_type]),
+ visibility="nested")
+ # CHECK: Name is: "some_func"
+ print("Name is: ", func.name)
+
+ # CHECK: Type is: (tensor<2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
+ print("Type is: ", func.type)
+
+ # CHECK: Visibility is: "nested"
+ print("Visibility is: ", func.visibility)
+
+ try:
+ entry_block = func.entry_block
+ except IndexError as e:
+ # CHECK: External function does not have a body
+ print(e)
+
+ with InsertionPoint(func.add_entry_block()):
+ std.ReturnOp([func.entry_block.arguments[0]])
+ pass
+
+ try:
+ func.add_entry_block()
+ except IndexError as e:
+ # CHECK: The function already has an entry block!
+ print(e)
+
+ # Try the callback builder and passing type as tuple.
+ func = builtin.FuncOp(name="some_other_func",
+ type=([tensor_type, tensor_type], [tensor_type]),
+ visibility="nested",
+ body_builder=lambda func: std.ReturnOp(
+ [func.entry_block.arguments[0]]))
+
+ # CHECK: module {
+ # CHECK: func nested @some_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
+ # CHECK: return %arg0 : tensor<2x3x4xf32>
+ # CHECK: }
+ # CHECK: func nested @some_other_func(%arg0: tensor<2x3x4xf32>, %arg1: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
+ # CHECK: return %arg0 : tensor<2x3x4xf32>
+ # CHECK: }
+ print(m)
+
+
+run(testBuildFuncOp)
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 0197bfb15577..94bfd58ab3a5 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -716,6 +716,10 @@ static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) {
os << llvm::formatv(fileHeader, clDialectName.getValue());
os << llvm::formatv(dialectClassTemplate, clDialectName.getValue());
+
+ if (clDialectName == "builtin")
+ clDialectName = "";
+
for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) {
Operator op(rec);
if (op.getDialectName() == clDialectName.getValue())
More information about the llvm-branch-commits
mailing list