[Mlir-commits] [mlir] f3502af - [mlir] Allow passing AsmState when printing Attributes and Types

River Riddle llvmlistbot at llvm.org
Tue Sep 6 14:45:30 PDT 2022


Author: River Riddle
Date: 2022-09-06T14:45:12-07:00
New Revision: f3502afe852693a19848e9e328f2c2a55fc9e9bb

URL: https://github.com/llvm/llvm-project/commit/f3502afe852693a19848e9e328f2c2a55fc9e9bb
DIFF: https://github.com/llvm/llvm-project/commit/f3502afe852693a19848e9e328f2c2a55fc9e9bb.diff

LOG: [mlir] Allow passing AsmState when printing Attributes and Types

This allows for extracting assembly information when printing an attribute
or type, such as the dialect resources referenced. This functionality is used in
a followup that adds resource support to the bytecode. This change also results
in a nice cleanup of AsmPrinter now that we don't need to awkwardly workaround
optional AsmStates.

Differential Revision: https://reviews.llvm.org/D132728

Added: 
    

Modified: 
    mlir/include/mlir/IR/AsmState.h
    mlir/include/mlir/IR/Attributes.h
    mlir/include/mlir/IR/Types.h
    mlir/lib/IR/AsmPrinter.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/AsmState.h b/mlir/include/mlir/IR/AsmState.h
