[Mlir-commits] [mlir] [mlir] Start rewrite tool (PR #77668)
Mehdi Amini
llvmlistbot at llvm.org
Thu Oct 10 16:04:28 PDT 2024
================
@@ -0,0 +1,391 @@
+//===- mlir-rewrite.cpp - MLIR Rewrite Driver -----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Main entry function for mlir-rewrite.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/AsmParser/AsmParser.h"
+#include "mlir/AsmParser/AsmParserState.h"
+#include "mlir/IR/AsmState.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/InitAllDialects.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/FileUtilities.h"
+#include "mlir/Tools/ParseUtilities.h"
+#include "llvm/ADT/RewriteBuffer.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/InitLLVM.h"
+#include "llvm/Support/LineIterator.h"
+#include "llvm/Support/Regex.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/ToolOutputFile.h"
+
+using namespace mlir;
+
+namespace mlir {
+using OperationDefinition = AsmParserState::OperationDefinition;
+
+/// Return the source code associated with the OperationDefinition.
+SMRange getOpRange(const OperationDefinition &op) {
+ const char *startOp = op.scopeLoc.Start.getPointer();
+ const char *endOp = op.scopeLoc.End.getPointer();
+
+ for (auto res : op.resultGroups) {
+ SMRange range = res.definition.loc;
+ startOp = std::min(startOp, range.Start.getPointer());
+ }
+ return {SMLoc::getFromPointer(startOp), SMLoc::getFromPointer(endOp)};
+}
+
+/// Helper to simplify rewriting the source file.
+class RewriteBuffer {
+public:
+ static std::unique_ptr<RewriteBuffer> init(StringRef inputFilename,
+ StringRef outputFilename);
+
+ /// Return the context the file was parsed into.
+ MLIRContext *getContext() { return &context; }
+
+ /// Return the OperationDefinition's of the operation's parsed.
+ auto getOpDefs() { return asmState.getOpDefs(); }
+
+ /// Insert the specified string at the specified location in the original
+ /// buffer.
+ void insertText(SMLoc pos, StringRef str, bool insertAfter = true) {
+ rewriteBuffer.InsertText(pos.getPointer() - start, str, insertAfter);
+ }
+
+ /// Replace the range of the source text with the corresponding string in the
+ /// output.
+ void replaceRange(SMRange range, StringRef str) {
+ rewriteBuffer.ReplaceText(range.Start.getPointer() - start,
+ range.End.getPointer() - range.Start.getPointer(),
+ str);
+ }
+
+ /// Replace the range of the operation in the source text with the
+ /// corresponding string in the output.
+ void replaceDef(const OperationDefinition &opDef, StringRef newDef) {
+ replaceRange(getOpRange(opDef), newDef);
+ }
+
+ /// Return the source string corresponding to the source range.
+ StringRef getSourceString(SMRange range) {
+ return StringRef(range.Start.getPointer(),
+ range.End.getPointer() - range.Start.getPointer());
+ }
+
+ /// Return the source string corresponding to operation definition.
+ StringRef getSourceString(const OperationDefinition &opDef) {
+ auto range = getOpRange(opDef);
+ return getSourceString(range);
+ }
+
+ /// Write to stream the result of applying all changes to the
+ /// original buffer.
+ /// Note that it isn't safe to use this function to overwrite memory mapped
+ /// files in-place (PR17960).
+ ///
+ /// The original buffer is not actually changed.
+ raw_ostream &write(raw_ostream &stream) const {
+ return rewriteBuffer.write(stream);
+ }
+
+ /// Return lines that are purely comments.
+ SmallVector<SMRange> getSingleLineComments() {
+ unsigned curBuf = sourceMgr.getMainFileID();
+ const llvm::MemoryBuffer *curMB = sourceMgr.getMemoryBuffer(curBuf);
+ llvm::line_iterator lineIterator(*curMB);
+ SmallVector<SMRange> ret;
+ for (; !lineIterator.is_at_end(); ++lineIterator) {
+ StringRef trimmed = lineIterator->ltrim();
+ if (trimmed.starts_with("//")) {
+ ret.emplace_back(
+ SMLoc::getFromPointer(trimmed.data()),
+ SMLoc::getFromPointer(trimmed.data() + trimmed.size()));
+ }
+ }
+ return ret;
+ }
+
+ /// Return the IR from parsed file.
+ Block *getParsed() { return &parsedIR; }
+
+ /// Return the definition for the given operation, or nullptr if the given
+ /// operation does not have a definition.
+ const OperationDefinition &getOpDef(Operation *op) const {
+ return *asmState.getOpDef(op);
+ }
+
+private:
+ // The context and state required to parse.
+ MLIRContext context;
+ llvm::SourceMgr sourceMgr;
+ DialectRegistry registry;
+ FallbackAsmResourceMap fallbackResourceMap;
+
+ // Storage of textual parsing results.
+ AsmParserState asmState;
+
+ // Parsed IR.
+ Block parsedIR;
+
+ // The RewriteBuffer is doing most of the real work.
+ llvm::RewriteBuffer rewriteBuffer;
+
+ // Start of the original input, used to compute offset.
+ const char *start;
+};
+
+std::unique_ptr<RewriteBuffer> RewriteBuffer::init(StringRef inputFilename,
+ StringRef outputFilename) {
+ std::unique_ptr<RewriteBuffer> r = std::make_unique<RewriteBuffer>();
+
+ // Register all the dialects needed.
+ registerAllDialects(r->registry);
----------------
joker-eph wrote:
There will be some non-trivial refactoring to make it a properly usable utility (like mlir-opt is separated from `MlirOptMain` library for example.
https://github.com/llvm/llvm-project/pull/77668
More information about the Mlir-commits
mailing list