[Mlir-commits] [mlir] [mlir] Start rewrite tool (PR #77668)

Mehdi Amini llvmlistbot at llvm.org
Thu Oct 10 16:04:28 PDT 2024


================
@@ -0,0 +1,391 @@
+//===- mlir-rewrite.cpp - MLIR Rewrite Driver -----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Main entry function for mlir-rewrite.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/AsmParser/AsmParser.h"
+#include "mlir/AsmParser/AsmParserState.h"
+#include "mlir/IR/AsmState.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/InitAllDialects.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/FileUtilities.h"
+#include "mlir/Tools/ParseUtilities.h"
+#include "llvm/ADT/RewriteBuffer.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/InitLLVM.h"
+#include "llvm/Support/LineIterator.h"
+#include "llvm/Support/Regex.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/ToolOutputFile.h"
+
+using namespace mlir;
+
+namespace mlir {
+using OperationDefinition = AsmParserState::OperationDefinition;
+
+/// Return the source code associated with the OperationDefinition.
+SMRange getOpRange(const OperationDefinition &op) {
+  const char *startOp = op.scopeLoc.Start.getPointer();
+  const char *endOp = op.scopeLoc.End.getPointer();
+
+  for (auto res : op.resultGroups) {
+    SMRange range = res.definition.loc;
+    startOp = std::min(startOp, range.Start.getPointer());
+  }
+  return {SMLoc::getFromPointer(startOp), SMLoc::getFromPointer(endOp)};
+}
+
+/// Helper to simplify rewriting the source file.
+class RewriteBuffer {
+public:
+  static std::unique_ptr<RewriteBuffer> init(StringRef inputFilename,
+                                             StringRef outputFilename);
+
+  /// Return the context the file was parsed into.
+  MLIRContext *getContext() { return &context; }
+
+  /// Return the OperationDefinition's of the operation's parsed.
+  auto getOpDefs() { return asmState.getOpDefs(); }
+
+  /// Insert the specified string at the specified location in the original
+  /// buffer.
+  void insertText(SMLoc pos, StringRef str, bool insertAfter = true) {
+    rewriteBuffer.InsertText(pos.getPointer() - start, str, insertAfter);
+  }
+
+  /// Replace the range of the source text with the corresponding string in the
+  /// output.
+  void replaceRange(SMRange range, StringRef str) {
+    rewriteBuffer.ReplaceText(range.Start.getPointer() - start,
+                              range.End.getPointer() - range.Start.getPointer(),
+                              str);
+  }
+
+  /// Replace the range of the operation in the source text with the
+  /// corresponding string in the output.
+  void replaceDef(const OperationDefinition &opDef, StringRef newDef) {
+    replaceRange(getOpRange(opDef), newDef);
+  }
+
+  /// Return the source string corresponding to the source range.
+  StringRef getSourceString(SMRange range) {
+    return StringRef(range.Start.getPointer(),
+                     range.End.getPointer() - range.Start.getPointer());
+  }
+
+  /// Return the source string corresponding to operation definition.
+  StringRef getSourceString(const OperationDefinition &opDef) {
+    auto range = getOpRange(opDef);
+    return getSourceString(range);
+  }
+
+  /// Write to stream the result of applying all changes to the
+  /// original buffer.
+  /// Note that it isn't safe to use this function to overwrite memory mapped
+  /// files in-place (PR17960).
+  ///
+  /// The original buffer is not actually changed.
+  raw_ostream &write(raw_ostream &stream) const {
+    return rewriteBuffer.write(stream);
+  }
+
+  /// Return lines that are purely comments.
+  SmallVector<SMRange> getSingleLineComments() {
+    unsigned curBuf = sourceMgr.getMainFileID();
+    const llvm::MemoryBuffer *curMB = sourceMgr.getMemoryBuffer(curBuf);
+    llvm::line_iterator lineIterator(*curMB);
+    SmallVector<SMRange> ret;
+    for (; !lineIterator.is_at_end(); ++lineIterator) {
+      StringRef trimmed = lineIterator->ltrim();
+      if (trimmed.starts_with("//")) {
+        ret.emplace_back(
+            SMLoc::getFromPointer(trimmed.data()),
+            SMLoc::getFromPointer(trimmed.data() + trimmed.size()));
+      }
+    }
+    return ret;
+  }
+
+  /// Return the IR from parsed file.
+  Block *getParsed() { return &parsedIR; }
+
+  /// Return the definition for the given operation, or nullptr if the given
+  /// operation does not have a definition.
+  const OperationDefinition &getOpDef(Operation *op) const {
+    return *asmState.getOpDef(op);
+  }
+
+private:
+  // The context and state required to parse.
+  MLIRContext context;
+  llvm::SourceMgr sourceMgr;
+  DialectRegistry registry;
+  FallbackAsmResourceMap fallbackResourceMap;
+
+  // Storage of textual parsing results.
+  AsmParserState asmState;
+
+  // Parsed IR.
+  Block parsedIR;
+
+  // The RewriteBuffer  is doing most of the real work.
+  llvm::RewriteBuffer rewriteBuffer;
+
+  // Start of the original input, used to compute offset.
+  const char *start;
+};
+
+std::unique_ptr<RewriteBuffer> RewriteBuffer::init(StringRef inputFilename,
+                                                   StringRef outputFilename) {
+  std::unique_ptr<RewriteBuffer> r = std::make_unique<RewriteBuffer>();
+
+  // Register all the dialects needed.
+  registerAllDialects(r->registry);
----------------
joker-eph wrote:

There will be some non-trivial refactoring to make it a properly usable utility (like mlir-opt is separated from `MlirOptMain`  library for example.

https://github.com/llvm/llvm-project/pull/77668


More information about the Mlir-commits mailing list