[Mlir-commits] [mlir] 34300ee - [mlir] Add fallback support for parsing/printing unknown external resources

River Riddle llvmlistbot at llvm.org
Tue Sep 13 11:39:44 PDT 2022


Author: River Riddle
Date: 2022-09-13T11:39:20-07:00
New Revision: 34300ee3697e32926e998d1036925d0f59ffae02

URL: https://github.com/llvm/llvm-project/commit/34300ee3697e32926e998d1036925d0f59ffae02
DIFF: https://github.com/llvm/llvm-project/commit/34300ee3697e32926e998d1036925d0f59ffae02.diff

LOG: [mlir] Add fallback support for parsing/printing unknown external resources

This is necessary/useful for building generic tooling that can roundtrip external
resources without needing to explicitly handle them. For example, this allows
for viewing the resources encoded within a bytecode file without having to
explicitly know how to process them (e.g. making it easier to interact with a
reproducer encoded in bytecode).

Differential Revision: https://reviews.llvm.org/D133460

Added: 
    

Modified: 
    mlir/include/mlir/Bytecode/BytecodeWriter.h
    mlir/include/mlir/IR/AsmState.h
    mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
    mlir/lib/IR/AsmPrinter.cpp
    mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp
    mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
    mlir/test/IR/file-metadata-resources.mlir
    mlir/test/IR/invalid-file-metadata.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Bytecode/BytecodeWriter.h b/mlir/include/mlir/Bytecode/BytecodeWriter.h
index cd1d8c71a80de..fb4329e4d66f0 100644
--- a/mlir/include/mlir/Bytecode/BytecodeWriter.h
+++ b/mlir/include/mlir/Bytecode/BytecodeWriter.h
@@ -27,6 +27,10 @@ class BytecodeWriterConfig {
   /// of the bytecode when reading. It has no functional effect on the bytecode
   /// serialization.
   BytecodeWriterConfig(StringRef producer = "MLIR" LLVM_VERSION_STRING);
+  /// `map` is a fallback resource map, which when provided will attach resource
+  /// printers for the fallback resources within the map.
+  BytecodeWriterConfig(FallbackAsmResourceMap &map,
+                       StringRef producer = "MLIR" LLVM_VERSION_STRING);
   ~BytecodeWriterConfig();
 
   /// An internal implementation class that contains the state of the
@@ -53,6 +57,13 @@ class BytecodeWriterConfig {
         name, std::forward<CallableT>(printFn)));
   }
 
+  /// Attach resource printers to the AsmState for the fallback resources
+  /// in the given map.
+  void attachFallbackResourcePrinter(FallbackAsmResourceMap &map) {
+    for (auto &printer : map.getPrinters())
+      attachResourcePrinter(std::move(printer));
+  }
+
 private:
   /// A pointer to allocated storage for the impl state.
   std::unique_ptr<Impl> impl;

diff  --git a/mlir/include/mlir/IR/AsmState.h b/mlir/include/mlir/IR/AsmState.h
index 3ff4cfdfad2f4..87f3a37b637dd 100644
--- a/mlir/include/mlir/IR/AsmState.h
+++ b/mlir/include/mlir/IR/AsmState.h
@@ -16,6 +16,8 @@
 
 #include "mlir/IR/OperationSupport.h"
 #include "mlir/Support/LLVM.h"
+#include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/StringMap.h"
 
 #include <memory>
 
@@ -401,6 +403,50 @@ class AsmResourcePrinter {
   std::string name;
 };
 
