[Mlir-commits] [mlir] [mlir:python] Compute get_op_result_or_value in PyOpView's constructor. (PR #123953)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jan 22 07:20:17 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-core

Author: Peter Hawkins (hawkinsp)

<details>
<summary>Changes</summary>

This logic is in the critical path for constructing an operation from Python. It is faster to compute this in C++ than it is in Python, and it is a minor change to do this.

This change also alters the API contract of
_ods_common.get_op_results_or_values to avoid calling get_op_result_or_value on each element of a sequence, since the C++ code will now do this.

Most of the diff here is simply reordering the code in IRCore.cpp.

---

Patch is 29.16 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/123953.diff


5 Files Affected:

- (modified) mlir/lib/Bindings/Python/IRCore.cpp (+232-199) 
- (modified) mlir/lib/Bindings/Python/IRModule.h (+1-1) 
- (modified) mlir/python/mlir/dialects/_ods_common.py (+2-2) 
- (modified) mlir/test/mlir-tblgen/op-python-bindings.td (+13-13) 
- (modified) mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp (+3-6) 


``````````diff
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index c862ec84fcbc55..53f2031d5df055 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1478,12 +1478,11 @@ static void maybeInsertOperation(PyOperationRef &op,
 
 nb::object PyOperation::create(std::string_view name,
                                std::optional<std::vector<PyType *>> results,
-                               std::optional<std::vector<PyValue *>> operands,
+                               const llvm::SmallVector<MlirValue, 4> &operands,
                                std::optional<nb::dict> attributes,
                                std::optional<std::vector<PyBlock *>> successors,
                                int regions, DefaultingPyLocation location,
                                const nb::object &maybeIp, bool inferType) {
-  llvm::SmallVector<MlirValue, 4> mlirOperands;
   llvm::SmallVector<MlirType, 4> mlirResults;
   llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
   llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
@@ -1492,16 +1491,6 @@ nb::object PyOperation::create(std::string_view name,
   if (regions < 0)
     throw nb::value_error("number of regions must be >= 0");
 
-  // Unpack/validate operands.
-  if (operands) {
-    mlirOperands.reserve(operands->size());
-    for (PyValue *operand : *operands) {
-      if (!operand)
-        throw nb::value_error("operand value cannot be None");
-      mlirOperands.push_back(operand->get());
-    }
-  }
-
   // Unpack/validate results.
   if (results) {
     mlirResults.reserve(results->size());
@@ -1559,9 +1548,8 @@ nb::object PyOperation::create(std::string_view name,
   // point, exceptions cannot be thrown or else the state will leak.
   MlirOperationState state =
       mlirOperationStateGet(toMlirStringRef(name), location);
-  if (!mlirOperands.empty())
-    mlirOperationStateAddOperands(&state, mlirOperands.size(),
-                                  mlirOperands.data());
+  if (!operands.empty())
+    mlirOperationStateAddOperands(&state, operands.size(), operands.data());
   state.enableResultTypeInference = inferType;
   if (!mlirResults.empty())
     mlirOperationStateAddResults(&state, mlirResults.size(),
@@ -1629,6 +1617,142 @@ void PyOperation::erase() {
   mlirOperationDestroy(operation);
 }
 
+namespace {
+/// CRTP base class for Python MLIR values that subclass Value and should be
+/// castable from it. The value hierarchy is one level deep and is not supposed
+/// to accommodate other levels unless core MLIR changes.
+template <typename DerivedTy> class PyConcreteValue : public PyValue {
+public:
+  // Derived classes must define statics for:
+  //   IsAFunctionTy isaFunction
+  //   const char *pyClassName
+  // and redefine bindDerived.
+  using ClassTy = nb::class_<DerivedTy, PyValue>;
+  using IsAFunctionTy = bool (*)(MlirValue);
+
+  PyConcreteValue() = default;
+  PyConcreteValue(PyOperationRef operationRef, MlirValue value)
+      : PyValue(operationRef, value) {}
+  PyConcreteValue(PyValue &orig)
+      : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
+
+  /// Attempts to cast the original value to the derived type and throws on
+  /// type mismatches.
+  static MlirValue castFrom(PyValue &orig) {
+    if (!DerivedTy::isaFunction(orig.get())) {
+      auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(orig)));
+      throw nb::value_error((Twine("Cannot cast value to ") +
+                             DerivedTy::pyClassName + " (from " + origRepr +
+                             ")")
+                                .str()
+                                .c_str());
+    }
+    return orig.get();
+  }
+
+  /// Binds the Python module objects to functions of this class.
+  static void bind(nb::module_ &m) {
+    auto cls = ClassTy(m, DerivedTy::pyClassName);
+    cls.def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"));
+    cls.def_static(
+        "isinstance",
+        [](PyValue &otherValue) -> bool {
+          return DerivedTy::isaFunction(otherValue);
+        },
+        nb::arg("other_value"));
+    cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+            [](DerivedTy &self) { return self.maybeDownCast(); });
+    DerivedTy::bindDerived(cls);
+  }
+
+  /// Implemented by derived classes to add methods to the Python subclass.
+  static void bindDerived(ClassTy &m) {}
+};
+
+} // namespace
+
+/// Python wrapper for MlirOpResult.
+class PyOpResult : public PyConcreteValue<PyOpResult> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
+  static constexpr const char *pyClassName = "OpResult";
+  using PyConcreteValue::PyConcreteValue;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_prop_ro("owner", [](PyOpResult &self) {
+      assert(
+          mlirOperationEqual(self.getParentOperation()->get(),
+                             mlirOpResultGetOwner(self.get())) &&
+          "expected the owner of the value in Python to match that in the IR");
+      return self.getParentOperation().getObject();
+    });
+    c.def_prop_ro("result_number", [](PyOpResult &self) {
+      return mlirOpResultGetResultNumber(self.get());
+    });
+  }
+};
+
+/// Returns the list of types of the values held by container.
+template <typename Container>
+static std::vector<MlirType> getValueTypes(Container &container,
+                                           PyMlirContextRef &context) {
+  std::vector<MlirType> result;
+  result.reserve(container.size());
+  for (int i = 0, e = container.size(); i < e; ++i) {
+    result.push_back(mlirValueGetType(container.getElement(i).get()));
+  }
+  return result;
+}
+
+/// A list of operation results. Internally, these are stored as consecutive
+/// elements, random access is cheap. The (returned) result list is associated
+/// with the operation whose results these are, and thus extends the lifetime of
+/// this operation.
+class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
+public:
+  static constexpr const char *pyClassName = "OpResultList";
+  using SliceableT = Sliceable<PyOpResultList, PyOpResult>;
+
+  PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
+                 intptr_t length = -1, intptr_t step = 1)
+      : Sliceable(startIndex,
+                  length == -1 ? mlirOperationGetNumResults(operation->get())
+                               : length,
+                  step),
+        operation(std::move(operation)) {}
+
+  static void bindDerived(ClassTy &c) {
+    c.def_prop_ro("types", [](PyOpResultList &self) {
+      return getValueTypes(self, self.operation->getContext());
+    });
+    c.def_prop_ro("owner", [](PyOpResultList &self) {
+      return self.operation->createOpView();
+    });
+  }
+
+  PyOperationRef &getOperation() { return operation; }
+
+private:
+  /// Give the parent CRTP class access to hook implementations below.
+  friend class Sliceable<PyOpResultList, PyOpResult>;
+
+  intptr_t getRawNumElements() {
+    operation->checkValid();
+    return mlirOperationGetNumResults(operation->get());
+  }
+
+  PyOpResult getRawElement(intptr_t index) {
+    PyValue value(operation, mlirOperationGetResult(operation->get(), index));
+    return PyOpResult(value);
+  }
+
+  PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
+    return PyOpResultList(operation, startIndex, length, step);
+  }
+
+  PyOperationRef operation;
+};
+
 //------------------------------------------------------------------------------
 // PyOpView
 //------------------------------------------------------------------------------
