[Mlir-commits] [mlir] acde3f7 - [mlir:python] Compute get_op_result_or_value in PyOpView's constructor. (#123953)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jan 24 06:26:33 PST 2025


Author: Peter Hawkins
Date: 2025-01-24T06:26:28-08:00
New Revision: acde3f722ff3766f6f793884108d342b78623fe4

URL: https://github.com/llvm/llvm-project/commit/acde3f722ff3766f6f793884108d342b78623fe4
DIFF: https://github.com/llvm/llvm-project/commit/acde3f722ff3766f6f793884108d342b78623fe4.diff

LOG: [mlir:python] Compute get_op_result_or_value in PyOpView's constructor. (#123953)

This logic is in the critical path for constructing an operation from
Python. It is faster to compute this in C++ than it is in Python, and it
is a minor change to do this.

This change also alters the API contract of
_ods_common.get_op_results_or_values to avoid calling
get_op_result_or_value on each element of a sequence, since the C++ code
will now do this.

Most of the diff here is simply reordering the code in IRCore.cpp.

Added: 
    

Modified: 
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/lib/Bindings/Python/IRModule.h
    mlir/python/mlir/dialects/_ods_common.py
    mlir/test/mlir-tblgen/op-python-bindings.td
    mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 738f1444b15fe5..8e351cb22eb948 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1481,12 +1481,11 @@ static void maybeInsertOperation(PyOperationRef &op,
 
 nb::object PyOperation::create(std::string_view name,
                                std::optional<std::vector<PyType *>> results,
-                               std::optional<std::vector<PyValue *>> operands,
+                               llvm::ArrayRef<MlirValue> operands,
                                std::optional<nb::dict> attributes,
                                std::optional<std::vector<PyBlock *>> successors,
                                int regions, DefaultingPyLocation location,
                                const nb::object &maybeIp, bool inferType) {
-  llvm::SmallVector<MlirValue, 4> mlirOperands;
   llvm::SmallVector<MlirType, 4> mlirResults;
   llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
   llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
@@ -1495,16 +1494,6 @@ nb::object PyOperation::create(std::string_view name,
   if (regions < 0)
     throw nb::value_error("number of regions must be >= 0");
 
-  // Unpack/validate operands.
-  if (operands) {
-    mlirOperands.reserve(operands->size());
-    for (PyValue *operand : *operands) {
-      if (!operand)
-        throw nb::value_error("operand value cannot be None");
-      mlirOperands.push_back(operand->get());
-    }
-  }
-
   // Unpack/validate results.
   if (results) {
     mlirResults.reserve(results->size());
@@ -1562,9 +1551,8 @@ nb::object PyOperation::create(std::string_view name,
   // point, exceptions cannot be thrown or else the state will leak.
   MlirOperationState state =
       mlirOperationStateGet(toMlirStringRef(name), location);
-  if (!mlirOperands.empty())
-    mlirOperationStateAddOperands(&state, mlirOperands.size(),
-                                  mlirOperands.data());
+  if (!operands.empty())
+    mlirOperationStateAddOperands(&state, operands.size(), operands.data());
   state.enableResultTypeInference = inferType;
   if (!mlirResults.empty())
     mlirOperationStateAddResults(&state, mlirResults.size(),
@@ -1632,6 +1620,143 @@ void PyOperation::erase() {
   mlirOperationDestroy(operation);
 }
 
+namespace {
+/// CRTP base class for Python MLIR values that subclass Value and should be
+/// castable from it. The value hierarchy is one level deep and is not supposed
+/// to accommodate other levels unless core MLIR changes.
+template <typename DerivedTy>
+class PyConcreteValue : public PyValue {
+public:
+  // Derived classes must define statics for:
+  //   IsAFunctionTy isaFunction
+  //   const char *pyClassName
+  // and redefine bindDerived.
+  using ClassTy = nb::class_<DerivedTy, PyValue>;
+  using IsAFunctionTy = bool (*)(MlirValue);
+
+  PyConcreteValue() = default;
+  PyConcreteValue(PyOperationRef operationRef, MlirValue value)
+      : PyValue(operationRef, value) {}
+  PyConcreteValue(PyValue &orig)
+      : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
+
+  /// Attempts to cast the original value to the derived type and throws on
+  /// type mismatches.
+  static MlirValue castFrom(PyValue &orig) {
+    if (!DerivedTy::isaFunction(orig.get())) {
+      auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(orig)));
+      throw nb::value_error((Twine("Cannot cast value to ") +
+                             DerivedTy::pyClassName + " (from " + origRepr +
+                             ")")
+                                .str()
+                                .c_str());
+    }
+    return orig.get();
+  }
+
+  /// Binds the Python module objects to functions of this class.
+  static void bind(nb::module_ &m) {
+    auto cls = ClassTy(m, DerivedTy::pyClassName);
+    cls.def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"));
+    cls.def_static(
+        "isinstance",
+        [](PyValue &otherValue) -> bool {
+          return DerivedTy::isaFunction(otherValue);
+        },
+        nb::arg("other_value"));
+    cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
+            [](DerivedTy &self) { return self.maybeDownCast(); });
+    DerivedTy::bindDerived(cls);
+  }
+
+  /// Implemented by derived classes to add methods to the Python subclass.
+  static void bindDerived(ClassTy &m) {}
+};
+
+} // namespace
+
+/// Python wrapper for MlirOpResult.
+class PyOpResult : public PyConcreteValue<PyOpResult> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
+  static constexpr const char *pyClassName = "OpResult";
+  using PyConcreteValue::PyConcreteValue;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_prop_ro("owner", [](PyOpResult &self) {
+      assert(
+          mlirOperationEqual(self.getParentOperation()->get(),
+                             mlirOpResultGetOwner(self.get())) &&
+          "expected the owner of the value in Python to match that in the IR");
+      return self.getParentOperation().getObject();
+    });
+    c.def_prop_ro("result_number", [](PyOpResult &self) {
+      return mlirOpResultGetResultNumber(self.get());
+    });
+  }
+};
+
+/// Returns the list of types of the values held by container.
+template <typename Container>
+static std::vector<MlirType> getValueTypes(Container &container,
+                                           PyMlirContextRef &context) {
+  std::vector<MlirType> result;
+  result.reserve(container.size());
+  for (int i = 0, e = container.size(); i < e; ++i) {
+    result.push_back(mlirValueGetType(container.getElement(i).get()));
+  }
+  return result;
+}
+
+/// A list of operation results. Internally, these are stored as consecutive
+/// elements, random access is cheap. The (returned) result list is associated
+/// with the operation whose results these are, and thus extends the lifetime of
+/// this operation.
+class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
+public:
+  static constexpr const char *pyClassName = "OpResultList";
+  using SliceableT = Sliceable<PyOpResultList, PyOpResult>;
+
+  PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
+                 intptr_t length = -1, intptr_t step = 1)
+      : Sliceable(startIndex,
+                  length == -1 ? mlirOperationGetNumResults(operation->get())
+                               : length,
+                  step),
+        operation(std::move(operation)) {}
+
+  static void bindDerived(ClassTy &c) {
+    c.def_prop_ro("types", [](PyOpResultList &self) {
+      return getValueTypes(self, self.operation->getContext());
+    });
+    c.def_prop_ro("owner", [](PyOpResultList &self) {
+      return self.operation->createOpView();
+    });
+  }
+
+  PyOperationRef &getOperation() { return operation; }
+
+private:
+  /// Give the parent CRTP class access to hook implementations below.
+  friend class Sliceable<PyOpResultList, PyOpResult>;
+
+  intptr_t getRawNumElements() {
+    operation->checkValid();
+    return mlirOperationGetNumResults(operation->get());
+  }
+
+  PyOpResult getRawElement(intptr_t index) {
+    PyValue value(operation, mlirOperationGetResult(operation->get(), index));
+    return PyOpResult(value);
+  }
+
+  PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
+    return PyOpResultList(operation, startIndex, length, step);
+  }
+
+  PyOperationRef operation;
+};
+
 //------------------------------------------------------------------------------
 // PyOpView
 //------------------------------------------------------------------------------
