[Mlir-commits] [mlir] [MLIR][Python] Move operation/dialect name retrieving as a util function (PR #184605)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Mar 4 05:04:21 PST 2026
https://github.com/PragmaTwice created https://github.com/llvm/llvm-project/pull/184605
We have a common pattern that retrieve an operation name or dialect name from a `type` or `str` in the rewrite nanobind module, so better to make it a common util function.
>From e6d939f329fe90b733aba84799e1b5cf67680cc5 Mon Sep 17 00:00:00 2001
From: PragmaTwice <twice at apache.org>
Date: Wed, 4 Mar 2026 21:01:21 +0800
Subject: [PATCH] [MLIR][Python] Move operation/dialect name retrieving as a
common util function
---
mlir/lib/Bindings/Python/Rewrite.cpp | 56 ++++++++++++++--------------
1 file changed, 28 insertions(+), 28 deletions(-)
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 77bc572e4ef62..99403892d6bbc 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -39,6 +39,28 @@ static MlirLogicalResult logicalResultFromObject(const nb::object &obj) {
: mlirLogicalResultSuccess();
}
+static std::string operationNameFromObject(nb::handle root) {
+ if (root.is_type()) {
+ return nb::cast<std::string>(root.attr("OPERATION_NAME"));
+ }
+ if (nb::isinstance<nb::str>(root)) {
+ return nb::cast<std::string>(root);
+ }
+
+ throw nb::type_error("the root argument must be a type or a string");
+}
+
+static std::string dialectNameFromObject(nb::handle root) {
+ if (root.is_type()) {
+ return nb::cast<std::string>(root.attr("DIALECT_NAMESPACE"));
+ }
+ if (nb::isinstance<nb::str>(root)) {
+ return nb::cast<std::string>(root);
+ }
+
+ throw nb::type_error("the root argument must be a type or a string");
+}
+
class PyPatternRewriter : public PyRewriterBase<PyPatternRewriter> {
public:
static constexpr const char *pyClassName = "PatternRewriter";
@@ -69,14 +91,7 @@ bool PyRewritePatternSet::isOwned() const { return owned; }
void PyRewritePatternSet::add(nb::handle root,
const nb::callable &matchAndRewrite,
unsigned benefit) {
- std::string opName;
- if (root.is_type()) {
- opName = nb::cast<std::string>(root.attr("OPERATION_NAME"));
- } else if (nb::isinstance<nb::str>(root)) {
- opName = nb::cast<std::string>(root);
- } else {
- throw nb::type_error("the root argument must be a type or a string");
- }
+ std::string opName = operationNameFromObject(root);
MlirStringRef rootName = mlirStringRefCreate(opName.data(), opName.size());
MlirRewritePatternCallbacks callbacks;
@@ -212,14 +227,7 @@ void PyRewritePatternSet::addConversion(nb::handle root,
const nb::callable &matchAndRewrite,
PyTypeConverter &typeConverter,
unsigned benefit) {
- std::string opName;
- if (root.is_type()) {
- opName = nb::cast<std::string>(root.attr("OPERATION_NAME"));
- } else if (nb::isinstance<nb::str>(root)) {
- opName = nb::cast<std::string>(root);
- } else {
- throw nb::type_error("the root argument must be a type or a string");
- }
+ std::string opName = operationNameFromObject(root);
MlirStringRef rootName = mlirStringRefCreate(opName.data(), opName.size());
MlirConversionPatternCallbacks callbacks;
@@ -604,9 +612,7 @@ void populateRewriteSubmodule(nb::module_ &m) {
"add_legal_op",
[](PyConversionTarget &self, const nb::args &ops) {
for (auto op : ops) {
- std::string opName =
- nb::cast<std::string>(op.attr("OPERATION_NAME"));
- self.addLegalOp(opName);
+ self.addLegalOp(operationNameFromObject(op));
}
},
"ops"_a, "Mark the given operations as legal.")
@@ -614,9 +620,7 @@ void populateRewriteSubmodule(nb::module_ &m) {
"add_illegal_op",
[](PyConversionTarget &self, const nb::args &ops) {
for (auto op : ops) {
- std::string opName =
- nb::cast<std::string>(op.attr("OPERATION_NAME"));
- self.addIllegalOp(opName);
+ self.addIllegalOp(operationNameFromObject(op));
}
},
"ops"_a, "Mark the given operations as illegal.")
@@ -624,9 +628,7 @@ void populateRewriteSubmodule(nb::module_ &m) {
"add_legal_dialect",
[](PyConversionTarget &self, const nb::args &dialects) {
for (auto dialect : dialects) {
- std::string dialectName =
- nb::cast<std::string>(dialect.attr("DIALECT_NAMESPACE"));
- self.addLegalDialect(dialectName);
+ self.addLegalDialect(dialectNameFromObject(dialect));
}
},
"dialects"_a, "Mark the given dialects as legal.")
@@ -634,9 +636,7 @@ void populateRewriteSubmodule(nb::module_ &m) {
"add_illegal_dialect",
[](PyConversionTarget &self, const nb::args &dialects) {
for (auto dialect : dialects) {
- std::string dialectName =
- nb::cast<std::string>(dialect.attr("DIALECT_NAMESPACE"));
- self.addIllegalDialect(dialectName);
+ self.addIllegalDialect(dialectNameFromObject(dialect));
}
},
"dialects"_a, "Mark the given dialect as illegal.");
More information about the Mlir-commits
mailing list