@@ -1730,6 +1854,40 @@ static void populateResultTypes(StringRef name, nb::list resultTypeList,
   }
 }
 
+static MlirValue getUniqueResult(MlirOperation operation) {
+  auto numResults = mlirOperationGetNumResults(operation);
+  if (numResults != 1) {
+    auto name = mlirIdentifierStr(mlirOperationGetName(operation));
+    throw nb::value_error((Twine("Cannot call .result on operation ") +
+                           StringRef(name.data, name.length) + " which has " +
+                           Twine(numResults) +
+                           " results (it is only valid for operations with a "
+                           "single result)")
+                              .str()
+                              .c_str());
+  }
+  return mlirOperationGetResult(operation, 0);
+}
+
+static MlirValue getOpResultOrValue(nb::handle operand) {
+  if (operand.is_none()) {
+    throw nb::value_error("contained a None item");
+  }
+  PyOperationBase *op;
+  if (nb::try_cast<PyOperationBase *>(operand, op)) {
+    return getUniqueResult(op->getOperation());
+  }
+  PyOpResultList *opResultList;
+  if (nb::try_cast<PyOpResultList *>(operand, opResultList)) {
+    return getUniqueResult(opResultList->getOperation()->get());
+  }
+  PyValue *value;
+  if (nb::try_cast<PyValue *>(operand, value)) {
+    return value->get();
+  }
+  throw nb::value_error("is not a Value");
+}
+
 nb::object PyOpView::buildGeneric(
     std::string_view name, std::tuple<int, bool> opRegionSpec,
     nb::object operandSegmentSpecObj, nb::object resultSegmentSpecObj,
@@ -1780,16 +1938,14 @@ nb::object PyOpView::buildGeneric(
   }
 
   // Unpack operands.
-  std::vector<PyValue *> operands;
+  llvm::SmallVector<MlirValue, 4> operands;
   operands.reserve(operands.size());
   if (operandSegmentSpecObj.is_none()) {
     // Non-sized operand unpacking.
     for (const auto &it : llvm::enumerate(operandList)) {
       try {
-        operands.push_back(nb::cast<PyValue *>(it.value()));
-        if (!operands.back())
-          throw nb::cast_error();
-      } catch (nb::cast_error &err) {
+        operands.push_back(getOpResultOrValue(it.value()));
+      } catch (nb::builtin_exception &err) {
         throw nb::value_error((llvm::Twine("Operand ") +
                                llvm::Twine(it.index()) + " of operation \"" +
                                name + "\" must be a Value (" + err.what() + ")")
@@ -1815,29 +1971,31 @@ nb::object PyOpView::buildGeneric(
       int segmentSpec = std::get<1>(it.value());
       if (segmentSpec == 1 || segmentSpec == 0) {
         // Unpack unary element.
-        try {
-          auto *operandValue = nb::cast<PyValue *>(std::get<0>(it.value()));
-          if (operandValue) {
-            operands.push_back(operandValue);
-            operandSegmentLengths.push_back(1);
-          } else if (segmentSpec == 0) {
-            // Allowed to be optional.
-            operandSegmentLengths.push_back(0);
-          } else {
-            throw nb::value_error(
-                (llvm::Twine("Operand ") + llvm::Twine(it.index()) +
-                 " of operation \"" + name +
-                 "\" must be a Value (was None and operand is not optional)")
-                    .str()
-                    .c_str());
+        auto &operand = std::get<0>(it.value());
+        if (!operand.is_none()) {
+          try {
+
+            operands.push_back(getOpResultOrValue(operand));
+          } catch (nb::builtin_exception &err) {
+            throw nb::value_error((llvm::Twine("Operand ") +
+                                   llvm::Twine(it.index()) +
+                                   " of operation \"" + name +
+                                   "\" must be a Value (" + err.what() + ")")
+                                      .str()
+                                      .c_str());
           }
-        } catch (nb::cast_error &err) {
-          throw nb::value_error((llvm::Twine("Operand ") +
-                                 llvm::Twine(it.index()) + " of operation \"" +
-                                 name + "\" must be a Value (" + err.what() +
-                                 ")")
-                                    .str()
-                                    .c_str());
+
+          operandSegmentLengths.push_back(1);
+        } else if (segmentSpec == 0) {
+          // Allowed to be optional.
+          operandSegmentLengths.push_back(0);
+        } else {
+          throw nb::value_error(
+              (llvm::Twine("Operand ") + llvm::Twine(it.index()) +
+               " of operation \"" + name +
+               "\" must be a Value (was None and operand is not optional)")
+                  .str()
+                  .c_str());
         }
       } else if (segmentSpec == -1) {
         // Unpack sequence by appending.
@@ -1849,10 +2007,7 @@ nb::object PyOpView::buildGeneric(
             // Unpack the list.
             auto segment = nb::cast<nb::sequence>(std::get<0>(it.value()));
             for (nb::handle segmentItem : segment) {
-              operands.push_back(nb::cast<PyValue *>(segmentItem));
-              if (!operands.back()) {
-                throw nb::type_error("contained a None item");
-              }
+              operands.push_back(getOpResultOrValue(segmentItem));
             }
             operandSegmentLengths.push_back(nb::len(segment));
           }
@@ -2266,57 +2421,6 @@ void PySymbolTable::walkSymbolTables(PyOperationBase &from,
 }
 
 namespace {
-/// CRTP base class for Python MLIR values that subclass Value and should be
-/// castable from it. The value hierarchy is one level deep and is not supposed
-/// to accommodate other levels unless core MLIR changes.
-template <typename DerivedTy>
-class PyConcreteValue : public PyValue {
-public:
-  // Derived classes must define statics for:
-  //   IsAFunctionTy isaFunction
-  //   const char *pyClassName
-  // and redefine bindDerived.
-  using ClassTy = nb::class_<DerivedTy, PyValue>;
-  using IsAFunctionTy = bool (*)(MlirValue);
-
-  PyConcreteValue() = default;
-  PyConcreteValue(PyOperationRef operationRef, MlirValue value)
-      : PyValue(operationRef, value) {}
-  PyConcreteValue(PyValue &orig)
-      : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
-
-  /// Attempts to cast the original value to the derived type and throws on
-  /// type mismatches.
-  static MlirValue castFrom(PyValue &orig) {
-    if (!DerivedTy::isaFunction(orig.get())) {
-      auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(orig)));
-      throw nb::value_error((Twine("Cannot cast value to ") +
-                             DerivedTy::pyClassName + " (from " + origRepr +
-                             ")")
-                                .str()
-                                .c_str());
-    }
-    return orig.get();
-  }
-
-  /// Binds the Python module objects to functions of this class.
-  static void bind(nb::module_ &m) {
-    auto cls = ClassTy(m, DerivedTy::pyClassName);
-    cls.def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"));
-    cls.def_static(
-        "isinstance",
-        [](PyValue &otherValue) -> bool {
-          return DerivedTy::isaFunction(otherValue);
-        },
-        nb::arg("other_value"));
-    cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
-            [](DerivedTy &self) { return self.maybeDownCast(); });
-    DerivedTy::bindDerived(cls);
-  }
-
-  /// Implemented by derived classes to add methods to the Python subclass.
-  static void bindDerived(ClassTy &m) {}
-};
 
 /// Python wrapper for MlirBlockArgument.
 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
@@ -2342,39 +2446,6 @@ class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
   }
 };
 
-/// Python wrapper for MlirOpResult.
-class PyOpResult : public PyConcreteValue<PyOpResult> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
-  static constexpr const char *pyClassName = "OpResult";
-  using PyConcreteValue::PyConcreteValue;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_prop_ro("owner", [](PyOpResult &self) {
-      assert(
-          mlirOperationEqual(self.getParentOperation()->get(),
-                             mlirOpResultGetOwner(self.get())) &&
-          "expected the owner of the value in Python to match that in the IR");
-      return self.getParentOperation().getObject();
-    });
-    c.def_prop_ro("result_number", [](PyOpResult &self) {
-      return mlirOpResultGetResultNumber(self.get());
-    });
-  }
-};
-
-/// Returns the list of types of the values held by container.
-template <typename Container>
-static std::vector<MlirType> getValueTypes(Container &container,
-                                           PyMlirContextRef &context) {
-  std::vector<MlirType> result;
-  result.reserve(container.size());
-  for (int i = 0, e = container.size(); i < e; ++i) {
-    result.push_back(mlirValueGetType(container.getElement(i).get()));
-  }
-  return result;
-}
-
 /// A list of block arguments. Internally, these are stored as consecutive
 /// elements, random access is cheap. The argument list is associated with the
 /// operation that contains the block (detached blocks are not allowed in
@@ -2481,53 +2552,6 @@ class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
   PyOperationRef operation;
 };
 
-/// A list of operation results. Internally, these are stored as consecutive
-/// elements, random access is cheap. The (returned) result list is associated
-/// with the operation whose results these are, and thus extends the lifetime of
-/// this operation.
-class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
-public:
-  static constexpr const char *pyClassName = "OpResultList";
-  using SliceableT = Sliceable<PyOpResultList, PyOpResult>;
-
-  PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
-                 intptr_t length = -1, intptr_t step = 1)
-      : Sliceable(startIndex,
-                  length == -1 ? mlirOperationGetNumResults(operation->get())
-                               : length,
-                  step),
-        operation(std::move(operation)) {}
-
-  static void bindDerived(ClassTy &c) {
-    c.def_prop_ro("types", [](PyOpResultList &self) {
-      return getValueTypes(self, self.operation->getContext());
-    });
-    c.def_prop_ro("owner", [](PyOpResultList &self) {
-      return self.operation->createOpView();
-    });
-  }
-
-private:
-  /// Give the parent CRTP class access to hook implementations below.
-  friend class Sliceable<PyOpResultList, PyOpResult>;
-
-  intptr_t getRawNumElements() {
-    operation->checkValid();
-    return mlirOperationGetNumResults(operation->get());
-  }
-
-  PyOpResult getRawElement(intptr_t index) {
-    PyValue value(operation, mlirOperationGetResult(operation->get(), index));
-    return PyOpResult(value);
-  }
-
-  PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
-    return PyOpResultList(operation, startIndex, length, step);
-  }
-
-  PyOperationRef operation;
-};
-
 /// A list of operation successors. Internally, these are stored as consecutive
 /// elements, random access is cheap. The (returned) successor list is
 /// associated with the operation whose successors these are, and thus extends
