[Mlir-commits] [mlir] 34300ee - [mlir] Add fallback support for parsing/printing unknown external resources
River Riddle
llvmlistbot at llvm.org
Tue Sep 13 11:39:44 PDT 2022
Author: River Riddle
Date: 2022-09-13T11:39:20-07:00
New Revision: 34300ee3697e32926e998d1036925d0f59ffae02
URL: https://github.com/llvm/llvm-project/commit/34300ee3697e32926e998d1036925d0f59ffae02
DIFF: https://github.com/llvm/llvm-project/commit/34300ee3697e32926e998d1036925d0f59ffae02.diff
LOG: [mlir] Add fallback support for parsing/printing unknown external resources
This is necessary/useful for building generic tooling that can roundtrip external
resources without needing to explicitly handle them. For example, this allows
for viewing the resources encoded within a bytecode file without having to
explicitly know how to process them (e.g. making it easier to interact with a
reproducer encoded in bytecode).
Differential Revision: https://reviews.llvm.org/D133460
Added:
Modified:
mlir/include/mlir/Bytecode/BytecodeWriter.h
mlir/include/mlir/IR/AsmState.h
mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp
mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
mlir/test/IR/file-metadata-resources.mlir
mlir/test/IR/invalid-file-metadata.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Bytecode/BytecodeWriter.h b/mlir/include/mlir/Bytecode/BytecodeWriter.h
index cd1d8c71a80de..fb4329e4d66f0 100644
--- a/mlir/include/mlir/Bytecode/BytecodeWriter.h
+++ b/mlir/include/mlir/Bytecode/BytecodeWriter.h
@@ -27,6 +27,10 @@ class BytecodeWriterConfig {
/// of the bytecode when reading. It has no functional effect on the bytecode
/// serialization.
BytecodeWriterConfig(StringRef producer = "MLIR" LLVM_VERSION_STRING);
+ /// `map` is a fallback resource map, which when provided will attach resource
+ /// printers for the fallback resources within the map.
+ BytecodeWriterConfig(FallbackAsmResourceMap &map,
+ StringRef producer = "MLIR" LLVM_VERSION_STRING);
~BytecodeWriterConfig();
/// An internal implementation class that contains the state of the
@@ -53,6 +57,13 @@ class BytecodeWriterConfig {
name, std::forward<CallableT>(printFn)));
}
+ /// Attach resource printers to the AsmState for the fallback resources
+ /// in the given map.
+ void attachFallbackResourcePrinter(FallbackAsmResourceMap &map) {
+ for (auto &printer : map.getPrinters())
+ attachResourcePrinter(std::move(printer));
+ }
+
private:
/// A pointer to allocated storage for the impl state.
std::unique_ptr<Impl> impl;
diff --git a/mlir/include/mlir/IR/AsmState.h b/mlir/include/mlir/IR/AsmState.h
index 3ff4cfdfad2f4..87f3a37b637dd 100644
--- a/mlir/include/mlir/IR/AsmState.h
+++ b/mlir/include/mlir/IR/AsmState.h
@@ -16,6 +16,8 @@
#include "mlir/IR/OperationSupport.h"
#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/StringMap.h"
#include <memory>
@@ -401,6 +403,50 @@ class AsmResourcePrinter {
std::string name;
};
+/// A fallback map containing external resources not explicitly handled by
+/// another parser/printer.
+class FallbackAsmResourceMap {
+public:
+ /// This class represents an opaque resource.
+ struct OpaqueAsmResource {
+ OpaqueAsmResource(StringRef key,
+ std::variant<AsmResourceBlob, bool, std::string> value)
+ : key(key.str()), value(std::move(value)) {}
+
+ /// The key identifying the resource.
+ std::string key;
+ /// An opaque value for the resource, whose variant values align 1-1 with
+ /// the kinds defined in AsmResourceEntryKind.
+ std::variant<AsmResourceBlob, bool, std::string> value;
+ };
+
+ /// Return a parser than can be used for parsing entries for the given
+ /// identifier key.
+ AsmResourceParser &getParserFor(StringRef key);
+
+ /// Build a set of resource printers to print the resources within this map.
+ std::vector<std::unique_ptr<AsmResourcePrinter>> getPrinters();
+
+private:
+ struct ResourceCollection : public AsmResourceParser {
+ ResourceCollection(StringRef name) : AsmResourceParser(name) {}
+
+ /// Parse a resource into this collection.
+ LogicalResult parseResource(AsmParsedResourceEntry &entry) final;
+
+ /// Build the resources held by this collection.
+ void buildResources(Operation *op, AsmResourceBuilder &builder) const;
+
+ /// The set of resources parsed into this collection.
+ SmallVector<OpaqueAsmResource> resources;
+ };
+
+ /// The set of opaque resources.
+ llvm::MapVector<std::string, std::unique_ptr<ResourceCollection>,
+ llvm::StringMap<unsigned>>
+ keyToResources;
+};
+
//===----------------------------------------------------------------------===//
// ParserConfig
//===----------------------------------------------------------------------===//
@@ -409,7 +455,12 @@ class AsmResourcePrinter {
/// contains all of the necessary state to parse a MLIR source file.
class ParserConfig {
public:
- ParserConfig(MLIRContext *context) : context(context) {
+ /// Construct a parser configuration with the given context.
+ /// `fallbackResourceMap` is an optional fallback handler that can be used to
+ /// parse external resources not explicitly handled by another parser.
+ ParserConfig(MLIRContext *context,
+ FallbackAsmResourceMap *fallbackResourceMap = nullptr)
+ : context(context), fallbackResourceMap(fallbackResourceMap) {
assert(context && "expected valid MLIR context");
}
@@ -420,7 +471,11 @@ class ParserConfig {
/// parser with `name` is registered.
AsmResourceParser *getResourceParser(StringRef name) const {
auto it = resourceParsers.find(name);
- return it == resourceParsers.end() ? nullptr : it->second.get();
+ if (it != resourceParsers.end())
+ return it->second.get();
+ if (fallbackResourceMap)
+ return &fallbackResourceMap->getParserFor(name);
+ return nullptr;
}
/// Attach the given resource parser.
@@ -444,6 +499,7 @@ class ParserConfig {
private:
MLIRContext *context;
DenseMap<StringRef, std::unique_ptr<AsmResourceParser>> resourceParsers;
+ FallbackAsmResourceMap *fallbackResourceMap;
};
//===----------------------------------------------------------------------===//
@@ -466,13 +522,17 @@ class AsmState {
using LocationMap = DenseMap<Operation *, std::pair<unsigned, unsigned>>;
/// Initialize the asm state at the level of the given operation. A location
- /// map may optionally be provided to be populated when printing.
+ /// map may optionally be provided to be populated when printing. `map` is an
+ /// optional fallback resource map, which when provided will attach resource
+ /// printers for the fallback resources within the map.
AsmState(Operation *op,
const OpPrintingFlags &printerFlags = OpPrintingFlags(),
- LocationMap *locationMap = nullptr);
+ LocationMap *locationMap = nullptr,
+ FallbackAsmResourceMap *map = nullptr);
AsmState(MLIRContext *ctx,
const OpPrintingFlags &printerFlags = OpPrintingFlags(),
- LocationMap *locationMap = nullptr);
+ LocationMap *locationMap = nullptr,
+ FallbackAsmResourceMap *map = nullptr);
~AsmState();
/// Get the printer flags.
@@ -498,6 +558,13 @@ class AsmState {
name, std::forward<CallableT>(printFn)));
}
+ /// Attach resource printers to the AsmState for the fallback resources
+ /// in the given map.
+ void attachFallbackResourcePrinter(FallbackAsmResourceMap &map) {
+ for (auto &printer : map.getPrinters())
+ attachResourcePrinter(std::move(printer));
+ }
+
/// Returns a map of dialect resources that were referenced when using this
/// state to print IR.
DenseMap<Dialect *, SetVector<AsmDialectResourceHandle>> &
diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
index ff53cec15d779..7bcc1a841c245 100644
--- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
+++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
@@ -39,6 +39,11 @@ struct BytecodeWriterConfig::Impl {
BytecodeWriterConfig::BytecodeWriterConfig(StringRef producer)
: impl(std::make_unique<Impl>(producer)) {}
+BytecodeWriterConfig::BytecodeWriterConfig(FallbackAsmResourceMap &map,
+ StringRef producer)
+ : BytecodeWriterConfig(producer) {
+ attachFallbackResourcePrinter(map);
+}
BytecodeWriterConfig::~BytecodeWriterConfig() = default;
void BytecodeWriterConfig::attachResourcePrinter(
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 395bd03bb5794..310e5efbd8f89 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1283,6 +1283,69 @@ StringRef mlir::toString(AsmResourceEntryKind kind) {
llvm_unreachable("unknown AsmResourceEntryKind");
}
+AsmResourceParser &FallbackAsmResourceMap::getParserFor(StringRef key) {
+ std::unique_ptr<ResourceCollection> &collection = keyToResources[key.str()];
+ if (!collection)
+ collection = std::make_unique<ResourceCollection>(key);
+ return *collection;
+}
+
+std::vector<std::unique_ptr<AsmResourcePrinter>>
+FallbackAsmResourceMap::getPrinters() {
+ std::vector<std::unique_ptr<AsmResourcePrinter>> printers;
+ for (auto &it : keyToResources) {
+ ResourceCollection *collection = it.second.get();
+ auto buildValues = [=](Operation *op, AsmResourceBuilder &builder) {
+ return collection->buildResources(op, builder);
+ };
+ printers.emplace_back(
+ AsmResourcePrinter::fromCallable(collection->getName(), buildValues));
+ }
+ return printers;
+}
+
+LogicalResult FallbackAsmResourceMap::ResourceCollection::parseResource(
+ AsmParsedResourceEntry &entry) {
+ switch (entry.getKind()) {
+ case AsmResourceEntryKind::Blob: {
+ FailureOr<AsmResourceBlob> blob = entry.parseAsBlob();
+ if (failed(blob))
+ return failure();
+ resources.emplace_back(entry.getKey(), std::move(*blob));
+ return success();
+ }
+ case AsmResourceEntryKind::Bool: {
+ FailureOr<bool> value = entry.parseAsBool();
+ if (failed(value))
+ return failure();
+ resources.emplace_back(entry.getKey(), *value);
+ break;
+ }
+ case AsmResourceEntryKind::String: {
+ FailureOr<std::string> str = entry.parseAsString();
+ if (failed(str))
+ return failure();
+ resources.emplace_back(entry.getKey(), std::move(*str));
+ break;
+ }
+ }
+ return success();
+}
+
+void FallbackAsmResourceMap::ResourceCollection::buildResources(
+ Operation *op, AsmResourceBuilder &builder) const {
+ for (const auto &entry : resources) {
+ if (const auto *value = std::get_if<AsmResourceBlob>(&entry.value))
+ builder.buildBlob(entry.key, *value);
+ else if (const auto *value = std::get_if<bool>(&entry.value))
+ builder.buildBool(entry.key, *value);
+ else if (const auto *value = std::get_if<std::string>(&entry.value))
+ builder.buildString(entry.key, *value);
+ else
+ llvm_unreachable("unknown AsmResourceEntryKind");
+ }
+}
+
//===----------------------------------------------------------------------===//
// AsmState
//===----------------------------------------------------------------------===//
@@ -1401,12 +1464,18 @@ static OpPrintingFlags verifyOpAndAdjustFlags(Operation *op,
}
AsmState::AsmState(Operation *op, const OpPrintingFlags &printerFlags,
- LocationMap *locationMap)
+ LocationMap *locationMap, FallbackAsmResourceMap *map)
: impl(std::make_unique<AsmStateImpl>(
- op, verifyOpAndAdjustFlags(op, printerFlags), locationMap)) {}
+ op, verifyOpAndAdjustFlags(op, printerFlags), locationMap)) {
+ if (map)
+ attachFallbackResourcePrinter(*map);
+}
AsmState::AsmState(MLIRContext *ctx, const OpPrintingFlags &printerFlags,
- LocationMap *locationMap)
- : impl(std::make_unique<AsmStateImpl>(ctx, printerFlags, locationMap)) {}
+ LocationMap *locationMap, FallbackAsmResourceMap *map)
+ : impl(std::make_unique<AsmStateImpl>(ctx, printerFlags, locationMap)) {
+ if (map)
+ attachFallbackResourcePrinter(*map);
+}
AsmState::~AsmState() = default;
const OpPrintingFlags &AsmState::getPrinterFlags() const {
@@ -3308,14 +3377,6 @@ void Value::printAsOperand(raw_ostream &os, AsmState &state) {
}
void Operation::print(raw_ostream &os, const OpPrintingFlags &printerFlags) {
- // If this is a top level operation, we also print aliases.
- if (!getParent() && !printerFlags.shouldUseLocalScope()) {
- AsmState state(this, printerFlags);
- state.getImpl().initializeAliases(this);
- print(os, state);
- return;
- }
-
// Find the operation to number from based upon the provided flags.
Operation *op = this;
bool shouldUseLocalScope = printerFlags.shouldUseLocalScope();
@@ -3337,10 +3398,12 @@ void Operation::print(raw_ostream &os, const OpPrintingFlags &printerFlags) {
}
void Operation::print(raw_ostream &os, AsmState &state) {
OperationPrinter printer(os, state.getImpl());
- if (!getParent() && !state.getPrinterFlags().shouldUseLocalScope())
+ if (!getParent() && !state.getPrinterFlags().shouldUseLocalScope()) {
+ state.getImpl().initializeAliases(this);
printer.printTopLevelOperation(this);
- else
+ } else {
printer.print(this);
+ }
}
void Operation::dump() {
diff --git a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp
index 22f01bf719d5e..93cc4dc335389 100644
--- a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp
+++ b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp
@@ -319,6 +319,10 @@ struct MLIRDocument {
/// The container for the IR parsed from the input file.
Block parsedIR;
+ /// A collection of external resources, which we want to propagate up to the
+ /// user.
+ FallbackAsmResourceMap fallbackResourceMap;
+
/// The source manager containing the contents of the input file.
llvm::SourceMgr sourceMgr;
};
@@ -338,11 +342,13 @@ MLIRDocument::MLIRDocument(MLIRContext &context, const lsp::URIForFile &uri,
return;
}
+ ParserConfig config(&context, &fallbackResourceMap);
sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
- if (failed(parseAsmSourceFile(sourceMgr, &parsedIR, &context, &asmState))) {
+ if (failed(parseAsmSourceFile(sourceMgr, &parsedIR, config, &asmState))) {
// If parsing failed, clear out any of the current state.
parsedIR.clear();
asmState = AsmParserState();
+ fallbackResourceMap = FallbackAsmResourceMap();
return;
}
}
@@ -875,9 +881,11 @@ MLIRDocument::convertToBytecode() {
lsp::MLIRConvertBytecodeResult result;
{
+ BytecodeWriterConfig writerConfig(fallbackResourceMap);
+
std::string rawBytecodeBuffer;
llvm::raw_string_ostream os(rawBytecodeBuffer);
- writeBytecodeToFile(&parsedIR.front(), os);
+ writeBytecodeToFile(&parsedIR.front(), os, writerConfig);
result.output = llvm::encodeBase64(rawBytecodeBuffer);
}
return result;
@@ -1284,11 +1292,15 @@ lsp::MLIRServer::convertFromBytecode(const URIForFile &uri) {
&tempContext,
[&](mlir::Diagnostic &diag) { errorMsg += diag.str() + "\n"; });
+ // Handling for external resources, which we want to propagate up to the user.
+ FallbackAsmResourceMap fallbackResourceMap;
+
+ // Setup the parser config.
+ ParserConfig parserConfig(&tempContext, &fallbackResourceMap);
+
// Try to parse the given source file.
- // TODO: This won't preserve external resources or the producer, we should try
- // to fix this.
Block parsedBlock;
- if (failed(parseSourceFile(uri.file(), &parsedBlock, &tempContext))) {
+ if (failed(parseSourceFile(uri.file(), &parsedBlock, parserConfig))) {
return llvm::make_error<lsp::LSPError>(
"failed to parse bytecode source file: " + errorMsg,
lsp::ErrorCode::RequestFailed);
@@ -1310,8 +1322,11 @@ lsp::MLIRServer::convertFromBytecode(const URIForFile &uri) {
OwningOpRef<Operation *> topOp = &parsedBlock.front();
(*topOp)->remove();
+ AsmState state(*topOp, OpPrintingFlags().enableDebugInfo().assumeVerified(),
+ /*locationMap=*/nullptr, &fallbackResourceMap);
+
llvm::raw_string_ostream os(result.output);
- (*topOp)->print(os, OpPrintingFlags().enableDebugInfo().assumeVerified());
+ (*topOp)->print(os, state);
}
return std::move(result);
}
diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
index 519da97af0c70..ccd095dcf0bbb 100644
--- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
+++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
@@ -66,8 +66,11 @@ static LogicalResult performActions(raw_ostream &os, bool verifyDiagnostics,
pm.enableTiming(timing);
// Prepare the parser config, and attach any useful/necessary resource
- // handlers.
- ParserConfig config(context);
+ // handlers. Unhandled external resources are treated as passthrough, i.e.
+ // they are not processed and will be emitted directly to the output
+ // untouched.
+ FallbackAsmResourceMap fallbackResourceMap;
+ ParserConfig config(context, &fallbackResourceMap);
attachPassReproducerAsmResource(config, pm, wasThreadingEnabled);
// Parse the input file and reset the context threading state.
@@ -89,9 +92,12 @@ static LogicalResult performActions(raw_ostream &os, bool verifyDiagnostics,
// Print the output.
TimingScope outputTiming = timing.nest("Output");
if (emitBytecode) {
- writeBytecodeToFile(module->getOperation(), os);
+ BytecodeWriterConfig writerConfig(fallbackResourceMap);
+ writeBytecodeToFile(module->getOperation(), os, writerConfig);
} else {
- module->print(os);
+ AsmState asmState(*module, OpPrintingFlags(), /*locationMap=*/nullptr,
+ &fallbackResourceMap);
+ module->print(os, asmState);
os << '\n';
}
return success();
diff --git a/mlir/test/IR/file-metadata-resources.mlir b/mlir/test/IR/file-metadata-resources.mlir
index 57562555c9643..a531c7ce97563 100644
--- a/mlir/test/IR/file-metadata-resources.mlir
+++ b/mlir/test/IR/file-metadata-resources.mlir
@@ -5,6 +5,13 @@
// CHECK-NEXT: blob1: "0x08000000010000000000000002000000000000000300000000000000"
// CHECK-NEXT: }
+// Check that we properly preserve unknown external resources.
+// CHECK: external: {
+// CHECK-NEXT: blob: "0x08000000010000000000000002000000000000000300000000000000"
+// CHECK-NEXT: bool: true
+// CHECK-NEXT: string: "string"
+// CHECK-NEXT: }
+
module attributes { test.blob_ref = #test.e1di64_elements<blob1> : tensor<*xi1>} {}
{-#
@@ -13,5 +20,12 @@ module attributes { test.blob_ref = #test.e1di64_elements<blob1> : tensor<*xi1>}
blob1: "0x08000000010000000000000002000000000000000300000000000000",
blob2: "0x08000000040000000000000005000000000000000600000000000000"
}
+ },
+ external_resources: {
+ external: {
+ blob: "0x08000000010000000000000002000000000000000300000000000000",
+ bool: true,
+ string: "string"
+ }
}
#-}
diff --git a/mlir/test/IR/invalid-file-metadata.mlir b/mlir/test/IR/invalid-file-metadata.mlir
index 352cf19f11bef..553bd43c6aeeb 100644
--- a/mlir/test/IR/invalid-file-metadata.mlir
+++ b/mlir/test/IR/invalid-file-metadata.mlir
@@ -129,14 +129,3 @@
entry "value"
}
#-}
-
-// -----
-
-// expected-warning at +3 {{ignoring unknown external resources for 'foobar'}}
-{-#
- external_resources: {
- foobar: {
- entry: "foo"
- }
- }
-#-}
More information about the Mlir-commits
mailing list