[Mlir-commits] [mlir] f573bc2 - [mlir][py] Reuse more of CAPI build time inference.

Jacques Pienaar llvmlistbot at llvm.org
Sun Jul 23 21:27:00 PDT 2023


Author: Jacques Pienaar
Date: 2023-07-23T21:26:52-07:00
New Revision: f573bc24d4ea61d68ea61e49eeefbb9a84fa4f34

URL: https://github.com/llvm/llvm-project/commit/f573bc24d4ea61d68ea61e49eeefbb9a84fa4f34
DIFF: https://github.com/llvm/llvm-project/commit/f573bc24d4ea61d68ea61e49eeefbb9a84fa4f34.diff

LOG: [mlir][py] Reuse more of CAPI build time inference.

This reduces code generated for type inference and instead reuses
facilities CAPI side that performed same role.

Differential Revision: https://reviews.llvm.org/D156041t

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/lib/Bindings/Python/IRModule.h
    mlir/test/mlir-tblgen/op-python-bindings.td
    mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 39049f387eda70..971d2819ade44b 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -78,6 +78,7 @@ static const char kOperationCreateDocstring[] =
   ip: An InsertionPoint (defaults to resolve from context manager or set to
     False to disable insertion, even with an insertion point set in the
     context manager).
+  infer_type: Whether to infer result types.
 Returns:
   A new "detached" Operation object. Detached operations can be added
   to blocks, which causes them to become "attached."
