[Mlir-commits] [mlir] 988a3ba - [mlir] Expose printer flags in AsmState
Sergei Grechanik
llvmlistbot at llvm.org
Tue Feb 15 17:38:24 PST 2022
Author: Sergei Grechanik
Date: 2022-02-15T17:27:45-08:00
New Revision: 988a3ba0d815110fe4df03cab1077ddef7b23252
URL: https://github.com/llvm/llvm-project/commit/988a3ba0d815110fe4df03cab1077ddef7b23252
DIFF: https://github.com/llvm/llvm-project/commit/988a3ba0d815110fe4df03cab1077ddef7b23252.diff
LOG: [mlir] Expose printer flags in AsmState
This change exposes printer flags in AsmState and AsmStateImpl. All functions
receiving AsmState as a parameter now use the flags from the AsmState instead of
taking an additional OpPrintingFlags parameter.
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D119870
Added:
Modified:
mlir/include/mlir/IR/AsmState.h
mlir/include/mlir/IR/OpDefinition.h
mlir/include/mlir/IR/Operation.h
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/Transforms/LocationSnapshot.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/AsmState.h b/mlir/include/mlir/IR/AsmState.h
index eeb12adff0886..6c3792f867ea6 100644
--- a/mlir/include/mlir/IR/AsmState.h
+++ b/mlir/include/mlir/IR/AsmState.h
@@ -47,6 +47,9 @@ class AsmState {
LocationMap *locationMap = nullptr);
~AsmState();
+ /// Get the printer flags.
+ const OpPrintingFlags &getPrinterFlags() const;
+
/// Return an instance of the internal implementation. Returns nullptr if the
/// state has not been initialized.
detail::AsmStateImpl &getImpl() { return *impl; }
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index e29b35f211857..e38b0cb5db576 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -112,9 +112,8 @@ class OpState {
void print(raw_ostream &os, OpPrintingFlags flags = llvm::None) {
state->print(os, flags);
}
- void print(raw_ostream &os, AsmState &asmState,
- OpPrintingFlags flags = llvm::None) {
- state->print(os, asmState, flags);
+ void print(raw_ostream &os, AsmState &asmState) {
+ state->print(os, asmState);
}
/// Dump this operation.
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index 27a59382076e1..cd33ceb43fd6a 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -192,8 +192,7 @@ class alignas(8) Operation final
bool isBeforeInBlock(Operation *other);
void print(raw_ostream &os, const OpPrintingFlags &flags = llvm::None);
- void print(raw_ostream &os, AsmState &state,
- const OpPrintingFlags &flags = llvm::None);
+ void print(raw_ostream &os, AsmState &state);
void dump();
//===--------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 3346702423c1a..3fc07f37c4ac8 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1216,6 +1216,9 @@ class AsmStateImpl {
/// Get the state used for SSA names.
SSANameState &getSSANameState() { return nameState; }
+ /// Get the printer flags.
+ const OpPrintingFlags &getPrinterFlags() const { return printerFlags; }
+
/// Register the location, line and column, within the buffer that the given
/// operation was printed at.
void registerOperationLocation(Operation *op, unsigned line, unsigned col) {
@@ -1247,6 +1250,10 @@ AsmState::AsmState(Operation *op, const OpPrintingFlags &printerFlags,
: impl(std::make_unique<AsmStateImpl>(op, printerFlags, locationMap)) {}
AsmState::~AsmState() = default;
+const OpPrintingFlags &AsmState::getPrinterFlags() const {
+ return impl->getPrinterFlags();
+}
+
//===----------------------------------------------------------------------===//
// AsmPrinter::Impl
//===----------------------------------------------------------------------===//
@@ -2405,9 +2412,9 @@ class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter {
using Impl = AsmPrinter::Impl;
using Impl::printType;
- explicit OperationPrinter(raw_ostream &os, OpPrintingFlags flags,
- AsmStateImpl &state)
- : Impl(os, flags, &state), OpAsmPrinter(static_cast<Impl &>(*this)) {}
+ explicit OperationPrinter(raw_ostream &os, AsmStateImpl &state)
+ : Impl(os, state.getPrinterFlags(), &state),
+ OpAsmPrinter(static_cast<Impl &>(*this)) {}
/// Print the given top-level operation.
void printTopLevelOperation(Operation *op);
@@ -2893,7 +2900,7 @@ void Operation::print(raw_ostream &os, const OpPrintingFlags &printerFlags) {
if (!getParent() && !printerFlags.shouldUseLocalScope()) {
AsmState state(this, printerFlags);
state.getImpl().initializeAliases(this);
- print(os, state, printerFlags);
+ print(os, state);
return;
}
@@ -2914,12 +2921,11 @@ void Operation::print(raw_ostream &os, const OpPrintingFlags &printerFlags) {
} while (true);
AsmState state(op, printerFlags);
- print(os, state, printerFlags);
+ print(os, state);
}
-void Operation::print(raw_ostream &os, AsmState &state,
- const OpPrintingFlags &flags) {
- OperationPrinter printer(os, flags, state.getImpl());
- if (!getParent() && !flags.shouldUseLocalScope())
+void Operation::print(raw_ostream &os, AsmState &state) {
+ OperationPrinter printer(os, state.getImpl());
+ if (!getParent() && !state.getPrinterFlags().shouldUseLocalScope())
printer.printTopLevelOperation(this);
else
printer.print(this);
@@ -2944,7 +2950,7 @@ void Block::print(raw_ostream &os) {
print(os, state);
}
void Block::print(raw_ostream &os, AsmState &state) {
- OperationPrinter(os, /*flags=*/llvm::None, state.getImpl()).print(this);
+ OperationPrinter(os, state.getImpl()).print(this);
}
void Block::dump() { print(llvm::errs()); }
@@ -2960,6 +2966,6 @@ void Block::printAsOperand(raw_ostream &os, bool printType) {
printAsOperand(os, state);
}
void Block::printAsOperand(raw_ostream &os, AsmState &state) {
- OperationPrinter printer(os, /*flags=*/llvm::None, state.getImpl());
+ OperationPrinter printer(os, state.getImpl());
printer.printBlockName(this);
}
diff --git a/mlir/lib/Transforms/LocationSnapshot.cpp b/mlir/lib/Transforms/LocationSnapshot.cpp
index f23a3eee15113..a042d07335bbc 100644
--- a/mlir/lib/Transforms/LocationSnapshot.cpp
+++ b/mlir/lib/Transforms/LocationSnapshot.cpp
@@ -27,7 +27,7 @@ static void generateLocationsFromIR(raw_ostream &os, StringRef fileName,
// Print the IR to the stream, and collect the raw line+column information.
AsmState::LocationMap opToLineCol;
AsmState state(op, flags, &opToLineCol);
- op->print(os, state, flags);
+ op->print(os, state);
Builder builder(op->getContext());
Optional<StringAttr> tagIdentifier;
More information about the Mlir-commits
mailing list