@@ -1733,6 +1858,40 @@ static void populateResultTypes(StringRef name, nb::list resultTypeList,
   }
 }
 
+static MlirValue getUniqueResult(MlirOperation operation) {
+  auto numResults = mlirOperationGetNumResults(operation);
+  if (numResults != 1) {
+    auto name = mlirIdentifierStr(mlirOperationGetName(operation));
+    throw nb::value_error((Twine("Cannot call .result on operation ") +
+                           StringRef(name.data, name.length) + " which has " +
+                           Twine(numResults) +
+                           " results (it is only valid for operations with a "
+                           "single result)")
+                              .str()
+                              .c_str());
+  }
+  return mlirOperationGetResult(operation, 0);
+}
+
+static MlirValue getOpResultOrValue(nb::handle operand) {
+  if (operand.is_none()) {
+    throw nb::value_error("contained a None item");
+  }
+  PyOperationBase *op;
+  if (nb::try_cast<PyOperationBase *>(operand, op)) {
+    return getUniqueResult(op->getOperation());
+  }
+  PyOpResultList *opResultList;
+  if (nb::try_cast<PyOpResultList *>(operand, opResultList)) {
+    return getUniqueResult(opResultList->getOperation()->get());
+  }
+  PyValue *value;
+  if (nb::try_cast<PyValue *>(operand, value)) {
+    return value->get();
+  }
+  throw nb::value_error("is not a Value");
+}
+
 nb::object PyOpView::buildGeneric(
     std::string_view name, std::tuple<int, bool> opRegionSpec,
     nb::object operandSegmentSpecObj, nb::object resultSegmentSpecObj,
@@ -1783,16 +1942,14 @@ nb::object PyOpView::buildGeneric(
   }
 
   // Unpack operands.
-  std::vector<PyValue *> operands;
+  llvm::SmallVector<MlirValue, 4> operands;
   operands.reserve(operands.size());
   if (operandSegmentSpecObj.is_none()) {
     // Non-sized operand unpacking.
     for (const auto &it : llvm::enumerate(operandList)) {
       try {
-        operands.push_back(nb::cast<PyValue *>(it.value()));
-        if (!operands.back())
-          throw nb::cast_error();
-      } catch (nb::cast_error &err) {
+        operands.push_back(getOpResultOrValue(it.value()));
+      } catch (nb::builtin_exception &err) {
         throw nb::value_error((llvm::Twine("Operand ") +
                                llvm::Twine(it.index()) + " of operation \"" +
                                name + "\" must be a Value (" + err.what() + ")")
@@ -1818,29 +1975,31 @@ nb::object PyOpView::buildGeneric(
       int segmentSpec = std::get<1>(it.value());
       if (segmentSpec == 1 || segmentSpec == 0) {
         // Unpack unary element.
-        try {
-          auto *operandValue = nb::cast<PyValue *>(std::get<0>(it.value()));
-          if (operandValue) {
-            operands.push_back(operandValue);
-            operandSegmentLengths.push_back(1);
-          } else if (segmentSpec == 0) {
-            // Allowed to be optional.
-            operandSegmentLengths.push_back(0);
-          } else {
-            throw nb::value_error(
-                (llvm::Twine("Operand ") + llvm::Twine(it.index()) +
-                 " of operation \"" + name +
-                 "\" must be a Value (was None and operand is not optional)")
-                    .str()
-                    .c_str());
+        auto &operand = std::get<0>(it.value());
+        if (!operand.is_none()) {
+          try {
+
+            operands.push_back(getOpResultOrValue(operand));
+          } catch (nb::builtin_exception &err) {
+            throw nb::value_error((llvm::Twine("Operand ") +
+                                   llvm::Twine(it.index()) +
+                                   " of operation \"" + name +
+                                   "\" must be a Value (" + err.what() + ")")
+                                      .str()
+                                      .c_str());
           }
-        } catch (nb::cast_error &err) {
-          throw nb::value_error((llvm::Twine("Operand ") +
-                                 llvm::Twine(it.index()) + " of operation \"" +
-                                 name + "\" must be a Value (" + err.what() +
-                                 ")")
-                                    .str()
-                                    .c_str());
+
+          operandSegmentLengths.push_back(1);
+        } else if (segmentSpec == 0) {
+          // Allowed to be optional.
+          operandSegmentLengths.push_back(0);
+        } else {
+          throw nb::value_error(
+              (llvm::Twine("Operand ") + llvm::Twine(it.index()) +
+               " of operation \"" + name +
+               "\" must be a Value (was None and operand is not optional)")
+                  .str()
+                  .c_str());
         }
       } else if (segmentSpec == -1) {
         // Unpack sequence by appending.
@@ -1852,10 +2011,7 @@ nb::object PyOpView::buildGeneric(
             // Unpack the list.
             auto segment = nb::cast<nb::sequence>(std::get<0>(it.value()));
             for (nb::handle segmentItem : segment) {
-              operands.push_back(nb::cast<PyValue *>(segmentItem));
-              if (!operands.back()) {
-                throw nb::type_error("contained a None item");
-              }
+              operands.push_back(getOpResultOrValue(segmentItem));
             }
             operandSegmentLengths.push_back(nb::len(segment));
           }
@@ -2269,57 +2425,6 @@ void PySymbolTable::walkSymbolTables(PyOperationBase &from,
 }
 
 namespace {
-/// CRTP base class for Python MLIR values that subclass Value and should be
-/// castable from it. The value hierarchy is one level deep and is not supposed
-/// to accommodate other levels unless core MLIR changes.
-template <typename DerivedTy>
-class PyConcreteValue : public PyValue {
-public:
-  // Derived classes must define statics for:
-  //   IsAFunctionTy isaFunction
-  //   const char *pyClassName
-  // and redefine bindDerived.
-  using ClassTy = nb::class_<DerivedTy, PyValue>;
-  using IsAFunctionTy = bool (*)(MlirValue);
-
-  PyConcreteValue() = default;
-  PyConcreteValue(PyOperationRef operationRef, MlirValue value)
-      : PyValue(operationRef, value) {}
-  PyConcreteValue(PyValue &orig)
-      : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
-
-  /// Attempts to cast the original value to the derived type and throws on
-  /// type mismatches.
-  static MlirValue castFrom(PyValue &orig) {
-    if (!DerivedTy::isaFunction(orig.get())) {
-      auto origRepr = nb::cast<std::string>(nb::repr(nb::cast(orig)));
-      throw nb::value_error((Twine("Cannot cast value to ") +
-                             DerivedTy::pyClassName + " (from " + origRepr +
-                             ")")
-                                .str()
-                                .c_str());
-    }
-    return orig.get();
-  }
-
-  /// Binds the Python module objects to functions of this class.
-  static void bind(nb::module_ &m) {
-    auto cls = ClassTy(m, DerivedTy::pyClassName);
-    cls.def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"));
-    cls.def_static(
-        "isinstance",
-        [](PyValue &otherValue) -> bool {
-          return DerivedTy::isaFunction(otherValue);
-        },
-        nb::arg("other_value"));
-    cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
-            [](DerivedTy &self) { return self.maybeDownCast(); });
-    DerivedTy::bindDerived(cls);
-  }
-
-  /// Implemented by derived classes to add methods to the Python subclass.
-  static void bindDerived(ClassTy &m) {}
-};
 
 /// Python wrapper for MlirBlockArgument.
 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
@@ -2345,39 +2450,6 @@ class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
   }
 };
 
-/// Python wrapper for MlirOpResult.
-class PyOpResult : public PyConcreteValue<PyOpResult> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
-  static constexpr const char *pyClassName = "OpResult";
-  using PyConcreteValue::PyConcreteValue;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_prop_ro("owner", [](PyOpResult &self) {
-      assert(
-          mlirOperationEqual(self.getParentOperation()->get(),
-                             mlirOpResultGetOwner(self.get())) &&
-          "expected the owner of the value in Python to match that in the IR");
-      return self.getParentOperation().getObject();
-    });
-    c.def_prop_ro("result_number", [](PyOpResult &self) {
-      return mlirOpResultGetResultNumber(self.get());
-    });
-  }
-};
-
-/// Returns the list of types of the values held by container.
-template <typename Container>
-static std::vector<MlirType> getValueTypes(Container &container,
-                                           PyMlirContextRef &context) {
-  std::vector<MlirType> result;
-  result.reserve(container.size());
-  for (int i = 0, e = container.size(); i < e; ++i) {
-    result.push_back(mlirValueGetType(container.getElement(i).get()));
-  }
-  return result;
-}
-
 /// A list of block arguments. Internally, these are stored as consecutive
 /// elements, random access is cheap. The argument list is associated with the
 /// operation that contains the block (detached blocks are not allowed in
@@ -2484,53 +2556,6 @@ class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
   PyOperationRef operation;
 };
 
-/// A list of operation results. Internally, these are stored as consecutive
-/// elements, random access is cheap. The (returned) result list is associated
-/// with the operation whose results these are, and thus extends the lifetime of
-/// this operation.
-class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
-public:
-  static constexpr const char *pyClassName = "OpResultList";
-  using SliceableT = Sliceable<PyOpResultList, PyOpResult>;
-
-  PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
-                 intptr_t length = -1, intptr_t step = 1)
-      : Sliceable(startIndex,
-                  length == -1 ? mlirOperationGetNumResults(operation->get())
-                               : length,
-                  step),
-        operation(std::move(operation)) {}
-
-  static void bindDerived(ClassTy &c) {
-    c.def_prop_ro("types", [](PyOpResultList &self) {
-      return getValueTypes(self, self.operation->getContext());
-    });
-    c.def_prop_ro("owner", [](PyOpResultList &self) {
-      return self.operation->createOpView();
-    });
-  }
-
-private:
-  /// Give the parent CRTP class access to hook implementations below.
-  friend class Sliceable<PyOpResultList, PyOpResult>;
-
-  intptr_t getRawNumElements() {
-    operation->checkValid();
-    return mlirOperationGetNumResults(operation->get());
-  }
-
-  PyOpResult getRawElement(intptr_t index) {
-    PyValue value(operation, mlirOperationGetResult(operation->get(), index));
-    return PyOpResult(value);
-  }
-
-  PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
-    return PyOpResultList(operation, startIndex, length, step);
-  }
-
-  PyOperationRef operation;
-};
-
 /// A list of operation successors. Internally, these are stored as consecutive
 /// elements, random access is cheap. The (returned) successor list is
 /// associated with the operation whose successors these are, and thus extends
@@ -3123,20 +3148,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
           "result",
           [](PyOperationBase &self) {
             auto &operation = self.getOperation();
-            auto numResults = mlirOperationGetNumResults(operation);
-            if (numResults != 1) {
-              auto name = mlirIdentifierStr(mlirOperationGetName(operation));
-              throw nb::value_error(
-                  (Twine("Cannot call .result on operation ") +
-                   StringRef(name.data, name.length) + " which has " +
-                   Twine(numResults) +
-                   " results (it is only valid for operations with a "
-                   "single result)")
-                      .str()
-                      .c_str());
-            }
-            return PyOpResult(operation.getRef(),
-                              mlirOperationGetResult(operation, 0))
+            return PyOpResult(operation.getRef(), getUniqueResult(operation))
                 .maybeDownCast();
           },
           "Shortcut to get an op result if it has only one (throws an error "
@@ -3233,14 +3245,36 @@ void mlir::python::populateIRCore(nb::module_ &m) {
            nb::arg("walk_order") = MlirWalkPostOrder);
 
   nb::class_<PyOperation, PyOperationBase>(m, "Operation")
-      .def_static("create", &PyOperation::create, nb::arg("name"),
-                  nb::arg("results").none() = nb::none(),
-                  nb::arg("operands").none() = nb::none(),
-                  nb::arg("attributes").none() = nb::none(),
-                  nb::arg("successors").none() = nb::none(),
-                  nb::arg("regions") = 0, nb::arg("loc").none() = nb::none(),
-                  nb::arg("ip").none() = nb::none(),
-                  nb::arg("infer_type") = false, kOperationCreateDocstring)
+      .def_static(
+          "create",
+          [](std::string_view name,
+             std::optional<std::vector<PyType *>> results,
+             std::optional<std::vector<PyValue *>> operands,
+             std::optional<nb::dict> attributes,
+             std::optional<std::vector<PyBlock *>> successors, int regions,
+             DefaultingPyLocation location, const nb::object &maybeIp,
+             bool inferType) {
+            // Unpack/validate operands.
+            llvm::SmallVector<MlirValue, 4> mlirOperands;
+            if (operands) {
+              mlirOperands.reserve(operands->size());
+              for (PyValue *operand : *operands) {
+                if (!operand)
+                  throw nb::value_error("operand value cannot be None");
+                mlirOperands.push_back(operand->get());
+              }
+            }
+
+            return PyOperation::create(name, results, mlirOperands, attributes,
+                                       successors, regions, location, maybeIp,
+                                       inferType);
+          },
+          nb::arg("name"), nb::arg("results").none() = nb::none(),
+          nb::arg("operands").none() = nb::none(),
+          nb::arg("attributes").none() = nb::none(),
+          nb::arg("successors").none() = nb::none(), nb::arg("regions") = 0,
+          nb::arg("loc").none() = nb::none(), nb::arg("ip").none() = nb::none(),
+          nb::arg("infer_type") = false, kOperationCreateDocstring)
       .def_static(
           "parse",
           [](const std::string &sourceStr, const std::string &sourceName,

diff  --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index fd70ac7ac6ec39..dd6e7ef9123746 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -686,7 +686,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
   /// Creates an operation. See corresponding python docstring.
   static nanobind::object
   create(std::string_view name, std::optional<std::vector<PyType *>> results,
-         std::optional<std::vector<PyValue *>> operands,
+         llvm::ArrayRef<MlirValue> operands,
          std::optional<nanobind::dict> attributes,
          std::optional<std::vector<PyBlock *>> successors, int regions,
          DefaultingPyLocation location, const nanobind::object &ip,

diff  --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index 5b67ab03d6f494..d3dbdc604ef4c6 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -115,7 +115,10 @@ def get_op_results_or_values(
         _cext.ir.Operation,
         _Sequence[_Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value]],
     ]
-) -> _Union[_Sequence[_cext.ir.Value], _cext.ir.OpResultList]:
+) -> _Union[
+    _Sequence[_Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value]],
+    _cext.ir.OpResultList,
+]:
     """Returns the given sequence of values or the results of the given op.
 
     This is useful to implement op constructors so that they can take other ops as
@@ -127,7 +130,7 @@ def get_op_results_or_values(
     elif isinstance(arg, _cext.ir.Operation):
         return arg.results
     else:
-        return [get_op_result_or_value(element) for element in arg]
+        return arg
 
 
 def get_op_result_or_op_results(

diff  --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 25833779c2f71f..72963cac64d54a 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -27,8 +27,8 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
   // CHECK:   attributes = {}
   // CHECK:   regions = None
   // CHECK:   operands.append(_get_op_results_or_values(variadic1))
-  // CHECK:   operands.append(_get_op_result_or_value(non_variadic))
-  // CHECK:   operands.append(_get_op_result_or_value(variadic2) if variadic2 is not None else None)
+  // CHECK:   operands.append(non_variadic)
+  // CHECK:   operands.append(variadic2)
   // CHECK:   _ods_successors = None
   // CHECK:   super().__init__(
   // CHECK:     self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS,
@@ -173,8 +173,8 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
   // CHECK:   results = []
   // CHECK:   attributes = {}
   // CHECK:   regions = None
-  // CHECK:   operands.append(_get_op_result_or_value(_gen_arg_0))
-  // CHECK:   operands.append(_get_op_result_or_value(_gen_arg_2))
+  // CHECK:   operands.append(_gen_arg_0)
+  // CHECK:   operands.append(_gen_arg_2)
   // CHECK:   if bool(in_): attributes["in"] = _ods_ir.UnitAttr.get(
   // CHECK:     _ods_get_default_loc_context(loc))
   // CHECK:   if is_ is not None: attributes["is"] = (is_
@@ -307,9 +307,9 @@ def MissingNamesOp : TestOp<"missing_names"> {
   // CHECK:   results = []
   // CHECK:   attributes = {}
   // CHECK:   regions = None
-  // CHECK:   operands.append(_get_op_result_or_value(_gen_arg_0))
-  // CHECK:   operands.append(_get_op_result_or_value(f32))
-  // CHECK:   operands.append(_get_op_result_or_value(_gen_arg_2))
+  // CHECK:   operands.append(_gen_arg_0)
+  // CHECK:   operands.append(f32)
+  // CHECK:   operands.append(_gen_arg_2)
   // CHECK:   results.append(i32)
   // CHECK:   results.append(_gen_res_1)
   // CHECK:   results.append(i64)
@@ -349,8 +349,8 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
   // CHECK:   results = []
   // CHECK:   attributes = {}
   // CHECK:   regions = None
-  // CHECK:   operands.append(_get_op_result_or_value(non_optional))
-  // CHECK:   if optional is not None: operands.append(_get_op_result_or_value(optional))
+  // CHECK:   operands.append(non_optional)
+  // CHECK:   if optional is not None: operands.append(optional)
   // CHECK:   _ods_successors = None
   // CHECK:   super().__init__(
   // CHECK:     self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS
@@ -380,7 +380,7 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
   // CHECK:   results = []
   // CHECK:   attributes = {}
   // CHECK:   regions = None
-  // CHECK:   operands.append(_get_op_result_or_value(non_variadic))
+  // CHECK:   operands.append(non_variadic)
   // CHECK:   operands.extend(_get_op_results_or_values(variadic))
   // CHECK:   _ods_successors = None
   // CHECK:   super().__init__(
@@ -445,7 +445,7 @@ def PythonKeywordOp : TestOp<"python_keyword"> {
   // CHECK:   results = []
   // CHECK:   attributes = {}
   // CHECK:   regions = None
-  // CHECK:   operands.append(_get_op_result_or_value(in_))
+  // CHECK:   operands.append(in_)
   // CHECK:   _ods_successors = None
   // CHECK:   super().__init__(
   // CHECK:     self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS
@@ -547,8 +547,8 @@ def SimpleOp : TestOp<"simple"> {
   // CHECK:   results = []
   // CHECK:   attributes = {}
   // CHECK:   regions = None
-  // CHECK:   operands.append(_get_op_result_or_value(i32))
-  // CHECK:   operands.append(_get_op_result_or_value(f32))
+  // CHECK:   operands.append(i32)
+  // CHECK:   operands.append(f32)
   // CHECK:   results.append(i64)
   // CHECK:   results.append(f64)
   // CHECK:   _ods_successors = None

diff  --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index e1540d1750ff19..604d2376052a86 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -37,7 +37,6 @@ from ._ods_common import (
     equally_sized_accessor as _ods_equally_sized_accessor,
     get_default_loc_context as _ods_get_default_loc_context,
     get_op_result_or_op_results as _get_op_result_or_op_results,
-    get_op_result_or_value as _get_op_result_or_value,
     get_op_results_or_values as _get_op_results_or_values,
     segmented_accessor as _ods_segmented_accessor,
 )
@@ -501,17 +500,15 @@ constexpr const char *initTemplate = R"Py(
 
 /// Template for appending a single element to the operand/result list.
 ///   {0} is the field name.
-constexpr const char *singleOperandAppendTemplate =
-    "operands.append(_get_op_result_or_value({0}))";
+constexpr const char *singleOperandAppendTemplate = "operands.append({0})";
 constexpr const char *singleResultAppendTemplate = "results.append({0})";
 
 /// Template for appending an optional element to the operand/result list.
 ///   {0} is the field name.
 constexpr const char *optionalAppendOperandTemplate =
-    "if {0} is not None: operands.append(_get_op_result_or_value({0}))";
+    "if {0} is not None: operands.append({0})";
 constexpr const char *optionalAppendAttrSizedOperandsTemplate =
-    "operands.append(_get_op_result_or_value({0}) if {0} is not None else "
-    "None)";
+    "operands.append({0})";
 constexpr const char *optionalAppendResultTemplate =
     "if {0} is not None: results.append({0})";
 


        


More information about the Mlir-commits mailing list