[Mlir-commits] [mlir] 2c0f179 - [mlir] Added OpPrintingFlags to AsmState and SSANameState.

Mehdi Amini llvmlistbot at llvm.org
Sat Jul 10 10:20:29 PDT 2021


Author: Itai Zukerman
Date: 2021-07-10T16:40:00Z
New Revision: 2c0f17982f39b14c7ed13069d6ed959ef43d02d9

URL: https://github.com/llvm/llvm-project/commit/2c0f17982f39b14c7ed13069d6ed959ef43d02d9
DIFF: https://github.com/llvm/llvm-project/commit/2c0f17982f39b14c7ed13069d6ed959ef43d02d9.diff

LOG: [mlir] Added OpPrintingFlags to AsmState and SSANameState.

This enables checking the printing flags when formatting names
in SSANameState.

Depends On D105299

Reviewed By: mehdi_amini, bondhugula

Differential Revision: https://reviews.llvm.org/D105300

Added: 
    

Modified: 
    mlir/include/mlir/IR/AsmState.h
    mlir/include/mlir/IR/Operation.h
    mlir/lib/IR/AsmPrinter.cpp
    mlir/lib/Transforms/LocationSnapshot.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/AsmState.h b/mlir/include/mlir/IR/AsmState.h
index 4f8ebba3f4a1f..691e6c71e6206 100644
--- a/mlir/include/mlir/IR/AsmState.h
+++ b/mlir/include/mlir/IR/AsmState.h
@@ -13,6 +13,7 @@
 #ifndef MLIR_IR_ASMSTATE_H_
 #define MLIR_IR_ASMSTATE_H_
 
+#include "mlir/IR/OperationSupport.h"
 #include "mlir/Support/LLVM.h"
 
 #include <memory>
@@ -41,7 +42,9 @@ class AsmState {
 
   /// Initialize the asm state at the level of the given operation. A location
   /// map may optionally be provided to be populated when printing.
-  AsmState(Operation *op, LocationMap *locationMap = nullptr);
+  AsmState(Operation *op,
+           const OpPrintingFlags &printerFlags = OpPrintingFlags(),
+           LocationMap *locationMap = nullptr);
   ~AsmState();
 
   /// Return an instance of the internal implementation. Returns nullptr if the

diff  --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index 6baa799f2c639..8b22695db3d1b 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -204,9 +204,9 @@ class alignas(8) Operation final
   /// take O(N) where N is the number of operations within the parent block.
   bool isBeforeInBlock(Operation *other);
 
-  void print(raw_ostream &os, OpPrintingFlags flags = llvm::None);
+  void print(raw_ostream &os, const OpPrintingFlags &flags = llvm::None);
   void print(raw_ostream &os, AsmState &state,
-             OpPrintingFlags flags = llvm::None);
+             const OpPrintingFlags &flags = llvm::None);
   void dump();
 
   //===--------------------------------------------------------------------===//

diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 49c4e472109c3..d654781c18d8d 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -356,9 +356,9 @@ class AliasInitializer {
 /// in the output, and trims down unnecessary output.
 class DummyAliasOperationPrinter : private OpAsmPrinter {
 public:
-  explicit DummyAliasOperationPrinter(const OpPrintingFlags &flags,
+  explicit DummyAliasOperationPrinter(const OpPrintingFlags &printerFlags,
                                       AliasInitializer &initializer)
-      : printerFlags(flags), initializer(initializer) {}
+      : printerFlags(printerFlags), initializer(initializer) {}
 
   /// Print the given operation.
   void print(Operation *op) {
@@ -767,7 +767,7 @@ class SSANameState {
   /// A sentinel value used for values with names set.
   enum : unsigned { NameSentinel = ~0U };
 
-  SSANameState(Operation *op,
+  SSANameState(Operation *op, const OpPrintingFlags &printerFlags,
                DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
 
   /// Print the SSA identifier for the given value to 'stream'. If
@@ -833,14 +833,18 @@ class SSANameState {
   /// This is the next ID to assign when a name conflict is detected.
   unsigned nextConflictID = 0;
 
+  /// These are the printing flags.  They control, eg., whether to print in
+  /// generic form.
+  OpPrintingFlags printerFlags;
+
   DialectInterfaceCollection<OpAsmDialectInterface> &interfaces;
 };
 } // end anonymous namespace
 
 SSANameState::SSANameState(
-    Operation *op,
+    Operation *op, const OpPrintingFlags &printerFlags,
     DialectInterfaceCollection<OpAsmDialectInterface> &interfaces)
-    : interfaces(interfaces) {
+    : printerFlags(printerFlags), interfaces(interfaces) {
   llvm::SaveAndRestore<unsigned> valueIDSaver(nextValueID);
   llvm::SaveAndRestore<unsigned> argumentIDSaver(nextArgumentID);
   llvm::SaveAndRestore<unsigned> conflictIDSaver(nextConflictID);
@@ -1134,12 +1138,13 @@ namespace mlir {
 namespace detail {
 class AsmStateImpl {
 public:
-  explicit AsmStateImpl(Operation *op, AsmState::LocationMap *locationMap)
-      : interfaces(op->getContext()), nameState(op, interfaces),
-        locationMap(locationMap) {}
+  explicit AsmStateImpl(Operation *op, const OpPrintingFlags &printerFlags,
+                        AsmState::LocationMap *locationMap)
+      : interfaces(op->getContext()), nameState(op, printerFlags, interfaces),
+        printerFlags(printerFlags), locationMap(locationMap) {}
 
   /// Initialize the alias state to enable the printing of aliases.
-  void initializeAliases(Operation *op, const OpPrintingFlags &printerFlags) {
+  void initializeAliases(Operation *op) {
     aliasState.initialize(op, printerFlags, interfaces);
   }
 
@@ -1172,14 +1177,18 @@ class AsmStateImpl {
   /// The state used for SSA value names.
   SSANameState nameState;
 
+  /// Flags that control op output.
+  OpPrintingFlags printerFlags;
+
   /// An optional location map to be populated.
   AsmState::LocationMap *locationMap;
 };
 } // end namespace detail
 } // end namespace mlir
 
-AsmState::AsmState(Operation *op, LocationMap *locationMap)
-    : impl(std::make_unique<AsmStateImpl>(op, locationMap)) {}
+AsmState::AsmState(Operation *op, const OpPrintingFlags &printerFlags,
+                   LocationMap *locationMap)
+    : impl(std::make_unique<AsmStateImpl>(op, printerFlags, locationMap)) {}
 AsmState::~AsmState() {}
 
 //===----------------------------------------------------------------------===//
@@ -2760,18 +2769,18 @@ void Value::printAsOperand(raw_ostream &os, AsmState &state) {
                                                  os);
 }
 
-void Operation::print(raw_ostream &os, OpPrintingFlags flags) {
+void Operation::print(raw_ostream &os, const OpPrintingFlags &printerFlags) {
   // If this is a top level operation, we also print aliases.
-  if (!getParent() && !flags.shouldUseLocalScope()) {
-    AsmState state(this);
-    state.getImpl().initializeAliases(this, flags);
-    print(os, state, flags);
+  if (!getParent() && !printerFlags.shouldUseLocalScope()) {
+    AsmState state(this, printerFlags);
+    state.getImpl().initializeAliases(this);
+    print(os, state, printerFlags);
     return;
   }
 
   // Find the operation to number from based upon the provided flags.
   Operation *op = this;
-  bool shouldUseLocalScope = flags.shouldUseLocalScope();
+  bool shouldUseLocalScope = printerFlags.shouldUseLocalScope();
   do {
     // If we are printing local scope, stop at the first operation that is
     // isolated from above.
@@ -2785,10 +2794,11 @@ void Operation::print(raw_ostream &os, OpPrintingFlags flags) {
     op = parentOp;
   } while (true);
 
-  AsmState state(op);
-  print(os, state, flags);
+  AsmState state(op, printerFlags);
+  print(os, state, printerFlags);
 }
-void Operation::print(raw_ostream &os, AsmState &state, OpPrintingFlags flags) {
+void Operation::print(raw_ostream &os, AsmState &state,
+                      const OpPrintingFlags &flags) {
   OperationPrinter printer(os, flags, state.getImpl());
   if (!getParent() && !flags.shouldUseLocalScope())
     printer.printTopLevelOperation(this);

diff  --git a/mlir/lib/Transforms/LocationSnapshot.cpp b/mlir/lib/Transforms/LocationSnapshot.cpp
index 7d4284a35c2fc..a7e08723bcec9 100644
--- a/mlir/lib/Transforms/LocationSnapshot.cpp
+++ b/mlir/lib/Transforms/LocationSnapshot.cpp
@@ -22,11 +22,11 @@ using namespace mlir;
 /// NameLoc with the given tag as the name, and then fused with the existing
 /// locations. Otherwise, the existing locations are replaced.
 static void generateLocationsFromIR(raw_ostream &os, StringRef fileName,
-                                    Operation *op, OpPrintingFlags flags,
+                                    Operation *op, const OpPrintingFlags &flags,
                                     StringRef tag) {
   // Print the IR to the stream, and collect the raw line+column information.
   AsmState::LocationMap opToLineCol;
-  AsmState state(op, &opToLineCol);
+  AsmState state(op, flags, &opToLineCol);
   op->print(os, state, flags);
 
   Builder builder(op->getContext());


        


More information about the Mlir-commits mailing list