[llvm] [mlir] [mlir] Start rewrite tool (PR #77668)

Jacques Pienaar via llvm-commits llvm-commits at lists.llvm.org
Sat Oct 12 14:32:47 PDT 2024

https://github.com/jpienaar updated https://github.com/llvm/llvm-project/pull/77668

>From 95da1e08b77b55ec417341f9147634211ac88e35 Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jpienaar at google.com>
Date: Sun, 7 Jan 2024 21:31:41 -0800
Subject: [PATCH] [mlir] Start rewrite tool

Initial commit of a tool to help in textual rewrites of .mlir files.
This tool builds of of AsmParserState and is rather simple. Took some
inspiration from when I used clang's AST rewrites where I'd often treat
it as a "localizing" regex applicator in fallback cases, and started
with that as functionality. There though, one does have access to the
lower level info than here, but still a step up over sed over entire

This aims to be helpful (e.g., rewrite syntax including best effort
inside comments) rather than bulletproof tool. It may even be better
suited under utils than tools. And most of the rewrites would be rather
short lived and might never make it upstream (while the helpers of those
rewrites may for future rewrites).

Started it as a single file to prototype more easily, planning to
refactor later to include and libs for out of file usage.
 mlir/docs/Tools/mlir-rewrite.md               |  29 ++
 mlir/test/CMakeLists.txt                      |   1 +
 mlir/test/mlir-rewrite/simple.mlir            |  11 +
 mlir/tools/CMakeLists.txt                     |   1 +
 mlir/tools/mlir-rewrite/CMakeLists.txt        |  35 ++
 mlir/tools/mlir-rewrite/mlir-rewrite.cpp      | 392 ++++++++++++++++++
 .../llvm-project-overlay/mlir/BUILD.bazel     |  24 ++
 7 files changed, 493 insertions(+)
 create mode 100644 mlir/docs/Tools/mlir-rewrite.md
 create mode 100644 mlir/test/mlir-rewrite/simple.mlir
 create mode 100644 mlir/tools/mlir-rewrite/CMakeLists.txt
 create mode 100644 mlir/tools/mlir-rewrite/mlir-rewrite.cpp

diff --git a/mlir/docs/Tools/mlir-rewrite.md b/mlir/docs/Tools/mlir-rewrite.md
new file mode 100644
index 00000000000000..178f92f72cbb6e
--- /dev/null
+++ b/mlir/docs/Tools/mlir-rewrite.md
@@ -0,0 +1,29 @@
+# mlir-rewrite
+Tool to simplify rewriting .mlir files. There are a couple of build in rewrites
+discussed below along with usage.
+Note: This is still in very early stage. Its so early its less a tool than a
+growing collection of useful functions: to use its best to do what's needed on
+a brance by just hacking it (dialects registered, rewrites etc) to say help
+ease a rename, upstream useful utility functions, point to ease others
+migrating, and then bin eventually. Once there are actually useful parts it
+should be refactored same as mlir-opt.
+## simple-rename
+Rename per op given a substring to a target. The match and replace uses LLVM's
+regex sub for the match and replace while the op-name is matched via regular
+string comparison. E.g.,
+mlir-rewrite input.mlir -o output.mlir --simple-rename \
+   --simple-rename-op-name="test.concat" --simple-rename-match="axis" \
+                                         --simple-rename-replace="bxis"
+to replace `axis` substring in the text of the range corresponding to
+`test.concat` ops with `bxis`.
diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt
index 4d2d738b734ec6..361981605a76b7 100644
--- a/mlir/test/CMakeLists.txt
+++ b/mlir/test/CMakeLists.txt
@@ -115,6 +115,7 @@ set(MLIR_TEST_DEPENDS
+  mlir-rewrite
diff --git a/mlir/test/mlir-rewrite/simple.mlir b/mlir/test/mlir-rewrite/simple.mlir
new file mode 100644
index 00000000000000..ab6bfe24fccf03
--- /dev/null
+++ b/mlir/test/mlir-rewrite/simple.mlir
@@ -0,0 +1,11 @@
+// RUN: mlir-opt %s | mlir-rewrite --simple-rename --simple-rename-op-name="test.concat" --simple-rename-match="axis" --simple-rename-replace="bxis" | FileCheck %s -check-prefix=RENAME
+// RUN: mlir-opt %s | mlir-rewrite --mark-ranges | FileCheck %s -check-prefix=RANGE
+// Note: running through mlir-opt to just strip out comments & avoid self matches.
+func.func @two_dynamic_one_direct_shape(%arg0: tensor<?x4x?xf32>, %arg1: tensor<2x4x?xf32>) -> tensor<?x4x?xf32> {
+  // RENAME: "test.concat"({{.*}}) {bxis = 0 : i64}
+  // RANGE: 《%{{.*}} = 〖"test.concat"〗({{.*}}) {axis = 0 : i64} : (tensor<?x4x?xf32>, tensor<2x4x?xf32>) -> tensor<?x4x?xf32>》
+  %5 = "test.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<?x4x?xf32>, tensor<2x4x?xf32>) -> tensor<?x4x?xf32>
+  return %5 : tensor<?x4x?xf32>
diff --git a/mlir/tools/CMakeLists.txt b/mlir/tools/CMakeLists.txt
index 9b474385fdae18..0a2d0ff2915099 100644
--- a/mlir/tools/CMakeLists.txt
+++ b/mlir/tools/CMakeLists.txt
@@ -4,6 +4,7 @@ add_subdirectory(mlir-parser-fuzzer)
diff --git a/mlir/tools/mlir-rewrite/CMakeLists.txt b/mlir/tools/mlir-rewrite/CMakeLists.txt
new file mode 100644
index 00000000000000..5b8c1cd4553997
--- /dev/null
+++ b/mlir/tools/mlir-rewrite/CMakeLists.txt
@@ -0,0 +1,35 @@
+get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
+  Support
+  )
+  ${dialect_libs}
+  ${test_libs}
+  MLIRAffineAnalysis
+  MLIRAnalysis
+  MLIRCastInterfaces
+  MLIRDialect
+  MLIRParser
+  MLIRPass
+  MLIRTransforms
+  MLIRTransformUtils
+  MLIRSupport
+  )
+  mlir-rewrite.cpp
+  ${LIBS}
+  )
+target_link_libraries(mlir-rewrite PRIVATE ${LIBS})
diff --git a/mlir/tools/mlir-rewrite/mlir-rewrite.cpp b/mlir/tools/mlir-rewrite/mlir-rewrite.cpp
new file mode 100644
index 00000000000000..308e6490726c86
--- /dev/null
+++ b/mlir/tools/mlir-rewrite/mlir-rewrite.cpp
@@ -0,0 +1,392 @@
+//===- mlir-rewrite.cpp - MLIR Rewrite Driver -----------------------------===//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+// Main entry function for mlir-rewrite.
+#include "mlir/AsmParser/AsmParser.h"
+#include "mlir/AsmParser/AsmParserState.h"
+#include "mlir/IR/AsmState.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/InitAllDialects.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/FileUtilities.h"
+#include "mlir/Tools/ParseUtilities.h"
+#include "llvm/ADT/RewriteBuffer.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/InitLLVM.h"
+#include "llvm/Support/LineIterator.h"
+#include "llvm/Support/Regex.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/ToolOutputFile.h"
+using namespace mlir;
+namespace mlir {
+using OperationDefinition = AsmParserState::OperationDefinition;
+/// Return the source code associated with the OperationDefinition.
+SMRange getOpRange(const OperationDefinition &op) {
+  const char *startOp = op.scopeLoc.Start.getPointer();
+  const char *endOp = op.scopeLoc.End.getPointer();
+  for (auto res : op.resultGroups) {
+    SMRange range = res.definition.loc;
+    startOp = std::min(startOp, range.Start.getPointer());
+  }
+  return {SMLoc::getFromPointer(startOp), SMLoc::getFromPointer(endOp)};
+/// Helper to simplify rewriting the source file.
+class RewritePad {
+  static std::unique_ptr<RewritePad> init(StringRef inputFilename,
+                                          StringRef outputFilename);
+  /// Return the context the file was parsed into.
+  MLIRContext *getContext() { return &context; }
+  /// Return the OperationDefinition's of the operation's parsed.
+  iterator_range<AsmParserState::OperationDefIterator> getOpDefs() {
+    return asmState.getOpDefs();
+  }
+  /// Insert the specified string at the specified location in the original
+  /// buffer.
+  void insertText(SMLoc pos, StringRef str, bool insertAfter = true) {
+    rewriteBuffer.InsertText(pos.getPointer() - start, str, insertAfter);
+  }
+  /// Replace the range of the source text with the corresponding string in the
+  /// output.
+  void replaceRange(SMRange range, StringRef str) {
+    rewriteBuffer.ReplaceText(range.Start.getPointer() - start,
+                              range.End.getPointer() - range.Start.getPointer(),
+                              str);
+  }
+  /// Replace the range of the operation in the source text with the
+  /// corresponding string in the output.
+  void replaceDef(const OperationDefinition &opDef, StringRef newDef) {
+    replaceRange(getOpRange(opDef), newDef);
+  }
+  /// Return the source string corresponding to the source range.
+  StringRef getSourceString(SMRange range) {
+    return StringRef(range.Start.getPointer(),
+                     range.End.getPointer() - range.Start.getPointer());
+  }
+  /// Return the source string corresponding to operation definition.
+  StringRef getSourceString(const OperationDefinition &opDef) {
+    auto range = getOpRange(opDef);
+    return getSourceString(range);
+  }
+  /// Write to stream the result of applying all changes to the
+  /// original buffer.
+  /// Note that it isn't safe to use this function to overwrite memory mapped
+  /// files in-place (PR17960).
+  ///
+  /// The original buffer is not actually changed.
+  raw_ostream &write(raw_ostream &stream) const {
+    return rewriteBuffer.write(stream);
+  }
+  /// Return lines that are purely comments.
+  SmallVector<SMRange> getSingleLineComments() {
+    unsigned curBuf = sourceMgr.getMainFileID();
+    const llvm::MemoryBuffer *curMB = sourceMgr.getMemoryBuffer(curBuf);
+    llvm::line_iterator lineIterator(*curMB);
+    SmallVector<SMRange> ret;
+    for (; !lineIterator.is_at_end(); ++lineIterator) {
+      StringRef trimmed = lineIterator->ltrim();
+      if (trimmed.starts_with("//")) {
+        ret.emplace_back(
+            SMLoc::getFromPointer(trimmed.data()),
+            SMLoc::getFromPointer(trimmed.data() + trimmed.size()));
+      }
+    }
+    return ret;
+  }
+  /// Return the IR from parsed file.
+  Block *getParsed() { return &parsedIR; }
+  /// Return the definition for the given operation, or nullptr if the given
+  /// operation does not have a definition.
+  const OperationDefinition &getOpDef(Operation *op) const {
+    return *asmState.getOpDef(op);
+  }
+  // The context and state required to parse.
+  MLIRContext context;
+  llvm::SourceMgr sourceMgr;
+  DialectRegistry registry;
+  FallbackAsmResourceMap fallbackResourceMap;
+  // Storage of textual parsing results.
+  AsmParserState asmState;
+  // Parsed IR.
+  Block parsedIR;
+  // The RewriteBuffer  is doing most of the real work.
+  llvm::RewriteBuffer rewriteBuffer;
+  // Start of the original input, used to compute offset.
+  const char *start;
+std::unique_ptr<RewritePad> RewritePad::init(StringRef inputFilename,
+                                             StringRef outputFilename) {
+  std::unique_ptr<RewritePad> r = std::make_unique<RewritePad>();
+  // Register all the dialects needed.
+  registerAllDialects(r->registry);
+  // Set up the input file.
+  std::string errorMessage;
+  std::unique_ptr<llvm::MemoryBuffer> file =
+      openInputFile(inputFilename, &errorMessage);
+  if (!file) {
+    llvm::errs() << errorMessage << "\n";
+    return nullptr;
+  }
+  r->sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc());
+  // Set up the MLIR context and error handling.
+  r->context.appendDialectRegistry(r->registry);
+  // Record the start of the buffer to compute offsets with.
+  unsigned curBuf = r->sourceMgr.getMainFileID();
+  const llvm::MemoryBuffer *curMB = r->sourceMgr.getMemoryBuffer(curBuf);
+  r->start = curMB->getBufferStart();
+  r->rewriteBuffer.Initialize(curMB->getBuffer());
+  // Parse and populate the AsmParserState.
+  ParserConfig parseConfig(&r->context, /*verifyAfterParse=*/true,
+                           &r->fallbackResourceMap);
+  // Always allow unregistered.
+  r->context.allowUnregisteredDialects(true);
+  if (failed(parseAsmSourceFile(r->sourceMgr, &r->parsedIR, parseConfig,
+                                &r->asmState)))
+    return nullptr;
+  return r;
+/// Return the source code associated with the operation name.
+SMRange getOpNameRange(const OperationDefinition &op) { return op.loc; }
+/// Return whether the operation was printed using generic syntax in original
+/// buffer.
+bool isGeneric(const OperationDefinition &op) {
+  return op.loc.Start.getPointer()[0] == '"';
+inline int asMainReturnCode(LogicalResult r) {
+  return r.succeeded() ? EXIT_SUCCESS : EXIT_FAILURE;
+/// Reriter function to invoke.
+using RewriterFunction = std::function<mlir::LogicalResult(
+    mlir::RewritePad &rewriteState, llvm::raw_ostream &os)>;
+/// Structure to group information about a rewriter (argument to invoke via
+/// mlir-tblgen, description, and rewriter function).
+class RewriterInfo {
+  /// RewriterInfo constructor should not be invoked directly, instead use
+  /// RewriterRegistration or registerRewriter.
+  RewriterInfo(StringRef arg, StringRef description, RewriterFunction rewriter)
+      : arg(arg), description(description), rewriter(std::move(rewriter)) {}
+  /// Invokes the rewriter and returns whether the rewriter failed.
+  LogicalResult invoke(mlir::RewritePad &rewriteState, raw_ostream &os) const {
+    assert(rewriter && "Cannot call rewriter with null rewriter");
+    return rewriter(rewriteState, os);
+  }
+  /// Returns the command line option that may be passed to 'mlir-rewrite' to
+  /// invoke this rewriter.
+  StringRef getRewriterArgument() const { return arg; }
+  /// Returns a description for the rewriter.
+  StringRef getRewriterDescription() const { return description; }
+  // The argument with which to invoke the rewriter via mlir-tblgen.
+  StringRef arg;
+  // Description of the rewriter.
+  StringRef description;
+  // Rewritererator function.
+  RewriterFunction rewriter;
+static llvm::ManagedStatic<std::vector<RewriterInfo>> rewriterRegistry;
+/// Adds command line option for each registered rewriter.
+struct RewriterNameParser : public llvm::cl::parser<const RewriterInfo *> {
+  RewriterNameParser(llvm::cl::Option &opt);
+  void printOptionInfo(const llvm::cl::Option &o,
+                       size_t globalWidth) const override;
+/// RewriterRegistration provides a global initializer that registers a rewriter
+/// function.
+struct RewriterRegistration {
+  RewriterRegistration(StringRef arg, StringRef description,
+                       const RewriterFunction &function);
+RewriterRegistration::RewriterRegistration(StringRef arg, StringRef description,
+                                           const RewriterFunction &function) {
+  rewriterRegistry->emplace_back(arg, description, function);
+RewriterNameParser::RewriterNameParser(llvm::cl::Option &opt)
+    : llvm::cl::parser<const RewriterInfo *>(opt) {
+  for (const auto &kv : *rewriterRegistry) {
+    addLiteralOption(kv.getRewriterArgument(), &kv,
+                     kv.getRewriterDescription());
+  }
+void RewriterNameParser::printOptionInfo(const llvm::cl::Option &o,
+                                         size_t globalWidth) const {
+  RewriterNameParser *tp = const_cast<RewriterNameParser *>(this);
+  llvm::array_pod_sort(tp->Values.begin(), tp->Values.end(),
+                       [](const RewriterNameParser::OptionInfo *vT1,
+                          const RewriterNameParser::OptionInfo *vT2) {
+                         return vT1->Name.compare(vT2->Name);
+                       });
+  using llvm::cl::parser;
+  parser<const RewriterInfo *>::printOptionInfo(o, globalWidth);
+} // namespace mlir
+// TODO: Make these injectable too in non-global way.
+static llvm::cl::OptionCategory clSimpleRenameCategory{"simple-rename options"};
+static llvm::cl::opt<std::string> simpleRenameOpName{
+    "simple-rename-op-name", llvm::cl::desc("Name of op to match on"),
+    llvm::cl::cat(clSimpleRenameCategory)};
+static llvm::cl::opt<std::string> simpleRenameMatch{
+    "simple-rename-match", llvm::cl::desc("Match string for rename"),
+    llvm::cl::cat(clSimpleRenameCategory)};
+static llvm::cl::opt<std::string> simpleRenameReplace{
+    "simple-rename-replace", llvm::cl::desc("Replace string for rename"),
+    llvm::cl::cat(clSimpleRenameCategory)};
+// Rewriter that does simple renames.
+LogicalResult simpleRename(RewritePad &rewriteState, raw_ostream &os) {
+  StringRef opName = simpleRenameOpName;
+  StringRef match = simpleRenameMatch;
+  StringRef replace = simpleRenameReplace;
+  llvm::Regex regex(match);
+  rewriteState.getParsed()->walk([&](Operation *op) {
+    if (op->getName().getStringRef() != opName)
+      return;
+    const OperationDefinition &opDef = rewriteState.getOpDef(op);
+    SMRange range = getOpRange(opDef);
+    // This is a little bit overkill for simple.
+    std::string str = regex.sub(replace, rewriteState.getSourceString(range));
+    rewriteState.replaceRange(range, str);
+  });
+  return success();
+static mlir::RewriterRegistration rewriteSimpleRename("simple-rename",
+                                                      "Perform a simple rename",
+                                                      simpleRename);
+// Rewriter that insert range markers.
+LogicalResult markRanges(RewritePad &rewriteState, raw_ostream &os) {
+  for (auto it : rewriteState.getOpDefs()) {
+    auto [startOp, endOp] = getOpRange(it);
+    rewriteState.insertText(startOp, "《");
+    rewriteState.insertText(endOp, "》");
+    auto nameRange = getOpNameRange(it);
+    if (isGeneric(it)) {
+      rewriteState.insertText(nameRange.Start, "〖");
+      rewriteState.insertText(nameRange.End, "〗");
+    } else {
+      rewriteState.insertText(nameRange.Start, "〔");
+      rewriteState.insertText(nameRange.End, "〕");
+    }
+  }
+  // Highlight all comment lines.
+  // TODO: Could be replaced if this is kept in memory.
+  for (auto commentLine : rewriteState.getSingleLineComments()) {
+    rewriteState.insertText(commentLine.Start, "❰");
+    rewriteState.insertText(commentLine.End, "❱");
+  }
+  return success();
+static mlir::RewriterRegistration
+    rewriteMarkRanges("mark-ranges", "Indicate ranges parsed", markRanges);
+int main(int argc, char **argv) {
+  static llvm::cl::opt<std::string> inputFilename(
+      llvm::cl::Positional, llvm::cl::desc("<input file>"),
+      llvm::cl::init("-"));
+  static llvm::cl::opt<std::string> outputFilename(
+      "o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
+      llvm::cl::init("-"));
+  llvm::cl::opt<const mlir::RewriterInfo *, false, mlir::RewriterNameParser>
+      rewriter("", llvm::cl::desc("Rewriter to run"));
+  std::string helpHeader = "mlir-rewrite";
+  llvm::cl::ParseCommandLineOptions(argc, argv, helpHeader);
+  // If no rewriter has been selected, exit with error code. Could also just
+  // return but its unlikely this was intentionally being used as `cp`.
+  if (!rewriter) {
+    llvm::errs() << "No rewriter selected!\n";
+    return mlir::asMainReturnCode(mlir::failure());
+  }
+  // Set up rewrite buffer.
+  auto rewriterOr = RewritePad::init(inputFilename, outputFilename);
+  if (!rewriterOr)
+    return mlir::asMainReturnCode(mlir::failure());
+  // Set up the output file.
+  std::string errorMessage;
+  auto output = openOutputFile(outputFilename, &errorMessage);
+  if (!output) {
+    llvm::errs() << errorMessage << "\n";
+    return mlir::asMainReturnCode(mlir::failure());
+  }
+  LogicalResult result = rewriter->invoke(*rewriterOr, output->os());
+  if (succeeded(result)) {
+    rewriterOr->write(output->os());
+    output->keep();
+  }
+  return mlir::asMainReturnCode(result);
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index e52439f00879fe..2e520fca978e26 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -9740,6 +9740,28 @@ cc_binary(
+    name = "mlir-rewrite",
+    srcs = ["tools/mlir-rewrite/mlir-rewrite.cpp"],
+    deps = [
+        ":AllExtensions",
+        ":AllPassesAndDialects",
+        ":AffineAnalysis",
+        ":Analysis",
+        ":AsmParser",
+        ":CastInterfaces",
+        ":Dialect",
+        ":Parser",
+        ":ParseUtilities",
+        ":Pass",
+        ":Transforms",
+        ":TransformUtils",
+        ":Support",
+        ":IR",
+        "//llvm:Support",
+    ]
     name = "MlirJitRunner",
     srcs = ["lib/ExecutionEngine/JitRunner.cpp"],
@@ -10603,6 +10625,7 @@ cc_library(
+        ":ConvertToLLVMInterface",
@@ -10650,6 +10673,7 @@ cc_library(
+        ":ConvertToLLVMInterface",

More information about the llvm-commits mailing list