[Mlir-commits] [mlir] [mlir] Start rewrite tool (PR #77668)

Jacques Pienaar llvmlistbot at llvm.org
Wed Jan 10 11:01:38 PST 2024


https://github.com/jpienaar updated https://github.com/llvm/llvm-project/pull/77668

>From eeaa830c123037e1d21b18bb66b245d7f6a51753 Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jpienaar at google.com>
Date: Sun, 7 Jan 2024 21:31:41 -0800
Subject: [PATCH] [mlir] Start rewrite tool

Initial commit of a tool to help in textual rewrites of .mlir files. This tool
builds of of AsmParserState and is rather simple. Took some inspiration from
when I used clang's AST rewrites where I'd often treat it as a "localizing"
regex applicator in fallback cases, and started with that as functionality.
There though, one does have access to the lower level info than here, but still
a step up over sed over entire file.

This aims to be helpful (e.g., rewrite syntax including best effort inside
comments) rather than bulletproof tool. It may even be better suited under
utils than tools. And most of the rewrites would be rather short lived and
might never make it upstream (while the helpers of those rewrites may for
future rewrites).

The layering at the moment is not ideal as it is reusing the RewriteBuffer
class from clang's rewrite engine. So only optionally enabling where clang is
also enable. There doesn't seem to be anything clang specific there (the dep
does pull in more dependencies than ideal, but leaving both refactorings).

Additionally started it as a single file to prototype more easily, planning to
refactor later to include and libs for out of file usage.
---
 mlir/CMakeLists.txt                      |   5 +
 mlir/docs/Tools/mlir-rewrite.md          |  29 ++
 mlir/test/CMakeLists.txt                 |   7 +
 mlir/test/lit.cfg.py                     |   4 +
 mlir/test/lit.site.cfg.py.in             |   1 +
 mlir/test/mlir-rewrite/simple.mlir       |  12 +
 mlir/tools/CMakeLists.txt                |   5 +
 mlir/tools/mlir-rewrite/CMakeLists.txt   |  37 +++
 mlir/tools/mlir-rewrite/mlir-rewrite.cpp | 394 +++++++++++++++++++++++
 9 files changed, 494 insertions(+)
 create mode 100644 mlir/docs/Tools/mlir-rewrite.md
 create mode 100644 mlir/test/mlir-rewrite/simple.mlir
 create mode 100644 mlir/tools/mlir-rewrite/CMakeLists.txt
 create mode 100644 mlir/tools/mlir-rewrite/mlir-rewrite.cpp

diff --git a/mlir/CMakeLists.txt b/mlir/CMakeLists.txt
index 2d9f78e03ba76b..64aad84e90a5ad 100644
--- a/mlir/CMakeLists.txt
+++ b/mlir/CMakeLists.txt
@@ -285,3 +285,8 @@ endif()
 if(MLIR_STANDALONE_BUILD)
   llvm_distribution_add_targets()
 endif()
+
+# FIXME: Currently depends on utility functions inside clang.
+if ("clang" IN_LIST LLVM_ENABLE_PROJECTS)
+  set(MLIR_ENABLE_REWRITE ON CACHE BOOL "mlir-rewrite enabled")
+endif()
diff --git a/mlir/docs/Tools/mlir-rewrite.md b/mlir/docs/Tools/mlir-rewrite.md
new file mode 100644
index 00000000000000..178f92f72cbb6e
--- /dev/null
+++ b/mlir/docs/Tools/mlir-rewrite.md
@@ -0,0 +1,29 @@
+# mlir-rewrite
+
+Tool to simplify rewriting .mlir files. There are a couple of build in rewrites
+discussed below along with usage.
+
+Note: This is still in very early stage. Its so early its less a tool than a
+growing collection of useful functions: to use its best to do what's needed on
+a brance by just hacking it (dialects registered, rewrites etc) to say help
+ease a rename, upstream useful utility functions, point to ease others
+migrating, and then bin eventually. Once there are actually useful parts it
+should be refactored same as mlir-opt.
+
+[TOC]
+
+## simple-rename
+
+Rename per op given a substring to a target. The match and replace uses LLVM's
+regex sub for the match and replace while the op-name is matched via regular
+string comparison. E.g.,
+
+```
+mlir-rewrite input.mlir -o output.mlir --simple-rename \
+   --simple-rename-op-name="test.concat" --simple-rename-match="axis" \
+                                         --simple-rename-replace="bxis"
+```
+
+to replace `axis` substring in the text of the range corresponding to
+`test.concat` ops with `bxis`.
+
diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt
index 8ce030feeded92..397a2efcf5e9f2 100644
--- a/mlir/test/CMakeLists.txt
+++ b/mlir/test/CMakeLists.txt
@@ -197,6 +197,13 @@ if(MLIR_ENABLE_BINDINGS_PYTHON)
   )
 endif()
 
