[Mlir-commits] [mlir] [mlir python] Change PyOpView constructor to construct operations. (PR #123777)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jan 21 08:50:38 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Peter Hawkins (hawkinsp)

<details>
<summary>Changes</summary>

Previously ODS-generated Python operations had code like this:
```
  super().__init__(self.build_generic(attributes=attributes, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip))
```

we change it to:
```
  super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)
```

This:
a) avoids an extra call dispatch (to `build_generic`), and
b) passes the class attributes directly to the constructor. Benchmarks
show that it is faster to pass these as arguments rather than having the
C++ code look up attributes on the class.

This PR improves the timing of the following benchmark on my workstation
from 5.3s to 4.5s:
```
def main(_):
  with ir.Context(), ir.Location.unknown():
    typ = ir.IntegerType.get_signless(32)
    m = ir.Module.create()
    with ir.InsertionPoint(m.body):
      start = time.time()
      for i in range(1000000):
        arith.ConstantOp(typ, i)
      end = time.time()
      print(f"time: {end - start}")
```

Since this change adds an additional overload to the constructor and
does not alter any existing behaviors, it should be backwards
compatible.

---
Full diff: https://github.com/llvm/llvm-project/pull/123777.diff


3 Files Affected:

- (modified) mlir/lib/Bindings/Python/IRCore.cpp (+60-13) 
- (modified) mlir/lib/Bindings/Python/IRModule.h (+11-7) 
- (modified) mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp (+5-1) 


``````````diff
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 53806ca9f04a49..1c9fb3d2abf557 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -211,6 +211,10 @@ static MlirStringRef toMlirStringRef(const std::string &s) {
   return mlirStringRefCreate(s.data(), s.size());
 }
 
+static MlirStringRef toMlirStringRef(std::string_view s) {
+  return mlirStringRefCreate(s.data(), s.size());
+}
+
 static MlirStringRef toMlirStringRef(const nb::bytes &s) {
   return mlirStringRefCreate(static_cast<const char *>(s.data()), s.size());
 }
@@ -1460,7 +1464,7 @@ static void maybeInsertOperation(PyOperationRef &op,
   }
 }
 
