[Mlir-commits] [mlir] 14c9207 - [mlir] support interfaces in Python bindings

Alex Zinenko llvmlistbot at llvm.org
Mon Oct 25 03:50:51 PDT 2021


Author: Alex Zinenko
Date: 2021-10-25T12:50:42+02:00
New Revision: 14c9207063bb00823a5126131e50c93f6e288bd3

URL: https://github.com/llvm/llvm-project/commit/14c9207063bb00823a5126131e50c93f6e288bd3
DIFF: https://github.com/llvm/llvm-project/commit/14c9207063bb00823a5126131e50c93f6e288bd3.diff

LOG: [mlir] support interfaces in Python bindings

Introduce the initial support for operation interfaces in C API and Python
bindings. Interfaces are a key component of MLIR's extensibility and should be
available in bindings to make use of full potential of MLIR.

This initial implementation exposes InferTypeOpInterface all the way to the
Python bindings since it can be later used to simplify the operation
construction methods by inferring their return types instead of requiring the
user to do so. The general infrastructure for binding interfaces is defined and
InferTypeOpInterface can be used as an example for binding other interfaces.

Reviewed By: gysit

Differential Revision: https://reviews.llvm.org/D111656

Added: 
    mlir/include/mlir-c/Interfaces.h
    mlir/include/mlir/CAPI/Interfaces.h
    mlir/lib/Bindings/Python/IRInterfaces.cpp
    mlir/lib/CAPI/Interfaces/CMakeLists.txt
    mlir/lib/CAPI/Interfaces/Interfaces.cpp
    mlir/test/python/CMakeLists.txt
    mlir/test/python/lib/CMakeLists.txt
    mlir/test/python/lib/PythonTestCAPI.cpp
    mlir/test/python/lib/PythonTestCAPI.h
    mlir/test/python/lib/PythonTestDialect.cpp
    mlir/test/python/lib/PythonTestDialect.h
    mlir/test/python/lib/PythonTestModule.cpp

Modified: 
    mlir/CMakeLists.txt
    mlir/docs/Bindings/Python.md
    mlir/docs/CAPI.md
    mlir/lib/Bindings/Python/IRModule.h
    mlir/lib/Bindings/Python/MainModule.cpp
    mlir/lib/CAPI/CMakeLists.txt
    mlir/python/CMakeLists.txt
    mlir/python/mlir/dialects/python_test.py
    mlir/test/CMakeLists.txt
    mlir/test/python/dialects/python_test.py
    mlir/test/python/python_test_ops.td
    utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Removed: 
    mlir/python/mlir/dialects/PythonTest.td


################################################################################
diff  --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt
index 1a2583873c131..58dc116082084 100644
--- a/mlir/CMakeLists.txt
+++ b/mlir/CMakeLists.txt
@@ -123,6 +123,7 @@ add_subdirectory(include/mlir)
 add_subdirectory(lib)
 # C API needs all dialects for registration, but should be built before tests.
 add_subdirectory(lib/CAPI)
+
 if (MLIR_INCLUDE_TESTS)
   add_definitions(-DMLIR_INCLUDE_TESTS)
   add_custom_target(MLIRUnitTests)