+/// A fallback map containing external resources not explicitly handled by
+/// another parser/printer.
+class FallbackAsmResourceMap {
+public:
+  /// This class represents an opaque resource.
+  struct OpaqueAsmResource {
+    OpaqueAsmResource(StringRef key,
+                      std::variant<AsmResourceBlob, bool, std::string> value)
+        : key(key.str()), value(std::move(value)) {}
+
+    /// The key identifying the resource.
+    std::string key;
+    /// An opaque value for the resource, whose variant values align 1-1 with
+    /// the kinds defined in AsmResourceEntryKind.
+    std::variant<AsmResourceBlob, bool, std::string> value;
+  };
+
+  /// Return a parser than can be used for parsing entries for the given
+  /// identifier key.
+  AsmResourceParser &getParserFor(StringRef key);
+
+  /// Build a set of resource printers to print the resources within this map.
+  std::vector<std::unique_ptr<AsmResourcePrinter>> getPrinters();
+
+private:
+  struct ResourceCollection : public AsmResourceParser {
+    ResourceCollection(StringRef name) : AsmResourceParser(name) {}
+
+    /// Parse a resource into this collection.
+    LogicalResult parseResource(AsmParsedResourceEntry &entry) final;
+
+    /// Build the resources held by this collection.
+    void buildResources(Operation *op, AsmResourceBuilder &builder) const;
+
+    /// The set of resources parsed into this collection.
+    SmallVector<OpaqueAsmResource> resources;
+  };
+
+  /// The set of opaque resources.
+  llvm::MapVector<std::string, std::unique_ptr<ResourceCollection>,
+                  llvm::StringMap<unsigned>>
+      keyToResources;
+};
+
 //===----------------------------------------------------------------------===//
 // ParserConfig
 //===----------------------------------------------------------------------===//
