[Mlir-commits] [mlir] Reformat whitespace in dependent dialects codegen (PR #78090)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Jan 14 00:20:00 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core
Author: mlevesquedion (mlevesquedion)
<details>
<summary>Changes</summary>
The generated code for dependent dialects is awkwardly formatted, making the code harder to read. This change reformats the whitespace to align code in its context and avoid unnecessary empty lines.
Also included are some typo fixes.
Below are examples of the codegen for a dialect before and after the change.
Before:
```
GPUDialect::GPUDialect(::mlir::MLIRContext *context)
: ::mlir::Dialect(getDialectNamespace(), context, ::mlir::TypeID::get<GPUDialect>()) {
getContext()->loadDialect<arith::ArithDialect>();
initialize();
}
```
After:
```
GPUDialect::GPUDialect(::mlir::MLIRContext *context)
: ::mlir::Dialect(getDialectNamespace(), context, ::mlir::TypeID::get<GPUDialect>()) {
getContext()->loadDialect<arith::ArithDialect>();
initialize();
}
```
Below are examples of the codegen for a pass before and after the change.
Before:
```
/// Return the dialect that must be loaded in the context before this pass.
void getDependentDialects(::mlir::DialectRegistry ®istry) const override {
registry.insert<func::FuncDialect>();
registry.insert<tensor::TensorDialect>();
registry.insert<tosa::TosaDialect>();
}
```
After:
```
/// Register the dialects that must be loaded in the context before this pass.
void getDependentDialects(::mlir::DialectRegistry ®istry) const override {
registry.insert<func::FuncDialect>();
registry.insert<tensor::TensorDialect>();
registry.insert<tosa::TosaDialect>();
}
```
---
Full diff: https://github.com/llvm/llvm-project/pull/78090.diff
4 Files Affected:
- (added) mlir/test/mlir-tblgen/dialect-with-dependents.td (+15)
- (added) mlir/test/mlir-tblgen/pass.td (+22)
- (modified) mlir/tools/mlir-tblgen/DialectGen.cpp (+11-9)
- (modified) mlir/tools/mlir-tblgen/PassGen.cpp (+16-11)
``````````diff
diff --git a/mlir/test/mlir-tblgen/dialect-with-dependents.td b/mlir/test/mlir-tblgen/dialect-with-dependents.td
new file mode 100644
index 00000000000000..e915e13841b5e5
--- /dev/null
+++ b/mlir/test/mlir-tblgen/dialect-with-dependents.td
@@ -0,0 +1,15 @@
+// RUN: mlir-tblgen -gen-dialect-defs -I %S/../../include %s | FileCheck %s
+
+include "mlir/IR/OpBase.td"
+
+def FooDialect : Dialect {
+ let name = "foo";
+ let dependentDialects = ["func::FuncDialect", "shape::ShapeDialect"];
+}
+
+// CHECK-LABEL: FooDialect::FooDialect
+// CHECK: {
+// CHECK-NEXT: getContext()->loadDialect<func::FuncDialect>();
+// CHECK-NEXT: getContext()->loadDialect<shape::ShapeDialect>();
+// CHECK-NEXT: initialize();
+// CHECK-NEXT: }
diff --git a/mlir/test/mlir-tblgen/pass.td b/mlir/test/mlir-tblgen/pass.td
new file mode 100644
index 00000000000000..fb5580a2e3dd16
--- /dev/null
+++ b/mlir/test/mlir-tblgen/pass.td
@@ -0,0 +1,22 @@
+// RUN: mlir-tblgen -gen-pass-decls -I %S/../../include %s | FileCheck %s
+
+include "mlir/Pass/PassBase.td"
+
+def FooPass : Pass<"foo", "ModuleOp"> {
+ let summary = "A pass for testing pass code generation.";
+ let dependentDialects = ["func::FuncDialect", "shape::ShapeDialect"];
+}
+
+// CHECK-LABEL: GEN_PASS_DEF_FOO
+// CHECK: FooPassBase
+// CHECK: void getDependentDialects
+// CHECK-NEXT: registry.insert<func::FuncDialect>();
+// CHECK-NEXT: registry.insert<shape::ShapeDialect>();
+// CHECK-NEXT: }
+
+// CHECK-LABEL: GEN_PASS_CLASSES
+// CHECK: FooPassBase
+// CHECK: void getDependentDialects
+// CHECK-NEXT: registry.insert<func::FuncDialect>();
+// CHECK-NEXT: registry.insert<shape::ShapeDialect>();
+// CHECK-NEXT: }
diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp
index f22434f755abe3..5d4c294bcd8004 100644
--- a/mlir/tools/mlir-tblgen/DialectGen.cpp
+++ b/mlir/tools/mlir-tblgen/DialectGen.cpp
@@ -106,9 +106,8 @@ class {0} : public ::mlir::{2} {
/// Registration for a single dependent dialect: to be inserted in the ctor
/// above for each dependent dialect.
-const char *const dialectRegistrationTemplate = R"(
- getContext()->loadDialect<{0}>();
-)";
+const char *const dialectRegistrationTemplate =
+ "getContext()->loadDialect<{0}>();";
/// The code block for the attribute parser/printer hooks.
static const char *const attrParserDecl = R"(
@@ -250,8 +249,8 @@ static bool emitDialectDecls(const llvm::RecordKeeper &recordKeeper,
/// The code block to generate a dialect constructor definition.
///
/// {0}: The name of the dialect class.
-/// {1}: initialization code that is emitted in the ctor body before calling
-/// initialize().
+/// {1}: Initialization code that is emitted in the ctor body before calling
+/// initialize(), such as dependent dialect registration.
/// {2}: The dialect parent class.
static const char *const dialectConstructorStr = R"(
{0}::{0}(::mlir::MLIRContext *context)
@@ -261,7 +260,7 @@ static const char *const dialectConstructorStr = R"(
}
)";
-/// The code block to generate a default desturctor definition.
+/// The code block to generate a default destructor definition.
///
/// {0}: The name of the dialect class.
static const char *const dialectDestructorStr = R"(
@@ -284,9 +283,12 @@ static void emitDialectDef(Dialect &dialect, raw_ostream &os) {
std::string dependentDialectRegistrations;
{
llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
- for (StringRef dependentDialect : dialect.getDependentDialects())
- dialectsOs << llvm::formatv(dialectRegistrationTemplate,
- dependentDialect);
+ llvm::interleave(
+ dialect.getDependentDialects(), dialectsOs,
+ [&dialectsOs](StringRef dd) {
+ dialectsOs << llvm::formatv(dialectRegistrationTemplate, dd);
+ },
+ "\n ");
}
// Emit the constructor and destructor.
diff --git a/mlir/tools/mlir-tblgen/PassGen.cpp b/mlir/tools/mlir-tblgen/PassGen.cpp
index de159d144ffbb4..4c895266177429 100644
--- a/mlir/tools/mlir-tblgen/PassGen.cpp
+++ b/mlir/tools/mlir-tblgen/PassGen.cpp
@@ -173,7 +173,8 @@ static void emitRegistrations(llvm::ArrayRef<Pass> passes, raw_ostream &os) {
/// {0}: The def name of the pass record.
/// {1}: The base class for the pass.
/// {2): The command line argument for the pass.
-/// {3}: The dependent dialects registration.
+/// {3}: The summary for the pass.
+/// {4}: The dependent dialects registration.
const char *const baseClassBegin = R"(
template <typename DerivedT>
class {0}Base : public {1} {
@@ -221,9 +222,7 @@ class {0}Base : public {1} {
/// Registration for a single dependent dialect, to be inserted for each
/// dependent dialect in the `getDependentDialects` above.
-const char *const dialectRegistrationTemplate = R"(
- registry.insert<{0}>();
-)";
+const char *const dialectRegistrationTemplate = "registry.insert<{0}>();";
const char *const friendDefaultConstructorDeclTemplate = R"(
namespace impl {{
@@ -307,9 +306,12 @@ static void emitPassDefs(const Pass &pass, raw_ostream &os) {
std::string dependentDialectRegistrations;
{
llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
- for (StringRef dependentDialect : pass.getDependentDialects())
- dialectsOs << llvm::formatv(dialectRegistrationTemplate,
- dependentDialect);
+ llvm::interleave(
+ pass.getDependentDialects(), dialectsOs,
+ [&dialectsOs](StringRef dd) {
+ dialectsOs << llvm::formatv(dialectRegistrationTemplate, dd);
+ },
+ "\n ");
}
os << "namespace impl {\n";
@@ -402,7 +404,7 @@ class {0}Base : public {1} {
return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
}
- /// Return the dialect that must be loaded in the context before this pass.
+ /// Register the dialects that must be loaded in the context before this pass.
void getDependentDialects(::mlir::DialectRegistry ®istry) const override {
{4}
}
@@ -422,9 +424,12 @@ static void emitOldPassDecl(const Pass &pass, raw_ostream &os) {
std::string dependentDialectRegistrations;
{
llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
- for (StringRef dependentDialect : pass.getDependentDialects())
- dialectsOs << llvm::formatv(dialectRegistrationTemplate,
- dependentDialect);
+ llvm::interleave(
+ pass.getDependentDialects(), dialectsOs,
+ [&dialectsOs](StringRef dd) {
+ dialectsOs << llvm::formatv(dialectRegistrationTemplate, dd);
+ },
+ "\n ");
}
os << llvm::formatv(oldPassDeclBegin, defName, pass.getBaseClass(),
pass.getArgument(), pass.getSummary(),
``````````
</details>
https://github.com/llvm/llvm-project/pull/78090
More information about the Mlir-commits
mailing list