[Mlir-commits] [mlir] f164534 - Add a `dialect_registration` callback for "translations" registered with mlir-translate

Mehdi Amini llvmlistbot at llvm.org
Sat Aug 22 18:02:09 PDT 2020


Author: Mehdi Amini
Date: 2020-08-23T01:00:39Z
New Revision: f164534ca8e042ab7bbc25516f88adf027ebe12d

URL: https://github.com/llvm/llvm-project/commit/f164534ca8e042ab7bbc25516f88adf027ebe12d
DIFF: https://github.com/llvm/llvm-project/commit/f164534ca8e042ab7bbc25516f88adf027ebe12d.diff

LOG: Add a `dialect_registration` callback for "translations" registered with mlir-translate

This will allow out-of-tree translation to register the dialects they expect
to see in their input, on the model of getDependentDialects() for passes.

Differential Revision: https://reviews.llvm.org/D86409

Added: 
    

Modified: 
    mlir/include/mlir/Translation.h
    mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp
    mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp
    mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp
    mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp
    mlir/lib/Target/LLVMIR/LLVMAVX512Intr.cpp
    mlir/lib/Translation/Translation.cpp
    mlir/tools/mlir-translate/mlir-translate.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Translation.h b/mlir/include/mlir/Translation.h
index bdc391ac1e5b..ab9dad8b6d91 100644
--- a/mlir/include/mlir/Translation.h
+++ b/mlir/include/mlir/Translation.h
@@ -76,8 +76,10 @@ struct TranslateToMLIRRegistration {
 };
 
 struct TranslateFromMLIRRegistration {
-  TranslateFromMLIRRegistration(llvm::StringRef name,
-                                const TranslateFromMLIRFunction &function);
+  TranslateFromMLIRRegistration(
+      llvm::StringRef name, const TranslateFromMLIRFunction &function,
+      std::function<void(DialectRegistry &)> dialectRegistration =
+          [](DialectRegistry &) {});
 };
 struct TranslateRegistration {
   TranslateRegistration(llvm::StringRef name,

diff  --git a/mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp b/mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp
index d0aa2774a34c..8a6032f8f417 100644
--- a/mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp
+++ b/mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp
@@ -11,10 +11,12 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/SPIRVModule.h"
 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/Serialization.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/Dialect.h"
 #include "mlir/IR/Function.h"
 #include "mlir/IR/Module.h"
 #include "mlir/Parser.h"
@@ -105,8 +107,12 @@ static LogicalResult serializeModule(ModuleOp module, raw_ostream &output) {
 namespace mlir {
 void registerToSPIRVTranslation() {
   TranslateFromMLIRRegistration toBinary(
-      "serialize-spirv", [](ModuleOp module, raw_ostream &output) {
+      "serialize-spirv",
+      [](ModuleOp module, raw_ostream &output) {
         return serializeModule(module, output);
+      },
+      [](DialectRegistry &registry) {
+        registry.insert<spirv::SPIRVDialect>();
       });
 }
 } // namespace mlir
@@ -147,15 +153,23 @@ static LogicalResult roundTripModule(ModuleOp srcModule, bool emitDebugInfo,
 namespace mlir {
 void registerTestRoundtripSPIRV() {
   TranslateFromMLIRRegistration roundtrip(
-      "test-spirv-roundtrip", [](ModuleOp module, raw_ostream &output) {
+      "test-spirv-roundtrip",
+      [](ModuleOp module, raw_ostream &output) {
         return roundTripModule(module, /*emitDebugInfo=*/false, output);
+      },
+      [](DialectRegistry &registry) {
+        registry.insert<spirv::SPIRVDialect>();
       });
 }
 
 void registerTestRoundtripDebugSPIRV() {
   TranslateFromMLIRRegistration roundtrip(
-      "test-spirv-roundtrip-debug", [](ModuleOp module, raw_ostream &output) {
+      "test-spirv-roundtrip-debug",
+      [](ModuleOp module, raw_ostream &output) {
         return roundTripModule(module, /*emitDebugInfo=*/true, output);
+      },
+      [](DialectRegistry &registry) {
+        registry.insert<spirv::SPIRVDialect>();
       });
 }
 } // namespace mlir

diff  --git a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp
index 6027ab344a56..89fb5b1e2214 100644
--- a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp
+++ b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp
@@ -30,7 +30,8 @@ mlir::translateModuleToLLVMIR(ModuleOp m, llvm::LLVMContext &llvmContext,
 namespace mlir {
 void registerToLLVMIRTranslation() {
   TranslateFromMLIRRegistration registration(
-      "mlir-to-llvmir", [](ModuleOp module, raw_ostream &output) {
+      "mlir-to-llvmir",
+      [](ModuleOp module, raw_ostream &output) {
         llvm::LLVMContext llvmContext;
         auto llvmModule = LLVM::ModuleTranslation::translateModule<>(
             module, llvmContext, "LLVMDialectModule");
@@ -39,6 +40,7 @@ void registerToLLVMIRTranslation() {
 
         llvmModule->print(output, nullptr);
         return success();
-      });
+      },
+      [](DialectRegistry &registry) { registry.insert<LLVM::LLVMDialect>(); });
 }
 } // namespace mlir

diff  --git a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp
index fc2f650840c1..bee7d18d7536 100644
--- a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp
+++ b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp
@@ -99,7 +99,8 @@ mlir::translateModuleToNVVMIR(Operation *m, llvm::LLVMContext &llvmContext,
 namespace mlir {
 void registerToNVVMIRTranslation() {
   TranslateFromMLIRRegistration registration(
-      "mlir-to-nvvmir", [](ModuleOp module, raw_ostream &output) {
+      "mlir-to-nvvmir",
+      [](ModuleOp module, raw_ostream &output) {
         llvm::LLVMContext llvmContext;
         auto llvmModule = mlir::translateModuleToNVVMIR(module, llvmContext);
         if (!llvmModule)
@@ -107,6 +108,9 @@ void registerToNVVMIRTranslation() {
 
         llvmModule->print(output, nullptr);
         return success();
+      },
+      [](DialectRegistry &registry) {
+        registry.insert<LLVM::LLVMDialect, NVVM::NVVMDialect>();
       });
 }
 } // namespace mlir

diff  --git a/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp
index 5bd04e4e3ef8..6bfcf197f8c3 100644
--- a/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp
+++ b/mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp
@@ -103,7 +103,8 @@ mlir::translateModuleToROCDLIR(Operation *m, llvm::LLVMContext &llvmContext,
 namespace mlir {
 void registerToROCDLIRTranslation() {
   TranslateFromMLIRRegistration registration(
-      "mlir-to-rocdlir", [](ModuleOp module, raw_ostream &output) {
+      "mlir-to-rocdlir",
+      [](ModuleOp module, raw_ostream &output) {
         llvm::LLVMContext llvmContext;
         auto llvmModule = mlir::translateModuleToROCDLIR(module, llvmContext);
         if (!llvmModule)
@@ -111,6 +112,9 @@ void registerToROCDLIRTranslation() {
 
         llvmModule->print(output, nullptr);
         return success();
+      },
+      [](DialectRegistry &registry) {
+        registry.insert<ROCDL::ROCDLDialect, LLVM::LLVMDialect>();
       });
 }
 } // namespace mlir

diff  --git a/mlir/lib/Target/LLVMIR/LLVMAVX512Intr.cpp b/mlir/lib/Target/LLVMIR/LLVMAVX512Intr.cpp
index 64f9fef91847..52f1792ba54c 100644
--- a/mlir/lib/Target/LLVMIR/LLVMAVX512Intr.cpp
+++ b/mlir/lib/Target/LLVMIR/LLVMAVX512Intr.cpp
@@ -45,7 +45,8 @@ translateLLVMAVX512ModuleToLLVMIR(Operation *m, llvm::LLVMContext &llvmContext,
 namespace mlir {
 void registerAVX512ToLLVMIRTranslation() {
   TranslateFromMLIRRegistration reg(
-      "avx512-mlir-to-llvmir", [](ModuleOp module, raw_ostream &output) {
+      "avx512-mlir-to-llvmir",
+      [](ModuleOp module, raw_ostream &output) {
         llvm::LLVMContext llvmContext;
         auto llvmModule = translateLLVMAVX512ModuleToLLVMIR(
             module, llvmContext, "LLVMDialectModule");
@@ -54,6 +55,9 @@ void registerAVX512ToLLVMIRTranslation() {
 
         llvmModule->print(output, nullptr);
         return success();
+      },
+      [](DialectRegistry &registry) {
+        registry.insert<LLVM::LLVMAVX512Dialect, LLVM::LLVMDialect>();
       });
 }
 } // namespace mlir

diff  --git a/mlir/lib/Translation/Translation.cpp b/mlir/lib/Translation/Translation.cpp
index 99f0acdd1b95..991bdf95c6cd 100644
--- a/mlir/lib/Translation/Translation.cpp
+++ b/mlir/lib/Translation/Translation.cpp
@@ -92,10 +92,12 @@ TranslateToMLIRRegistration::TranslateToMLIRRegistration(
 //===----------------------------------------------------------------------===//
 
 TranslateFromMLIRRegistration::TranslateFromMLIRRegistration(
-    StringRef name, const TranslateFromMLIRFunction &function) {
-  registerTranslation(name, [function](llvm::SourceMgr &sourceMgr,
-                                       raw_ostream &output,
-                                       MLIRContext *context) {
+    StringRef name, const TranslateFromMLIRFunction &function,
+    std::function<void(DialectRegistry &)> dialectRegistration) {
+  registerTranslation(name, [function, dialectRegistration](
+                                llvm::SourceMgr &sourceMgr, raw_ostream &output,
+                                MLIRContext *context) {
+    dialectRegistration(context->getDialectRegistry());
     auto module = OwningModuleRef(parseSourceFile(sourceMgr, context));
     if (!module)
       return failure();
@@ -173,7 +175,7 @@ LogicalResult mlir::mlirTranslateMain(int argc, char **argv,
   // Processes the memory buffer with a new MLIRContext.
   auto processBuffer = [&](std::unique_ptr<llvm::MemoryBuffer> ownedBuffer,
                            raw_ostream &os) {
-    MLIRContext context;
+    MLIRContext context(false);
     context.printOpOnDiagnostic(!verifyDiagnostics);
     llvm::SourceMgr sourceMgr;
     sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), llvm::SMLoc());

diff  --git a/mlir/tools/mlir-translate/mlir-translate.cpp b/mlir/tools/mlir-translate/mlir-translate.cpp
index 6388a4c8b954..cf84856ddb84 100644
--- a/mlir/tools/mlir-translate/mlir-translate.cpp
+++ b/mlir/tools/mlir-translate/mlir-translate.cpp
@@ -32,7 +32,5 @@ static void registerTestTranslations() {
 int main(int argc, char **argv) {
   registerAllTranslations();
   registerTestTranslations();
-  // TODO: remove the global dialect registry
-  registerAllDialects();
   return failed(mlirTranslateMain(argc, argv, "MLIR Translation Testing Tool"));
 }


        


More information about the Mlir-commits mailing list