@@ -1288,7 +1289,7 @@ py::object PyOperation::create(const std::string &name,
                                std::optional<py::dict> attributes,
                                std::optional<std::vector<PyBlock *>> successors,
                                int regions, DefaultingPyLocation location,
-                               const py::object &maybeIp) {
+                               const py::object &maybeIp, bool inferType) {
   llvm::SmallVector<MlirValue, 4> mlirOperands;
   llvm::SmallVector<MlirType, 4> mlirResults;
   llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
@@ -1367,6 +1368,7 @@ py::object PyOperation::create(const std::string &name,
   if (!mlirOperands.empty())
     mlirOperationStateAddOperands(&state, mlirOperands.size(),
                                   mlirOperands.data());
+  state.enableResultTypeInference = inferType;
   if (!mlirResults.empty())
     mlirOperationStateAddResults(&state, mlirResults.size(),
                                  mlirResults.data());
@@ -1398,6 +1400,8 @@ py::object PyOperation::create(const std::string &name,
 
   // Construct the operation.
   MlirOperation operation = mlirOperationCreate(&state);
+  if (!operation.ptr)
+    throw py::value_error("Operation creation failed");
   PyOperationRef created =
       PyOperation::createDetached(location->getContext(), operation);
   maybeInsertOperation(created, maybeIp);
@@ -1441,51 +1445,10 @@ void PyOperation::erase() {
 // PyOpView
 //------------------------------------------------------------------------------
 
-py::object
-PyOpView::buildGeneric(const py::object &cls, py::list resultTypeList,
-                       py::list operandList, std::optional<py::dict> attributes,
-                       std::optional<std::vector<PyBlock *>> successors,
-                       std::optional<int> regions,
-                       DefaultingPyLocation location,
-                       const py::object &maybeIp) {
-  PyMlirContextRef context = location->getContext();
-  // Class level operation construction metadata.
-  std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
-  // Operand and result segment specs are either none, which does no
-  // variadic unpacking, or a list of ints with segment sizes, where each
-  // element is either a positive number (typically 1 for a scalar) or -1 to
-  // indicate that it is derived from the length of the same-indexed operand
-  // or result (implying that it is a list at that position).
-  py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
-  py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
-
-  std::vector<int32_t> operandSegmentLengths;
-  std::vector<int32_t> resultSegmentLengths;
-
-  // Validate/determine region count.
-  auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
-  int opMinRegionCount = std::get<0>(opRegionSpec);
-  bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
-  if (!regions) {
-    regions = opMinRegionCount;
-  }
-  if (*regions < opMinRegionCount) {
-    throw py::value_error(
-        (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
-         llvm::Twine(opMinRegionCount) +
-         " regions but was built with regions=" + llvm::Twine(*regions))
-            .str());
-  }
-  if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
-    throw py::value_error(
-        (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
-         llvm::Twine(opMinRegionCount) +
-         " regions but was built with regions=" + llvm::Twine(*regions))
-            .str());
-  }
-
-  // Unpack results.
-  std::vector<PyType *> resultTypes;
+static void populateResultTypes(StringRef name, py::list resultTypeList,
+                                const py::object &resultSegmentSpecObj,
+                                std::vector<int32_t> &resultSegmentLengths,
+                                std::vector<PyType *> &resultTypes) {
   resultTypes.reserve(resultTypeList.size());
   if (resultSegmentSpecObj.is_none()) {
     // Non-variadic result unpacking.
@@ -1568,6 +1531,56 @@ PyOpView::buildGeneric(const py::object &cls, py::list resultTypeList,
       }
     }
   }
+}
+
+py::object PyOpView::buildGeneric(
+    const py::object &cls, std::optional<py::list> resultTypeList,
+    py::list operandList, std::optional<py::dict> attributes,
+    std::optional<std::vector<PyBlock *>> successors,
+    std::optional<int> regions, DefaultingPyLocation location,
+    const py::object &maybeIp) {
+  PyMlirContextRef context = location->getContext();
+  // Class level operation construction metadata.
+  std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
+  // Operand and result segment specs are either none, which does no
+  // variadic unpacking, or a list of ints with segment sizes, where each
+  // element is either a positive number (typically 1 for a scalar) or -1 to
+  // indicate that it is derived from the length of the same-indexed operand
+  // or result (implying that it is a list at that position).
+  py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
+  py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
+
+  std::vector<int32_t> operandSegmentLengths;
+  std::vector<int32_t> resultSegmentLengths;
+
+  // Validate/determine region count.
+  auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
+  int opMinRegionCount = std::get<0>(opRegionSpec);
+  bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
+  if (!regions) {
+    regions = opMinRegionCount;
+  }
+  if (*regions < opMinRegionCount) {
+    throw py::value_error(
+        (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
+         llvm::Twine(opMinRegionCount) +
+         " regions but was built with regions=" + llvm::Twine(*regions))
+            .str());
+  }
+  if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
+    throw py::value_error(
+        (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
+         llvm::Twine(opMinRegionCount) +
+         " regions but was built with regions=" + llvm::Twine(*regions))
+            .str());
+  }
+
+  // Unpack results.
+  std::vector<PyType *> resultTypes;
+  if (resultTypeList.has_value()) {
+    populateResultTypes(name, *resultTypeList, resultSegmentSpecObj,
+                        resultSegmentLengths, resultTypes);
+  }
 
   // Unpack operands.
   std::vector<PyValue *> operands;
@@ -1694,7 +1707,8 @@ PyOpView::buildGeneric(const py::object &cls, py::list resultTypeList,
                              /*operands=*/std::move(operands),
                              /*attributes=*/std::move(attributes),
                              /*successors=*/std::move(successors),
-                             /*regions=*/*regions, location, maybeIp);
+                             /*regions=*/*regions, location, maybeIp,
+                             !resultTypeList);
 }
 
 pybind11::object PyOpView::constructDerived(const pybind11::object &cls,
@@ -2854,7 +2868,7 @@ void mlir::python::populateIRCore(py::module &m) {
                   py::arg("attributes") = py::none(),
                   py::arg("successors") = py::none(), py::arg("regions") = 0,
                   py::arg("loc") = py::none(), py::arg("ip") = py::none(),
-                  kOperationCreateDocstring)
+                  py::arg("infer_type") = false, kOperationCreateDocstring)
       .def_static(
           "parse",
           [](const std::string &sourceStr, const std::string &sourceName,

diff  --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 5da1d7d25ebe9c..d1911730c1ede0 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -655,7 +655,8 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
          std::optional<std::vector<PyValue *>> operands,
          std::optional<pybind11::dict> attributes,
          std::optional<std::vector<PyBlock *>> successors, int regions,
-         DefaultingPyLocation location, const pybind11::object &ip);
+         DefaultingPyLocation location, const pybind11::object &ip,
+         bool inferType);
 
   /// Creates an OpView suitable for this operation.
   pybind11::object createOpView();
@@ -704,13 +705,12 @@ class PyOpView : public PyOperationBase {
 
   pybind11::object getOperationObject() { return operationObject; }
 
-  static pybind11::object
-  buildGeneric(const pybind11::object &cls, pybind11::list resultTypeList,
-               pybind11::list operandList,
-               std::optional<pybind11::dict> attributes,
-               std::optional<std::vector<PyBlock *>> successors,
-               std::optional<int> regions, DefaultingPyLocation location,
-               const pybind11::object &maybeIp);
+  static pybind11::object buildGeneric(
+      const pybind11::object &cls, std::optional<pybind11::list> resultTypeList,
+      pybind11::list operandList, std::optional<pybind11::dict> attributes,
+      std::optional<std::vector<PyBlock *>> successors,
+      std::optional<int> regions, DefaultingPyLocation location,
+      const pybind11::object &maybeIp);
 
   /// Construct an instance of a class deriving from OpView, bypassing its
   /// `__init__` method. The derived class will typically define a constructor

diff  --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 036ea5b4a852e2..de979f7e8f43e6 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -245,14 +245,10 @@ def EmptyOp : TestOp<"empty">;
 // CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_implied_op"
 def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
   // CHECK:  def __init__(self, *, loc=None, ip=None):
-  // CHECK:    operands = []
-  // CHECK:    results = []
   // CHECK:    _ods_context = _ods_get_default_loc_context(loc)
-  // CHECK:    results = _ods_ir.InferTypeOpInterface(InferResultTypesImpliedOp).inferReturnTypes(
-  // CHECK:        operands=operands,
-  // CHECK:        attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context),
-  // CHECK:        context=_ods_context,
-  // CHECK:        loc=loc)
+  // CHECK:   super().__init__(self.build_generic(
+  // CHECK:     attributes=attributes, operands=operands,
+  // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip))
   let results = (outs I32:$i32, F32:$f32);
 }
 
@@ -260,13 +256,9 @@ def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
 def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]> {
   // CHECK:  def __init__(self, *, loc=None, ip=None):
   // CHECK:    operands = []
-  // CHECK:    results = []
-  // CHECK:    _ods_context = _ods_get_default_loc_context(loc)
-  // CHECK:    results = _ods_ir.InferTypeOpInterface(InferResultTypesOp).inferReturnTypes(
-  // CHECK:        operands=operands,
-  // CHECK:        attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context),
-  // CHECK:        context=_ods_context,
-  // CHECK:        loc=loc)
+  // CHECK:   super().__init__(self.build_generic(
+  // CHECK:     attributes=attributes, operands=operands,
+  // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip))
   let results = (outs AnyType, AnyType, AnyType);
 }
 