-nb::object PyOperation::create(const std::string &name,
+nb::object PyOperation::create(std::string_view name,
                                std::optional<std::vector<PyType *>> results,
                                std::optional<std::vector<PyValue *>> operands,
                                std::optional<nb::dict> attributes,
@@ -1506,7 +1510,7 @@ nb::object PyOperation::create(const std::string &name,
       } catch (nb::cast_error &err) {
         std::string msg = "Invalid attribute key (not a string) when "
                           "attempting to create the operation \"" +
-                          name + "\" (" + err.what() + ")";
+                          std::string(name) + "\" (" + err.what() + ")";
         throw nb::type_error(msg.c_str());
       }
       try {
@@ -1516,13 +1520,14 @@ nb::object PyOperation::create(const std::string &name,
       } catch (nb::cast_error &err) {
         std::string msg = "Invalid attribute value for the key \"" + key +
                           "\" when attempting to create the operation \"" +
-                          name + "\" (" + err.what() + ")";
+                          std::string(name) + "\" (" + err.what() + ")";
         throw nb::type_error(msg.c_str());
       } catch (std::runtime_error &) {
         // This exception seems thrown when the value is "None".
         std::string msg =
             "Found an invalid (`None`?) attribute value for the key \"" + key +
-            "\" when attempting to create the operation \"" + name + "\"";
+            "\" when attempting to create the operation \"" +
+            std::string(name) + "\"";
         throw std::runtime_error(msg);
       }
     }
@@ -1714,27 +1719,25 @@ static void populateResultTypes(StringRef name, nb::list resultTypeList,
 }
 
 nb::object PyOpView::buildGeneric(
-    const nb::object &cls, std::optional<nb::list> resultTypeList,
-    nb::list operandList, std::optional<nb::dict> attributes,
+    std::string_view name, std::tuple<int, bool> opRegionSpec,
+    nb::object operandSegmentSpecObj, nb::object resultSegmentSpecObj,
+    std::optional<nb::list> resultTypeList, nb::list operandList,
+    std::optional<nb::dict> attributes,
     std::optional<std::vector<PyBlock *>> successors,
     std::optional<int> regions, DefaultingPyLocation location,
     const nb::object &maybeIp) {
   PyMlirContextRef context = location->getContext();
+
   // Class level operation construction metadata.
-  std::string name = nb::cast<std::string>(cls.attr("OPERATION_NAME"));
   // Operand and result segment specs are either none, which does no
   // variadic unpacking, or a list of ints with segment sizes, where each
   // element is either a positive number (typically 1 for a scalar) or -1 to
   // indicate that it is derived from the length of the same-indexed operand
   // or result (implying that it is a list at that position).
-  nb::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
-  nb::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
-
   std::vector<int32_t> operandSegmentLengths;
   std::vector<int32_t> resultSegmentLengths;
 
   // Validate/determine region count.
-  auto opRegionSpec = nb::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
   int opMinRegionCount = std::get<0>(opRegionSpec);
   bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
   if (!regions) {
@@ -3236,6 +3239,33 @@ void mlir::python::populateIRCore(nb::module_ &m) {
   auto opViewClass =
       nb::class_<PyOpView, PyOperationBase>(m, "OpView")
           .def(nb::init<nb::object>(), nb::arg("operation"))
+          .def(
+              "__init__",
+              [](PyOpView *self, std::string_view name,
+                 std::tuple<int, bool> opRegionSpec,
+                 nb::object operandSegmentSpecObj,
+                 nb::object resultSegmentSpecObj,
+                 std::optional<nb::list> resultTypeList, nb::list operandList,
+                 std::optional<nb::dict> attributes,
+                 std::optional<std::vector<PyBlock *>> successors,
+                 std::optional<int> regions, DefaultingPyLocation location,
+                 const nb::object &maybeIp) {
+                new (self) PyOpView(PyOpView::buildGeneric(
+                    name, opRegionSpec, operandSegmentSpecObj,
+                    resultSegmentSpecObj, resultTypeList, operandList,
+                    attributes, successors, regions, location, maybeIp));
+              },
+              nb::arg("name"), nb::arg("opRegionSpec"),
+              nb::arg("operandSegmentSpecObj").none() = nb::none(),
+              nb::arg("resultSegmentSpecObj").none() = nb::none(),
+              nb::arg("results").none() = nb::none(),
+              nb::arg("operands").none() = nb::none(),
+              nb::arg("attributes").none() = nb::none(),
+              nb::arg("successors").none() = nb::none(),
+              nb::arg("regions").none() = nb::none(),
+              nb::arg("loc").none() = nb::none(),
+              nb::arg("ip").none() = nb::none())
+
           .def_prop_ro("operation", &PyOpView::getOperationObject)
           .def_prop_ro("opview", [](nb::object self) { return self; })
           .def(
@@ -3250,9 +3280,26 @@ void mlir::python::populateIRCore(nb::module_ &m) {
   opViewClass.attr("_ODS_REGIONS") = nb::make_tuple(0, true);
   opViewClass.attr("_ODS_OPERAND_SEGMENTS") = nb::none();
   opViewClass.attr("_ODS_RESULT_SEGMENTS") = nb::none();
+  // It is faster to pass the operation_name, ods_regions, and
+  // ods_operand_segments/ods_result_segments as arguments to the constructor,
+  // rather than to access them as attributes.
   opViewClass.attr("build_generic") = classmethod(
-      &PyOpView::buildGeneric, nb::arg("cls"),
-      nb::arg("results").none() = nb::none(),
+      [](nb::handle cls, std::optional<nb::list> resultTypeList,
+         nb::list operandList, std::optional<nb::dict> attributes,
+         std::optional<std::vector<PyBlock *>> successors,
+         std::optional<int> regions, DefaultingPyLocation location,
+         const nb::object &maybeIp) {
+        std::string name = nb::cast<std::string>(cls.attr("OPERATION_NAME"));
+        std::tuple<int, bool> opRegionSpec =
+            nb::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
+        nb::object operandSegmentSpec = cls.attr("_ODS_OPERAND_SEGMENTS");
+        nb::object resultSegmentSpec = cls.attr("_ODS_RESULT_SEGMENTS");
+        return PyOpView::buildGeneric(name, opRegionSpec, operandSegmentSpec,
+                                      resultSegmentSpec, resultTypeList,
+                                      operandList, attributes, successors,
+                                      regions, location, maybeIp);
+      },
+      nb::arg("cls"), nb::arg("results").none() = nb::none(),
       nb::arg("operands").none() = nb::none(),
       nb::arg("attributes").none() = nb::none(),
       nb::arg("successors").none() = nb::none(),
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index d1fb4308dbb77c..2228b55231b0bd 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -685,7 +685,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
 
   /// Creates an operation. See corresponding python docstring.
   static nanobind::object
-  create(const std::string &name, std::optional<std::vector<PyType *>> results,
+  create(std::string_view name, std::optional<std::vector<PyType *>> results,
          std::optional<std::vector<PyValue *>> operands,
          std::optional<nanobind::dict> attributes,
          std::optional<std::vector<PyBlock *>> successors, int regions,
@@ -739,12 +739,16 @@ class PyOpView : public PyOperationBase {
 
   nanobind::object getOperationObject() { return operationObject; }
 
-  static nanobind::object buildGeneric(
-      const nanobind::object &cls, std::optional<nanobind::list> resultTypeList,
-      nanobind::list operandList, std::optional<nanobind::dict> attributes,
-      std::optional<std::vector<PyBlock *>> successors,
-      std::optional<int> regions, DefaultingPyLocation location,
-      const nanobind::object &maybeIp);
+  static nanobind::object
+  buildGeneric(std::string_view name, std::tuple<int, bool> opRegionSpec,
+               nanobind::object operandSegmentSpecObj,
+               nanobind::object resultSegmentSpecObj,
+               std::optional<nanobind::list> resultTypeList,
+               nanobind::list operandList,
+               std::optional<nanobind::dict> attributes,
+               std::optional<std::vector<PyBlock *>> successors,
+               std::optional<int> regions, DefaultingPyLocation location,
+               const nanobind::object &maybeIp);
 
   /// Construct an instance of a class deriving from OpView, bypassing its
   /// `__init__` method. The derived class will typically define a constructor
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 300d9977f03994..e1540d1750ff19 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -496,7 +496,7 @@ constexpr const char *initTemplate = R"Py(
     attributes = {{}
     regions = None
     {1}
-    super().__init__(self.build_generic({2}))
+    super().__init__({2})
 )Py";
 
 /// Template for appending a single element to the operand/result list.
@@ -915,6 +915,10 @@ static SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
   functionArgs.push_back("ip=None");
 
   SmallVector<std::string> initArgs;
+  initArgs.push_back("self.OPERATION_NAME");
+  initArgs.push_back("self._ODS_REGIONS");
+  initArgs.push_back("self._ODS_OPERAND_SEGMENTS");
+  initArgs.push_back("self._ODS_RESULT_SEGMENTS");
   initArgs.push_back("attributes=attributes");
   if (!hasInferTypeInterface(op))
     initArgs.push_back("results=results");

``````````

</details>


https://github.com/llvm/llvm-project/pull/123777


More information about the Mlir-commits mailing list