@@ -409,7 +455,12 @@ class AsmResourcePrinter {
 /// contains all of the necessary state to parse a MLIR source file.
 class ParserConfig {
 public:
-  ParserConfig(MLIRContext *context) : context(context) {
+  /// Construct a parser configuration with the given context.
+  /// `fallbackResourceMap` is an optional fallback handler that can be used to
+  /// parse external resources not explicitly handled by another parser.
+  ParserConfig(MLIRContext *context,
+               FallbackAsmResourceMap *fallbackResourceMap = nullptr)
+      : context(context), fallbackResourceMap(fallbackResourceMap) {
     assert(context && "expected valid MLIR context");
   }
 
@@ -420,7 +471,11 @@ class ParserConfig {
   /// parser with `name` is registered.
   AsmResourceParser *getResourceParser(StringRef name) const {
     auto it = resourceParsers.find(name);
-    return it == resourceParsers.end() ? nullptr : it->second.get();
+    if (it != resourceParsers.end())
+      return it->second.get();
+    if (fallbackResourceMap)
+      return &fallbackResourceMap->getParserFor(name);
+    return nullptr;
   }
 
   /// Attach the given resource parser.
@@ -444,6 +499,7 @@ class ParserConfig {
 private:
   MLIRContext *context;
   DenseMap<StringRef, std::unique_ptr<AsmResourceParser>> resourceParsers;
+  FallbackAsmResourceMap *fallbackResourceMap;
 };
 
 //===----------------------------------------------------------------------===//
@@ -466,13 +522,17 @@ class AsmState {
   using LocationMap = DenseMap<Operation *, std::pair<unsigned, unsigned>>;
 
   /// Initialize the asm state at the level of the given operation. A location
-  /// map may optionally be provided to be populated when printing.
+  /// map may optionally be provided to be populated when printing. `map` is an
+  /// optional fallback resource map, which when provided will attach resource
+  /// printers for the fallback resources within the map.
   AsmState(Operation *op,
            const OpPrintingFlags &printerFlags = OpPrintingFlags(),
-           LocationMap *locationMap = nullptr);
+           LocationMap *locationMap = nullptr,
+           FallbackAsmResourceMap *map = nullptr);
   AsmState(MLIRContext *ctx,
            const OpPrintingFlags &printerFlags = OpPrintingFlags(),
-           LocationMap *locationMap = nullptr);
+           LocationMap *locationMap = nullptr,
+           FallbackAsmResourceMap *map = nullptr);
   ~AsmState();
 
   /// Get the printer flags.
@@ -498,6 +558,13 @@ class AsmState {
         name, std::forward<CallableT>(printFn)));
   }
 
+  /// Attach resource printers to the AsmState for the fallback resources
+  /// in the given map.
+  void attachFallbackResourcePrinter(FallbackAsmResourceMap &map) {
+    for (auto &printer : map.getPrinters())
+      attachResourcePrinter(std::move(printer));
+  }
+
   /// Returns a map of dialect resources that were referenced when using this
   /// state to print IR.
   DenseMap<Dialect *, SetVector<AsmDialectResourceHandle>> &

diff  --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
index ff53cec15d779..7bcc1a841c245 100644
--- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
+++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
@@ -39,6 +39,11 @@ struct BytecodeWriterConfig::Impl {
 
 BytecodeWriterConfig::BytecodeWriterConfig(StringRef producer)
     : impl(std::make_unique<Impl>(producer)) {}
+BytecodeWriterConfig::BytecodeWriterConfig(FallbackAsmResourceMap &map,
+                                           StringRef producer)
+    : BytecodeWriterConfig(producer) {
+  attachFallbackResourcePrinter(map);
+}
 BytecodeWriterConfig::~BytecodeWriterConfig() = default;
 
 void BytecodeWriterConfig::attachResourcePrinter(

diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 395bd03bb5794..310e5efbd8f89 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1283,6 +1283,69 @@ StringRef mlir::toString(AsmResourceEntryKind kind) {
   llvm_unreachable("unknown AsmResourceEntryKind");
 }
 
+AsmResourceParser &FallbackAsmResourceMap::getParserFor(StringRef key) {
+  std::unique_ptr<ResourceCollection> &collection = keyToResources[key.str()];
+  if (!collection)
+    collection = std::make_unique<ResourceCollection>(key);
+  return *collection;
+}
+
+std::vector<std::unique_ptr<AsmResourcePrinter>>
+FallbackAsmResourceMap::getPrinters() {
+  std::vector<std::unique_ptr<AsmResourcePrinter>> printers;
+  for (auto &it : keyToResources) {
+    ResourceCollection *collection = it.second.get();
+    auto buildValues = [=](Operation *op, AsmResourceBuilder &builder) {
+      return collection->buildResources(op, builder);
+    };
+    printers.emplace_back(
+        AsmResourcePrinter::fromCallable(collection->getName(), buildValues));
+  }
+  return printers;
+}
+
+LogicalResult FallbackAsmResourceMap::ResourceCollection::parseResource(
+    AsmParsedResourceEntry &entry) {
+  switch (entry.getKind()) {
+  case AsmResourceEntryKind::Blob: {
+    FailureOr<AsmResourceBlob> blob = entry.parseAsBlob();
+    if (failed(blob))
+      return failure();
+    resources.emplace_back(entry.getKey(), std::move(*blob));
+    return success();
+  }
+  case AsmResourceEntryKind::Bool: {
+    FailureOr<bool> value = entry.parseAsBool();
+    if (failed(value))
+      return failure();
+    resources.emplace_back(entry.getKey(), *value);
+    break;
+  }
+  case AsmResourceEntryKind::String: {
+    FailureOr<std::string> str = entry.parseAsString();
+    if (failed(str))
+      return failure();
+    resources.emplace_back(entry.getKey(), std::move(*str));
+    break;
+  }
+  }
+  return success();
+}
+
+void FallbackAsmResourceMap::ResourceCollection::buildResources(
+    Operation *op, AsmResourceBuilder &builder) const {
+  for (const auto &entry : resources) {
+    if (const auto *value = std::get_if<AsmResourceBlob>(&entry.value))
+      builder.buildBlob(entry.key, *value);
+    else if (const auto *value = std::get_if<bool>(&entry.value))
+      builder.buildBool(entry.key, *value);
+    else if (const auto *value = std::get_if<std::string>(&entry.value))
+      builder.buildString(entry.key, *value);
+    else
+      llvm_unreachable("unknown AsmResourceEntryKind");
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // AsmState
 //===----------------------------------------------------------------------===//
@@ -1401,12 +1464,18 @@ static OpPrintingFlags verifyOpAndAdjustFlags(Operation *op,
 }
 
 AsmState::AsmState(Operation *op, const OpPrintingFlags &printerFlags,
-                   LocationMap *locationMap)
+                   LocationMap *locationMap, FallbackAsmResourceMap *map)
     : impl(std::make_unique<AsmStateImpl>(
-          op, verifyOpAndAdjustFlags(op, printerFlags), locationMap)) {}
+          op, verifyOpAndAdjustFlags(op, printerFlags), locationMap)) {
+  if (map)
+    attachFallbackResourcePrinter(*map);
+}
 AsmState::AsmState(MLIRContext *ctx, const OpPrintingFlags &printerFlags,
-                   LocationMap *locationMap)
-    : impl(std::make_unique<AsmStateImpl>(ctx, printerFlags, locationMap)) {}
+                   LocationMap *locationMap, FallbackAsmResourceMap *map)
+    : impl(std::make_unique<AsmStateImpl>(ctx, printerFlags, locationMap)) {
+  if (map)
+    attachFallbackResourcePrinter(*map);
+}
 AsmState::~AsmState() = default;
 
 const OpPrintingFlags &AsmState::getPrinterFlags() const {
@@ -3308,14 +3377,6 @@ void Value::printAsOperand(raw_ostream &os, AsmState &state) {
 }
 
 void Operation::print(raw_ostream &os, const OpPrintingFlags &printerFlags) {
-  // If this is a top level operation, we also print aliases.
-  if (!getParent() && !printerFlags.shouldUseLocalScope()) {
-    AsmState state(this, printerFlags);
-    state.getImpl().initializeAliases(this);
-    print(os, state);
-    return;
-  }
-
   // Find the operation to number from based upon the provided flags.
   Operation *op = this;
   bool shouldUseLocalScope = printerFlags.shouldUseLocalScope();
@@ -3337,10 +3398,12 @@ void Operation::print(raw_ostream &os, const OpPrintingFlags &printerFlags) {
 }
 void Operation::print(raw_ostream &os, AsmState &state) {
   OperationPrinter printer(os, state.getImpl());
-  if (!getParent() && !state.getPrinterFlags().shouldUseLocalScope())
+  if (!getParent() && !state.getPrinterFlags().shouldUseLocalScope()) {
+    state.getImpl().initializeAliases(this);
     printer.printTopLevelOperation(this);
-  else
+  } else {
     printer.print(this);
+  }
 }
 
 void Operation::dump() {

diff  --git a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp
index 22f01bf719d5e..93cc4dc335389 100644
--- a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp
+++ b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp
@@ -319,6 +319,10 @@ struct MLIRDocument {
   /// The container for the IR parsed from the input file.
   Block parsedIR;
 
+  /// A collection of external resources, which we want to propagate up to the
+  /// user.
+  FallbackAsmResourceMap fallbackResourceMap;
+
   /// The source manager containing the contents of the input file.
   llvm::SourceMgr sourceMgr;
 };
@@ -338,11 +342,13 @@ MLIRDocument::MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri,
     return;
   }
 
+  ParserConfig config(&context, &fallbackResourceMap);
   sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
-  if (failed(parseAsmSourceFile(sourceMgr, &parsedIR, &context, &asmState))) {
+  if (failed(parseAsmSourceFile(sourceMgr, &parsedIR, config, &asmState))) {
     // If parsing failed, clear out any of the current state.
     parsedIR.clear();
     asmState = AsmParserState();
+    fallbackResourceMap = FallbackAsmResourceMap();
     return;
   }
 }
@@ -875,9 +881,11 @@ MLIRDocument::convertToBytecode() {
 
   lsp::MLIRConvertBytecodeResult result;
   {
+    BytecodeWriterConfig writerConfig(fallbackResourceMap);
+
     std::string rawBytecodeBuffer;
     llvm::raw_string_ostream os(rawBytecodeBuffer);
-    writeBytecodeToFile(&parsedIR.front(), os);
+    writeBytecodeToFile(&parsedIR.front(), os, writerConfig);
     result.output = llvm::encodeBase64(rawBytecodeBuffer);
   }
   return result;
@@ -1284,11 +1292,15 @@ lsp::MLIRServer::convertFromBytecode(const URIForFile &uri) {
       &tempContext,
       [&](mlir::Diagnostic &diag) { errorMsg += diag.str() + "\n"; });
 
+  // Handling for external resources, which we want to propagate up to the user.
+  FallbackAsmResourceMap fallbackResourceMap;
+
+  // Setup the parser config.
+  ParserConfig parserConfig(&tempContext, &fallbackResourceMap);
+
   // Try to parse the given source file.
-  // TODO: This won't preserve external resources or the producer, we should try
-  // to fix this.
   Block parsedBlock;
-  if (failed(parseSourceFile(uri.file(), &parsedBlock, &tempContext))) {
+  if (failed(parseSourceFile(uri.file(), &parsedBlock, parserConfig))) {
     return llvm::make_error<lsp::LSPError>(
         "failed to parse bytecode source file: " + errorMsg,
         lsp::ErrorCode::RequestFailed);
@@ -1310,8 +1322,11 @@ lsp::MLIRServer::convertFromBytecode(const URIForFile &uri) {
     OwningOpRef<Operation *> topOp = &parsedBlock.front();
     (*topOp)->remove();
 
+    AsmState state(*topOp, OpPrintingFlags().enableDebugInfo().assumeVerified(),
+                   /*locationMap=*/nullptr, &fallbackResourceMap);
+
     llvm::raw_string_ostream os(result.output);
-    (*topOp)->print(os, OpPrintingFlags().enableDebugInfo().assumeVerified());
+    (*topOp)->print(os, state);
   }
   return std::move(result);
 }

diff  --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
index 519da97af0c70..ccd095dcf0bbb 100644
--- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
+++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
@@ -66,8 +66,11 @@ static LogicalResult performActions(raw_ostream &os, bool verifyDiagnostics,
   pm.enableTiming(timing);
 
   // Prepare the parser config, and attach any useful/necessary resource
-  // handlers.
-  ParserConfig config(context);
+  // handlers. Unhandled external resources are treated as passthrough, i.e.
+  // they are not processed and will be emitted directly to the output
+  // untouched.
+  FallbackAsmResourceMap fallbackResourceMap;
+  ParserConfig config(context, &fallbackResourceMap);
   attachPassReproducerAsmResource(config, pm, wasThreadingEnabled);
 
   // Parse the input file and reset the context threading state.
@@ -89,9 +92,12 @@ static LogicalResult performActions(raw_ostream &os, bool verifyDiagnostics,
   // Print the output.
   TimingScope outputTiming = timing.nest("Output");
   if (emitBytecode) {
-    writeBytecodeToFile(module->getOperation(), os);
+    BytecodeWriterConfig writerConfig(fallbackResourceMap);
+    writeBytecodeToFile(module->getOperation(), os, writerConfig);
   } else {
-    module->print(os);
+    AsmState asmState(*module, OpPrintingFlags(), /*locationMap=*/nullptr,
+                      &fallbackResourceMap);
+    module->print(os, asmState);
     os << '\n';
   }
   return success();

diff  --git a/mlir/test/IR/file-metadata-resources.mlir b/mlir/test/IR/file-metadata-resources.mlir
index 57562555c9643..a531c7ce97563 100644
--- a/mlir/test/IR/file-metadata-resources.mlir
+++ b/mlir/test/IR/file-metadata-resources.mlir
@@ -5,6 +5,13 @@
 // CHECK-NEXT:   blob1: "0x08000000010000000000000002000000000000000300000000000000"
 // CHECK-NEXT: }
 
+// Check that we properly preserve unknown external resources.
+// CHECK:      external: {
+// CHECK-NEXT:   blob: "0x08000000010000000000000002000000000000000300000000000000"
+// CHECK-NEXT:   bool: true
+// CHECK-NEXT:   string: "string"
+// CHECK-NEXT: }
+
 module attributes { test.blob_ref = #test.e1di64_elements<blob1> : tensor<*xi1>} {}
 
 {-#
@@ -13,5 +20,12 @@ module attributes { test.blob_ref = #test.e1di64_elements<blob1> : tensor<*xi1>}
       blob1: "0x08000000010000000000000002000000000000000300000000000000",
       blob2: "0x08000000040000000000000005000000000000000600000000000000"
     }
+  },
+  external_resources: {
+    external: {
+      blob: "0x08000000010000000000000002000000000000000300000000000000",
+      bool: true,
+      string: "string"
+    }
   }
 #-}

diff  --git a/mlir/test/IR/invalid-file-metadata.mlir b/mlir/test/IR/invalid-file-metadata.mlir
index 352cf19f11bef..553bd43c6aeeb 100644
--- a/mlir/test/IR/invalid-file-metadata.mlir
+++ b/mlir/test/IR/invalid-file-metadata.mlir
@@ -129,14 +129,3 @@
     entry "value"
   }
 #-}
-
-// -----
-
-// expected-warning at +3 {{ignoring unknown external resources for 'foobar'}}
-{-#
-  external_resources: {
-    foobar: {
-      entry: "foo"
-    }
-  }
-#-}


        


More information about the Mlir-commits mailing list