diff  --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index f17e6e9eda4031..dd6e52d300efe1 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -493,9 +493,7 @@ constexpr const char *initTemplate = R"Py(
     attributes = {{}
     regions = None
     {1}
-    super().__init__(self.build_generic(
-      attributes=attributes, results=results, operands=operands,
-      successors=_ods_successors, regions=regions, loc=loc, ip=ip))
+    super().__init__(self.build_generic({2}))
 )Py";
 
 /// Template for appending a single element to the operand/result list.
@@ -755,17 +753,6 @@ _ods_derived_result_type = (
 /// Python code template appending {0} type {1} times to the results list.
 constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})";
 
-/// Python code template for inferring the operation results using the
-/// corresponding interface:
-///   - {0} is the name of the class for which the types are inferred.
-constexpr const char *inferTypeInterfaceTemplate =
-    R"PY(results = _ods_ir.InferTypeOpInterface({0}).inferReturnTypes(
-    operands=operands,
-    attributes=_ods_ir.DictAttr.get(attributes, context=_ods_context),
-    context=_ods_context,
-    loc=loc)
-)PY";
-
 /// Appends the given multiline string as individual strings into
 /// `builderLines`.
 static void appendLineByLine(StringRef string,
@@ -805,12 +792,8 @@ populateBuilderLinesResult(const Operator &op,
     return;
   }
 
-  if (hasInferTypeInterface(op)) {
-    appendLineByLine(
-        llvm::formatv(inferTypeInterfaceTemplate, op.getCppClassName()).str(),
-        builderLines);
+  if (hasInferTypeInterface(op))
     return;
-  }
 
   // For each element, find or generate a name.
   for (int i = 0, e = op.getNumResults(); i < e; ++i) {
@@ -934,8 +917,20 @@ static void emitDefaultOpBuilder(const Operator &op, raw_ostream &os) {
   }
   functionArgs.push_back("loc=None");
   functionArgs.push_back("ip=None");
+
+  SmallVector<std::string> initArgs;
+  initArgs.push_back("attributes=attributes");
+  if (!hasInferTypeInterface(op))
+    initArgs.push_back("results=results");
+  initArgs.push_back("operands=operands");
+  initArgs.push_back("successors=_ods_successors");
+  initArgs.push_back("regions=regions");
+  initArgs.push_back("loc=loc");
+  initArgs.push_back("ip=ip");
+
   os << llvm::formatv(initTemplate, llvm::join(functionArgs, ", "),
-                      llvm::join(builderLines, "\n    "));
+                      llvm::join(builderLines, "\n    "),
+                      llvm::join(initArgs, ", "));
 }
 
 static void emitSegmentSpec(


        


More information about the Mlir-commits mailing list