[Mlir-commits] [mlir] 988a3ba - [mlir] Expose printer flags in AsmState

Sergei Grechanik llvmlistbot at llvm.org
Tue Feb 15 17:38:24 PST 2022


Author: Sergei Grechanik
Date: 2022-02-15T17:27:45-08:00
New Revision: 988a3ba0d815110fe4df03cab1077ddef7b23252

URL: https://github.com/llvm/llvm-project/commit/988a3ba0d815110fe4df03cab1077ddef7b23252
DIFF: https://github.com/llvm/llvm-project/commit/988a3ba0d815110fe4df03cab1077ddef7b23252.diff

LOG: [mlir] Expose printer flags in AsmState

This change exposes printer flags in AsmState and AsmStateImpl. All functions
receiving AsmState as a parameter now use the flags from the AsmState instead of
taking an additional OpPrintingFlags parameter.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D119870

Added: 
    

Modified: 
    mlir/include/mlir/IR/AsmState.h
    mlir/include/mlir/IR/OpDefinition.h
    mlir/include/mlir/IR/Operation.h
    mlir/lib/IR/AsmPrinter.cpp
    mlir/lib/Transforms/LocationSnapshot.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/AsmState.h b/mlir/include/mlir/IR/AsmState.h
index eeb12adff0886..6c3792f867ea6 100644
--- a/mlir/include/mlir/IR/AsmState.h
+++ b/mlir/include/mlir/IR/AsmState.h
@@ -47,6 +47,9 @@ class AsmState {
            LocationMap *locationMap = nullptr);
   ~AsmState();
 
+  /// Get the printer flags.
+  const OpPrintingFlags &getPrinterFlags() const;
+
   /// Return an instance of the internal implementation. Returns nullptr if the
   /// state has not been initialized.
   detail::AsmStateImpl &getImpl() { return *impl; }

diff  --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index e29b35f211857..e38b0cb5db576 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -112,9 +112,8 @@ class OpState {
   void print(raw_ostream &os, OpPrintingFlags flags = llvm::None) {
     state->print(os, flags);
   }
-  void print(raw_ostream &os, AsmState &asmState,
-             OpPrintingFlags flags = llvm::None) {
-    state->print(os, asmState, flags);
+  void print(raw_ostream &os, AsmState &asmState) {
+    state->print(os, asmState);
   }
 
   /// Dump this operation.

diff  --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index 27a59382076e1..cd33ceb43fd6a 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -192,8 +192,7 @@ class alignas(8) Operation final
   bool isBeforeInBlock(Operation *other);
 
   void print(raw_ostream &os, const OpPrintingFlags &flags = llvm::None);
-  void print(raw_ostream &os, AsmState &state,
-             const OpPrintingFlags &flags = llvm::None);
+  void print(raw_ostream &os, AsmState &state);
   void dump();
 
   //===--------------------------------------------------------------------===//

diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 3346702423c1a..3fc07f37c4ac8 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1216,6 +1216,9 @@ class AsmStateImpl {
   /// Get the state used for SSA names.
   SSANameState &getSSANameState() { return nameState; }
 
+  /// Get the printer flags.
+  const OpPrintingFlags &getPrinterFlags() const { return printerFlags; }
+
   /// Register the location, line and column, within the buffer that the given
   /// operation was printed at.
   void registerOperationLocation(Operation *op, unsigned line, unsigned col) {
@@ -1247,6 +1250,10 @@ AsmState::AsmState(Operation *op, const OpPrintingFlags &printerFlags,
     : impl(std::make_unique<AsmStateImpl>(op, printerFlags, locationMap)) {}
 AsmState::~AsmState() = default;
 
+const OpPrintingFlags &AsmState::getPrinterFlags() const {
+  return impl->getPrinterFlags();
+}
+
 //===----------------------------------------------------------------------===//
 // AsmPrinter::Impl
 //===----------------------------------------------------------------------===//
@@ -2405,9 +2412,9 @@ class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter {
   using Impl = AsmPrinter::Impl;
   using Impl::printType;
 
-  explicit OperationPrinter(raw_ostream &os, OpPrintingFlags flags,
-                            AsmStateImpl &state)
-      : Impl(os, flags, &state), OpAsmPrinter(static_cast<Impl &>(*this)) {}
+  explicit OperationPrinter(raw_ostream &os, AsmStateImpl &state)
+      : Impl(os, state.getPrinterFlags(), &state),
+        OpAsmPrinter(static_cast<Impl &>(*this)) {}
 
   /// Print the given top-level operation.
   void printTopLevelOperation(Operation *op);
@@ -2893,7 +2900,7 @@ void Operation::print(raw_ostream &os, const OpPrintingFlags &printerFlags) {
   if (!getParent() && !printerFlags.shouldUseLocalScope()) {
     AsmState state(this, printerFlags);
     state.getImpl().initializeAliases(this);
-    print(os, state, printerFlags);
+    print(os, state);
     return;
   }
 
@@ -2914,12 +2921,11 @@ void Operation::print(raw_ostream &os, const OpPrintingFlags &printerFlags) {
   } while (true);
 
   AsmState state(op, printerFlags);
-  print(os, state, printerFlags);
+  print(os, state);
 }
-void Operation::print(raw_ostream &os, AsmState &state,
-                      const OpPrintingFlags &flags) {
-  OperationPrinter printer(os, flags, state.getImpl());
-  if (!getParent() && !flags.shouldUseLocalScope())
+void Operation::print(raw_ostream &os, AsmState &state) {
+  OperationPrinter printer(os, state.getImpl());
+  if (!getParent() && !state.getPrinterFlags().shouldUseLocalScope())
     printer.printTopLevelOperation(this);
   else
     printer.print(this);
@@ -2944,7 +2950,7 @@ void Block::print(raw_ostream &os) {
   print(os, state);
 }
 void Block::print(raw_ostream &os, AsmState &state) {
-  OperationPrinter(os, /*flags=*/llvm::None, state.getImpl()).print(this);
+  OperationPrinter(os, state.getImpl()).print(this);
 }
 
 void Block::dump() { print(llvm::errs()); }
@@ -2960,6 +2966,6 @@ void Block::printAsOperand(raw_ostream &os, bool printType) {
   printAsOperand(os, state);
 }
 void Block::printAsOperand(raw_ostream &os, AsmState &state) {
-  OperationPrinter printer(os, /*flags=*/llvm::None, state.getImpl());
+  OperationPrinter printer(os, state.getImpl());
   printer.printBlockName(this);
 }

diff  --git a/mlir/lib/Transforms/LocationSnapshot.cpp b/mlir/lib/Transforms/LocationSnapshot.cpp
index f23a3eee15113..a042d07335bbc 100644
--- a/mlir/lib/Transforms/LocationSnapshot.cpp
+++ b/mlir/lib/Transforms/LocationSnapshot.cpp
@@ -27,7 +27,7 @@ static void generateLocationsFromIR(raw_ostream &os, StringRef fileName,
   // Print the IR to the stream, and collect the raw line+column information.
   AsmState::LocationMap opToLineCol;
   AsmState state(op, flags, &opToLineCol);
-  op->print(os, state, flags);
+  op->print(os, state);
 
   Builder builder(op->getContext());
   Optional<StringAttr> tagIdentifier;


        


More information about the Mlir-commits mailing list