[Mlir-commits] [mlir] f22008e - [MLIR] Add InferShapedTypeOpInterface bindings

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 11 14:24:10 PDT 2023


Author: Arash Taheri-Dezfouli
Date: 2023-05-11T16:20:47-05:00
New Revision: f22008ed89eac028cd70f91de3adf41a481f6d22

URL: https://github.com/llvm/llvm-project/commit/f22008ed89eac028cd70f91de3adf41a481f6d22
DIFF: https://github.com/llvm/llvm-project/commit/f22008ed89eac028cd70f91de3adf41a481f6d22.diff

LOG: [MLIR] Add InferShapedTypeOpInterface bindings

Add C and python bindings for InferShapedTypeOpInterface
and ShapedTypeComponents. This allows users to invoke
InferShapedTypeOpInterface for ops that implement it.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D149494

Added: 
    

Modified: 
    mlir/include/mlir-c/Interfaces.h
    mlir/lib/Bindings/Python/IRInterfaces.cpp
    mlir/lib/CAPI/Interfaces/Interfaces.cpp
    mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
    mlir/test/python/dialects/python_test.py
    mlir/test/python/python_test_ops.td

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/Interfaces.h b/mlir/include/mlir-c/Interfaces.h
index 405e2bb7173e0..a5a3473eaef59 100644
--- a/mlir/include/mlir-c/Interfaces.h
+++ b/mlir/include/mlir-c/Interfaces.h
@@ -60,6 +60,33 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes(
     void *properties, intptr_t nRegions, MlirRegion *regions,
     MlirTypesCallback callback, void *userData);
 
+//===----------------------------------------------------------------------===//
+// InferShapedTypeOpInterface.
+//===----------------------------------------------------------------------===//
+
+/// Returns the interface TypeID of the InferShapedTypeOpInterface.
+MLIR_CAPI_EXPORTED MlirTypeID mlirInferShapedTypeOpInterfaceTypeID();
+
+/// These callbacks are used to return multiple shaped type components from
+/// functions while transferring ownership to the caller. The first argument is
+/// the has rank boolean followed by the the rank and a pointer to the shape
+/// (if applicable). The next argument is the element type, then the attribute.
+/// The last argument is an opaque pointer forwarded to the callback by the
+/// caller. This callback will be called potentially multiple times for each
+/// shaped type components.
+typedef void (*MlirShapedTypeComponentsCallback)(bool, intptr_t,
+                                                 const int64_t *, MlirType,
+                                                 MlirAttribute, void *);
+
+/// Infers the return shaped type components of the operation. Calls `callback`
+/// with the types of inferred arguments on success. Returns failure otherwise.
+MLIR_CAPI_EXPORTED MlirLogicalResult
+mlirInferShapedTypeOpInterfaceInferReturnTypes(
+    MlirStringRef opName, MlirContext context, MlirLocation location,
+    intptr_t nOperands, MlirValue *operands, MlirAttribute attributes,
+    void *properties, intptr_t nRegions, MlirRegion *regions,
+    MlirShapedTypeComponentsCallback callback, void *userData);
+
 #ifdef __cplusplus
 }
 #endif

diff  --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp
index 766d6f3e4793e..0a7a25c0005fe 100644
--- a/mlir/lib/Bindings/Python/IRInterfaces.cpp
+++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp
@@ -6,8 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include <utility>
 #include <optional>
+#include <utility>
 
 #include "IRModule.h"
 #include "mlir-c/BuiltinAttributes.h"
@@ -35,6 +35,83 @@ constexpr static const char *inferReturnTypesDoc =
     R"(Given the arguments required to build an operation, attempts to infer
 its return types. Raises ValueError on failure.)";
 