@@ -3108,20 +3132,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
           "result",
           [](PyOperationBase &self) {
             auto &operation = self.getOperation();
-            auto numResults = mlirOperationGetNumResults(operation);
-            if (numResults != 1) {
-              auto name = mlirIdentifierStr(mlirOperationGetName(operation));
-              throw nb::value_error(
-                  (Twine("Cannot call .result on operation ") +
-                   StringRef(name.data, name.length) + " which has " +
-                   Twine(numResults) +
-                   " results (it is only valid for operations with a "
-                   "single result)")
-                      .str()
-                      .c_str());
-            }
-            return PyOpResult(operation.getRef(),
-                              mlirOperationGetResult(operation, 0))
+            return PyOpResult(operation.getRef(), getUniqueResult(operation))
                 .maybeDownCast();
           },
           "Shortcut to get an op result if it has only one (throws an error "
@@ -3218,14 +3229,36 @@ void mlir::python::populateIRCore(nb::module_ &m) {
            nb::arg("walk_order") = MlirWalkPostOrder);
 
   nb::class_<PyOperation, PyOperationBase>(m, "Operation")
-      .def_static("create", &PyOperation::create, nb::arg("name"),
-                  nb::arg("results").none() = nb::none(),
-                  nb::arg("operands").none() = nb::none(),
-      ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/123953


More information about the Mlir-commits mailing list