[Mlir-commits] [mlir] 6a72635 - Revert "Remove global dialect registration"
Mehdi Amini
llvmlistbot at llvm.org
Fri Oct 23 14:27:02 PDT 2020
Author: Mehdi Amini
Date: 2020-10-23T21:26:48Z
New Revision: 6a72635881e98dbac458323fe9666af6507a09ec
URL: https://github.com/llvm/llvm-project/commit/6a72635881e98dbac458323fe9666af6507a09ec
DIFF: https://github.com/llvm/llvm-project/commit/6a72635881e98dbac458323fe9666af6507a09ec.diff
LOG: Revert "Remove global dialect registration"
This reverts commit b22e2e4c6e420b78a8a4c307f0cf002f51af9590.
Investigating broken builds
Added:
Modified:
mlir/examples/standalone/standalone-opt/standalone-opt.cpp
mlir/examples/toy/Ch2/toyc.cpp
mlir/examples/toy/Ch3/toyc.cpp
mlir/examples/toy/Ch4/toyc.cpp
mlir/examples/toy/Ch5/toyc.cpp
mlir/examples/toy/Ch6/toyc.cpp
mlir/examples/toy/Ch7/toyc.cpp
mlir/include/mlir/IR/Dialect.h
mlir/include/mlir/IR/MLIRContext.h
mlir/include/mlir/InitAllDialects.h
mlir/lib/CAPI/IR/IR.cpp
mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp
mlir/lib/ExecutionEngine/JitRunner.cpp
mlir/lib/IR/Dialect.cpp
mlir/lib/IR/MLIRContext.cpp
mlir/lib/Support/MlirOptMain.cpp
mlir/lib/Translation/Translation.cpp
mlir/test/EDSC/builder-api-test.cpp
mlir/test/SDBM/sdbm-api-test.cpp
mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp
mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp
mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
mlir/tools/mlir-opt/mlir-opt.cpp
mlir/tools/mlir-reduce/mlir-reduce.cpp
mlir/tools/mlir-rocm-runner/mlir-rocm-runner.cpp
mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp
mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp
mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
mlir/unittests/IR/AttributeTest.cpp
mlir/unittests/IR/DialectTest.cpp
mlir/unittests/IR/OperationSupportTest.cpp
mlir/unittests/Pass/AnalysisManagerTest.cpp
mlir/unittests/SDBM/SDBMTest.cpp
mlir/unittests/TableGen/OpBuildGen.cpp
mlir/unittests/TableGen/StructsGenTest.cpp
Removed:
################################################################################
diff --git a/mlir/examples/standalone/standalone-opt/standalone-opt.cpp b/mlir/examples/standalone/standalone-opt/standalone-opt.cpp
index 0fb211add2cb..86cf67918446 100644
--- a/mlir/examples/standalone/standalone-opt/standalone-opt.cpp
+++ b/mlir/examples/standalone/standalone-opt/standalone-opt.cpp
@@ -22,6 +22,7 @@
#include "Standalone/StandaloneDialect.h"
int main(int argc, char **argv) {
+ mlir::registerAllDialects();
mlir::registerAllPasses();
// TODO: Register standalone passes here.
diff --git a/mlir/examples/toy/Ch2/toyc.cpp b/mlir/examples/toy/Ch2/toyc.cpp
index 1a3e5ddefacf..99232d8f24a4 100644
--- a/mlir/examples/toy/Ch2/toyc.cpp
+++ b/mlir/examples/toy/Ch2/toyc.cpp
@@ -68,7 +68,7 @@ std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
}
int dumpMLIR() {
- mlir::MLIRContext context;
+ mlir::MLIRContext context(/*loadAllDialects=*/false);
// Load our Dialect in this MLIR Context.
context.getOrLoadDialect<mlir::toy::ToyDialect>();
diff --git a/mlir/examples/toy/Ch3/toyc.cpp b/mlir/examples/toy/Ch3/toyc.cpp
index baef6d47b037..d0430ce16e54 100644
--- a/mlir/examples/toy/Ch3/toyc.cpp
+++ b/mlir/examples/toy/Ch3/toyc.cpp
@@ -102,7 +102,7 @@ int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context,
}
int dumpMLIR() {
- mlir::MLIRContext context;
+ mlir::MLIRContext context(/*loadAllDialects=*/false);
// Load our Dialect in this MLIR Context.
context.getOrLoadDialect<mlir::toy::ToyDialect>();
diff --git a/mlir/examples/toy/Ch4/toyc.cpp b/mlir/examples/toy/Ch4/toyc.cpp
index af3ae8748d24..9f95887d2707 100644
--- a/mlir/examples/toy/Ch4/toyc.cpp
+++ b/mlir/examples/toy/Ch4/toyc.cpp
@@ -103,7 +103,7 @@ int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context,
}
int dumpMLIR() {
- mlir::MLIRContext context;
+ mlir::MLIRContext context(/*loadAllDialects=*/false);
// Load our Dialect in this MLIR Context.
context.getOrLoadDialect<mlir::toy::ToyDialect>();
diff --git a/mlir/examples/toy/Ch5/toyc.cpp b/mlir/examples/toy/Ch5/toyc.cpp
index 94c3bd573cdd..16faac02fc60 100644
--- a/mlir/examples/toy/Ch5/toyc.cpp
+++ b/mlir/examples/toy/Ch5/toyc.cpp
@@ -106,7 +106,7 @@ int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context,
}
int dumpMLIR() {
- mlir::MLIRContext context;
+ mlir::MLIRContext context(/*loadAllDialects=*/false);
// Load our Dialect in this MLIR Context.
context.getOrLoadDialect<mlir::toy::ToyDialect>();
@@ -172,6 +172,8 @@ int dumpAST() {
}
int main(int argc, char **argv) {
+ mlir::registerAllDialects();
+
// Register any command line options.
mlir::registerAsmPrinterCLOptions();
mlir::registerMLIRContextCLOptions();
diff --git a/mlir/examples/toy/Ch6/toyc.cpp b/mlir/examples/toy/Ch6/toyc.cpp
index 2051089a18d3..9504a38b8784 100644
--- a/mlir/examples/toy/Ch6/toyc.cpp
+++ b/mlir/examples/toy/Ch6/toyc.cpp
@@ -241,6 +241,8 @@ int runJit(mlir::ModuleOp module) {
}
int main(int argc, char **argv) {
+ mlir::registerAllDialects();
+
// Register any command line options.
mlir::registerAsmPrinterCLOptions();
mlir::registerMLIRContextCLOptions();
@@ -253,7 +255,7 @@ int main(int argc, char **argv) {
// If we aren't dumping the AST, then we are compiling with/to MLIR.
- mlir::MLIRContext context;
+ mlir::MLIRContext context(/*loadAllDialects=*/false);
// Load our Dialect in this MLIR Context.
context.getOrLoadDialect<mlir::toy::ToyDialect>();
diff --git a/mlir/examples/toy/Ch7/toyc.cpp b/mlir/examples/toy/Ch7/toyc.cpp
index 2eb32a7290f7..cb3b455dc7ec 100644
--- a/mlir/examples/toy/Ch7/toyc.cpp
+++ b/mlir/examples/toy/Ch7/toyc.cpp
@@ -242,6 +242,8 @@ int runJit(mlir::ModuleOp module) {
}
int main(int argc, char **argv) {
+ mlir::registerAllDialects();
+
// Register any command line options.
mlir::registerAsmPrinterCLOptions();
mlir::registerMLIRContextCLOptions();
@@ -254,7 +256,7 @@ int main(int argc, char **argv) {
// If we aren't dumping the AST, then we are compiling with/to MLIR.
- mlir::MLIRContext context;
+ mlir::MLIRContext context(/*loadAllDialects=*/false);
// Load our Dialect in this MLIR Context.
context.getOrLoadDialect<mlir::toy::ToyDialect>();
diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h
index a3328c2d02bf..5bd8a745edce 100644
--- a/mlir/include/mlir/IR/Dialect.h
+++ b/mlir/include/mlir/IR/Dialect.h
@@ -284,6 +284,48 @@ class DialectRegistry {
MapTy registry;
};
+/// Deprecated: this provides a global registry for convenience, while we're
+/// transitioning the registration mechanism to a stateless approach.
+DialectRegistry &getGlobalDialectRegistry();
+
+/// This controls globally whether the dialect registry is / isn't enabled.
+/// This is deprecated and only intended to help the transition. It'll be
+/// removed soon.
+void enableGlobalDialectRegistry(bool);
+bool isGlobalDialectRegistryEnabled();
+
+/// Registers all dialects from the global registries with the
+/// specified MLIRContext. This won't load the dialects in the context,
+/// but only make them available for lazy loading by name.
+/// Note: This method is not thread-safe.
+/// Deprecated: this method will be deleted soon.
+void registerAllDialects(MLIRContext *context);
+
+/// Register and return the dialect with the given namespace in the provided
+/// context. Returns nullptr is there is no constructor registered for this
+/// dialect.
+inline Dialect *registerDialect(StringRef name, MLIRContext *context) {
+ return getGlobalDialectRegistry().loadByName(name, context);
+}
+
+/// Utility to register a dialect. Client can register their dialect with the
+/// global registry by calling registerDialect<MyDialect>();
+/// Note: This method is not thread-safe.
+template <typename ConcreteDialect> void registerDialect() {
+ getGlobalDialectRegistry().insert<ConcreteDialect>();
+}
+
+/// DialectRegistration provides a global initializer that registers a Dialect
+/// allocation routine.
+///
+/// Usage:
+///
+/// // At namespace scope.
+/// static DialectRegistration<MyDialect> Unused;
+template <typename ConcreteDialect> struct DialectRegistration {
+ DialectRegistration() { registerDialect<ConcreteDialect>(); }
+};
+
} // namespace mlir
namespace llvm {
diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h
index 7fb941d344b7..b7fe642166dd 100644
--- a/mlir/include/mlir/IR/MLIRContext.h
+++ b/mlir/include/mlir/IR/MLIRContext.h
@@ -24,6 +24,7 @@ class InFlightDiagnostic;
class Location;
class MLIRContextImpl;
class StorageUniquer;
+DialectRegistry &getGlobalDialectRegistry();
/// MLIRContext is the top-level object for a collection of MLIR modules. It
/// holds immortal uniqued objects like types, and the tables used to unique
@@ -39,7 +40,7 @@ class MLIRContext {
/// The loadAllDialects parameters allows to load all dialects from the global
/// registry on Context construction. It is deprecated and will be removed
/// soon.
- explicit MLIRContext();
+ explicit MLIRContext(bool loadAllDialects = true);
~MLIRContext();
/// Return information about all IR dialects loaded in the context.
@@ -87,6 +88,11 @@ class MLIRContext {
loadDialect<OtherDialect, MoreDialects...>();
}
+ /// Deprecated: load all globally registered dialects into this context.
+ /// This method will be removed soon, it can be used temporarily as we're
+ /// phasing out the global registry.
+ void loadAllGloballyRegisteredDialects();
+
/// Get (or create) a dialect for the given derived dialect name.
/// The dialect will be loaded from the registry if no dialect is found.
/// If no dialect is loaded for this name and none is available in the
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index e32d877946e5..060acdec5a13 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -64,6 +64,13 @@ inline void registerAllDialects(DialectRegistry ®istry) {
// clang-format on
}
+// This function should be called before creating any MLIRContext if one expect
+// all the possible dialects to be made available to the context automatically.
+inline void registerAllDialects() {
+ static bool initOnce =
+ ([]() { registerAllDialects(getGlobalDialectRegistry()); }(), true);
+ (void)initOnce;
+}
} // namespace mlir
#endif // MLIR_INITALLDIALECTS_H_
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index fdc40bc6c4f1..379770c8962f 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -26,7 +26,7 @@ using namespace mlir;
/* ========================================================================== */
MlirContext mlirContextCreate() {
- auto *context = new MLIRContext;
+ auto *context = new MLIRContext(/*loadAllDialects=*/false);
return wrap(context);
}
diff --git a/mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp b/mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp
index de148713315e..267c6b2aa9ac 100644
--- a/mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp
+++ b/mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp
@@ -136,7 +136,7 @@ static LogicalResult roundTripModule(ModuleOp srcModule, bool emitDebugInfo,
if (failed(spirv::serialize(*spirvModules.begin(), binary, emitDebugInfo)))
return failure();
- MLIRContext deserializationContext;
+ MLIRContext deserializationContext(false);
context->getDialectRegistry().loadAll(&deserializationContext);
// Then deserialize to get back a SPIR-V module.
spirv::OwningSPIRVModuleRef spirvModule =
diff --git a/mlir/lib/ExecutionEngine/JitRunner.cpp b/mlir/lib/ExecutionEngine/JitRunner.cpp
index 3727306f19f4..7d141e90edda 100644
--- a/mlir/lib/ExecutionEngine/JitRunner.cpp
+++ b/mlir/lib/ExecutionEngine/JitRunner.cpp
@@ -22,7 +22,6 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/StandardTypes.h"
-#include "mlir/InitAllDialects.h"
#include "mlir/Parser.h"
#include "mlir/Support/FileUtilities.h"
@@ -260,8 +259,8 @@ int mlir::JitRunnerMain(
}
}
- MLIRContext context;
- registerAllDialects(context.getDialectRegistry());
+ MLIRContext context(/*loadAllDialects=*/false);
+ registerAllDialects(&context);
auto m = parseMLIRInput(options.inputFilename, &context);
if (!m) {
diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp
index ff2b7a04be76..f356e8265ecf 100644
--- a/mlir/lib/IR/Dialect.cpp
+++ b/mlir/lib/IR/Dialect.cpp
@@ -22,6 +22,27 @@ using namespace detail;
DialectAsmParser::~DialectAsmParser() {}
+//===----------------------------------------------------------------------===//
+// Dialect Registration (DEPRECATED)
+//===----------------------------------------------------------------------===//
+
+/// Registry for all dialect allocation functions.
+static llvm::ManagedStatic<DialectRegistry> dialectRegistry;
+DialectRegistry &mlir::getGlobalDialectRegistry() { return *dialectRegistry; }
+
+// Note: deprecated, will be removed soon.
+static bool isGlobalDialectRegistryEnabledFlag = false;
+void mlir::enableGlobalDialectRegistry(bool enable) {
+ isGlobalDialectRegistryEnabledFlag = enable;
+}
+bool mlir::isGlobalDialectRegistryEnabled() {
+ return isGlobalDialectRegistryEnabledFlag;
+}
+
+void mlir::registerAllDialects(MLIRContext *context) {
+ dialectRegistry->appendTo(context->getDialectRegistry());
+}
+
Dialect *DialectRegistry::loadByName(StringRef name, MLIRContext *context) {
auto it = registry.find(name.str());
if (it == registry.end())
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index eec679c771f6..7551bb929970 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -361,7 +361,7 @@ class MLIRContextImpl {
};
} // end namespace mlir
-MLIRContext::MLIRContext() : impl(new MLIRContextImpl()) {
+MLIRContext::MLIRContext(bool loadAllDialects) : impl(new MLIRContextImpl()) {
// Initialize values based on the command line flags if they were provided.
if (clOptions.isConstructed()) {
disableMultithreading(clOptions->disableThreading);
@@ -369,8 +369,10 @@ MLIRContext::MLIRContext() : impl(new MLIRContextImpl()) {
printStackTraceOnDiagnostic(clOptions->printStackTraceOnDiagnostic);
}
- // Ensure the builtin dialect is always pre-loaded.
+ // Register dialects with this context.
getOrLoadDialect<BuiltinDialect>();
+ if (loadAllDialects)
+ loadAllGloballyRegisteredDialects();
// Initialize several common attributes and types to avoid the need to lock
// the context when accessing them.
@@ -518,6 +520,12 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
return dialect.get();
}
+void MLIRContext::loadAllGloballyRegisteredDialects() {
+ if (!isGlobalDialectRegistryEnabled())
+ return;
+ getGlobalDialectRegistry().loadAll(this);
+}
+
bool MLIRContext::allowsUnregisteredDialects() {
return impl->allowUnregisteredDialects;
}
diff --git a/mlir/lib/Support/MlirOptMain.cpp b/mlir/lib/Support/MlirOptMain.cpp
index 1c2a1fe18ca3..77b07605407b 100644
--- a/mlir/lib/Support/MlirOptMain.cpp
+++ b/mlir/lib/Support/MlirOptMain.cpp
@@ -89,7 +89,7 @@ static LogicalResult processBuffer(raw_ostream &os,
sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
// Parse the input file.
- MLIRContext context;
+ MLIRContext context(/*loadAllDialects=*/preloadDialectsInContext);
registry.appendTo(context.getDialectRegistry());
if (preloadDialectsInContext)
registry.loadAll(&context);
diff --git a/mlir/lib/Translation/Translation.cpp b/mlir/lib/Translation/Translation.cpp
index 7f47cc25f720..991bdf95c6cd 100644
--- a/mlir/lib/Translation/Translation.cpp
+++ b/mlir/lib/Translation/Translation.cpp
@@ -175,7 +175,7 @@ LogicalResult mlir::mlirTranslateMain(int argc, char **argv,
// Processes the memory buffer with a new MLIRContext.
auto processBuffer = [&](std::unique_ptr<llvm::MemoryBuffer> ownedBuffer,
raw_ostream &os) {
- MLIRContext context;
+ MLIRContext context(false);
context.printOpOnDiagnostic(!verifyDiagnostics);
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), llvm::SMLoc());
diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp
index 1a866066523e..f36b93ecc9dd 100644
--- a/mlir/test/EDSC/builder-api-test.cpp
+++ b/mlir/test/EDSC/builder-api-test.cpp
@@ -36,7 +36,7 @@ using namespace mlir::edsc;
using namespace mlir::edsc::intrinsics;
static MLIRContext &globalContext() {
- static thread_local MLIRContext context;
+ static thread_local MLIRContext context(/*loadAllDialects=*/false);
static thread_local bool initOnce = [&]() {
// clang-format off
context.loadDialect<AffineDialect,
diff --git a/mlir/test/SDBM/sdbm-api-test.cpp b/mlir/test/SDBM/sdbm-api-test.cpp
index 027c584c7409..ddefc52fb461 100644
--- a/mlir/test/SDBM/sdbm-api-test.cpp
+++ b/mlir/test/SDBM/sdbm-api-test.cpp
@@ -21,7 +21,7 @@ using namespace mlir;
static MLIRContext *ctx() {
- static thread_local MLIRContext context;
+ static thread_local MLIRContext context(/*loadAllDialects=*/false);
static thread_local bool once =
(context.getOrLoadDialect<SDBMDialect>(), true);
(void)once;
diff --git a/mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp b/mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp
index 7667908c39b3..3e93bc6cec64 100644
--- a/mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp
+++ b/mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp
@@ -19,6 +19,7 @@
#include "llvm/Support/TargetSelect.h"
int main(int argc, char **argv) {
+ mlir::registerAllDialects();
llvm::InitLLVM y(argc, argv);
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
diff --git a/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp b/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp
index be00646bd0ce..208e27a75f78 100644
--- a/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp
+++ b/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp
@@ -125,6 +125,7 @@ static LogicalResult runMLIRPasses(ModuleOp m) {
int main(int argc, char **argv) {
registerPassManagerCLOptions();
+ mlir::registerAllDialects();
llvm::InitLLVM y(argc, argv);
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
index 82b90ec53b26..b0e38ef83dae 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
@@ -1757,7 +1757,7 @@ int main(int argc, char **argv) {
if (testEmitIncludeTdHeader)
output->os() << "include \"mlir/Dialect/Linalg/IR/LinalgStructuredOps.td\"";
- MLIRContext context;
+ MLIRContext context(/*loadAllDialects=*/false);
llvm::SourceMgr mgr;
mgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc());
Parser parser(mgr, &context);
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 2322de20230d..ef55e4f0f436 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -144,6 +144,7 @@ void registerTestPasses() {
#endif
int main(int argc, char **argv) {
+ registerAllDialects();
registerAllPasses();
#ifdef MLIR_INCLUDE_TESTS
registerTestPasses();
diff --git a/mlir/tools/mlir-reduce/mlir-reduce.cpp b/mlir/tools/mlir-reduce/mlir-reduce.cpp
index e493fabd6e32..3397bfc2aaf9 100644
--- a/mlir/tools/mlir-reduce/mlir-reduce.cpp
+++ b/mlir/tools/mlir-reduce/mlir-reduce.cpp
@@ -71,6 +71,7 @@ int main(int argc, char **argv) {
llvm::InitLLVM y(argc, argv);
+ registerAllDialects();
registerMLIRContextCLOptions();
registerPassManagerCLOptions();
diff --git a/mlir/tools/mlir-rocm-runner/mlir-rocm-runner.cpp b/mlir/tools/mlir-rocm-runner/mlir-rocm-runner.cpp
index 8cd391a0be46..41d03b21c8f7 100644
--- a/mlir/tools/mlir-rocm-runner/mlir-rocm-runner.cpp
+++ b/mlir/tools/mlir-rocm-runner/mlir-rocm-runner.cpp
@@ -322,6 +322,7 @@ static LogicalResult runMLIRPasses(ModuleOp m) {
int main(int argc, char **argv) {
registerPassManagerCLOptions();
+ mlir::registerAllDialects();
llvm::InitLLVM y(argc, argv);
llvm::InitializeAllTargetInfos();
llvm::InitializeAllTargetMCs();
diff --git a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
index 905d2e422115..c792f38bdb82 100644
--- a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
+++ b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp
@@ -53,6 +53,7 @@ int main(int argc, char **argv) {
llvm::llvm_shutdown_obj x;
registerPassManagerCLOptions();
+ mlir::registerAllDialects();
llvm::InitLLVM y(argc, argv);
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
diff --git a/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp b/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp
index 85735c040bce..bae95e1a13b6 100644
--- a/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp
+++ b/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp
@@ -75,7 +75,7 @@ UniformQuantizedType getTestQuantizedType(Type storageType, MLIRContext *ctx) {
}
TEST(QuantizationUtilsTest, convertFloatAttrUniform) {
- MLIRContext ctx;
+ MLIRContext ctx(/*loadAllDialects=*/false);
ctx.getOrLoadDialect<QuantizationDialect>();
IntegerType convertedType = IntegerType::get(8, &ctx);
auto quantizedType = getTestQuantizedType(convertedType, &ctx);
@@ -93,7 +93,7 @@ TEST(QuantizationUtilsTest, convertFloatAttrUniform) {
}
TEST(QuantizationUtilsTest, convertRankedDenseAttrUniform) {
- MLIRContext ctx;
+ MLIRContext ctx(/*loadAllDialects=*/false);
ctx.getOrLoadDialect<QuantizationDialect>();
IntegerType convertedType = IntegerType::get(8, &ctx);
auto quantizedType = getTestQuantizedType(convertedType, &ctx);
@@ -118,7 +118,7 @@ TEST(QuantizationUtilsTest, convertRankedDenseAttrUniform) {
}
TEST(QuantizationUtilsTest, convertRankedSplatAttrUniform) {
- MLIRContext ctx;
+ MLIRContext ctx(/*loadAllDialects=*/false);
ctx.getOrLoadDialect<QuantizationDialect>();
IntegerType convertedType = IntegerType::get(8, &ctx);
auto quantizedType = getTestQuantizedType(convertedType, &ctx);
@@ -143,7 +143,7 @@ TEST(QuantizationUtilsTest, convertRankedSplatAttrUniform) {
}
TEST(QuantizationUtilsTest, convertRankedSparseAttrUniform) {
- MLIRContext ctx;
+ MLIRContext ctx(/*loadAllDialects=*/false);
ctx.getOrLoadDialect<QuantizationDialect>();
IntegerType convertedType = IntegerType::get(8, &ctx);
auto quantizedType = getTestQuantizedType(convertedType, &ctx);
diff --git a/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp b/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp
index ccec33951870..4aa2ffed7e2b 100644
--- a/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp
+++ b/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp
@@ -25,6 +25,9 @@
using namespace mlir;
+/// Load the SPIRV dialect.
+static DialectRegistration<spirv::SPIRVDialect> SPIRVRegistration;
+
using ::testing::StrEq;
//===----------------------------------------------------------------------===//
@@ -35,7 +38,7 @@ using ::testing::StrEq;
/// diagnostic checking utilities.
class DeserializationTest : public ::testing::Test {
protected:
- DeserializationTest() {
+ DeserializationTest() : context(/*loadAllDialects=*/false) {
context.getOrLoadDialect<mlir::spirv::SPIRVDialect>();
// Register a diagnostic handler to capture the diagnostic so that we can
// check it later.
diff --git a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
index 6bd16b964c29..cb89cd61de7b 100644
--- a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
+++ b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
@@ -36,7 +36,7 @@ using namespace mlir;
class SerializationTest : public ::testing::Test {
protected:
- SerializationTest() {
+ SerializationTest() : context(/*loadAllDialects=*/false) {
context.getOrLoadDialect<mlir::spirv::SPIRVDialect>();
createModuleOp();
}
diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp
index df449a0da75c..78f7dd53d8fd 100644
--- a/mlir/unittests/IR/AttributeTest.cpp
+++ b/mlir/unittests/IR/AttributeTest.cpp
@@ -32,7 +32,7 @@ static void testSplat(Type eltType, const EltTy &splatElt) {
namespace {
TEST(DenseSplatTest, BoolSplat) {
- MLIRContext context;
+ MLIRContext context(false);
IntegerType boolTy = IntegerType::get(1, &context);
RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy);
@@ -57,7 +57,7 @@ TEST(DenseSplatTest, BoolSplat) {
TEST(DenseSplatTest, LargeBoolSplat) {
constexpr int64_t boolCount = 56;
- MLIRContext context;
+ MLIRContext context(false);
IntegerType boolTy = IntegerType::get(1, &context);
RankedTensorType shape = RankedTensorType::get({boolCount}, boolTy);
@@ -80,7 +80,7 @@ TEST(DenseSplatTest, LargeBoolSplat) {
}
TEST(DenseSplatTest, BoolNonSplat) {
- MLIRContext context;
+ MLIRContext context(false);
IntegerType boolTy = IntegerType::get(1, &context);
RankedTensorType shape = RankedTensorType::get({6}, boolTy);
@@ -92,7 +92,7 @@ TEST(DenseSplatTest, BoolNonSplat) {
TEST(DenseSplatTest, OddIntSplat) {
// Test detecting a splat with an odd(non 8-bit) integer bitwidth.
- MLIRContext context;
+ MLIRContext context(false);
constexpr size_t intWidth = 19;
IntegerType intTy = IntegerType::get(intWidth, &context);
APInt value(intWidth, 10);
@@ -101,7 +101,7 @@ TEST(DenseSplatTest, OddIntSplat) {
}
TEST(DenseSplatTest, Int32Splat) {
- MLIRContext context;
+ MLIRContext context(false);
IntegerType intTy = IntegerType::get(32, &context);
int value = 64;
@@ -109,7 +109,7 @@ TEST(DenseSplatTest, Int32Splat) {
}
TEST(DenseSplatTest, IntAttrSplat) {
- MLIRContext context;
+ MLIRContext context(false);
IntegerType intTy = IntegerType::get(85, &context);
Attribute value = IntegerAttr::get(intTy, 109);
@@ -117,7 +117,7 @@ TEST(DenseSplatTest, IntAttrSplat) {
}
TEST(DenseSplatTest, F32Splat) {
- MLIRContext context;
+ MLIRContext context(false);
FloatType floatTy = FloatType::getF32(&context);
float value = 10.0;
@@ -125,7 +125,7 @@ TEST(DenseSplatTest, F32Splat) {
}
TEST(DenseSplatTest, F64Splat) {
- MLIRContext context;
+ MLIRContext context(false);
FloatType floatTy = FloatType::getF64(&context);
double value = 10.0;
@@ -133,7 +133,7 @@ TEST(DenseSplatTest, F64Splat) {
}
TEST(DenseSplatTest, FloatAttrSplat) {
- MLIRContext context;
+ MLIRContext context(false);
FloatType floatTy = FloatType::getF32(&context);
Attribute value = FloatAttr::get(floatTy, 10.0);
@@ -141,7 +141,7 @@ TEST(DenseSplatTest, FloatAttrSplat) {
}
TEST(DenseSplatTest, BF16Splat) {
- MLIRContext context;
+ MLIRContext context(false);
FloatType floatTy = FloatType::getBF16(&context);
Attribute value = FloatAttr::get(floatTy, 10.0);
@@ -149,7 +149,7 @@ TEST(DenseSplatTest, BF16Splat) {
}
TEST(DenseSplatTest, StringSplat) {
- MLIRContext context;
+ MLIRContext context(false);
Type stringType =
OpaqueType::get(Identifier::get("test", &context), "string", &context);
StringRef value = "test-string";
@@ -157,7 +157,7 @@ TEST(DenseSplatTest, StringSplat) {
}
TEST(DenseSplatTest, StringAttrSplat) {
- MLIRContext context;
+ MLIRContext context(false);
Type stringType =
OpaqueType::get(Identifier::get("test", &context), "string", &context);
Attribute stringAttr = StringAttr::get("test-string", stringType);
@@ -165,28 +165,28 @@ TEST(DenseSplatTest, StringAttrSplat) {
}
TEST(DenseComplexTest, ComplexFloatSplat) {
- MLIRContext context;
+ MLIRContext context(false);
ComplexType complexType = ComplexType::get(FloatType::getF32(&context));
std::complex<float> value(10.0, 15.0);
testSplat(complexType, value);
}
TEST(DenseComplexTest, ComplexIntSplat) {
- MLIRContext context;
+ MLIRContext context(false);
ComplexType complexType = ComplexType::get(IntegerType::get(64, &context));
std::complex<int64_t> value(10, 15);
testSplat(complexType, value);
}
TEST(DenseComplexTest, ComplexAPFloatSplat) {
- MLIRContext context;
+ MLIRContext context(false);
ComplexType complexType = ComplexType::get(FloatType::getF32(&context));
std::complex<APFloat> value(APFloat(10.0f), APFloat(15.0f));
testSplat(complexType, value);
}
TEST(DenseComplexTest, ComplexAPIntSplat) {
- MLIRContext context;
+ MLIRContext context(false);
ComplexType complexType = ComplexType::get(IntegerType::get(64, &context));
std::complex<APInt> value(APInt(64, 10), APInt(64, 15));
testSplat(complexType, value);
diff --git a/mlir/unittests/IR/DialectTest.cpp b/mlir/unittests/IR/DialectTest.cpp
index 2410be0263b3..5a0a229d2c82 100644
--- a/mlir/unittests/IR/DialectTest.cpp
+++ b/mlir/unittests/IR/DialectTest.cpp
@@ -26,7 +26,7 @@ struct AnotherTestDialect : public Dialect {
};
TEST(DialectDeathTest, MultipleDialectsWithSameNamespace) {
- MLIRContext context;
+ MLIRContext context(false);
// Registering a dialect with the same namespace twice should result in a
// failure.
diff --git a/mlir/unittests/IR/OperationSupportTest.cpp b/mlir/unittests/IR/OperationSupportTest.cpp
index 356e849bc495..3b905c2ac4b2 100644
--- a/mlir/unittests/IR/OperationSupportTest.cpp
+++ b/mlir/unittests/IR/OperationSupportTest.cpp
@@ -26,7 +26,7 @@ static Operation *createOp(MLIRContext *context,
namespace {
TEST(OperandStorageTest, NonResizable) {
- MLIRContext context;
+ MLIRContext context(false);
Builder builder(&context);
Operation *useOp =
@@ -50,7 +50,7 @@ TEST(OperandStorageTest, NonResizable) {
}
TEST(OperandStorageTest, Resizable) {
- MLIRContext context;
+ MLIRContext context(false);
Builder builder(&context);
Operation *useOp =
@@ -78,7 +78,7 @@ TEST(OperandStorageTest, Resizable) {
}
TEST(OperandStorageTest, RangeReplace) {
- MLIRContext context;
+ MLIRContext context(false);
Builder builder(&context);
Operation *useOp =
@@ -114,7 +114,7 @@ TEST(OperandStorageTest, RangeReplace) {
}
TEST(OperandStorageTest, MutableRange) {
- MLIRContext context;
+ MLIRContext context(false);
Builder builder(&context);
Operation *useOp =
@@ -151,7 +151,7 @@ TEST(OperandStorageTest, MutableRange) {
}
TEST(OperationOrderTest, OrderIsAlwaysValid) {
- MLIRContext context;
+ MLIRContext context(false);
Builder builder(&context);
Operation *containerOp =
diff --git a/mlir/unittests/Pass/AnalysisManagerTest.cpp b/mlir/unittests/Pass/AnalysisManagerTest.cpp
index 55f953aacaff..7cefd6870eeb 100644
--- a/mlir/unittests/Pass/AnalysisManagerTest.cpp
+++ b/mlir/unittests/Pass/AnalysisManagerTest.cpp
@@ -29,7 +29,7 @@ struct OpSpecificAnalysis {
};
TEST(AnalysisManagerTest, FineGrainModuleAnalysisPreservation) {
- MLIRContext context;
+ MLIRContext context(false);
// Test fine grain invalidation of the module analysis manager.
OwningModuleRef module(ModuleOp::create(UnknownLoc::get(&context)));
@@ -50,7 +50,7 @@ TEST(AnalysisManagerTest, FineGrainModuleAnalysisPreservation) {
}
TEST(AnalysisManagerTest, FineGrainFunctionAnalysisPreservation) {
- MLIRContext context;
+ MLIRContext context(false);
Builder builder(&context);
// Create a function and a module.
@@ -79,7 +79,7 @@ TEST(AnalysisManagerTest, FineGrainFunctionAnalysisPreservation) {
}
TEST(AnalysisManagerTest, FineGrainChildFunctionAnalysisPreservation) {
- MLIRContext context;
+ MLIRContext context(false);
Builder builder(&context);
// Create a function and a module.
@@ -122,7 +122,7 @@ struct CustomInvalidatingAnalysis {
};
TEST(AnalysisManagerTest, CustomInvalidation) {
- MLIRContext context;
+ MLIRContext context(false);
Builder builder(&context);
// Create a function and a module.
diff --git a/mlir/unittests/SDBM/SDBMTest.cpp b/mlir/unittests/SDBM/SDBMTest.cpp
index c907aed6258a..bbe87e3d292c 100644
--- a/mlir/unittests/SDBM/SDBMTest.cpp
+++ b/mlir/unittests/SDBM/SDBMTest.cpp
@@ -19,7 +19,7 @@ using namespace mlir;
static MLIRContext *ctx() {
- static thread_local MLIRContext context;
+ static thread_local MLIRContext context(false);
context.getOrLoadDialect<SDBMDialect>();
return &context;
}
diff --git a/mlir/unittests/TableGen/OpBuildGen.cpp b/mlir/unittests/TableGen/OpBuildGen.cpp
index 1f3c6381d4e0..09a02b4f90fd 100644
--- a/mlir/unittests/TableGen/OpBuildGen.cpp
+++ b/mlir/unittests/TableGen/OpBuildGen.cpp
@@ -26,7 +26,7 @@ namespace mlir {
//===----------------------------------------------------------------------===//
static MLIRContext &getContext() {
- static MLIRContext ctx;
+ static MLIRContext ctx(false);
ctx.getOrLoadDialect<TestDialect>();
return ctx;
}
diff --git a/mlir/unittests/TableGen/StructsGenTest.cpp b/mlir/unittests/TableGen/StructsGenTest.cpp
index d089ac5f37c8..d2acb28ebfb1 100644
--- a/mlir/unittests/TableGen/StructsGenTest.cpp
+++ b/mlir/unittests/TableGen/StructsGenTest.cpp
@@ -44,7 +44,7 @@ static test::TestStruct getTestStruct(mlir::MLIRContext *context) {
/// Validates that test::TestStruct::classof correctly identifies a valid
/// test::TestStruct.
TEST(StructsGenTest, ClassofTrue) {
- mlir::MLIRContext context;
+ mlir::MLIRContext context(false);
auto structAttr = getTestStruct(&context);
ASSERT_TRUE(test::TestStruct::classof(structAttr));
}
More information about the Mlir-commits
mailing list