[Mlir-commits] [mlir] 09fd9ef - [mlir] Execute all requested translations in MlirTranslateMain
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Feb 18 17:12:37 PST 2023
Author: Fabian
Date: 2023-02-18T21:31:59+01:00
New Revision: 09fd9ef4f4cb5c7b25bdb9f97d7e92445aec3417
URL: https://github.com/llvm/llvm-project/commit/09fd9ef4f4cb5c7b25bdb9f97d7e92445aec3417
DIFF: https://github.com/llvm/llvm-project/commit/09fd9ef4f4cb5c7b25bdb9f97d7e92445aec3417.diff
LOG: [mlir] Execute all requested translations in MlirTranslateMain
Currently, MlirTranslateMain only executes one of the requested translations, and does not error if multiple are specified. This commit enables translations to be chained in the specified order.
This makes round-trip tests easier, since existing import/export passes can be reused and no combined round-trip passes have to be registered (example: mlir-translate -serialize-spirv -deserialize-spirv).
Additionally, by leveraging TranslateRegistration with file-to-file TranslateFunctions, generic pre- and post-processing can be added before/after conversion to/from MLIR.
Reviewed By: lattner, Mogball
Differential Revision: https://reviews.llvm.org/D143719
Added:
mlir/test/Target/SPIRV/array-two-step-roundtrip.mlir
Modified:
mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp b/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp
index 02c91285dbedf..5e460ed13fd62 100644
--- a/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp
+++ b/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp
@@ -13,6 +13,7 @@
#include "mlir/IR/Verifier.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Support/FileUtilities.h"
+#include "mlir/Support/LogicalResult.h"
#include "mlir/Support/ToolUtilities.h"
#include "mlir/Tools/mlir-translate/Translation.h"
#include "llvm/Support/InitLLVM.h"
@@ -56,9 +57,9 @@ LogicalResult mlir::mlirTranslateMain(int argc, char **argv,
llvm::InitLLVM y(argc, argv);
// Add flags for all the registered translations.
- llvm::cl::opt<const Translation *, false, TranslationParser>
- translationRequested("", llvm::cl::desc("Translation to perform"),
- llvm::cl::Required);
+ llvm::cl::list<const Translation *, bool, TranslationParser>
+ translationsRequested("", llvm::cl::desc("Translations to perform"),
+ llvm::cl::Required);
registerAsmPrinterCLOptions();
registerMLIRContextCLOptions();
registerTranslationCLOptions();
@@ -66,7 +67,7 @@ LogicalResult mlir::mlirTranslateMain(int argc, char **argv,
std::string errorMessage;
std::unique_ptr<llvm::MemoryBuffer> input;
- if (auto inputAlignment = translationRequested->getInputAlignment())
+ if (auto inputAlignment = translationsRequested[0]->getInputAlignment())
input = openInputFile(inputFilename, *inputAlignment, &errorMessage);
else
input = openInputFile(inputFilename, &errorMessage);
@@ -84,23 +85,54 @@ LogicalResult mlir::mlirTranslateMain(int argc, char **argv,
// Processes the memory buffer with a new MLIRContext.
auto processBuffer = [&](std::unique_ptr<llvm::MemoryBuffer> ownedBuffer,
raw_ostream &os) {
- MLIRContext context;
- context.allowUnregisteredDialects(allowUnregisteredDialects);
- context.printOpOnDiagnostic(!verifyDiagnostics);
- auto sourceMgr = std::make_shared<llvm::SourceMgr>();
- sourceMgr->AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
-
- if (!verifyDiagnostics) {
- SourceMgrDiagnosticHandler sourceMgrHandler(*sourceMgr, &context);
- return (*translationRequested)(sourceMgr, os, &context);
+ // Temporary buffers for chained translation processing.
+ std::string dataIn;
+ std::string dataOut;
+ LogicalResult result = LogicalResult::success();
+
+ for (size_t i = 0, e = translationsRequested.size(); i < e; ++i) {
+ llvm::raw_ostream *stream;
+ llvm::raw_string_ostream dataStream(dataOut);
+
+ if (i == e - 1) {
+ // Output last translation to output.
+ stream = &os;
+ } else {
+ // Output translation to temporary data buffer.
+ stream = &dataStream;
+ }
+
+ const Translation *translationRequested = translationsRequested[i];
+ MLIRContext context;
+ context.allowUnregisteredDialects(allowUnregisteredDialects);
+ context.printOpOnDiagnostic(!verifyDiagnostics);
+ auto sourceMgr = std::make_shared<llvm::SourceMgr>();
+ sourceMgr->AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
+
+ if (verifyDiagnostics) {
+ // In the diagnostic verification flow, we ignore whether the
+ // translation failed (in most cases, it is expected to fail).
+ // Instead, we check if the diagnostics were produced as expected.
+ SourceMgrDiagnosticVerifierHandler sourceMgrHandler(*sourceMgr,
+ &context);
+ (void)(*translationRequested)(sourceMgr, os, &context);
+ result = sourceMgrHandler.verify();
+ } else {
+ SourceMgrDiagnosticHandler sourceMgrHandler(*sourceMgr, &context);
+ result = (*translationRequested)(sourceMgr, *stream, &context);
+ }
+ if (failed(result))
+ return result;
+
+ if (i < e - 1) {
+ // If there are further translations, create a new buffer with the
+ // output data.
+ dataIn = dataOut;
+ dataOut.clear();
+ ownedBuffer = llvm::MemoryBuffer::getMemBuffer(dataIn);
+ }
}
-
- // In the diagnostic verification flow, we ignore whether the translation
- // failed (in most cases, it is expected to fail). Instead, we check if the
- // diagnostics were produced as expected.
- SourceMgrDiagnosticVerifierHandler sourceMgrHandler(*sourceMgr, &context);
- (void)(*translationRequested)(sourceMgr, os, &context);
- return sourceMgrHandler.verify();
+ return result;
};
if (failed(splitAndProcessBuffer(std::move(input), processBuffer,
diff --git a/mlir/test/Target/SPIRV/array-two-step-roundtrip.mlir b/mlir/test/Target/SPIRV/array-two-step-roundtrip.mlir
new file mode 100644
index 0000000000000..427b926527240
--- /dev/null
+++ b/mlir/test/Target/SPIRV/array-two-step-roundtrip.mlir
@@ -0,0 +1,18 @@
+// RUN: mlir-translate -no-implicit-module -split-input-file -serialize-spirv -deserialize-spirv %s | FileCheck %s
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+ spirv.func @array_stride(%arg0 : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32, stride=4>, stride=128>, StorageBuffer>, %arg1 : i32, %arg2 : i32) "None" {
+ // CHECK: {{%.*}} = spirv.AccessChain {{%.*}}[{{%.*}}, {{%.*}}] : !spirv.ptr<!spirv.array<4 x !spirv.array<4 x f32, stride=4>, stride=128>, StorageBuffer>, i32, i32
+ %2 = spirv.AccessChain %arg0[%arg1, %arg2] : !spirv.ptr<!spirv.array<4x!spirv.array<4xf32, stride=4>, stride=128>, StorageBuffer>, i32, i32
+ spirv.Return
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+ // CHECK: spirv.GlobalVariable {{@.*}} : !spirv.ptr<!spirv.rtarray<f32, stride=4>, StorageBuffer>
+ spirv.GlobalVariable @var0 : !spirv.ptr<!spirv.rtarray<f32, stride=4>, StorageBuffer>
+ // CHECK: spirv.GlobalVariable {{@.*}} : !spirv.ptr<!spirv.rtarray<vector<4xf16>>, Input>
+ spirv.GlobalVariable @var1 : !spirv.ptr<!spirv.rtarray<vector<4xf16>>, Input>
+}
More information about the Mlir-commits
mailing list