[Mlir-commits] [mlir] acde3f7 - [mlir:python] Compute get_op_result_or_value in PyOpView's constructor. (#123953)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jan 24 06:26:33 PST 2025
Author: Peter Hawkins
Date: 2025-01-24T06:26:28-08:00
New Revision: acde3f722ff3766f6f793884108d342b78623fe4
URL: https://github.com/llvm/llvm-project/commit/acde3f722ff3766f6f793884108d342b78623fe4
DIFF: https://github.com/llvm/llvm-project/commit/acde3f722ff3766f6f793884108d342b78623fe4.diff
LOG: [mlir:python] Compute get_op_result_or_value in PyOpView's constructor. (#123953)
This logic is in the critical path for constructing an operation from
Python. It is faster to compute this in C++ than it is in Python, and it
is a minor change to do this.
This change also alters the API contract of
_ods_common.get_op_results_or_values to avoid calling
get_op_result_or_value on each element of a sequence, since the C++ code
will now do this.
Most of the diff here is simply reordering the code in IRCore.cpp.
Added:
Modified:
mlir/lib/Bindings/Python/IRCore.cpp
mlir/lib/Bindings/Python/IRModule.h
mlir/python/mlir/dialects/_ods_common.py
mlir/test/mlir-tblgen/op-python-bindings.td
mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 738f1444b15fe5..8e351cb22eb948 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1481,12 +1481,11 @@ static void maybeInsertOperation(PyOperationRef &op,
nb::object PyOperation::create(std::string_view name,
std::optional<std::vector<PyType *>> results,
- std::optional<std::vector<PyValue *>> operands,
+ llvm::ArrayRef<MlirValue> operands,
std::optional<nb::dict> attributes,
std::optional<std::vector<PyBlock *>> successors,
int regions, DefaultingPyLocation location,
const nb::object &maybeIp, bool inferType) {
- llvm::SmallVector<MlirValue, 4> mlirOperands;
llvm::SmallVector<MlirType, 4> mlirResults;
llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
@@ -1495,16 +1494,6 @@ nb::object PyOperation::create(std::string_view name,
if (regions < 0)
throw nb::value_error("number of regions must be >= 0");
- // Unpack/validate operands.
- if (operands) {
- mlirOperands.reserve(operands->size());
- for (PyValue *operand : *operands) {
- if (!operand)
- throw nb::value_error("operand value cannot be None");
- mlirOperands.push_back(operand->get());
- }
- }
-
// Unpack/validate results.
if (results) {
mlirResults.reserve(results->size());
@@ -1562,9 +1551,8 @@ nb::object PyOperation::create(std::string_view name,
// point, exceptions cannot be thrown or else the state will leak.
MlirOperationState state =
mlirOperationStateGet(toMlirStringRef(name), location);
- if (!mlirOperands.empty())
- mlirOperationStateAddOperands(&state, mlirOperands.size(),
- mlirOperands.data());
+ if (!operands.empty())
+ mlirOperationStateAddOperands(&state, operands.size(), operands.data());
state.enableResultTypeInference = inferType;
if (!mlirResults.empty())
mlirOperationStateAddResults(&state, mlirResults.size(),
@@ -1632,6 +1620,143 @@ void PyOperation::erase() {
mlirOperationDestroy(operation);
}
+namespace {
+/// CRTP base class for Python MLIR values that subclass Value and should be
+/// castable from it. The value hierarchy is one level deep and is not supposed
+/// to accommodate other levels unless core MLIR changes.
+template <typename DerivedTy>
+class PyConcreteValue : public PyValue {
+public:
+ // Derived classes must define statics for:
+ // IsAFunctionTy isaFunction
+ // const char *pyClassName
+ // and redefine bindDerived.
+ using ClassTy = nb::class_<DerivedTy, PyValue>;
+ using IsAFunctionTy = bool (*)(MlirValue);
+
+ PyConcreteValue() = default;
+ PyConcreteValue(PyOperationRef operationRef, MlirValue value)
+ : PyValue(operationRef, value) {}
+ PyConcreteValue(PyValue &orig)
+ : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
+
+ /// Attempts to cast the original value to the derived type and throws on
+ /// type mismatches.
+ static MlirValue castFrom(PyValue &orig) {
+ if (!DerivedTy::isaFunction(orig.get())) {
+ auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(orig)));
+ throw nb::value_error((Twine("Cannot cast value to ") +
+ DerivedTy::pyClassName + " (from " + origRepr +
+ ")")
+ .str()
+ .c_str());
+ }
+ return orig.get();
+ }
+
+ /// Binds the Python module objects to functions of this class.
+ static void bind(nb::module_ &m) {
+ auto cls = ClassTy(m, DerivedTy::pyClassName);
+ cls.def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"));
+ cls.def_static(
+ "isinstance",
+ [](PyValue &otherValue) -> bool {
+ return DerivedTy::isaFunction(otherValue);
+ },
+ nb::arg("other_value"));
+ cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+ [](DerivedTy &self) { return self.maybeDownCast(); });
+ DerivedTy::bindDerived(cls);
+ }
+
+ /// Implemented by derived classes to add methods to the Python subclass.
+ static void bindDerived(ClassTy &m) {}
+};
+
+} // namespace
+
+/// Python wrapper for MlirOpResult.
+class PyOpResult : public PyConcreteValue<PyOpResult> {
+public:
+ static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
+ static constexpr const char *pyClassName = "OpResult";
+ using PyConcreteValue::PyConcreteValue;
+
+ static void bindDerived(ClassTy &c) {
+ c.def_prop_ro("owner", [](PyOpResult &self) {
+ assert(
+ mlirOperationEqual(self.getParentOperation()->get(),
+ mlirOpResultGetOwner(self.get())) &&
+ "expected the owner of the value in Python to match that in the IR");
+ return self.getParentOperation().getObject();
+ });
+ c.def_prop_ro("result_number", [](PyOpResult &self) {
+ return mlirOpResultGetResultNumber(self.get());
+ });
+ }
+};
+
+/// Returns the list of types of the values held by container.
+template <typename Container>
+static std::vector<MlirType> getValueTypes(Container &container,
+ PyMlirContextRef &context) {
+ std::vector<MlirType> result;
+ result.reserve(container.size());
+ for (int i = 0, e = container.size(); i < e; ++i) {
+ result.push_back(mlirValueGetType(container.getElement(i).get()));
+ }
+ return result;
+}
+
+/// A list of operation results. Internally, these are stored as consecutive
+/// elements, random access is cheap. The (returned) result list is associated
+/// with the operation whose results these are, and thus extends the lifetime of
+/// this operation.
+class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
+public:
+ static constexpr const char *pyClassName = "OpResultList";
+ using SliceableT = Sliceable<PyOpResultList, PyOpResult>;
+
+ PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
+ intptr_t length = -1, intptr_t step = 1)
+ : Sliceable(startIndex,
+ length == -1 ? mlirOperationGetNumResults(operation->get())
+ : length,
+ step),
+ operation(std::move(operation)) {}
+
+ static void bindDerived(ClassTy &c) {
+ c.def_prop_ro("types", [](PyOpResultList &self) {
+ return getValueTypes(self, self.operation->getContext());
+ });
+ c.def_prop_ro("owner", [](PyOpResultList &self) {
+ return self.operation->createOpView();
+ });
+ }
+
+ PyOperationRef &getOperation() { return operation; }
+
+private:
+ /// Give the parent CRTP class access to hook implementations below.
+ friend class Sliceable<PyOpResultList, PyOpResult>;
+
+ intptr_t getRawNumElements() {
+ operation->checkValid();
+ return mlirOperationGetNumResults(operation->get());
+ }
+
+ PyOpResult getRawElement(intptr_t index) {
+ PyValue value(operation, mlirOperationGetResult(operation->get(), index));
+ return PyOpResult(value);
+ }
+
+ PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
+ return PyOpResultList(operation, startIndex, length, step);
+ }
+
+ PyOperationRef operation;
+};
+
//------------------------------------------------------------------------------
// PyOpView
//------------------------------------------------------------------------------
@@ -1733,6 +1858,40 @@ static void populateResultTypes(StringRef name, nb::list resultTypeList,
}
}
+static MlirValue getUniqueResult(MlirOperation operation) {
+ auto numResults = mlirOperationGetNumResults(operation);
+ if (numResults != 1) {
+ auto name = mlirIdentifierStr(mlirOperationGetName(operation));
+ throw nb::value_error((Twine("Cannot call .result on operation ") +
+ StringRef(name.data, name.length) + " which has " +
+ Twine(numResults) +
+ " results (it is only valid for operations with a "
+ "single result)")
+ .str()
+ .c_str());
+ }
+ return mlirOperationGetResult(operation, 0);
+}
+
+static MlirValue getOpResultOrValue(nb::handle operand) {
+ if (operand.is_none()) {
+ throw nb::value_error("contained a None item");
+ }
+ PyOperationBase *op;
+ if (nb::try_cast<PyOperationBase *>(operand, op)) {
+ return getUniqueResult(op->getOperation());
+ }
+ PyOpResultList *opResultList;
+ if (nb::try_cast<PyOpResultList *>(operand, opResultList)) {
+ return getUniqueResult(opResultList->getOperation()->get());
+ }
+ PyValue *value;
+ if (nb::try_cast<PyValue *>(operand, value)) {
+ return value->get();
+ }
+ throw nb::value_error("is not a Value");
+}
+
nb::object PyOpView::buildGeneric(
std::string_view name, std::tuple<int, bool> opRegionSpec,
nb::object operandSegmentSpecObj, nb::object resultSegmentSpecObj,
@@ -1783,16 +1942,14 @@ nb::object PyOpView::buildGeneric(
}
// Unpack operands.
- std::vector<PyValue *> operands;
+ llvm::SmallVector<MlirValue, 4> operands;
operands.reserve(operands.size());
if (operandSegmentSpecObj.is_none()) {
// Non-sized operand unpacking.
for (const auto &it : llvm::enumerate(operandList)) {
try {
- operands.push_back(nb::cast<PyValue *>(it.value()));
- if (!operands.back())
- throw nb::cast_error();
- } catch (nb::cast_error &err) {
+ operands.push_back(getOpResultOrValue(it.value()));
+ } catch (nb::builtin_exception &err) {
throw nb::value_error((llvm::Twine("Operand ") +
llvm::Twine(it.index()) + " of operation \"" +
name + "\" must be a Value (" + err.what() + ")")
@@ -1818,29 +1975,31 @@ nb::object PyOpView::buildGeneric(
int segmentSpec = std::get<1>(it.value());
if (segmentSpec == 1 || segmentSpec == 0) {
// Unpack unary element.
- try {
- auto *operandValue = nb::cast<PyValue *>(std::get<0>(it.value()));
- if (operandValue) {
- operands.push_back(operandValue);
- operandSegmentLengths.push_back(1);
- } else if (segmentSpec == 0) {
- // Allowed to be optional.
- operandSegmentLengths.push_back(0);
- } else {
- throw nb::value_error(
- (llvm::Twine("Operand ") + llvm::Twine(it.index()) +
- " of operation \"" + name +
- "\" must be a Value (was None and operand is not optional)")
- .str()
- .c_str());
+ auto &operand = std::get<0>(it.value());
+ if (!operand.is_none()) {
+ try {
+
+ operands.push_back(getOpResultOrValue(operand));
+ } catch (nb::builtin_exception &err) {
+ throw nb::value_error((llvm::Twine("Operand ") +
+ llvm::Twine(it.index()) +
+ " of operation \"" + name +
+ "\" must be a Value (" + err.what() + ")")
+ .str()
+ .c_str());
}
- } catch (nb::cast_error &err) {
- throw nb::value_error((llvm::Twine("Operand ") +
- llvm::Twine(it.index()) + " of operation \"" +
- name + "\" must be a Value (" + err.what() +
- ")")
- .str()
- .c_str());
+
+ operandSegmentLengths.push_back(1);
+ } else if (segmentSpec == 0) {
+ // Allowed to be optional.
+ operandSegmentLengths.push_back(0);
+ } else {
+ throw nb::value_error(
+ (llvm::Twine("Operand ") + llvm::Twine(it.index()) +
+ " of operation \"" + name +
+ "\" must be a Value (was None and operand is not optional)")
+ .str()
+ .c_str());
}
} else if (segmentSpec == -1) {
// Unpack sequence by appending.
@@ -1852,10 +2011,7 @@ nb::object PyOpView::buildGeneric(
// Unpack the list.
auto segment = nb::cast<nb::sequence>(std::get<0>(it.value()));
for (nb::handle segmentItem : segment) {
- operands.push_back(nb::cast<PyValue *>(segmentItem));
- if (!operands.back()) {
- throw nb::type_error("contained a None item");
- }
+ operands.push_back(getOpResultOrValue(segmentItem));
}
operandSegmentLengths.push_back(nb::len(segment));
}
@@ -2269,57 +2425,6 @@ void PySymbolTable::walkSymbolTables(PyOperationBase &from,
}
namespace {
-/// CRTP base class for Python MLIR values that subclass Value and should be
-/// castable from it. The value hierarchy is one level deep and is not supposed
-/// to accommodate other levels unless core MLIR changes.
-template <typename DerivedTy>
-class PyConcreteValue : public PyValue {
-public:
- // Derived classes must define statics for:
- // IsAFunctionTy isaFunction
- // const char *pyClassName
- // and redefine bindDerived.
- using ClassTy = nb::class_<DerivedTy, PyValue>;
- using IsAFunctionTy = bool (*)(MlirValue);
-
- PyConcreteValue() = default;
- PyConcreteValue(PyOperationRef operationRef, MlirValue value)
- : PyValue(operationRef, value) {}
- PyConcreteValue(PyValue &orig)
- : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
-
- /// Attempts to cast the original value to the derived type and throws on
- /// type mismatches.
- static MlirValue castFrom(PyValue &orig) {
- if (!DerivedTy::isaFunction(orig.get())) {
- auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(orig)));
- throw nb::value_error((Twine("Cannot cast value to ") +
- DerivedTy::pyClassName + " (from " + origRepr +
- ")")
- .str()
- .c_str());
- }
- return orig.get();
- }
-
- /// Binds the Python module objects to functions of this class.
- static void bind(nb::module_ &m) {
- auto cls = ClassTy(m, DerivedTy::pyClassName);
- cls.def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"));
- cls.def_static(
- "isinstance",
- [](PyValue &otherValue) -> bool {
- return DerivedTy::isaFunction(otherValue);
- },
- nb::arg("other_value"));
- cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
- [](DerivedTy &self) { return self.maybeDownCast(); });
- DerivedTy::bindDerived(cls);
- }
-
- /// Implemented by derived classes to add methods to the Python subclass.
- static void bindDerived(ClassTy &m) {}
-};
/// Python wrapper for MlirBlockArgument.
class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
@@ -2345,39 +2450,6 @@ class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
}
};
-/// Python wrapper for MlirOpResult.
-class PyOpResult : public PyConcreteValue<PyOpResult> {
-public:
- static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
- static constexpr const char *pyClassName = "OpResult";
- using PyConcreteValue::PyConcreteValue;
-
- static void bindDerived(ClassTy &c) {
- c.def_prop_ro("owner", [](PyOpResult &self) {
- assert(
- mlirOperationEqual(self.getParentOperation()->get(),
- mlirOpResultGetOwner(self.get())) &&
- "expected the owner of the value in Python to match that in the IR");
- return self.getParentOperation().getObject();
- });
- c.def_prop_ro("result_number", [](PyOpResult &self) {
- return mlirOpResultGetResultNumber(self.get());
- });
- }
-};
-
-/// Returns the list of types of the values held by container.
-template <typename Container>
-static std::vector<MlirType> getValueTypes(Container &container,
- PyMlirContextRef &context) {
- std::vector<MlirType> result;
- result.reserve(container.size());
- for (int i = 0, e = container.size(); i < e; ++i) {
- result.push_back(mlirValueGetType(container.getElement(i).get()));
- }
- return result;
-}
-
/// A list of block arguments. Internally, these are stored as consecutive
/// elements, random access is cheap. The argument list is associated with the
/// operation that contains the block (detached blocks are not allowed in
@@ -2484,53 +2556,6 @@ class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
PyOperationRef operation;
};
-/// A list of operation results. Internally, these are stored as consecutive
-/// elements, random access is cheap. The (returned) result list is associated
-/// with the operation whose results these are, and thus extends the lifetime of
-/// this operation.
-class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
-public:
- static constexpr const char *pyClassName = "OpResultList";
- using SliceableT = Sliceable<PyOpResultList, PyOpResult>;
-
- PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
- intptr_t length = -1, intptr_t step = 1)
- : Sliceable(startIndex,
- length == -1 ? mlirOperationGetNumResults(operation->get())
- : length,
- step),
- operation(std::move(operation)) {}
-
- static void bindDerived(ClassTy &c) {
- c.def_prop_ro("types", [](PyOpResultList &self) {
- return getValueTypes(self, self.operation->getContext());
- });
- c.def_prop_ro("owner", [](PyOpResultList &self) {
- return self.operation->createOpView();
- });
- }
-
-private:
- /// Give the parent CRTP class access to hook implementations below.
- friend class Sliceable<PyOpResultList, PyOpResult>;
-
- intptr_t getRawNumElements() {
- operation->checkValid();
- return mlirOperationGetNumResults(operation->get());
- }
-
- PyOpResult getRawElement(intptr_t index) {
- PyValue value(operation, mlirOperationGetResult(operation->get(), index));
- return PyOpResult(value);
- }
-
- PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
- return PyOpResultList(operation, startIndex, length, step);
- }
-
- PyOperationRef operation;
-};
-
/// A list of operation successors. Internally, these are stored as consecutive
/// elements, random access is cheap. The (returned) successor list is
/// associated with the operation whose successors these are, and thus extends
@@ -3123,20 +3148,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
"result",
[](PyOperationBase &self) {
auto &operation = self.getOperation();
- auto numResults = mlirOperationGetNumResults(operation);
- if (numResults != 1) {
- auto name = mlirIdentifierStr(mlirOperationGetName(operation));
- throw nb::value_error(
- (Twine("Cannot call .result on operation ") +
- StringRef(name.data, name.length) + " which has " +
- Twine(numResults) +
- " results (it is only valid for operations with a "
- "single result)")
- .str()
- .c_str());
- }
- return PyOpResult(operation.getRef(),
- mlirOperationGetResult(operation, 0))
+ return PyOpResult(operation.getRef(), getUniqueResult(operation))
.maybeDownCast();
},
"Shortcut to get an op result if it has only one (throws an error "
@@ -3233,14 +3245,36 @@ void mlir::python::populateIRCore(nb::module_ &m) {
nb::arg("walk_order") = MlirWalkPostOrder);
nb::class_<PyOperation, PyOperationBase>(m, "Operation")
- .def_static("create", &PyOperation::create, nb::arg("name"),
- nb::arg("results").none() = nb::none(),
- nb::arg("operands").none() = nb::none(),
- nb::arg("attributes").none() = nb::none(),
- nb::arg("successors").none() = nb::none(),
- nb::arg("regions") = 0, nb::arg("loc").none() = nb::none(),
- nb::arg("ip").none() = nb::none(),
- nb::arg("infer_type") = false, kOperationCreateDocstring)
+ .def_static(
+ "create",
+ [](std::string_view name,
+ std::optional<std::vector<PyType *>> results,
+ std::optional<std::vector<PyValue *>> operands,
+ std::optional<nb::dict> attributes,
+ std::optional<std::vector<PyBlock *>> successors, int regions,
+ DefaultingPyLocation location, const nb::object &maybeIp,
+ bool inferType) {
+ // Unpack/validate operands.
+ llvm::SmallVector<MlirValue, 4> mlirOperands;
+ if (operands) {
+ mlirOperands.reserve(operands->size());
+ for (PyValue *operand : *operands) {
+ if (!operand)
+ throw nb::value_error("operand value cannot be None");
+ mlirOperands.push_back(operand->get());
+ }
+ }
+
+ return PyOperation::create(name, results, mlirOperands, attributes,
+ successors, regions, location, maybeIp,
+ inferType);
+ },
+ nb::arg("name"), nb::arg("results").none() = nb::none(),
+ nb::arg("operands").none() = nb::none(),
+ nb::arg("attributes").none() = nb::none(),
+ nb::arg("successors").none() = nb::none(), nb::arg("regions") = 0,
+ nb::arg("loc").none() = nb::none(), nb::arg("ip").none() = nb::none(),
+ nb::arg("infer_type") = false, kOperationCreateDocstring)
.def_static(
"parse",
[](const std::string &sourceStr, const std::string &sourceName,
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index fd70ac7ac6ec39..dd6e7ef9123746 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -686,7 +686,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
/// Creates an operation. See corresponding python docstring.
static nanobind::object
create(std::string_view name, std::optional<std::vector<PyType *>> results,
- std::optional<std::vector<PyValue *>> operands,
+ llvm::ArrayRef<MlirValue> operands,
std::optional<nanobind::dict> attributes,
std::optional<std::vector<PyBlock *>> successors, int regions,
DefaultingPyLocation location, const nanobind::object &ip,
diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index 5b67ab03d6f494..d3dbdc604ef4c6 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -115,7 +115,10 @@ def get_op_results_or_values(
_cext.ir.Operation,
_Sequence[_Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value]],
]
-) -> _Union[_Sequence[_cext.ir.Value], _cext.ir.OpResultList]:
+) -> _Union[
+ _Sequence[_Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value]],
+ _cext.ir.OpResultList,
+]:
"""Returns the given sequence of values or the results of the given op.
This is useful to implement op constructors so that they can take other ops as
@@ -127,7 +130,7 @@ def get_op_results_or_values(
elif isinstance(arg, _cext.ir.Operation):
return arg.results
else:
- return [get_op_result_or_value(element) for element in arg]
+ return arg
def get_op_result_or_op_results(
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 25833779c2f71f..72963cac64d54a 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -27,8 +27,8 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
// CHECK: attributes = {}
// CHECK: regions = None
// CHECK: operands.append(_get_op_results_or_values(variadic1))
- // CHECK: operands.append(_get_op_result_or_value(non_variadic))
- // CHECK: operands.append(_get_op_result_or_value(variadic2) if variadic2 is not None else None)
+ // CHECK: operands.append(non_variadic)
+ // CHECK: operands.append(variadic2)
// CHECK: _ods_successors = None
// CHECK: super().__init__(
// CHECK: self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS,
@@ -173,8 +173,8 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
// CHECK: results = []
// CHECK: attributes = {}
// CHECK: regions = None
- // CHECK: operands.append(_get_op_result_or_value(_gen_arg_0))
- // CHECK: operands.append(_get_op_result_or_value(_gen_arg_2))
+ // CHECK: operands.append(_gen_arg_0)
+ // CHECK: operands.append(_gen_arg_2)
// CHECK: if bool(in_): attributes["in"] = _ods_ir.UnitAttr.get(
// CHECK: _ods_get_default_loc_context(loc))
// CHECK: if is_ is not None: attributes["is"] = (is_
@@ -307,9 +307,9 @@ def MissingNamesOp : TestOp<"missing_names"> {
// CHECK: results = []
// CHECK: attributes = {}
// CHECK: regions = None
- // CHECK: operands.append(_get_op_result_or_value(_gen_arg_0))
- // CHECK: operands.append(_get_op_result_or_value(f32))
- // CHECK: operands.append(_get_op_result_or_value(_gen_arg_2))
+ // CHECK: operands.append(_gen_arg_0)
+ // CHECK: operands.append(f32)
+ // CHECK: operands.append(_gen_arg_2)
// CHECK: results.append(i32)
// CHECK: results.append(_gen_res_1)
// CHECK: results.append(i64)
@@ -349,8 +349,8 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
// CHECK: results = []
// CHECK: attributes = {}
// CHECK: regions = None
- // CHECK: operands.append(_get_op_result_or_value(non_optional))
- // CHECK: if optional is not None: operands.append(_get_op_result_or_value(optional))
+ // CHECK: operands.append(non_optional)
+ // CHECK: if optional is not None: operands.append(optional)
// CHECK: _ods_successors = None
// CHECK: super().__init__(
// CHECK: self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS
@@ -380,7 +380,7 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
// CHECK: results = []
// CHECK: attributes = {}
// CHECK: regions = None
- // CHECK: operands.append(_get_op_result_or_value(non_variadic))
+ // CHECK: operands.append(non_variadic)
// CHECK: operands.extend(_get_op_results_or_values(variadic))
// CHECK: _ods_successors = None
// CHECK: super().__init__(
@@ -445,7 +445,7 @@ def PythonKeywordOp : TestOp<"python_keyword"> {
// CHECK: results = []
// CHECK: attributes = {}
// CHECK: regions = None
- // CHECK: operands.append(_get_op_result_or_value(in_))
+ // CHECK: operands.append(in_)
// CHECK: _ods_successors = None
// CHECK: super().__init__(
// CHECK: self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS
@@ -547,8 +547,8 @@ def SimpleOp : TestOp<"simple"> {
// CHECK: results = []
// CHECK: attributes = {}
// CHECK: regions = None
- // CHECK: operands.append(_get_op_result_or_value(i32))
- // CHECK: operands.append(_get_op_result_or_value(f32))
+ // CHECK: operands.append(i32)
+ // CHECK: operands.append(f32)
// CHECK: results.append(i64)
// CHECK: results.append(f64)
// CHECK: _ods_successors = None
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index e1540d1750ff19..604d2376052a86 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -37,7 +37,6 @@ from ._ods_common import (
equally_sized_accessor as _ods_equally_sized_accessor,
get_default_loc_context as _ods_get_default_loc_context,
get_op_result_or_op_results as _get_op_result_or_op_results,
- get_op_result_or_value as _get_op_result_or_value,
get_op_results_or_values as _get_op_results_or_values,
segmented_accessor as _ods_segmented_accessor,
)
@@ -501,17 +500,15 @@ constexpr const char *initTemplate = R"Py(
/// Template for appending a single element to the operand/result list.
/// {0} is the field name.
-constexpr const char *singleOperandAppendTemplate =
- "operands.append(_get_op_result_or_value({0}))";
+constexpr const char *singleOperandAppendTemplate = "operands.append({0})";
constexpr const char *singleResultAppendTemplate = "results.append({0})";
/// Template for appending an optional element to the operand/result list.
/// {0} is the field name.
constexpr const char *optionalAppendOperandTemplate =
- "if {0} is not None: operands.append(_get_op_result_or_value({0}))";
+ "if {0} is not None: operands.append({0})";
constexpr const char *optionalAppendAttrSizedOperandsTemplate =
- "operands.append(_get_op_result_or_value({0}) if {0} is not None else "
- "None)";
+ "operands.append({0})";
constexpr const char *optionalAppendResultTemplate =
"if {0} is not None: results.append({0})";
More information about the Mlir-commits
mailing list