[Mlir-commits] [mlir] Reformat whitespace in dependent dialects codegen (PR #78090)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Jan 14 21:12:28 PST 2024
https://github.com/mlevesquedion updated https://github.com/llvm/llvm-project/pull/78090
>From dca03e2329a5657174d1449318292404e2f21899 Mon Sep 17 00:00:00 2001
From: Michael Levesque-Dion <mlevesquedion at google.com>
Date: Sat, 13 Jan 2024 23:46:56 -0800
Subject: [PATCH 1/2] Reformat whitespace in dependent dialects codegen
The generated code for dependent dialects is awkwardly formatted, making
the code harder to read. This change reformats the whitespace to align
code in its context and avoid unnecessary empty lines.
Also included are some typo fixes.
Below are examples of the codegen for a dialect before and after the
change.
Before:
```
GPUDialect::GPUDialect(::mlir::MLIRContext *context)
: ::mlir::Dialect(getDialectNamespace(), context, ::mlir::TypeID::get<GPUDialect>()) {
getContext()->loadDialect<arith::ArithDialect>();
initialize();
}
```
After:
```
GPUDialect::GPUDialect(::mlir::MLIRContext *context)
: ::mlir::Dialect(getDialectNamespace(), context, ::mlir::TypeID::get<GPUDialect>()) {
getContext()->loadDialect<arith::ArithDialect>();
initialize();
}
```
Below are examples of the codegen for a pass before and after the change.
Before:
```
/// Return the dialect that must be loaded in the context before this pass.
void getDependentDialects(::mlir::DialectRegistry ®istry) const override {
registry.insert<func::FuncDialect>();
registry.insert<tensor::TensorDialect>();
registry.insert<tosa::TosaDialect>();
}
```
After:
```
/// Register the dialects that must be loaded in the context before this pass.
void getDependentDialects(::mlir::DialectRegistry ®istry) const override {
registry.insert<func::FuncDialect>();
registry.insert<tensor::TensorDialect>();
registry.insert<tosa::TosaDialect>();
}
```
---
.../mlir-tblgen/dialect-with-dependents.td | 15 +++++++++++
mlir/test/mlir-tblgen/pass.td | 22 +++++++++++++++
mlir/tools/mlir-tblgen/DialectGen.cpp | 20 +++++++-------
mlir/tools/mlir-tblgen/PassGen.cpp | 27 +++++++++++--------
4 files changed, 64 insertions(+), 20 deletions(-)
create mode 100644 mlir/test/mlir-tblgen/dialect-with-dependents.td
create mode 100644 mlir/test/mlir-tblgen/pass.td
diff --git a/mlir/test/mlir-tblgen/dialect-with-dependents.td b/mlir/test/mlir-tblgen/dialect-with-dependents.td
new file mode 100644
index 00000000000000..e915e13841b5e5
--- /dev/null
+++ b/mlir/test/mlir-tblgen/dialect-with-dependents.td
@@ -0,0 +1,15 @@
+// RUN: mlir-tblgen -gen-dialect-defs -I %S/../../include %s | FileCheck %s
+
+include "mlir/IR/OpBase.td"
+
+def FooDialect : Dialect {
+ let name = "foo";
+ let dependentDialects = ["func::FuncDialect", "shape::ShapeDialect"];
+}
+
+// CHECK-LABEL: FooDialect::FooDialect
+// CHECK: {
+// CHECK-NEXT: getContext()->loadDialect<func::FuncDialect>();
+// CHECK-NEXT: getContext()->loadDialect<shape::ShapeDialect>();
+// CHECK-NEXT: initialize();
+// CHECK-NEXT: }
diff --git a/mlir/test/mlir-tblgen/pass.td b/mlir/test/mlir-tblgen/pass.td
new file mode 100644
index 00000000000000..fb5580a2e3dd16
--- /dev/null
+++ b/mlir/test/mlir-tblgen/pass.td
@@ -0,0 +1,22 @@
+// RUN: mlir-tblgen -gen-pass-decls -I %S/../../include %s | FileCheck %s
+
+include "mlir/Pass/PassBase.td"
+
+def FooPass : Pass<"foo", "ModuleOp"> {
+ let summary = "A pass for testing pass code generation.";
+ let dependentDialects = ["func::FuncDialect", "shape::ShapeDialect"];
+}
+
+// CHECK-LABEL: GEN_PASS_DEF_FOO
+// CHECK: FooPassBase
+// CHECK: void getDependentDialects
+// CHECK-NEXT: registry.insert<func::FuncDialect>();
+// CHECK-NEXT: registry.insert<shape::ShapeDialect>();
+// CHECK-NEXT: }
+
+// CHECK-LABEL: GEN_PASS_CLASSES
+// CHECK: FooPassBase
+// CHECK: void getDependentDialects
+// CHECK-NEXT: registry.insert<func::FuncDialect>();
+// CHECK-NEXT: registry.insert<shape::ShapeDialect>();
+// CHECK-NEXT: }
diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp
index f22434f755abe3..5d4c294bcd8004 100644
--- a/mlir/tools/mlir-tblgen/DialectGen.cpp
+++ b/mlir/tools/mlir-tblgen/DialectGen.cpp
@@ -106,9 +106,8 @@ class {0} : public ::mlir::{2} {
/// Registration for a single dependent dialect: to be inserted in the ctor
/// above for each dependent dialect.
-const char *const dialectRegistrationTemplate = R"(
- getContext()->loadDialect<{0}>();
-)";
+const char *const dialectRegistrationTemplate =
+ "getContext()->loadDialect<{0}>();";
/// The code block for the attribute parser/printer hooks.
static const char *const attrParserDecl = R"(
@@ -250,8 +249,8 @@ static bool emitDialectDecls(const llvm::RecordKeeper &recordKeeper,
/// The code block to generate a dialect constructor definition.
///
/// {0}: The name of the dialect class.
-/// {1}: initialization code that is emitted in the ctor body before calling
-/// initialize().
+/// {1}: Initialization code that is emitted in the ctor body before calling
+/// initialize(), such as dependent dialect registration.
/// {2}: The dialect parent class.
static const char *const dialectConstructorStr = R"(
{0}::{0}(::mlir::MLIRContext *context)
@@ -261,7 +260,7 @@ static const char *const dialectConstructorStr = R"(
}
)";
-/// The code block to generate a default desturctor definition.
+/// The code block to generate a default destructor definition.
///
/// {0}: The name of the dialect class.
static const char *const dialectDestructorStr = R"(
@@ -284,9 +283,12 @@ static void emitDialectDef(Dialect &dialect, raw_ostream &os) {
std::string dependentDialectRegistrations;
{
llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
- for (StringRef dependentDialect : dialect.getDependentDialects())
- dialectsOs << llvm::formatv(dialectRegistrationTemplate,
- dependentDialect);
+ llvm::interleave(
+ dialect.getDependentDialects(), dialectsOs,
+ [&dialectsOs](StringRef dd) {
+ dialectsOs << llvm::formatv(dialectRegistrationTemplate, dd);
+ },
+ "\n ");
}
// Emit the constructor and destructor.
diff --git a/mlir/tools/mlir-tblgen/PassGen.cpp b/mlir/tools/mlir-tblgen/PassGen.cpp
index de159d144ffbb4..4c895266177429 100644
--- a/mlir/tools/mlir-tblgen/PassGen.cpp
+++ b/mlir/tools/mlir-tblgen/PassGen.cpp
@@ -173,7 +173,8 @@ static void emitRegistrations(llvm::ArrayRef<Pass> passes, raw_ostream &os) {
/// {0}: The def name of the pass record.
/// {1}: The base class for the pass.
/// {2): The command line argument for the pass.
-/// {3}: The dependent dialects registration.
+/// {3}: The summary for the pass.
+/// {4}: The dependent dialects registration.
const char *const baseClassBegin = R"(
template <typename DerivedT>
class {0}Base : public {1} {
@@ -221,9 +222,7 @@ class {0}Base : public {1} {
/// Registration for a single dependent dialect, to be inserted for each
/// dependent dialect in the `getDependentDialects` above.
-const char *const dialectRegistrationTemplate = R"(
- registry.insert<{0}>();
-)";
+const char *const dialectRegistrationTemplate = "registry.insert<{0}>();";
const char *const friendDefaultConstructorDeclTemplate = R"(
namespace impl {{
@@ -307,9 +306,12 @@ static void emitPassDefs(const Pass &pass, raw_ostream &os) {
std::string dependentDialectRegistrations;
{
llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
- for (StringRef dependentDialect : pass.getDependentDialects())
- dialectsOs << llvm::formatv(dialectRegistrationTemplate,
- dependentDialect);
+ llvm::interleave(
+ pass.getDependentDialects(), dialectsOs,
+ [&dialectsOs](StringRef dd) {
+ dialectsOs << llvm::formatv(dialectRegistrationTemplate, dd);
+ },
+ "\n ");
}
os << "namespace impl {\n";
@@ -402,7 +404,7 @@ class {0}Base : public {1} {
return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
}
- /// Return the dialect that must be loaded in the context before this pass.
+ /// Register the dialects that must be loaded in the context before this pass.
void getDependentDialects(::mlir::DialectRegistry ®istry) const override {
{4}
}
@@ -422,9 +424,12 @@ static void emitOldPassDecl(const Pass &pass, raw_ostream &os) {
std::string dependentDialectRegistrations;
{
llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
- for (StringRef dependentDialect : pass.getDependentDialects())
- dialectsOs << llvm::formatv(dialectRegistrationTemplate,
- dependentDialect);
+ llvm::interleave(
+ pass.getDependentDialects(), dialectsOs,
+ [&dialectsOs](StringRef dd) {
+ dialectsOs << llvm::formatv(dialectRegistrationTemplate, dd);
+ },
+ "\n ");
}
os << llvm::formatv(oldPassDeclBegin, defName, pass.getBaseClass(),
pass.getArgument(), pass.getSummary(),
>From 5c42ffc1ad18e09a625c368c3601bced5236e93b Mon Sep 17 00:00:00 2001
From: Michael Levesque-Dion <mlevesquedion at google.com>
Date: Sun, 14 Jan 2024 10:28:04 -0800
Subject: [PATCH 2/2] Fixup: delete tests and address nits
---
.../mlir-tblgen/dialect-with-dependents.td | 15 -------------
mlir/test/mlir-tblgen/pass.td | 22 -------------------
mlir/tools/mlir-tblgen/DialectGen.cpp | 5 +++--
mlir/tools/mlir-tblgen/PassGen.cpp | 10 +++++----
4 files changed, 9 insertions(+), 43 deletions(-)
delete mode 100644 mlir/test/mlir-tblgen/dialect-with-dependents.td
delete mode 100644 mlir/test/mlir-tblgen/pass.td
diff --git a/mlir/test/mlir-tblgen/dialect-with-dependents.td b/mlir/test/mlir-tblgen/dialect-with-dependents.td
deleted file mode 100644
index e915e13841b5e5..00000000000000
--- a/mlir/test/mlir-tblgen/dialect-with-dependents.td
+++ /dev/null
@@ -1,15 +0,0 @@
-// RUN: mlir-tblgen -gen-dialect-defs -I %S/../../include %s | FileCheck %s
-
-include "mlir/IR/OpBase.td"
-
-def FooDialect : Dialect {
- let name = "foo";
- let dependentDialects = ["func::FuncDialect", "shape::ShapeDialect"];
-}
-
-// CHECK-LABEL: FooDialect::FooDialect
-// CHECK: {
-// CHECK-NEXT: getContext()->loadDialect<func::FuncDialect>();
-// CHECK-NEXT: getContext()->loadDialect<shape::ShapeDialect>();
-// CHECK-NEXT: initialize();
-// CHECK-NEXT: }
diff --git a/mlir/test/mlir-tblgen/pass.td b/mlir/test/mlir-tblgen/pass.td
deleted file mode 100644
index fb5580a2e3dd16..00000000000000
--- a/mlir/test/mlir-tblgen/pass.td
+++ /dev/null
@@ -1,22 +0,0 @@
-// RUN: mlir-tblgen -gen-pass-decls -I %S/../../include %s | FileCheck %s
-
-include "mlir/Pass/PassBase.td"
-
-def FooPass : Pass<"foo", "ModuleOp"> {
- let summary = "A pass for testing pass code generation.";
- let dependentDialects = ["func::FuncDialect", "shape::ShapeDialect"];
-}
-
-// CHECK-LABEL: GEN_PASS_DEF_FOO
-// CHECK: FooPassBase
-// CHECK: void getDependentDialects
-// CHECK-NEXT: registry.insert<func::FuncDialect>();
-// CHECK-NEXT: registry.insert<shape::ShapeDialect>();
-// CHECK-NEXT: }
-
-// CHECK-LABEL: GEN_PASS_CLASSES
-// CHECK: FooPassBase
-// CHECK: void getDependentDialects
-// CHECK-NEXT: registry.insert<func::FuncDialect>();
-// CHECK-NEXT: registry.insert<shape::ShapeDialect>();
-// CHECK-NEXT: }
diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp
index 5d4c294bcd8004..4f2021083384fc 100644
--- a/mlir/tools/mlir-tblgen/DialectGen.cpp
+++ b/mlir/tools/mlir-tblgen/DialectGen.cpp
@@ -285,8 +285,9 @@ static void emitDialectDef(Dialect &dialect, raw_ostream &os) {
llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
llvm::interleave(
dialect.getDependentDialects(), dialectsOs,
- [&dialectsOs](StringRef dd) {
- dialectsOs << llvm::formatv(dialectRegistrationTemplate, dd);
+ [&](StringRef dependentDialect) {
+ dialectsOs << llvm::formatv(dialectRegistrationTemplate,
+ dependentDialect);
},
"\n ");
}
diff --git a/mlir/tools/mlir-tblgen/PassGen.cpp b/mlir/tools/mlir-tblgen/PassGen.cpp
index 4c895266177429..11af6497cecf50 100644
--- a/mlir/tools/mlir-tblgen/PassGen.cpp
+++ b/mlir/tools/mlir-tblgen/PassGen.cpp
@@ -308,8 +308,9 @@ static void emitPassDefs(const Pass &pass, raw_ostream &os) {
llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
llvm::interleave(
pass.getDependentDialects(), dialectsOs,
- [&dialectsOs](StringRef dd) {
- dialectsOs << llvm::formatv(dialectRegistrationTemplate, dd);
+ [&](StringRef dependentDialect) {
+ dialectsOs << llvm::formatv(dialectRegistrationTemplate,
+ dependentDialect);
},
"\n ");
}
@@ -426,8 +427,9 @@ static void emitOldPassDecl(const Pass &pass, raw_ostream &os) {
llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
llvm::interleave(
pass.getDependentDialects(), dialectsOs,
- [&dialectsOs](StringRef dd) {
- dialectsOs << llvm::formatv(dialectRegistrationTemplate, dd);
+ [&](StringRef dependentDialect) {
+ dialectsOs << llvm::formatv(dialectRegistrationTemplate,
+ dependentDialect);
},
"\n ");
}
More information about the Mlir-commits
mailing list