[Mlir-commits] [mlir] 3dff20c - [mlir] Reformat whitespace in dependent dialects codegen (#78090)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Jan 15 02:11:57 PST 2024
Author: mlevesquedion
Date: 2024-01-15T11:11:52+01:00
New Revision: 3dff20cfa27e0988840d5d13a169482269aa4fa5
URL: https://github.com/llvm/llvm-project/commit/3dff20cfa27e0988840d5d13a169482269aa4fa5
DIFF: https://github.com/llvm/llvm-project/commit/3dff20cfa27e0988840d5d13a169482269aa4fa5.diff
LOG: [mlir] Reformat whitespace in dependent dialects codegen (#78090)
The generated code for dependent dialects is awkwardly formatted, making
the code harder to read. This change reformats the whitespace to align
code in its context and avoid unnecessary empty lines.
Also included are some typo fixes.
Below are examples of the codegen for a dialect before and after the
change.
Before:
```
GPUDialect::GPUDialect(::mlir::MLIRContext *context)
: ::mlir::Dialect(getDialectNamespace(), context, ::mlir::TypeID::get<GPUDialect>()) {
getContext()->loadDialect<arith::ArithDialect>();
initialize();
}
```
After:
```
GPUDialect::GPUDialect(::mlir::MLIRContext *context)
: ::mlir::Dialect(getDialectNamespace(), context, ::mlir::TypeID::get<GPUDialect>()) {
getContext()->loadDialect<arith::ArithDialect>();
initialize();
}
```
Below are examples of the codegen for a pass before and after the
change.
Before:
```
/// Return the dialect that must be loaded in the context before this pass.
void getDependentDialects(::mlir::DialectRegistry ®istry) const override {
registry.insert<func::FuncDialect>();
registry.insert<tensor::TensorDialect>();
registry.insert<tosa::TosaDialect>();
}
```
After:
```
/// Register the dialects that must be loaded in the context before this pass.
void getDependentDialects(::mlir::DialectRegistry ®istry) const override {
registry.insert<func::FuncDialect>();
registry.insert<tensor::TensorDialect>();
registry.insert<tosa::TosaDialect>();
}
```
Added:
Modified:
mlir/tools/mlir-tblgen/DialectGen.cpp
mlir/tools/mlir-tblgen/PassGen.cpp
Removed:
################################################################################
diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp
index f22434f755abe3..4f2021083384fc 100644
--- a/mlir/tools/mlir-tblgen/DialectGen.cpp
+++ b/mlir/tools/mlir-tblgen/DialectGen.cpp
@@ -106,9 +106,8 @@ class {0} : public ::mlir::{2} {
/// Registration for a single dependent dialect: to be inserted in the ctor
/// above for each dependent dialect.
-const char *const dialectRegistrationTemplate = R"(
- getContext()->loadDialect<{0}>();
-)";
+const char *const dialectRegistrationTemplate =
+ "getContext()->loadDialect<{0}>();";
/// The code block for the attribute parser/printer hooks.
static const char *const attrParserDecl = R"(
@@ -250,8 +249,8 @@ static bool emitDialectDecls(const llvm::RecordKeeper &recordKeeper,
/// The code block to generate a dialect constructor definition.
///
/// {0}: The name of the dialect class.
-/// {1}: initialization code that is emitted in the ctor body before calling
-/// initialize().
+/// {1}: Initialization code that is emitted in the ctor body before calling
+/// initialize(), such as dependent dialect registration.
/// {2}: The dialect parent class.
static const char *const dialectConstructorStr = R"(
{0}::{0}(::mlir::MLIRContext *context)
@@ -261,7 +260,7 @@ static const char *const dialectConstructorStr = R"(
}
)";
-/// The code block to generate a default desturctor definition.
+/// The code block to generate a default destructor definition.
///
/// {0}: The name of the dialect class.
static const char *const dialectDestructorStr = R"(
@@ -284,9 +283,13 @@ static void emitDialectDef(Dialect &dialect, raw_ostream &os) {
std::string dependentDialectRegistrations;
{
llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
- for (StringRef dependentDialect : dialect.getDependentDialects())
- dialectsOs << llvm::formatv(dialectRegistrationTemplate,
- dependentDialect);
+ llvm::interleave(
+ dialect.getDependentDialects(), dialectsOs,
+ [&](StringRef dependentDialect) {
+ dialectsOs << llvm::formatv(dialectRegistrationTemplate,
+ dependentDialect);
+ },
+ "\n ");
}
// Emit the constructor and destructor.
diff --git a/mlir/tools/mlir-tblgen/PassGen.cpp b/mlir/tools/mlir-tblgen/PassGen.cpp
index de159d144ffbb4..11af6497cecf50 100644
--- a/mlir/tools/mlir-tblgen/PassGen.cpp
+++ b/mlir/tools/mlir-tblgen/PassGen.cpp
@@ -173,7 +173,8 @@ static void emitRegistrations(llvm::ArrayRef<Pass> passes, raw_ostream &os) {
/// {0}: The def name of the pass record.
/// {1}: The base class for the pass.
/// {2): The command line argument for the pass.
-/// {3}: The dependent dialects registration.
+/// {3}: The summary for the pass.
+/// {4}: The dependent dialects registration.
const char *const baseClassBegin = R"(
template <typename DerivedT>
class {0}Base : public {1} {
@@ -221,9 +222,7 @@ class {0}Base : public {1} {
/// Registration for a single dependent dialect, to be inserted for each
/// dependent dialect in the `getDependentDialects` above.
-const char *const dialectRegistrationTemplate = R"(
- registry.insert<{0}>();
-)";
+const char *const dialectRegistrationTemplate = "registry.insert<{0}>();";
const char *const friendDefaultConstructorDeclTemplate = R"(
namespace impl {{
@@ -307,9 +306,13 @@ static void emitPassDefs(const Pass &pass, raw_ostream &os) {
std::string dependentDialectRegistrations;
{
llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
- for (StringRef dependentDialect : pass.getDependentDialects())
- dialectsOs << llvm::formatv(dialectRegistrationTemplate,
- dependentDialect);
+ llvm::interleave(
+ pass.getDependentDialects(), dialectsOs,
+ [&](StringRef dependentDialect) {
+ dialectsOs << llvm::formatv(dialectRegistrationTemplate,
+ dependentDialect);
+ },
+ "\n ");
}
os << "namespace impl {\n";
@@ -402,7 +405,7 @@ class {0}Base : public {1} {
return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
}
- /// Return the dialect that must be loaded in the context before this pass.
+ /// Register the dialects that must be loaded in the context before this pass.
void getDependentDialects(::mlir::DialectRegistry ®istry) const override {
{4}
}
@@ -422,9 +425,13 @@ static void emitOldPassDecl(const Pass &pass, raw_ostream &os) {
std::string dependentDialectRegistrations;
{
llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
- for (StringRef dependentDialect : pass.getDependentDialects())
- dialectsOs << llvm::formatv(dialectRegistrationTemplate,
- dependentDialect);
+ llvm::interleave(
+ pass.getDependentDialects(), dialectsOs,
+ [&](StringRef dependentDialect) {
+ dialectsOs << llvm::formatv(dialectRegistrationTemplate,
+ dependentDialect);
+ },
+ "\n ");
}
os << llvm::formatv(oldPassDeclBegin, defName, pass.getBaseClass(),
pass.getArgument(), pass.getSummary(),
More information about the Mlir-commits
mailing list