+# FIXME: Currently depends on utility functions inside clang.
+if ("clang" IN_LIST LLVM_ENABLE_PROJECTS)
+  list(APPEND MLIR_TEST_DEPENDS
+    mlir-rewrite
+  )
+endif()
+
 # This target can be used to just build the dependencies
 # for the check-mlir target without executing the tests.
 # This is useful for bots when splitting the build step
diff --git a/mlir/test/lit.cfg.py b/mlir/test/lit.cfg.py
index 0a1ea1d16da452..35d6a3bd1f5636 100644
--- a/mlir/test/lit.cfg.py
+++ b/mlir/test/lit.cfg.py
@@ -144,6 +144,10 @@ def add_runtime(name):
         )
     )
 
+if config.enable_mlir_rewrite:
+    tools.extend(["mlir-rewrite"])
+    config.available_features.add('mlir-rewrite')
+
 # The following tools are optional
 tools.extend(
     [
diff --git a/mlir/test/lit.site.cfg.py.in b/mlir/test/lit.site.cfg.py.in
index c0fa1b8980e539..d35e3701198a56 100644
--- a/mlir/test/lit.site.cfg.py.in
+++ b/mlir/test/lit.site.cfg.py.in
@@ -23,6 +23,7 @@ config.mlir_obj_root = "@MLIR_BINARY_DIR@"
 config.mlir_tools_dir = "@MLIR_TOOLS_DIR@"
 config.mlir_cmake_dir = "@MLIR_CMAKE_DIR@"
 config.mlir_lib_dir = "@MLIR_LIB_DIR@"
+config.enable_mlir_rewrite = "@MLIR_ENABLE_REWRITE@"
 
 config.build_examples = @LLVM_BUILD_EXAMPLES@
 config.run_cuda_tests = @MLIR_ENABLE_CUDA_CONVERSIONS@
diff --git a/mlir/test/mlir-rewrite/simple.mlir b/mlir/test/mlir-rewrite/simple.mlir
new file mode 100644
index 00000000000000..cf3a029b0653b0
--- /dev/null
+++ b/mlir/test/mlir-rewrite/simple.mlir
@@ -0,0 +1,12 @@
+// RUN: mlir-opt %s | mlir-rewrite --simple-rename --simple-rename-op-name="test.concat" --simple-rename-match="axis" --simple-rename-replace="bxis" | FileCheck %s -check-prefix=RENAME
+// RUN: mlir-opt %s | mlir-rewrite --mark-ranges | FileCheck %s -check-prefix=RANGE
+// Note: running through mlir-opt to just strip out comments & avoid self matches.
+// REQUIRES: mlir-rewrite
+
+func.func @two_dynamic_one_direct_shape(%arg0: tensor<?x4x?xf32>, %arg1: tensor<2x4x?xf32>) -> tensor<?x4x?xf32> {
+  // RENAME: "test.concat"({{.*}}) {bxis = 0 : i64}
+  // RANGE: 《%{{.*}} = 〖"test.concat"〗({{.*}}) {axis = 0 : i64} : (tensor<?x4x?xf32>, tensor<2x4x?xf32>) -> tensor<?x4x?xf32>》
+  %5 = "test.concat"(%arg0, %arg1) {axis = 0 : i64} : (tensor<?x4x?xf32>, tensor<2x4x?xf32>) -> tensor<?x4x?xf32>
+  return %5 : tensor<?x4x?xf32>
+}
+
diff --git a/mlir/tools/CMakeLists.txt b/mlir/tools/CMakeLists.txt
index 9b474385fdae18..7d330e124a2ca2 100644
--- a/mlir/tools/CMakeLists.txt
+++ b/mlir/tools/CMakeLists.txt
@@ -15,3 +15,8 @@ add_subdirectory(tblgen-to-irdl)
 if(MLIR_ENABLE_EXECUTION_ENGINE)
   add_subdirectory(mlir-cpu-runner)
 endif()
+
+# FIXME: Currently depends on utility functions inside clang.
+if ("clang" IN_LIST LLVM_ENABLE_PROJECTS)
+  add_subdirectory(mlir-rewrite)
+endif()
diff --git a/mlir/tools/mlir-rewrite/CMakeLists.txt b/mlir/tools/mlir-rewrite/CMakeLists.txt
new file mode 100644
index 00000000000000..29126432d2de5d
--- /dev/null
+++ b/mlir/tools/mlir-rewrite/CMakeLists.txt
@@ -0,0 +1,37 @@
+get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
+set(LLVM_LINK_COMPONENTS
+  Support
+  )
+
+set(LIBS
+  ${dialect_libs}
+  ${test_libs}
+
+  clangRewrite
+  MLIRAffineAnalysis
+  MLIRAnalysis
+  MLIRCastInterfaces
+  MLIRDialect
+  MLIROptLib
+  MLIRParser
+  MLIRPass
+  MLIRTransforms
+  MLIRTransformUtils
+  MLIRSupport
+  MLIRIR
+  )
+
+include_directories(../../../clang/include)
+
+add_mlir_tool(mlir-rewrite
+  mlir-rewrite.cpp
+
+  DEPENDS
+  ${LIBS}
+  SUPPORT_PLUGINS
+  )
+target_link_libraries(mlir-rewrite PRIVATE ${LIBS})
+llvm_update_compile_flags(mlir-rewrite)
+
+mlir_check_all_link_libraries(mlir-rewrite)
+export_executable_symbols_for_plugins(mlir-rewrite)
diff --git a/mlir/tools/mlir-rewrite/mlir-rewrite.cpp b/mlir/tools/mlir-rewrite/mlir-rewrite.cpp
new file mode 100644
index 00000000000000..b704fd4c59c3fe
--- /dev/null
+++ b/mlir/tools/mlir-rewrite/mlir-rewrite.cpp
@@ -0,0 +1,394 @@
+//===- mlir-rewrite.cpp - MLIR Rewrite Driver -----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Main entry function for mlir-rewrite.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/AsmParser/AsmParser.h"
+#include "mlir/AsmParser/AsmParserState.h"
+#include "mlir/IR/AsmState.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/InitAllDialects.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/FileUtilities.h"
+#include "mlir/Tools/ParseUtilities.h"
+// TODO: Refactor the RewriteBuffer out to avoid the weird Clang dep.
+#include "clang/Rewrite/Core/RewriteBuffer.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/InitLLVM.h"
+#include "llvm/Support/LineIterator.h"
+#include "llvm/Support/Regex.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/ToolOutputFile.h"
+
+using namespace llvm;
+using namespace mlir;
+
+namespace mlir {
+using OperationDefinition = AsmParserState::OperationDefinition;
+
+/// Return the source code associated with the OperationDefinition.
+SMRange getOpRange(const OperationDefinition &op) {
+  const char *startOp = op.scopeLoc.Start.getPointer();
+  const char *endOp = op.scopeLoc.End.getPointer();
+
+  for (auto res : op.resultGroups) {
+    SMRange range = res.definition.loc;
+    startOp = std::min(startOp, range.Start.getPointer());
+  }
+  return {SMLoc::getFromPointer(startOp), SMLoc::getFromPointer(endOp)};
+}
+
+/// Helper to simplify rewriting the source file.
+class RewriteBuffer {
+public:
+  static std::unique_ptr<RewriteBuffer> init(StringRef inputFilename,
+                                             StringRef outputFilename);
+
+  /// Return the context the file was parsed into.
+  MLIRContext *getContext() { return &context; }
+
+  /// Return the OperationDefinition's of the operation's parsed.
+  auto getOpDefs() { return asmState.getOpDefs(); }
+
+  /// Insert the specified string at the specified location in the original
+  /// buffer.
+  void insertText(SMLoc pos, StringRef str, bool insertAfter = true) {
+    rewriteBuffer.InsertText(pos.getPointer() - start, str, insertAfter);
+  }
+
+  /// Replace the range of the source text with the corresponding string in the
+  /// output.
+  void replaceRange(SMRange range, StringRef str) {
+    rewriteBuffer.ReplaceText(range.Start.getPointer() - start,
+                              range.End.getPointer() - range.Start.getPointer(),
+                              str);
+  }
+
+  /// Replace the range of the operation in the source text with the
+  /// corresponding string in the output.
+  void replaceDef(const OperationDefinition &opDef, StringRef newDef) {
+    replaceRange(getOpRange(opDef), newDef);
+  }
+
+  /// Return the source string corresponding to the source range.
+  StringRef getSourceString(SMRange range) {
+    return StringRef(range.Start.getPointer(),
+                     range.End.getPointer() - range.Start.getPointer());
+  }
+
+  /// Return the source string corresponding to operation definition.
+  StringRef getSourceString(const OperationDefinition &opDef) {
+    auto range = getOpRange(opDef);
+    return getSourceString(range);
+  }
+
+  /// Write to stream the result of applying all changes to the
+  /// original buffer.
+  /// Note that it isn't safe to use this function to overwrite memory mapped
+  /// files in-place (PR17960).
+  ///
+  /// The original buffer is not actually changed.
+  raw_ostream &write(raw_ostream &stream) const {
+    return rewriteBuffer.write(stream);
+  }
+
+  /// Return lines that are purely comments.
+  SmallVector<SMRange> getSingleLineComments() {
+    unsigned curBuf = sourceMgr.getMainFileID();
+    const MemoryBuffer *curMB = sourceMgr.getMemoryBuffer(curBuf);
+    auto lineIterator = line_iterator(*curMB);
+    SmallVector<SMRange> ret;
+    for (; !lineIterator.is_at_end(); ++lineIterator) {
+      StringRef trimmed = lineIterator->ltrim();
+      if (trimmed.starts_with("//")) {
+        ret.emplace_back(
+            SMLoc::getFromPointer(trimmed.data()),
+            SMLoc::getFromPointer(trimmed.data() + trimmed.size()));
+      }
+    }
+    return ret;
+  }
+
+  /// Return the IR from parsed file.
+  Block *getParsed() { return &parsedIR; }
+
+  /// Return the definition for the given operation, or nullptr if the given
+  /// operation does not have a definition.
+  const OperationDefinition &getOpDef(Operation *op) const {
+    return *asmState.getOpDef(op);
+  }
+
+private:
+  // The context and state required to parse.
+  MLIRContext context;
+  SourceMgr sourceMgr;
+  DialectRegistry registry;
+  FallbackAsmResourceMap fallbackResourceMap;
+
+  // Storage of textual parsing results.
+  AsmParserState asmState;
+
+  // Parsed IR.
+  Block parsedIR;
+
+  // The RewriteBuffer from clang-rewrite is doing most of the real work.
+  clang::RewriteBuffer rewriteBuffer;
+
+  // Start of the original input, used to compute offset.
+  const char *start;
+};
+
+std::unique_ptr<RewriteBuffer> RewriteBuffer::init(StringRef inputFilename,
+                                                   StringRef outputFilename) {
+  std::unique_ptr<RewriteBuffer> r = std::make_unique<RewriteBuffer>();
+
+  // Register all the dialects needed.
+  registerAllDialects(r->registry);
+
+  // Set up the input file.
+  std::string errorMessage;
+  std::unique_ptr<llvm::MemoryBuffer> file =
+      openInputFile(inputFilename, &errorMessage);
+  if (!file) {
+    llvm::errs() << errorMessage << "\n";
+    return nullptr;
+  }
+  r->sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc());
+
+  // Set up the MLIR context and error handling.
+  r->context.appendDialectRegistry(r->registry);
+
+  // Record the start of the buffer to compute offsets with.
+  unsigned curBuf = r->sourceMgr.getMainFileID();
+  const MemoryBuffer *curMB = r->sourceMgr.getMemoryBuffer(curBuf);
+  r->start = curMB->getBufferStart();
+  r->rewriteBuffer.Initialize(curMB->getBuffer());
+
+  // Parse and populate the AsmParserState.
+  ParserConfig parseConfig(&r->context, /*verifyAfterParse=*/true,
+                           &r->fallbackResourceMap);
+  // Always allow unregistered.
+  r->context.allowUnregisteredDialects(true);
+  if (failed(parseAsmSourceFile(r->sourceMgr, &r->parsedIR, parseConfig,
+                                &r->asmState)))
+    return nullptr;
+
+  return r;
+}
+
+/// Return the source code associated with the operation name.
+SMRange getOpNameRange(const OperationDefinition &op) { return op.loc; }
+
+/// Return whether the operation was printed using generic syntax in original
+/// buffer.
+bool isGeneric(const OperationDefinition &op) {
+  return op.loc.Start.getPointer()[0] == '"';
+}
+
+inline int asMainReturnCode(LogicalResult r) {
+  return r.succeeded() ? EXIT_SUCCESS : EXIT_FAILURE;
+}
+
+/// Reriter function to invoke.
+using RewriterFunction = std::function<mlir::LogicalResult(
+    mlir::RewriteBuffer &rewriteBuffer, llvm::raw_ostream &os)>;
+
+/// Structure to group information about a rewriter (argument to invoke via
+/// mlir-tblgen, description, and rewriter function).
+class RewriterInfo {
+public:
+  /// RewriterInfo constructor should not be invoked directly, instead use
+  /// RewriterRegistration or registerRewriter.
+  RewriterInfo(StringRef arg, StringRef description, RewriterFunction rewriter)
+      : arg(arg), description(description), rewriter(std::move(rewriter)) {}
+
+  /// Invokes the rewriter and returns whether the rewriter failed.
+  LogicalResult invoke(mlir::RewriteBuffer &rewriteBuffer,
+                       raw_ostream &os) const {
+    assert(rewriter && "Cannot call rewriter with null rewriter");
+    return rewriter(rewriteBuffer, os);
+  }
+
+  /// Returns the command line option that may be passed to 'mlir-rewrite' to
+  /// invoke this rewriter.
+  StringRef getRewriterArgument() const { return arg; }
+
+  /// Returns a description for the rewriter.
+  StringRef getRewriterDescription() const { return description; }
+
+private:
+  // The argument with which to invoke the rewriter via mlir-tblgen.
+  StringRef arg;
+
+  // Description of the rewriter.
+  StringRef description;
+
+  // Rewritererator function.
+  RewriterFunction rewriter;
+};
+
+static llvm::ManagedStatic<std::vector<RewriterInfo>> rewriterRegistry;
+
+/// Adds command line option for each registered rewriter.
+struct RewriterNameParser : public llvm::cl::parser<const RewriterInfo *> {
+  RewriterNameParser(llvm::cl::Option &opt);
+
+  void printOptionInfo(const llvm::cl::Option &o,
+                       size_t globalWidth) const override;
+};
+
+/// RewriterRegistration provides a global initializer that registers a rewriter
+/// function.
+struct RewriterRegistration {
+  RewriterRegistration(StringRef arg, StringRef description,
+                       const RewriterFunction &function);
+};
+
+RewriterRegistration::RewriterRegistration(StringRef arg, StringRef description,
+                                           const RewriterFunction &function) {
+  rewriterRegistry->emplace_back(arg, description, function);
+}
+
+RewriterNameParser::RewriterNameParser(llvm::cl::Option &opt)
+    : llvm::cl::parser<const RewriterInfo *>(opt) {
+  for (const auto &kv : *rewriterRegistry) {
+    addLiteralOption(kv.getRewriterArgument(), &kv,
+                     kv.getRewriterDescription());
+  }
+}
+
+void RewriterNameParser::printOptionInfo(const llvm::cl::Option &o,
+                                         size_t globalWidth) const {
+  RewriterNameParser *tp = const_cast<RewriterNameParser *>(this);
+  llvm::array_pod_sort(tp->Values.begin(), tp->Values.end(),
+                       [](const RewriterNameParser::OptionInfo *vT1,
+                          const RewriterNameParser::OptionInfo *vT2) {
+                         return vT1->Name.compare(vT2->Name);
+                       });
+  using llvm::cl::parser;
+  parser<const RewriterInfo *>::printOptionInfo(o, globalWidth);
+}
+
+} // namespace mlir
+
+// TODO: Make these injectable too in non-global way.
+static llvm::cl::OptionCategory clSimpleRenameCategory{"simple-rename options"};
+static llvm::cl::opt<std::string> simpleRenameOpName{
+    "simple-rename-op-name", llvm::cl::desc("Name of op to match on"),
+    llvm::cl::cat(clSimpleRenameCategory)};
+static llvm::cl::opt<std::string> simpleRenameMatch{
+    "simple-rename-match", llvm::cl::desc("Match string for rename"),
+    llvm::cl::cat(clSimpleRenameCategory)};
+static llvm::cl::opt<std::string> simpleRenameReplace{
+    "simple-rename-replace", llvm::cl::desc("Replace string for rename"),
+    llvm::cl::cat(clSimpleRenameCategory)};
+
+// Rewriter that does simple renames.
+LogicalResult simpleRename(RewriteBuffer &rewriteBuffer, raw_ostream &os) {
+  StringRef opName = simpleRenameOpName;
+  StringRef match = simpleRenameMatch;
+  StringRef replace = simpleRenameReplace;
+  llvm::Regex regex(match);
+
+  rewriteBuffer.getParsed()->walk([&](Operation *op) {
+    if (op->getName().getStringRef() != opName)
+      return;
+
+    const OperationDefinition &opDef = rewriteBuffer.getOpDef(op);
+    SMRange range = getOpRange(opDef);
+    // This is a little bit overkill for simple.
+    std::string str = regex.sub(replace, rewriteBuffer.getSourceString(range));
+    rewriteBuffer.replaceRange(range, str);
+  });
+  return success();
+}
+
+static mlir::RewriterRegistration rewriteSimpleRename("simple-rename",
+                                                      "Perform a simple rename",
+                                                      simpleRename);
+
+// Rewriter that insert range markers.
+LogicalResult markRanges(RewriteBuffer &rewriteBuffer, raw_ostream &os) {
+  int i = 0;
+  for (auto it : rewriteBuffer.getOpDefs()) {
+    auto [startOp, endOp] = getOpRange(it);
+
+    rewriteBuffer.insertText(startOp, "《");
+    rewriteBuffer.insertText(endOp, "》");
+
+    auto nameRange = getOpNameRange(it);
+
+    if (isGeneric(it)) {
+      rewriteBuffer.insertText(nameRange.Start, "〖");
+      rewriteBuffer.insertText(nameRange.End, "〗");
+    } else {
+      rewriteBuffer.insertText(nameRange.Start, "〔");
+      rewriteBuffer.insertText(nameRange.End, "〕");
+    }
+    ++i;
+  }
+
+  // Highlight all comment lines.
+  // TODO: Could be replaced if this is kept in memory.
+  for (auto commentLine : rewriteBuffer.getSingleLineComments()) {
+    rewriteBuffer.insertText(commentLine.Start, "❰");
+    rewriteBuffer.insertText(commentLine.End, "❱");
+  }
+
+  return success();
+}
+
+static mlir::RewriterRegistration
+    rewriteMarkRanges("mark-ranges", "Indicate ranges parsed", markRanges);
+
+int main(int argc, char **argv) {
+  static cl::opt<std::string> inputFilename(
+      cl::Positional, cl::desc("<input file>"), cl::init("-"));
+
+  static cl::opt<std::string> outputFilename("o", cl::desc("Output filename"),
+                                             cl::value_desc("filename"),
+                                             cl::init("-"));
+
+  llvm::cl::opt<const mlir::RewriterInfo *, false, mlir::RewriterNameParser>
+      rewriter("", llvm::cl::desc("Rewriter to run"));
+
+  std::string helpHeader = "mlir-rewrite";
+
+  cl::ParseCommandLineOptions(argc, argv, helpHeader);
+
+  // If no rewriter has been selected, exit with error code. Could also just
+  // return but its unlikely this was intentionally being used as `cp`.
+  if (!rewriter) {
+    llvm::errs() << "No rewriter selected!\n";
+    return mlir::asMainReturnCode(mlir::failure());
+  }
+
+  // Set up rewrite buffer.
+  auto rewriterOr = RewriteBuffer::init(inputFilename, outputFilename);
+  if (!rewriterOr)
+    return mlir::asMainReturnCode(mlir::failure());
+
+  // Set up the output file.
+  std::string errorMessage;
+  auto output = openOutputFile(outputFilename, &errorMessage);
+  if (!output) {
+    llvm::errs() << errorMessage << "\n";
+    return mlir::asMainReturnCode(mlir::failure());
+  }
+
+  LogicalResult result = rewriter->invoke(*rewriterOr, output->os());
+  if (succeeded(result)) {
+    rewriterOr->write(output->os());
+    output->keep();
+  }
+  return mlir::asMainReturnCode(result);
+}



More information about the Mlir-commits mailing list