[Mlir-commits] [mlir] 4155be3 - [mlir][Translation] Allow specifying an expected input alignment for "ToMLIR" translations
River Riddle
llvmlistbot at llvm.org
Tue Nov 15 17:22:51 PST 2022
Author: River Riddle
Date: 2022-11-15T17:22:41-08:00
New Revision: 4155be339ba80fef8fef0423bbd83217e8e9ca48
URL: https://github.com/llvm/llvm-project/commit/4155be339ba80fef8fef0423bbd83217e8e9ca48
DIFF: https://github.com/llvm/llvm-project/commit/4155be339ba80fef8fef0423bbd83217e8e9ca48.diff
LOG: [mlir][Translation] Allow specifying an expected input alignment for "ToMLIR" translations
This allows for ensuring that alignment requirements on translation
inputs are satisfied.
Differential Revision: https://reviews.llvm.org/D137999
Added:
Modified:
mlir/include/mlir/Tools/mlir-translate/Translation.h
mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp
mlir/lib/Tools/mlir-translate/Translation.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Tools/mlir-translate/Translation.h b/mlir/include/mlir/Tools/mlir-translate/Translation.h
index d3cd817d3e7f0..80c4e37f47caa 100644
--- a/mlir/include/mlir/Tools/mlir-translate/Translation.h
+++ b/mlir/include/mlir/Tools/mlir-translate/Translation.h
@@ -47,9 +47,44 @@ using TranslateFromMLIRFunction =
using TranslateFunction = std::function<LogicalResult(
llvm::SourceMgr &sourceMgr, llvm::raw_ostream &output, MLIRContext *)>;
+/// This class contains all of the components necessary for performing a
+/// translation.
+class Translation {
+public:
+ Translation() = default;
+ Translation(TranslateFunction function, StringRef description,
+ Optional<llvm::Align> inputAlignment)
+ : function(std::move(function)), description(description),
+ inputAlignment(inputAlignment) {}
+
+ /// Return the description of this translation.
+ StringRef getDescription() const { return description; }
+
+ /// Return the optional alignment desired for the input of the translation.
+ Optional<llvm::Align> getInputAlignment() const { return inputAlignment; }
+
+ /// Invoke the translation function with the given input and output streams.
+ LogicalResult operator()(llvm::SourceMgr &sourceMgr,
+ llvm::raw_ostream &output,
+ MLIRContext *context) const {
+ return function(sourceMgr, output, context);
+ }
+
+private:
+ /// The underlying translation function.
+ TranslateFunction function;
+
+ /// The description of the translation.
+ StringRef description;
+
+ /// An optional alignment desired for the input of the translation.
+ Optional<llvm::Align> inputAlignment;
+};
+
/// Use Translate[ToMLIR|FromMLIR]Registration as an initializer that
/// registers a function and associates it with name. This requires that a
-/// translation has not been registered to a given name.
+/// translation has not been registered to a given name. `inputAlign` is an
+/// optional expected alignment for the input data.
///
/// Usage:
///
@@ -62,10 +97,14 @@ using TranslateFunction = std::function<LogicalResult(
///
/// \{
struct TranslateToMLIRRegistration {
- TranslateToMLIRRegistration(llvm::StringRef name, llvm::StringRef description,
- const TranslateSourceMgrToMLIRFunction &function);
- TranslateToMLIRRegistration(llvm::StringRef name, llvm::StringRef description,
- const TranslateStringRefToMLIRFunction &function);
+ TranslateToMLIRRegistration(
+ llvm::StringRef name, llvm::StringRef description,
+ const TranslateSourceMgrToMLIRFunction &function,
+ Optional<llvm::Align> inputAlignment = llvm::None);
+ TranslateToMLIRRegistration(
+ llvm::StringRef name, llvm::StringRef description,
+ const TranslateStringRefToMLIRFunction &function,
+ Optional<llvm::Align> inputAlignment = llvm::None);
};
struct TranslateFromMLIRRegistration {
@@ -99,7 +138,7 @@ struct TranslateRegistration {
/// \}
/// A command line parser for translation functions.
-struct TranslationParser : public llvm::cl::parser<const TranslateFunction *> {
+struct TranslationParser : public llvm::cl::parser<const Translation *> {
TranslationParser(llvm::cl::Option &opt);
void printOptionInfo(const llvm::cl::Option &o,
diff --git a/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp b/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp
index ef2545bd46beb..51b21f251747a 100644
--- a/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp
+++ b/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp
@@ -56,7 +56,7 @@ LogicalResult mlir::mlirTranslateMain(int argc, char **argv,
llvm::InitLLVM y(argc, argv);
// Add flags for all the registered translations.
- llvm::cl::opt<const TranslateFunction *, false, TranslationParser>
+ llvm::cl::opt<const Translation *, false, TranslationParser>
translationRequested("", llvm::cl::desc("Translation to perform"),
llvm::cl::Required);
registerAsmPrinterCLOptions();
@@ -65,7 +65,11 @@ LogicalResult mlir::mlirTranslateMain(int argc, char **argv,
llvm::cl::ParseCommandLineOptions(argc, argv, toolName);
std::string errorMessage;
- auto input = openInputFile(inputFilename, &errorMessage);
+ std::unique_ptr<llvm::MemoryBuffer> input;
+ if (auto inputAlignment = translationRequested->getInputAlignment())
+ input = openInputFile(inputFilename, *inputAlignment, &errorMessage);
+ else
+ input = openInputFile(inputFilename, &errorMessage);
if (!input) {
llvm::errs() << errorMessage << "\n";
return failure();
diff --git a/mlir/lib/Tools/mlir-translate/Translation.cpp b/mlir/lib/Tools/mlir-translate/Translation.cpp
index ab86cd0000b99..548e3f9825b37 100644
--- a/mlir/lib/Tools/mlir-translate/Translation.cpp
+++ b/mlir/lib/Tools/mlir-translate/Translation.cpp
@@ -40,34 +40,30 @@ void mlir::registerTranslationCLOptions() { *clOptions; }
// Translation Registry
//===----------------------------------------------------------------------===//
-struct TranslationBundle {
- TranslateFunction translateFunction;
- StringRef translateDescription;
-};
-
-/// Get the mutable static map between registered file-to-file MLIR translations
-/// and TranslateFunctions with its description that perform those translations.
-static llvm::StringMap<TranslationBundle> &getTranslationRegistry() {
- static llvm::StringMap<TranslationBundle> translationBundle;
+/// Get the mutable static map between registered file-to-file MLIR
+/// translations.
+static llvm::StringMap<Translation> &getTranslationRegistry() {
+ static llvm::StringMap<Translation> translationBundle;
return translationBundle;
}
/// Register the given translation.
static void registerTranslation(StringRef name, StringRef description,
+ Optional<llvm::Align> inputAlignment,
const TranslateFunction &function) {
- auto &translationRegistry = getTranslationRegistry();
- if (translationRegistry.find(name) != translationRegistry.end())
+ auto ®istry = getTranslationRegistry();
+ if (registry.count(name))
llvm::report_fatal_error(
"Attempting to overwrite an existing <file-to-file> function");
assert(function &&
"Attempting to register an empty translate <file-to-file> function");
- translationRegistry[name].translateFunction = function;
- translationRegistry[name].translateDescription = description;
+ registry[name] = Translation(function, description, inputAlignment);
}
TranslateRegistration::TranslateRegistration(
StringRef name, StringRef description, const TranslateFunction &function) {
- registerTranslation(name, description, function);
+ registerTranslation(name, description, /*inputAlignment=*/llvm::None,
+ function);
}
//===----------------------------------------------------------------------===//
@@ -77,7 +73,7 @@ TranslateRegistration::TranslateRegistration(
// Puts `function` into the to-MLIR translation registry unless there is already
// a function registered for the same name.
static void registerTranslateToMLIRFunction(
- StringRef name, StringRef description,
+ StringRef name, StringRef description, Optional<llvm::Align> inputAlignment,
const TranslateSourceMgrToMLIRFunction &function) {
auto wrappedFn = [function](llvm::SourceMgr &sourceMgr, raw_ostream &output,
MLIRContext *context) {
@@ -87,21 +83,23 @@ static void registerTranslateToMLIRFunction(
op.get()->print(output);
return success();
};
- registerTranslation(name, description, wrappedFn);
+ registerTranslation(name, description, inputAlignment, wrappedFn);
}
TranslateToMLIRRegistration::TranslateToMLIRRegistration(
StringRef name, StringRef description,
- const TranslateSourceMgrToMLIRFunction &function) {
- registerTranslateToMLIRFunction(name, description, function);
+ const TranslateSourceMgrToMLIRFunction &function,
+ Optional<llvm::Align> inputAlignment) {
+ registerTranslateToMLIRFunction(name, description, inputAlignment, function);
}
/// Wraps `function` with a lambda that extracts a StringRef from a source
/// manager and registers the wrapper lambda as a to-MLIR conversion.
TranslateToMLIRRegistration::TranslateToMLIRRegistration(
StringRef name, StringRef description,
- const TranslateStringRefToMLIRFunction &function) {
+ const TranslateStringRefToMLIRFunction &function,
+ Optional<llvm::Align> inputAlignment) {
registerTranslateToMLIRFunction(
- name, description,
+ name, description, inputAlignment,
[function](llvm::SourceMgr &sourceMgr, MLIRContext *ctx) {
const llvm::MemoryBuffer *buffer =
sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID());
@@ -117,9 +115,8 @@ TranslateFromMLIRRegistration::TranslateFromMLIRRegistration(
StringRef name, StringRef description,
const TranslateFromMLIRFunction &function,
const std::function<void(DialectRegistry &)> &dialectRegistration) {
-
registerTranslation(
- name, description,
+ name, description, /*inputAlignment=*/llvm::None,
[function, dialectRegistration](llvm::SourceMgr &sourceMgr,
raw_ostream &output,
MLIRContext *context) {
@@ -141,11 +138,9 @@ TranslateFromMLIRRegistration::TranslateFromMLIRRegistration(
//===----------------------------------------------------------------------===//
TranslationParser::TranslationParser(llvm::cl::Option &opt)
- : llvm::cl::parser<const TranslateFunction *>(opt) {
- for (const auto &kv : getTranslationRegistry()) {
- addLiteralOption(kv.first(), &kv.second.translateFunction,
- kv.second.translateDescription);
- }
+ : llvm::cl::parser<const Translation *>(opt) {
+ for (const auto &kv : getTranslationRegistry())
+ addLiteralOption(kv.first(), &kv.second, kv.second.getDescription());
}
void TranslationParser::printOptionInfo(const llvm::cl::Option &o,
@@ -156,5 +151,5 @@ void TranslationParser::printOptionInfo(const llvm::cl::Option &o,
const TranslationParser::OptionInfo *rhs) {
return lhs->Name.compare(rhs->Name);
});
- llvm::cl::parser<const TranslateFunction *>::printOptionInfo(o, globalWidth);
+ llvm::cl::parser<const Translation *>::printOptionInfo(o, globalWidth);
}
More information about the Mlir-commits
mailing list