+constexpr static const char *inferReturnTypeComponentsDoc =
+    R"(Given the arguments required to build an operation, attempts to infer
+its return shaped type components. Raises ValueError on failure.)";
+
+namespace {
+
+/// Takes in an optional ist of operands and converts them into a SmallVector
+/// of MlirVlaues. Returns an empty SmallVector if the list is empty.
+llvm::SmallVector<MlirValue> wrapOperands(std::optional<py::list> operandList) {
+  llvm::SmallVector<MlirValue> mlirOperands;
+
+  if (!operandList || operandList->empty()) {
+    return mlirOperands;
+  }
+
+  // Note: as the list may contain other lists this may not be final size.
+  mlirOperands.reserve(operandList->size());
+  for (const auto &&it : llvm::enumerate(*operandList)) {
+    PyValue *val;
+    try {
+      val = py::cast<PyValue *>(it.value());
+      if (!val)
+        throw py::cast_error();
+      mlirOperands.push_back(val->get());
+      continue;
+    } catch (py::cast_error &err) {
+      // Intentionally unhandled to try sequence below first.
+      (void)err;
+    }
+
+    try {
+      auto vals = py::cast<py::sequence>(it.value());
+      for (py::object v : vals) {
+        try {
+          val = py::cast<PyValue *>(v);
+          if (!val)
+            throw py::cast_error();
+          mlirOperands.push_back(val->get());
+        } catch (py::cast_error &err) {
+          throw py::value_error(
+              (llvm::Twine("Operand ") + llvm::Twine(it.index()) +
+               " must be a Value or Sequence of Values (" + err.what() + ")")
+                  .str());
+        }
+      }
+      continue;
+    } catch (py::cast_error &err) {
+      throw py::value_error((llvm::Twine("Operand ") + llvm::Twine(it.index()) +
+                             " must be a Value or Sequence of Values (" +
+                             err.what() + ")")
+                                .str());
+    }
+
+    throw py::cast_error();
+  }
+
+  return mlirOperands;
+}
+
+/// Takes in an optional vector of PyRegions and returns a SmallVector of
+/// MlirRegion. Returns an empty SmallVector if the list is empty.
+llvm::SmallVector<MlirRegion>
+wrapRegions(std::optional<std::vector<PyRegion>> regions) {
+  llvm::SmallVector<MlirRegion> mlirRegions;
+
+  if (regions) {
+    mlirRegions.reserve(regions->size());
+    for (PyRegion &region : *regions) {
+      mlirRegions.push_back(region);
+    }
+  }
+
+  return mlirRegions;
+}
+
+} // namespace
+
 /// CRTP base class for Python classes representing MLIR Op interfaces.
 /// Interface hierarchies are flat so no base class is expected here. The
 /// derived class is expected to define the following static fields:
@@ -104,7 +181,7 @@ class PyConcreteOpInterface {
 
   /// Creates the Python bindings for this class in the given module.
   static void bind(py::module &m) {
-    py::class_<ConcreteIface> cls(m, "InferTypeOpInterface",
+    py::class_<ConcreteIface> cls(m, ConcreteIface::pyClassName,
                                   py::module_local());
     cls.def(py::init<py::object, DefaultingPyMlirContext>(), py::arg("object"),
             py::arg("context") = py::none(), constructorDoc)
@@ -155,7 +232,7 @@ class PyConcreteOpInterface {
   py::object obj;
 };
 
-/// Python wrapper for InterTypeOpInterface. This interface has only static
+/// Python wrapper for InferTypeOpInterface. This interface has only static
 /// methods.
 class PyInferTypeOpInterface
     : public PyConcreteOpInterface<PyInferTypeOpInterface> {
@@ -191,59 +268,8 @@ class PyInferTypeOpInterface
                    std::optional<std::vector<PyRegion>> regions,
                    DefaultingPyMlirContext context,
                    DefaultingPyLocation location) {
-    llvm::SmallVector<MlirValue> mlirOperands;
-    llvm::SmallVector<MlirRegion> mlirRegions;
-
-    if (operandList && !operandList->empty()) {
-      // Note: as the list may contain other lists this may not be final size.
-      mlirOperands.reserve(operandList->size());
-      for (const auto& it : llvm::enumerate(*operandList)) {
-        PyValue* val;
-        try {
-          val = py::cast<PyValue *>(it.value());
-          if (!val)
-            throw py::cast_error();
-          mlirOperands.push_back(val->get());
-          continue;
-        } catch (py::cast_error &err) {
-          // Intentionally unhandled to try sequence below first.
-          (void)err;
-        }
-
-        try {
-          auto vals = py::cast<py::sequence>(it.value());
-          for (py::object v : vals) {
-            try {
-              val = py::cast<PyValue *>(v);
-              if (!val)
-                throw py::cast_error();
-              mlirOperands.push_back(val->get());
-            } catch (py::cast_error &err) {
-              throw py::value_error(
-                  (llvm::Twine("Operand ") + llvm::Twine(it.index()) +
-                   " must be a Value or Sequence of Values (" + err.what() +
-                   ")")
-                      .str());
-            }
-          }
-          continue;
-        } catch (py::cast_error &err) {
-          throw py::value_error(
-              (llvm::Twine("Operand ") + llvm::Twine(it.index()) +
-               " must be a Value or Sequence of Values (" + err.what() + ")")
-                  .str());
-        }
-
-        throw py::cast_error();
-      }
-    }
-
-    if (regions) {
-      mlirRegions.reserve(regions->size());
-      for (PyRegion &region : *regions) {
-        mlirRegions.push_back(region);
-      }
-    }
+    llvm::SmallVector<MlirValue> mlirOperands = wrapOperands(operandList);
+    llvm::SmallVector<MlirRegion> mlirRegions = wrapRegions(regions);
 
     std::vector<PyType> inferredTypes;
     PyMlirContext &pyContext = context.resolve();
@@ -275,7 +301,172 @@ class PyInferTypeOpInterface
   }
 };
 
-void populateIRInterfaces(py::module &m) { PyInferTypeOpInterface::bind(m); }
+/// Wrapper around an shaped type components.
+class PyShapedTypeComponents {
+public:
+  PyShapedTypeComponents(MlirType elementType) : elementType(elementType) {}
+  PyShapedTypeComponents(py::list shape, MlirType elementType)
+      : shape(shape), elementType(elementType), ranked(true) {}
+  PyShapedTypeComponents(py::list shape, MlirType elementType,
+                         MlirAttribute attribute)
+      : shape(shape), elementType(elementType), attribute(attribute),
+        ranked(true) {}
+  PyShapedTypeComponents(PyShapedTypeComponents &) = delete;
+  PyShapedTypeComponents(PyShapedTypeComponents &&other)
+      : shape(other.shape), elementType(other.elementType),
+        attribute(other.attribute), ranked(other.ranked) {}
+
+  static void bind(py::module &m) {
+    py::class_<PyShapedTypeComponents>(m, "ShapedTypeComponents",
+                                       py::module_local())
+        .def_property_readonly(
+            "element_type",
+            [](PyShapedTypeComponents &self) {
+              return PyType(PyMlirContext::forContext(
+                                mlirTypeGetContext(self.elementType)),
+                            self.elementType);
+            },
+            "Returns the element type of the shaped type components.")
+        .def_static(
+            "get",
+            [](PyType &elementType) {
+              return PyShapedTypeComponents(elementType);
+            },
+            py::arg("element_type"),
+            "Create an shaped type components object with only the element "
+            "type.")
+        .def_static(
+            "get",
+            [](py::list shape, PyType &elementType) {
+              return PyShapedTypeComponents(shape, elementType);
+            },
+            py::arg("shape"), py::arg("element_type"),
+            "Create a ranked shaped type components object.")
+        .def_static(
+            "get",
+            [](py::list shape, PyType &elementType, PyAttribute &attribute) {
+              return PyShapedTypeComponents(shape, elementType, attribute);
+            },
+            py::arg("shape"), py::arg("element_type"), py::arg("attribute"),
+            "Create a ranked shaped type components object with attribute.")
+        .def_property_readonly(
+            "has_rank",
+            [](PyShapedTypeComponents &self) -> bool { return self.ranked; },
+            "Returns whether the given shaped type component is ranked.")
+        .def_property_readonly(
+            "rank",
+            [](PyShapedTypeComponents &self) -> py::object {
+              if (!self.ranked) {
+                return py::none();
+              }
+              return py::int_(self.shape.size());
+            },
+            "Returns the rank of the given ranked shaped type components. If "
+            "the shaped type components does not have a rank, None is "
+            "returned.")
+        .def_property_readonly(
+            "shape",
+            [](PyShapedTypeComponents &self) -> py::object {
+              if (!self.ranked) {
+                return py::none();
+              }
+              return py::list(self.shape);
+            },
+            "Returns the shape of the ranked shaped type components as a list "
+            "of integers. Returns none if the shaped type component does not "
+            "have a rank.");
+  }
+
+  pybind11::object getCapsule();
+  static PyShapedTypeComponents createFromCapsule(pybind11::object capsule);
+
+private:
+  py::list shape;
+  MlirType elementType;
+  MlirAttribute attribute;
+  bool ranked{false};
+};
+
+/// Python wrapper for InferShapedTypeOpInterface. This interface has only
+/// static methods.
+class PyInferShapedTypeOpInterface
+    : public PyConcreteOpInterface<PyInferShapedTypeOpInterface> {
+public:
+  using PyConcreteOpInterface<
+      PyInferShapedTypeOpInterface>::PyConcreteOpInterface;
+
+  constexpr static const char *pyClassName = "InferShapedTypeOpInterface";
+  constexpr static GetTypeIDFunctionTy getInterfaceID =
+      &mlirInferShapedTypeOpInterfaceTypeID;
+
+  /// C-style user-data structure for type appending callback.
+  struct AppendResultsCallbackData {
+    std::vector<PyShapedTypeComponents> &inferredShapedTypeComponents;
+  };
+
+  /// Appends the shaped type components provided as unpacked shape, element
+  /// type, attribute to the user-data.
+  static void appendResultsCallback(bool hasRank, intptr_t rank,
+                                    const int64_t *shape, MlirType elementType,
+                                    MlirAttribute attribute, void *userData) {
+    auto *data = static_cast<AppendResultsCallbackData *>(userData);
+    if (!hasRank) {
+      data->inferredShapedTypeComponents.emplace_back(elementType);
+    } else {
+      py::list shapeList;
+      for (intptr_t i = 0; i < rank; ++i) {
+        shapeList.append(shape[i]);
+      }
+      data->inferredShapedTypeComponents.emplace_back(shapeList, elementType,
+                                                      attribute);
+    }
+  }
+
+  /// Given the arguments required to build an operation, attempts to infer the
+  /// shaped type components. Throws value_error on failure.
+  std::vector<PyShapedTypeComponents> inferReturnTypeComponents(
+      std::optional<py::list> operandList,
+      std::optional<PyAttribute> attributes, void *properties,
+      std::optional<std::vector<PyRegion>> regions,
+      DefaultingPyMlirContext context, DefaultingPyLocation location) {
+    llvm::SmallVector<MlirValue> mlirOperands = wrapOperands(operandList);
+    llvm::SmallVector<MlirRegion> mlirRegions = wrapRegions(regions);
+
+    std::vector<PyShapedTypeComponents> inferredShapedTypeComponents;
+    PyMlirContext &pyContext = context.resolve();
+    AppendResultsCallbackData data{inferredShapedTypeComponents};
+    MlirStringRef opNameRef =
+        mlirStringRefCreate(getOpName().data(), getOpName().length());
+    MlirAttribute attributeDict =
+        attributes ? attributes->get() : mlirAttributeGetNull();
+
+    MlirLogicalResult result = mlirInferShapedTypeOpInterfaceInferReturnTypes(
+        opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(),
+        mlirOperands.data(), attributeDict, properties, mlirRegions.size(),
+        mlirRegions.data(), &appendResultsCallback, &data);
+
+    if (mlirLogicalResultIsFailure(result)) {
+      throw py::value_error("Failed to infer result shape type components");
+    }
+
+    return inferredShapedTypeComponents;
+  }
+
+  static void bindDerived(ClassTy &cls) {
+    cls.def("inferReturnTypeComponents",
+            &PyInferShapedTypeOpInterface::inferReturnTypeComponents,
+            py::arg("operands") = py::none(),
+            py::arg("attributes") = py::none(), py::arg("regions") = py::none(),
+            py::arg("properties") = py::none(), py::arg("context") = py::none(),
+            py::arg("loc") = py::none(), inferReturnTypeComponentsDoc);
+  }
+};
+
+void populateIRInterfaces(py::module &m) {
+  PyInferTypeOpInterface::bind(m);
+  PyShapedTypeComponents::bind(m);
+  PyInferShapedTypeOpInterface::bind(m);
+}
 
 } // namespace python
 } // namespace mlir