diff  --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md
index 71737753b5faf..7fdc8402d03e2 100644
--- a/mlir/docs/Bindings/Python.md
+++ b/mlir/docs/Bindings/Python.md
@@ -536,6 +536,68 @@ except ValueError:
   concrete = OpResult(value)
 ```
 
+#### Interfaces
+
+MLIR interfaces are a mechanism to interact with the IR without needing to know
+specific types of operations but only some of their aspects. Operation
+interfaces are available as Python classes with the same name as their C++
+counterparts. Objects of these classes can be constructed from either:
+
+-   an object of the `Operation` class or of any `OpView` subclass; in this
+    case, all interface methods are available;
+-   a subclass of `OpView` and a context; in this case, only the *static*
+    interface methods are available as there is no associated operation.
+
+In both cases, construction of the interface raises a `ValueError` if the
+operation class does not implement the interface in the given context (or, for
+operations, in the context that the operation is defined in). Similarly to
+attributes and types, the MLIR context may be set up by a surrounding context
+manager.
+
+```python
+from mlir.ir import Context, InferTypeOpInterface
+
+with Context():
+  op = <...>
+
+  # Attempt to cast the operation into an interface.
+  try:
+    iface = InferTypeOpInterface(op)
+  except ValueError:
+    print("Operation does not implement InferTypeOpInterface.")
+    raise
+
+  # All methods are available on interface objects constructed from an Operation
+  # or an OpView.
+  iface.someInstanceMethod()
+
+  # An interface object can also be constructed given an OpView subclass. It
+  # also needs a context in which the interface will be looked up. The context
+  # can be provided explicitly or set up by the surrounding context manager.
+  try:
+    iface = InferTypeOpInterface(some_dialect.SomeOp)
+  except ValueError:
+    print("SomeOp does not implement InferTypeOpInterface.")
+    raise
+
+  # Calling an instance method on an interface object constructed from a class
+  # will raise TypeError.
+  try:
+    iface.someInstanceMethod()
+  except TypeError:
+    pass
+
+  # One can still call static interface methods though.
+  iface.inferOpReturnTypes(<...>)
+```
+
+If an interface object was constructed from an `Operation` or an `OpView`, they
+are available as `.operation` and `.opview` properties of the interface object,
+respectively.
+
+Only a subset of operation interfaces are currently provided in Python bindings.
+Attribute and type interfaces are not yet available in Python bindings.
+
 ### Creating IR Objects
 
 Python bindings also support IR creation and manipulation.

diff  --git a/mlir/docs/CAPI.md b/mlir/docs/CAPI.md
index 4feb1faf1b8ef..1dd224731ac36 100644
--- a/mlir/docs/CAPI.md
+++ b/mlir/docs/CAPI.md
@@ -194,3 +194,23 @@ counterparts. `wrap` converts a C++ class into a C structure and `unwrap` does
 the inverse conversion. Once the C++ object is available, the API implementation
 should rely on `isa` to implement `mlirXIsAY` and is expected to use `cast`
 inside other API calls.
+
+### Extensions for Interfaces
+
+Interfaces can follow the example of IR interfaces and should be placed in the
+appropriate library (e.g., common interfaces in `mlir-c/Interfaces` and
+dialect-specific interfaces in their dialect library). Similarly to other type
+hierarchies, interfaces are not expected to have objects of their own type and
+instead operate on top-level objects: `MlirAttribute`, `MlirOperation` and
+`MlirType`. Static interface methods are expected to take as leading argument a
+canonical identifier of the class, `MlirStringRef` with the name for operations
+and `MlirTypeID` for attributes and types, followed by `MlirContext` in which
+the interfaces are registered.
+
+Individual interfaces are expected provide a `mlir<InterfaceName>TypeID()`
+function that can be used to check whether an object or a class implements this
+interface using `mlir<Attribute/Operation/Type>ImplementsInterface` or
+`mlir<Attribute/Operation?Type>ImplementsInterfaceStatic` functions,
+respectively. Rationale: C++ `isa` only works when an object exists, static
+methods are usually dispatched to using templates; lookup by `TypeID` in
+`MLIRContext` works even without an object.

diff  --git a/mlir/include/mlir-c/Interfaces.h b/mlir/include/mlir-c/Interfaces.h
new file mode 100644
index 0000000000000..f03dd6ea5c83b
--- /dev/null
+++ b/mlir/include/mlir-c/Interfaces.h
@@ -0,0 +1,67 @@
+//===-- mlir-c/Interfaces.h - C API to Core MLIR IR interfaces ----*- C -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM
+// Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header declares the C interface to MLIR interface classes. It is
+// intended to contain interfaces defined in lib/Interfaces.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_C_DIALECT_H
+#define MLIR_C_DIALECT_H
+
+#include "mlir-c/IR.h"
+#include "mlir-c/Support.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/// Returns `true` if the given operation implements an interface identified by
+/// its TypeID.
+MLIR_CAPI_EXPORTED bool
+mlirOperationImplementsInterface(MlirOperation operation,
+                                 MlirTypeID interfaceTypeID);
+
+/// Returns `true` if the operation identified by its canonical string name
+/// implements the interface identified by its TypeID in the given context.
+/// Note that interfaces may be attached to operations in some contexts and not
+/// others.
+MLIR_CAPI_EXPORTED bool
+mlirOperationImplementsInterfaceStatic(MlirStringRef operationName,
+                                       MlirContext context,
+                                       MlirTypeID interfaceTypeID);
+
+//===----------------------------------------------------------------------===//
+// InferTypeOpInterface.
+//===----------------------------------------------------------------------===//
+
+/// Returns the interface TypeID of the InferTypeOpInterface.
+MLIR_CAPI_EXPORTED MlirTypeID mlirInferTypeOpInterfaceTypeID();
+
+/// These callbacks are used to return multiple types from functions while
+/// transferring ownerhsip to the caller. The first argument is the number of
+/// consecutive elements pointed to by the second argument. The third argument
+/// is an opaque pointer forwarded to the callback by the caller.
+typedef void (*MlirTypesCallback)(intptr_t, MlirType *, void *);
+
+/// Infers the return types of the operation identified by its canonical given
+/// the arguments that will be supplied to its generic builder. Calls `callback`
+/// with the types of inferred arguments, potentially several times, on success.
+/// Returns failure otherwise.
+MLIR_CAPI_EXPORTED MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes(
+    MlirStringRef opName, MlirContext context, MlirLocation location,
+    intptr_t nOperands, MlirValue *operands, MlirAttribute attributes,
+    intptr_t nRegions, MlirRegion *regions, MlirTypesCallback callback,
+    void *userData);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif // MLIR_C_DIALECT_H

diff  --git a/mlir/include/mlir/CAPI/Interfaces.h b/mlir/include/mlir/CAPI/Interfaces.h
new file mode 100644
index 0000000000000..4154b8c9ec6cc
--- /dev/null
+++ b/mlir/include/mlir/CAPI/Interfaces.h
@@ -0,0 +1,18 @@
+//===- Interfaces.h - C API Utils for MLIR interfaces -----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains declarations of implementation details of the C API for
+// MLIR interface classes. This file should not be included from C++ code other
+// than C API implementation nor from C code.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CAPI_INTERFACES_H
+#define MLIR_CAPI_INTERFACES_H
+
+#endif // MLIR_CAPI_INTERFACES_H

diff  --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp
new file mode 100644
index 0000000000000..c3d41c4d84d79
--- /dev/null
+++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp
@@ -0,0 +1,240 @@
+//===- IRInterfaces.cpp - MLIR IR interfaces pybind -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "IRModule.h"
+#include "mlir-c/BuiltinAttributes.h"
+#include "mlir-c/Interfaces.h"
+
+namespace py = pybind11;
+
+namespace mlir {
+namespace python {
+
+constexpr static const char *constructorDoc =
+    R"(Creates an interface from a given operation/opview object or from a
+subclass of OpView. Raises ValueError if the operation does not implement the
+interface.)";
+
+constexpr static const char *operationDoc =
+    R"(Returns an Operation for which the interface was constructed.)";
+
+constexpr static const char *opviewDoc =
+    R"(Returns an OpView subclass _instance_ for which the interface was
+constructed)";
+
+constexpr static const char *inferReturnTypesDoc =
+    R"(Given the arguments required to build an operation, attempts to infer
+its return types. Raises ValueError on failure.)";
+
+/// CRTP base class for Python classes representing MLIR Op interfaces.
+/// Interface hierarchies are flat so no base class is expected here. The
+/// derived class is expected to define the following static fields:
+///  - `const char *pyClassName` - the name of the Python class to create;
+///  - `GetTypeIDFunctionTy getInterfaceID` - the function producing the TypeID
+///    of the interface.
+/// Derived classes may redefine the `bindDerived(ClassTy &)` method to bind
+/// interface-specific methods.
+///
+/// An interface class may be constructed from either an Operation/OpView object
+/// or from a subclass of OpView. In the latter case, only the static interface
+/// methods are available, similarly to calling ConcereteOp::staticMethod on the
+/// C++ side. Implementations of concrete interfaces can use the `isStatic`
+/// method to check whether the interface object was constructed from a class or
+/// an operation/opview instance. The `getOpName` always succeeds and returns a
+/// canonical name of the operation suitable for lookups.
+template <typename ConcreteIface>
+class PyConcreteOpInterface {
+protected:
+  using ClassTy = py::class_<ConcreteIface>;
+  using GetTypeIDFunctionTy = MlirTypeID (*)();
+
+public:
+  /// Constructs an interface instance from an object that is either an
+  /// operation or a subclass of OpView. In the latter case, only the static
+  /// methods of the interface are accessible to the caller.
+  PyConcreteOpInterface(py::object object, DefaultingPyMlirContext context)
+      : obj(object) {
+    try {
+      operation = &py::cast<PyOperation &>(obj);
+    } catch (py::cast_error &err) {
+      // Do nothing.
+    }
+
+    try {
+      operation = &py::cast<PyOpView &>(obj).getOperation();
+    } catch (py::cast_error &err) {
+      // Do nothing.
+    }
+
+    if (operation != nullptr) {
+      if (!mlirOperationImplementsInterface(*operation,
+                                            ConcreteIface::getInterfaceID())) {
+        std::string msg = "the operation does not implement ";
+        throw py::value_error(msg + ConcreteIface::pyClassName);
+      }
+
+      MlirIdentifier identifier = mlirOperationGetName(*operation);
+      MlirStringRef stringRef = mlirIdentifierStr(identifier);
+      opName = std::string(stringRef.data, stringRef.length);
+    } else {
+      try {
+        opName = obj.attr("OPERATION_NAME").template cast<std::string>();
+      } catch (py::cast_error &err) {
+        throw py::type_error(
+            "Op interface does not refer to an operation or OpView class");
+      }
+
+      if (!mlirOperationImplementsInterfaceStatic(
+              mlirStringRefCreate(opName.data(), opName.length()),
+              context.resolve().get(), ConcreteIface::getInterfaceID())) {
+        std::string msg = "the operation does not implement ";
+        throw py::value_error(msg + ConcreteIface::pyClassName);
+      }
+    }
+  }
+
+  /// Creates the Python bindings for this class in the given module.
+  static void bind(py::module &m) {
+    py::class_<ConcreteIface> cls(m, "InferTypeOpInterface",
+                                  py::module_local());
+    cls.def(py::init<py::object, DefaultingPyMlirContext>(), py::arg("object"),
+            py::arg("context") = py::none(), constructorDoc)
+        .def_property_readonly("operation",
+                               &PyConcreteOpInterface::getOperationObject,
+                               operationDoc)
+        .def_property_readonly("opview", &PyConcreteOpInterface::getOpView,
+                               opviewDoc);
+    ConcreteIface::bindDerived(cls);
+  }
+
+  /// Hook for derived classes to add class-specific bindings.
+  static void bindDerived(ClassTy &cls) {}
+
+  /// Returns `true` if this object was constructed from a subclass of OpView
+  /// rather than from an operation instance.
+  bool isStatic() { return operation == nullptr; }
+
+  /// Returns the operation instance from which this object was constructed.
+  /// Throws a type error if this object was constructed from a subclass of
+  /// OpView.
+  py::object getOperationObject() {
+    if (operation == nullptr) {
+      throw py::type_error("Cannot get an operation from a static interface");
+    }
+
+    return operation->getRef().releaseObject();
+  }
+
+  /// Returns the opview of the operation instance from which this object was
+  /// constructed. Throws a type error if this object was constructed form a
+  /// subclass of OpView.
+  py::object getOpView() {
+    if (operation == nullptr) {
+      throw py::type_error("Cannot get an opview from a static interface");
+    }
+
+    return operation->createOpView();
+  }
+
+  /// Returns the canonical name of the operation this interface is constructed
+  /// from.
+  const std::string &getOpName() { return opName; }
+
+private:
+  PyOperation *operation = nullptr;
+  std::string opName;
+  py::object obj;
+};
+
+/// Python wrapper for InterTypeOpInterface. This interface has only static
+/// methods.
+class PyInferTypeOpInterface
+    : public PyConcreteOpInterface<PyInferTypeOpInterface> {
+public:
+  using PyConcreteOpInterface<PyInferTypeOpInterface>::PyConcreteOpInterface;
+
+  constexpr static const char *pyClassName = "InferTypeOpInterface";
+  constexpr static GetTypeIDFunctionTy getInterfaceID =
+      &mlirInferTypeOpInterfaceTypeID;
+
+  /// C-style user-data structure for type appending callback.
+  struct AppendResultsCallbackData {
+    std::vector<PyType> &inferredTypes;
+    PyMlirContext &pyMlirContext;
+  };
+
+  /// Appends the types provided as the two first arguments to the user-data
+  /// structure (expects AppendResultsCallbackData).
+  static void appendResultsCallback(intptr_t nTypes, MlirType *types,
+                                    void *userData) {
+    auto *data = static_cast<AppendResultsCallbackData *>(userData);
+    data->inferredTypes.reserve(data->inferredTypes.size() + nTypes);
+    for (intptr_t i = 0; i < nTypes; ++i) {
+      data->inferredTypes.push_back(
+          PyType(data->pyMlirContext.getRef(), types[i]));
+    }
+  }
+
+  /// Given the arguments required to build an operation, attempts to infer its
+  /// return types. Throws value_error on faliure.
+  std::vector<PyType>
+  inferReturnTypes(llvm::Optional<std::vector<PyValue>> operands,
+                   llvm::Optional<PyAttribute> attributes,
+                   llvm::Optional<std::vector<PyRegion>> regions,
+                   DefaultingPyMlirContext context,
+                   DefaultingPyLocation location) {
+    llvm::SmallVector<MlirValue> mlirOperands;
+    llvm::SmallVector<MlirRegion> mlirRegions;
+
+    if (operands) {
+      mlirOperands.reserve(operands->size());
+      for (PyValue &value : *operands) {
+        mlirOperands.push_back(value);
+      }
+    }
+
+    if (regions) {
+      mlirRegions.reserve(regions->size());
+      for (PyRegion &region : *regions) {
+        mlirRegions.push_back(region);
+      }
+    }
+
+    std::vector<PyType> inferredTypes;
+    PyMlirContext &pyContext = context.resolve();
+    AppendResultsCallbackData data{inferredTypes, pyContext};
+    MlirStringRef opNameRef =
+        mlirStringRefCreate(getOpName().data(), getOpName().length());
+    MlirAttribute attributeDict =
+        attributes ? attributes->get() : mlirAttributeGetNull();
+
+    MlirLogicalResult result = mlirInferTypeOpInterfaceInferReturnTypes(
+        opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(),
+        mlirOperands.data(), attributeDict, mlirRegions.size(),
+        mlirRegions.data(), &appendResultsCallback, &data);
+
+    if (mlirLogicalResultIsFailure(result)) {
+      throw py::value_error("Failed to infer result types");
+    }
+
+    return inferredTypes;
+  }
+
+  static void bindDerived(ClassTy &cls) {
+    cls.def("inferReturnTypes", &PyInferTypeOpInterface::inferReturnTypes,
+            py::arg("operands") = py::none(),
+            py::arg("attributes") = py::none(), py::arg("regions") = py::none(),
+            py::arg("context") = py::none(), py::arg("loc") = py::none(),
+            inferReturnTypesDoc);
+  }
+};
+
+void populateIRInterfaces(py::module &m) { PyInferTypeOpInterface::bind(m); }
+
+} // namespace python
+} // namespace mlir

diff  --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index ae85ef8507c30..59285c01a4bf1 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -859,6 +859,7 @@ class PyIntegerSet : public BaseContextObject {
 void populateIRAffine(pybind11::module &m);
 void populateIRAttributes(pybind11::module &m);
 void populateIRCore(pybind11::module &m);
+void populateIRInterfaces(pybind11::module &m);
 void populateIRTypes(pybind11::module &m);
 
 } // namespace python

diff  --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index cbade532e8ed0..5489a4d3e3810 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -85,6 +85,7 @@ PYBIND11_MODULE(_mlir, m) {
   populateIRCore(irModule);
   populateIRAffine(irModule);
   populateIRAttributes(irModule);
+  populateIRInterfaces(irModule);
   populateIRTypes(irModule);
 
   // Define and populate PassManager submodule.

diff  --git a/mlir/lib/CAPI/CMakeLists.txt b/mlir/lib/CAPI/CMakeLists.txt
index eed3f38d1a46e..30ccbe94acc53 100644
--- a/mlir/lib/CAPI/CMakeLists.txt
+++ b/mlir/lib/CAPI/CMakeLists.txt
@@ -2,6 +2,7 @@ add_subdirectory(Debug)
 add_subdirectory(Dialect)
 add_subdirectory(Conversion)
 add_subdirectory(ExecutionEngine)
+add_subdirectory(Interfaces)
 add_subdirectory(IR)
 add_subdirectory(Registration)
 add_subdirectory(Transforms)

diff  --git a/mlir/lib/CAPI/Interfaces/CMakeLists.txt b/mlir/lib/CAPI/Interfaces/CMakeLists.txt
new file mode 100644
index 0000000000000..1de5f21d8bac2
--- /dev/null
+++ b/mlir/lib/CAPI/Interfaces/CMakeLists.txt
@@ -0,0 +1,5 @@
+add_mlir_public_c_api_library(MLIRCAPIInterfaces
+  Interfaces.cpp
+
+  LINK_LIBS PUBLIC
+  MLIRInferTypeOpInterface)

diff  --git a/mlir/lib/CAPI/Interfaces/Interfaces.cpp b/mlir/lib/CAPI/Interfaces/Interfaces.cpp
new file mode 100644
index 0000000000000..315adb5fbaf68
--- /dev/null
+++ b/mlir/lib/CAPI/Interfaces/Interfaces.cpp
@@ -0,0 +1,82 @@
+//===- Interfaces.cpp - C Interface for MLIR Interfaces -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir-c/Interfaces.h"
+
+#include "mlir/CAPI/IR.h"
+#include "mlir/CAPI/Wrap.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "llvm/ADT/ScopeExit.h"
+
+using namespace mlir;
+
+bool mlirOperationImplementsInterface(MlirOperation operation,
+                                      MlirTypeID interfaceTypeID) {
+  const AbstractOperation *abstractOp =
+      unwrap(operation)->getAbstractOperation();
+  return abstractOp && abstractOp->hasInterface(unwrap(interfaceTypeID));
+}
+
+bool mlirOperationImplementsInterfaceStatic(MlirStringRef operationName,
+                                            MlirContext context,
+                                            MlirTypeID interfaceTypeID) {
+  const AbstractOperation *abstractOp = AbstractOperation::lookup(
+      StringRef(operationName.data, operationName.length), unwrap(context));
+  return abstractOp && abstractOp->hasInterface(unwrap(interfaceTypeID));
+}
+
+MlirTypeID mlirInferTypeOpInterfaceTypeID() {
+  return wrap(InferTypeOpInterface::getInterfaceID());
+}
+
+MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes(
+    MlirStringRef opName, MlirContext context, MlirLocation location,
+    intptr_t nOperands, MlirValue *operands, MlirAttribute attributes,
+    intptr_t nRegions, MlirRegion *regions, MlirTypesCallback callback,
+    void *userData) {
+  StringRef name(opName.data, opName.length);
+  const AbstractOperation *abstractOp =
+      AbstractOperation::lookup(name, unwrap(context));
+  if (!abstractOp)
+    return mlirLogicalResultFailure();
+
+  llvm::Optional<Location> maybeLocation = llvm::None;
+  if (!mlirLocationIsNull(location))
+    maybeLocation = unwrap(location);
+  SmallVector<Value> unwrappedOperands;
+  (void)unwrapList(nOperands, operands, unwrappedOperands);
+  DictionaryAttr attributeDict;
+  if (!mlirAttributeIsNull(attributes))
+    attributeDict = unwrap(attributes).cast<DictionaryAttr>();
+
+  // Create a vector of unique pointers to regions and make sure they are not
+  // deleted when exiting the scope. This is a hack caused by C++ API expecting
+  // an list of unique pointers to regions (without ownership transfer
+  // semantics) and C API making ownership transfer explicit.
+  SmallVector<std::unique_ptr<Region>> unwrappedRegions;
+  unwrappedRegions.reserve(nRegions);
+  for (intptr_t i = 0; i < nRegions; ++i)
+    unwrappedRegions.emplace_back(unwrap(*(regions + i)));
+  auto cleaner = llvm::make_scope_exit([&]() {
+    for (auto &region : unwrappedRegions)
+      region.release();
+  });
+
+  SmallVector<Type> inferredTypes;
+  if (failed(abstractOp->getInterface<InferTypeOpInterface>()->inferReturnTypes(
+          unwrap(context), maybeLocation, unwrappedOperands, attributeDict,
+          unwrappedRegions, inferredTypes)))
+    return mlirLogicalResultFailure();
+
+  SmallVector<MlirType> wrappedInferredTypes;
+  wrappedInferredTypes.reserve(inferredTypes.size());
+  for (Type t : inferredTypes)
+    wrappedInferredTypes.push_back(wrap(t));
+  callback(wrappedInferredTypes.size(), wrappedInferredTypes.data(), userData);
+  return mlirLogicalResultSuccess();
+}

diff  --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 8c60a31b0cd7e..54cc51f0ba173 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -113,12 +113,25 @@ declare_mlir_dialect_python_bindings(
     dialects/_memref_ops_ext.py
   DIALECT_NAME memref)
 
-declare_mlir_dialect_python_bindings(
-  ADD_TO_PARENT MLIRPythonTestSources.Dialects
+# TODO: this uses a tablegen file from the test directory and should be
+# decoupled from here.
+declare_mlir_python_sources(
+  MLIRPythonSources.Dialects.PythonTest
   ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
-  TD_FILE dialects/PythonTest.td
-  SOURCES dialects/python_test.py
-  DIALECT_NAME python_test)
+  ADD_TO_PARENT MLIRPythonSources.Dialects
+  SOURCES dialects/python_test.py)
+set(LLVM_TARGET_DEFINITIONS
+  "${MLIR_MAIN_SRC_DIR}/test/python/python_test_ops.td")
+mlir_tablegen(
+  "dialects/_python_test_ops_gen.py"
+  -gen-python-op-bindings
+  -bind-dialect=python_test)
+add_public_tablegen_target(PythonTestDialectPyIncGen)
+declare_mlir_python_sources(
+  MLIRPythonSources.Dialects.PythonTest.ops_gen
+  ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}"
+  ADD_TO_PARENT MLIRPythonSources.Dialects.PythonTest
+  SOURCES "dialects/_python_test_ops_gen.py")
 
 declare_mlir_dialect_python_bindings(
   ADD_TO_PARENT MLIRPythonSources.Dialects
@@ -192,6 +205,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Core
     ${PYTHON_SOURCE_DIR}/IRAffine.cpp
     ${PYTHON_SOURCE_DIR}/IRAttributes.cpp
     ${PYTHON_SOURCE_DIR}/IRCore.cpp
+    ${PYTHON_SOURCE_DIR}/IRInterfaces.cpp
     ${PYTHON_SOURCE_DIR}/IRModule.cpp
     ${PYTHON_SOURCE_DIR}/IRTypes.cpp
     ${PYTHON_SOURCE_DIR}/PybindUtils.cpp
@@ -201,6 +215,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Core
   EMBED_CAPI_LINK_LIBS
     MLIRCAPIDebug
     MLIRCAPIIR
+    MLIRCAPIInterfaces
     MLIRCAPIRegistration  # TODO: See about dis-aggregating
 
     # Dialects
@@ -297,6 +312,20 @@ declare_mlir_python_extension(MLIRPythonExtension.Transforms
     MLIRCAPITransforms
 )
 
+# TODO: This should not be included in the main Python extension. However,
+# putting it into MLIRPythonTestSources along with the dialect declaration
+# above confuses Python module loader when running under lit.
+declare_mlir_python_extension(MLIRPythonExtension.PythonTest
+  MODULE_NAME _mlirPythonTest
+  ADD_TO_PARENT MLIRPythonSources.Dialects
+  SOURCES
+    ${MLIR_SOURCE_DIR}/test/python/lib/PythonTestModule.cpp
+  PRIVATE_LINK_LIBS
+    LLVMSupport
+  EMBED_CAPI_LINK_LIBS
+    MLIRCAPIPythonTestDialect
+)
+
 ################################################################################
 # Common CAPI dependency DSO.
 # All python extensions must link through one DSO which exports the CAPI, and
@@ -336,7 +365,6 @@ add_mlir_python_modules(MLIRPythonModules
     MLIRPythonCAPI
   )
 
-
 add_mlir_python_modules(MLIRPythonTestModules
   ROOT_PREFIX "${MLIR_BINARY_DIR}/python_packages/mlir_test/mlir"
   INSTALL_PREFIX "python_packages/mlir_test/mlir"

diff  --git a/mlir/python/mlir/dialects/PythonTest.td b/mlir/python/mlir/dialects/PythonTest.td
deleted file mode 100644
index d3d49395ad45a..0000000000000
--- a/mlir/python/mlir/dialects/PythonTest.td
+++ /dev/null
@@ -1,33 +0,0 @@
-//===-- python_test_ops.td - Python test Op definitions ----*- tablegen -*-===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#ifndef PYTHON_TEST_OPS
-#define PYTHON_TEST_OPS
-
-include "mlir/Bindings/Python/Attributes.td"
-include "mlir/IR/OpBase.td"
-
-def Python_Test_Dialect : Dialect {
-  let name = "python_test";
-  let cppNamespace = "PythonTest";
-}
-class TestOp<string mnemonic, list<OpTrait> traits = []>
-    : Op<Python_Test_Dialect, mnemonic, traits>;
-
-def AttributedOp : TestOp<"attributed_op"> {
-  let arguments = (ins I32Attr:$mandatory_i32,
-                   OptionalAttr<I32Attr>:$optional_i32,
-                   UnitAttr:$unit);
-}
-
-def PropertyOp : TestOp<"property_op"> {
-  let arguments = (ins I32Attr:$property,
-                   I32:$idx);
-}
-
-#endif // PYTHON_TEST_OPS

diff  --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py
index 524db4317678d..82c01d5a091c7 100644
--- a/mlir/python/mlir/dialects/python_test.py
+++ b/mlir/python/mlir/dialects/python_test.py
@@ -3,3 +3,8 @@
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 
 from ._python_test_ops_gen import *
+
+
+def register_python_test_dialect(context, load=True):
+  from .._mlir_libs import _mlirPythonTest
+  _mlirPythonTest.register_python_test_dialect(context, load)

diff  --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt
index 33ea11ba97828..8674c65cf4863 100644
--- a/mlir/test/CMakeLists.txt
+++ b/mlir/test/CMakeLists.txt
@@ -1,6 +1,10 @@
 add_subdirectory(CAPI)
 add_subdirectory(lib)
 
+if (MLIR_ENABLE_BINDINGS_PYTHON)
+  add_subdirectory(python)
+endif()
+
 # Passed to lit.site.cfg.py.so that the out of tree Standalone dialect test
 # can find MLIR's CMake configuration
 set(MLIR_CMAKE_DIR

diff  --git a/mlir/test/python/CMakeLists.txt b/mlir/test/python/CMakeLists.txt
new file mode 100644
index 0000000000000..c8cb474760e2c
--- /dev/null
+++ b/mlir/test/python/CMakeLists.txt
@@ -0,0 +1,8 @@
+set(LLVM_TARGET_DEFINITIONS python_test_ops.td)
+mlir_tablegen(lib/PythonTestDialect.h.inc -gen-dialect-decls)
+mlir_tablegen(lib/PythonTestDialect.cpp.inc -gen-dialect-defs)
+mlir_tablegen(lib/PythonTestOps.h.inc -gen-op-decls)
+mlir_tablegen(lib/PythonTestOps.cpp.inc -gen-op-defs)
+add_public_tablegen_target(MLIRPythonTestIncGen)
+
+add_subdirectory(lib)

diff  --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 231c5ad311e43..3d0600e331a57 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -6,8 +6,10 @@
 def run(f):
   print("\nTEST:", f.__name__)
   f()
+  return f
 
 # CHECK-LABEL: TEST: testAttributes
+ at run
 def testAttributes():
   with Context() as ctx, Location.unknown():
     ctx.allow_unregistered_dialects = True
@@ -127,4 +129,47 @@ def testAttributes():
     del op.unit
     print(f"Unit: {op.unit}")
 
-run(testAttributes)
+
+# CHECK-LABEL: TEST: inferReturnTypes
+ at run
+def inferReturnTypes():
+  with Context() as ctx, Location.unknown(ctx):
+    test.register_python_test_dialect(ctx)
+    module = Module.create()
+    with InsertionPoint(module.body):
+      op = test.InferResultsOp(
+          IntegerType.get_signless(32), IntegerType.get_signless(64))
+      dummy = test.DummyOp()
+
+    # CHECK: [Type(i32), Type(i64)]
+    iface = InferTypeOpInterface(op)
+    print(iface.inferReturnTypes())
+
+    # CHECK: [Type(i32), Type(i64)]
+    iface_static = InferTypeOpInterface(test.InferResultsOp)
+    print(iface.inferReturnTypes())
+
+    assert isinstance(iface.opview, test.InferResultsOp)
+    assert iface.opview == iface.operation.opview
+
+    try:
+      iface_static.opview
+    except TypeError:
+      pass
+    else:
+      assert False, ("not expected to be able to obtain an opview from a static"
+                     " interface")
+
+    try:
+      InferTypeOpInterface(dummy)
+    except ValueError:
+      pass
+    else:
+      assert False, "not expected dummy op to implement the interface"
+
+    try:
+      InferTypeOpInterface(test.DummyOp)
+    except ValueError:
+      pass
+    else:
+      assert False, "not expected dummy op class to implement the interface"

diff  --git a/mlir/test/python/lib/CMakeLists.txt b/mlir/test/python/lib/CMakeLists.txt
new file mode 100644
index 0000000000000..cd45eec0ce85c
--- /dev/null
+++ b/mlir/test/python/lib/CMakeLists.txt
@@ -0,0 +1,33 @@
+set(LLVM_OPTIONAL_SOURCES
+  PythonTestCAPI.cpp
+  PythonTestDialect.cpp
+  PythonTestModule.cpp
+)
+
+add_mlir_library(MLIRPythonTestDialect
+  PythonTestDialect.cpp
+
+  EXCLUDE_FROM_LIBMLIR
+
+  DEPENDS
+  MLIRPythonTestIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRInferTypeOpInterface
+  MLIRIR
+  MLIRSupport
+)
+
+add_mlir_public_c_api_library(MLIRCAPIPythonTestDialect
+  PythonTestCAPI.cpp
+
+  DEPENDS
+  MLIRPythonTestIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRCAPIInterfaces
+  MLIRCAPIIR
+  MLIRCAPIRegistration
+  MLIRPythonTestDialect
+)
+

diff  --git a/mlir/test/python/lib/PythonTestCAPI.cpp b/mlir/test/python/lib/PythonTestCAPI.cpp
new file mode 100644
index 0000000000000..474476e741985
--- /dev/null
+++ b/mlir/test/python/lib/PythonTestCAPI.cpp
@@ -0,0 +1,14 @@
+//===- PythonTestCAPI.cpp - C API for the PythonTest dialect --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "PythonTestCAPI.h"
+#include "PythonTestDialect.h"
+#include "mlir/CAPI/Registration.h"
+
+MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(PythonTest, python_test,
+                                      python_test::PythonTestDialect)

diff  --git a/mlir/test/python/lib/PythonTestCAPI.h b/mlir/test/python/lib/PythonTestCAPI.h
new file mode 100644
index 0000000000000..627ce3fe9a151
--- /dev/null
+++ b/mlir/test/python/lib/PythonTestCAPI.h
@@ -0,0 +1,24 @@
+//===- PythonTestCAPI.h - C API for the PythonTest dialect ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TEST_PYTHON_LIB_PYTHONTESTCAPI_H
+#define MLIR_TEST_PYTHON_LIB_PYTHONTESTCAPI_H
+
+#include "mlir-c/Registration.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(PythonTest, python_test);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif // MLIR_TEST_PYTHON_LIB_PYTHONTESTCAPI_H

diff  --git a/mlir/test/python/lib/PythonTestDialect.cpp b/mlir/test/python/lib/PythonTestDialect.cpp
new file mode 100644
index 0000000000000..b70c0336b2f64
--- /dev/null
+++ b/mlir/test/python/lib/PythonTestDialect.cpp
@@ -0,0 +1,25 @@
+//===- PythonTestDialect.cpp - PythonTest dialect definition --------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "PythonTestDialect.h"
+#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/OpImplementation.h"
+
+#include "PythonTestDialect.cpp.inc"
+
+#define GET_OP_CLASSES
+#include "PythonTestOps.cpp.inc"
+
+namespace python_test {
+void PythonTestDialect::initialize() {
+  addOperations<
+#define GET_OP_LIST
+#include "PythonTestOps.cpp.inc"
+      >();
+}
+} // namespace python_test

diff  --git a/mlir/test/python/lib/PythonTestDialect.h b/mlir/test/python/lib/PythonTestDialect.h
new file mode 100644
index 0000000000000..e25d00ceec980
--- /dev/null
+++ b/mlir/test/python/lib/PythonTestDialect.h
@@ -0,0 +1,21 @@
+//===- PythonTestDialect.h - PythonTest dialect definition ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TEST_PYTHON_LIB_PYTHONTESTDIALECT_H
+#define MLIR_TEST_PYTHON_LIB_PYTHONTESTDIALECT_H
+
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
+
+#include "PythonTestDialect.h.inc"
+
+#define GET_OP_CLASSES
+#include "PythonTestOps.h.inc"
+
+#endif // MLIR_TEST_PYTHON_LIB_PYTHONTESTDIALECT_H

diff  --git a/mlir/test/python/lib/PythonTestModule.cpp b/mlir/test/python/lib/PythonTestModule.cpp
new file mode 100644
index 0000000000000..4232a86518636
--- /dev/null
+++ b/mlir/test/python/lib/PythonTestModule.cpp
@@ -0,0 +1,26 @@
+//===- PythonTestModule.cpp - Python extension for the PythonTest dialect -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "PythonTestCAPI.h"
+#include "mlir/Bindings/Python/PybindAdaptors.h"
+
+namespace py = pybind11;
+
+PYBIND11_MODULE(_mlirPythonTest, m) {
+  m.def(
+      "register_python_test_dialect",
+      [](MlirContext context, bool load) {
+        MlirDialectHandle pythonTestDialect =
+            mlirGetDialectHandle__python_test__();
+        mlirDialectHandleRegisterDialect(pythonTestDialect, context);
+        if (load) {
+          mlirDialectHandleLoadDialect(pythonTestDialect, context);
+        }
+      },
+      py::arg("context"), py::arg("load") = true);
+}

diff  --git a/mlir/test/python/python_test_ops.td b/mlir/test/python/python_test_ops.td
index d3d49395ad45a..74c90a311f049 100644
--- a/mlir/test/python/python_test_ops.td
+++ b/mlir/test/python/python_test_ops.td
@@ -11,10 +11,11 @@
 
 include "mlir/Bindings/Python/Attributes.td"
 include "mlir/IR/OpBase.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
 
 def Python_Test_Dialect : Dialect {
   let name = "python_test";
-  let cppNamespace = "PythonTest";
+  let cppNamespace = "python_test";
 }
 class TestOp<string mnemonic, list<OpTrait> traits = []>
     : Op<Python_Test_Dialect, mnemonic, traits>;
@@ -30,4 +31,25 @@ def PropertyOp : TestOp<"property_op"> {
                    I32:$idx);
 }
 
+def DummyOp : TestOp<"dummy_op"> {
+}
+
+def InferResultsOp : TestOp<"infer_results_op", [InferTypeOpInterface]> {
+  let arguments = (ins);
+  let results = (outs AnyInteger:$single, AnyInteger:$doubled);
+
+  let extraClassDeclaration = [{
+    static ::mlir::LogicalResult inferReturnTypes(
+      ::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location,
+      ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
+      ::mlir::RegionRange regions,
+      ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
+      ::mlir::Builder b(context);
+      inferredReturnTypes.push_back(b.getI32Type());
+      inferredReturnTypes.push_back(b.getI64Type());
+      return ::mlir::success();
+    }
+  }];
+}
+
 #endif // PYTHON_TEST_OPS

diff  --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index fea353d460cb9..a0339919063fe 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -333,6 +333,7 @@ cc_library(
         "include/mlir-c/ExecutionEngine.h",
         "include/mlir-c/IR.h",
         "include/mlir-c/IntegerSet.h",
+        "include/mlir-c/Interfaces.h",
         "include/mlir-c/Pass.h",
         "include/mlir-c/Registration.h",
         "include/mlir-c/Support.h",
@@ -360,6 +361,20 @@ cc_library(
     ],
 )
 
+cc_library(
+    name = "CAPIInterfaces",
+    srcs = [
+        "lib/CAPI/Interfaces/Interfaces.cpp",
+    ],
+    includes = ["include"],
+    deps = [
+        ":CAPIIR",
+        ":IR",
+        ":InferTypeOpInterface",
+        "//llvm:Support",
+    ],
+)
+
 cc_library(
     name = "CAPIAsync",
     srcs = [
@@ -558,6 +573,7 @@ cc_library(
         "lib/Bindings/Python/IRAffine.cpp",
         "lib/Bindings/Python/IRAttributes.cpp",
         "lib/Bindings/Python/IRCore.cpp",
+        "lib/Bindings/Python/IRInterfaces.cpp",
         "lib/Bindings/Python/IRModule.cpp",
         "lib/Bindings/Python/IRTypes.cpp",
         "lib/Bindings/Python/Pass.cpp",
@@ -581,6 +597,7 @@ cc_library(
         ":CAPIDebug",
         ":CAPIGPU",
         ":CAPIIR",
+        ":CAPIInterfaces",
         ":CAPILinalg",
         ":CAPIRegistration",
         ":CAPISparseTensor",


        


More information about the Mlir-commits mailing list