[Mlir-commits] [mlir] 25ee851 - Revert "Separate the Registration from Loading dialects in the Context"
Mehdi Amini
llvmlistbot at llvm.org
Sat Aug 15 02:22:24 PDT 2020
Author: Mehdi Amini
Date: 2020-08-15T09:21:47Z
New Revision: 25ee851746dfb43588d6718e97ffe5b305c49d1f
URL: https://github.com/llvm/llvm-project/commit/25ee851746dfb43588d6718e97ffe5b305c49d1f
DIFF: https://github.com/llvm/llvm-project/commit/25ee851746dfb43588d6718e97ffe5b305c49d1f.diff
LOG: Revert "Separate the Registration from Loading dialects in the Context"
This reverts commit 20563933875a9396c8ace9c9770ecf6a988c4ea6.
Build is broken on a few bots
Added:
Modified:
flang/unittests/Lower/OpenMPLoweringTest.cpp
mlir/examples/standalone/standalone-opt/standalone-opt.cpp
mlir/examples/toy/Ch2/toyc.cpp
mlir/examples/toy/Ch3/toyc.cpp
mlir/examples/toy/Ch4/toyc.cpp
mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
mlir/examples/toy/Ch5/toyc.cpp
mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
mlir/examples/toy/Ch6/toyc.cpp
mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
mlir/examples/toy/Ch7/toyc.cpp
mlir/include/mlir-c/IR.h
mlir/include/mlir/Conversion/Passes.td
mlir/include/mlir/Dialect/Affine/Passes.td
mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h
mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
mlir/include/mlir/Dialect/Linalg/Passes.td
mlir/include/mlir/Dialect/SCF/Passes.td
mlir/include/mlir/IR/Dialect.h
mlir/include/mlir/IR/FunctionSupport.h
mlir/include/mlir/IR/MLIRContext.h
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/InitAllDialects.h
mlir/include/mlir/InitAllTranslations.h
mlir/include/mlir/Pass/Pass.h
mlir/include/mlir/Pass/PassBase.td
mlir/include/mlir/Pass/PassManager.h
mlir/include/mlir/Support/MlirOptMain.h
mlir/include/mlir/TableGen/Dialect.h
mlir/include/mlir/TableGen/Pass.h
mlir/include/mlir/Transforms/Passes.td
mlir/lib/CAPI/IR/IR.cpp
mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
mlir/lib/Conversion/PassDetail.h
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
mlir/lib/Dialect/Affine/Transforms/PassDetail.h
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/PassDetail.h
mlir/lib/Dialect/SCF/Transforms/PassDetail.h
mlir/lib/Dialect/SDBM/SDBMExpr.cpp
mlir/lib/ExecutionEngine/JitRunner.cpp
mlir/lib/IR/Dialect.cpp
mlir/lib/IR/MLIRContext.cpp
mlir/lib/IR/Operation.cpp
mlir/lib/IR/Verifier.cpp
mlir/lib/Parser/AttributeParser.cpp
mlir/lib/Parser/DialectSymbolParser.cpp
mlir/lib/Parser/Parser.cpp
mlir/lib/Pass/Pass.cpp
mlir/lib/Pass/PassDetail.h
mlir/lib/Support/MlirOptMain.cpp
mlir/lib/TableGen/Dialect.cpp
mlir/lib/TableGen/Pass.cpp
mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
mlir/lib/Transforms/PassDetail.h
mlir/test/CAPI/ir.c
mlir/test/EDSC/builder-api-test.cpp
mlir/test/SDBM/sdbm-api-test.cpp
mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp
mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
mlir/test/lib/Dialect/Test/TestPatterns.cpp
mlir/test/lib/Transforms/TestAllReduceLowering.cpp
mlir/test/lib/Transforms/TestBufferPlacement.cpp
mlir/test/lib/Transforms/TestGpuMemoryPromotion.cpp
mlir/test/lib/Transforms/TestLinalgHoisting.cpp
mlir/test/lib/Transforms/TestLinalgTransforms.cpp
mlir/test/lib/Transforms/TestVectorTransforms.cpp
mlir/test/mlir-opt/commandline.mlir
mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
mlir/tools/mlir-opt/mlir-opt.cpp
mlir/tools/mlir-tblgen/DialectGen.cpp
mlir/tools/mlir-tblgen/PassGen.cpp
mlir/tools/mlir-translate/mlir-translate.cpp
mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp
mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp
mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
mlir/unittests/IR/AttributeTest.cpp
mlir/unittests/IR/DialectTest.cpp
mlir/unittests/IR/OperationSupportTest.cpp
mlir/unittests/Pass/AnalysisManagerTest.cpp
mlir/unittests/SDBM/SDBMTest.cpp
mlir/unittests/TableGen/OpBuildGen.cpp
mlir/unittests/TableGen/StructsGenTest.cpp
Removed:
################################################################################
diff --git a/flang/unittests/Lower/OpenMPLoweringTest.cpp b/flang/unittests/Lower/OpenMPLoweringTest.cpp
index 175b65837df3..ad6fe739d16b 100644
--- a/flang/unittests/Lower/OpenMPLoweringTest.cpp
+++ b/flang/unittests/Lower/OpenMPLoweringTest.cpp
@@ -15,7 +15,8 @@
class OpenMPLoweringTest : public testing::Test {
protected:
void SetUp() override {
- ctx.loadDialect<mlir::omp::OpenMPDialect>();
+ mlir::registerDialect<mlir::omp::OpenMPDialect>();
+ mlir::registerAllDialects(&ctx);
mlirOpBuilder.reset(new mlir::OpBuilder(&ctx));
}
diff --git a/mlir/examples/standalone/standalone-opt/standalone-opt.cpp b/mlir/examples/standalone/standalone-opt/standalone-opt.cpp
index eb624b3e8954..5c99058693c3 100644
--- a/mlir/examples/standalone/standalone-opt/standalone-opt.cpp
+++ b/mlir/examples/standalone/standalone-opt/standalone-opt.cpp
@@ -76,7 +76,7 @@ int main(int argc, char **argv) {
if (showDialects) {
mlir::MLIRContext context;
llvm::outs() << "Registered Dialects:\n";
- for (mlir::Dialect *dialect : context.getLoadedDialects()) {
+ for (mlir::Dialect *dialect : context.getRegisteredDialects()) {
llvm::outs() << dialect->getNamespace() << "\n";
}
return 0;
diff --git a/mlir/examples/toy/Ch2/toyc.cpp b/mlir/examples/toy/Ch2/toyc.cpp
index 99232d8f24a4..d0880ce0971b 100644
--- a/mlir/examples/toy/Ch2/toyc.cpp
+++ b/mlir/examples/toy/Ch2/toyc.cpp
@@ -68,9 +68,10 @@ std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
}
int dumpMLIR() {
- mlir::MLIRContext context(/*loadAllDialects=*/false);
- // Load our Dialect in this MLIR Context.
- context.getOrLoadDialect<mlir::toy::ToyDialect>();
+ // Register our Dialect with MLIR.
+ mlir::registerDialect<mlir::toy::ToyDialect>();
+
+ mlir::MLIRContext context;
// Handle '.toy' input to the compiler.
if (inputType != InputType::MLIR &&
diff --git a/mlir/examples/toy/Ch3/toyc.cpp b/mlir/examples/toy/Ch3/toyc.cpp
index d0430ce16e54..f9d5631719e8 100644
--- a/mlir/examples/toy/Ch3/toyc.cpp
+++ b/mlir/examples/toy/Ch3/toyc.cpp
@@ -102,10 +102,10 @@ int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context,
}
int dumpMLIR() {
- mlir::MLIRContext context(/*loadAllDialects=*/false);
- // Load our Dialect in this MLIR Context.
- context.getOrLoadDialect<mlir::toy::ToyDialect>();
+ // Register our Dialect with MLIR.
+ mlir::registerDialect<mlir::toy::ToyDialect>();
+ mlir::MLIRContext context;
mlir::OwningModuleRef module;
llvm::SourceMgr sourceMgr;
mlir::SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);
diff --git a/mlir/examples/toy/Ch4/toyc.cpp b/mlir/examples/toy/Ch4/toyc.cpp
index 9f95887d2707..e11f35c5f7e1 100644
--- a/mlir/examples/toy/Ch4/toyc.cpp
+++ b/mlir/examples/toy/Ch4/toyc.cpp
@@ -103,10 +103,10 @@ int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context,
}
int dumpMLIR() {
- mlir::MLIRContext context(/*loadAllDialects=*/false);
- // Load our Dialect in this MLIR Context.
- context.getOrLoadDialect<mlir::toy::ToyDialect>();
+ // Register our Dialect with MLIR.
+ mlir::registerDialect<mlir::toy::ToyDialect>();
+ mlir::MLIRContext context;
mlir::OwningModuleRef module;
llvm::SourceMgr sourceMgr;
mlir::SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);
diff --git a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
index 92fd246a1358..3097681ea3fa 100644
--- a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
@@ -256,9 +256,6 @@ struct TransposeOpLowering : public ConversionPattern {
namespace {
struct ToyToAffineLoweringPass
: public PassWrapper<ToyToAffineLoweringPass, FunctionPass> {
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<AffineDialect, StandardOpsDialect>();
- }
void runOnFunction() final;
};
} // end anonymous namespace.
diff --git a/mlir/examples/toy/Ch5/toyc.cpp b/mlir/examples/toy/Ch5/toyc.cpp
index 16faac02fc60..ed0496957093 100644
--- a/mlir/examples/toy/Ch5/toyc.cpp
+++ b/mlir/examples/toy/Ch5/toyc.cpp
@@ -106,10 +106,10 @@ int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context,
}
int dumpMLIR() {
- mlir::MLIRContext context(/*loadAllDialects=*/false);
- // Load our Dialect in this MLIR Context.
- context.getOrLoadDialect<mlir::toy::ToyDialect>();
+ // Register our Dialect with MLIR.
+ mlir::registerDialect<mlir::toy::ToyDialect>();
+ mlir::MLIRContext context;
mlir::OwningModuleRef module;
llvm::SourceMgr sourceMgr;
mlir::SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);
diff --git a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
index f3857f35e25c..cac3415f48d6 100644
--- a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
@@ -255,9 +255,6 @@ struct TransposeOpLowering : public ConversionPattern {
namespace {
struct ToyToAffineLoweringPass
: public PassWrapper<ToyToAffineLoweringPass, FunctionPass> {
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<AffineDialect, StandardOpsDialect>();
- }
void runOnFunction() final;
};
} // end anonymous namespace.
diff --git a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
index 19bf27e1864d..74b32dc0ca11 100644
--- a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
+++ b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
@@ -159,9 +159,6 @@ class PrintOpLowering : public ConversionPattern {
namespace {
struct ToyToLLVMLoweringPass
: public PassWrapper<ToyToLLVMLoweringPass, OperationPass<ModuleOp>> {
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<LLVM::LLVMDialect, scf::SCFDialect>();
- }
void runOnOperation() final;
};
} // end anonymous namespace
diff --git a/mlir/examples/toy/Ch6/toyc.cpp b/mlir/examples/toy/Ch6/toyc.cpp
index 9504a38b8784..bdcdf1af7ea8 100644
--- a/mlir/examples/toy/Ch6/toyc.cpp
+++ b/mlir/examples/toy/Ch6/toyc.cpp
@@ -255,10 +255,10 @@ int main(int argc, char **argv) {
// If we aren't dumping the AST, then we are compiling with/to MLIR.
- mlir::MLIRContext context(/*loadAllDialects=*/false);
- // Load our Dialect in this MLIR Context.
- context.getOrLoadDialect<mlir::toy::ToyDialect>();
+ // Register our Dialect with MLIR.
+ mlir::registerDialect<mlir::toy::ToyDialect>();
+ mlir::MLIRContext context;
mlir::OwningModuleRef module;
if (int error = loadAndProcessMLIR(context, module))
return error;
diff --git a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
index 92fd246a1358..3097681ea3fa 100644
--- a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
@@ -256,9 +256,6 @@ struct TransposeOpLowering : public ConversionPattern {
namespace {
struct ToyToAffineLoweringPass
: public PassWrapper<ToyToAffineLoweringPass, FunctionPass> {
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<AffineDialect, StandardOpsDialect>();
- }
void runOnFunction() final;
};
} // end anonymous namespace.
diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
index 19bf27e1864d..74b32dc0ca11 100644
--- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
+++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
@@ -159,9 +159,6 @@ class PrintOpLowering : public ConversionPattern {
namespace {
struct ToyToLLVMLoweringPass
: public PassWrapper<ToyToLLVMLoweringPass, OperationPass<ModuleOp>> {
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<LLVM::LLVMDialect, scf::SCFDialect>();
- }
void runOnOperation() final;
};
} // end anonymous namespace
diff --git a/mlir/examples/toy/Ch7/toyc.cpp b/mlir/examples/toy/Ch7/toyc.cpp
index cb3b455dc7ec..c1cc207a406c 100644
--- a/mlir/examples/toy/Ch7/toyc.cpp
+++ b/mlir/examples/toy/Ch7/toyc.cpp
@@ -256,10 +256,10 @@ int main(int argc, char **argv) {
// If we aren't dumping the AST, then we are compiling with/to MLIR.
- mlir::MLIRContext context(/*loadAllDialects=*/false);
- // Load our Dialect in this MLIR Context.
- context.getOrLoadDialect<mlir::toy::ToyDialect>();
+ // Register our Dialect with MLIR.
+ mlir::registerDialect<mlir::toy::ToyDialect>();
+ mlir::MLIRContext context;
mlir::OwningModuleRef module;
if (int error = loadAndProcessMLIR(context, module))
return error;
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index f9ec4d132056..6b5be2d0195b 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -90,12 +90,6 @@ MlirContext mlirContextCreate();
/** Takes an MLIR context owned by the caller and destroys it. */
void mlirContextDestroy(MlirContext context);
-/** Load all the globally registered dialects in the provided context.
- * TODO: remove the concept of globally registered dialect by exposing the
- * DialectRegistry.
- */
-void mlirContextLoadAllDialects(MlirContext context);
-
/*============================================================================*/
/* Location API. */
/*============================================================================*/
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 0c40bb3bbfb6..4d4fe064a6bc 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -66,11 +66,6 @@ def ConvertAffineToStandard : Pass<"lower-affine"> {
`affine.apply`.
}];
let constructor = "mlir::createLowerAffinePass()";
- let dependentDialects = [
- "scf::SCFDialect",
- "StandardOpsDialect",
- "vector::VectorDialect"
- ];
}
//===----------------------------------------------------------------------===//
@@ -81,7 +76,6 @@ def ConvertAVX512ToLLVM : Pass<"convert-avx512-to-llvm", "ModuleOp"> {
let summary = "Convert the operations from the avx512 dialect into the LLVM "
"dialect";
let constructor = "mlir::createConvertAVX512ToLLVMPass()";
- let dependentDialects = ["LLVM::LLVMDialect", "LLVM::LLVMAVX512Dialect"];
}
//===----------------------------------------------------------------------===//
@@ -104,7 +98,6 @@ def GpuToLLVMConversionPass : Pass<"gpu-to-llvm", "ModuleOp"> {
def ConvertGpuOpsToNVVMOps : Pass<"convert-gpu-to-nvvm", "gpu::GPUModuleOp"> {
let summary = "Generate NVVM operations for gpu operations";
let constructor = "mlir::createLowerGpuOpsToNVVMOpsPass()";
- let dependentDialects = ["NVVM::NVVMDialect"];
let options = [
Option<"indexBitwidth", "index-bitwidth", "unsigned",
/*default=kDeriveIndexBitwidthFromDataLayout*/"0",
@@ -119,7 +112,6 @@ def ConvertGpuOpsToNVVMOps : Pass<"convert-gpu-to-nvvm", "gpu::GPUModuleOp"> {
def ConvertGpuOpsToROCDLOps : Pass<"convert-gpu-to-rocdl", "gpu::GPUModuleOp"> {
let summary = "Generate ROCDL operations for gpu operations";
let constructor = "mlir::createLowerGpuOpsToROCDLOpsPass()";
- let dependentDialects = ["ROCDL::ROCDLDialect"];
let options = [
Option<"indexBitwidth", "index-bitwidth", "unsigned",
/*default=kDeriveIndexBitwidthFromDataLayout*/"0",
@@ -134,7 +126,6 @@ def ConvertGpuOpsToROCDLOps : Pass<"convert-gpu-to-rocdl", "gpu::GPUModuleOp"> {
def ConvertGPUToSPIRV : Pass<"convert-gpu-to-spirv", "ModuleOp"> {
let summary = "Convert GPU dialect to SPIR-V dialect";
let constructor = "mlir::createConvertGPUToSPIRVPass()";
- let dependentDialects = ["spirv::SPIRVDialect"];
}
//===----------------------------------------------------------------------===//
@@ -145,7 +136,6 @@ def ConvertGpuLaunchFuncToVulkanLaunchFunc
: Pass<"convert-gpu-launch-to-vulkan-launch", "ModuleOp"> {
let summary = "Convert gpu.launch_func to vulkanLaunch external call";
let constructor = "mlir::createConvertGpuLaunchFuncToVulkanLaunchFuncPass()";
- let dependentDialects = ["spirv::SPIRVDialect"];
}
def ConvertVulkanLaunchFuncToVulkanCalls
@@ -153,7 +143,6 @@ def ConvertVulkanLaunchFuncToVulkanCalls
let summary = "Convert vulkanLaunch external call to Vulkan runtime external "
"calls";
let constructor = "mlir::createConvertVulkanLaunchFuncToVulkanCallsPass()";
- let dependentDialects = ["LLVM::LLVMDialect"];
}
//===----------------------------------------------------------------------===//
@@ -164,7 +153,6 @@ def ConvertLinalgToLLVM : Pass<"convert-linalg-to-llvm", "ModuleOp"> {
let summary = "Convert the operations from the linalg dialect into the LLVM "
"dialect";
let constructor = "mlir::createConvertLinalgToLLVMPass()";
- let dependentDialects = ["scf::SCFDialect", "LLVM::LLVMDialect"];
}
//===----------------------------------------------------------------------===//
@@ -175,7 +163,6 @@ def ConvertLinalgToStandard : Pass<"convert-linalg-to-std", "ModuleOp"> {
let summary = "Convert the operations from the linalg dialect into the "
"Standard dialect";
let constructor = "mlir::createConvertLinalgToStandardPass()";
- let dependentDialects = ["StandardOpsDialect"];
}
//===----------------------------------------------------------------------===//
@@ -185,7 +172,6 @@ def ConvertLinalgToStandard : Pass<"convert-linalg-to-std", "ModuleOp"> {
def ConvertLinalgToSPIRV : Pass<"convert-linalg-to-spirv", "ModuleOp"> {
let summary = "Convert Linalg ops to SPIR-V ops";
let constructor = "mlir::createLinalgToSPIRVPass()";
- let dependentDialects = ["spirv::SPIRVDialect"];
}
//===----------------------------------------------------------------------===//
@@ -196,7 +182,6 @@ def SCFToStandard : Pass<"convert-scf-to-std"> {
let summary = "Convert SCF dialect to Standard dialect, replacing structured"
" control flow with a CFG";
let constructor = "mlir::createLowerToCFGPass()";
- let dependentDialects = ["StandardOpsDialect"];
}
//===----------------------------------------------------------------------===//
@@ -206,7 +191,6 @@ def SCFToStandard : Pass<"convert-scf-to-std"> {
def ConvertAffineForToGPU : FunctionPass<"convert-affine-for-to-gpu"> {
let summary = "Convert top-level AffineFor Ops to GPU kernels";
let constructor = "mlir::createAffineForToGPUPass()";
- let dependentDialects = ["gpu::GPUDialect"];
let options = [
Option<"numBlockDims", "gpu-block-dims", "unsigned", /*default=*/"1u",
"Number of GPU block dimensions for mapping">,
@@ -218,7 +202,6 @@ def ConvertAffineForToGPU : FunctionPass<"convert-affine-for-to-gpu"> {
def ConvertParallelLoopToGpu : Pass<"convert-parallel-loops-to-gpu"> {
let summary = "Convert mapped scf.parallel ops to gpu launch operations";
let constructor = "mlir::createParallelLoopToGpuPass()";
- let dependentDialects = ["AffineDialect", "gpu::GPUDialect"];
}
//===----------------------------------------------------------------------===//
@@ -229,7 +212,6 @@ def ConvertShapeToStandard : Pass<"convert-shape-to-std", "ModuleOp"> {
let summary = "Convert operations from the shape dialect into the standard "
"dialect";
let constructor = "mlir::createConvertShapeToStandardPass()";
- let dependentDialects = ["StandardOpsDialect"];
}
//===----------------------------------------------------------------------===//
@@ -239,7 +221,6 @@ def ConvertShapeToStandard : Pass<"convert-shape-to-std", "ModuleOp"> {
def ConvertShapeToSCF : FunctionPass<"convert-shape-to-scf"> {
let summary = "Convert operations from the shape dialect to the SCF dialect";
let constructor = "mlir::createConvertShapeToSCFPass()";
- let dependentDialects = ["scf::SCFDialect"];
}
//===----------------------------------------------------------------------===//
@@ -249,7 +230,6 @@ def ConvertShapeToSCF : FunctionPass<"convert-shape-to-scf"> {
def ConvertSPIRVToLLVM : Pass<"convert-spirv-to-llvm", "ModuleOp"> {
let summary = "Convert SPIR-V dialect to LLVM dialect";
let constructor = "mlir::createConvertSPIRVToLLVMPass()";
- let dependentDialects = ["LLVM::LLVMDialect"];
}
//===----------------------------------------------------------------------===//
@@ -284,7 +264,6 @@ def ConvertStandardToLLVM : Pass<"convert-std-to-llvm", "ModuleOp"> {
LLVM IR types.
}];
let constructor = "mlir::createLowerToLLVMPass()";
- let dependentDialects = ["LLVM::LLVMDialect"];
let options = [
Option<"useAlignedAlloc", "use-aligned-alloc", "bool", /*default=*/"false",
"Use aligned_alloc in place of malloc for heap allocations">,
@@ -308,13 +287,11 @@ def ConvertStandardToLLVM : Pass<"convert-std-to-llvm", "ModuleOp"> {
def LegalizeStandardForSPIRV : Pass<"legalize-std-for-spirv"> {
let summary = "Legalize standard ops for SPIR-V lowering";
let constructor = "mlir::createLegalizeStdOpsForSPIRVLoweringPass()";
- let dependentDialects = ["spirv::SPIRVDialect"];
}
def ConvertStandardToSPIRV : Pass<"convert-std-to-spirv", "ModuleOp"> {
let summary = "Convert Standard Ops to SPIR-V dialect";
let constructor = "mlir::createConvertStandardToSPIRVPass()";
- let dependentDialects = ["spirv::SPIRVDialect"];
}
//===----------------------------------------------------------------------===//
@@ -325,7 +302,6 @@ def ConvertVectorToSCF : FunctionPass<"convert-vector-to-scf"> {
let summary = "Lower the operations from the vector dialect into the SCF "
"dialect";
let constructor = "mlir::createConvertVectorToSCFPass()";
- let dependentDialects = ["AffineDialect", "scf::SCFDialect"];
let options = [
Option<"fullUnroll", "full-unroll", "bool", /*default=*/"false",
"Perform full unrolling when converting vector transfers to SCF">,
@@ -340,7 +316,6 @@ def ConvertVectorToLLVM : Pass<"convert-vector-to-llvm", "ModuleOp"> {
let summary = "Lower the operations from the vector dialect into the LLVM "
"dialect";
let constructor = "mlir::createConvertVectorToLLVMPass()";
- let dependentDialects = ["LLVM::LLVMDialect"];
let options = [
Option<"reassociateFPReductions", "reassociate-fp-reductions",
"bool", /*default=*/"false",
@@ -356,7 +331,6 @@ def ConvertVectorToROCDL : Pass<"convert-vector-to-rocdl", "ModuleOp"> {
let summary = "Lower the operations from the vector dialect into the ROCDL "
"dialect";
let constructor = "mlir::createConvertVectorToROCDLPass()";
- let dependentDialects = ["ROCDL::ROCDLDialect"];
}
#endif // MLIR_CONVERSION_PASSES
diff --git a/mlir/include/mlir/Dialect/Affine/Passes.td b/mlir/include/mlir/Dialect/Affine/Passes.td
index f43fabd19aae..810640058155 100644
--- a/mlir/include/mlir/Dialect/Affine/Passes.td
+++ b/mlir/include/mlir/Dialect/Affine/Passes.td
@@ -94,7 +94,6 @@ def AffineLoopUnrollAndJam : FunctionPass<"affine-loop-unroll-jam"> {
def AffineVectorize : FunctionPass<"affine-super-vectorize"> {
let summary = "Vectorize to a target independent n-D vector abstraction";
let constructor = "mlir::createSuperVectorizePass()";
- let dependentDialects = ["vector::VectorDialect"];
let options = [
ListOption<"vectorSizes", "virtual-vector-size", "int64_t",
"Specify an n-D virtual vector size for vectorization",
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
index 2f465f07a97e..04700f0aa17d 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
@@ -15,7 +15,6 @@
#define MLIR_DIALECT_LLVMIR_LLVMDIALECT_H_
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
-#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/OpDefinition.h"
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
index 0529b7f823f1..d21f5bc0b49b 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
@@ -19,11 +19,6 @@ include "mlir/IR/OpBase.td"
def LLVM_Dialect : Dialect {
let name = "llvm";
let cppNamespace = "LLVM";
-
- /// FIXME: at the moment this is a dependency of the translation to LLVM IR,
- /// not really one of this dialect per-se.
- let dependentDialects = ["omp::OpenMPDialect"];
-
let hasRegionArgAttrVerify = 1;
let extraClassDeclaration = [{
~LLVMDialect();
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
index 9cc5314bdb90..86d437c9b561 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
@@ -14,7 +14,6 @@
#ifndef MLIR_DIALECT_LLVMIR_NVVMDIALECT_H_
#define MLIR_DIALECT_LLVMIR_NVVMDIALECT_H_
-#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 7d47e5012ac9..5f022e32b801 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -23,7 +23,6 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
def NVVM_Dialect : Dialect {
let name = "nvvm";
let cppNamespace = "NVVM";
- let dependentDialects = ["LLVM::LLVMDialect"];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h b/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h
index eb40373c3f11..bf761c357f90 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h
@@ -22,7 +22,6 @@
#ifndef MLIR_DIALECT_LLVMIR_ROCDLDIALECT_H_
#define MLIR_DIALECT_LLVMIR_ROCDLDIALECT_H_
-#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
index f85c4f02899b..0cd11690daa8 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
@@ -23,7 +23,6 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
def ROCDL_Dialect : Dialect {
let name = "rocdl";
let cppNamespace = "ROCDL";
- let dependentDialects = ["LLVM::LLVMDialect"];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index dcf4b5ec06cb..11f12ad30eb6 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -30,20 +30,17 @@ def LinalgFusion : FunctionPass<"linalg-fusion"> {
def LinalgFusionOfTensorOps : Pass<"linalg-fusion-for-tensor-ops"> {
let summary = "Fuse operations on RankedTensorType in linalg dialect";
let constructor = "mlir::createLinalgFusionOfTensorOpsPass()";
- let dependentDialects = ["AffineDialect"];
}
def LinalgLowerToAffineLoops : FunctionPass<"convert-linalg-to-affine-loops"> {
let summary = "Lower the operations from the linalg dialect into affine "
"loops";
let constructor = "mlir::createConvertLinalgToAffineLoopsPass()";
- let dependentDialects = ["AffineDialect"];
}
def LinalgLowerToLoops : FunctionPass<"convert-linalg-to-loops"> {
let summary = "Lower the operations from the linalg dialect into loops";
let constructor = "mlir::createConvertLinalgToLoopsPass()";
- let dependentDialects = ["scf::SCFDialect", "AffineDialect"];
}
def LinalgOnTensorsToBuffers : Pass<"convert-linalg-on-tensors-to-buffers", "ModuleOp"> {
@@ -57,7 +54,6 @@ def LinalgLowerToParallelLoops
let summary = "Lower the operations from the linalg dialect into parallel "
"loops";
let constructor = "mlir::createConvertLinalgToParallelLoopsPass()";
- let dependentDialects = ["AffineDialect", "scf::SCFDialect"];
}
def LinalgPromotion : FunctionPass<"linalg-promote-subviews"> {
@@ -74,9 +70,6 @@ def LinalgPromotion : FunctionPass<"linalg-promote-subviews"> {
def LinalgTiling : FunctionPass<"linalg-tile"> {
let summary = "Tile operations in the linalg dialect";
let constructor = "mlir::createLinalgTilingPass()";
- let dependentDialects = [
- "AffineDialect", "scf::SCFDialect"
- ];
let options = [
ListOption<"tileSizes", "linalg-tile-sizes", "int64_t",
"Test generation of dynamic promoted buffers",
@@ -93,7 +86,6 @@ def LinalgTilingToParallelLoops
"Test generation of dynamic promoted buffers",
"llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">
];
- let dependentDialects = ["AffineDialect", "scf::SCFDialect"];
}
#endif // MLIR_DIALECT_LINALG_PASSES
diff --git a/mlir/include/mlir/Dialect/SCF/Passes.td b/mlir/include/mlir/Dialect/SCF/Passes.td
index 6f3cf0e12642..483d0ba7c7be 100644
--- a/mlir/include/mlir/Dialect/SCF/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Passes.td
@@ -36,7 +36,6 @@ def SCFParallelLoopTiling : FunctionPass<"parallel-loop-tiling"> {
"Factors to tile parallel loops by",
"llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">
];
- let dependentDialects = ["AffineDialect"];
}
#endif // MLIR_DIALECT_SCF_PASSES
diff --git a/mlir/include/mlir/IR/Dialect.h b/mlir/include/mlir/IR/Dialect.h
index 8c0fef0d7ccf..4f9e4cb3618b 100644
--- a/mlir/include/mlir/IR/Dialect.h
+++ b/mlir/include/mlir/IR/Dialect.h
@@ -16,8 +16,6 @@
#include "mlir/IR/OperationSupport.h"
#include "mlir/Support/TypeID.h"
-#include <map>
-
namespace mlir {
class DialectAsmParser;
class DialectAsmPrinter;
@@ -25,7 +23,7 @@ class DialectInterface;
class OpBuilder;
class Type;
-using DialectAllocatorFunction = std::function<Dialect *(MLIRContext *)>;
+using DialectAllocatorFunction = std::function<void(MLIRContext *)>;
/// Dialects are groups of MLIR operations and behavior associated with the
/// entire group. For example, hooks into other systems for constant folding,
@@ -214,87 +212,30 @@ class Dialect {
/// A collection of registered dialect interfaces.
DenseMap<TypeID, std::unique_ptr<DialectInterface>> registeredInterfaces;
+ /// Registers a specific dialect creation function with the global registry.
+ /// Used through the registerDialect template.
+ /// Registrations are deduplicated by dialect TypeID and only the first
+ /// registration will be used.
+ static void
+ registerDialectAllocator(TypeID typeID,
+ const DialectAllocatorFunction &function);
+ template <typename ConcreteDialect>
friend void registerDialect();
friend class MLIRContext;
};
-/// The DialectRegistry maps a dialect namespace to a constructor for the
-/// matching dialect.
-/// This allows for decoupling the list of dialects "available" from the
-/// dialects loaded in the Context. The parser in particular will lazily load
-/// dialects in in the Context as operations are encountered.
-class DialectRegistry {
- using MapTy =
- std::map<std::string, std::pair<TypeID, DialectAllocatorFunction>>;
-
-public:
- template <typename ConcreteDialect>
- void insert() {
- insert(TypeID::get<ConcreteDialect>(),
- ConcreteDialect::getDialectNamespace(),
- static_cast<DialectAllocatorFunction>(([](MLIRContext *ctx) {
- // Just allocate the dialect, the context
- // takes ownership of it.
- return ctx->getOrLoadDialect<ConcreteDialect>();
- })));
- }
-
- template <typename ConcreteDialect, typename OtherDialect,
- typename... MoreDialects>
- void insert() {
- insert<ConcreteDialect>();
- insert<OtherDialect, MoreDialects...>();
- }
-
- /// Add a new dialect constructor to the registry.
- void insert(TypeID typeID, StringRef name, DialectAllocatorFunction ctor);
-
- /// Load a dialect for this namespace in the provided context.
- Dialect *loadByName(StringRef name, MLIRContext *context);
-
- // Register all dialects available in the current registry with the registry
- // in the provided context.
- void appendTo(DialectRegistry &destination) {
- for (const auto &nameAndRegistrationIt : registry)
- destination.insert(nameAndRegistrationIt.second.first,
- nameAndRegistrationIt.first,
- nameAndRegistrationIt.second.second);
- }
- // Load all dialects available in the registry in the provided context.
- void loadAll(MLIRContext *context) {
- for (const auto &nameAndRegistrationIt : registry)
- nameAndRegistrationIt.second.second(context);
- }
-
- MapTy::const_iterator begin() const { return registry.begin(); }
- MapTy::const_iterator end() const { return registry.end(); }
-
-private:
- MapTy registry;
-};
-
-/// Deprecated: this provides a global registry for convenience, while we're
-/// transitionning the registration mechanism to a stateless approach.
-DialectRegistry &getGlobalDialectRegistry();
-
-/// Registers all dialects from the global registries with the
-/// specified MLIRContext. This won't load the dialects in the context,
-/// but only make them available for lazy loading by name.
+/// Registers all dialects and hooks from the global registries with the
+/// specified MLIRContext.
/// Note: This method is not thread-safe.
void registerAllDialects(MLIRContext *context);
-/// Register and return the dialect with the given namespace in the provided
-/// context. Returns nullptr is there is no constructor registered for this
-/// dialect.
-inline Dialect *registerDialect(StringRef name, MLIRContext *context) {
- return getGlobalDialectRegistry().loadByName(name, context);
-}
-
/// Utility to register a dialect. Client can register their dialect with the
/// global registry by calling registerDialect<MyDialect>();
/// Note: This method is not thread-safe.
template <typename ConcreteDialect> void registerDialect() {
- getGlobalDialectRegistry().insert<ConcreteDialect>();
+ Dialect::registerDialectAllocator(
+ TypeID::get<ConcreteDialect>(),
+ [](MLIRContext *ctx) { ctx->getOrCreateDialect<ConcreteDialect>(); });
}
/// DialectRegistration provides a global initializer that registers a Dialect
diff --git a/mlir/include/mlir/IR/FunctionSupport.h b/mlir/include/mlir/IR/FunctionSupport.h
index 3d467cd4f364..7e281f393af9 100644
--- a/mlir/include/mlir/IR/FunctionSupport.h
+++ b/mlir/include/mlir/IR/FunctionSupport.h
@@ -428,7 +428,7 @@ LogicalResult FunctionLike<ConcreteType>::verifyTrait(Operation *op) {
if (!attr.first.strref().contains('.'))
return funcOp.emitOpError("arguments may only have dialect attributes");
auto dialectNamePair = attr.first.strref().split('.');
- if (auto *dialect = ctx->getLoadedDialect(dialectNamePair.first)) {
+ if (auto *dialect = ctx->getRegisteredDialect(dialectNamePair.first)) {
if (failed(dialect->verifyRegionArgAttribute(op, /*regionIndex=*/0,
/*argIndex=*/i, attr)))
return failure();
@@ -444,7 +444,7 @@ LogicalResult FunctionLike<ConcreteType>::verifyTrait(Operation *op) {
if (!attr.first.strref().contains('.'))
return funcOp.emitOpError("results may only have dialect attributes");
auto dialectNamePair = attr.first.strref().split('.');
- if (auto *dialect = ctx->getLoadedDialect(dialectNamePair.first)) {
+ if (auto *dialect = ctx->getRegisteredDialect(dialectNamePair.first)) {
if (failed(dialect->verifyRegionResultAttribute(op, /*regionIndex=*/0,
/*resultIndex=*/i,
attr)))
diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h
index e8a5d6e6d236..0192a8ae06af 100644
--- a/mlir/include/mlir/IR/MLIRContext.h
+++ b/mlir/include/mlir/IR/MLIRContext.h
@@ -19,12 +19,10 @@ namespace mlir {
class AbstractOperation;
class DiagnosticEngine;
class Dialect;
-class DialectRegistry;
class InFlightDiagnostic;
class Location;
class MLIRContextImpl;
class StorageUniquer;
-DialectRegistry &getGlobalDialectRegistry();
/// MLIRContext is the top-level object for a collection of MLIR modules. It
/// holds immortal uniqued objects like types, and the tables used to unique
@@ -36,69 +34,34 @@ DialectRegistry &getGlobalDialectRegistry();
///
class MLIRContext {
public:
- /// Create a new Context.
- /// The loadAllDialects parameters allows to load all dialects from the global
- /// registry on Context construction. It is deprecated and will be removed
- /// soon.
- explicit MLIRContext(bool loadAllDialects = true);
+ explicit MLIRContext();
~MLIRContext();
- /// Return information about all IR dialects loaded in the context.
- std::vector<Dialect *> getLoadedDialects();
-
- /// Return the dialect registry associated with this context.
- DialectRegistry &getDialectRegistry();
-
- /// Return information about all available dialects in the registry in this
- /// context.
- std::vector<StringRef> getAvailableDialects();
+ /// Return information about all registered IR dialects.
+ std::vector<Dialect *> getRegisteredDialects();
/// Get a registered IR dialect with the given namespace. If an exact match is
/// not found, then return nullptr.
- Dialect *getLoadedDialect(StringRef name);
+ Dialect *getRegisteredDialect(StringRef name);
/// Get a registered IR dialect for the given derived dialect type. The
/// derived type must provide a static 'getDialectNamespace' method.
- template <typename T>
- T *getLoadedDialect() {
- return static_cast<T *>(getLoadedDialect(T::getDialectNamespace()));
+ template <typename T> T *getRegisteredDialect() {
+ return static_cast<T *>(getRegisteredDialect(T::getDialectNamespace()));
}
/// Get (or create) a dialect for the given derived dialect type. The derived
/// type must provide a static 'getDialectNamespace' method.
template <typename T>
- T *getOrLoadDialect() {
- return static_cast<T *>(
- getOrLoadDialect(T::getDialectNamespace(), TypeID::get<T>(), [this]() {
+ T *getOrCreateDialect() {
+ return static_cast<T *>(getOrCreateDialect(
+ T::getDialectNamespace(), TypeID::get<T>(), [this]() {
std::unique_ptr<T> dialect(new T(this));
+ dialect->dialectID = TypeID::get<T>();
return dialect;
}));
}
- /// Load a dialect in the context.
- template <typename Dialect>
- void loadDialect() {
- getOrLoadDialect<Dialect>();
- }
-
- /// Load a list dialects in the context.
- template <typename Dialect, typename OtherDialect, typename... MoreDialects>
- void loadDialect() {
- getOrLoadDialect<Dialect>();
- loadDialect<OtherDialect, MoreDialects...>();
- }
-
- /// Deprecated: load all globally registered dialects into this context.
- /// This method will be removed soon, it can be used temporarily as we're
- /// phasing out the global registry.
- void loadAllGloballyRegisteredDialects();
-
- /// Get (or create) a dialect for the given derived dialect name.
- /// The dialect will be loaded from the registry if no dialect is found.
- /// If no dialect is loaded for this name and none is available in the
- /// registry, returns nullptr.
- Dialect *getOrLoadDialect(StringRef name);
-
/// Return true if we allow to create operation for unregistered dialects.
bool allowsUnregisteredDialects();
@@ -160,12 +123,10 @@ class MLIRContext {
const std::unique_ptr<MLIRContextImpl> impl;
/// Get a dialect for the provided namespace and TypeID: abort the program if
- /// a dialect exist for this namespace with
diff erent TypeID. If a dialect has
- /// not been loaded for this namespace/TypeID yet, use the provided ctor to
- /// create one on the fly and load it. Returns a pointer to the dialect owned
- /// by the context.
- Dialect *getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
- function_ref<std::unique_ptr<Dialect>()> ctor);
+ /// a dialect exist for this namespace with
diff erent TypeID. Returns a
+ /// pointer to the dialect owned by the context.
+ Dialect *getOrCreateDialect(StringRef dialectNamespace, TypeID dialectID,
+ function_ref<std::unique_ptr<Dialect>()> ctor);
MLIRContext(const MLIRContext &) = delete;
void operator=(const MLIRContext &) = delete;
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index a28410f028d5..9cc57a617289 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -244,11 +244,6 @@ class Dialect {
// The description of the dialect.
string description = ?;
- // A list of dialects this dialect will load on construction as dependencies.
- // These are dialects that this dialect may involved in canonicalization
- // pattern or interfaces.
- list<string> dependentDialects = [];
-
// The C++ namespace that ops of this dialect should be placed into.
//
// By default, uses the name of the dialect as the only namespace. To avoid
diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h
index 147ececc4c5a..b76b26fe3483 100644
--- a/mlir/include/mlir/InitAllDialects.h
+++ b/mlir/include/mlir/InitAllDialects.h
@@ -35,35 +35,30 @@
namespace mlir {
-// Add all the MLIR dialects to the provided registry.
-inline void registerAllDialects(DialectRegistry ®istry) {
- // clang-format off
- registry.insert<acc::OpenACCDialect,
- AffineDialect,
- avx512::AVX512Dialect,
- gpu::GPUDialect,
- LLVM::LLVMAVX512Dialect,
- LLVM::LLVMDialect,
- linalg::LinalgDialect,
- scf::SCFDialect,
- omp::OpenMPDialect,
- quant::QuantizationDialect,
- spirv::SPIRVDialect,
- StandardOpsDialect,
- vector::VectorDialect,
- NVVM::NVVMDialect,
- ROCDL::ROCDLDialect,
- SDBMDialect,
- shape::ShapeDialect>();
- // clang-format on
-}
-
// This function should be called before creating any MLIRContext if one expect
// all the possible dialects to be made available to the context automatically.
inline void registerAllDialects() {
- static bool initOnce =
- ([]() { registerAllDialects(getGlobalDialectRegistry()); }(), true);
- (void)initOnce;
+ static bool init_once = []() {
+ registerDialect<acc::OpenACCDialect>();
+ registerDialect<AffineDialect>();
+ registerDialect<avx512::AVX512Dialect>();
+ registerDialect<gpu::GPUDialect>();
+ registerDialect<LLVM::LLVMAVX512Dialect>();
+ registerDialect<LLVM::LLVMDialect>();
+ registerDialect<linalg::LinalgDialect>();
+ registerDialect<scf::SCFDialect>();
+ registerDialect<omp::OpenMPDialect>();
+ registerDialect<quant::QuantizationDialect>();
+ registerDialect<spirv::SPIRVDialect>();
+ registerDialect<StandardOpsDialect>();
+ registerDialect<vector::VectorDialect>();
+ registerDialect<NVVM::NVVMDialect>();
+ registerDialect<ROCDL::ROCDLDialect>();
+ registerDialect<SDBMDialect>();
+ registerDialect<shape::ShapeDialect>();
+ return true;
+ }();
+ (void)init_once;
}
} // namespace mlir
diff --git a/mlir/include/mlir/InitAllTranslations.h b/mlir/include/mlir/InitAllTranslations.h
index a1771dab144c..31ca0254cf89 100644
--- a/mlir/include/mlir/InitAllTranslations.h
+++ b/mlir/include/mlir/InitAllTranslations.h
@@ -28,7 +28,7 @@ void registerAVX512ToLLVMIRTranslation();
// expects all the possible translations to be made available to the context
// automatically.
inline void registerAllTranslations() {
- static bool initOnce = []() {
+ static bool init_once = []() {
registerFromLLVMIRTranslation();
registerFromSPIRVTranslation();
registerToLLVMIRTranslation();
@@ -38,7 +38,7 @@ inline void registerAllTranslations() {
registerAVX512ToLLVMIRTranslation();
return true;
}();
- (void)initOnce;
+ (void)init_once;
}
} // namespace mlir
diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h
index d982c7f6d41d..7c0f9bd958a1 100644
--- a/mlir/include/mlir/Pass/Pass.h
+++ b/mlir/include/mlir/Pass/Pass.h
@@ -9,7 +9,6 @@
#ifndef MLIR_PASS_PASS_H
#define MLIR_PASS_PASS_H
-#include "mlir/IR/Dialect.h"
#include "mlir/IR/Function.h"
#include "mlir/Pass/AnalysisManager.h"
#include "mlir/Pass/PassRegistry.h"
@@ -58,13 +57,6 @@ class Pass {
/// Returns the derived pass name.
virtual StringRef getName() const = 0;
- /// Register dependent dialects for the current pass.
- /// A pass is expected to register the dialects it will create entities for
- /// (Operations, Types, Attributes), other than dialect that exists in the
- /// input. For example, a pass that converts from Linalg to Affine would
- /// register the Affine dialect but does not need to register Linalg.
- virtual void getDependentDialects(DialectRegistry ®istry) const {}
-
/// Returns the command line argument used when registering this pass. Return
/// an empty string if one does not exist.
virtual StringRef getArgument() const {
diff --git a/mlir/include/mlir/Pass/PassBase.td b/mlir/include/mlir/Pass/PassBase.td
index 7a2feff4fe04..54b44031559e 100644
--- a/mlir/include/mlir/Pass/PassBase.td
+++ b/mlir/include/mlir/Pass/PassBase.td
@@ -78,9 +78,6 @@ class PassBase<string passArg, string base> {
// A C++ constructor call to create an instance of this pass.
code constructor = [{}];
- // A list of dialects this pass may produce entities in.
- list<string> dependentDialects = [];
-
// A set of options provided by this pass.
list<Option> options = [];
diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h
index 29e7c07c2ee4..9cbfb0b27710 100644
--- a/mlir/include/mlir/Pass/PassManager.h
+++ b/mlir/include/mlir/Pass/PassManager.h
@@ -9,7 +9,6 @@
#ifndef MLIR_PASS_PASSMANAGER_H
#define MLIR_PASS_PASSMANAGER_H
-#include "mlir/IR/Dialect.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/Optional.h"
@@ -59,14 +58,6 @@ class OpPassManager {
pass_iterator end();
iterator_range<pass_iterator> getPasses() { return {begin(), end()}; }
- using const_pass_iterator = llvm::pointee_iterator<
- std::vector<std::unique_ptr<Pass>>::const_iterator>;
- const_pass_iterator begin() const;
- const_pass_iterator end() const;
- iterator_range<const_pass_iterator> getPasses() const {
- return {begin(), end()};
- }
-
/// Run the held passes over the given operation.
LogicalResult run(Operation *op, AnalysisManager am);
@@ -109,11 +100,6 @@ class OpPassManager {
/// Merge the pass statistics of this class into 'other'.
void mergeStatisticsInto(OpPassManager &other);
- /// Register dependent dialects for the current pass manager.
- /// This is forwarding to every pass in this PassManager, see the
- /// documentation for the same method on the Pass class.
- void getDependentDialects(DialectRegistry &dialects) const;
-
private:
OpPassManager(OperationName name, bool verifyPasses);
diff --git a/mlir/include/mlir/Support/MlirOptMain.h b/mlir/include/mlir/Support/MlirOptMain.h
index dde6452eb0a6..f235ea3fa520 100644
--- a/mlir/include/mlir/Support/MlirOptMain.h
+++ b/mlir/include/mlir/Support/MlirOptMain.h
@@ -22,15 +22,10 @@ namespace mlir {
struct LogicalResult;
class PassPipelineCLParser;
-/// Run an passPipeline on the provided memory buffer loaded as an MLIRModule.
-/// The preloadDialectsInContext option will trigger the upfront loading of all
-/// dialects from the global registry in the MLIRContext. This option is
-/// deprecated and will be removed soon.
LogicalResult MlirOptMain(llvm::raw_ostream &os,
std::unique_ptr<llvm::MemoryBuffer> buffer,
const PassPipelineCLParser &passPipeline,
bool splitInputFile, bool verifyDiagnostics,
- bool verifyPasses, bool allowUnregisteredDialects,
- bool preloadDialectsInContext = true);
+ bool verifyPasses, bool allowUnregisteredDialects);
} // end namespace mlir
diff --git a/mlir/include/mlir/TableGen/Dialect.h b/mlir/include/mlir/TableGen/Dialect.h
index 99217d8c7d3d..5e85806f377f 100644
--- a/mlir/include/mlir/TableGen/Dialect.h
+++ b/mlir/include/mlir/TableGen/Dialect.h
@@ -14,7 +14,6 @@
#include "mlir/Support/LLVM.h"
#include <string>
-#include <vector>
namespace llvm {
class Record;
@@ -26,7 +25,7 @@ namespace tblgen {
// and provides helper methods for accessing them.
class Dialect {
public:
- explicit Dialect(const llvm::Record *def);
+ explicit Dialect(const llvm::Record *def) : def(def) {}
// Returns the name of this dialect.
StringRef getName() const;
@@ -44,10 +43,6 @@ class Dialect {
// Returns the description of the dialect. Returns empty string if none.
StringRef getDescription() const;
- // Returns the list of dialect (class names) that this dialect depends on.
- // These are dialects that will be loaded on construction of this dialect.
- ArrayRef<StringRef> getDependentDialects() const;
-
// Returns the dialects extra class declaration code.
llvm::Optional<StringRef> getExtraClassDeclaration() const;
@@ -75,7 +70,6 @@ class Dialect {
private:
const llvm::Record *def;
- std::vector<StringRef> dependentDialects;
};
} // end namespace tblgen
} // end namespace mlir
diff --git a/mlir/include/mlir/TableGen/Pass.h b/mlir/include/mlir/TableGen/Pass.h
index 968c85416965..02427e42a525 100644
--- a/mlir/include/mlir/TableGen/Pass.h
+++ b/mlir/include/mlir/TableGen/Pass.h
@@ -94,9 +94,6 @@ class Pass {
/// Return the C++ constructor call to create an instance of this pass.
StringRef getConstructor() const;
- /// Return the dialects this pass needs to be registered.
- ArrayRef<StringRef> getDependentDialects() const;
-
/// Return the options provided by this pass.
ArrayRef<PassOption> getOptions() const;
@@ -107,7 +104,6 @@ class Pass {
private:
const llvm::Record *def;
- std::vector<StringRef> dependentDialects;
std::vector<PassOption> options;
std::vector<PassStatistic> statistics;
};
diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 3292d5e7dec2..778780573498 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -162,8 +162,6 @@ def BufferPlacement : FunctionPass<"buffer-placement"> {
}];
let constructor = "mlir::createBufferPlacementPass()";
- // TODO: this pass likely shouldn't depend on Linalg?
- let dependentDialects = ["linalg::LinalgDialect"];
}
def Canonicalizer : Pass<"canonicalize"> {
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 94b0fd6c8e99..4ccfb45f2c43 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -9,11 +9,9 @@
#include "mlir-c/IR.h"
#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Dialect.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Types.h"
-#include "mlir/InitAllDialects.h"
#include "mlir/Parser.h"
#include "llvm/Support/raw_ostream.h"
@@ -91,16 +89,12 @@ class CallbackOstream : public llvm::raw_ostream {
/* ========================================================================== */
MlirContext mlirContextCreate() {
- auto *context = new MLIRContext(/*loadAllDialects=*/false);
+ auto *context = new MLIRContext;
return wrap(context);
}
void mlirContextDestroy(MlirContext context) { delete unwrap(context); }
-void mlirContextLoadAllDialects(MlirContext context) {
- getGlobalDialectRegistry().loadAll(unwrap(context));
-}
-
/* ========================================================================== */
/* Location API. */
/* ========================================================================== */
diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
index 42673936b878..1ebf48174aaf 100644
--- a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
+++ b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
@@ -16,7 +16,6 @@
#include "../PassDetail.h"
#include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
-#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Serialization.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
index 0460d98b44a4..7b57854dde98 100644
--- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
+++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
@@ -19,7 +19,6 @@
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Dialect/Linalg/Passes.h"
-#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
diff --git a/mlir/lib/Conversion/PassDetail.h b/mlir/lib/Conversion/PassDetail.h
index 7fa5a5a92f20..6da0bc81e7af 100644
--- a/mlir/lib/Conversion/PassDetail.h
+++ b/mlir/lib/Conversion/PassDetail.h
@@ -12,43 +12,11 @@
#include "mlir/Pass/Pass.h"
namespace mlir {
-class AffineDialect;
-class StandardOpsDialect;
-
-// Forward declaration from Dialect.h
-template <typename ConcreteDialect>
-void registerDialect(DialectRegistry ®istry);
namespace gpu {
-class GPUDialect;
class GPUModuleOp;
} // end namespace gpu
-namespace LLVM {
-class LLVMDialect;
-class LLVMAVX512Dialect;
-} // end namespace LLVM
-
-namespace NVVM {
-class NVVMDialect;
-} // end namespace NVVM
-
-namespace ROCDL {
-class ROCDLDialect;
-} // end namespace ROCDL
-
-namespace scf {
-class SCFDialect;
-} // end namespace scf
-
-namespace spirv {
-class SPIRVDialect;
-} // end namespace spirv
-
-namespace vector {
-class VectorDialect;
-} // end namespace vector
-
#define GEN_PASS_CLASSES
#include "mlir/Conversion/Passes.h.inc"
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 60967039b545..efe4a3c958d7 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -125,7 +125,7 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx)
/// Create an LLVMTypeConverter using custom LowerToLLVMOptions.
LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
const LowerToLLVMOptions &options)
- : llvmDialect(ctx->getOrLoadDialect<LLVM::LLVMDialect>()),
+ : llvmDialect(ctx->getRegisteredDialect<LLVM::LLVMDialect>()),
options(options) {
assert(llvmDialect && "LLVM IR dialect is not registered");
if (options.indexBitwidth == kDeriveIndexBitwidthFromDataLayout)
diff --git a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
index a2e608dcb713..19643d271f8d 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
@@ -14,7 +14,6 @@
#include "../PassDetail.h"
#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h"
#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h"
-#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/PatternMatch.h"
diff --git a/mlir/lib/Dialect/Affine/Transforms/PassDetail.h b/mlir/lib/Dialect/Affine/Transforms/PassDetail.h
index da8f7ac3fc81..3bae0592b3d4 100644
--- a/mlir/lib/Dialect/Affine/Transforms/PassDetail.h
+++ b/mlir/lib/Dialect/Affine/Transforms/PassDetail.h
@@ -12,16 +12,6 @@
#include "mlir/Pass/Pass.h"
namespace mlir {
-// Forward declaration from Dialect.h
-template <typename ConcreteDialect>
-void registerDialect(DialectRegistry ®istry);
-
-namespace linalg {
-class LinalgDialect;
-} // end namespace linalg
-namespace vector {
-class VectorDialect;
-} // end namespace vector
#define GEN_PASS_CLASSES
#include "mlir/Dialect/Affine/Passes.h.inc"
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index bf18c6cddd7e..009699be5263 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1224,7 +1224,6 @@ template <typename NamedStructuredOpType>
static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
OperationState &result) {
SmallVector<OpAsmParser::OperandType, 8> operandsInfo;
- result.getContext()->getOrLoadDialect<StandardOpsDialect>();
// Optional attributes may be added.
if (parser.parseOperandList(operandsInfo) ||
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PassDetail.h b/mlir/lib/Dialect/Linalg/Transforms/PassDetail.h
index 0415aeb8a1fd..7fa05ff12120 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PassDetail.h
+++ b/mlir/lib/Dialect/Linalg/Transforms/PassDetail.h
@@ -9,18 +9,9 @@
#ifndef DIALECT_LINALG_TRANSFORMS_PASSDETAIL_H_
#define DIALECT_LINALG_TRANSFORMS_PASSDETAIL_H_
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/IR/Dialect.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
-// Forward declaration from Dialect.h
-template <typename ConcreteDialect>
-void registerDialect(DialectRegistry ®istry);
-
-namespace scf {
-class SCFDialect;
-} // end namespace scf
#define GEN_PASS_CLASSES
#include "mlir/Dialect/Linalg/Passes.h.inc"
diff --git a/mlir/lib/Dialect/SCF/Transforms/PassDetail.h b/mlir/lib/Dialect/SCF/Transforms/PassDetail.h
index 6fa7f227d3da..95f8636b27c1 100644
--- a/mlir/lib/Dialect/SCF/Transforms/PassDetail.h
+++ b/mlir/lib/Dialect/SCF/Transforms/PassDetail.h
@@ -12,11 +12,6 @@
#include "mlir/Pass/Pass.h"
namespace mlir {
-// Forward declaration from Dialect.h
-template <typename ConcreteDialect>
-void registerDialect(DialectRegistry ®istry);
-
-class AffineDialect;
#define GEN_PASS_CLASSES
#include "mlir/Dialect/SCF/Passes.h.inc"
diff --git a/mlir/lib/Dialect/SDBM/SDBMExpr.cpp b/mlir/lib/Dialect/SDBM/SDBMExpr.cpp
index a1971c3da3b2..435c7fe25f0c 100644
--- a/mlir/lib/Dialect/SDBM/SDBMExpr.cpp
+++ b/mlir/lib/Dialect/SDBM/SDBMExpr.cpp
@@ -517,7 +517,7 @@ Optional<SDBMExpr> SDBMExpr::tryConvertAffineExpr(AffineExpr affine) {
SDBMDialect *dialect;
} converter;
- converter.dialect = affine.getContext()->getOrLoadDialect<SDBMDialect>();
+ converter.dialect = affine.getContext()->getRegisteredDialect<SDBMDialect>();
if (auto result = converter.visit(affine))
return result;
diff --git a/mlir/lib/ExecutionEngine/JitRunner.cpp b/mlir/lib/ExecutionEngine/JitRunner.cpp
index 2b18adb37347..7959183e8968 100644
--- a/mlir/lib/ExecutionEngine/JitRunner.cpp
+++ b/mlir/lib/ExecutionEngine/JitRunner.cpp
@@ -259,9 +259,7 @@ int mlir::JitRunnerMain(
}
}
- MLIRContext context(/*loadAllDialects=*/false);
- registerAllDialects(&context);
-
+ MLIRContext context;
auto m = parseMLIRInput(options.inputFilename, &context);
if (!m) {
llvm::errs() << "could not parse the input IR\n";
diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp
index ba1a40cbcc34..555bb2bf0eb4 100644
--- a/mlir/lib/IR/Dialect.cpp
+++ b/mlir/lib/IR/Dialect.cpp
@@ -27,30 +27,21 @@ DialectAsmParser::~DialectAsmParser() {}
//===----------------------------------------------------------------------===//
/// Registry for all dialect allocation functions.
-static llvm::ManagedStatic<DialectRegistry> dialectRegistry;
-DialectRegistry &mlir::getGlobalDialectRegistry() { return *dialectRegistry; }
-
-void mlir::registerAllDialects(MLIRContext *context) {
- dialectRegistry->appendTo(context->getDialectRegistry());
-}
-
-Dialect *DialectRegistry::loadByName(StringRef name, MLIRContext *context) {
- auto it = registry.find(name.str());
- if (it == registry.end())
- return nullptr;
- return it->second.second(context);
+static llvm::ManagedStatic<llvm::MapVector<TypeID, DialectAllocatorFunction>>
+ dialectRegistry;
+
+void Dialect::registerDialectAllocator(
+ TypeID typeID, const DialectAllocatorFunction &function) {
+ assert(function &&
+ "Attempting to register an empty dialect initialize function");
+ dialectRegistry->insert({typeID, function});
}
-void DialectRegistry::insert(TypeID typeID, StringRef name,
- DialectAllocatorFunction ctor) {
- auto inserted =
- registry.insert(std::make_pair((std::string)name,
- std::make_pair(typeID, ctor)));
- if (!inserted.second && inserted.first->second.first != typeID) {
- llvm::report_fatal_error(
- "Trying to register
diff erent dialects for the same namespace: " +
- name);
- }
+/// Registers all dialects and hooks from the global registries with the
+/// specified MLIRContext.
+void mlir::registerAllDialects(MLIRContext *context) {
+ for (const auto &it : *dialectRegistry)
+ it.second(context);
}
//===----------------------------------------------------------------------===//
@@ -128,7 +119,7 @@ DialectInterface::~DialectInterface() {}
DialectInterfaceCollectionBase::DialectInterfaceCollectionBase(
MLIRContext *ctx, TypeID interfaceKind) {
- for (auto *dialect : ctx->getLoadedDialects()) {
+ for (auto *dialect : ctx->getRegisteredDialects()) {
if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) {
interfaces.insert(interface);
orderedInterfaces.push_back(interface);
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 7c8a637ede0f..0d66070657aa 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -31,13 +31,10 @@
#include "llvm/ADT/Twine.h"
#include "llvm/Support/Allocator.h"
#include "llvm/Support/CommandLine.h"
-#include "llvm/Support/Debug.h"
#include "llvm/Support/RWMutex.h"
#include "llvm/Support/raw_ostream.h"
#include <memory>
-#define DEBUG_TYPE "mlircontext"
-
using namespace mlir;
using namespace mlir::detail;
@@ -278,8 +275,7 @@ class MLIRContextImpl {
/// This is a list of dialects that are created referring to this context.
/// The MLIRContext owns the objects.
- DenseMap<StringRef, std::unique_ptr<Dialect>> loadedDialects;
- DialectRegistry dialectsRegistry;
+ std::vector<std::unique_ptr<Dialect>> dialects;
/// This is a mapping from operation name to AbstractOperation for registered
/// operations.
@@ -350,7 +346,7 @@ class MLIRContextImpl {
};
} // end namespace mlir
-MLIRContext::MLIRContext(bool loadAllDialects) : impl(new MLIRContextImpl()) {
+MLIRContext::MLIRContext() : impl(new MLIRContextImpl()) {
// Initialize values based on the command line flags if they were provided.
if (clOptions.isConstructed()) {
disableMultithreading(clOptions->disableThreading);
@@ -359,9 +355,8 @@ MLIRContext::MLIRContext(bool loadAllDialects) : impl(new MLIRContextImpl()) {
}
// Register dialects with this context.
- getOrLoadDialect<BuiltinDialect>();
- if (loadAllDialects)
- loadAllGloballyRegisteredDialects();
+ getOrCreateDialect<BuiltinDialect>();
+ registerAllDialects(this);
// Initialize several common attributes and types to avoid the need to lock
// the context when accessing them.
@@ -443,72 +438,54 @@ DiagnosticEngine &MLIRContext::getDiagEngine() { return getImpl().diagEngine; }
// Dialect and Operation Registration
//===----------------------------------------------------------------------===//
-DialectRegistry &MLIRContext::getDialectRegistry() {
- return impl->dialectsRegistry;
-}
-
/// Return information about all registered IR dialects.
-std::vector<Dialect *> MLIRContext::getLoadedDialects() {
+std::vector<Dialect *> MLIRContext::getRegisteredDialects() {
std::vector<Dialect *> result;
- result.reserve(impl->loadedDialects.size());
- for (auto &dialect : impl->loadedDialects)
- result.push_back(dialect.second.get());
- llvm::array_pod_sort(result.begin(), result.end(),
- [](Dialect *const *lhs, Dialect *const *rhs) -> int {
- return (*lhs)->getNamespace() < (*rhs)->getNamespace();
- });
- return result;
-}
-std::vector<StringRef> MLIRContext::getAvailableDialects() {
- std::vector<StringRef> result;
- for (auto &dialect : impl->dialectsRegistry)
- result.push_back(dialect.first);
+ result.reserve(impl->dialects.size());
+ for (auto &dialect : impl->dialects)
+ result.push_back(dialect.get());
return result;
}
/// Get a registered IR dialect with the given namespace. If none is found,
/// then return nullptr.
-Dialect *MLIRContext::getLoadedDialect(StringRef name) {
+Dialect *MLIRContext::getRegisteredDialect(StringRef name) {
// Dialects are sorted by name, so we can use binary search for lookup.
- auto it = impl->loadedDialects.find(name);
- return (it != impl->loadedDialects.end()) ? it->second.get() : nullptr;
-}
-
-Dialect *MLIRContext::getOrLoadDialect(StringRef name) {
- Dialect *dialect = getLoadedDialect(name);
- if (dialect)
- return dialect;
- return impl->dialectsRegistry.loadByName(name, this);
+ auto it = llvm::lower_bound(
+ impl->dialects, name,
+ [](const auto &lhs, StringRef rhs) { return lhs->getNamespace() < rhs; });
+ return (it != impl->dialects.end() && (*it)->getNamespace() == name)
+ ? (*it).get()
+ : nullptr;
}
/// Get a dialect for the provided namespace and TypeID: abort the program if a
/// dialect exist for this namespace with
diff erent TypeID. Returns a pointer to
/// the dialect owned by the context.
Dialect *
-MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
- function_ref<std::unique_ptr<Dialect>()> ctor) {
+MLIRContext::getOrCreateDialect(StringRef dialectNamespace, TypeID dialectID,
+ function_ref<std::unique_ptr<Dialect>()> ctor) {
auto &impl = getImpl();
// Get the correct insertion position sorted by namespace.
- std::unique_ptr<Dialect> &dialect = impl.loadedDialects[dialectNamespace];
-
- if (!dialect) {
- LLVM_DEBUG(llvm::dbgs()
- << "Load new dialect in Context" << dialectNamespace);
- dialect = ctor();
- assert(dialect && "dialect ctor failed");
- return dialect.get();
- }
+ auto insertPt =
+ llvm::lower_bound(impl.dialects, nullptr,
+ [&](const std::unique_ptr<Dialect> &lhs,
+ const std::unique_ptr<Dialect> &rhs) {
+ if (!lhs)
+ return dialectNamespace < rhs->getNamespace();
+ return lhs->getNamespace() < dialectNamespace;
+ });
// Abort if dialect with namespace has already been registered.
- if (dialect->getTypeID() != dialectID)
+ if (insertPt != impl.dialects.end() &&
+ (*insertPt)->getNamespace() == dialectNamespace) {
+ if ((*insertPt)->getTypeID() == dialectID)
+ return insertPt->get();
llvm::report_fatal_error("a dialect with namespace '" + dialectNamespace +
"' has already been registered");
-
- return dialect.get();
-}
-
-void MLIRContext::loadAllGloballyRegisteredDialects() {
- getGlobalDialectRegistry().loadAll(this);
+ }
+ auto it = impl.dialects.insert(insertPt, ctor());
+ return &**it;
}
bool MLIRContext::allowsUnregisteredDialects() {
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index dce570a9f368..152ed0124767 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -214,7 +214,7 @@ Dialect *Operation::getDialect() {
// If this operation hasn't been registered or doesn't have abstract
// operation, try looking up the dialect name in the context.
- return getContext()->getLoadedDialect(getName().getDialect());
+ return getContext()->getRegisteredDialect(getName().getDialect());
}
Region *Operation::getParentRegion() {
diff --git a/mlir/lib/IR/Verifier.cpp b/mlir/lib/IR/Verifier.cpp
index 4caf9891383c..b1aed8842dc4 100644
--- a/mlir/lib/IR/Verifier.cpp
+++ b/mlir/lib/IR/Verifier.cpp
@@ -50,7 +50,7 @@ class OperationVerifier {
Dialect *getDialectForAttribute(const NamedAttribute &attr) {
assert(attr.first.strref().contains('.') && "expected dialect attribute");
auto dialectNamePair = attr.first.strref().split('.');
- return ctx->getLoadedDialect(dialectNamePair.first);
+ return ctx->getRegisteredDialect(dialectNamePair.first);
}
private:
@@ -218,7 +218,7 @@ LogicalResult OperationVerifier::verifyOperation(Operation &op) {
auto it = dialectAllowsUnknownOps.find(dialectPrefix);
if (it == dialectAllowsUnknownOps.end()) {
// If the operation dialect is registered, query it directly.
- if (auto *dialect = ctx->getLoadedDialect(dialectPrefix))
+ if (auto *dialect = ctx->getRegisteredDialect(dialectPrefix))
it = dialectAllowsUnknownOps
.try_emplace(dialectPrefix, dialect->allowsUnknownOperations())
.first;
diff --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp
index 37ee938a4bcd..1c1261e6d765 100644
--- a/mlir/lib/Parser/AttributeParser.cpp
+++ b/mlir/lib/Parser/AttributeParser.cpp
@@ -12,7 +12,6 @@
#include "Parser.h"
#include "mlir/IR/AffineMap.h"
-#include "mlir/IR/Dialect.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/StandardTypes.h"
#include "llvm/ADT/StringExtras.h"
@@ -247,11 +246,6 @@ ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
return emitError("duplicate key in dictionary attribute");
consumeToken();
- // Lazy load a dialect in the context if there is a possible namespace.
- auto splitName = nameId->strref().split('.');
- if (!splitName.second.empty())
- getContext()->getOrLoadDialect(splitName.first);
-
// Try to parse the '=' for the attribute value.
if (!consumeIf(Token::equal)) {
// If there is no '=', we treat this as a unit attribute.
@@ -823,9 +817,7 @@ Attribute Parser::parseOpaqueElementsAttr(Type attrType) {
return (emitError("expected dialect namespace"), nullptr);
auto name = getToken().getStringValue();
- // Lazy load a dialect in the context if there is a possible namespace.
- Dialect *dialect = builder.getContext()->getOrLoadDialect(name);
-
+ auto *dialect = builder.getContext()->getRegisteredDialect(name);
// TODO: Allow for having an unknown dialect on an opaque
// attribute. Otherwise, it can't be roundtripped without having the dialect
// registered.
diff --git a/mlir/lib/Parser/DialectSymbolParser.cpp b/mlir/lib/Parser/DialectSymbolParser.cpp
index d45ddf071989..3b522a876f25 100644
--- a/mlir/lib/Parser/DialectSymbolParser.cpp
+++ b/mlir/lib/Parser/DialectSymbolParser.cpp
@@ -526,8 +526,7 @@ Attribute Parser::parseExtendedAttr(Type type) {
return Attribute();
// If we found a registered dialect, then ask it to parse the attribute.
- if (Dialect *dialect =
- builder.getContext()->getOrLoadDialect(dialectName)) {
+ if (auto *dialect = state.context->getRegisteredDialect(dialectName)) {
return parseSymbol<Attribute>(
symbolData, state.context, state.symbols, [&](Parser &parser) {
CustomDialectAsmParser customParser(symbolData, parser);
@@ -564,9 +563,7 @@ Type Parser::parseExtendedType() {
[&](StringRef dialectName, StringRef symbolData,
llvm::SMLoc loc) -> Type {
// If we found a registered dialect, then ask it to parse the type.
- auto *dialect = state.context->getOrLoadDialect(dialectName);
-
- if (dialect) {
+ if (auto *dialect = state.context->getRegisteredDialect(dialectName)) {
return parseSymbol<Type>(
symbolData, state.context, state.symbols, [&](Parser &parser) {
CustomDialectAsmParser customParser(symbolData, parser);
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index 837b08ca54c0..3a995a4e2b04 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -12,7 +12,6 @@
#include "Parser.h"
#include "mlir/IR/AffineMap.h"
-#include "mlir/IR/Dialect.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Parser.h"
@@ -728,7 +727,7 @@ Operation *OperationParser::parseGenericOperation() {
// Get location information for the operation.
auto srcLocation = getEncodedSourceLocation(getToken().getLoc());
- std::string name = getToken().getStringValue();
+ auto name = getToken().getStringValue();
if (name.empty())
return (emitError("empty operation name is invalid"), nullptr);
if (name.find('\0') != StringRef::npos)
@@ -738,15 +737,6 @@ Operation *OperationParser::parseGenericOperation() {
OperationState result(srcLocation, name);
- // Lazy load dialects in the context as needed.
- if (!result.name.getAbstractOperation()) {
- StringRef dialectName = StringRef(name).split('.').first;
- if (!getContext()->getLoadedDialect(dialectName) &&
- getContext()->getOrLoadDialect(dialectName)) {
- result.name = OperationName(name, getContext());
- }
- }
-
// Parse the operand list.
SmallVector<SSAUseInfo, 8> operandInfos;
if (parseToken(Token::l_paren, "expected '(' to start operand list") ||
@@ -1452,28 +1442,17 @@ class CustomOpAsmParser : public OpAsmParser {
Operation *
OperationParser::parseCustomOperation(ArrayRef<ResultRecord> resultIDs) {
- llvm::SMLoc opLoc = getToken().getLoc();
- StringRef opName = getTokenSpelling();
+ auto opLoc = getToken().getLoc();
+ auto opName = getTokenSpelling();
auto *opDefinition = AbstractOperation::lookup(opName, getContext());
- if (!opDefinition) {
- if (opName.contains('.')) {
- // This op has a dialect, we try to check if we can register it in the
- // context on the fly.
- StringRef dialectName = opName.split('.').first;
- if (!getContext()->getLoadedDialect(dialectName) &&
- getContext()->getOrLoadDialect(dialectName)) {
- opDefinition = AbstractOperation::lookup(opName, getContext());
- }
- } else {
- // If the operation name has no namespace prefix we treat it as a standard
- // operation and prefix it with "std".
- // TODO: Would it be better to just build a mapping of the registered
- // operations in the standard dialect?
- if (getContext()->getOrLoadDialect("std"))
- opDefinition = AbstractOperation::lookup(Twine("std." + opName).str(),
- getContext());
- }
+ if (!opDefinition && !opName.contains('.')) {
+ // If the operation name has no namespace prefix we treat it as a standard
+ // operation and prefix it with "std".
+ // TODO: Would it be better to just build a mapping of the registered
+ // operations in the standard dialect?
+ opDefinition =
+ AbstractOperation::lookup(Twine("std." + opName).str(), getContext());
}
if (!opDefinition) {
diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index 9bc23c2e4a65..b791bf483e67 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -290,13 +290,6 @@ OpPassManager::pass_iterator OpPassManager::begin() {
}
OpPassManager::pass_iterator OpPassManager::end() { return impl->passes.end(); }
-OpPassManager::const_pass_iterator OpPassManager::begin() const {
- return impl->passes.begin();
-}
-OpPassManager::const_pass_iterator OpPassManager::end() const {
- return impl->passes.end();
-}
-
/// Run all of the passes in this manager over the current operation.
LogicalResult OpPassManager::run(Operation *op, AnalysisManager am) {
// Run each of the held passes.
@@ -353,16 +346,6 @@ void OpPassManager::printAsTextualPipeline(raw_ostream &os) {
::printAsTextualPipeline(impl->passes, os);
}
-static void registerDialectsForPipeline(const OpPassManager &pm,
- DialectRegistry &dialects) {
- for (const Pass &pass : pm.getPasses())
- pass.getDependentDialects(dialects);
-}
-
-void OpPassManager::getDependentDialects(DialectRegistry &dialects) const {
- registerDialectsForPipeline(*this, dialects);
-}
-
//===----------------------------------------------------------------------===//
// OpToOpPassAdaptor
//===----------------------------------------------------------------------===//
@@ -395,11 +378,6 @@ OpToOpPassAdaptor::OpToOpPassAdaptor(OpPassManager &&mgr) {
mgrs.emplace_back(std::move(mgr));
}
-void OpToOpPassAdaptor::getDependentDialects(DialectRegistry &dialects) const {
- for (auto &pm : mgrs)
- pm.getDependentDialects(dialects);
-}
-
/// Merge the current pass adaptor into given 'rhs'.
void OpToOpPassAdaptor::mergeInto(OpToOpPassAdaptor &rhs) {
for (auto &pm : mgrs) {
@@ -743,11 +721,6 @@ LogicalResult PassManager::run(ModuleOp module) {
// pipeline.
getImpl().coalesceAdjacentAdaptorPasses();
- // Register all dialects for the current pipeline.
- DialectRegistry dependentDialects;
- getDependentDialects(dependentDialects);
- dependentDialects.loadAll(module.getContext());
-
// Construct an analysis manager for the pipeline.
ModuleAnalysisManager am(module, instrumentor.get());
diff --git a/mlir/lib/Pass/PassDetail.h b/mlir/lib/Pass/PassDetail.h
index f69701d85e15..2342a1a7af97 100644
--- a/mlir/lib/Pass/PassDetail.h
+++ b/mlir/lib/Pass/PassDetail.h
@@ -43,10 +43,6 @@ class OpToOpPassAdaptor
/// Returns the pass managers held by this adaptor.
MutableArrayRef<OpPassManager> getPassManagers() { return mgrs; }
- /// Populate the set of dependent dialects for the passes in the current
- /// adaptor.
- void getDependentDialects(DialectRegistry &dialects) const override;
-
/// Return the async pass managers held by this parallel adaptor.
MutableArrayRef<SmallVector<OpPassManager, 1>> getParallelPassManagers() {
return asyncExecutors;
diff --git a/mlir/lib/Support/MlirOptMain.cpp b/mlir/lib/Support/MlirOptMain.cpp
index 1eafdfe42b0d..25e197083b62 100644
--- a/mlir/lib/Support/MlirOptMain.cpp
+++ b/mlir/lib/Support/MlirOptMain.cpp
@@ -75,14 +75,13 @@ static LogicalResult processBuffer(raw_ostream &os,
std::unique_ptr<MemoryBuffer> ownedBuffer,
bool verifyDiagnostics, bool verifyPasses,
bool allowUnregisteredDialects,
- bool preloadDialectsInContext,
const PassPipelineCLParser &passPipeline) {
// Tell sourceMgr about this buffer, which is what the parser will pick up.
SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
// Parse the input file.
- MLIRContext context(/*loadAllDialects=*/preloadDialectsInContext);
+ MLIRContext context;
context.allowUnregisteredDialects(allowUnregisteredDialects);
context.printOpOnDiagnostic(!verifyDiagnostics);
@@ -112,8 +111,7 @@ LogicalResult mlir::MlirOptMain(raw_ostream &os,
const PassPipelineCLParser &passPipeline,
bool splitInputFile, bool verifyDiagnostics,
bool verifyPasses,
- bool allowUnregisteredDialects,
- bool preloadDialectsInContext) {
+ bool allowUnregisteredDialects) {
// The split-input-file mode is a very specific mode that slices the file
// up into small pieces and checks each independently.
if (splitInputFile)
@@ -122,11 +120,10 @@ LogicalResult mlir::MlirOptMain(raw_ostream &os,
[&](std::unique_ptr<MemoryBuffer> chunkBuffer, raw_ostream &os) {
return processBuffer(os, std::move(chunkBuffer), verifyDiagnostics,
verifyPasses, allowUnregisteredDialects,
- preloadDialectsInContext, passPipeline);
+ passPipeline);
},
os);
return processBuffer(os, std::move(buffer), verifyDiagnostics, verifyPasses,
- allowUnregisteredDialects, preloadDialectsInContext,
- passPipeline);
+ allowUnregisteredDialects, passPipeline);
}
diff --git a/mlir/lib/TableGen/Dialect.cpp b/mlir/lib/TableGen/Dialect.cpp
index 8aee067d26f7..6af77e7df0f6 100644
--- a/mlir/lib/TableGen/Dialect.cpp
+++ b/mlir/lib/TableGen/Dialect.cpp
@@ -16,11 +16,6 @@
using namespace mlir;
using namespace mlir::tblgen;
-Dialect::Dialect(const llvm::Record *def) : def(def) {
- for (StringRef dialect : def->getValueAsListOfStrings("dependentDialects"))
- dependentDialects.push_back(dialect);
-}
-
StringRef Dialect::getName() const { return def->getValueAsString("name"); }
StringRef Dialect::getCppNamespace() const {
@@ -51,10 +46,6 @@ StringRef Dialect::getDescription() const {
return getAsStringOrEmpty(*def, "description");
}
-ArrayRef<StringRef> Dialect::getDependentDialects() const {
- return dependentDialects;
-}
-
llvm::Optional<StringRef> Dialect::getExtraClassDeclaration() const {
auto value = def->getValueAsString("extraClassDeclaration");
return value.empty() ? llvm::Optional<StringRef>() : value;
diff --git a/mlir/lib/TableGen/Pass.cpp b/mlir/lib/TableGen/Pass.cpp
index f96180689c55..4bc46b622c2b 100644
--- a/mlir/lib/TableGen/Pass.cpp
+++ b/mlir/lib/TableGen/Pass.cpp
@@ -69,8 +69,6 @@ Pass::Pass(const llvm::Record *def) : def(def) {
options.push_back(PassOption(init));
for (auto *init : def->getValueAsListOfDefs("statistics"))
statistics.push_back(PassStatistic(init));
- for (StringRef dialect : def->getValueAsListOfStrings("dependentDialects"))
- dependentDialects.push_back(dialect);
}
StringRef Pass::getArgument() const {
@@ -90,9 +88,6 @@ StringRef Pass::getDescription() const {
StringRef Pass::getConstructor() const {
return def->getValueAsString("constructor");
}
-ArrayRef<StringRef> Pass::getDependentDialects() const {
- return dependentDialects;
-}
ArrayRef<PassOption> Pass::getOptions() const { return options; }
diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
index a5d833893879..470044bc9953 100644
--- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
+++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
@@ -836,7 +836,6 @@ LogicalResult Importer::processBasicBlock(llvm::BasicBlock *bb, Block *block) {
OwningModuleRef
mlir::translateLLVMIRToModule(std::unique_ptr<llvm::Module> llvmModule,
MLIRContext *context) {
- context->loadDialect<LLVMDialect>();
OwningModuleRef module(ModuleOp::create(
FileLineColLoc::get("", /*line=*/0, /*column=*/0, context)));
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 027422ba5dc4..215c1910f744 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -302,7 +302,8 @@ ModuleTranslation::ModuleTranslation(Operation *module,
: mlirModule(module), llvmModule(std::move(llvmModule)),
debugTranslation(
std::make_unique<DebugTranslation>(module, *this->llvmModule)),
- ompDialect(module->getContext()->getOrLoadDialect<omp::OpenMPDialect>()),
+ ompDialect(
+ module->getContext()->getRegisteredDialect<omp::OpenMPDialect>()),
typeTranslator(this->llvmModule->getContext()) {
assert(satisfiesLLVMModule(mlirModule) &&
"mlirModule should honor LLVM's module semantics.");
@@ -943,7 +944,7 @@ ModuleTranslation::lookupValues(ValueRange values) {
std::unique_ptr<llvm::Module> ModuleTranslation::prepareLLVMModule(
Operation *m, llvm::LLVMContext &llvmContext, StringRef name) {
- auto *dialect = m->getContext()->getOrLoadDialect<LLVM::LLVMDialect>();
+ auto *dialect = m->getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
assert(dialect && "LLVM dialect must be registered");
auto llvmModule = std::make_unique<llvm::Module>(name, llvmContext);
diff --git a/mlir/lib/Transforms/PassDetail.h b/mlir/lib/Transforms/PassDetail.h
index 220ed1aac407..c6f7e225d71a 100644
--- a/mlir/lib/Transforms/PassDetail.h
+++ b/mlir/lib/Transforms/PassDetail.h
@@ -12,13 +12,6 @@
#include "mlir/Pass/Pass.h"
namespace mlir {
-// Forward declaration from Dialect.h
-template <typename ConcreteDialect>
-void registerDialect(DialectRegistry ®istry);
-
-namespace linalg {
-class LinalgDialect;
-} // end namespace linalg
#define GEN_PASS_CLASSES
#include "mlir/Transforms/Passes.h.inc"
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index df2d32f7913a..d6ab3513384f 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -243,7 +243,6 @@ static void printFirstOfEach(MlirOperation operation) {
int main() {
mlirRegisterAllDialects();
MlirContext ctx = mlirContextCreate();
- mlirContextLoadAllDialects(ctx);
MlirLocation location = mlirLocationUnknownGet(ctx);
MlirModule moduleOp = makeAdd(ctx, location);
diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp
index 062e4b591229..3fcfcf24ef8f 100644
--- a/mlir/test/EDSC/builder-api-test.cpp
+++ b/mlir/test/EDSC/builder-api-test.cpp
@@ -36,18 +36,16 @@ using namespace mlir::edsc;
using namespace mlir::edsc::intrinsics;
static MLIRContext &globalContext() {
- static thread_local MLIRContext context(/*loadAllDialects=*/false);
- static thread_local bool initOnce = [&]() {
- // clang-format off
- context.loadDialect<AffineDialect,
- scf::SCFDialect,
- linalg::LinalgDialect,
- StandardOpsDialect,
- vector::VectorDialect>();
- // clang-format on
+ static bool init_once = []() {
+ registerDialect<AffineDialect>();
+ registerDialect<linalg::LinalgDialect>();
+ registerDialect<scf::SCFDialect>();
+ registerDialect<StandardOpsDialect>();
+ registerDialect<vector::VectorDialect>();
return true;
}();
- (void)initOnce;
+ (void)init_once;
+ static thread_local MLIRContext context;
context.allowUnregisteredDialects();
return context;
}
diff --git a/mlir/test/SDBM/sdbm-api-test.cpp b/mlir/test/SDBM/sdbm-api-test.cpp
index ddefc52fb461..0b58e2948145 100644
--- a/mlir/test/SDBM/sdbm-api-test.cpp
+++ b/mlir/test/SDBM/sdbm-api-test.cpp
@@ -19,19 +19,18 @@
using namespace mlir;
+// Load the SDBM dialect
+static DialectRegistration<SDBMDialect> SDBMRegistration;
static MLIRContext *ctx() {
- static thread_local MLIRContext context(/*loadAllDialects=*/false);
- static thread_local bool once =
- (context.getOrLoadDialect<SDBMDialect>(), true);
- (void)once;
+ static thread_local MLIRContext context;
return &context;
}
static SDBMDialect *dialect() {
static thread_local SDBMDialect *d = nullptr;
if (!d) {
- d = ctx()->getOrLoadDialect<SDBMDialect>();
+ d = ctx()->getRegisteredDialect<SDBMDialect>();
}
return d;
}
diff --git a/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp b/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp
index cfac2dce2300..a6719b060aac 100644
--- a/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp
@@ -14,7 +14,6 @@
#include "mlir/Analysis/NestedMatcher.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Dialect/Vector/VectorUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Diagnostics.h"
@@ -73,9 +72,6 @@ struct VectorizerTestPass
: public PassWrapper<VectorizerTestPass, FunctionPass> {
static constexpr auto kTestAffineMapOpName = "test_affine_map";
static constexpr auto kTestAffineMapAttrName = "affine_map";
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<vector::VectorDialect>();
- }
void runOnFunction() override;
void testVectorShapeRatio(llvm::raw_ostream &outs);
diff --git a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
index 03c425d6d906..0c1069f38b67 100644
--- a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
+++ b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
@@ -30,7 +30,7 @@ void PrintOpAvailability::runOnFunction() {
auto f = getFunction();
llvm::outs() << f.getName() << "\n";
- Dialect *spvDialect = getContext().getLoadedDialect("spv");
+ Dialect *spvDialect = getContext().getRegisteredDialect("spv");
f.getOperation()->walk([&](Operation *op) {
if (op->getDialect() != spvDialect)
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index be5d799a0253..f2a17a9f3f5f 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -768,10 +768,6 @@ struct TestTypeConversionProducer
struct TestTypeConversionDriver
: public PassWrapper<TestTypeConversionDriver, OperationPass<ModuleOp>> {
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<TestDialect>();
- }
-
void runOnOperation() override {
// Initialize the type converter.
TypeConverter converter;
diff --git a/mlir/test/lib/Transforms/TestAllReduceLowering.cpp b/mlir/test/lib/Transforms/TestAllReduceLowering.cpp
index 0c72b6cd2a89..c043d0f02f8d 100644
--- a/mlir/test/lib/Transforms/TestAllReduceLowering.cpp
+++ b/mlir/test/lib/Transforms/TestAllReduceLowering.cpp
@@ -11,7 +11,6 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/GPU/Passes.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
@@ -20,9 +19,6 @@ using namespace mlir;
namespace {
struct TestAllReduceLoweringPass
: public PassWrapper<TestAllReduceLoweringPass, OperationPass<ModuleOp>> {
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<StandardOpsDialect>();
- }
void runOnOperation() override {
OwningRewritePatternList patterns;
populateGpuRewritePatterns(&getContext(), patterns);
diff --git a/mlir/test/lib/Transforms/TestBufferPlacement.cpp b/mlir/test/lib/Transforms/TestBufferPlacement.cpp
index 6cc0924191cb..5ad441aa15c3 100644
--- a/mlir/test/lib/Transforms/TestBufferPlacement.cpp
+++ b/mlir/test/lib/Transforms/TestBufferPlacement.cpp
@@ -116,10 +116,6 @@ struct TestBufferPlacementPreparationPass
patterns->insert<GenericOpConverter>(context, placer, converter);
}
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<linalg::LinalgDialect>();
- }
-
void runOnOperation() override {
MLIRContext &context = this->getContext();
ConversionTarget target(context);
diff --git a/mlir/test/lib/Transforms/TestGpuMemoryPromotion.cpp b/mlir/test/lib/Transforms/TestGpuMemoryPromotion.cpp
index 3c2b933e99f6..08862dd06140 100644
--- a/mlir/test/lib/Transforms/TestGpuMemoryPromotion.cpp
+++ b/mlir/test/lib/Transforms/TestGpuMemoryPromotion.cpp
@@ -13,9 +13,6 @@
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/GPU/MemoryPromotion.h"
-#include "mlir/Dialect/SCF/SCF.h"
-#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Attributes.h"
#include "mlir/Pass/Pass.h"
@@ -29,10 +26,6 @@ namespace {
class TestGpuMemoryPromotionPass
: public PassWrapper<TestGpuMemoryPromotionPass,
OperationPass<gpu::GPUFuncOp>> {
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<StandardOpsDialect, scf::SCFDialect>();
- }
-
void runOnOperation() override {
gpu::GPUFuncOp op = getOperation();
for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i) {
diff --git a/mlir/test/lib/Transforms/TestLinalgHoisting.cpp b/mlir/test/lib/Transforms/TestLinalgHoisting.cpp
index 5d4031f90043..d1e478fec3bc 100644
--- a/mlir/test/lib/Transforms/TestLinalgHoisting.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgHoisting.cpp
@@ -10,7 +10,6 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
#include "mlir/Pass/Pass.h"
@@ -23,9 +22,6 @@ struct TestLinalgHoisting
: public PassWrapper<TestLinalgHoisting, FunctionPass> {
TestLinalgHoisting() = default;
TestLinalgHoisting(const TestLinalgHoisting &pass) {}
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<AffineDialect>();
- }
void runOnFunction() override;
diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
index 2ca6b05c7616..f6c1160d35b0 100644
--- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
@@ -15,7 +15,6 @@
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
-#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
@@ -31,16 +30,6 @@ struct TestLinalgTransforms
TestLinalgTransforms() = default;
TestLinalgTransforms(const TestLinalgTransforms &pass) {}
- void getDependentDialects(DialectRegistry ®istry) const override {
- // clang-format off
- registry.insert<AffineDialect,
- scf::SCFDialect,
- StandardOpsDialect,
- vector::VectorDialect,
- gpu::GPUDialect>();
- // clang-format on
- }
-
void runOnFunction() override;
Option<bool> testPatterns{*this, "test-patterns",
diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
index ab8460318b49..9da3156d5359 100644
--- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
@@ -8,9 +8,6 @@
#include <type_traits>
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
-#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
-#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/Dialect/Vector/VectorTransforms.h"
@@ -131,11 +128,6 @@ struct TestVectorTransferFullPartialSplitPatterns
TestVectorTransferFullPartialSplitPatterns() = default;
TestVectorTransferFullPartialSplitPatterns(
const TestVectorTransferFullPartialSplitPatterns &pass) {}
-
- void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<AffineDialect, linalg::LinalgDialect, scf::SCFDialect>();
- }
-
Option<bool> useLinalgOps{
*this, "use-linalg-copy",
llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + "
diff --git a/mlir/test/mlir-opt/commandline.mlir b/mlir/test/mlir-opt/commandline.mlir
index 4cf6ea9d8a69..f99a68d6303c 100644
--- a/mlir/test/mlir-opt/commandline.mlir
+++ b/mlir/test/mlir-opt/commandline.mlir
@@ -1,5 +1,5 @@
// RUN: mlir-opt --show-dialects | FileCheck %s
-// CHECK: Available Dialects:
+// CHECK: Registered Dialects:
// CHECK: affine
// CHECK: gpu
// CHECK: linalg
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
index 92efef67e8f4..12e6aeef9162 100644
--- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
@@ -1703,7 +1703,7 @@ int main(int argc, char **argv) {
if (testEmitIncludeTdHeader)
output->os() << "include \"mlir/Dialect/Linalg/IR/LinalgStructuredOps.td\"";
- MLIRContext context(/*loadAllDialects=*/false);
+ MLIRContext context;
llvm::SourceMgr mgr;
mgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc());
Parser parser(mgr, &context);
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 53ea4dae20f7..efcb32856607 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -175,10 +175,11 @@ int main(int argc, char **argv) {
cl::ParseCommandLineOptions(argc, argv, "MLIR modular optimizer driver\n");
if(showDialects) {
- MLIRContext context(false);
- registerAllDialects(&context);
- llvm::outs() << "Available Dialects:\n";
- interleave(context.getAvailableDialects(), llvm::outs(), "\n");
+ llvm::outs() << "Registered Dialects:\n";
+ MLIRContext context;
+ for(Dialect *dialect : context.getRegisteredDialects()) {
+ llvm::outs() << dialect->getNamespace() << "\n";
+ }
return 0;
}
diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp
index 3a19379da8a3..13421c42c3c2 100644
--- a/mlir/tools/mlir-tblgen/DialectGen.cpp
+++ b/mlir/tools/mlir-tblgen/DialectGen.cpp
@@ -61,14 +61,11 @@ filterForDialect(ArrayRef<llvm::Record *> records, Dialect &dialect) {
///
/// {0}: The name of the dialect class.
/// {1}: The dialect namespace.
-/// {2}: initialization code that is emitted in the ctor body before calling
-/// initialize()
static const char *const dialectDeclBeginStr = R"(
class {0} : public ::mlir::Dialect {
explicit {0}(::mlir::MLIRContext *context)
: ::mlir::Dialect(getDialectNamespace(), context,
::mlir::TypeID::get<{0}>()) {{
- {2}
initialize();
}
void initialize();
@@ -77,12 +74,6 @@ class {0} : public ::mlir::Dialect {
static ::llvm::StringRef getDialectNamespace() { return "{1}"; }
)";
-/// Registration for a single dependent dialect: to be inserted in the ctor
-/// above for each dependent dialect.
-const char *const dialectRegistrationTemplate = R"(
- getContext()->getOrLoadDialect<{0}>();
-)";
-
/// The code block for the attribute parser/printer hooks.
static const char *const attrParserDecl = R"(
/// Parse an attribute registered to this dialect.
@@ -145,18 +136,9 @@ static void emitDialectDecl(Dialect &dialect,
iterator_range<DialectFilterIterator> dialectAttrs,
iterator_range<DialectFilterIterator> dialectTypes,
raw_ostream &os) {
- /// Build the list of dependent dialects
- std::string dependentDialectRegistrations;
- {
- llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
- for (StringRef dependentDialect : dialect.getDependentDialects())
- dialectsOs << llvm::formatv(dialectRegistrationTemplate,
- dependentDialect);
- }
// Emit the start of the decl.
std::string cppName = dialect.getCppClassName();
- os << llvm::formatv(dialectDeclBeginStr, cppName, dialect.getName(),
- dependentDialectRegistrations);
+ os << llvm::formatv(dialectDeclBeginStr, cppName, dialect.getName());
// Check for any attributes/types registered to this dialect. If there are,
// add the hooks for parsing/printing.
diff --git a/mlir/tools/mlir-tblgen/PassGen.cpp b/mlir/tools/mlir-tblgen/PassGen.cpp
index c1664a0c826c..c2dcdb8e4ac9 100644
--- a/mlir/tools/mlir-tblgen/PassGen.cpp
+++ b/mlir/tools/mlir-tblgen/PassGen.cpp
@@ -36,7 +36,6 @@ static llvm::cl::opt<std::string>
/// {0}: The def name of the pass record.
/// {1}: The base class for the pass.
/// {2): The command line argument for the pass.
-/// {3}: The dependent dialects registration.
const char *const passDeclBegin = R"(
//===----------------------------------------------------------------------===//
// {0}
@@ -64,20 +63,9 @@ class {0}Base : public {1} {
return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
}
- /// Return the dialect that must be loaded in the context before this pass.
- void getDependentDialects(::mlir::DialectRegistry ®istry) const override {
- {3}
- }
-
protected:
)";
-/// Registration for a single dependent dialect, to be inserted for each
-/// dependent dialect in the `getDependentDialects` above.
-const char *const dialectRegistrationTemplate = R"(
- registry.insert<{0}>();
-)";
-
/// Emit the declarations for each of the pass options.
static void emitPassOptionDecls(const Pass &pass, raw_ostream &os) {
for (const PassOption &opt : pass.getOptions()) {
@@ -106,15 +94,8 @@ static void emitPassStatisticDecls(const Pass &pass, raw_ostream &os) {
static void emitPassDecl(const Pass &pass, raw_ostream &os) {
StringRef defName = pass.getDef()->getName();
- std::string dependentDialectRegistrations;
- {
- llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
- for (StringRef dependentDialect : pass.getDependentDialects())
- dialectsOs << llvm::formatv(dialectRegistrationTemplate,
- dependentDialect);
- }
os << llvm::formatv(passDeclBegin, defName, pass.getBaseClass(),
- pass.getArgument(), dependentDialectRegistrations);
+ pass.getArgument());
emitPassOptionDecls(pass, os);
emitPassStatisticDecls(pass, os);
os << "};\n";
diff --git a/mlir/tools/mlir-translate/mlir-translate.cpp b/mlir/tools/mlir-translate/mlir-translate.cpp
index 0d67286a8a91..914bd340b3f5 100644
--- a/mlir/tools/mlir-translate/mlir-translate.cpp
+++ b/mlir/tools/mlir-translate/mlir-translate.cpp
@@ -88,8 +88,7 @@ int main(int argc, char **argv) {
// Processes the memory buffer with a new MLIRContext.
auto processBuffer = [&](std::unique_ptr<llvm::MemoryBuffer> ownedBuffer,
raw_ostream &os) {
- MLIRContext context(false);
- registerAllDialects(&context);
+ MLIRContext context;
context.allowUnregisteredDialects();
context.printOpOnDiagnostic(!verifyDiagnostics);
llvm::SourceMgr sourceMgr;
diff --git a/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp b/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp
index bae95e1a13b6..97c94a54ffc4 100644
--- a/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp
+++ b/mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp
@@ -17,6 +17,9 @@
using namespace mlir;
using namespace mlir::quant;
+// Load the quant dialect
+static DialectRegistration<QuantizationDialect> QuantOpsRegistration;
+
namespace {
// Test UniformQuantizedValueConverter converts all APFloat to a magic number 5.
@@ -75,8 +78,7 @@ UniformQuantizedType getTestQuantizedType(Type storageType, MLIRContext *ctx) {
}
TEST(QuantizationUtilsTest, convertFloatAttrUniform) {
- MLIRContext ctx(/*loadAllDialects=*/false);
- ctx.getOrLoadDialect<QuantizationDialect>();
+ MLIRContext ctx;
IntegerType convertedType = IntegerType::get(8, &ctx);
auto quantizedType = getTestQuantizedType(convertedType, &ctx);
TestUniformQuantizedValueConverter converter(quantizedType);
@@ -93,8 +95,7 @@ TEST(QuantizationUtilsTest, convertFloatAttrUniform) {
}
TEST(QuantizationUtilsTest, convertRankedDenseAttrUniform) {
- MLIRContext ctx(/*loadAllDialects=*/false);
- ctx.getOrLoadDialect<QuantizationDialect>();
+ MLIRContext ctx;
IntegerType convertedType = IntegerType::get(8, &ctx);
auto quantizedType = getTestQuantizedType(convertedType, &ctx);
TestUniformQuantizedValueConverter converter(quantizedType);
@@ -118,8 +119,7 @@ TEST(QuantizationUtilsTest, convertRankedDenseAttrUniform) {
}
TEST(QuantizationUtilsTest, convertRankedSplatAttrUniform) {
- MLIRContext ctx(/*loadAllDialects=*/false);
- ctx.getOrLoadDialect<QuantizationDialect>();
+ MLIRContext ctx;
IntegerType convertedType = IntegerType::get(8, &ctx);
auto quantizedType = getTestQuantizedType(convertedType, &ctx);
TestUniformQuantizedValueConverter converter(quantizedType);
@@ -143,8 +143,7 @@ TEST(QuantizationUtilsTest, convertRankedSplatAttrUniform) {
}
TEST(QuantizationUtilsTest, convertRankedSparseAttrUniform) {
- MLIRContext ctx(/*loadAllDialects=*/false);
- ctx.getOrLoadDialect<QuantizationDialect>();
+ MLIRContext ctx;
IntegerType convertedType = IntegerType::get(8, &ctx);
auto quantizedType = getTestQuantizedType(convertedType, &ctx);
TestUniformQuantizedValueConverter converter(quantizedType);
diff --git a/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp b/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp
index 4aa2ffed7e2b..fe5632d7ae16 100644
--- a/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp
+++ b/mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp
@@ -38,8 +38,7 @@ using ::testing::StrEq;
/// diagnostic checking utilities.
class DeserializationTest : public ::testing::Test {
protected:
- DeserializationTest() : context(/*loadAllDialects=*/false) {
- context.getOrLoadDialect<mlir::spirv::SPIRVDialect>();
+ DeserializationTest() {
// Register a diagnostic handler to capture the diagnostic so that we can
// check it later.
context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
diff --git a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
index cb89cd61de7b..3d57e559ca5e 100644
--- a/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
+++ b/mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
@@ -36,10 +36,7 @@ using namespace mlir;
class SerializationTest : public ::testing::Test {
protected:
- SerializationTest() : context(/*loadAllDialects=*/false) {
- context.getOrLoadDialect<mlir::spirv::SPIRVDialect>();
- createModuleOp();
- }
+ SerializationTest() { createModuleOp(); }
void createModuleOp() {
OpBuilder builder(&context);
diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp
index 78f7dd53d8fd..df449a0da75c 100644
--- a/mlir/unittests/IR/AttributeTest.cpp
+++ b/mlir/unittests/IR/AttributeTest.cpp
@@ -32,7 +32,7 @@ static void testSplat(Type eltType, const EltTy &splatElt) {
namespace {
TEST(DenseSplatTest, BoolSplat) {
- MLIRContext context(false);
+ MLIRContext context;
IntegerType boolTy = IntegerType::get(1, &context);
RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy);
@@ -57,7 +57,7 @@ TEST(DenseSplatTest, BoolSplat) {
TEST(DenseSplatTest, LargeBoolSplat) {
constexpr int64_t boolCount = 56;
- MLIRContext context(false);
+ MLIRContext context;
IntegerType boolTy = IntegerType::get(1, &context);
RankedTensorType shape = RankedTensorType::get({boolCount}, boolTy);
@@ -80,7 +80,7 @@ TEST(DenseSplatTest, LargeBoolSplat) {
}
TEST(DenseSplatTest, BoolNonSplat) {
- MLIRContext context(false);
+ MLIRContext context;
IntegerType boolTy = IntegerType::get(1, &context);
RankedTensorType shape = RankedTensorType::get({6}, boolTy);
@@ -92,7 +92,7 @@ TEST(DenseSplatTest, BoolNonSplat) {
TEST(DenseSplatTest, OddIntSplat) {
// Test detecting a splat with an odd(non 8-bit) integer bitwidth.
- MLIRContext context(false);
+ MLIRContext context;
constexpr size_t intWidth = 19;
IntegerType intTy = IntegerType::get(intWidth, &context);
APInt value(intWidth, 10);
@@ -101,7 +101,7 @@ TEST(DenseSplatTest, OddIntSplat) {
}
TEST(DenseSplatTest, Int32Splat) {
- MLIRContext context(false);
+ MLIRContext context;
IntegerType intTy = IntegerType::get(32, &context);
int value = 64;
@@ -109,7 +109,7 @@ TEST(DenseSplatTest, Int32Splat) {
}
TEST(DenseSplatTest, IntAttrSplat) {
- MLIRContext context(false);
+ MLIRContext context;
IntegerType intTy = IntegerType::get(85, &context);
Attribute value = IntegerAttr::get(intTy, 109);
@@ -117,7 +117,7 @@ TEST(DenseSplatTest, IntAttrSplat) {
}
TEST(DenseSplatTest, F32Splat) {
- MLIRContext context(false);
+ MLIRContext context;
FloatType floatTy = FloatType::getF32(&context);
float value = 10.0;
@@ -125,7 +125,7 @@ TEST(DenseSplatTest, F32Splat) {
}
TEST(DenseSplatTest, F64Splat) {
- MLIRContext context(false);
+ MLIRContext context;
FloatType floatTy = FloatType::getF64(&context);
double value = 10.0;
@@ -133,7 +133,7 @@ TEST(DenseSplatTest, F64Splat) {
}
TEST(DenseSplatTest, FloatAttrSplat) {
- MLIRContext context(false);
+ MLIRContext context;
FloatType floatTy = FloatType::getF32(&context);
Attribute value = FloatAttr::get(floatTy, 10.0);
@@ -141,7 +141,7 @@ TEST(DenseSplatTest, FloatAttrSplat) {
}
TEST(DenseSplatTest, BF16Splat) {
- MLIRContext context(false);
+ MLIRContext context;
FloatType floatTy = FloatType::getBF16(&context);
Attribute value = FloatAttr::get(floatTy, 10.0);
@@ -149,7 +149,7 @@ TEST(DenseSplatTest, BF16Splat) {
}
TEST(DenseSplatTest, StringSplat) {
- MLIRContext context(false);
+ MLIRContext context;
Type stringType =
OpaqueType::get(Identifier::get("test", &context), "string", &context);
StringRef value = "test-string";
@@ -157,7 +157,7 @@ TEST(DenseSplatTest, StringSplat) {
}
TEST(DenseSplatTest, StringAttrSplat) {
- MLIRContext context(false);
+ MLIRContext context;
Type stringType =
OpaqueType::get(Identifier::get("test", &context), "string", &context);
Attribute stringAttr = StringAttr::get("test-string", stringType);
@@ -165,28 +165,28 @@ TEST(DenseSplatTest, StringAttrSplat) {
}
TEST(DenseComplexTest, ComplexFloatSplat) {
- MLIRContext context(false);
+ MLIRContext context;
ComplexType complexType = ComplexType::get(FloatType::getF32(&context));
std::complex<float> value(10.0, 15.0);
testSplat(complexType, value);
}
TEST(DenseComplexTest, ComplexIntSplat) {
- MLIRContext context(false);
+ MLIRContext context;
ComplexType complexType = ComplexType::get(IntegerType::get(64, &context));
std::complex<int64_t> value(10, 15);
testSplat(complexType, value);
}
TEST(DenseComplexTest, ComplexAPFloatSplat) {
- MLIRContext context(false);
+ MLIRContext context;
ComplexType complexType = ComplexType::get(FloatType::getF32(&context));
std::complex<APFloat> value(APFloat(10.0f), APFloat(15.0f));
testSplat(complexType, value);
}
TEST(DenseComplexTest, ComplexAPIntSplat) {
- MLIRContext context(false);
+ MLIRContext context;
ComplexType complexType = ComplexType::get(IntegerType::get(64, &context));
std::complex<APInt> value(APInt(64, 10), APInt(64, 15));
testSplat(complexType, value);
diff --git a/mlir/unittests/IR/DialectTest.cpp b/mlir/unittests/IR/DialectTest.cpp
index 5a0a229d2c82..bc389ce1f0da 100644
--- a/mlir/unittests/IR/DialectTest.cpp
+++ b/mlir/unittests/IR/DialectTest.cpp
@@ -26,12 +26,12 @@ struct AnotherTestDialect : public Dialect {
};
TEST(DialectDeathTest, MultipleDialectsWithSameNamespace) {
- MLIRContext context(false);
+ MLIRContext context;
// Registering a dialect with the same namespace twice should result in a
// failure.
- context.loadDialect<TestDialect>();
- ASSERT_DEATH(context.loadDialect<AnotherTestDialect>(), "");
+ context.getOrCreateDialect<TestDialect>();
+ ASSERT_DEATH(context.getOrCreateDialect<AnotherTestDialect>(), "");
}
} // end namespace
diff --git a/mlir/unittests/IR/OperationSupportTest.cpp b/mlir/unittests/IR/OperationSupportTest.cpp
index 96693309bd13..95ddcccc565e 100644
--- a/mlir/unittests/IR/OperationSupportTest.cpp
+++ b/mlir/unittests/IR/OperationSupportTest.cpp
@@ -25,7 +25,7 @@ static Operation *createOp(MLIRContext *context,
namespace {
TEST(OperandStorageTest, NonResizable) {
- MLIRContext context(false);
+ MLIRContext context;
Builder builder(&context);
Operation *useOp =
@@ -49,7 +49,7 @@ TEST(OperandStorageTest, NonResizable) {
}
TEST(OperandStorageTest, Resizable) {
- MLIRContext context(false);
+ MLIRContext context;
Builder builder(&context);
Operation *useOp =
@@ -77,7 +77,7 @@ TEST(OperandStorageTest, Resizable) {
}
TEST(OperandStorageTest, RangeReplace) {
- MLIRContext context(false);
+ MLIRContext context;
Builder builder(&context);
Operation *useOp =
@@ -113,7 +113,7 @@ TEST(OperandStorageTest, RangeReplace) {
}
TEST(OperandStorageTest, MutableRange) {
- MLIRContext context(false);
+ MLIRContext context;
Builder builder(&context);
Operation *useOp =
diff --git a/mlir/unittests/Pass/AnalysisManagerTest.cpp b/mlir/unittests/Pass/AnalysisManagerTest.cpp
index 958cf43b209d..a99df3911a5d 100644
--- a/mlir/unittests/Pass/AnalysisManagerTest.cpp
+++ b/mlir/unittests/Pass/AnalysisManagerTest.cpp
@@ -24,7 +24,7 @@ struct OtherAnalysis {
};
TEST(AnalysisManagerTest, FineGrainModuleAnalysisPreservation) {
- MLIRContext context(false);
+ MLIRContext context;
// Test fine grain invalidation of the module analysis manager.
OwningModuleRef module(ModuleOp::create(UnknownLoc::get(&context)));
@@ -45,7 +45,7 @@ TEST(AnalysisManagerTest, FineGrainModuleAnalysisPreservation) {
}
TEST(AnalysisManagerTest, FineGrainFunctionAnalysisPreservation) {
- MLIRContext context(false);
+ MLIRContext context;
Builder builder(&context);
// Create a function and a module.
@@ -74,7 +74,7 @@ TEST(AnalysisManagerTest, FineGrainFunctionAnalysisPreservation) {
}
TEST(AnalysisManagerTest, FineGrainChildFunctionAnalysisPreservation) {
- MLIRContext context(false);
+ MLIRContext context;
Builder builder(&context);
// Create a function and a module.
@@ -117,7 +117,7 @@ struct CustomInvalidatingAnalysis {
};
TEST(AnalysisManagerTest, CustomInvalidation) {
- MLIRContext context(false);
+ MLIRContext context;
Builder builder(&context);
// Create a function and a module.
diff --git a/mlir/unittests/SDBM/SDBMTest.cpp b/mlir/unittests/SDBM/SDBMTest.cpp
index bbe87e3d292c..61d670650b4b 100644
--- a/mlir/unittests/SDBM/SDBMTest.cpp
+++ b/mlir/unittests/SDBM/SDBMTest.cpp
@@ -17,17 +17,18 @@
using namespace mlir;
+/// Load the SDBM dialect.
+static DialectRegistration<SDBMDialect> SDBMRegistration;
static MLIRContext *ctx() {
- static thread_local MLIRContext context(false);
- context.getOrLoadDialect<SDBMDialect>();
+ static thread_local MLIRContext context;
return &context;
}
static SDBMDialect *dialect() {
static thread_local SDBMDialect *d = nullptr;
if (!d) {
- d = ctx()->getOrLoadDialect<SDBMDialect>();
+ d = ctx()->getRegisteredDialect<SDBMDialect>();
}
return d;
}
diff --git a/mlir/unittests/TableGen/OpBuildGen.cpp b/mlir/unittests/TableGen/OpBuildGen.cpp
index 46a37da6e944..3e3256e96cd0 100644
--- a/mlir/unittests/TableGen/OpBuildGen.cpp
+++ b/mlir/unittests/TableGen/OpBuildGen.cpp
@@ -25,16 +25,11 @@ namespace mlir {
// Test Fixture
//===----------------------------------------------------------------------===//
-static MLIRContext &getContext() {
- static MLIRContext ctx(false);
- ctx.getOrLoadDialect<TestDialect>();
- return ctx;
-}
/// Test fixture for providing basic utilities for testing.
class OpBuildGenTest : public ::testing::Test {
protected:
OpBuildGenTest()
- : ctx(getContext()), builder(&ctx), loc(builder.getUnknownLoc()),
+ : ctx{}, builder(&ctx), loc(builder.getUnknownLoc()),
i32Ty(builder.getI32Type()), f32Ty(builder.getF32Type()),
cstI32(builder.create<TableGenConstant>(loc, i32Ty)),
cstF32(builder.create<TableGenConstant>(loc, f32Ty)),
@@ -91,7 +86,7 @@ class OpBuildGenTest : public ::testing::Test {
}
protected:
- MLIRContext &ctx;
+ MLIRContext ctx;
OpBuilder builder;
Location loc;
Type i32Ty;
diff --git a/mlir/unittests/TableGen/StructsGenTest.cpp b/mlir/unittests/TableGen/StructsGenTest.cpp
index 14b0abc675bf..c58fedb4ec4f 100644
--- a/mlir/unittests/TableGen/StructsGenTest.cpp
+++ b/mlir/unittests/TableGen/StructsGenTest.cpp
@@ -42,7 +42,7 @@ static test::TestStruct getTestStruct(mlir::MLIRContext *context) {
/// Validates that test::TestStruct::classof correctly identifies a valid
/// test::TestStruct.
TEST(StructsGenTest, ClassofTrue) {
- mlir::MLIRContext context(false);
+ mlir::MLIRContext context;
auto structAttr = getTestStruct(&context);
ASSERT_TRUE(test::TestStruct::classof(structAttr));
}
More information about the Mlir-commits
mailing list