[Mlir-commits] [mlir] Add initial `mlir-format` PoC (PR #121260)

Perry Gibson llvmlistbot at llvm.org
Sat Dec 28 02:29:43 PST 2024


https://github.com/Wheest created https://github.com/llvm/llvm-project/pull/121260

This is a draft PR as a PoC of a basic formatting tool for MLIR.  This was first described in [this forum thread](https://discourse.llvm.org/t/clang-format-or-some-other-auto-format-for-mlir-files/75258).

It does not aim to be anything close to `clang-format`.  Instead, it uses the existing dialect printers to format IR.  

An example use-case might be "I'm making some small edits to this MLIR test case, and I want it to be better formatted".  Right now, users would either need to use `mlir-opt` (and reinsert their comments and SSA names by-hand), or do the formatting manually.

```mlir
// Add two values
%x1 = arith.addf       
            %x,     %cst1 : f64
```
:arrow_down:  :magic_wand: 
```mlir
// Add two values
%x1 = arith.addf %x, %cst1 : f64
```

The two features that this tool introduces that cannot be done simply using `mlir-opt` are the following:
1. retain comments 
2. retain SSA names

Feature 1 is achieved by using the original source buffer, and only replacing the character range of the formatted op.  This means comments are left as-is.

Feature 2 is achieved by using the recently added [NameLoc support](https://github.com/llvm/llvm-project/pull/119996), where we can retain identifier names for debugging. `%alice = op()` -> `%0 = op() loc("alice")`, which can then be printed again as `%alice = op()`.

>From a design perspective, this tool adapts [`mlir-rewrite`](https://github.com/llvm/llvm-project/pull/77668).  It creates 2 rewrite buffers, the first one it uses to insert the `loc`, which is then used to generate the formatted ops.  The second rewrite buffer is where these formatted ops are inserted.  I might be able to do it with one buffer, and there are some other code improvements that could be made.

There's a laundry list of small features that would be good to have to make this a good tool:
1. if there are in-op comments, do not format the op 
2. Identify if we have any of the cases which are unsupported by `NameLoc` (e.g., `%group1:2 %group2:3 = return_5vals`)., if so do not format those ops.
3. text editor integration 
4. reduce fragility of named block args

For 2  I have an idea for a workaround to enable this, for 1. we could try and get clever with reinsertion of the comment, but I'd prefer to keep the tool simple for now.



>From 3eebcfdf91ecbbe572624e8a0f360e48e85b6fa6 Mon Sep 17 00:00:00 2001
From: Perry Gibson <perry at fractile.ai>
Date: Sat, 28 Dec 2024 10:18:12 +0000
Subject: [PATCH] Add initial mlir-format PoC

---
 mlir/test/mlir-format/annotate_locs.mlir |  13 +
 mlir/test/mlir-format/simple.mlir        |  15 +
 mlir/test/mlir-format/simple2.mlir       |  19 ++
 mlir/tools/CMakeLists.txt                |   1 +
 mlir/tools/mlir-format/CMakeLists.txt    |  32 ++
 mlir/tools/mlir-format/mlir-format.cpp   | 385 +++++++++++++++++++++++
 6 files changed, 465 insertions(+)
 create mode 100644 mlir/test/mlir-format/annotate_locs.mlir
 create mode 100644 mlir/test/mlir-format/simple.mlir
 create mode 100644 mlir/test/mlir-format/simple2.mlir
 create mode 100644 mlir/tools/mlir-format/CMakeLists.txt
 create mode 100644 mlir/tools/mlir-format/mlir-format.cpp

diff --git a/mlir/test/mlir-format/annotate_locs.mlir b/mlir/test/mlir-format/annotate_locs.mlir
new file mode 100644
index 00000000000000..c557e8447bcf7b
--- /dev/null
+++ b/mlir/test/mlir-format/annotate_locs.mlir
@@ -0,0 +1,13 @@
+// RUN: mlir-format %s --mlir-use-nameloc-as-prefix --insert-name-loc-only | FileCheck %s
+
+// Append NameLocs (`loc("[ssa_name]")`) to operations and block arguments
+
+// CHECK: func.func @add_one(%my_input: f64 loc("my_input"), %my_input2: f64 loc("my_input2")) -> f64 {
+func.func @add_one(%my_input: f64, %my_input2: f64) -> f64 {
+    // CHECK: %my_constant = arith.constant 1.00000e+00 : f64 loc("my_constant")
+    %my_constant = arith.constant 1.00000e+00 : f64
+
+    %my_output = arith.addf %my_input, %my_constant : f64
+    // CHECK: %my_output = arith.addf %my_input, %my_constant : f64 loc("my_output")
+    return %my_output : f64
+}
diff --git a/mlir/test/mlir-format/simple.mlir b/mlir/test/mlir-format/simple.mlir
new file mode 100644
index 00000000000000..0b05fa2a63a1ec
--- /dev/null
+++ b/mlir/test/mlir-format/simple.mlir
@@ -0,0 +1,15 @@
+// RUN: mlir-format %s --mlir-use-nameloc-as-prefix  | FileCheck %s
+
+// CHECK: func.func @add_one(%my_input: f64) -> f64 {
+func.func @add_one(%my_input: f64) -> f64 {
+    // CHECK: %my_constant = arith.constant 1.00000e+00 : f64
+    %my_constant = arith.constant 1.00000e+00 : f64
+    // CHECK: // Dinnae drop this comment!
+    // Dinnae drop this comment!
+    %my_output = arith.addf
+               %my_input,
+               %my_constant : f64
+    // CHECK-STRICT: %my_output = arith.addf %my_input, %my_constant : f64
+    return %my_output : f64
+    // CHECK: return %my_output : f64
+}
diff --git a/mlir/test/mlir-format/simple2.mlir b/mlir/test/mlir-format/simple2.mlir
new file mode 100644
index 00000000000000..beb22ec933fb98
--- /dev/null
+++ b/mlir/test/mlir-format/simple2.mlir
@@ -0,0 +1,19 @@
+// RUN: mlir-format %s --mlir-use-nameloc-as-prefix  | FileCheck %s
+
+// CHECK: func.func @my_func(%x: f64, %y: f64) -> f64 {
+func.func @my_func(%x: f64, %y: f64) -> f64 {
+    // CHECK: %cst1 = arith.constant 1.00000e+00 : f64
+    %cst1 = arith.constant 1.00000e+00 : f64
+    // CHECK: %cst2 = arith.constant 2.00000e+00 : f64
+    %cst2 = arith.constant 2.00000e+00 : f64
+    // CHECK-STRICT: %x1 = arith.addf %x, %cst1 : f64
+    %x1 = arith.addf
+%x,
+%cst1 : f64
+    // CHECK-STRICT: %y2 = arith.addf %y, %cst2 : f64
+    %y2 = arith.addf                %y, %cst2 : f64
+    // CHECK: %z = arith.addf %x1, %y2 : f64
+    %z = arith.addf %x1, %y2 : f64
+    // return %z : f64
+    return %z : f64
+}
diff --git a/mlir/tools/CMakeLists.txt b/mlir/tools/CMakeLists.txt
index 072e83c5d45ea1..0f3d5e6f9731ae 100644
--- a/mlir/tools/CMakeLists.txt
+++ b/mlir/tools/CMakeLists.txt
@@ -5,6 +5,7 @@ add_subdirectory(mlir-pdll-lsp-server)
 add_subdirectory(mlir-query)
 add_subdirectory(mlir-reduce)
 add_subdirectory(mlir-rewrite)
+add_subdirectory(mlir-format)
 add_subdirectory(mlir-shlib)
 add_subdirectory(mlir-translate)
 add_subdirectory(mlir-vulkan-runner)
diff --git a/mlir/tools/mlir-format/CMakeLists.txt b/mlir/tools/mlir-format/CMakeLists.txt
new file mode 100644
index 00000000000000..8ed6f095f0e4b5
--- /dev/null
+++ b/mlir/tools/mlir-format/CMakeLists.txt
@@ -0,0 +1,32 @@
+get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
+set(LLVM_LINK_COMPONENTS
+  Support
+  )
+
+set(LIBS
+  ${dialect_libs}
+
+  MLIRAffineAnalysis
+  MLIRAnalysis
+  MLIRCastInterfaces
+  MLIRDialect
+  MLIRParser
+  MLIRPass
+  MLIRTransforms
+  MLIRTransformUtils
+  MLIRSupport
+  MLIRIR
+  )
+
+include_directories(../../../clang/include)
+
+add_mlir_tool(mlir-format
+  mlir-format.cpp
+
+  SUPPORT_PLUGINS
+  )
+mlir_target_link_libraries(mlir-format PRIVATE ${LIBS})
+llvm_update_compile_flags(mlir-format)
+
+mlir_check_all_link_libraries(mlir-format)
+export_executable_symbols_for_plugins(mlir-format)
diff --git a/mlir/tools/mlir-format/mlir-format.cpp b/mlir/tools/mlir-format/mlir-format.cpp
new file mode 100644
index 00000000000000..ae0122a6627afa
--- /dev/null
+++ b/mlir/tools/mlir-format/mlir-format.cpp
@@ -0,0 +1,385 @@
+#include "mlir/AsmParser/AsmParser.h"
+#include "mlir/AsmParser/AsmParserState.h"
+#include "mlir/IR/AsmState.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/InitAllDialects.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/FileUtilities.h"
+#include "mlir/Tools/ParseUtilities.h"
+#include "llvm/ADT/RewriteBuffer.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/InitLLVM.h"
+#include "llvm/Support/LineIterator.h"
+#include "llvm/Support/Regex.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/ToolOutputFile.h"
+
+using namespace mlir;
+
+namespace mlir {
+using OperationDefinition = AsmParserState::OperationDefinition;
+using BlockDefinition = AsmParserState::BlockDefinition;
+using SMDefinition = AsmParserState::SMDefinition;
+
+inline int asMainReturnCode(LogicalResult r) {
+  return r.succeeded() ? EXIT_SUCCESS : EXIT_FAILURE;
+}
+
+/// Return the source code associated with the OperationDefinition.
+SMRange getOpRange(const OperationDefinition &op) {
+  const char *startOp = op.scopeLoc.Start.getPointer();
+  const char *endOp = op.scopeLoc.End.getPointer();
+
+  for (const auto &res : op.resultGroups) {
+    SMRange range = res.definition.loc;
+    startOp = std::min(startOp, range.Start.getPointer());
+  }
+  return {SMLoc::getFromPointer(startOp), SMLoc::getFromPointer(endOp)};
+}
+
+class CombinedOpDefIterator {
+public:
+  using BaseIterator = AsmParserState::OperationDefIterator;
+  using value_type = std::pair<OperationDefinition &, OperationDefinition &>;
+
+  // Constructor
+  CombinedOpDefIterator(BaseIterator opIter, BaseIterator fmtIter)
+      : opIter(opIter), fmtIter(fmtIter) {}
+
+  // Dereference operator to return a pair of references
+  value_type operator*() const { return {*opIter, *fmtIter}; }
+
+  // Increment operator
+  CombinedOpDefIterator &operator++() {
+    ++opIter;
+    ++fmtIter;
+    return *this;
+  }
+
+  // Equality operator
+  bool operator==(const CombinedOpDefIterator &other) const {
+    return opIter == other.opIter && fmtIter == other.fmtIter;
+  }
+
+  // Inequality operator
+  bool operator!=(const CombinedOpDefIterator &other) const {
+    return !(*this == other);
+  }
+
+private:
+  BaseIterator opIter;
+  BaseIterator fmtIter;
+};
+
+// Function to find the character before the previous comma
+const char *findPrevComma(const char *start, const char *stop_point) {
+  if (!start) {
+    llvm::errs() << "Error: Input pointer is null.\n";
+    return nullptr;
+  }
+
+  const char *current = start - 1; // Start checking backwards
+  while (current >= stop_point) {
+    if (*current == ',') {
+      return current;
+    }
+    --current;
+  }
+
+  llvm::errs() << "Error: No previous comma found before provided pointer.\n";
+  return nullptr;
+}
+
+// Function to find the next closing parenthesis
+const char *findNextCloseParenth(const char *start) {
+  if (!start) {
+    llvm::errs() << "Error: Input pointer is null.\n";
+    return nullptr;
+  }
+
+  while (*start != '\0') { // Traverse until null terminator
+    if (*start == ')') {
+      return start; // Return pointer to the closing parenthesis
+    }
+    ++start;
+  }
+
+  llvm::errs() << "Error: No closing parenthesis found in the string.\n";
+  return nullptr;
+}
+
+class Formatter {
+public:
+  static std::unique_ptr<Formatter> init(StringRef inputFilename,
+                                         StringRef outputFilename);
+
+  /// Return the OperationDefinition's of the operations parsed.
+  iterator_range<AsmParserState::OperationDefIterator> getOpDefs() {
+    return asmState.getOpDefs();
+  }
+
+  /// Return the OperationDefinition's of the blocks parsed.
+  iterator_range<AsmParserState::BlockDefIterator> getBlockDefs() {
+    return asmState.getBlockDefs();
+  }
+
+  /// Return a pair iterators of OperationDefinitions for the asmState and
+  /// fmtAsmState.
+  iterator_range<CombinedOpDefIterator> getCombinedOpDefs() {
+    auto opBegin = asmState.getOpDefs().begin();
+    auto opEnd = asmState.getOpDefs().end();
+    auto fmtBegin = fmtAsmState.getOpDefs().begin();
+    auto fmtEnd = fmtAsmState.getOpDefs().end();
+
+    assert(std::distance(opBegin, opEnd) == std::distance(fmtBegin, fmtEnd) &&
+           "Both iterators must have the same length");
+
+    return llvm::make_range(CombinedOpDefIterator(opBegin, fmtBegin),
+                            CombinedOpDefIterator(opEnd, fmtEnd));
+  }
+
+  /// Print the parsed operations to the provided output stream.
+  void printOps(raw_ostream &os) {
+    // Iterate over each operation in the parsedIR block and print it.
+    for (Operation &op : parsedIR) {
+      op.print(os);
+      os << "\n";
+    }
+  }
+
+  /// Insert the specified string at the specified location in the original
+  /// buffer.
+  void insertText(SMLoc pos, StringRef str, bool insertAfter = true) {
+    rewriteBuffer.InsertText(pos.getPointer() - start, str, insertAfter);
+  }
+
+  /// Replace the range of the source text with the corresponding string in the
+  /// output.
+  void replaceRange(SMRange range, StringRef str) {
+    rewriteBuffer.ReplaceText(range.Start.getPointer() - start,
+                              range.End.getPointer() - range.Start.getPointer(),
+                              str);
+  }
+
+  void replaceRangeFmt(SMRange range, StringRef str) {
+    fmtRewriteBuffer.ReplaceText(
+        range.Start.getPointer() - start,
+        range.End.getPointer() - range.Start.getPointer(), str);
+  }
+
+  /// Write to stream the result of applying all changes to the
+  /// original buffer.
+  /// Note that it isn't safe to use this function to overwrite memory mapped
+  /// files in-place (PR17960).
+  ///
+  /// The original buffer is not actually changed.
+  raw_ostream &write(raw_ostream &stream) const {
+    return rewriteBuffer.write(stream);
+  }
+
+  raw_ostream &writeFmt(raw_ostream &stream) const {
+    return fmtRewriteBuffer.write(stream);
+  }
+
+  /// Generate formatAsmState from the rewriteBuffer
+  void formatOps();
+
+private:
+  // The context and state required to parse.
+  MLIRContext context;
+  llvm::SourceMgr sourceMgr;
+  llvm::SourceMgr fmtSourceMgr;
+  DialectRegistry registry;
+  FallbackAsmResourceMap fallbackResourceMap;
+
+  // Storage of textual parsing results.
+  AsmParserState asmState;
+
+  // Storage of initial formatted ops.
+  AsmParserState fmtAsmState;
+
+  // Parsed IR.
+  Block parsedIR;
+  Block parsedFmtIR;
+
+  // The RewriteBuffer  is doing most of the real work.
+  llvm::RewriteBuffer rewriteBuffer;
+  llvm::RewriteBuffer fmtRewriteBuffer;
+
+  // Start of the original input, used to compute offset.
+  const char *start;
+};
+
+std::unique_ptr<Formatter> Formatter::init(StringRef inputFilename,
+                                           StringRef outputFilename) {
+
+  std::unique_ptr<Formatter> f = std::make_unique<Formatter>();
+  // Register all the dialects needed.
+  registerAllDialects(f->registry);
+
+  // Set up the input file.
+  std::string errorMessage;
+  std::unique_ptr<llvm::MemoryBuffer> file =
+      openInputFile(inputFilename, &errorMessage);
+  if (!file) {
+    llvm::errs() << errorMessage << "\n";
+    return nullptr;
+  }
+  f->sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc());
+
+  // Set up the MLIR context and error handling.
+  f->context.appendDialectRegistry(f->registry);
+
+  // Record the start of the buffer to compute offsets with.
+  unsigned curBuf = f->sourceMgr.getMainFileID();
+  const llvm::MemoryBuffer *curMB = f->sourceMgr.getMemoryBuffer(curBuf);
+  f->start = curMB->getBufferStart();
+  f->rewriteBuffer.Initialize(curMB->getBuffer());
+  f->fmtRewriteBuffer.Initialize(curMB->getBuffer());
+
+  // Parse and populate the AsmParserState.
+  ParserConfig parseConfig(&f->context, /*verifyAfterParse=*/true,
+                           &f->fallbackResourceMap);
+  // Always allow unregistered.
+  f->context.allowUnregisteredDialects(true);
+  if (failed(parseAsmSourceFile(f->sourceMgr, &f->parsedIR, parseConfig,
+                                &f->asmState)))
+    return nullptr;
+
+  return f;
+}
+
+void Formatter::formatOps() {
+  // Generate formatAsmState from the rewriteBuffer
+  ParserConfig parseConfig(&context, /*verifyAfterParse=*/true,
+                           &fallbackResourceMap);
+
+  // Write the rewriteBuffer to a stream, that we can then parse
+  std::string bufferContent;
+  llvm::raw_string_ostream stream(bufferContent);
+  rewriteBuffer.write(stream);
+  stream.flush();
+
+  // Print the bufferContent to llvm::outs() for debugging.
+  fmtSourceMgr.AddNewSourceBuffer(
+      llvm::MemoryBuffer::getMemBufferCopy(bufferContent), SMLoc());
+
+  // Parse and populate the forrmat AsmParserState.
+  if (failed(parseAsmSourceFile(fmtSourceMgr, &parsedFmtIR, parseConfig,
+                                &fmtAsmState)))
+    return;
+
+  // Insert the formatted ops.  Block args should be untouched,
+  // and their references will use the correct SSA ID.
+  for (auto [opDef, fmtDef] : getCombinedOpDefs()) {
+    auto [startOp, endOp] = getOpRange(opDef);
+
+    // Skip if the op is a FuncOp (we format ops in its body)
+    // or a ReturnOp (we want to keep the user's preference for
+    // `func.return` or plain `return`)
+    if (llvm::dyn_cast<mlir::func::FuncOp>(fmtDef.op))
+      continue;
+    else if (llvm::dyn_cast<mlir::func::ReturnOp>(fmtDef.op))
+      continue;
+
+    // Print the fmtDef op and store as a string.
+    // Replace the opDef with this formatted string.
+    std::string formattedStr;
+    llvm::raw_string_ostream stream(formattedStr);
+    fmtDef.op->print(stream);
+
+    // Replacing the range:
+    replaceRangeFmt({startOp, endOp}, formattedStr);
+  }
+
+  // Write the updated buffer to llvm::outs()
+  writeFmt(llvm::outs());
+}
+
+void markNames(Formatter &formatState, raw_ostream &os) {
+  // Get the operation definitions from the AsmParserState.
+  for (OperationDefinition &it : formatState.getOpDefs()) {
+    auto [startOp, endOp] = getOpRange(it);
+    // loop through the resultgroups
+    for (auto &resultGroup : it.resultGroups) {
+      auto def = resultGroup.definition;
+      auto sm_range = def.loc;
+      const char *start = sm_range.Start.getPointer();
+      int len = sm_range.End.getPointer() - start;
+      // Drop the % prefix, and put in new string with  `loc("name")` format.
+      auto name = StringRef(start + 1, len - 1);
+
+      // Add loc("{name}") to the end of the op
+      std::string formattedStr = " loc(\"" + name.str() + "\")";
+      StringRef namedLoc(formattedStr);
+      formatState.insertText(endOp, namedLoc);
+    }
+  }
+
+  // Insert the NameLocs for the block arguments
+  for (BlockDefinition &block : formatState.getBlockDefs()) {
+    for (size_t i = 0; i < block.arguments.size(); ++i) {
+      SMDefinition &arg = block.arguments[i];
+
+      // Find where to insert the NameLoc.  Either before the next argument,
+      // or at the end of the arg list
+      const char *insertPointPtr;
+      const char *arg_end = arg.loc.End.getPointer();
+      SMDefinition *nextArg =
+          (i + 1 < block.arguments.size()) ? &block.arguments[i + 1] : nullptr;
+      if (nextArg) {
+        const char *nextStart = nextArg->loc.Start.getPointer();
+        insertPointPtr = findPrevComma(nextStart, arg_end);
+      } else {
+        insertPointPtr = findNextCloseParenth(arg.loc.End.getPointer());
+      }
+
+      // Drop the % prefix, and put in new string with  `loc("name")` format.
+      const char *start = arg.loc.Start.getPointer();
+      const int len = arg_end - start;
+      auto name = StringRef(start + 1, len - 1);
+      std::string formattedStr = " loc(\"" + name.str() + "\")";
+      StringRef namedLoc(formattedStr);
+      formatState.insertText(SMLoc::getFromPointer(insertPointPtr), namedLoc);
+    }
+  }
+}
+} // namespace mlir
+
+int main(int argc, char **argv) {
+  registerAsmPrinterCLOptions();
+  static llvm::cl::opt<std::string> inputFilename(
+      llvm::cl::Positional, llvm::cl::desc("<input file>"),
+      llvm::cl::init("-"));
+
+  static llvm::cl::opt<std::string> outputFilename(
+      "o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
+      llvm::cl::init("-"));
+
+  llvm::cl::opt<bool> nameLocOnly{
+      "insert-name-loc-only", llvm::cl::init(false),
+      llvm::cl::desc("Only return a buffer with the NameLocs appended")};
+
+  std::string helpHeader = "mlir-format";
+
+  llvm::cl::ParseCommandLineOptions(argc, argv, helpHeader);
+
+  // Set up formatter buffer.
+  auto f = Formatter::init(inputFilename, outputFilename);
+
+  // Append the SSA names as NameLocs
+  markNames(*f, llvm::outs());
+
+  if (nameLocOnly) {
+    // Return the original buffer with NameLocs appended to ops
+    // e.g., `%alice = memref.load %0[] : memref<i32> loc("alice")`
+    f->write(llvm::outs());
+    return mlir::asMainReturnCode(mlir::success());
+  }
+
+  f->formatOps();
+
+  return mlir::asMainReturnCode(mlir::success());
+}



More information about the Mlir-commits mailing list