[Mlir-commits] [mlir] fd407e1 - [mlir] ODS-backed python binding generator for custom op classes
Alex Zinenko
llvmlistbot at llvm.org
Tue Nov 10 01:58:37 PST 2020
Author: Alex Zinenko
Date: 2020-11-10T10:58:29+01:00
New Revision: fd407e1f1eed7deb4818509a8393ee930480d7f5
URL: https://github.com/llvm/llvm-project/commit/fd407e1f1eed7deb4818509a8393ee930480d7f5
DIFF: https://github.com/llvm/llvm-project/commit/fd407e1f1eed7deb4818509a8393ee930480d7f5.diff
LOG: [mlir] ODS-backed python binding generator for custom op classes
Introduce an ODS/Tablegen backend producing Op wrappers for Python bindings
based on the ODS operation definition. Usage:
mlir-tblgen -gen-python-op-bindings -Iinclude <path/to/Ops.td> \
-bind-dialect=<dialect-name>
Reviewed By: mehdi_amini
Differential Revision: https://reviews.llvm.org/D90960
Added:
mlir/test/Bindings/Python/dialects/std.py
mlir/test/mlir-tblgen/op-python-bindings.td
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
Modified:
mlir/CMakeLists.txt
mlir/cmake/modules/AddMLIRPythonExtension.cmake
mlir/include/mlir/Dialect/StandardOps/IR/CMakeLists.txt
mlir/lib/Bindings/Python/CMakeLists.txt
mlir/lib/Bindings/Python/mlir/dialects/__init__.py
mlir/test/Bindings/Python/dialects.py
mlir/tools/mlir-tblgen/CMakeLists.txt
Removed:
mlir/lib/Bindings/Python/mlir/dialects/std.py
################################################################################
diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt
index 2842a1e82a91..c83b6f56b4cc 100644
--- a/mlir/CMakeLists.txt
+++ b/mlir/CMakeLists.txt
@@ -101,6 +101,12 @@ include_directories( ${MLIR_INCLUDE_DIR})
# from another directory like tools
add_subdirectory(tools/mlir-tblgen)
+# Create an anchor target that will depend on dialect-specific op bindings.
+if (MLIR_BINDINGS_PYTHON_ENABLED)
+ add_custom_target(MLIRBindingsPythonIncGen)
+ include(AddMLIRPythonExtension)
+endif()
+
add_subdirectory(include/mlir)
add_subdirectory(lib)
# C API needs all dialects for registration, but should be built before tests.
diff --git a/mlir/cmake/modules/AddMLIRPythonExtension.cmake b/mlir/cmake/modules/AddMLIRPythonExtension.cmake
index 528046b4529e..3cc01c7bd999 100644
--- a/mlir/cmake/modules/AddMLIRPythonExtension.cmake
+++ b/mlir/cmake/modules/AddMLIRPythonExtension.cmake
@@ -122,3 +122,25 @@ function(add_mlir_python_extension libname extname)
endif()
endfunction()
+
+function(add_mlir_dialect_python_bindings filename dialectname)
+ set(LLVM_TARGET_DEFINITIONS ${filename})
+ mlir_tablegen("${dialectname}.py" -gen-python-op-bindings
+ -bind-dialect=${dialectname})
+ if (${ARGC} GREATER 2)
+ set(suffix ${ARGV2})
+ else()
+ get_filename_component(suffix ${filename} NAME_WE)
+ endif()
+ set(tblgen_target "MLIRBindingsPython${suffix}")
+ add_public_tablegen_target(${tblgen_target})
+
+ add_custom_command(
+ TARGET ${tblgen_target} POST_BUILD
+ COMMENT "Copying generated python source \"dialects/${dialectname}.py\""
+ COMMAND "${CMAKE_COMMAND}" -E copy_if_
diff erent
+ "${CMAKE_CURRENT_BINARY_DIR}/${dialectname}.py"
+ "${PROJECT_BINARY_DIR}/python/mlir/dialects/${dialectname}.py")
+ add_dependencies(MLIRBindingsPythonIncGen ${tblgen_target})
+endfunction()
+
diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/StandardOps/IR/CMakeLists.txt
index b9178c5a0db3..ee3e3cfdd9f2 100644
--- a/mlir/include/mlir/Dialect/StandardOps/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/StandardOps/IR/CMakeLists.txt
@@ -7,3 +7,7 @@ mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRStandardOpsIncGen)
add_mlir_doc(Ops -gen-op-doc StandardOps Dialects/)
+
+if (MLIR_BINDINGS_PYTHON_ENABLED)
+ add_mlir_dialect_python_bindings(Ops.td std StandardOps)
+endif()
diff --git a/mlir/lib/Bindings/Python/CMakeLists.txt b/mlir/lib/Bindings/Python/CMakeLists.txt
index 9c294fefbf23..499d684c076b 100644
--- a/mlir/lib/Bindings/Python/CMakeLists.txt
+++ b/mlir/lib/Bindings/Python/CMakeLists.txt
@@ -8,7 +8,6 @@ set(PY_SRC_FILES
mlir/__init__.py
mlir/ir.py
mlir/dialects/__init__.py
- mlir/dialects/std.py
)
add_custom_target(MLIRBindingsPythonSources ALL
@@ -16,6 +15,8 @@ add_custom_target(MLIRBindingsPythonSources ALL
)
add_dependencies(MLIRBindingsPythonExtension MLIRBindingsPythonSources)
+add_dependencies(MLIRBindingsPythonExtension MLIRBindingsPythonIncGen)
+
foreach(PY_SRC_FILE ${PY_SRC_FILES})
set(PY_DEST_FILE "${PROJECT_BINARY_DIR}/python/${PY_SRC_FILE}")
add_custom_command(
diff --git a/mlir/lib/Bindings/Python/mlir/dialects/__init__.py b/mlir/lib/Bindings/Python/mlir/dialects/__init__.py
index 1b7e62c030fb..0aceff1caf3f 100644
--- a/mlir/lib/Bindings/Python/mlir/dialects/__init__.py
+++ b/mlir/lib/Bindings/Python/mlir/dialects/__init__.py
@@ -4,3 +4,40 @@
# Re-export the parent _cext so that every level of the API can get it locally.
from .. import _cext
+
+def _segmented_accessor(elements, raw_segments, idx):
+ """
+ Returns a slice of elements corresponding to the idx-th segment.
+
+ elements: a sliceable container (operands or results).
+ raw_segments: an mlir.ir.Attribute, of DenseIntElements subclass containing
+ sizes of the segments.
+ idx: index of the segment.
+ """
+ segments = _cext.ir.DenseIntElementsAttr(raw_segments)
+ start = sum(segments[i] for i in range(idx))
+ end = start + segments[idx]
+ return elements[start:end]
+
+
+def _equally_sized_accessor(elements, n_variadic, n_preceding_simple,
+ n_preceding_variadic):
+ """
+ Returns a starting position and a number of elements per variadic group
+ assuming equally-sized groups and the given numbers of preceding groups.
+
+ elements: a sequential container.
+ n_variadic: the number of variadic groups in the container.
+ n_preceding_simple: the number of non-variadic groups preceding the current
+ group.
+ n_preceding_variadic: the number of variadic groups preceding the current
+ group.
+ """
+
+ total_variadic_length = len(elements) - n_variadic + 1
+ # This should be enforced by the C++-side trait verifier.
+ assert total_variadic_length % n_variadic == 0
+
+ elements_per_group = total_variadic_length // n_variadic
+ start = n_preceding_simple + n_preceding_variadic * elements_per_group
+ return start, elements_per_group
diff --git a/mlir/lib/Bindings/Python/mlir/dialects/std.py b/mlir/lib/Bindings/Python/mlir/dialects/std.py
deleted file mode 100644
index 74f990cdb5ed..000000000000
--- a/mlir/lib/Bindings/Python/mlir/dialects/std.py
+++ /dev/null
@@ -1,35 +0,0 @@
-# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-# See https://llvm.org/LICENSE.txt for license information.
-# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-
-# TODO: This file should be auto-generated.
-
-from . import _cext
-_ir = _cext.ir
-
- at _cext.register_dialect
-class _Dialect(_ir.Dialect):
- # Special case: 'std' namespace aliases to the empty namespace.
- DIALECT_NAMESPACE = "std"
- pass
-
- at _cext.register_operation(_Dialect)
-class AddFOp(_ir.OpView):
- OPERATION_NAME = "std.addf"
-
- def __init__(self, lhs, rhs, loc=None, ip=None):
- super().__init__(_ir.Operation.create(
- "std.addf", operands=[lhs, rhs], results=[lhs.type],
- loc=loc, ip=ip))
-
- @property
- def lhs(self):
- return self.operation.operands[0]
-
- @property
- def rhs(self):
- return self.operation.operands[1]
-
- @property
- def result(self):
- return self.operation.results[0]
diff --git a/mlir/test/Bindings/Python/dialects.py b/mlir/test/Bindings/Python/dialects.py
index e66c67f08095..63ec61456a58 100644
--- a/mlir/test/Bindings/Python/dialects.py
+++ b/mlir/test/Bindings/Python/dialects.py
@@ -63,7 +63,7 @@ def testUserDialectClass():
run(testUserDialectClass)
-# CHECK-LABEL: TEST: testCustomOpView
+# XHECK-LABEL: TEST: testCustomOpView
# This test uses the standard dialect AddFOp as an example of a user op.
# TODO: Op creation and access is still quite verbose: simplify this test as
# additional capabilities come online.
@@ -88,10 +88,11 @@ def createInput():
from mlir.dialects.std import AddFOp
AddFOp(input1, op1.result)
- # CHECK: %[[INPUT0:.*]] = "pytest_dummy.intinput"
- # CHECK: %[[INPUT1:.*]] = "pytest_dummy.intinput"
- # CHECK: %[[R0:.*]] = addf %[[INPUT0]], %[[INPUT1]] : f32
- # CHECK: %[[R1:.*]] = addf %[[INPUT0]], %[[R0]] : f32
+ # XHECK: %[[INPUT0:.*]] = "pytest_dummy.intinput"
+ # XHECK: %[[INPUT1:.*]] = "pytest_dummy.intinput"
+ # XHECK: %[[R0:.*]] = addf %[[INPUT0]], %[[INPUT1]] : f32
+ # XHECK: %[[R1:.*]] = addf %[[INPUT0]], %[[R0]] : f32
m.operation.print()
-run(testCustomOpView)
+# TODO: re-enable when constructs are generated again
+# run(testCustomOpView)
diff --git a/mlir/test/Bindings/Python/dialects/std.py b/mlir/test/Bindings/Python/dialects/std.py
new file mode 100644
index 000000000000..66f7be6bee88
--- /dev/null
+++ b/mlir/test/Bindings/Python/dialects/std.py
@@ -0,0 +1,51 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from mlir.ir import *
+import mlir.dialects.std as std
+
+def run(f):
+ print("\nTEST:", f.__name__)
+ f()
+
+# CHECK-LABEL: TEST: testSubViewAccessors
+def testSubViewAccessors():
+ ctx = Context()
+ module = Module.parse(r"""
+ func @f1(%arg0: memref<?x?xf32>) {
+ %0 = constant 0 : index
+ %1 = constant 1 : index
+ %2 = constant 2 : index
+ %3 = constant 3 : index
+ %4 = constant 4 : index
+ %5 = constant 5 : index
+ subview %arg0[%0, %1][%2, %3][%4, %5] : memref<?x?xf32> to memref<?x?xf32, offset: ?, strides: [?, ?]>
+ return
+ }
+ """, ctx)
+ func_body = module.body.operations[0].regions[0].blocks[0]
+ subview = func_body.operations[6]
+
+ assert subview.source == subview.operands[0]
+ assert len(subview.offsets) == 2
+ assert len(subview.sizes) == 2
+ assert len(subview.strides) == 2
+ assert subview.result == subview.results[0]
+
+ # CHECK: SubViewOp
+ print(type(subview).__name__)
+
+ # CHECK: constant 0
+ print(subview.offsets[0])
+ # CHECK: constant 1
+ print(subview.offsets[1])
+ # CHECK: constant 2
+ print(subview.sizes[0])
+ # CHECK: constant 3
+ print(subview.sizes[1])
+ # CHECK: constant 4
+ print(subview.strides[0])
+ # CHECK: constant 5
+ print(subview.strides[1])
+
+
+run(testSubViewAccessors)
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
new file mode 100644
index 000000000000..3d19379978d7
--- /dev/null
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -0,0 +1,206 @@
+// RUN: mlir-tblgen -gen-python-op-bindings -bind-dialect=test -I %S/../../include %s | FileCheck %s
+
+include "mlir/IR/OpBase.td"
+
+// CHECK: @_cext.register_dialect
+// CHECK: class _Dialect(_ir.Dialect):
+ // CHECK: DIALECT_NAMESPACE = "test"
+ // CHECK: pass
+def Test_Dialect : Dialect {
+ let name = "test";
+ let cppNamespace = "Test";
+}
+class TestOp<string mnemonic, list<OpTrait> traits = []> :
+ Op<Test_Dialect, mnemonic, traits>;
+
+// CHECK: @_cext.register_operation(_Dialect)
+// CHECK: class AttrSizedOperandsOp(_ir.OpView):
+// CHECK-LABEL: OPERATION_NAME = "test.attr_sized_operands"
+def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
+ [AttrSizedOperandSegments]> {
+ // CHECK: @property
+ // CHECK: def variadic1(self):
+ // CHECK: operand_range = _segmented_accessor(
+ // CHECK: self.operation.operands,
+ // CHECK: self.operation.attributes["operand_segment_sizes"], 0)
+ // CHECK: return operand_range
+ //
+ // CHECK: @property
+ // CHECK: def non_variadic(self):
+ // CHECK: operand_range = _segmented_accessor(
+ // CHECK: self.operation.operands,
+ // CHECK: self.operation.attributes["operand_segment_sizes"], 1)
+ // CHECK: return operand_range[0]
+ //
+ // CHECK: @property
+ // CHECK: def variadic2(self):
+ // CHECK: operand_range = _segmented_accessor(
+ // CHECK: self.operation.operands,
+ // CHECK: self.operation.attributes["operand_segment_sizes"], 2)
+ // CHECK: return operand_range[0] if len(operand_range) > 0 else None
+ let arguments = (ins Variadic<AnyType>:$variadic1, AnyType:$non_variadic,
+ Optional<AnyType>:$variadic2);
+}
+
+// CHECK: @_cext.register_operation(_Dialect)
+// CHECK: class AttrSizedResultsOp(_ir.OpView):
+// CHECK-LABEL: OPERATION_NAME = "test.attr_sized_results"
+def AttrSizedResultsOp : TestOp<"attr_sized_results",
+ [AttrSizedResultSegments]> {
+ // CHECK: @property
+ // CHECK: def variadic1(self):
+ // CHECK: result_range = _segmented_accessor(
+ // CHECK: self.operation.results,
+ // CHECK: self.operation.attributes["result_segment_sizes"], 0)
+ // CHECK: return result_range[0] if len(result_range) > 0 else None
+ //
+ // CHECK: @property
+ // CHECK: def non_variadic(self):
+ // CHECK: result_range = _segmented_accessor(
+ // CHECK: self.operation.results,
+ // CHECK: self.operation.attributes["result_segment_sizes"], 1)
+ // CHECK: return result_range[0]
+ //
+ // CHECK: @property
+ // CHECK: def variadic2(self):
+ // CHECK: result_range = _segmented_accessor(
+ // CHECK: self.operation.results,
+ // CHECK: self.operation.attributes["result_segment_sizes"], 2)
+ // CHECK: return result_range
+ let results = (outs Optional<AnyType>:$variadic1, AnyType:$non_variadic,
+ Optional<AnyType>:$variadic2);
+}
+
+// CHECK: @_cext.register_operation(_Dialect)
+// CHECK: class EmptyOp(_ir.OpView):
+// CHECK-LABEL: OPERATION_NAME = "test.empty"
+def EmptyOp : TestOp<"empty">;
+
+// CHECK: @_cext.register_operation(_Dialect)
+// CHECK: class MissingNamesOp(_ir.OpView):
+// CHECK-LABEL: OPERATION_NAME = "test.missing_names"
+def MissingNamesOp : TestOp<"missing_names"> {
+ // CHECK: @property
+ // CHECK: def f32(self):
+ // CHECK: return self.operation.operands[1]
+ let arguments = (ins I32, F32:$f32, I64);
+
+ // CHECK: @property
+ // CHECK: def i32(self):
+ // CHECK: return self.operation.results[0]
+ //
+ // CHECK: @property
+ // CHECK: def i64(self):
+ // CHECK: return self.operation.results[2]
+ let results = (outs I32:$i32, F32, I64:$i64);
+}
+
+// CHECK: @_cext.register_operation(_Dialect)
+// CHECK: class OneVariadicOperandOp(_ir.OpView):
+// CHECK-LABEL: OPERATION_NAME = "test.one_variadic_operand"
+def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
+ // CHECK: @property
+ // CHECK: def non_variadic(self):
+ // CHECK: return self.operation.operands[0]
+ //
+ // CHECK: @property
+ // CHECK: def variadic(self):
+ // CHECK: variadic_group_length = len(self.operation.operands) - 2 + 1
+ // CHECK: return self.operation.operands[1:1 + variadic_group_length]
+ let arguments = (ins AnyType:$non_variadic, Variadic<AnyType>:$variadic);
+}
+
+// CHECK: @_cext.register_operation(_Dialect)
+// CHECK: class OneVariadicResultOp(_ir.OpView):
+// CHECK-LABEL: OPERATION_NAME = "test.one_variadic_result"
+def OneVariadicResultOp : TestOp<"one_variadic_result"> {
+ // CHECK: @property
+ // CHECK: def variadic(self):
+ // CHECK: variadic_group_length = len(self.operation.results) - 2 + 1
+ // CHECK: return self.operation.results[0:0 + variadic_group_length]
+ //
+ // CHECK: @property
+ // CHECK: def non_variadic(self):
+ // CHECK: variadic_group_length = len(self.operation.results) - 2 + 1
+ // CHECK: return self.operation.results[1 + variadic_group_length - 1]
+ let results = (outs Variadic<AnyType>:$variadic, AnyType:$non_variadic);
+}
+
+// CHECK: @_cext.register_operation(_Dialect)
+// CHECK: class PythonKeywordOp(_ir.OpView):
+// CHECK-LABEL: OPERATION_NAME = "test.python_keyword"
+def PythonKeywordOp : TestOp<"python_keyword"> {
+ // CHECK: @property
+ // CHECK: def in_(self):
+ // CHECK: return self.operation.operands[0]
+ let arguments = (ins AnyType:$in);
+}
+
+// CHECK: @_cext.register_operation(_Dialect)
+// CHECK: class SameVariadicOperandSizeOp(_ir.OpView):
+// CHECK-LABEL: OPERATION_NAME = "test.same_variadic_operand"
+def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
+ [SameVariadicOperandSize]> {
+ // CHECK: @property
+ // CHECK: def variadic1(self):
+ // CHECK: start, pg = _equally_sized_accessor(operation.operands, 2, 0, 0)
+ // CHECK: return self.operation.operands[start:start + pg]
+ //
+ // CHECK: @property
+ // CHECK: def non_variadic(self):
+ // CHECK: start, pg = _equally_sized_accessor(operation.operands, 2, 0, 1)
+ // CHECK: return self.operation.operands[start]
+ //
+ // CHECK: @property
+ // CHECK: def variadic2(self):
+ // CHECK: start, pg = _equally_sized_accessor(operation.operands, 2, 1, 1)
+ // CHECK: return self.operation.operands[start:start + pg]
+ let arguments = (ins Variadic<AnyType>:$variadic1, AnyType:$non_variadic,
+ Variadic<AnyType>:$variadic2);
+}
+
+// CHECK: @_cext.register_operation(_Dialect)
+// CHECK: class SameVariadicResultSizeOp(_ir.OpView):
+// CHECK-LABEL: OPERATION_NAME = "test.same_variadic_result"
+def SameVariadicResultSizeOp : TestOp<"same_variadic_result",
+ [SameVariadicResultSize]> {
+ // CHECK: @property
+ // CHECK: def variadic1(self):
+ // CHECK: start, pg = _equally_sized_accessor(operation.results, 2, 0, 0)
+ // CHECK: return self.operation.results[start:start + pg]
+ //
+ // CHECK: @property
+ // CHECK: def non_variadic(self):
+ // CHECK: start, pg = _equally_sized_accessor(operation.results, 2, 0, 1)
+ // CHECK: return self.operation.results[start]
+ //
+ // CHECK: @property
+ // CHECK: def variadic2(self):
+ // CHECK: start, pg = _equally_sized_accessor(operation.results, 2, 1, 1)
+ // CHECK: return self.operation.results[start:start + pg]
+ let results = (outs Variadic<AnyType>:$variadic1, AnyType:$non_variadic,
+ Variadic<AnyType>:$variadic2);
+}
+
+// CHECK: @_cext.register_operation(_Dialect)
+// CHECK: class SimpleOp(_ir.OpView):
+// CHECK-LABEL: OPERATION_NAME = "test.simple"
+def SimpleOp : TestOp<"simple"> {
+ // CHECK: @property
+ // CHECK: def i32(self):
+ // CHECK: return self.operation.operands[0]
+ //
+ // CHECK: @property
+ // CHECK: def f32(self):
+ // CHECK: return self.operation.operands[1]
+ let arguments = (ins I32:$i32, F32:$f32);
+
+ // CHECK: @property
+ // CHECK: def i64(self):
+ // CHECK: return self.operation.results[0]
+ //
+ // CHECK: @property
+ // CHECK: def f64(self):
+ // CHECK: return self.operation.results[1]
+ let results = (outs I64:$i64, F64:$f64);
+}
diff --git a/mlir/tools/mlir-tblgen/CMakeLists.txt b/mlir/tools/mlir-tblgen/CMakeLists.txt
index 5686e63fbdde..119d03573a66 100644
--- a/mlir/tools/mlir-tblgen/CMakeLists.txt
+++ b/mlir/tools/mlir-tblgen/CMakeLists.txt
@@ -14,6 +14,7 @@ add_tablegen(mlir-tblgen MLIR
OpDocGen.cpp
OpFormatGen.cpp
OpInterfacesGen.cpp
+ OpPythonBindingGen.cpp
OpenMPCommonGen.cpp
PassCAPIGen.cpp
PassDocGen.cpp
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
new file mode 100644
index 000000000000..f940aae38176
--- /dev/null
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -0,0 +1,333 @@
+//===- OpPythonBindingGen.cpp - Generator of Python API for MLIR Ops ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// OpPythonBindingGen uses ODS specification of MLIR ops to generate Python
+// binding classes wrapping a generic operation API.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/TableGen/GenInfo.h"
+#include "mlir/TableGen/Operator.h"
+#include "llvm/ADT/StringSet.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/TableGen/Error.h"
+#include "llvm/TableGen/Record.h"
+
+using namespace mlir;
+using namespace mlir::tblgen;
+
+/// File header and includes.
+constexpr const char *fileHeader = R"Py(
+# Autogenerated by mlir-tblgen; don't manually edit.
+
+from . import _cext
+from . import _segmented_accessor, _equally_sized_accessor
+_ir = _cext.ir
+)Py";
+
+/// Template for dialect class:
+/// {0} is the dialect namespace.
+constexpr const char *dialectClassTemplate = R"Py(
+ at _cext.register_dialect
+class _Dialect(_ir.Dialect):
+ DIALECT_NAMESPACE = "{0}"
+ pass
+
+)Py";
+
+/// Template for operation class:
+/// {0} is the Python class name;
+/// {1} is the operation name.
+constexpr const char *opClassTemplate = R"Py(
+ at _cext.register_operation(_Dialect)
+class {0}(_ir.OpView):
+ OPERATION_NAME = "{1}"
+)Py";
+
+/// Template for single-element accessor:
+/// {0} is the name of the accessor;
+/// {1} is either 'operand' or 'result';
+/// {2} is the position in the element list.
+constexpr const char *opSingleTemplate = R"Py(
+ @property
+ def {0}(self):
+ return self.operation.{1}s[{2}]
+)Py";
+
+/// Template for single-element accessor after a variable-length group:
+/// {0} is the name of the accessor;
+/// {1} is either 'operand' or 'result';
+/// {2} is the total number of element groups;
+/// {3} is the position of the current group in the group list.
+/// This works for both a single variadic group (non-negative length) and an
+/// single optional element (zero length if the element is absent).
+constexpr const char *opSingleAfterVariableTemplate = R"Py(
+ @property
+ def {0}(self):
+ variadic_group_length = len(self.operation.{1}s) - {2} + 1
+ return self.operation.{1}s[{3} + variadic_group_length - 1]
+)Py";
+
+/// Template for an optional element accessor:
+/// {0} is the name of the accessor;
+/// {1} is either 'operand' or 'result';
+/// {2} is the total number of element groups;
+/// {3} is the position of the current group in the group list.
+constexpr const char *opOneOptionalTemplate = R"Py(
+ @property
+ def {0}(self);
+ return self.operation.{1}s[{3}] if len(self.operation.{1}s) > {2}
+ else None
+)Py";
+
+/// Template for the variadic group accessor in the single variadic group case:
+/// {0} is the name of the accessor;
+/// {1} is either 'operand' or 'result';
+/// {2} is the total number of element groups;
+/// {3} is the position of the current group in the group list.
+constexpr const char *opOneVariadicTemplate = R"Py(
+ @property
+ def {0}(self):
+ variadic_group_length = len(self.operation.{1}s) - {2} + 1
+ return self.operation.{1}s[{3}:{3} + variadic_group_length]
+)Py";
+
+/// First part of the template for equally-sized variadic group accessor:
+/// {0} is the name of the accessor;
+/// {1} is either 'operand' or 'result';
+/// {2} is the total number of variadic groups;
+/// {3} is the number of non-variadic groups preceding the current group;
+/// {3} is the number of variadic groups preceding the current group.
+constexpr const char *opVariadicEqualPrefixTemplate = R"Py(
+ @property
+ def {0}(self):
+ start, pg = _equally_sized_accessor(operation.{1}s, {2}, {3}, {4}))Py";
+
+/// Second part of the template for equally-sized case, accessing a single
+/// element:
+/// {0} is either 'operand' or 'result'.
+constexpr const char *opVariadicEqualSimpleTemplate = R"Py(
+ return self.operation.{0}s[start]
+)Py";
+
+/// Second part of the template for equally-sized case, accessing a variadic
+/// group:
+/// {0} is either 'operand' or 'result'.
+constexpr const char *opVariadicEqualVariadicTemplate = R"Py(
+ return self.operation.{0}s[start:start + pg]
+)Py";
+
+/// Template for an attribute-sized group accessor:
+/// {0} is the name of the accessor;
+/// {1} is either 'operand' or 'result';
+/// {2} is the position of the group in the group list;
+/// {3} is a return suffix (expected [0] for single-element, empty for
+/// variadic, and opVariadicSegmentOptionalTrailingTemplate for optional).
+constexpr const char *opVariadicSegmentTemplate = R"Py(
+ @property
+ def {0}(self):
+ {1}_range = _segmented_accessor(
+ self.operation.{1}s,
+ self.operation.attributes["{1}_segment_sizes"], {2})
+ return {1}_range{3}
+)Py";
+
+/// Template for a suffix when accessing an optional element in the
+/// attribute-sized case:
+/// {0} is either 'operand' or 'result';
+constexpr const char *opVariadicSegmentOptionalTrailingTemplate =
+ R"Py([0] if len({0}_range) > 0 else None)Py";
+
+static llvm::cl::OptionCategory
+ clOpPythonBindingCat("Options for -gen-python-op-bindings");
+
+static llvm::cl::opt<std::string>
+ clDialectName("bind-dialect",
+ llvm::cl::desc("The dialect to run the generator for"),
+ llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));
+
+/// Checks whether `str` is a Python keyword.
+static bool isPythonKeyword(StringRef str) {
+ static llvm::StringSet<> keywords(
+ {"and", "as", "assert", "break", "class", "continue",
+ "def", "del", "elif", "else", "except", "finally",
+ "for", "from", "global", "if", "import", "in",
+ "is", "lambda", "nonlocal", "not", "or", "pass",
+ "raise", "return", "try", "while", "with", "yield"});
+ return keywords.contains(str);
+};
+
+/// Modifies the `name` in a way that it becomes suitable for Python bindings
+/// (does not change the `name` if it already is suitable) and returns the
+/// modified version.
+static std::string sanitizeName(StringRef name) {
+ if (isPythonKeyword(name))
+ return (name + "_").str();
+ return name.str();
+}
+
+/// Emits accessors to "elements" of an Op definition. Currently, the supported
+/// elements are operands and results, indicated by `kind`, which must be either
+/// `operand` or `result` and is used verbatim in the emitted code.
+static void emitElementAccessors(
+ const Operator &op, raw_ostream &os, const char *kind,
+ llvm::function_ref<unsigned(const Operator &)> getNumVariadic,
+ llvm::function_ref<int(const Operator &)> getNumElements,
+ llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
+ getElement) {
+ assert(llvm::is_contained(
+ llvm::SmallVector<StringRef, 2>{"operand", "result"}, kind) &&
+ "unsupported kind");
+
+ // Traits indicating how to process variadic elements.
+ std::string sameSizeTrait =
+ llvm::formatv("::mlir::OpTrait::SameVariadic{0}{1}Size",
+ llvm::StringRef(kind).take_front().upper(),
+ llvm::StringRef(kind).drop_front());
+ std::string attrSizedTrait =
+ llvm::formatv("::mlir::OpTrait::AttrSized{0}{1}Segments",
+ llvm::StringRef(kind).take_front().upper(),
+ llvm::StringRef(kind).drop_front());
+
+ unsigned numVariadic = getNumVariadic(op);
+
+ // If there is only one variadic element group, its size can be inferred from
+ // the total number of elements. If there are none, the generation is
+ // straightforward.
+ if (numVariadic <= 1) {
+ bool seenVariableLength = false;
+ for (int i = 0, e = getNumElements(op); i < e; ++i) {
+ const NamedTypeConstraint &element = getElement(op, i);
+ if (element.isVariableLength())
+ seenVariableLength = true;
+ if (element.name.empty())
+ continue;
+ if (element.isVariableLength()) {
+ os << llvm::formatv(element.isOptional() ? opOneOptionalTemplate
+ : opOneVariadicTemplate,
+ sanitizeName(element.name), kind,
+ getNumElements(op), i);
+ } else if (seenVariableLength) {
+ os << llvm::formatv(opSingleAfterVariableTemplate,
+ sanitizeName(element.name), kind,
+ getNumElements(op), i);
+ } else {
+ os << llvm::formatv(opSingleTemplate, sanitizeName(element.name), kind,
+ i);
+ }
+ }
+ return;
+ }
+
+ // Handle the operations where variadic groups have the same size.
+ if (op.getTrait(sameSizeTrait)) {
+ int numPrecedingSimple = 0;
+ int numPrecedingVariadic = 0;
+ for (int i = 0, e = getNumElements(op); i < e; ++i) {
+ const NamedTypeConstraint &element = getElement(op, i);
+ if (!element.name.empty()) {
+ os << llvm::formatv(opVariadicEqualPrefixTemplate,
+ sanitizeName(element.name), kind, numVariadic,
+ numPrecedingSimple, numPrecedingVariadic);
+ os << llvm::formatv(element.isVariableLength()
+ ? opVariadicEqualVariadicTemplate
+ : opVariadicEqualSimpleTemplate,
+ kind);
+ }
+ if (element.isVariableLength())
+ ++numPrecedingVariadic;
+ else
+ ++numPrecedingSimple;
+ }
+ return;
+ }
+
+ // Handle the operations where the size of groups (variadic or not) is
+ // provided as an attribute. For non-variadic elements, make sure to return
+ // an element rather than a singleton container.
+ if (op.getTrait(attrSizedTrait)) {
+ for (int i = 0, e = getNumElements(op); i < e; ++i) {
+ const NamedTypeConstraint &element = getElement(op, i);
+ if (element.name.empty())
+ continue;
+ std::string trailing;
+ if (!element.isVariableLength())
+ trailing = "[0]";
+ else if (element.isOptional())
+ trailing = std::string(
+ llvm::formatv(opVariadicSegmentOptionalTrailingTemplate, kind));
+ os << llvm::formatv(opVariadicSegmentTemplate, sanitizeName(element.name),
+ kind, i, trailing);
+ }
+ return;
+ }
+
+ llvm::PrintFatalError("unsupported " + llvm::Twine(kind) + " structure");
+}
+
+/// Emits accessor to Op operands.
+static void emitOperandAccessors(const Operator &op, raw_ostream &os) {
+ auto getNumVariadic = [](const Operator &oper) {
+ return oper.getNumVariableLengthOperands();
+ };
+ auto getNumElements = [](const Operator &oper) {
+ return oper.getNumOperands();
+ };
+ auto getElement = [](const Operator &oper,
+ int i) -> const NamedTypeConstraint & {
+ return oper.getOperand(i);
+ };
+ emitElementAccessors(op, os, "operand", getNumVariadic, getNumElements,
+ getElement);
+}
+
+/// Emits access or Op results.
+static void emitResultAccessors(const Operator &op, raw_ostream &os) {
+ auto getNumVariadic = [](const Operator &oper) {
+ return oper.getNumVariableLengthResults();
+ };
+ auto getNumElements = [](const Operator &oper) {
+ return oper.getNumResults();
+ };
+ auto getElement = [](const Operator &oper,
+ int i) -> const NamedTypeConstraint & {
+ return oper.getResult(i);
+ };
+ emitElementAccessors(op, os, "result", getNumVariadic, getNumElements,
+ getElement);
+}
+
+/// Emits bindings for a specific Op to the given output stream.
+static void emitOpBindings(const Operator &op, raw_ostream &os) {
+ os << llvm::formatv(opClassTemplate, op.getCppClassName(),
+ op.getOperationName());
+ emitOperandAccessors(op, os);
+ emitResultAccessors(op, os);
+}
+
+/// Emits bindings for the dialect specified in the command line, including file
+/// headers and utilities. Returns `false` on success to comply with Tablegen
+/// registration requirements.
+static bool emitAllOps(const llvm::RecordKeeper &records, raw_ostream &os) {
+ if (clDialectName.empty())
+ llvm::PrintFatalError("dialect name not provided");
+
+ os << fileHeader;
+ os << llvm::formatv(dialectClassTemplate, clDialectName.getValue());
+ for (const llvm::Record *rec : records.getAllDerivedDefinitions("Op")) {
+ Operator op(rec);
+ if (op.getDialectName() == clDialectName.getValue())
+ emitOpBindings(op, os);
+ }
+ return false;
+}
+
+static GenRegistration
+ genPythonBindings("gen-python-op-bindings",
+ "Generate Python bindings for MLIR Ops", &emitAllOps);
More information about the Mlir-commits
mailing list