[Mlir-commits] [mlir] [MLIR][Python] Move operation/dialect name retrieving as a util function (PR #184605)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Mar 4 05:04:21 PST 2026


https://github.com/PragmaTwice created https://github.com/llvm/llvm-project/pull/184605

We have a common pattern that retrieve an operation name or dialect name from a `type` or `str` in the rewrite nanobind module, so better to make it a common util function.

>From e6d939f329fe90b733aba84799e1b5cf67680cc5 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Wed, 4 Mar 2026 21:01:21 +0800
Subject: [PATCH] [MLIR][Python] Move operation/dialect name retrieving as a
 common util function

---
 mlir/lib/Bindings/Python/Rewrite.cpp | 56 ++++++++++++++--------------
 1 file changed, 28 insertions(+), 28 deletions(-)

diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 77bc572e4ef62..99403892d6bbc 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -39,6 +39,28 @@ static MlirLogicalResult logicalResultFromObject(const nb::object &obj) {
                              : mlirLogicalResultSuccess();
 }
 
+static std::string operationNameFromObject(nb::handle root) {
+  if (root.is_type()) {
+    return nb::cast<std::string>(root.attr("OPERATION_NAME"));
+  }
+  if (nb::isinstance<nb::str>(root)) {
+    return nb::cast<std::string>(root);
+  }
+
+  throw nb::type_error("the root argument must be a type or a string");
+}
+
+static std::string dialectNameFromObject(nb::handle root) {
+  if (root.is_type()) {
+    return nb::cast<std::string>(root.attr("DIALECT_NAMESPACE"));
+  }
+  if (nb::isinstance<nb::str>(root)) {
+    return nb::cast<std::string>(root);
+  }
+
+  throw nb::type_error("the root argument must be a type or a string");
+}
+
 class PyPatternRewriter : public PyRewriterBase<PyPatternRewriter> {
 public:
   static constexpr const char *pyClassName = "PatternRewriter";
@@ -69,14 +91,7 @@ bool PyRewritePatternSet::isOwned() const { return owned; }
 void PyRewritePatternSet::add(nb::handle root,
                               const nb::callable &matchAndRewrite,
                               unsigned benefit) {
-  std::string opName;
-  if (root.is_type()) {
-    opName = nb::cast<std::string>(root.attr("OPERATION_NAME"));
-  } else if (nb::isinstance<nb::str>(root)) {
-    opName = nb::cast<std::string>(root);
-  } else {
-    throw nb::type_error("the root argument must be a type or a string");
-  }
+  std::string opName = operationNameFromObject(root);
   MlirStringRef rootName = mlirStringRefCreate(opName.data(), opName.size());
 
   MlirRewritePatternCallbacks callbacks;
@@ -212,14 +227,7 @@ void PyRewritePatternSet::addConversion(nb::handle root,
                                         const nb::callable &matchAndRewrite,
                                         PyTypeConverter &typeConverter,
                                         unsigned benefit) {
-  std::string opName;
-  if (root.is_type()) {
-    opName = nb::cast<std::string>(root.attr("OPERATION_NAME"));
-  } else if (nb::isinstance<nb::str>(root)) {
-    opName = nb::cast<std::string>(root);
-  } else {
-    throw nb::type_error("the root argument must be a type or a string");
-  }
+  std::string opName = operationNameFromObject(root);
   MlirStringRef rootName = mlirStringRefCreate(opName.data(), opName.size());
 
   MlirConversionPatternCallbacks callbacks;
@@ -604,9 +612,7 @@ void populateRewriteSubmodule(nb::module_ &m) {
           "add_legal_op",
           [](PyConversionTarget &self, const nb::args &ops) {
             for (auto op : ops) {
-              std::string opName =
-                  nb::cast<std::string>(op.attr("OPERATION_NAME"));
-              self.addLegalOp(opName);
+              self.addLegalOp(operationNameFromObject(op));
             }
           },
           "ops"_a, "Mark the given operations as legal.")
@@ -614,9 +620,7 @@ void populateRewriteSubmodule(nb::module_ &m) {
           "add_illegal_op",
           [](PyConversionTarget &self, const nb::args &ops) {
             for (auto op : ops) {
-              std::string opName =
-                  nb::cast<std::string>(op.attr("OPERATION_NAME"));
-              self.addIllegalOp(opName);
+              self.addIllegalOp(operationNameFromObject(op));
             }
           },
           "ops"_a, "Mark the given operations as illegal.")
@@ -624,9 +628,7 @@ void populateRewriteSubmodule(nb::module_ &m) {
           "add_legal_dialect",
           [](PyConversionTarget &self, const nb::args &dialects) {
             for (auto dialect : dialects) {
-              std::string dialectName =
-                  nb::cast<std::string>(dialect.attr("DIALECT_NAMESPACE"));
-              self.addLegalDialect(dialectName);
+              self.addLegalDialect(dialectNameFromObject(dialect));
             }
           },
           "dialects"_a, "Mark the given dialects as legal.")
@@ -634,9 +636,7 @@ void populateRewriteSubmodule(nb::module_ &m) {
           "add_illegal_dialect",
           [](PyConversionTarget &self, const nb::args &dialects) {
             for (auto dialect : dialects) {
-              std::string dialectName =
-                  nb::cast<std::string>(dialect.attr("DIALECT_NAMESPACE"));
-              self.addIllegalDialect(dialectName);
+              self.addIllegalDialect(dialectNameFromObject(dialect));
             }
           },
           "dialects"_a, "Mark the given dialect as illegal.");



More information about the Mlir-commits mailing list