[Mlir-commits] [mlir] [mlir:python] Compute get_op_result_or_value in PyOpView's constructor. (PR #123953)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jan 22 07:20:17 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core
Author: Peter Hawkins (hawkinsp)
<details>
<summary>Changes</summary>
This logic is in the critical path for constructing an operation from Python. It is faster to compute this in C++ than it is in Python, and it is a minor change to do this.
This change also alters the API contract of
_ods_common.get_op_results_or_values to avoid calling get_op_result_or_value on each element of a sequence, since the C++ code will now do this.
Most of the diff here is simply reordering the code in IRCore.cpp.
---
Patch is 29.16 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/123953.diff
5 Files Affected:
- (modified) mlir/lib/Bindings/Python/IRCore.cpp (+232-199)
- (modified) mlir/lib/Bindings/Python/IRModule.h (+1-1)
- (modified) mlir/python/mlir/dialects/_ods_common.py (+2-2)
- (modified) mlir/test/mlir-tblgen/op-python-bindings.td (+13-13)
- (modified) mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp (+3-6)
``````````diff
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index c862ec84fcbc55..53f2031d5df055 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1478,12 +1478,11 @@ static void maybeInsertOperation(PyOperationRef &op,
nb::object PyOperation::create(std::string_view name,
std::optional<std::vector<PyType *>> results,
- std::optional<std::vector<PyValue *>> operands,
+ const llvm::SmallVector<MlirValue, 4> &operands,
std::optional<nb::dict> attributes,
std::optional<std::vector<PyBlock *>> successors,
int regions, DefaultingPyLocation location,
const nb::object &maybeIp, bool inferType) {
- llvm::SmallVector<MlirValue, 4> mlirOperands;
llvm::SmallVector<MlirType, 4> mlirResults;
llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
@@ -1492,16 +1491,6 @@ nb::object PyOperation::create(std::string_view name,
if (regions < 0)
throw nb::value_error("number of regions must be >= 0");
- // Unpack/validate operands.
- if (operands) {
- mlirOperands.reserve(operands->size());
- for (PyValue *operand : *operands) {
- if (!operand)
- throw nb::value_error("operand value cannot be None");
- mlirOperands.push_back(operand->get());
- }
- }
-
// Unpack/validate results.
if (results) {
mlirResults.reserve(results->size());
@@ -1559,9 +1548,8 @@ nb::object PyOperation::create(std::string_view name,
// point, exceptions cannot be thrown or else the state will leak.
MlirOperationState state =
mlirOperationStateGet(toMlirStringRef(name), location);
- if (!mlirOperands.empty())
- mlirOperationStateAddOperands(&state, mlirOperands.size(),
- mlirOperands.data());
+ if (!operands.empty())
+ mlirOperationStateAddOperands(&state, operands.size(), operands.data());
state.enableResultTypeInference = inferType;
if (!mlirResults.empty())
mlirOperationStateAddResults(&state, mlirResults.size(),
@@ -1629,6 +1617,142 @@ void PyOperation::erase() {
mlirOperationDestroy(operation);
}
+namespace {
+/// CRTP base class for Python MLIR values that subclass Value and should be
+/// castable from it. The value hierarchy is one level deep and is not supposed
+/// to accommodate other levels unless core MLIR changes.
+template <typename DerivedTy> class PyConcreteValue : public PyValue {
+public:
+ // Derived classes must define statics for:
+ // IsAFunctionTy isaFunction
+ // const char *pyClassName
+ // and redefine bindDerived.
+ using ClassTy = nb::class_<DerivedTy, PyValue>;
+ using IsAFunctionTy = bool (*)(MlirValue);
+
+ PyConcreteValue() = default;
+ PyConcreteValue(PyOperationRef operationRef, MlirValue value)
+ : PyValue(operationRef, value) {}
+ PyConcreteValue(PyValue &orig)
+ : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
+
+ /// Attempts to cast the original value to the derived type and throws on
+ /// type mismatches.
+ static MlirValue castFrom(PyValue &orig) {
+ if (!DerivedTy::isaFunction(orig.get())) {
+ auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(orig)));
+ throw nb::value_error((Twine("Cannot cast value to ") +
+ DerivedTy::pyClassName + " (from " + origRepr +
+ ")")
+ .str()
+ .c_str());
+ }
+ return orig.get();
+ }
+
+ /// Binds the Python module objects to functions of this class.
+ static void bind(nb::module_ &m) {
+ auto cls = ClassTy(m, DerivedTy::pyClassName);
+ cls.def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"));
+ cls.def_static(
+ "isinstance",
+ [](PyValue &otherValue) -> bool {
+ return DerivedTy::isaFunction(otherValue);
+ },
+ nb::arg("other_value"));
+ cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+ [](DerivedTy &self) { return self.maybeDownCast(); });
+ DerivedTy::bindDerived(cls);
+ }
+
+ /// Implemented by derived classes to add methods to the Python subclass.
+ static void bindDerived(ClassTy &m) {}
+};
+
+} // namespace
+
+/// Python wrapper for MlirOpResult.
+class PyOpResult : public PyConcreteValue<PyOpResult> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
+ static constexpr const char *pyClassName = "OpResult";
+ using PyConcreteValue::PyConcreteValue;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_prop_ro("owner", [](PyOpResult &self) {
+ assert(
+ mlirOperationEqual(self.getParentOperation()->get(),
+ mlirOpResultGetOwner(self.get())) &&
+ "expected the owner of the value in Python to match that in the IR");
+ return self.getParentOperation().getObject();
+ });
+ c.def_prop_ro("result_number", [](PyOpResult &self) {
+ return mlirOpResultGetResultNumber(self.get());
+ });
+ }
+};
+
+/// Returns the list of types of the values held by container.
+template <typename Container>
+static std::vector<MlirType> getValueTypes(Container &container,
+ PyMlirContextRef &context) {
+ std::vector<MlirType> result;
+ result.reserve(container.size());
+ for (int i = 0, e = container.size(); i < e; ++i) {
+ result.push_back(mlirValueGetType(container.getElement(i).get()));
+ }
+ return result;
+}
+
+/// A list of operation results. Internally, these are stored as consecutive
+/// elements, random access is cheap. The (returned) result list is associated
+/// with the operation whose results these are, and thus extends the lifetime of
+/// this operation.
+class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
+public:
+ static constexpr const char *pyClassName = "OpResultList";
+ using SliceableT = Sliceable<PyOpResultList, PyOpResult>;
+
+ PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
+ intptr_t length = -1, intptr_t step = 1)
+ : Sliceable(startIndex,
+ length == -1 ? mlirOperationGetNumResults(operation->get())
+ : length,
+ step),
+ operation(std::move(operation)) {}
+
+ static void bindDerived(ClassTy &c) {
+ c.def_prop_ro("types", [](PyOpResultList &self) {
+ return getValueTypes(self, self.operation->getContext());
+ });
+ c.def_prop_ro("owner", [](PyOpResultList &self) {
+ return self.operation->createOpView();
+ });
+ }
+
+ PyOperationRef &getOperation() { return operation; }
+
+private:
+ /// Give the parent CRTP class access to hook implementations below.
+ friend class Sliceable<PyOpResultList, PyOpResult>;
+
+ intptr_t getRawNumElements() {
+ operation->checkValid();
+ return mlirOperationGetNumResults(operation->get());
+ }
+
+ PyOpResult getRawElement(intptr_t index) {
+ PyValue value(operation, mlirOperationGetResult(operation->get(), index));
+ return PyOpResult(value);
+ }
+
+ PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
+ return PyOpResultList(operation, startIndex, length, step);
+ }
+
+ PyOperationRef operation;
+};
+
//------------------------------------------------------------------------------
// PyOpView
//------------------------------------------------------------------------------
@@ -1730,6 +1854,40 @@ static void populateResultTypes(StringRef name, nb::list resultTypeList,
}
}
+static MlirValue getUniqueResult(MlirOperation operation) {
+ auto numResults = mlirOperationGetNumResults(operation);
+ if (numResults != 1) {
+ auto name = mlirIdentifierStr(mlirOperationGetName(operation));
+ throw nb::value_error((Twine("Cannot call .result on operation ") +
+ StringRef(name.data, name.length) + " which has " +
+ Twine(numResults) +
+ " results (it is only valid for operations with a "
+ "single result)")
+ .str()
+ .c_str());
+ }
+ return mlirOperationGetResult(operation, 0);
+}
+
+static MlirValue getOpResultOrValue(nb::handle operand) {
+ if (operand.is_none()) {
+ throw nb::value_error("contained a None item");
+ }
+ PyOperationBase *op;
+ if (nb::try_cast<PyOperationBase *>(operand, op)) {
+ return getUniqueResult(op->getOperation());
+ }
+ PyOpResultList *opResultList;
+ if (nb::try_cast<PyOpResultList *>(operand, opResultList)) {
+ return getUniqueResult(opResultList->getOperation()->get());
+ }
+ PyValue *value;
+ if (nb::try_cast<PyValue *>(operand, value)) {
+ return value->get();
+ }
+ throw nb::value_error("is not a Value");
+}
+
nb::object PyOpView::buildGeneric(
std::string_view name, std::tuple<int, bool> opRegionSpec,
nb::object operandSegmentSpecObj, nb::object resultSegmentSpecObj,
@@ -1780,16 +1938,14 @@ nb::object PyOpView::buildGeneric(
}
// Unpack operands.
- std::vector<PyValue *> operands;
+ llvm::SmallVector<MlirValue, 4> operands;
operands.reserve(operands.size());
if (operandSegmentSpecObj.is_none()) {
// Non-sized operand unpacking.
for (const auto &it : llvm::enumerate(operandList)) {
try {
- operands.push_back(nb::cast<PyValue *>(it.value()));
- if (!operands.back())
- throw nb::cast_error();
- } catch (nb::cast_error &err) {
+ operands.push_back(getOpResultOrValue(it.value()));
+ } catch (nb::builtin_exception &err) {
throw nb::value_error((llvm::Twine("Operand ") +
llvm::Twine(it.index()) + " of operation \"" +
name + "\" must be a Value (" + err.what() + ")")
@@ -1815,29 +1971,31 @@ nb::object PyOpView::buildGeneric(
int segmentSpec = std::get<1>(it.value());
if (segmentSpec == 1 || segmentSpec == 0) {
// Unpack unary element.
- try {
- auto *operandValue = nb::cast<PyValue *>(std::get<0>(it.value()));
- if (operandValue) {
- operands.push_back(operandValue);
- operandSegmentLengths.push_back(1);
- } else if (segmentSpec == 0) {
- // Allowed to be optional.
- operandSegmentLengths.push_back(0);
- } else {
- throw nb::value_error(
- (llvm::Twine("Operand ") + llvm::Twine(it.index()) +
- " of operation \"" + name +
- "\" must be a Value (was None and operand is not optional)")
- .str()
- .c_str());
+ auto &operand = std::get<0>(it.value());
+ if (!operand.is_none()) {
+ try {
+
+ operands.push_back(getOpResultOrValue(operand));
+ } catch (nb::builtin_exception &err) {
+ throw nb::value_error((llvm::Twine("Operand ") +
+ llvm::Twine(it.index()) +
+ " of operation \"" + name +
+ "\" must be a Value (" + err.what() + ")")
+ .str()
+ .c_str());
}
- } catch (nb::cast_error &err) {
- throw nb::value_error((llvm::Twine("Operand ") +
- llvm::Twine(it.index()) + " of operation \"" +
- name + "\" must be a Value (" + err.what() +
- ")")
- .str()
- .c_str());
+
+ operandSegmentLengths.push_back(1);
+ } else if (segmentSpec == 0) {
+ // Allowed to be optional.
+ operandSegmentLengths.push_back(0);
+ } else {
+ throw nb::value_error(
+ (llvm::Twine("Operand ") + llvm::Twine(it.index()) +
+ " of operation \"" + name +
+ "\" must be a Value (was None and operand is not optional)")
+ .str()
+ .c_str());
}
} else if (segmentSpec == -1) {
// Unpack sequence by appending.
@@ -1849,10 +2007,7 @@ nb::object PyOpView::buildGeneric(
// Unpack the list.
auto segment = nb::cast<nb::sequence>(std::get<0>(it.value()));
for (nb::handle segmentItem : segment) {
- operands.push_back(nb::cast<PyValue *>(segmentItem));
- if (!operands.back()) {
- throw nb::type_error("contained a None item");
- }
+ operands.push_back(getOpResultOrValue(segmentItem));
}
operandSegmentLengths.push_back(nb::len(segment));
}
@@ -2266,57 +2421,6 @@ void PySymbolTable::walkSymbolTables(PyOperationBase &from,
}
namespace {
-/// CRTP base class for Python MLIR values that subclass Value and should be
-/// castable from it. The value hierarchy is one level deep and is not supposed
-/// to accommodate other levels unless core MLIR changes.
-template <typename DerivedTy>
-class PyConcreteValue : public PyValue {
-public:
- // Derived classes must define statics for:
- // IsAFunctionTy isaFunction
- // const char *pyClassName
- // and redefine bindDerived.
- using ClassTy = nb::class_<DerivedTy, PyValue>;
- using IsAFunctionTy = bool (*)(MlirValue);
-
- PyConcreteValue() = default;
- PyConcreteValue(PyOperationRef operationRef, MlirValue value)
- : PyValue(operationRef, value) {}
- PyConcreteValue(PyValue &orig)
- : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
-
- /// Attempts to cast the original value to the derived type and throws on
- /// type mismatches.
- static MlirValue castFrom(PyValue &orig) {
- if (!DerivedTy::isaFunction(orig.get())) {
- auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(orig)));
- throw nb::value_error((Twine("Cannot cast value to ") +
- DerivedTy::pyClassName + " (from " + origRepr +
- ")")
- .str()
- .c_str());
- }
- return orig.get();
- }
-
- /// Binds the Python module objects to functions of this class.
- static void bind(nb::module_ &m) {
- auto cls = ClassTy(m, DerivedTy::pyClassName);
- cls.def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"));
- cls.def_static(
- "isinstance",
- [](PyValue &otherValue) -> bool {
- return DerivedTy::isaFunction(otherValue);
- },
- nb::arg("other_value"));
- cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
- [](DerivedTy &self) { return self.maybeDownCast(); });
- DerivedTy::bindDerived(cls);
- }
-
- /// Implemented by derived classes to add methods to the Python subclass.
- static void bindDerived(ClassTy &m) {}
-};
/// Python wrapper for MlirBlockArgument.
class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
@@ -2342,39 +2446,6 @@ class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
}
};
-/// Python wrapper for MlirOpResult.
-class PyOpResult : public PyConcreteValue<PyOpResult> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
- static constexpr const char *pyClassName = "OpResult";
- using PyConcreteValue::PyConcreteValue;
-
- static void bindDerived(ClassTy &c) {
- c.def_prop_ro("owner", [](PyOpResult &self) {
- assert(
- mlirOperationEqual(self.getParentOperation()->get(),
- mlirOpResultGetOwner(self.get())) &&
- "expected the owner of the value in Python to match that in the IR");
- return self.getParentOperation().getObject();
- });
- c.def_prop_ro("result_number", [](PyOpResult &self) {
- return mlirOpResultGetResultNumber(self.get());
- });
- }
-};
-
-/// Returns the list of types of the values held by container.
-template <typename Container>
-static std::vector<MlirType> getValueTypes(Container &container,
- PyMlirContextRef &context) {
- std::vector<MlirType> result;
- result.reserve(container.size());
- for (int i = 0, e = container.size(); i < e; ++i) {
- result.push_back(mlirValueGetType(container.getElement(i).get()));
- }
- return result;
-}
-
/// A list of block arguments. Internally, these are stored as consecutive
/// elements, random access is cheap. The argument list is associated with the
/// operation that contains the block (detached blocks are not allowed in
@@ -2481,53 +2552,6 @@ class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
PyOperationRef operation;
};
-/// A list of operation results. Internally, these are stored as consecutive
-/// elements, random access is cheap. The (returned) result list is associated
-/// with the operation whose results these are, and thus extends the lifetime of
-/// this operation.
-class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
-public:
- static constexpr const char *pyClassName = "OpResultList";
- using SliceableT = Sliceable<PyOpResultList, PyOpResult>;
-
- PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
- intptr_t length = -1, intptr_t step = 1)
- : Sliceable(startIndex,
- length == -1 ? mlirOperationGetNumResults(operation->get())
- : length,
- step),
- operation(std::move(operation)) {}
-
- static void bindDerived(ClassTy &c) {
- c.def_prop_ro("types", [](PyOpResultList &self) {
- return getValueTypes(self, self.operation->getContext());
- });
- c.def_prop_ro("owner", [](PyOpResultList &self) {
- return self.operation->createOpView();
- });
- }
-
-private:
- /// Give the parent CRTP class access to hook implementations below.
- friend class Sliceable<PyOpResultList, PyOpResult>;
-
- intptr_t getRawNumElements() {
- operation->checkValid();
- return mlirOperationGetNumResults(operation->get());
- }
-
- PyOpResult getRawElement(intptr_t index) {
- PyValue value(operation, mlirOperationGetResult(operation->get(), index));
- return PyOpResult(value);
- }
-
- PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
- return PyOpResultList(operation, startIndex, length, step);
- }
-
- PyOperationRef operation;
-};
-
/// A list of operation successors. Internally, these are stored as consecutive
/// elements, random access is cheap. The (returned) successor list is
/// associated with the operation whose successors these are, and thus extends
@@ -3108,20 +3132,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
"result",
[](PyOperationBase &self) {
auto &operation = self.getOperation();
- auto numResults = mlirOperationGetNumResults(operation);
- if (numResults != 1) {
- auto name = mlirIdentifierStr(mlirOperationGetName(operation));
- throw nb::value_error(
- (Twine("Cannot call .result on operation ") +
- StringRef(name.data, name.length) + " which has " +
- Twine(numResults) +
- " results (it is only valid for operations with a "
- "single result)")
- .str()
- .c_str());
- }
- return PyOpResult(operation.getRef(),
- mlirOperationGetResult(operation, 0))
+ return PyOpResult(operation.getRef(), getUniqueResult(operation))
.maybeDownCast();
},
"Shortcut to get an op result if it has only one (throws an error "
@@ -3218,14 +3229,36 @@ void mlir::python::populateIRCore(nb::module_ &m) {
nb::arg("walk_order") = MlirWalkPostOrder);
nb::class_<PyOperation, PyOperationBase>(m, "Operation")
- .def_static("create", &PyOperation::create, nb::arg("name"),
- nb::arg("results").none() = nb::none(),
- nb::arg("operands").none() = nb::none(),
- ...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/123953
More information about the Mlir-commits
mailing list