[Mlir-commits] [mlir] [mlir] Start rewrite tool (PR #77668)

Jacques Pienaar llvmlistbot at llvm.org
Sun Aug 18 20:24:32 PDT 2024


https://github.com/jpienaar updated https://github.com/llvm/llvm-project/pull/77668

>From 702ac0f10815b662610ebeb73f3de3c64ad84a8e Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jpienaar at google.com>
Date: Sun, 7 Jan 2024 21:31:41 -0800
Subject: [PATCH] [mlir] Start rewrite tool

Initial commit of a tool to help in textual rewrites of .mlir files.
This tool builds of of AsmParserState and is rather simple. Took some
inspiration from when I used clang's AST rewrites where I'd often treat
it as a "localizing" regex applicator in fallback cases, and started
with that as functionality. There though, one does have access to the
lower level info than here, but still a step up over sed over entire
file.

This aims to be helpful (e.g., rewrite syntax including best effort
inside comments) rather than bulletproof tool. It may even be better
suited under utils than tools. And most of the rewrites would be rather
short lived and might never make it upstream (while the helpers of those
rewrites may for future rewrites).

Started it as a single file to prototype more easily, planning to
refactor later to include and libs for out of file usage.
---
 mlir/CMakeLists.txt                      |   1 +
 mlir/docs/Tools/mlir-rewrite.md          |  29 ++
 mlir/test/CMakeLists.txt                 |   1 +
 mlir/test/mlir-rewrite/simple.mlir       |  11 +
 mlir/tools/CMakeLists.txt                |   2 +
 mlir/tools/mlir-rewrite/CMakeLists.txt   |  36 +++
 mlir/tools/mlir-rewrite/mlir-rewrite.cpp | 390 +++++++++++++++++++++++
 7 files changed, 470 insertions(+)
 create mode 100644 mlir/docs/Tools/mlir-rewrite.md
 create mode 100644 mlir/test/mlir-rewrite/simple.mlir
 create mode 100644 mlir/tools/mlir-rewrite/CMakeLists.txt
 create mode 100644 mlir/tools/mlir-rewrite/mlir-rewrite.cpp

diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt
index c6d44908a1111d..549dee8e2be33c 100644
--- a/mlir/CMakeLists.txt
+++ b/mlir/CMakeLists.txt
@@ -285,3 +285,4 @@ endif()
 if(MLIR_STANDALONE_BUILD)
   llvm_distribution_add_targets()
 endif()
+
diff --git a/mlir/docs/Tools/mlir-rewrite.md b/mlir/docs/Tools/mlir-rewrite.md
new file mode 100644
index 00000000000000..178f92f72cbb6e
--- /dev/null
+++ b/mlir/docs/Tools/mlir-rewrite.md
@@ -0,0 +1,29 @@
+# mlir-rewrite
+
+Tool to simplify rewriting .mlir files. There are a couple of build in rewrites
+discussed below along with usage.
+
+Note: This is still in very early stage. Its so early its less a tool than a
+growing collection of useful functions: to use its best to do what's needed on
+a brance by just hacking it (dialects registered, rewrites etc) to say help
+ease a rename, upstream useful utility functions, point to ease others
+migrating, and then bin eventually. Once there are actually useful parts it
+should be refactored same as mlir-opt.
+
+[TOC]
+
+## simple-rename
+
+Rename per op given a substring to a target. The match and replace uses LLVM's
+regex sub for the match and replace while the op-name is matched via regular
+string comparison. E.g.,
+
+```
+mlir-rewrite input.mlir -o output.mlir --simple-rename \
+   --simple-rename-op-name="test.concat" --simple-rename-match="axis" \
+                                         --simple-rename-replace="bxis"
+```
+
+to replace `axis` substring in the text of the range corresponding to
+`test.concat` ops with `bxis`.
+
diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt
index df95e5db11f1e0..89297127772637 100644
--- a/mlir/test/CMakeLists.txt
+++ b/mlir/test/CMakeLists.txt
@@ -115,6 +115,7 @@ set(MLIR_TEST_DEPENDS
   mlir-opt
   mlir-query
   mlir-reduce
+  mlir-rewrite
   mlir-tblgen
   mlir-translate
   tblgen-lsp-server
diff --git a/mlir/test/mlir-rewrite/simple.mlir b/mlir/test/mlir-rewrite/simple.mlir
new file mode 100644
index 00000000000000..ab6bfe24fccf03
--- /dev/null
+++ b/mlir/test/mlir-rewrite/simple.mlir
@@ -0,0 +1,11 @@
+// RUN: mlir-opt %s | mlir-rewrite --simple-rename --simple-rename-op-name="test.concat" --simple-rename-match="axis" --simple-rename-replace="bxis" | FileCheck %s -check-prefix=RENAME
+// RUN: mlir-opt %s | mlir-rewrite --mark-ranges | FileCheck %s -check-prefix=RANGE
+// Note: running through mlir-opt to just strip out comments & avoid self matches.
+
+func.func @two_dynamic_one_direct_shape(%arg0: tensor<?x4x?xf32>, %arg1: tensor<2x4x?xf32>) -> tensor<?x4x?xf32> {
+  // RENAME: "test.concat"({{.*}}) {bxis = 0 : i64}
+  // RANGE: 《%{{.*}} = 〖"test.concat"〗({{.*}}) {axis = 0 : i64} : (tensor<?x4x?xf32>, tensor<2x4x?xf32>) -> tensor<?x4x?xf32>》
+  %5 = "test.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<?x4x?xf32>, tensor<2x4x?xf32>) -> tensor<?x4x?xf32>
+  return %5 : tensor<?x4x?xf32>
+}
+
diff --git a/mlir/tools/CMakeLists.txt b/mlir/tools/CMakeLists.txt
index 9b474385fdae18..55c773942ecead 100644
--- a/mlir/tools/CMakeLists.txt
+++ b/mlir/tools/CMakeLists.txt
@@ -4,6 +4,7 @@ add_subdirectory(mlir-parser-fuzzer)
 add_subdirectory(mlir-pdll-lsp-server)
 add_subdirectory(mlir-query)
 add_subdirectory(mlir-reduce)
+add_subdirectory(mlir-rewrite)
 add_subdirectory(mlir-shlib)
 add_subdirectory(mlir-spirv-cpu-runner)
 add_subdirectory(mlir-translate)
@@ -15,3 +16,4 @@ add_subdirectory(tblgen-to-irdl)
 if(MLIR_ENABLE_EXECUTION_ENGINE)
   add_subdirectory(mlir-cpu-runner)
 endif()
+
diff --git a/mlir/tools/mlir-rewrite/CMakeLists.txt b/mlir/tools/mlir-rewrite/CMakeLists.txt
new file mode 100644
index 00000000000000..fc114c878f1bbe
--- /dev/null
+++ b/mlir/tools/mlir-rewrite/CMakeLists.txt
@@ -0,0 +1,36 @@
+get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
+set(LLVM_LINK_COMPONENTS
+  Support
+  )
+
+set(LIBS
+  ${dialect_libs}
+  ${test_libs}
+
+  MLIRAffineAnalysis
+  MLIRAnalysis
+  MLIRCastInterfaces
+  MLIRDialect
+  MLIROptLib
+  MLIRParser
+  MLIRPass
+  MLIRTransforms
+  MLIRTransformUtils
+  MLIRSupport
+  MLIRIR
+  )
+
+include_directories(../../../clang/include)
+
+add_mlir_tool(mlir-rewrite
+  mlir-rewrite.cpp
+
+  DEPENDS
+  ${LIBS}
+  SUPPORT_PLUGINS
+  )
+target_link_libraries(mlir-rewrite PRIVATE ${LIBS})
+llvm_update_compile_flags(mlir-rewrite)
+
+mlir_check_all_link_libraries(mlir-rewrite)
+export_executable_symbols_for_plugins(mlir-rewrite)
diff --git a/mlir/tools/mlir-rewrite/mlir-rewrite.cpp b/mlir/tools/mlir-rewrite/mlir-rewrite.cpp
new file mode 100644
index 00000000000000..48a2a46fc2cd82
--- /dev/null
+++ b/mlir/tools/mlir-rewrite/mlir-rewrite.cpp
@@ -0,0 +1,390 @@
+//===- mlir-rewrite.cpp - MLIR Rewrite Driver -----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Main entry function for mlir-rewrite.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/AsmParser/AsmParser.h"
+#include "mlir/AsmParser/AsmParserState.h"
+#include "mlir/IR/AsmState.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/InitAllDialects.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/FileUtilities.h"
+#include "mlir/Tools/ParseUtilities.h"
+#include "llvm/ADT/RewriteBuffer.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/InitLLVM.h"
+#include "llvm/Support/LineIterator.h"
+#include "llvm/Support/Regex.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/ToolOutputFile.h"
+
+using namespace mlir;
+
+namespace mlir {
+using OperationDefinition = AsmParserState::OperationDefinition;
+
+/// Return the source code associated with the OperationDefinition.
+SMRange getOpRange(const OperationDefinition &op) {
+  const char *startOp = op.scopeLoc.Start.getPointer();
+  const char *endOp = op.scopeLoc.End.getPointer();
+
+  for (auto res : op.resultGroups) {
+    SMRange range = res.definition.loc;
+    startOp = std::min(startOp, range.Start.getPointer());
+  }
+  return {SMLoc::getFromPointer(startOp), SMLoc::getFromPointer(endOp)};
+}
+
+/// Helper to simplify rewriting the source file.
+class RewriteBuffer {
+public:
+  static std::unique_ptr<RewriteBuffer> init(StringRef inputFilename,
+                                             StringRef outputFilename);
+
+  /// Return the context the file was parsed into.
+  MLIRContext *getContext() { return &context; }
+
+  /// Return the OperationDefinition's of the operation's parsed.
+  auto getOpDefs() { return asmState.getOpDefs(); }
+
+  /// Insert the specified string at the specified location in the original
+  /// buffer.
+  void insertText(SMLoc pos, StringRef str, bool insertAfter = true) {
+    rewriteBuffer.InsertText(pos.getPointer() - start, str, insertAfter);
+  }
+
+  /// Replace the range of the source text with the corresponding string in the
+  /// output.
+  void replaceRange(SMRange range, StringRef str) {
+    rewriteBuffer.ReplaceText(range.Start.getPointer() - start,
+                              range.End.getPointer() - range.Start.getPointer(),
+                              str);
+  }
+
+  /// Replace the range of the operation in the source text with the
+  /// corresponding string in the output.
+  void replaceDef(const OperationDefinition &opDef, StringRef newDef) {
+    replaceRange(getOpRange(opDef), newDef);
+  }
+
+  /// Return the source string corresponding to the source range.
+  StringRef getSourceString(SMRange range) {
+    return StringRef(range.Start.getPointer(),
+                     range.End.getPointer() - range.Start.getPointer());
+  }
+
+  /// Return the source string corresponding to operation definition.
+  StringRef getSourceString(const OperationDefinition &opDef) {
+    auto range = getOpRange(opDef);
+    return getSourceString(range);
+  }
+
+  /// Write to stream the result of applying all changes to the
+  /// original buffer.
+  /// Note that it isn't safe to use this function to overwrite memory mapped
+  /// files in-place (PR17960).
+  ///
+  /// The original buffer is not actually changed.
+  raw_ostream &write(raw_ostream &stream) const {
+    return rewriteBuffer.write(stream);
+  }
+
+  /// Return lines that are purely comments.
+  SmallVector<SMRange> getSingleLineComments() {
+    unsigned curBuf = sourceMgr.getMainFileID();
+    const llvm::MemoryBuffer *curMB = sourceMgr.getMemoryBuffer(curBuf);
+    llvm::line_iterator lineIterator(*curMB);
+    SmallVector<SMRange> ret;
+    for (; !lineIterator.is_at_end(); ++lineIterator) {
+      StringRef trimmed = lineIterator->ltrim();
+      if (trimmed.starts_with("//")) {
+        ret.emplace_back(
+            SMLoc::getFromPointer(trimmed.data()),
+            SMLoc::getFromPointer(trimmed.data() + trimmed.size()));
+      }
+    }
+    return ret;
+  }
+
+  /// Return the IR from parsed file.
+  Block *getParsed() { return &parsedIR; }
+
+  /// Return the definition for the given operation, or nullptr if the given
+  /// operation does not have a definition.
+  const OperationDefinition &getOpDef(Operation *op) const {
+    return *asmState.getOpDef(op);
+  }
+
+private:
+  // The context and state required to parse.
+  MLIRContext context;
+  llvm::SourceMgr sourceMgr;
+  DialectRegistry registry;
+  FallbackAsmResourceMap fallbackResourceMap;
+
+  // Storage of textual parsing results.
+  AsmParserState asmState;
+
+  // Parsed IR.
+  Block parsedIR;
+
+  // The RewriteBuffer  is doing most of the real work.
+  llvm::RewriteBuffer rewriteBuffer;
+
+  // Start of the original input, used to compute offset.
+  const char *start;
+};
+
+std::unique_ptr<RewriteBuffer> RewriteBuffer::init(StringRef inputFilename,
+                                                   StringRef outputFilename) {
+  std::unique_ptr<RewriteBuffer> r = std::make_unique<RewriteBuffer>();
+
+  // Register all the dialects needed.
+  registerAllDialects(r->registry);
+
+  // Set up the input file.
+  std::string errorMessage;
+  std::unique_ptr<llvm::MemoryBuffer> file =
+      openInputFile(inputFilename, &errorMessage);
+  if (!file) {
+    llvm::errs() << errorMessage << "\n";
+    return nullptr;
+  }
+  r->sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc());
+
+  // Set up the MLIR context and error handling.
+  r->context.appendDialectRegistry(r->registry);
+
+  // Record the start of the buffer to compute offsets with.
+  unsigned curBuf = r->sourceMgr.getMainFileID();
+  const llvm::MemoryBuffer *curMB = r->sourceMgr.getMemoryBuffer(curBuf);
+  r->start = curMB->getBufferStart();
+  r->rewriteBuffer.Initialize(curMB->getBuffer());
+
+  // Parse and populate the AsmParserState.
+  ParserConfig parseConfig(&r->context, /*verifyAfterParse=*/true,
+                           &r->fallbackResourceMap);
+  // Always allow unregistered.
+  r->context.allowUnregisteredDialects(true);
+  if (failed(parseAsmSourceFile(r->sourceMgr, &r->parsedIR, parseConfig,
+                                &r->asmState)))
+    return nullptr;
+
+  return r;
+}
+
+/// Return the source code associated with the operation name.
+SMRange getOpNameRange(const OperationDefinition &op) { return op.loc; }
+
+/// Return whether the operation was printed using generic syntax in original
+/// buffer.
+bool isGeneric(const OperationDefinition &op) {
+  return op.loc.Start.getPointer()[0] == '"';
+}
+
+inline int asMainReturnCode(LogicalResult r) {
+  return r.succeeded() ? EXIT_SUCCESS : EXIT_FAILURE;
+}
+
+/// Reriter function to invoke.
+using RewriterFunction = std::function<mlir::LogicalResult(
+    mlir::RewriteBuffer &rewriteBuffer, llvm::raw_ostream &os)>;
+
+/// Structure to group information about a rewriter (argument to invoke via
+/// mlir-tblgen, description, and rewriter function).
+class RewriterInfo {
+public:
+  /// RewriterInfo constructor should not be invoked directly, instead use
+  /// RewriterRegistration or registerRewriter.
+  RewriterInfo(StringRef arg, StringRef description, RewriterFunction rewriter)
+      : arg(arg), description(description), rewriter(std::move(rewriter)) {}
+
+  /// Invokes the rewriter and returns whether the rewriter failed.
+  LogicalResult invoke(mlir::RewriteBuffer &rewriteBuffer,
+                       raw_ostream &os) const {
+    assert(rewriter && "Cannot call rewriter with null rewriter");
+    return rewriter(rewriteBuffer, os);
+  }
+
+  /// Returns the command line option that may be passed to 'mlir-rewrite' to
+  /// invoke this rewriter.
+  StringRef getRewriterArgument() const { return arg; }
+
+  /// Returns a description for the rewriter.
+  StringRef getRewriterDescription() const { return description; }
+
+private:
+  // The argument with which to invoke the rewriter via mlir-tblgen.
+  StringRef arg;
+
+  // Description of the rewriter.
+  StringRef description;
+
+  // Rewritererator function.
+  RewriterFunction rewriter;
+};
+
+static llvm::ManagedStatic<std::vector<RewriterInfo>> rewriterRegistry;
+
+/// Adds command line option for each registered rewriter.
+struct RewriterNameParser : public llvm::cl::parser<const RewriterInfo *> {
+  RewriterNameParser(llvm::cl::Option &opt);
+
+  void printOptionInfo(const llvm::cl::Option &o,
+                       size_t globalWidth) const override;
+};
+
+/// RewriterRegistration provides a global initializer that registers a rewriter
+/// function.
+struct RewriterRegistration {
+  RewriterRegistration(StringRef arg, StringRef description,
+                       const RewriterFunction &function);
+};
+
+RewriterRegistration::RewriterRegistration(StringRef arg, StringRef description,
+                                           const RewriterFunction &function) {
+  rewriterRegistry->emplace_back(arg, description, function);
+}
+
+RewriterNameParser::RewriterNameParser(llvm::cl::Option &opt)
+    : llvm::cl::parser<const RewriterInfo *>(opt) {
+  for (const auto &kv : *rewriterRegistry) {
+    addLiteralOption(kv.getRewriterArgument(), &kv,
+                     kv.getRewriterDescription());
+  }
+}
+
+void RewriterNameParser::printOptionInfo(const llvm::cl::Option &o,
+                                         size_t globalWidth) const {
+  RewriterNameParser *tp = const_cast<RewriterNameParser *>(this);
+  llvm::array_pod_sort(tp->Values.begin(), tp->Values.end(),
+                       [](const RewriterNameParser::OptionInfo *vT1,
+                          const RewriterNameParser::OptionInfo *vT2) {
+                         return vT1->Name.compare(vT2->Name);
+                       });
+  using llvm::cl::parser;
+  parser<const RewriterInfo *>::printOptionInfo(o, globalWidth);
+}
+
+} // namespace mlir
+
+// TODO: Make these injectable too in non-global way.
+static llvm::cl::OptionCategory clSimpleRenameCategory{"simple-rename options"};
+static llvm::cl::opt<std::string> simpleRenameOpName{
+    "simple-rename-op-name", llvm::cl::desc("Name of op to match on"),
+    llvm::cl::cat(clSimpleRenameCategory)};
+static llvm::cl::opt<std::string> simpleRenameMatch{
+    "simple-rename-match", llvm::cl::desc("Match string for rename"),
+    llvm::cl::cat(clSimpleRenameCategory)};
+static llvm::cl::opt<std::string> simpleRenameReplace{
+    "simple-rename-replace", llvm::cl::desc("Replace string for rename"),
+    llvm::cl::cat(clSimpleRenameCategory)};
+
+// Rewriter that does simple renames.
+LogicalResult simpleRename(RewriteBuffer &rewriteBuffer, raw_ostream &os) {
+  StringRef opName = simpleRenameOpName;
+  StringRef match = simpleRenameMatch;
+  StringRef replace = simpleRenameReplace;
+  llvm::Regex regex(match);
+
+  rewriteBuffer.getParsed()->walk([&](Operation *op) {
+    if (op->getName().getStringRef() != opName)
+      return;
+
+    const OperationDefinition &opDef = rewriteBuffer.getOpDef(op);
+    SMRange range = getOpRange(opDef);
+    // This is a little bit overkill for simple.
+    std::string str = regex.sub(replace, rewriteBuffer.getSourceString(range));
+    rewriteBuffer.replaceRange(range, str);
+  });
+  return success();
+}
+
+static mlir::RewriterRegistration rewriteSimpleRename("simple-rename",
+                                                      "Perform a simple rename",
+                                                      simpleRename);
+
+// Rewriter that insert range markers.
+LogicalResult markRanges(RewriteBuffer &rewriteBuffer, raw_ostream &os) {
+  for (auto it : rewriteBuffer.getOpDefs()) {
+    auto [startOp, endOp] = getOpRange(it);
+
+    rewriteBuffer.insertText(startOp, "《");
+    rewriteBuffer.insertText(endOp, "》");
+
+    auto nameRange = getOpNameRange(it);
+
+    if (isGeneric(it)) {
+      rewriteBuffer.insertText(nameRange.Start, "〖");
+      rewriteBuffer.insertText(nameRange.End, "〗");
+    } else {
+      rewriteBuffer.insertText(nameRange.Start, "〔");
+      rewriteBuffer.insertText(nameRange.End, "〕");
+    }
+  }
+
+  // Highlight all comment lines.
+  // TODO: Could be replaced if this is kept in memory.
+  for (auto commentLine : rewriteBuffer.getSingleLineComments()) {
+    rewriteBuffer.insertText(commentLine.Start, "❰");
+    rewriteBuffer.insertText(commentLine.End, "❱");
+  }
+
+  return success();
+}
+
+static mlir::RewriterRegistration
+    rewriteMarkRanges("mark-ranges", "Indicate ranges parsed", markRanges);
+
+int main(int argc, char **argv) {
+  static llvm::cl::opt<std::string> inputFilename(
+      llvm::cl::Positional, llvm::cl::desc("<input file>"), llvm::cl::init("-"));
+
+  static llvm::cl::opt<std::string> outputFilename("o", llvm::cl::desc("Output filename"),
+                                             llvm::cl::value_desc("filename"),
+                                             llvm::cl::init("-"));
+
+  llvm::cl::opt<const mlir::RewriterInfo *, false, mlir::RewriterNameParser>
+      rewriter("", llvm::cl::desc("Rewriter to run"));
+
+  std::string helpHeader = "mlir-rewrite";
+
+  llvm::cl::ParseCommandLineOptions(argc, argv, helpHeader);
+
+  // If no rewriter has been selected, exit with error code. Could also just
+  // return but its unlikely this was intentionally being used as `cp`.
+  if (!rewriter) {
+    llvm::errs() << "No rewriter selected!\n";
+    return mlir::asMainReturnCode(mlir::failure());
+  }
+
+  // Set up rewrite buffer.
+  auto rewriterOr = RewriteBuffer::init(inputFilename, outputFilename);
+  if (!rewriterOr)
+    return mlir::asMainReturnCode(mlir::failure());
+
+  // Set up the output file.
+  std::string errorMessage;
+  auto output = openOutputFile(outputFilename, &errorMessage);
+  if (!output) {
+    llvm::errs() << errorMessage << "\n";
+    return mlir::asMainReturnCode(mlir::failure());
+  }
+
+  LogicalResult result = rewriter->invoke(*rewriterOr, output->os());
+  if (succeeded(result)) {
+    rewriterOr->write(output->os());
+    output->keep();
+  }
+  return mlir::asMainReturnCode(result);
+}



More information about the Mlir-commits mailing list