[Mlir-commits] [mlir] Add initial `mlir-format` PoC (PR #121260)
Perry Gibson
llvmlistbot at llvm.org
Sat Dec 28 02:29:43 PST 2024
https://github.com/Wheest created https://github.com/llvm/llvm-project/pull/121260
This is a draft PR as a PoC of a basic formatting tool for MLIR. This was first described in [this forum thread](https://discourse.llvm.org/t/clang-format-or-some-other-auto-format-for-mlir-files/75258).
It does not aim to be anything close to `clang-format`. Instead, it uses the existing dialect printers to format IR.
An example use-case might be "I'm making some small edits to this MLIR test case, and I want it to be better formatted". Right now, users would either need to use `mlir-opt` (and reinsert their comments and SSA names by-hand), or do the formatting manually.
```mlir
// Add two values
%x1 = arith.addf
%x, %cst1 : f64
```
:arrow_down: :magic_wand:
```mlir
// Add two values
%x1 = arith.addf %x, %cst1 : f64
```
The two features that this tool introduces that cannot be done simply using `mlir-opt` are the following:
1. retain comments
2. retain SSA names
Feature 1 is achieved by using the original source buffer, and only replacing the character range of the formatted op. This means comments are left as-is.
Feature 2 is achieved by using the recently added [NameLoc support](https://github.com/llvm/llvm-project/pull/119996), where we can retain identifier names for debugging. `%alice = op()` -> `%0 = op() loc("alice")`, which can then be printed again as `%alice = op()`.
>From a design perspective, this tool adapts [`mlir-rewrite`](https://github.com/llvm/llvm-project/pull/77668). It creates 2 rewrite buffers, the first one it uses to insert the `loc`, which is then used to generate the formatted ops. The second rewrite buffer is where these formatted ops are inserted. I might be able to do it with one buffer, and there are some other code improvements that could be made.
There's a laundry list of small features that would be good to have to make this a good tool:
1. if there are in-op comments, do not format the op
2. Identify if we have any of the cases which are unsupported by `NameLoc` (e.g., `%group1:2 %group2:3 = return_5vals`)., if so do not format those ops.
3. text editor integration
4. reduce fragility of named block args
For 2 I have an idea for a workaround to enable this, for 1. we could try and get clever with reinsertion of the comment, but I'd prefer to keep the tool simple for now.
>From 3eebcfdf91ecbbe572624e8a0f360e48e85b6fa6 Mon Sep 17 00:00:00 2001
From: Perry Gibson <perry at fractile.ai>
Date: Sat, 28 Dec 2024 10:18:12 +0000
Subject: [PATCH] Add initial mlir-format PoC
---
mlir/test/mlir-format/annotate_locs.mlir | 13 +
mlir/test/mlir-format/simple.mlir | 15 +
mlir/test/mlir-format/simple2.mlir | 19 ++
mlir/tools/CMakeLists.txt | 1 +
mlir/tools/mlir-format/CMakeLists.txt | 32 ++
mlir/tools/mlir-format/mlir-format.cpp | 385 +++++++++++++++++++++++
6 files changed, 465 insertions(+)
create mode 100644 mlir/test/mlir-format/annotate_locs.mlir
create mode 100644 mlir/test/mlir-format/simple.mlir
create mode 100644 mlir/test/mlir-format/simple2.mlir
create mode 100644 mlir/tools/mlir-format/CMakeLists.txt
create mode 100644 mlir/tools/mlir-format/mlir-format.cpp
diff --git a/mlir/test/mlir-format/annotate_locs.mlir b/mlir/test/mlir-format/annotate_locs.mlir
new file mode 100644
index 00000000000000..c557e8447bcf7b
--- /dev/null
+++ b/mlir/test/mlir-format/annotate_locs.mlir
@@ -0,0 +1,13 @@
+// RUN: mlir-format %s --mlir-use-nameloc-as-prefix --insert-name-loc-only | FileCheck %s
+
+// Append NameLocs (`loc("[ssa_name]")`) to operations and block arguments
+
+// CHECK: func.func @add_one(%my_input: f64 loc("my_input"), %my_input2: f64 loc("my_input2")) -> f64 {
+func.func @add_one(%my_input: f64, %my_input2: f64) -> f64 {
+ // CHECK: %my_constant = arith.constant 1.00000e+00 : f64 loc("my_constant")
+ %my_constant = arith.constant 1.00000e+00 : f64
+
+ %my_output = arith.addf %my_input, %my_constant : f64
+ // CHECK: %my_output = arith.addf %my_input, %my_constant : f64 loc("my_output")
+ return %my_output : f64
+}
diff --git a/mlir/test/mlir-format/simple.mlir b/mlir/test/mlir-format/simple.mlir
new file mode 100644
index 00000000000000..0b05fa2a63a1ec
--- /dev/null
+++ b/mlir/test/mlir-format/simple.mlir
@@ -0,0 +1,15 @@
+// RUN: mlir-format %s --mlir-use-nameloc-as-prefix | FileCheck %s
+
+// CHECK: func.func @add_one(%my_input: f64) -> f64 {
+func.func @add_one(%my_input: f64) -> f64 {
+ // CHECK: %my_constant = arith.constant 1.00000e+00 : f64
+ %my_constant = arith.constant 1.00000e+00 : f64
+ // CHECK: // Dinnae drop this comment!
+ // Dinnae drop this comment!
+ %my_output = arith.addf
+ %my_input,
+ %my_constant : f64
+ // CHECK-STRICT: %my_output = arith.addf %my_input, %my_constant : f64
+ return %my_output : f64
+ // CHECK: return %my_output : f64
+}
diff --git a/mlir/test/mlir-format/simple2.mlir b/mlir/test/mlir-format/simple2.mlir
new file mode 100644
index 00000000000000..beb22ec933fb98
--- /dev/null
+++ b/mlir/test/mlir-format/simple2.mlir
@@ -0,0 +1,19 @@
+// RUN: mlir-format %s --mlir-use-nameloc-as-prefix | FileCheck %s
+
+// CHECK: func.func @my_func(%x: f64, %y: f64) -> f64 {
+func.func @my_func(%x: f64, %y: f64) -> f64 {
+ // CHECK: %cst1 = arith.constant 1.00000e+00 : f64
+ %cst1 = arith.constant 1.00000e+00 : f64
+ // CHECK: %cst2 = arith.constant 2.00000e+00 : f64
+ %cst2 = arith.constant 2.00000e+00 : f64
+ // CHECK-STRICT: %x1 = arith.addf %x, %cst1 : f64
+ %x1 = arith.addf
+%x,
+%cst1 : f64
+ // CHECK-STRICT: %y2 = arith.addf %y, %cst2 : f64
+ %y2 = arith.addf %y, %cst2 : f64
+ // CHECK: %z = arith.addf %x1, %y2 : f64
+ %z = arith.addf %x1, %y2 : f64
+ // return %z : f64
+ return %z : f64
+}
diff --git a/mlir/tools/CMakeLists.txt b/mlir/tools/CMakeLists.txt
index 072e83c5d45ea1..0f3d5e6f9731ae 100644
--- a/mlir/tools/CMakeLists.txt
+++ b/mlir/tools/CMakeLists.txt
@@ -5,6 +5,7 @@ add_subdirectory(mlir-pdll-lsp-server)
add_subdirectory(mlir-query)
add_subdirectory(mlir-reduce)
add_subdirectory(mlir-rewrite)
+add_subdirectory(mlir-format)
add_subdirectory(mlir-shlib)
add_subdirectory(mlir-translate)
add_subdirectory(mlir-vulkan-runner)
diff --git a/mlir/tools/mlir-format/CMakeLists.txt b/mlir/tools/mlir-format/CMakeLists.txt
new file mode 100644
index 00000000000000..8ed6f095f0e4b5
--- /dev/null
+++ b/mlir/tools/mlir-format/CMakeLists.txt
@@ -0,0 +1,32 @@
+get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
+set(LLVM_LINK_COMPONENTS
+ Support
+ )
+
+set(LIBS
+ ${dialect_libs}
+
+ MLIRAffineAnalysis
+ MLIRAnalysis
+ MLIRCastInterfaces
+ MLIRDialect
+ MLIRParser
+ MLIRPass
+ MLIRTransforms
+ MLIRTransformUtils
+ MLIRSupport
+ MLIRIR
+ )
+
+include_directories(../../../clang/include)
+
+add_mlir_tool(mlir-format
+ mlir-format.cpp
+
+ SUPPORT_PLUGINS
+ )
+mlir_target_link_libraries(mlir-format PRIVATE ${LIBS})
+llvm_update_compile_flags(mlir-format)
+
+mlir_check_all_link_libraries(mlir-format)
+export_executable_symbols_for_plugins(mlir-format)
diff --git a/mlir/tools/mlir-format/mlir-format.cpp b/mlir/tools/mlir-format/mlir-format.cpp
new file mode 100644
index 00000000000000..ae0122a6627afa
--- /dev/null
+++ b/mlir/tools/mlir-format/mlir-format.cpp
@@ -0,0 +1,385 @@
+#include "mlir/AsmParser/AsmParser.h"
+#include "mlir/AsmParser/AsmParserState.h"
+#include "mlir/IR/AsmState.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/InitAllDialects.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/FileUtilities.h"
+#include "mlir/Tools/ParseUtilities.h"
+#include "llvm/ADT/RewriteBuffer.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/InitLLVM.h"
+#include "llvm/Support/LineIterator.h"
+#include "llvm/Support/Regex.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/ToolOutputFile.h"
+
+using namespace mlir;
+
+namespace mlir {
+using OperationDefinition = AsmParserState::OperationDefinition;
+using BlockDefinition = AsmParserState::BlockDefinition;
+using SMDefinition = AsmParserState::SMDefinition;
+
+inline int asMainReturnCode(LogicalResult r) {
+ return r.succeeded() ? EXIT_SUCCESS : EXIT_FAILURE;
+}
+
+/// Return the source code associated with the OperationDefinition.
+SMRange getOpRange(const OperationDefinition &op) {
+ const char *startOp = op.scopeLoc.Start.getPointer();
+ const char *endOp = op.scopeLoc.End.getPointer();
+
+ for (const auto &res : op.resultGroups) {
+ SMRange range = res.definition.loc;
+ startOp = std::min(startOp, range.Start.getPointer());
+ }
+ return {SMLoc::getFromPointer(startOp), SMLoc::getFromPointer(endOp)};
+}
+
+class CombinedOpDefIterator {
+public:
+ using BaseIterator = AsmParserState::OperationDefIterator;
+ using value_type = std::pair<OperationDefinition &, OperationDefinition &>;
+
+ // Constructor
+ CombinedOpDefIterator(BaseIterator opIter, BaseIterator fmtIter)
+ : opIter(opIter), fmtIter(fmtIter) {}
+
+ // Dereference operator to return a pair of references
+ value_type operator*() const { return {*opIter, *fmtIter}; }
+
+ // Increment operator
+ CombinedOpDefIterator &operator++() {
+ ++opIter;
+ ++fmtIter;
+ return *this;
+ }
+
+ // Equality operator
+ bool operator==(const CombinedOpDefIterator &other) const {
+ return opIter == other.opIter && fmtIter == other.fmtIter;
+ }
+
+ // Inequality operator
+ bool operator!=(const CombinedOpDefIterator &other) const {
+ return !(*this == other);
+ }
+
+private:
+ BaseIterator opIter;
+ BaseIterator fmtIter;
+};
+
+// Function to find the character before the previous comma
+const char *findPrevComma(const char *start, const char *stop_point) {
+ if (!start) {
+ llvm::errs() << "Error: Input pointer is null.\n";
+ return nullptr;
+ }
+
+ const char *current = start - 1; // Start checking backwards
+ while (current >= stop_point) {
+ if (*current == ',') {
+ return current;
+ }
+ --current;
+ }
+
+ llvm::errs() << "Error: No previous comma found before provided pointer.\n";
+ return nullptr;
+}
+
+// Function to find the next closing parenthesis
+const char *findNextCloseParenth(const char *start) {
+ if (!start) {
+ llvm::errs() << "Error: Input pointer is null.\n";
+ return nullptr;
+ }
+
+ while (*start != '\0') { // Traverse until null terminator
+ if (*start == ')') {
+ return start; // Return pointer to the closing parenthesis
+ }
+ ++start;
+ }
+
+ llvm::errs() << "Error: No closing parenthesis found in the string.\n";
+ return nullptr;
+}
+
+class Formatter {
+public:
+ static std::unique_ptr<Formatter> init(StringRef inputFilename,
+ StringRef outputFilename);
+
+ /// Return the OperationDefinition's of the operations parsed.
+ iterator_range<AsmParserState::OperationDefIterator> getOpDefs() {
+ return asmState.getOpDefs();
+ }
+
+ /// Return the OperationDefinition's of the blocks parsed.
+ iterator_range<AsmParserState::BlockDefIterator> getBlockDefs() {
+ return asmState.getBlockDefs();
+ }
+
+ /// Return a pair iterators of OperationDefinitions for the asmState and
+ /// fmtAsmState.
+ iterator_range<CombinedOpDefIterator> getCombinedOpDefs() {
+ auto opBegin = asmState.getOpDefs().begin();
+ auto opEnd = asmState.getOpDefs().end();
+ auto fmtBegin = fmtAsmState.getOpDefs().begin();
+ auto fmtEnd = fmtAsmState.getOpDefs().end();
+
+ assert(std::distance(opBegin, opEnd) == std::distance(fmtBegin, fmtEnd) &&
+ "Both iterators must have the same length");
+
+ return llvm::make_range(CombinedOpDefIterator(opBegin, fmtBegin),
+ CombinedOpDefIterator(opEnd, fmtEnd));
+ }
+
+ /// Print the parsed operations to the provided output stream.
+ void printOps(raw_ostream &os) {
+ // Iterate over each operation in the parsedIR block and print it.
+ for (Operation &op : parsedIR) {
+ op.print(os);
+ os << "\n";
+ }
+ }
+
+ /// Insert the specified string at the specified location in the original
+ /// buffer.
+ void insertText(SMLoc pos, StringRef str, bool insertAfter = true) {
+ rewriteBuffer.InsertText(pos.getPointer() - start, str, insertAfter);
+ }
+
+ /// Replace the range of the source text with the corresponding string in the
+ /// output.
+ void replaceRange(SMRange range, StringRef str) {
+ rewriteBuffer.ReplaceText(range.Start.getPointer() - start,
+ range.End.getPointer() - range.Start.getPointer(),
+ str);
+ }
+
+ void replaceRangeFmt(SMRange range, StringRef str) {
+ fmtRewriteBuffer.ReplaceText(
+ range.Start.getPointer() - start,
+ range.End.getPointer() - range.Start.getPointer(), str);
+ }
+
+ /// Write to stream the result of applying all changes to the
+ /// original buffer.
+ /// Note that it isn't safe to use this function to overwrite memory mapped
+ /// files in-place (PR17960).
+ ///
+ /// The original buffer is not actually changed.
+ raw_ostream &write(raw_ostream &stream) const {
+ return rewriteBuffer.write(stream);
+ }
+
+ raw_ostream &writeFmt(raw_ostream &stream) const {
+ return fmtRewriteBuffer.write(stream);
+ }
+
+ /// Generate formatAsmState from the rewriteBuffer
+ void formatOps();
+
+private:
+ // The context and state required to parse.
+ MLIRContext context;
+ llvm::SourceMgr sourceMgr;
+ llvm::SourceMgr fmtSourceMgr;
+ DialectRegistry registry;
+ FallbackAsmResourceMap fallbackResourceMap;
+
+ // Storage of textual parsing results.
+ AsmParserState asmState;
+
+ // Storage of initial formatted ops.
+ AsmParserState fmtAsmState;
+
+ // Parsed IR.
+ Block parsedIR;
+ Block parsedFmtIR;
+
+ // The RewriteBuffer is doing most of the real work.
+ llvm::RewriteBuffer rewriteBuffer;
+ llvm::RewriteBuffer fmtRewriteBuffer;
+
+ // Start of the original input, used to compute offset.
+ const char *start;
+};
+
+std::unique_ptr<Formatter> Formatter::init(StringRef inputFilename,
+ StringRef outputFilename) {
+
+ std::unique_ptr<Formatter> f = std::make_unique<Formatter>();
+ // Register all the dialects needed.
+ registerAllDialects(f->registry);
+
+ // Set up the input file.
+ std::string errorMessage;
+ std::unique_ptr<llvm::MemoryBuffer> file =
+ openInputFile(inputFilename, &errorMessage);
+ if (!file) {
+ llvm::errs() << errorMessage << "\n";
+ return nullptr;
+ }
+ f->sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc());
+
+ // Set up the MLIR context and error handling.
+ f->context.appendDialectRegistry(f->registry);
+
+ // Record the start of the buffer to compute offsets with.
+ unsigned curBuf = f->sourceMgr.getMainFileID();
+ const llvm::MemoryBuffer *curMB = f->sourceMgr.getMemoryBuffer(curBuf);
+ f->start = curMB->getBufferStart();
+ f->rewriteBuffer.Initialize(curMB->getBuffer());
+ f->fmtRewriteBuffer.Initialize(curMB->getBuffer());
+
+ // Parse and populate the AsmParserState.
+ ParserConfig parseConfig(&f->context, /*verifyAfterParse=*/true,
+ &f->fallbackResourceMap);
+ // Always allow unregistered.
+ f->context.allowUnregisteredDialects(true);
+ if (failed(parseAsmSourceFile(f->sourceMgr, &f->parsedIR, parseConfig,
+ &f->asmState)))
+ return nullptr;
+
+ return f;
+}
+
+void Formatter::formatOps() {
+ // Generate formatAsmState from the rewriteBuffer
+ ParserConfig parseConfig(&context, /*verifyAfterParse=*/true,
+ &fallbackResourceMap);
+
+ // Write the rewriteBuffer to a stream, that we can then parse
+ std::string bufferContent;
+ llvm::raw_string_ostream stream(bufferContent);
+ rewriteBuffer.write(stream);
+ stream.flush();
+
+ // Print the bufferContent to llvm::outs() for debugging.
+ fmtSourceMgr.AddNewSourceBuffer(
+ llvm::MemoryBuffer::getMemBufferCopy(bufferContent), SMLoc());
+
+ // Parse and populate the forrmat AsmParserState.
+ if (failed(parseAsmSourceFile(fmtSourceMgr, &parsedFmtIR, parseConfig,
+ &fmtAsmState)))
+ return;
+
+ // Insert the formatted ops. Block args should be untouched,
+ // and their references will use the correct SSA ID.
+ for (auto [opDef, fmtDef] : getCombinedOpDefs()) {
+ auto [startOp, endOp] = getOpRange(opDef);
+
+ // Skip if the op is a FuncOp (we format ops in its body)
+ // or a ReturnOp (we want to keep the user's preference for
+ // `func.return` or plain `return`)
+ if (llvm::dyn_cast<mlir::func::FuncOp>(fmtDef.op))
+ continue;
+ else if (llvm::dyn_cast<mlir::func::ReturnOp>(fmtDef.op))
+ continue;
+
+ // Print the fmtDef op and store as a string.
+ // Replace the opDef with this formatted string.
+ std::string formattedStr;
+ llvm::raw_string_ostream stream(formattedStr);
+ fmtDef.op->print(stream);
+
+ // Replacing the range:
+ replaceRangeFmt({startOp, endOp}, formattedStr);
+ }
+
+ // Write the updated buffer to llvm::outs()
+ writeFmt(llvm::outs());
+}
+
+void markNames(Formatter &formatState, raw_ostream &os) {
+ // Get the operation definitions from the AsmParserState.
+ for (OperationDefinition &it : formatState.getOpDefs()) {
+ auto [startOp, endOp] = getOpRange(it);
+ // loop through the resultgroups
+ for (auto &resultGroup : it.resultGroups) {
+ auto def = resultGroup.definition;
+ auto sm_range = def.loc;
+ const char *start = sm_range.Start.getPointer();
+ int len = sm_range.End.getPointer() - start;
+ // Drop the % prefix, and put in new string with `loc("name")` format.
+ auto name = StringRef(start + 1, len - 1);
+
+ // Add loc("{name}") to the end of the op
+ std::string formattedStr = " loc(\"" + name.str() + "\")";
+ StringRef namedLoc(formattedStr);
+ formatState.insertText(endOp, namedLoc);
+ }
+ }
+
+ // Insert the NameLocs for the block arguments
+ for (BlockDefinition &block : formatState.getBlockDefs()) {
+ for (size_t i = 0; i < block.arguments.size(); ++i) {
+ SMDefinition &arg = block.arguments[i];
+
+ // Find where to insert the NameLoc. Either before the next argument,
+ // or at the end of the arg list
+ const char *insertPointPtr;
+ const char *arg_end = arg.loc.End.getPointer();
+ SMDefinition *nextArg =
+ (i + 1 < block.arguments.size()) ? &block.arguments[i + 1] : nullptr;
+ if (nextArg) {
+ const char *nextStart = nextArg->loc.Start.getPointer();
+ insertPointPtr = findPrevComma(nextStart, arg_end);
+ } else {
+ insertPointPtr = findNextCloseParenth(arg.loc.End.getPointer());
+ }
+
+ // Drop the % prefix, and put in new string with `loc("name")` format.
+ const char *start = arg.loc.Start.getPointer();
+ const int len = arg_end - start;
+ auto name = StringRef(start + 1, len - 1);
+ std::string formattedStr = " loc(\"" + name.str() + "\")";
+ StringRef namedLoc(formattedStr);
+ formatState.insertText(SMLoc::getFromPointer(insertPointPtr), namedLoc);
+ }
+ }
+}
+} // namespace mlir
+
+int main(int argc, char **argv) {
+ registerAsmPrinterCLOptions();
+ static llvm::cl::opt<std::string> inputFilename(
+ llvm::cl::Positional, llvm::cl::desc("<input file>"),
+ llvm::cl::init("-"));
+
+ static llvm::cl::opt<std::string> outputFilename(
+ "o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
+ llvm::cl::init("-"));
+
+ llvm::cl::opt<bool> nameLocOnly{
+ "insert-name-loc-only", llvm::cl::init(false),
+ llvm::cl::desc("Only return a buffer with the NameLocs appended")};
+
+ std::string helpHeader = "mlir-format";
+
+ llvm::cl::ParseCommandLineOptions(argc, argv, helpHeader);
+
+ // Set up formatter buffer.
+ auto f = Formatter::init(inputFilename, outputFilename);
+
+ // Append the SSA names as NameLocs
+ markNames(*f, llvm::outs());
+
+ if (nameLocOnly) {
+ // Return the original buffer with NameLocs appended to ops
+ // e.g., `%alice = memref.load %0[] : memref<i32> loc("alice")`
+ f->write(llvm::outs());
+ return mlir::asMainReturnCode(mlir::success());
+ }
+
+ f->formatOps();
+
+ return mlir::asMainReturnCode(mlir::success());
+}
More information about the Mlir-commits
mailing list