diff  --git a/mlir/lib/CAPI/Interfaces/Interfaces.cpp b/mlir/lib/CAPI/Interfaces/Interfaces.cpp
index 029feed3a3593..e597a7bcb4f23 100644
--- a/mlir/lib/CAPI/Interfaces/Interfaces.cpp
+++ b/mlir/lib/CAPI/Interfaces/Interfaces.cpp
@@ -11,14 +11,65 @@
 #include "mlir-c/Interfaces.h"
 
 #include "mlir/CAPI/IR.h"
+#include "mlir/CAPI/Interfaces.h"
 #include "mlir/CAPI/Support.h"
 #include "mlir/CAPI/Wrap.h"
+#include "mlir/IR/ValueRange.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "llvm/ADT/ScopeExit.h"
 #include <optional>
 
 using namespace mlir;
 
+namespace {
+
+std::optional<RegisteredOperationName>
+getRegisteredOperationName(MlirContext context, MlirStringRef opName) {
+  StringRef name(opName.data, opName.length);
+  std::optional<RegisteredOperationName> info =
+      RegisteredOperationName::lookup(name, unwrap(context));
+  return info;
+}
+
+std::optional<Location> maybeGetLocation(MlirLocation location) {
+  std::optional<Location> maybeLocation;
+  if (!mlirLocationIsNull(location))
+    maybeLocation = unwrap(location);
+  return maybeLocation;
+}
+
+SmallVector<Value> unwrapOperands(intptr_t nOperands, MlirValue *operands) {
+  SmallVector<Value> unwrappedOperands;
+  (void)unwrapList(nOperands, operands, unwrappedOperands);
+  return unwrappedOperands;
+}
+
+DictionaryAttr unwrapAttributes(MlirAttribute attributes) {
+  DictionaryAttr attributeDict;
+  if (!mlirAttributeIsNull(attributes))
+    attributeDict = unwrap(attributes).cast<DictionaryAttr>();
+  return attributeDict;
+}
+
+SmallVector<std::unique_ptr<Region>> unwrapRegions(intptr_t nRegions,
+                                                   MlirRegion *regions) {
+  // Create a vector of unique pointers to regions and make sure they are not
+  // deleted when exiting the scope. This is a hack caused by C++ API expecting
+  // an list of unique pointers to regions (without ownership transfer
+  // semantics) and C API making ownership transfer explicit.
+  SmallVector<std::unique_ptr<Region>> unwrappedRegions;
+  unwrappedRegions.reserve(nRegions);
+  for (intptr_t i = 0; i < nRegions; ++i)
+    unwrappedRegions.emplace_back(unwrap(*(regions + i)));
+  auto cleaner = llvm::make_scope_exit([&]() {
+    for (auto &region : unwrappedRegions)
+      region.release();
+  });
+  return unwrappedRegions;
+}
+
+} // namespace
+
 bool mlirOperationImplementsInterface(MlirOperation operation,
                                       MlirTypeID interfaceTypeID) {
   std::optional<RegisteredOperationName> info =
@@ -45,31 +96,15 @@ MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes(
     MlirTypesCallback callback, void *userData) {
   StringRef name(opName.data, opName.length);
   std::optional<RegisteredOperationName> info =
-      RegisteredOperationName::lookup(name, unwrap(context));
+      getRegisteredOperationName(context, opName);
   if (!info)
     return mlirLogicalResultFailure();
 
-  std::optional<Location> maybeLocation;
-  if (!mlirLocationIsNull(location))
-    maybeLocation = unwrap(location);
-  SmallVector<Value> unwrappedOperands;
-  (void)unwrapList(nOperands, operands, unwrappedOperands);
-  DictionaryAttr attributeDict;
-  if (!mlirAttributeIsNull(attributes))
-    attributeDict = unwrap(attributes).cast<DictionaryAttr>();
-
-  // Create a vector of unique pointers to regions and make sure they are not
-  // deleted when exiting the scope. This is a hack caused by C++ API expecting
-  // an list of unique pointers to regions (without ownership transfer
-  // semantics) and C API making ownership transfer explicit.
-  SmallVector<std::unique_ptr<Region>> unwrappedRegions;
-  unwrappedRegions.reserve(nRegions);
-  for (intptr_t i = 0; i < nRegions; ++i)
-    unwrappedRegions.emplace_back(unwrap(*(regions + i)));
-  auto cleaner = llvm::make_scope_exit([&]() {
-    for (auto &region : unwrappedRegions)
-      region.release();
-  });
+  std::optional<Location> maybeLocation = maybeGetLocation(location);
+  SmallVector<Value> unwrappedOperands = unwrapOperands(nOperands, operands);
+  DictionaryAttr attributeDict = unwrapAttributes(attributes);
+  SmallVector<std::unique_ptr<Region>> unwrappedRegions =
+      unwrapRegions(nRegions, regions);
 
   SmallVector<Type> inferredTypes;
   if (failed(info->getInterface<InferTypeOpInterface>()->inferReturnTypes(
@@ -84,3 +119,51 @@ MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes(
   callback(wrappedInferredTypes.size(), wrappedInferredTypes.data(), userData);
   return mlirLogicalResultSuccess();
 }
+
+MlirTypeID mlirInferShapedTypeOpInterfaceTypeID() {
+  return wrap(InferShapedTypeOpInterface::getInterfaceID());
+}
+
+MlirLogicalResult mlirInferShapedTypeOpInterfaceInferReturnTypes(
+    MlirStringRef opName, MlirContext context, MlirLocation location,
+    intptr_t nOperands, MlirValue *operands, MlirAttribute attributes,
+    void *properties, intptr_t nRegions, MlirRegion *regions,
+    MlirShapedTypeComponentsCallback callback, void *userData) {
+  std::optional<RegisteredOperationName> info =
+      getRegisteredOperationName(context, opName);
+  if (!info)
+    return mlirLogicalResultFailure();
+
+  std::optional<Location> maybeLocation = maybeGetLocation(location);
+  SmallVector<Value> unwrappedOperands = unwrapOperands(nOperands, operands);
+  DictionaryAttr attributeDict = unwrapAttributes(attributes);
+  SmallVector<std::unique_ptr<Region>> unwrappedRegions =
+      unwrapRegions(nRegions, regions);
+
+  SmallVector<ShapedTypeComponents> inferredTypeComponents;
+  if (failed(info->getInterface<InferShapedTypeOpInterface>()
+                 ->inferReturnTypeComponents(
+                     unwrap(context), maybeLocation,
+                     mlir::ValueRange(llvm::ArrayRef(unwrappedOperands)),
+                     attributeDict, properties, unwrappedRegions,
+                     inferredTypeComponents)))
+    return mlirLogicalResultFailure();
+
+  bool hasRank;
+  intptr_t rank;
+  const int64_t *shapeData;
+  for (ShapedTypeComponents t : inferredTypeComponents) {
+    if (t.hasRank()) {
+      hasRank = true;
+      rank = t.getDims().size();
+      shapeData = t.getDims().data();
+    } else {
+      hasRank = false;
+      rank = 0;
+      shapeData = nullptr;
+    }
+    callback(hasRank, rank, shapeData, wrap(t.getElementType()),
+             wrap(t.getAttribute()), userData);
+  }
+  return mlirLogicalResultSuccess();
+}

diff  --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index 75b25bd8c1c9d..714935fe12e28 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -62,6 +62,7 @@ __all__ = [
     "FloatAttr",
     "FunctionType",
     "IndexType",
+    "InferShapedTypeOpInterface",
     "InferTypeOpInterface",
     "InsertionPoint",
     "IntegerAttr",
@@ -88,6 +89,7 @@ __all__ = [
     "RegionIterator",
     "RegionSequence",
     "ShapedType",
+    "ShapedTypeComponents",
     "StringAttr",
     "SymbolTable",
     "TupleType",
@@ -689,9 +691,17 @@ class IndexType(Type):
     @staticmethod
     def isinstance(arg: Any) -> bool: ...
 
+class InferShapedTypeOpInterface:
+    def __init__(self, object: object, context: Optional[Context] = None) -> None: ...
+    def inferReturnTypeComponents(self, operands: Optional[List] = None, attributes: Optional[Attribute] = None, properties = None, regions: Optional[List[Region]] = None, context: Optional[Context] = None, loc: Optional[Location] = None) -> List[ShapedTypeComponents]: ...
+    @property
+    def operation(self) -> Operation: ...
+    @property
+    def opview(self) -> OpView: ...
+
 class InferTypeOpInterface:
     def __init__(self, object: object, context: Optional[Context] = None) -> None: ...
-    def inferReturnTypes(self, operands: Optional[List] = None, attributes: Optional[Attribute] = None, regions: Optional[List[Region]] = None, context: Optional[Context] = None, loc: Optional[Location] = None) -> List[Type]: ...
+    def inferReturnTypes(self, operands: Optional[List] = None, attributes: Optional[Attribute] = None, properties = None, regions: Optional[List[Region]] = None, context: Optional[Context] = None, loc: Optional[Location] = None) -> List[Type]: ...
     @property
     def operation(self) -> Operation: ...
     @property
@@ -1016,6 +1026,18 @@ class ShapedType(Type):
     @property
     def shape(self) -> List[int]: ...
 
+class ShapedTypeComponents:
+    @property
+    def element_type(self) -> Type: ...
+    @staticmethod
+    def get(*args, **kwargs) -> ShapedTypeComponents: ...
+    @property
+    def has_rank(self) -> bool: ...
+    @property
+    def rank(self) -> int: ...
+    @property
+    def shape(self) -> List[int]: ...
+
 # TODO: Auto-generated. Audit and fix.
 class StringAttr(Attribute):
     def __init__(self, cast_from_attr: Attribute) -> None: ...

diff  --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index d826540bec1da..8280e5ec73a76 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -1,6 +1,7 @@
 # RUN: %PYTHON %s | FileCheck %s
 
 from mlir.ir import *
+import mlir.dialects.func as func
 import mlir.dialects.python_test as test
 import mlir.dialects.tensor as tensor
 
@@ -330,3 +331,55 @@ def __str__(self):
 
       # CHECK: False
       print(tt.is_null())
+
+
+# CHECK-LABEL: TEST: inferReturnTypeComponents
+ at run
+def inferReturnTypeComponents():
+    with Context() as ctx, Location.unknown(ctx):
+        test.register_python_test_dialect(ctx)
+        module = Module.create()
+        i32 = IntegerType.get_signless(32)
+        with InsertionPoint(module.body):
+            resultType = UnrankedTensorType.get(i32)
+            operandTypes = [
+                RankedTensorType.get([1, 3, 10, 10], i32),
+                UnrankedTensorType.get(i32),
+            ]
+            f = func.FuncOp(
+                "test_inferReturnTypeComponents", (operandTypes, [resultType])
+            )
+            entry_block = Block.create_at_start(f.operation.regions[0], operandTypes)
+            with InsertionPoint(entry_block):
+                ranked_op = test.InferShapedTypeComponentsOp(
+                    resultType, entry_block.arguments[0]
+                )
+                unranked_op = test.InferShapedTypeComponentsOp(
+                    resultType, entry_block.arguments[1]
+                )
+
+        # CHECK: has rank: True
+        # CHECK: rank: 4
+        # CHECK: element type: i32
+        # CHECK: shape: [1, 3, 10, 10]
+        iface = InferShapedTypeOpInterface(ranked_op)
+        shaped_type_components = iface.inferReturnTypeComponents(
+            operands=[ranked_op.operand]
+        )[0]
+        print("has rank:", shaped_type_components.has_rank)
+        print("rank:", shaped_type_components.rank)
+        print("element type:", shaped_type_components.element_type)
+        print("shape:", shaped_type_components.shape)
+
+        # CHECK: has rank: False
+        # CHECK: rank: None
+        # CHECK: element type: i32
+        # CHECK: shape: None
+        iface = InferShapedTypeOpInterface(unranked_op)
+        shaped_type_components = iface.inferReturnTypeComponents(
+          operands=[unranked_op.operand]
+        )[0]
+        print("has rank:", shaped_type_components.has_rank)
+        print("rank:", shaped_type_components.rank)
+        print("element type:", shaped_type_components.element_type)
+        print("shape:", shaped_type_components.shape)

diff  --git a/mlir/test/python/python_test_ops.td b/mlir/test/python/python_test_ops.td
index 692fcb938961e..e1a03c6ee217d 100644
--- a/mlir/test/python/python_test_ops.td
+++ b/mlir/test/python/python_test_ops.td
@@ -90,6 +90,33 @@ def InferResultsImpliedOp : TestOp<"infer_results_implied_op"> {
   let results = (outs I32:$integer, F64:$flt, Index:$index);
 }
 
+def InferShapedTypeComponentsOp : TestOp<"infer_shaped_type_components_op",
+  [DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
+                             ["inferReturnTypeComponents"]>]> {
+  let arguments = (ins AnyTensor:$operand);
+  let results = (outs AnyTensor:$result);
+
+  let extraClassDefinition = [{
+    ::mlir::LogicalResult $cppClass::inferReturnTypeComponents(
+      ::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location,
+      ::mlir::ValueShapeRange operands, ::mlir::DictionaryAttr attributes,
+      ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
+      ::llvm::SmallVectorImpl<
+        ::mlir::ShapedTypeComponents>& inferredShapedTypeComponents) {
+      $cppClass::Adaptor adaptor(operands, attributes, properties, regions);
+      auto operandType =
+          adaptor.getOperand().getType().cast<::mlir::ShapedType>();
+      if (operandType.hasRank()) {
+        inferredShapedTypeComponents.emplace_back(operandType.getShape(),
+            operandType.getElementType());
+      } else {
+        inferredShapedTypeComponents.emplace_back(operandType.getElementType());
+      }
+      return ::mlir::success();
+    }
+  }];
+}
+
 def SameOperandAndResultTypeOp : TestOp<"same_operand_and_result_type_op",
                                         [SameOperandsAndResultType]> {
   let arguments = (ins Variadic<AnyType>);


        


More information about the Mlir-commits mailing list