[Mlir-commits] [mlir] [MLIR][Python] Move operation/dialect name retrieving as a util function (PR #184605)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Mar 4 05:05:01 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Twice (PragmaTwice)

<details>
<summary>Changes</summary>

We have a common pattern that retrieve an operation name or dialect name from a `type` or `str` in the rewrite nanobind module, so better to make it a common util function.

---
Full diff: https://github.com/llvm/llvm-project/pull/184605.diff


1 Files Affected:

- (modified) mlir/lib/Bindings/Python/Rewrite.cpp (+28-28) 


``````````diff
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 77bc572e4ef62..99403892d6bbc 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -39,6 +39,28 @@ static MlirLogicalResult logicalResultFromObject(const nb::object &obj) {
                              : mlirLogicalResultSuccess();
 }
 
+static std::string operationNameFromObject(nb::handle root) {
+  if (root.is_type()) {
+    return nb::cast<std::string>(root.attr("OPERATION_NAME"));
+  }
+  if (nb::isinstance<nb::str>(root)) {
+    return nb::cast<std::string>(root);
+  }
+
+  throw nb::type_error("the root argument must be a type or a string");
+}
+
+static std::string dialectNameFromObject(nb::handle root) {
+  if (root.is_type()) {
+    return nb::cast<std::string>(root.attr("DIALECT_NAMESPACE"));
+  }
+  if (nb::isinstance<nb::str>(root)) {
+    return nb::cast<std::string>(root);
+  }
+
+  throw nb::type_error("the root argument must be a type or a string");
+}
+
 class PyPatternRewriter : public PyRewriterBase<PyPatternRewriter> {
 public:
   static constexpr const char *pyClassName = "PatternRewriter";
@@ -69,14 +91,7 @@ bool PyRewritePatternSet::isOwned() const { return owned; }
 void PyRewritePatternSet::add(nb::handle root,
                               const nb::callable &matchAndRewrite,
                               unsigned benefit) {
-  std::string opName;
-  if (root.is_type()) {
-    opName = nb::cast<std::string>(root.attr("OPERATION_NAME"));
-  } else if (nb::isinstance<nb::str>(root)) {
-    opName = nb::cast<std::string>(root);
-  } else {
-    throw nb::type_error("the root argument must be a type or a string");
-  }
+  std::string opName = operationNameFromObject(root);
   MlirStringRef rootName = mlirStringRefCreate(opName.data(), opName.size());
 
   MlirRewritePatternCallbacks callbacks;
@@ -212,14 +227,7 @@ void PyRewritePatternSet::addConversion(nb::handle root,
                                         const nb::callable &matchAndRewrite,
                                         PyTypeConverter &typeConverter,
                                         unsigned benefit) {
-  std::string opName;
-  if (root.is_type()) {
-    opName = nb::cast<std::string>(root.attr("OPERATION_NAME"));
-  } else if (nb::isinstance<nb::str>(root)) {
-    opName = nb::cast<std::string>(root);
-  } else {
-    throw nb::type_error("the root argument must be a type or a string");
-  }
+  std::string opName = operationNameFromObject(root);
   MlirStringRef rootName = mlirStringRefCreate(opName.data(), opName.size());
 
   MlirConversionPatternCallbacks callbacks;
@@ -604,9 +612,7 @@ void populateRewriteSubmodule(nb::module_ &m) {
           "add_legal_op",
           [](PyConversionTarget &self, const nb::args &ops) {
             for (auto op : ops) {
-              std::string opName =
-                  nb::cast<std::string>(op.attr("OPERATION_NAME"));
-              self.addLegalOp(opName);
+              self.addLegalOp(operationNameFromObject(op));
             }
           },
           "ops"_a, "Mark the given operations as legal.")
@@ -614,9 +620,7 @@ void populateRewriteSubmodule(nb::module_ &m) {
           "add_illegal_op",
           [](PyConversionTarget &self, const nb::args &ops) {
             for (auto op : ops) {
-              std::string opName =
-                  nb::cast<std::string>(op.attr("OPERATION_NAME"));
-              self.addIllegalOp(opName);
+              self.addIllegalOp(operationNameFromObject(op));
             }
           },
           "ops"_a, "Mark the given operations as illegal.")
@@ -624,9 +628,7 @@ void populateRewriteSubmodule(nb::module_ &m) {
           "add_legal_dialect",
           [](PyConversionTarget &self, const nb::args &dialects) {
             for (auto dialect : dialects) {
-              std::string dialectName =
-                  nb::cast<std::string>(dialect.attr("DIALECT_NAMESPACE"));
-              self.addLegalDialect(dialectName);
+              self.addLegalDialect(dialectNameFromObject(dialect));
             }
           },
           "dialects"_a, "Mark the given dialects as legal.")
@@ -634,9 +636,7 @@ void populateRewriteSubmodule(nb::module_ &m) {
           "add_illegal_dialect",
           [](PyConversionTarget &self, const nb::args &dialects) {
             for (auto dialect : dialects) {
-              std::string dialectName =
-                  nb::cast<std::string>(dialect.attr("DIALECT_NAMESPACE"));
-              self.addIllegalDialect(dialectName);
+              self.addIllegalDialect(dialectNameFromObject(dialect));
             }
           },
           "dialects"_a, "Mark the given dialect as illegal.");

``````````

</details>


https://github.com/llvm/llvm-project/pull/184605


More information about the Mlir-commits mailing list