index 51a66310b25d9..d3ef630898037 100644
--- a/mlir/include/mlir/IR/AsmState.h
+++ b/mlir/include/mlir/IR/AsmState.h
@@ -21,6 +21,7 @@
 
 namespace mlir {
 class AsmResourcePrinter;
+class AsmDialectResourceHandle;
 class Operation;
 
 namespace detail {
@@ -455,6 +456,9 @@ class AsmState {
   AsmState(Operation *op,
            const OpPrintingFlags &printerFlags = OpPrintingFlags(),
            LocationMap *locationMap = nullptr);
+  AsmState(MLIRContext *ctx,
+           const OpPrintingFlags &printerFlags = OpPrintingFlags(),
+           LocationMap *locationMap = nullptr);
   ~AsmState();
 
   /// Get the printer flags.
@@ -480,6 +484,11 @@ class AsmState {
         name, std::forward<CallableT>(printFn)));
   }
 
+  /// Returns a map of dialect resources that were referenced when using this
+  /// state to print IR.
+  DenseMap<Dialect *, SetVector<AsmDialectResourceHandle>> &
+  getDialectResources() const;
+
 private:
   AsmState() = delete;
 

diff  --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h
index 6ebb0449da336..a3ee5d2db56eb 100644
--- a/mlir/include/mlir/IR/Attributes.h
+++ b/mlir/include/mlir/IR/Attributes.h
@@ -13,6 +13,7 @@
 #include "llvm/Support/PointerLikeTypeTraits.h"
 
 namespace mlir {
+class AsmState;
 class StringAttr;
 
 /// Attributes are known-constant values of operations.
@@ -76,6 +77,7 @@ class Attribute {
 
   /// Print the attribute.
   void print(raw_ostream &os) const;
+  void print(raw_ostream &os, AsmState &state) const;
   void dump() const;
 
   /// Get an opaque pointer to the attribute.

diff  --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h
index 5cac1e240d653..28cccd15dc8d9 100644
--- a/mlir/include/mlir/IR/Types.h
+++ b/mlir/include/mlir/IR/Types.h
@@ -15,6 +15,8 @@
 #include "llvm/Support/PointerLikeTypeTraits.h"
 
 namespace mlir {
+class AsmState;
+
 /// Instances of the Type class are uniqued, have an immutable identifier and an
 /// optional mutable component.  They wrap a pointer to the storage object owned
 /// by MLIRContext.  Therefore, instances of Type are passed around by value.
@@ -162,6 +164,7 @@ class Type {
 
   /// Print the current type.
   void print(raw_ostream &os) const;
+  void print(raw_ostream &os, AsmState &state) const;
   void dump() const;
 
   friend ::llvm::hash_code hash_value(Type arg);

diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 9cf9501850a6b..841cc56dfa594 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -853,6 +853,7 @@ class SSANameState {
   enum : unsigned { NameSentinel = ~0U };
 
   SSANameState(Operation *op, const OpPrintingFlags &printerFlags);
+  SSANameState() = default;
 
   /// Print the SSA identifier for the given value to 'stream'. If
   /// 'printResultNo' is true, it also presents the result number ('#' number)
@@ -1282,6 +1283,9 @@ class AsmStateImpl {
                         AsmState::LocationMap *locationMap)
       : interfaces(op->getContext()), nameState(op, printerFlags),
         printerFlags(printerFlags), locationMap(locationMap) {}
+  explicit AsmStateImpl(MLIRContext *ctx, const OpPrintingFlags &printerFlags,
+                        AsmState::LocationMap *locationMap)
+      : interfaces(ctx), printerFlags(printerFlags), locationMap(locationMap) {}
 
   /// Initialize the alias state to enable the printing of aliases.
   void initializeAliases(Operation *op) {
@@ -1315,6 +1319,12 @@ class AsmStateImpl {
       (*locationMap)[op] = std::make_pair(line, col);
   }
 
+  /// Return the referenced dialect resources within the printer.
+  DenseMap<Dialect *, SetVector<AsmDialectResourceHandle>> &
+  getDialectResources() {
+    return dialectResources;
+  }
+
 private:
   /// Collection of OpAsm interfaces implemented in the context.
   DialectInterfaceCollection<OpAsmDialectInterface> interfaces;
@@ -1322,6 +1332,9 @@ class AsmStateImpl {
   /// A collection of non-dialect resource printers.
   SmallVector<std::unique_ptr<AsmResourcePrinter>> externalResourcePrinters;
 
+  /// A set of dialect resources that were referenced during printing.
+  DenseMap<Dialect *, SetVector<AsmDialectResourceHandle>> dialectResources;
+
   /// The state used for attribute and type aliases.
   AliasState aliasState;
 
@@ -1379,6 +1392,9 @@ AsmState::AsmState(Operation *op, const OpPrintingFlags &printerFlags,
                    LocationMap *locationMap)
     : impl(std::make_unique<AsmStateImpl>(
           op, verifyOpAndAdjustFlags(op, printerFlags), locationMap)) {}
+AsmState::AsmState(MLIRContext *ctx, const OpPrintingFlags &printerFlags,
+                   LocationMap *locationMap)
+    : impl(std::make_unique<AsmStateImpl>(ctx, printerFlags, locationMap)) {}
 AsmState::~AsmState() = default;
 
 const OpPrintingFlags &AsmState::getPrinterFlags() const {
@@ -1390,6 +1406,11 @@ void AsmState::attachResourcePrinter(
   impl->externalResourcePrinters.emplace_back(std::move(printer));
 }
 
+DenseMap<Dialect *, SetVector<AsmDialectResourceHandle>> &
+AsmState::getDialectResources() const {
+  return impl->getDialectResources();
+}
+
 //===----------------------------------------------------------------------===//
 // AsmPrinter::Impl
 //===----------------------------------------------------------------------===//
@@ -1397,11 +1418,9 @@ void AsmState::attachResourcePrinter(
 namespace mlir {
 class AsmPrinter::Impl {
 public:
-  Impl(raw_ostream &os, OpPrintingFlags flags = llvm::None,
-       AsmStateImpl *state = nullptr)
-      : os(os), printerFlags(flags), state(state) {}
-  explicit Impl(Impl &other)
-      : Impl(other.os, other.printerFlags, other.state) {}
+  Impl(raw_ostream &os, AsmStateImpl &state)
+      : os(os), state(state), printerFlags(state.getPrinterFlags()) {}
+  explicit Impl(Impl &other) : Impl(other.os, other.state) {}
 
   /// Returns the output stream of the printer.
   raw_ostream &getStream() { return os; }
@@ -1446,7 +1465,7 @@ class AsmPrinter::Impl {
   void printResourceHandle(const AsmDialectResourceHandle &resource) {
     auto *interface = cast<OpAsmDialectInterface>(resource.getDialect());
     os << interface->getResourceKey(resource);
-    dialectResources[resource.getDialect()].insert(resource);
+    state.getDialectResources()[resource.getDialect()].insert(resource);
   }
 
   void printAffineMap(AffineMap map);
@@ -1503,17 +1522,14 @@ class AsmPrinter::Impl {
   /// The output stream for the printer.
   raw_ostream &os;
 
+  /// An underlying assembly printer state.
+  AsmStateImpl &state;
+
   /// A set of flags to control the printer's behavior.
   OpPrintingFlags printerFlags;
 
-  /// An optional printer state for the module.
-  AsmStateImpl *state;
-
   /// A tracker for the number of new lines emitted during printing.
   NewLineCounter newLine;
-
-  /// A set of dialect resources that were referenced during printing.
-  DenseMap<Dialect *, SetVector<AsmDialectResourceHandle>> dialectResources;
 };
 } // namespace mlir
 
@@ -1647,7 +1663,7 @@ void AsmPrinter::Impl::printLocation(LocationAttr loc, bool allowAlias) {
     return printLocationInternal(loc, /*pretty=*/true);
 
   os << "loc(";
-  if (!allowAlias || !state || failed(state->getAliasState().getAlias(loc, os)))
+  if (!allowAlias || failed(printAlias(loc)))
     printLocationInternal(loc);
   os << ')';
 }
@@ -1734,11 +1750,11 @@ static void printElidedElementsAttr(raw_ostream &os) {
 }
 
 LogicalResult AsmPrinter::Impl::printAlias(Attribute attr) {
-  return success(state && succeeded(state->getAliasState().getAlias(attr, os)));
+  return state.getAliasState().getAlias(attr, os);
 }
 
 LogicalResult AsmPrinter::Impl::printAlias(Type type) {
-  return success(state && succeeded(state->getAliasState().getAlias(type, os)));
+  return state.getAliasState().getAlias(type, os);
 }
 
 void AsmPrinter::Impl::printAttribute(Attribute attr,
@@ -2068,7 +2084,7 @@ void AsmPrinter::Impl::printType(Type type) {
   }
 
   // Try to print an alias for this type.
-  if (state && succeeded(state->getAliasState().getAlias(type, os)))
+  if (succeeded(printAlias(type)))
     return;
 
   TypeSwitch<Type>(type)
@@ -2242,14 +2258,9 @@ void AsmPrinter::Impl::printDialectAttribute(Attribute attr) {
   std::string attrName;
   {
     llvm::raw_string_ostream attrNameStr(attrName);
-    Impl subPrinter(attrNameStr, printerFlags, state);
+    Impl subPrinter(attrNameStr, state);
     DialectAsmPrinter printer(subPrinter);
     dialect.printAttribute(attr, printer);
-
-    // FIXME: Delete this when we no longer require a nested printer.
-    for (auto &it : subPrinter.dialectResources)
-      for (const auto &resource : it.second)
-        dialectResources[it.first].insert(resource);
   }
   printDialectSymbol(os, "#", dialect.getNamespace(), attrName);
 }
@@ -2261,14 +2272,9 @@ void AsmPrinter::Impl::printDialectType(Type type) {
   std::string typeName;
   {
     llvm::raw_string_ostream typeNameStr(typeName);
-    Impl subPrinter(typeNameStr, printerFlags, state);
+    Impl subPrinter(typeNameStr, state);
     DialectAsmPrinter printer(subPrinter);
     dialect.printType(type, printer);
-
-    // FIXME: Delete this when we no longer require a nested printer.
-    for (auto &it : subPrinter.dialectResources)
-      for (const auto &resource : it.second)
-        dialectResources[it.first].insert(resource);
   }
   printDialectSymbol(os, "!", dialect.getNamespace(), typeName);
 }
@@ -2561,8 +2567,7 @@ class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter {
   using Impl::printType;
 
   explicit OperationPrinter(raw_ostream &os, AsmStateImpl &state)
-      : Impl(os, state.getPrinterFlags(), &state),
-        OpAsmPrinter(static_cast<Impl &>(*this)) {}
+      : Impl(os, state), OpAsmPrinter(static_cast<Impl &>(*this)) {}
 
   /// Print the given top-level operation.
   void printTopLevelOperation(Operation *op);
@@ -2646,7 +2651,7 @@ class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter {
   /// operations. If any entry in namesToUse is null, the corresponding
   /// argument name is left alone.
   void shadowRegionArgs(Region &region, ValueRange namesToUse) override {
-    state->getSSANameState().shadowRegionArgs(region, namesToUse);
+    state.getSSANameState().shadowRegionArgs(region, namesToUse);
   }
 
   /// Print the given affine map with the symbol and dimension operands printed
@@ -2736,14 +2741,14 @@ class OperationPrinter : public AsmPrinter::Impl, private OpAsmPrinter {
 
 void OperationPrinter::printTopLevelOperation(Operation *op) {
   // Output the aliases at the top level that can't be deferred.
-  state->getAliasState().printNonDeferredAliases(os, newLine);
+  state.getAliasState().printNonDeferredAliases(os, newLine);
 
   // Print the module.
   print(op);
   os << newLine;
 
   // Output the aliases at the top level that can be deferred.
-  state->getAliasState().printDeferredAliases(os, newLine);
+  state.getAliasState().printDeferredAliases(os, newLine);
 
   // Output any file level metadata.
   printFileMetadataDictionary(op);
@@ -2795,7 +2800,8 @@ void OperationPrinter::printResourceFileMetadata(
 
   // Print the `dialect_resources` section if we have any dialects with
   // resources.
-  for (const OpAsmDialectInterface &interface : state->getDialectInterfaces()) {
+  for (const OpAsmDialectInterface &interface : state.getDialectInterfaces()) {
+    auto &dialectResources = state.getDialectResources();
     StringRef name = interface.getDialect()->getNamespace();
     auto it = dialectResources.find(interface.getDialect());
     if (it != dialectResources.end())
@@ -2810,7 +2816,7 @@ void OperationPrinter::printResourceFileMetadata(
   // Print the `external_resources` section if we have any external clients with
   // resources.
   hadResource = false;
-  for (const auto &printer : state->getResourcePrinters())
+  for (const auto &printer : state.getResourcePrinters())
     processProvider("external", printer.getName(), printer);
   if (hadResource)
     os << newLine << "  }";
@@ -2836,7 +2842,7 @@ void OperationPrinter::printRegionArgument(BlockArgument arg,
 
 void OperationPrinter::print(Operation *op) {
   // Track the location of this operation.
-  state->registerOperationLocation(op, newLine.curLine, currentIndent);
+  state.registerOperationLocation(op, newLine.curLine, currentIndent);
 
   os.indent(currentIndent);
   printOperation(op);
@@ -2854,7 +2860,7 @@ void OperationPrinter::printOperation(Operation *op) {
     };
 
     // Check to see if this operation has multiple result groups.
-    ArrayRef<int> resultGroups = state->getSSANameState().getOpResultGroups(op);
+    ArrayRef<int> resultGroups = state.getSSANameState().getOpResultGroups(op);
     if (!resultGroups.empty()) {
       // Interleave the groups excluding the last one, this one will be handled
       // separately.
@@ -3010,7 +3016,7 @@ void OperationPrinter::printGenericOp(Operation *op, bool printOpName) {
 }
 
 void OperationPrinter::printBlockName(Block *block) {
-  os << state->getSSANameState().getBlockInfo(block).name;
+  os << state.getSSANameState().getBlockInfo(block).name;
 }
 
 void OperationPrinter::print(Block *block, bool printBlockArgs,
@@ -3048,7 +3054,7 @@ void OperationPrinter::print(Block *block, bool printBlockArgs,
       // whatever order the use-list is in, so gather and sort them.
       SmallVector<BlockInfo, 4> predIDs;
       for (auto *pred : block->getPredecessors())
-        predIDs.push_back(state->getSSANameState().getBlockInfo(pred));
+        predIDs.push_back(state.getSSANameState().getBlockInfo(pred));
       llvm::sort(predIDs, [](BlockInfo lhs, BlockInfo rhs) {
         return lhs.ordering < rhs.ordering;
       });
@@ -3084,14 +3090,14 @@ void OperationPrinter::print(Block *block, bool printBlockArgs,
 
 void OperationPrinter::printValueID(Value value, bool printResultNo,
                                     raw_ostream *streamOverride) const {
-  state->getSSANameState().printValueID(value, printResultNo,
-                                        streamOverride ? *streamOverride : os);
+  state.getSSANameState().printValueID(value, printResultNo,
+                                       streamOverride ? *streamOverride : os);
 }
 
 void OperationPrinter::printOperationID(Operation *op,
                                         raw_ostream *streamOverride) const {
-  state->getSSANameState().printOperationID(op, streamOverride ? *streamOverride
-                                                               : os);
+  state.getSSANameState().printOperationID(op, streamOverride ? *streamOverride
+                                                              : os);
 }
 
 void OperationPrinter::printSuccessor(Block *successor) {
@@ -3176,7 +3182,16 @@ void OperationPrinter::printAffineExprOfSSAIds(AffineExpr expr,
 //===----------------------------------------------------------------------===//
 
 void Attribute::print(raw_ostream &os) const {
-  AsmPrinter::Impl(os).printAttribute(*this);
+  if (!*this) {
+    os << "<<NULL ATTRIBUTE>>";
+    return;
+  }
+
+  AsmState state(getContext());
+  print(os, state);
+}
+void Attribute::print(raw_ostream &os, AsmState &state) const {
+  AsmPrinter::Impl(os, state.getImpl()).printAttribute(*this);
 }
 
 void Attribute::dump() const {
@@ -3185,7 +3200,16 @@ void Attribute::dump() const {
 }
 
 void Type::print(raw_ostream &os) const {
-  AsmPrinter::Impl(os).printType(*this);
+  if (!*this) {
+    os << "<<NULL TYPE>>";
+    return;
+  }
+
+  AsmState state(getContext());
+  print(os, state);
+}
+void Type::print(raw_ostream &os, AsmState &state) const {
+  AsmPrinter::Impl(os, state.getImpl()).printType(*this);
 }
 
 void Type::dump() const { print(llvm::errs()); }
@@ -3205,7 +3229,8 @@ void AffineExpr::print(raw_ostream &os) const {
     os << "<<NULL AFFINE EXPR>>";
     return;
   }
-  AsmPrinter::Impl(os).printAffineExpr(*this);
+  AsmState state(getContext());
+  AsmPrinter::Impl(os, state.getImpl()).printAffineExpr(*this);
 }
 
 void AffineExpr::dump() const {
@@ -3218,11 +3243,13 @@ void AffineMap::print(raw_ostream &os) const {
     os << "<<NULL AFFINE MAP>>";
     return;
   }
-  AsmPrinter::Impl(os).printAffineMap(*this);
+  AsmState state(getContext());
+  AsmPrinter::Impl(os, state.getImpl()).printAffineMap(*this);
 }
 
 void IntegerSet::print(raw_ostream &os) const {
-  AsmPrinter::Impl(os).printIntegerSet(*this);
+  AsmState state(getContext());
+  AsmPrinter::Impl(os, state.getImpl()).printIntegerSet(*this);
 }
 
 void Value::print(raw_ostream &os) { print(os, OpPrintingFlags()); }


        


More information about the Mlir-commits mailing list