[Mlir-commits] [mlir] Add initial `mlir-format` PoC (PR #121260)
Perry Gibson
llvmlistbot at llvm.org
Mon Dec 30 02:55:24 PST 2024
https://github.com/Wheest updated https://github.com/llvm/llvm-project/pull/121260
>From 3eebcfdf91ecbbe572624e8a0f360e48e85b6fa6 Mon Sep 17 00:00:00 2001
From: Perry Gibson <perry at fractile.ai>
Date: Sat, 28 Dec 2024 10:18:12 +0000
Subject: [PATCH 1/2] Add initial mlir-format PoC
---
mlir/test/mlir-format/annotate_locs.mlir | 13 +
mlir/test/mlir-format/simple.mlir | 15 +
mlir/test/mlir-format/simple2.mlir | 19 ++
mlir/tools/CMakeLists.txt | 1 +
mlir/tools/mlir-format/CMakeLists.txt | 32 ++
mlir/tools/mlir-format/mlir-format.cpp | 385 +++++++++++++++++++++++
6 files changed, 465 insertions(+)
create mode 100644 mlir/test/mlir-format/annotate_locs.mlir
create mode 100644 mlir/test/mlir-format/simple.mlir
create mode 100644 mlir/test/mlir-format/simple2.mlir
create mode 100644 mlir/tools/mlir-format/CMakeLists.txt
create mode 100644 mlir/tools/mlir-format/mlir-format.cpp
diff --git a/mlir/test/mlir-format/annotate_locs.mlir b/mlir/test/mlir-format/annotate_locs.mlir
new file mode 100644
index 00000000000000..c557e8447bcf7b
--- /dev/null
+++ b/mlir/test/mlir-format/annotate_locs.mlir
@@ -0,0 +1,13 @@
+// RUN: mlir-format %s --mlir-use-nameloc-as-prefix --insert-name-loc-only | FileCheck %s
+
+// Append NameLocs (`loc("[ssa_name]")`) to operations and block arguments
+
+// CHECK: func.func @add_one(%my_input: f64 loc("my_input"), %my_input2: f64 loc("my_input2")) -> f64 {
+func.func @add_one(%my_input: f64, %my_input2: f64) -> f64 {
+ // CHECK: %my_constant = arith.constant 1.00000e+00 : f64 loc("my_constant")
+ %my_constant = arith.constant 1.00000e+00 : f64
+
+ %my_output = arith.addf %my_input, %my_constant : f64
+ // CHECK: %my_output = arith.addf %my_input, %my_constant : f64 loc("my_output")
+ return %my_output : f64
+}
diff --git a/mlir/test/mlir-format/simple.mlir b/mlir/test/mlir-format/simple.mlir
new file mode 100644
index 00000000000000..0b05fa2a63a1ec
--- /dev/null
+++ b/mlir/test/mlir-format/simple.mlir
@@ -0,0 +1,15 @@
+// RUN: mlir-format %s --mlir-use-nameloc-as-prefix | FileCheck %s
+
+// CHECK: func.func @add_one(%my_input: f64) -> f64 {
+func.func @add_one(%my_input: f64) -> f64 {
+ // CHECK: %my_constant = arith.constant 1.00000e+00 : f64
+ %my_constant = arith.constant 1.00000e+00 : f64
+ // CHECK: // Dinnae drop this comment!
+ // Dinnae drop this comment!
+ %my_output = arith.addf
+ %my_input,
+ %my_constant : f64
+ // CHECK-STRICT: %my_output = arith.addf %my_input, %my_constant : f64
+ return %my_output : f64
+ // CHECK: return %my_output : f64
+}
diff --git a/mlir/test/mlir-format/simple2.mlir b/mlir/test/mlir-format/simple2.mlir
new file mode 100644
index 00000000000000..beb22ec933fb98
--- /dev/null
+++ b/mlir/test/mlir-format/simple2.mlir
@@ -0,0 +1,19 @@
+// RUN: mlir-format %s --mlir-use-nameloc-as-prefix | FileCheck %s
+
+// CHECK: func.func @my_func(%x: f64, %y: f64) -> f64 {
+func.func @my_func(%x: f64, %y: f64) -> f64 {
+ // CHECK: %cst1 = arith.constant 1.00000e+00 : f64
+ %cst1 = arith.constant 1.00000e+00 : f64
+ // CHECK: %cst2 = arith.constant 2.00000e+00 : f64
+ %cst2 = arith.constant 2.00000e+00 : f64
+ // CHECK-STRICT: %x1 = arith.addf %x, %cst1 : f64
+ %x1 = arith.addf
+%x,
+%cst1 : f64
+ // CHECK-STRICT: %y2 = arith.addf %y, %cst2 : f64
+ %y2 = arith.addf %y, %cst2 : f64
+ // CHECK: %z = arith.addf %x1, %y2 : f64
+ %z = arith.addf %x1, %y2 : f64
+ // return %z : f64
+ return %z : f64
+}
diff --git a/mlir/tools/CMakeLists.txt b/mlir/tools/CMakeLists.txt
index 072e83c5d45ea1..0f3d5e6f9731ae 100644
--- a/mlir/tools/CMakeLists.txt
+++ b/mlir/tools/CMakeLists.txt
@@ -5,6 +5,7 @@ add_subdirectory(mlir-pdll-lsp-server)
add_subdirectory(mlir-query)
add_subdirectory(mlir-reduce)
add_subdirectory(mlir-rewrite)
+add_subdirectory(mlir-format)
add_subdirectory(mlir-shlib)
add_subdirectory(mlir-translate)
add_subdirectory(mlir-vulkan-runner)
diff --git a/mlir/tools/mlir-format/CMakeLists.txt b/mlir/tools/mlir-format/CMakeLists.txt
new file mode 100644
index 00000000000000..8ed6f095f0e4b5
--- /dev/null
+++ b/mlir/tools/mlir-format/CMakeLists.txt
@@ -0,0 +1,32 @@
+get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
+set(LLVM_LINK_COMPONENTS
+ Support
+ )
+
+set(LIBS
+ ${dialect_libs}
+
+ MLIRAffineAnalysis
+ MLIRAnalysis
+ MLIRCastInterfaces
+ MLIRDialect
+ MLIRParser
+ MLIRPass
+ MLIRTransforms
+ MLIRTransformUtils
+ MLIRSupport
+ MLIRIR
+ )
+
+include_directories(../../../clang/include)
+
+add_mlir_tool(mlir-format
+ mlir-format.cpp
+
+ SUPPORT_PLUGINS
+ )
+mlir_target_link_libraries(mlir-format PRIVATE ${LIBS})
+llvm_update_compile_flags(mlir-format)
+
+mlir_check_all_link_libraries(mlir-format)
+export_executable_symbols_for_plugins(mlir-format)
diff --git a/mlir/tools/mlir-format/mlir-format.cpp b/mlir/tools/mlir-format/mlir-format.cpp
new file mode 100644
index 00000000000000..ae0122a6627afa
--- /dev/null
+++ b/mlir/tools/mlir-format/mlir-format.cpp
@@ -0,0 +1,385 @@
+#include "mlir/AsmParser/AsmParser.h"
+#include "mlir/AsmParser/AsmParserState.h"
+#include "mlir/IR/AsmState.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/InitAllDialects.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/FileUtilities.h"
+#include "mlir/Tools/ParseUtilities.h"
+#include "llvm/ADT/RewriteBuffer.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/InitLLVM.h"
+#include "llvm/Support/LineIterator.h"
+#include "llvm/Support/Regex.h"
+#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/ToolOutputFile.h"
+
+using namespace mlir;
+
+namespace mlir {
+using OperationDefinition = AsmParserState::OperationDefinition;
+using BlockDefinition = AsmParserState::BlockDefinition;
+using SMDefinition = AsmParserState::SMDefinition;
+
+inline int asMainReturnCode(LogicalResult r) {
+ return r.succeeded() ? EXIT_SUCCESS : EXIT_FAILURE;
+}
+
+/// Return the source code associated with the OperationDefinition.
+SMRange getOpRange(const OperationDefinition &op) {
+ const char *startOp = op.scopeLoc.Start.getPointer();
+ const char *endOp = op.scopeLoc.End.getPointer();
+
+ for (const auto &res : op.resultGroups) {
+ SMRange range = res.definition.loc;
+ startOp = std::min(startOp, range.Start.getPointer());
+ }
+ return {SMLoc::getFromPointer(startOp), SMLoc::getFromPointer(endOp)};
+}
+
+class CombinedOpDefIterator {
+public:
+ using BaseIterator = AsmParserState::OperationDefIterator;
+ using value_type = std::pair<OperationDefinition &, OperationDefinition &>;
+
+ // Constructor
+ CombinedOpDefIterator(BaseIterator opIter, BaseIterator fmtIter)
+ : opIter(opIter), fmtIter(fmtIter) {}
+
+ // Dereference operator to return a pair of references
+ value_type operator*() const { return {*opIter, *fmtIter}; }
+
+ // Increment operator
+ CombinedOpDefIterator &operator++() {
+ ++opIter;
+ ++fmtIter;
+ return *this;
+ }
+
+ // Equality operator
+ bool operator==(const CombinedOpDefIterator &other) const {
+ return opIter == other.opIter && fmtIter == other.fmtIter;
+ }
+
+ // Inequality operator
+ bool operator!=(const CombinedOpDefIterator &other) const {
+ return !(*this == other);
+ }
+
+private:
+ BaseIterator opIter;
+ BaseIterator fmtIter;
+};
+
+// Function to find the character before the previous comma
+const char *findPrevComma(const char *start, const char *stop_point) {
+ if (!start) {
+ llvm::errs() << "Error: Input pointer is null.\n";
+ return nullptr;
+ }
+
+ const char *current = start - 1; // Start checking backwards
+ while (current >= stop_point) {
+ if (*current == ',') {
+ return current;
+ }
+ --current;
+ }
+
+ llvm::errs() << "Error: No previous comma found before provided pointer.\n";
+ return nullptr;
+}
+
+// Function to find the next closing parenthesis
+const char *findNextCloseParenth(const char *start) {
+ if (!start) {
+ llvm::errs() << "Error: Input pointer is null.\n";
+ return nullptr;
+ }
+
+ while (*start != '\0') { // Traverse until null terminator
+ if (*start == ')') {
+ return start; // Return pointer to the closing parenthesis
+ }
+ ++start;
+ }
+
+ llvm::errs() << "Error: No closing parenthesis found in the string.\n";
+ return nullptr;
+}
+
+class Formatter {
+public:
+ static std::unique_ptr<Formatter> init(StringRef inputFilename,
+ StringRef outputFilename);
+
+ /// Return the OperationDefinition's of the operations parsed.
+ iterator_range<AsmParserState::OperationDefIterator> getOpDefs() {
+ return asmState.getOpDefs();
+ }
+
+ /// Return the OperationDefinition's of the blocks parsed.
+ iterator_range<AsmParserState::BlockDefIterator> getBlockDefs() {
+ return asmState.getBlockDefs();
+ }
+
+ /// Return a pair iterators of OperationDefinitions for the asmState and
+ /// fmtAsmState.
+ iterator_range<CombinedOpDefIterator> getCombinedOpDefs() {
+ auto opBegin = asmState.getOpDefs().begin();
+ auto opEnd = asmState.getOpDefs().end();
+ auto fmtBegin = fmtAsmState.getOpDefs().begin();
+ auto fmtEnd = fmtAsmState.getOpDefs().end();
+
+ assert(std::distance(opBegin, opEnd) == std::distance(fmtBegin, fmtEnd) &&
+ "Both iterators must have the same length");
+
+ return llvm::make_range(CombinedOpDefIterator(opBegin, fmtBegin),
+ CombinedOpDefIterator(opEnd, fmtEnd));
+ }
+
+ /// Print the parsed operations to the provided output stream.
+ void printOps(raw_ostream &os) {
+ // Iterate over each operation in the parsedIR block and print it.
+ for (Operation &op : parsedIR) {
+ op.print(os);
+ os << "\n";
+ }
+ }
+
+ /// Insert the specified string at the specified location in the original
+ /// buffer.
+ void insertText(SMLoc pos, StringRef str, bool insertAfter = true) {
+ rewriteBuffer.InsertText(pos.getPointer() - start, str, insertAfter);
+ }
+
+ /// Replace the range of the source text with the corresponding string in the
+ /// output.
+ void replaceRange(SMRange range, StringRef str) {
+ rewriteBuffer.ReplaceText(range.Start.getPointer() - start,
+ range.End.getPointer() - range.Start.getPointer(),
+ str);
+ }
+
+ void replaceRangeFmt(SMRange range, StringRef str) {
+ fmtRewriteBuffer.ReplaceText(
+ range.Start.getPointer() - start,
+ range.End.getPointer() - range.Start.getPointer(), str);
+ }
+
+ /// Write to stream the result of applying all changes to the
+ /// original buffer.
+ /// Note that it isn't safe to use this function to overwrite memory mapped
+ /// files in-place (PR17960).
+ ///
+ /// The original buffer is not actually changed.
+ raw_ostream &write(raw_ostream &stream) const {
+ return rewriteBuffer.write(stream);
+ }
+
+ raw_ostream &writeFmt(raw_ostream &stream) const {
+ return fmtRewriteBuffer.write(stream);
+ }
+
+ /// Generate formatAsmState from the rewriteBuffer
+ void formatOps();
+
+private:
+ // The context and state required to parse.
+ MLIRContext context;
+ llvm::SourceMgr sourceMgr;
+ llvm::SourceMgr fmtSourceMgr;
+ DialectRegistry registry;
+ FallbackAsmResourceMap fallbackResourceMap;
+
+ // Storage of textual parsing results.
+ AsmParserState asmState;
+
+ // Storage of initial formatted ops.
+ AsmParserState fmtAsmState;
+
+ // Parsed IR.
+ Block parsedIR;
+ Block parsedFmtIR;
+
+ // The RewriteBuffer is doing most of the real work.
+ llvm::RewriteBuffer rewriteBuffer;
+ llvm::RewriteBuffer fmtRewriteBuffer;
+
+ // Start of the original input, used to compute offset.
+ const char *start;
+};
+
+std::unique_ptr<Formatter> Formatter::init(StringRef inputFilename,
+ StringRef outputFilename) {
+
+ std::unique_ptr<Formatter> f = std::make_unique<Formatter>();
+ // Register all the dialects needed.
+ registerAllDialects(f->registry);
+
+ // Set up the input file.
+ std::string errorMessage;
+ std::unique_ptr<llvm::MemoryBuffer> file =
+ openInputFile(inputFilename, &errorMessage);
+ if (!file) {
+ llvm::errs() << errorMessage << "\n";
+ return nullptr;
+ }
+ f->sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc());
+
+ // Set up the MLIR context and error handling.
+ f->context.appendDialectRegistry(f->registry);
+
+ // Record the start of the buffer to compute offsets with.
+ unsigned curBuf = f->sourceMgr.getMainFileID();
+ const llvm::MemoryBuffer *curMB = f->sourceMgr.getMemoryBuffer(curBuf);
+ f->start = curMB->getBufferStart();
+ f->rewriteBuffer.Initialize(curMB->getBuffer());
+ f->fmtRewriteBuffer.Initialize(curMB->getBuffer());
+
+ // Parse and populate the AsmParserState.
+ ParserConfig parseConfig(&f->context, /*verifyAfterParse=*/true,
+ &f->fallbackResourceMap);
+ // Always allow unregistered.
+ f->context.allowUnregisteredDialects(true);
+ if (failed(parseAsmSourceFile(f->sourceMgr, &f->parsedIR, parseConfig,
+ &f->asmState)))
+ return nullptr;
+
+ return f;
+}
+
+void Formatter::formatOps() {
+ // Generate formatAsmState from the rewriteBuffer
+ ParserConfig parseConfig(&context, /*verifyAfterParse=*/true,
+ &fallbackResourceMap);
+
+ // Write the rewriteBuffer to a stream, that we can then parse
+ std::string bufferContent;
+ llvm::raw_string_ostream stream(bufferContent);
+ rewriteBuffer.write(stream);
+ stream.flush();
+
+ // Print the bufferContent to llvm::outs() for debugging.
+ fmtSourceMgr.AddNewSourceBuffer(
+ llvm::MemoryBuffer::getMemBufferCopy(bufferContent), SMLoc());
+
+ // Parse and populate the forrmat AsmParserState.
+ if (failed(parseAsmSourceFile(fmtSourceMgr, &parsedFmtIR, parseConfig,
+ &fmtAsmState)))
+ return;
+
+ // Insert the formatted ops. Block args should be untouched,
+ // and their references will use the correct SSA ID.
+ for (auto [opDef, fmtDef] : getCombinedOpDefs()) {
+ auto [startOp, endOp] = getOpRange(opDef);
+
+ // Skip if the op is a FuncOp (we format ops in its body)
+ // or a ReturnOp (we want to keep the user's preference for
+ // `func.return` or plain `return`)
+ if (llvm::dyn_cast<mlir::func::FuncOp>(fmtDef.op))
+ continue;
+ else if (llvm::dyn_cast<mlir::func::ReturnOp>(fmtDef.op))
+ continue;
+
+ // Print the fmtDef op and store as a string.
+ // Replace the opDef with this formatted string.
+ std::string formattedStr;
+ llvm::raw_string_ostream stream(formattedStr);
+ fmtDef.op->print(stream);
+
+ // Replacing the range:
+ replaceRangeFmt({startOp, endOp}, formattedStr);
+ }
+
+ // Write the updated buffer to llvm::outs()
+ writeFmt(llvm::outs());
+}
+
+void markNames(Formatter &formatState, raw_ostream &os) {
+ // Get the operation definitions from the AsmParserState.
+ for (OperationDefinition &it : formatState.getOpDefs()) {
+ auto [startOp, endOp] = getOpRange(it);
+ // loop through the resultgroups
+ for (auto &resultGroup : it.resultGroups) {
+ auto def = resultGroup.definition;
+ auto sm_range = def.loc;
+ const char *start = sm_range.Start.getPointer();
+ int len = sm_range.End.getPointer() - start;
+ // Drop the % prefix, and put in new string with `loc("name")` format.
+ auto name = StringRef(start + 1, len - 1);
+
+ // Add loc("{name}") to the end of the op
+ std::string formattedStr = " loc(\"" + name.str() + "\")";
+ StringRef namedLoc(formattedStr);
+ formatState.insertText(endOp, namedLoc);
+ }
+ }
+
+ // Insert the NameLocs for the block arguments
+ for (BlockDefinition &block : formatState.getBlockDefs()) {
+ for (size_t i = 0; i < block.arguments.size(); ++i) {
+ SMDefinition &arg = block.arguments[i];
+
+ // Find where to insert the NameLoc. Either before the next argument,
+ // or at the end of the arg list
+ const char *insertPointPtr;
+ const char *arg_end = arg.loc.End.getPointer();
+ SMDefinition *nextArg =
+ (i + 1 < block.arguments.size()) ? &block.arguments[i + 1] : nullptr;
+ if (nextArg) {
+ const char *nextStart = nextArg->loc.Start.getPointer();
+ insertPointPtr = findPrevComma(nextStart, arg_end);
+ } else {
+ insertPointPtr = findNextCloseParenth(arg.loc.End.getPointer());
+ }
+
+ // Drop the % prefix, and put in new string with `loc("name")` format.
+ const char *start = arg.loc.Start.getPointer();
+ const int len = arg_end - start;
+ auto name = StringRef(start + 1, len - 1);
+ std::string formattedStr = " loc(\"" + name.str() + "\")";
+ StringRef namedLoc(formattedStr);
+ formatState.insertText(SMLoc::getFromPointer(insertPointPtr), namedLoc);
+ }
+ }
+}
+} // namespace mlir
+
+int main(int argc, char **argv) {
+ registerAsmPrinterCLOptions();
+ static llvm::cl::opt<std::string> inputFilename(
+ llvm::cl::Positional, llvm::cl::desc("<input file>"),
+ llvm::cl::init("-"));
+
+ static llvm::cl::opt<std::string> outputFilename(
+ "o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
+ llvm::cl::init("-"));
+
+ llvm::cl::opt<bool> nameLocOnly{
+ "insert-name-loc-only", llvm::cl::init(false),
+ llvm::cl::desc("Only return a buffer with the NameLocs appended")};
+
+ std::string helpHeader = "mlir-format";
+
+ llvm::cl::ParseCommandLineOptions(argc, argv, helpHeader);
+
+ // Set up formatter buffer.
+ auto f = Formatter::init(inputFilename, outputFilename);
+
+ // Append the SSA names as NameLocs
+ markNames(*f, llvm::outs());
+
+ if (nameLocOnly) {
+ // Return the original buffer with NameLocs appended to ops
+ // e.g., `%alice = memref.load %0[] : memref<i32> loc("alice")`
+ f->write(llvm::outs());
+ return mlir::asMainReturnCode(mlir::success());
+ }
+
+ f->formatOps();
+
+ return mlir::asMainReturnCode(mlir::success());
+}
>From 68ae2fadf2bb0c11152c6223263de6fe330b0214 Mon Sep 17 00:00:00 2001
From: Perry Gibson <perry at fractile.ai>
Date: Mon, 30 Dec 2024 10:54:43 +0000
Subject: [PATCH 2/2] Add support for type alias via ad-hoc parser
---
mlir/test/mlir-format/type_alias.mlir | 21 ++
mlir/tools/mlir-format/mlir-format.cpp | 275 ++++++++++++++++++++-----
2 files changed, 247 insertions(+), 49 deletions(-)
create mode 100644 mlir/test/mlir-format/type_alias.mlir
diff --git a/mlir/test/mlir-format/type_alias.mlir b/mlir/test/mlir-format/type_alias.mlir
new file mode 100644
index 00000000000000..10a8a5a532eace
--- /dev/null
+++ b/mlir/test/mlir-format/type_alias.mlir
@@ -0,0 +1,21 @@
+// RUN: mlir-format %s --mlir-use-nameloc-as-prefix | FileCheck %s
+
+// CHECK: !funky64 = f64
+!funky64 = f64
+// CHECK: !fancy64 = f64
+!fancy64 = f64
+
+// CHECK: func.func @add_one(%b: f643) -> (f64, !funky64, !fancy64) {
+func.func @add_one(%b: f64) -> (f64, !funky64, !fancy64) {
+ // CHECK: %c = arith.constant 1.00000e+00 : !funky64
+ %c = arith.constant 1.00000e+00 : !funky64
+ // CHECK: %x1 = arith.addf %b, %c : f64
+ %x1 = arith.addf %b,
+ %c : f64
+ // CHECK: %x2 = arith.addf %b, %b : !funky64
+ %x2 = arith.addf %b, %b : !funky64
+ // CHECK: %x3 = arith.addf %x2, %b : !fancy64
+ %x3 = arith.addf %x2, %b : !fancy64
+ // CHECK: return %x1, %x2, %x3 : f64, !funky64, !fancy64
+ return %x1, %x2, %x3 : f64, !funky64, !fancy64
+}
diff --git a/mlir/tools/mlir-format/mlir-format.cpp b/mlir/tools/mlir-format/mlir-format.cpp
index ae0122a6627afa..fa187ff64f0d37 100644
--- a/mlir/tools/mlir-format/mlir-format.cpp
+++ b/mlir/tools/mlir-format/mlir-format.cpp
@@ -73,6 +73,157 @@ class CombinedOpDefIterator {
BaseIterator fmtIter;
};
+// Given the scopeLoc of an operation, extract src locations of the input and
+// output type
+std::pair<SmallVector<llvm::SMRange>, SmallVector<llvm::SMRange>>
+getOpTypeLoc(llvm::SMRange op_loc) {
+ SmallVector<llvm::SMRange> inputTypeRanges;
+ SmallVector<llvm::SMRange> outputTypeRanges;
+
+ // Extract the string from the range
+ const char *startPtr = op_loc.Start.getPointer();
+ const char *endPtr = op_loc.End.getPointer();
+ StringRef opString(startPtr, endPtr - startPtr);
+
+ // Find the position of the last ':' in the string
+ size_t colonPos = opString.rfind(':');
+ if (colonPos == StringRef::npos) {
+ // No ':' found, return empty vectors
+ return {inputTypeRanges, outputTypeRanges};
+ }
+
+ // Extract the type definition substring
+ StringRef typeDefStr = opString.substr(colonPos + 1).trim();
+
+ // Check if the type definition substring contains '->' (input -> output
+ // types)
+ size_t arrowPos = typeDefStr.find("->");
+
+ if (arrowPos != StringRef::npos) {
+ // Split into input and output type strings
+ StringRef inputTypeStr = typeDefStr.substr(0, arrowPos).trim();
+ StringRef outputTypeStr = typeDefStr.substr(arrowPos + 2).trim();
+
+ // Parse input type ranges (if any)
+ if (!inputTypeStr.empty() && inputTypeStr != "()") {
+ SmallVector<StringRef> inputTypeParts;
+ inputTypeStr
+ .drop_front() // Remove '('
+ .drop_back() // Remove ')'
+ .split(inputTypeParts, ',');
+
+ for (const auto &typeStr : inputTypeParts) {
+ const char *start = typeStr.trim().data();
+ const char *end = start + typeStr.trim().size();
+ inputTypeRanges.push_back(
+ llvm::SMRange(llvm::SMLoc::getFromPointer(start),
+ llvm::SMLoc::getFromPointer(end)));
+ }
+ }
+
+ // Parse output type ranges (if any)
+ if (!outputTypeStr.empty() && outputTypeStr != "()") {
+ SmallVector<StringRef> outputTypeParts;
+ outputTypeStr.split(outputTypeParts, ',');
+
+ for (const auto &typeStr : outputTypeParts) {
+ const char *start = typeStr.trim().data();
+ const char *end = start + typeStr.trim().size();
+ outputTypeRanges.push_back(
+ llvm::SMRange(llvm::SMLoc::getFromPointer(start),
+ llvm::SMLoc::getFromPointer(end)));
+ }
+ }
+ } else {
+ // Single type definition (no '->'), assume it's an output type
+ SmallVector<StringRef> typeParts;
+ typeDefStr.split(typeParts, ',');
+
+ for (const auto &typeStr : typeParts) {
+ const char *start = typeStr.trim().data();
+ const char *end = start + typeStr.trim().size();
+ outputTypeRanges.push_back(
+ llvm::SMRange(llvm::SMLoc::getFromPointer(start),
+ llvm::SMLoc::getFromPointer(end)));
+ }
+ }
+
+ return {inputTypeRanges, outputTypeRanges};
+}
+
+llvm::SMRange getSMRangeFromString(const std::string &str) {
+ const char *startPtr = str.data();
+ const char *endPtr = startPtr + str.size();
+ return llvm::SMRange(llvm::SMLoc::getFromPointer(startPtr),
+ llvm::SMLoc::getFromPointer(endPtr));
+}
+
+void replaceTypesInString(std::string &formattedStr,
+ const SmallVector<llvm::SMRange> &inputTypes,
+ const SmallVector<llvm::SMRange> &outputTypes) {
+ // Get type locations from the formatted string
+ llvm::SMRange formattedLoc = getSMRangeFromString(formattedStr);
+ auto formattedTypes = getOpTypeLoc(formattedLoc);
+
+ // Ensure the number of types matches
+ if (inputTypes.size() != formattedTypes.first.size() ||
+ outputTypes.size() != formattedTypes.second.size()) {
+ llvm::errs() << "Error: Mismatched number of input/output types in "
+ "replacement operation.\n";
+ return;
+ }
+
+ // Perform input type replacements backwards to avoid index issues
+ for (size_t i = inputTypes.size(); i-- > 0;) {
+ const llvm::SMRange &formattedRange = formattedTypes.first[i];
+ const llvm::SMRange &inputRange = inputTypes[i];
+
+ const char *formattedStart = formattedRange.Start.getPointer();
+ const char *formattedEnd = formattedRange.End.getPointer();
+
+ const char *inputStart = inputRange.Start.getPointer();
+ const char *inputEnd = inputRange.End.getPointer();
+
+ llvm::StringRef formattedType(formattedStart,
+ formattedEnd - formattedStart);
+ llvm::StringRef inputType(inputStart, inputEnd - inputStart);
+
+ // Replace in the formatted string
+ size_t pos = formattedStr.find(formattedType.str());
+ if (pos != std::string::npos) {
+ formattedStr.replace(pos, formattedType.size(), inputType.str());
+ } else {
+ llvm::errs() << "Warning: Input type not found in formatted string: "
+ << formattedType << "\n";
+ }
+ }
+
+ // Perform output type replacements backwards to avoid index issues
+ for (size_t i = outputTypes.size(); i-- > 0;) {
+ const llvm::SMRange &formattedRange = formattedTypes.second[i];
+ const llvm::SMRange &outputRange = outputTypes[i];
+
+ const char *formattedStart = formattedRange.Start.getPointer();
+ const char *formattedEnd = formattedRange.End.getPointer();
+
+ const char *outputStart = outputRange.Start.getPointer();
+ const char *outputEnd = outputRange.End.getPointer();
+
+ llvm::StringRef formattedType(formattedStart,
+ formattedEnd - formattedStart);
+ llvm::StringRef outputType(outputStart, outputEnd - outputStart);
+
+ // Replace in the formatted string
+ size_t pos = formattedStr.find(formattedType.str());
+ if (pos != std::string::npos) {
+ formattedStr.replace(pos, formattedType.size(), outputType.str());
+ } else {
+ llvm::errs() << "Warning: Output type not found in formatted string: "
+ << formattedType << "\n";
+ }
+ }
+}
+
// Function to find the character before the previous comma
const char *findPrevComma(const char *start, const char *stop_point) {
if (!start) {
@@ -256,13 +407,11 @@ void Formatter::formatOps() {
ParserConfig parseConfig(&context, /*verifyAfterParse=*/true,
&fallbackResourceMap);
- // Write the rewriteBuffer to a stream, that we can then parse
std::string bufferContent;
llvm::raw_string_ostream stream(bufferContent);
rewriteBuffer.write(stream);
stream.flush();
- // Print the bufferContent to llvm::outs() for debugging.
fmtSourceMgr.AddNewSourceBuffer(
llvm::MemoryBuffer::getMemBufferCopy(bufferContent), SMLoc());
@@ -285,67 +434,93 @@ void Formatter::formatOps() {
continue;
// Print the fmtDef op and store as a string.
- // Replace the opDef with this formatted string.
std::string formattedStr;
llvm::raw_string_ostream stream(formattedStr);
fmtDef.op->print(stream);
- // Replacing the range:
+ // Use the original type aliases
+ auto orig_types = getOpTypeLoc(opDef.scopeLoc);
+ replaceTypesInString(formattedStr, orig_types.first, orig_types.second);
+
+ // Replace the opDef with this formatted string.
replaceRangeFmt({startOp, endOp}, formattedStr);
+
+ // Write the updated buffer to llvm::outs()
+ writeFmt(llvm::outs());
}
- // Write the updated buffer to llvm::outs()
- writeFmt(llvm::outs());
-}
+ std::string getNamedLoc(
+ const OperationDefinition::ResultGroupDefinition &resultGroup) {
+ auto sm_range = resultGroup.definition.loc;
+ const char *start = sm_range.Start.getPointer();
+ const int len = sm_range.End.getPointer() - start;
-void markNames(Formatter &formatState, raw_ostream &os) {
- // Get the operation definitions from the AsmParserState.
- for (OperationDefinition &it : formatState.getOpDefs()) {
- auto [startOp, endOp] = getOpRange(it);
- // loop through the resultgroups
- for (auto &resultGroup : it.resultGroups) {
- auto def = resultGroup.definition;
- auto sm_range = def.loc;
- const char *start = sm_range.Start.getPointer();
- int len = sm_range.End.getPointer() - start;
- // Drop the % prefix, and put in new string with `loc("name")` format.
- auto name = StringRef(start + 1, len - 1);
-
- // Add loc("{name}") to the end of the op
- std::string formattedStr = " loc(\"" + name.str() + "\")";
- StringRef namedLoc(formattedStr);
- formatState.insertText(endOp, namedLoc);
- }
+ // Drop the '%' prefix and construct the `loc("name")` string
+ auto name = llvm::StringRef(start + 1,
+ len - 1); // Assumes the '%' is always present
+ std::string formattedStr = " loc(\"" + name.str() + "\")";
+
+ return formattedStr;
}
- // Insert the NameLocs for the block arguments
- for (BlockDefinition &block : formatState.getBlockDefs()) {
- for (size_t i = 0; i < block.arguments.size(); ++i) {
- SMDefinition &arg = block.arguments[i];
-
- // Find where to insert the NameLoc. Either before the next argument,
- // or at the end of the arg list
- const char *insertPointPtr;
- const char *arg_end = arg.loc.End.getPointer();
- SMDefinition *nextArg =
- (i + 1 < block.arguments.size()) ? &block.arguments[i + 1] : nullptr;
- if (nextArg) {
- const char *nextStart = nextArg->loc.Start.getPointer();
- insertPointPtr = findPrevComma(nextStart, arg_end);
+ // To handle ops with multiple result groups, create a dummy "alias" op
+ // so that we can each group its own NameLoc
+ void insertAliasOp() {}
+
+ LogicalResult markNames(Formatter & formatState, raw_ostream & os) {
+ // Get the operation definitions from the AsmParserState.
+ for (OperationDefinition &it : formatState.getOpDefs()) {
+ auto [startOp, endOp] = getOpRange(it);
+
+ if (it.resultGroups.size() == 1) {
+ // Simple case, where we have only one result group for the op,
+ // e.g., `%v = op` or `%v:2 = op`
+ auto resultGroup = it.resultGroups[0];
+ auto nameLoc = getNamedLoc(resultGroup);
+ formatState.insertText(endOp, StringRef(nameLoc));
} else {
- insertPointPtr = findNextCloseParenth(arg.loc.End.getPointer());
+ // Complex case, where we have more than one result group, e.g.,
+ // `%x, %y = op` or `%xs:2, %ys:3 = op`.
+ // In this case we need insert some aliasing ops.
+ for (auto &resultGroup : it.resultGroups) {
+ auto nameLoc = getNamedLoc(resultGroup);
+ // StringRef namedLoc(getNamedLoc(resultGroup));
+ llvm::errs() << "Not implemented yet\n";
+ return failure();
+ }
}
+ }
- // Drop the % prefix, and put in new string with `loc("name")` format.
- const char *start = arg.loc.Start.getPointer();
- const int len = arg_end - start;
- auto name = StringRef(start + 1, len - 1);
- std::string formattedStr = " loc(\"" + name.str() + "\")";
- StringRef namedLoc(formattedStr);
- formatState.insertText(SMLoc::getFromPointer(insertPointPtr), namedLoc);
+ // Insert the NameLocs for the block arguments
+ for (BlockDefinition &block : formatState.getBlockDefs()) {
+ for (size_t i = 0; i < block.arguments.size(); ++i) {
+ SMDefinition &arg = block.arguments[i];
+
+ // Find where to insert the NameLoc. Either before the next argument,
+ // or at the end of the arg list
+ const char *insertPointPtr;
+ const char *arg_end = arg.loc.End.getPointer();
+ SMDefinition *nextArg = (i + 1 < block.arguments.size())
+ ? &block.arguments[i + 1]
+ : nullptr;
+ if (nextArg) {
+ const char *nextStart = nextArg->loc.Start.getPointer();
+ insertPointPtr = findPrevComma(nextStart, arg_end);
+ } else {
+ insertPointPtr = findNextCloseParenth(arg.loc.End.getPointer());
+ }
+
+ // Drop the % prefix, and put in new string with `loc("name")` format.
+ const char *start = arg.loc.Start.getPointer();
+ const int len = arg_end - start;
+ auto name = StringRef(start + 1, len - 1);
+ std::string formattedStr = " loc(\"" + name.str() + "\")";
+ StringRef namedLoc(formattedStr);
+ formatState.insertText(SMLoc::getFromPointer(insertPointPtr), namedLoc);
+ }
}
+ return success();
}
-}
} // namespace mlir
int main(int argc, char **argv) {
@@ -370,7 +545,9 @@ int main(int argc, char **argv) {
auto f = Formatter::init(inputFilename, outputFilename);
// Append the SSA names as NameLocs
- markNames(*f, llvm::outs());
+ LogicalResult result = markNames(*f, llvm::outs());
+ if (!succeeded(result))
+ return mlir::asMainReturnCode(mlir::failure());
if (nameLocOnly) {
// Return the original buffer with NameLocs appended to ops
More information about the Mlir-commits
mailing list