[Mlir-commits] [mlir] ba8424a - [mlir] Add Dynamic Dialects
River Riddle
llvmlistbot at llvm.org
Mon Sep 19 09:58:31 PDT 2022
Author: Mathieu Fehr
Date: 2022-09-19T09:58:18-07:00
New Revision: ba8424a251d72756d3d697fd0b208c1c29f3b99f
URL: https://github.com/llvm/llvm-project/commit/ba8424a251d72756d3d697fd0b208c1c29f3b99f
DIFF: https://github.com/llvm/llvm-project/commit/ba8424a251d72756d3d697fd0b208c1c29f3b99f.diff
LOG: [mlir] Add Dynamic Dialects
Dynamic dialects are dialects that can be defined at runtime.
Dynamic dialects are extensible by new operations, types, and
attributes at runtime.
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D125201
Added:
mlir/test/lib/Dialect/TestDyn/CMakeLists.txt
mlir/test/lib/Dialect/TestDyn/TestDynDialect.cpp
Modified:
mlir/docs/DefiningDialects.md
mlir/include/mlir/IR/DialectRegistry.h
mlir/include/mlir/IR/ExtensibleDialect.h
mlir/include/mlir/IR/MLIRContext.h
mlir/lib/IR/Dialect.cpp
mlir/lib/IR/ExtensibleDialect.cpp
mlir/lib/IR/MLIRContext.cpp
mlir/test/IR/dynamic.mlir
mlir/test/lib/Dialect/CMakeLists.txt
mlir/test/mlir-opt/commandline.mlir
mlir/tools/mlir-opt/CMakeLists.txt
mlir/tools/mlir-opt/mlir-opt.cpp
Removed:
################################################################################
diff --git a/mlir/docs/DefiningDialects.md b/mlir/docs/DefiningDialects.md
index 1b5a6fb48a150..7ef6ae0235ef5 100644
--- a/mlir/docs/DefiningDialects.md
+++ b/mlir/docs/DefiningDialects.md
@@ -372,6 +372,30 @@ if (auto extensibleDialect = llvm::dyn_cast<ExtensibleDialect>(dialect)) {
}
```
+### Defining a dynamic dialect
+
+Dynamic dialects are extensible dialects that can be defined at runtime. They
+are only populated with dynamic operations, types, and attributes. They can be
+registered in a `DialectRegistry` with `insertDynamic`.
+
+```c++
+auto populateDialect = [](MLIRContext *ctx, DynamicDialect* dialect) {
+ // Code that will be ran when the dynamic dialect is created and loaded.
+ // For instance, this is where we register the dynamic operations, types, and
+ // attributes of the dialect.
+ ...
+}
+
+registry.insertDynamic("dialectName", populateDialect);
+```
+
+Once a dynamic dialect is registered in the `MLIRContext`, it can be retrieved
+with `getOrLoadDialect`.
+
+```c++
+Dialect *dialect = ctx->getOrLoadDialect("dialectName");
+```
+
### Defining an operation at runtime
The `DynamicOpDefinition` class represents the definition of an operation
diff --git a/mlir/include/mlir/IR/DialectRegistry.h b/mlir/include/mlir/IR/DialectRegistry.h
index a94d9b35545ab..7874813106a26 100644
--- a/mlir/include/mlir/IR/DialectRegistry.h
+++ b/mlir/include/mlir/IR/DialectRegistry.h
@@ -26,6 +26,8 @@ class Dialect;
using DialectAllocatorFunction = std::function<Dialect *(MLIRContext *)>;
using DialectAllocatorFunctionRef = function_ref<Dialect *(MLIRContext *)>;
+using DynamicDialectPopulationFunction =
+ std::function<void(MLIRContext *, DynamicDialect *)>;
//===----------------------------------------------------------------------===//
// DialectExtension
@@ -135,8 +137,15 @@ class DialectRegistry {
void insert(TypeID typeID, StringRef name,
const DialectAllocatorFunction &ctor);
- /// Return an allocation function for constructing the dialect identified by
- /// its namespace, or nullptr if the namespace is not in this registry.
+ /// Add a new dynamic dialect constructor in the registry. The constructor
+ /// provides as argument the created dynamic dialect, and is expected to
+ /// register the dialect types, attributes, and ops, using the
+ /// methods defined in ExtensibleDialect such as registerDynamicOperation.
+ void insertDynamic(StringRef name,
+ const DynamicDialectPopulationFunction &ctor);
+
+ /// Return an allocation function for constructing the dialect identified
+ /// by its namespace, or nullptr if the namespace is not in this registry.
DialectAllocatorFunctionRef getDialectAllocator(StringRef name) const;
// Register all dialects available in the current registry with the registry
diff --git a/mlir/include/mlir/IR/ExtensibleDialect.h b/mlir/include/mlir/IR/ExtensibleDialect.h
index 84dc56749e775..520712f0bbc60 100644
--- a/mlir/include/mlir/IR/ExtensibleDialect.h
+++ b/mlir/include/mlir/IR/ExtensibleDialect.h
@@ -550,6 +550,30 @@ class ExtensibleDialect : public mlir::Dialect {
/// Owns the TypeID generated at runtime for operations.
TypeIDAllocator typeIDAllocator;
};
+
+//===----------------------------------------------------------------------===//
+// Dynamic dialect
+//===----------------------------------------------------------------------===//
+
+/// A dialect that can be defined at runtime. It can be extended with new
+/// operations, types, and attributes at runtime.
+class DynamicDialect : public SelfOwningTypeID, public ExtensibleDialect {
+public:
+ DynamicDialect(StringRef name, MLIRContext *ctx);
+
+ TypeID getTypeID() { return SelfOwningTypeID::getTypeID(); }
+
+ /// Check if the dialect is an extensible dialect.
+ static bool classof(const Dialect *dialect);
+
+ virtual Type parseType(DialectAsmParser &parser) const override;
+ virtual void printType(Type type, DialectAsmPrinter &printer) const override;
+
+ virtual Attribute parseAttribute(DialectAsmParser &parser,
+ Type type) const override;
+ virtual void printAttribute(Attribute attr,
+ DialectAsmPrinter &printer) const override;
+};
} // namespace mlir
namespace llvm {
@@ -561,6 +585,15 @@ struct isa_impl<mlir::ExtensibleDialect, mlir::Dialect> {
return mlir::ExtensibleDialect::classof(&dialect);
}
};
+
+/// Provide isa functionality for DynamicDialect.
+/// This is to override the isa functionality for Dialect.
+template <>
+struct isa_impl<mlir::DynamicDialect, mlir::Dialect> {
+ static inline bool doit(const ::mlir::Dialect &dialect) {
+ return mlir::DynamicDialect::classof(&dialect);
+ }
+};
} // namespace llvm
#endif // MLIR_IR_EXTENSIBLEDIALECT_H
diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h
index a66399dee71d7..c162b00f8402c 100644
--- a/mlir/include/mlir/IR/MLIRContext.h
+++ b/mlir/include/mlir/IR/MLIRContext.h
@@ -24,6 +24,7 @@ class DebugActionManager;
class DiagnosticEngine;
class Dialect;
class DialectRegistry;
+class DynamicDialect;
class InFlightDiagnostic;
class Location;
class MLIRContextImpl;
@@ -110,6 +111,11 @@ class MLIRContext {
loadDialect<OtherDialect, MoreDialects...>();
}
+ /// Get (or create) a dynamic dialect for the given name.
+ DynamicDialect *
+ getOrLoadDynamicDialect(StringRef dialectNamespace,
+ function_ref<void(DynamicDialect *)> ctor);
+
/// Load all dialects available in the registry in this context.
void loadAllAvailableDialects();
diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp
index e72e071d8f95a..c97c800d3f022 100644
--- a/mlir/lib/IR/Dialect.cpp
+++ b/mlir/lib/IR/Dialect.cpp
@@ -11,6 +11,7 @@
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/DialectInterface.h"
+#include "mlir/IR/ExtensibleDialect.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "llvm/ADT/MapVector.h"
@@ -167,6 +168,24 @@ void DialectRegistry::insert(TypeID typeID, StringRef name,
}
}
+void DialectRegistry::insertDynamic(
+ StringRef name, const DynamicDialectPopulationFunction &ctor) {
+ // This TypeID marks dynamic dialects. We cannot give a TypeID for the
+ // dialect yet, since the TypeID of a dynamic dialect is defined at its
+ // construction.
+ TypeID typeID = TypeID::get<void>();
+
+ // Create the dialect, and then call ctor, which allocates its components.
+ auto constructor = [nameStr = name.str(), ctor](MLIRContext *ctx) {
+ auto *dynDialect = ctx->getOrLoadDynamicDialect(
+ nameStr, [ctx, ctor](DynamicDialect *dialect) { ctor(ctx, dialect); });
+ assert(dynDialect && "Dynamic dialect creation unexpectedly failed");
+ return dynDialect;
+ };
+
+ insert(typeID, name, constructor);
+}
+
void DialectRegistry::applyExtensions(Dialect *dialect) const {
MLIRContext *ctx = dialect->getContext();
StringRef dialectName = dialect->getNamespace();
diff --git a/mlir/lib/IR/ExtensibleDialect.cpp b/mlir/lib/IR/ExtensibleDialect.cpp
index 148ba9f04f1a4..41f44f57bfaed 100644
--- a/mlir/lib/IR/ExtensibleDialect.cpp
+++ b/mlir/lib/IR/ExtensibleDialect.cpp
@@ -507,3 +507,84 @@ LogicalResult ExtensibleDialect::printIfDynamicAttr(Attribute attribute,
}
return failure();
}
+
+//===----------------------------------------------------------------------===//
+// Dynamic dialect
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Interface that can only be implemented by extensible dialects.
+/// The interface is used to check if a dialect is extensible or not.
+class IsDynamicDialect : public DialectInterface::Base<IsDynamicDialect> {
+public:
+ IsDynamicDialect(Dialect *dialect) : Base(dialect) {}
+
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(IsDynamicDialect)
+};
+} // namespace
+
+DynamicDialect::DynamicDialect(StringRef name, MLIRContext *ctx)
+ : SelfOwningTypeID(),
+ ExtensibleDialect(name, ctx, SelfOwningTypeID::getTypeID()) {
+ addInterfaces<IsDynamicDialect>();
+}
+
+bool DynamicDialect::classof(const Dialect *dialect) {
+ return const_cast<Dialect *>(dialect)
+ ->getRegisteredInterface<IsDynamicDialect>();
+}
+
+Type DynamicDialect::parseType(DialectAsmParser &parser) const {
+ auto loc = parser.getCurrentLocation();
+ StringRef typeTag;
+ if (failed(parser.parseKeyword(&typeTag)))
+ return Type();
+
+ {
+ Type dynType;
+ auto parseResult = parseOptionalDynamicType(typeTag, parser, dynType);
+ if (parseResult.has_value()) {
+ if (succeeded(parseResult.value()))
+ return dynType;
+ return Type();
+ }
+ }
+
+ parser.emitError(loc, "expected dynamic type");
+ return Type();
+}
+
+void DynamicDialect::printType(Type type, DialectAsmPrinter &printer) const {
+ auto wasDynamic = printIfDynamicType(type, printer);
+ (void)wasDynamic;
+ assert(succeeded(wasDynamic) &&
+ "non-dynamic type defined in dynamic dialect");
+}
+
+Attribute DynamicDialect::parseAttribute(DialectAsmParser &parser,
+ Type type) const {
+ auto loc = parser.getCurrentLocation();
+ StringRef typeTag;
+ if (failed(parser.parseKeyword(&typeTag)))
+ return Attribute();
+
+ {
+ Attribute dynAttr;
+ auto parseResult = parseOptionalDynamicAttr(typeTag, parser, dynAttr);
+ if (parseResult.has_value()) {
+ if (succeeded(parseResult.value()))
+ return dynAttr;
+ return Attribute();
+ }
+ }
+
+ parser.emitError(loc, "expected dynamic attribute");
+ return Attribute();
+}
+void DynamicDialect::printAttribute(Attribute attr,
+ DialectAsmPrinter &printer) const {
+ auto wasDynamic = printIfDynamicAttr(attr, printer);
+ (void)wasDynamic;
+ assert(succeeded(wasDynamic) &&
+ "non-dynamic attribute defined in dynamic dialect");
+}
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 144dc61442b69..3d41823eb6c16 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -18,6 +18,7 @@
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
+#include "mlir/IR/ExtensibleDialect.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/OpImplementation.h"
@@ -455,6 +456,41 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
return dialect.get();
}
+DynamicDialect *MLIRContext::getOrLoadDynamicDialect(
+ StringRef dialectNamespace, function_ref<void(DynamicDialect *)> ctor) {
+ auto &impl = getImpl();
+ // Get the correct insertion position sorted by namespace.
+ auto dialectIt = impl.loadedDialects.find(dialectNamespace);
+
+ if (dialectIt != impl.loadedDialects.end()) {
+ if (auto dynDialect = dyn_cast<DynamicDialect>(dialectIt->second.get()))
+ return dynDialect;
+ llvm::report_fatal_error("a dialect with namespace '" + dialectNamespace +
+ "' has already been registered");
+ }
+
+ LLVM_DEBUG(llvm::dbgs() << "Load new dynamic dialect in Context "
+ << dialectNamespace << "\n");
+#ifndef NDEBUG
+ if (impl.multiThreadedExecutionContext != 0)
+ llvm::report_fatal_error(
+ "Loading a dynamic dialect (" + dialectNamespace +
+ ") while in a multi-threaded execution context (maybe "
+ "the PassManager): this can indicate a "
+ "missing `dependentDialects` in a pass for example.");
+#endif
+
+ auto name = StringAttr::get(this, dialectNamespace);
+ auto *dialect = new DynamicDialect(name, this);
+ (void)getOrLoadDialect(name, dialect->getTypeID(), [dialect, ctor]() {
+ ctor(dialect);
+ return std::unique_ptr<DynamicDialect>(dialect);
+ });
+ // This is the same result as `getOrLoadDialect` (if it didn't failed),
+ // since it has the same TypeID, and TypeIDs are unique.
+ return dialect;
+}
+
void MLIRContext::loadAllAvailableDialects() {
for (StringRef name : getAvailableDialects())
getOrLoadDialect(name);
diff --git a/mlir/test/IR/dynamic.mlir b/mlir/test/IR/dynamic.mlir
index 677fd9894dc85..cf03414c89a35 100644
--- a/mlir/test/IR/dynamic.mlir
+++ b/mlir/test/IR/dynamic.mlir
@@ -124,3 +124,18 @@ func.func @customOpParserPrinter() {
test.dynamic_custom_parser_printer custom_keyword
return
}
+
+//===----------------------------------------------------------------------===//
+// Dynamic dialect
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// Check that the verifier of a dynamic operation in a dynamic dialect
+// can fail. This shows that the dialect is correctly registered.
+
+func.func @failedDynamicDialectOpVerifier() {
+ // expected-error at +1 {{expected a single result, no operands and no regions}}
+ "test_dyn.one_result"() : () -> ()
+ return
+}
diff --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt
index 6bc8635757e44..46b38dcfbc736 100644
--- a/mlir/test/lib/Dialect/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/CMakeLists.txt
@@ -11,6 +11,7 @@ add_subdirectory(Shape)
add_subdirectory(SPIRV)
add_subdirectory(Tensor)
add_subdirectory(Test)
+add_subdirectory(TestDyn)
add_subdirectory(Tosa)
add_subdirectory(Transform)
add_subdirectory(Vector)
diff --git a/mlir/test/lib/Dialect/TestDyn/CMakeLists.txt b/mlir/test/lib/Dialect/TestDyn/CMakeLists.txt
new file mode 100644
index 0000000000000..13eb9040b0744
--- /dev/null
+++ b/mlir/test/lib/Dialect/TestDyn/CMakeLists.txt
@@ -0,0 +1,8 @@
+add_mlir_dialect_library(MLIRTestDynDialect
+ TestDynDialect.cpp
+
+ EXCLUDE_FROM_LIBMLIR
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+)
diff --git a/mlir/test/lib/Dialect/TestDyn/TestDynDialect.cpp b/mlir/test/lib/Dialect/TestDyn/TestDynDialect.cpp
new file mode 100644
index 0000000000000..1757f1dbb873d
--- /dev/null
+++ b/mlir/test/lib/Dialect/TestDyn/TestDynDialect.cpp
@@ -0,0 +1,36 @@
+//===- TestDynDialect.cpp -------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines a fake 'test_dyn' dynamic dialect that is used to test the
+// registration of dynamic dialects.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/ExtensibleDialect.h"
+
+using namespace mlir;
+
+namespace test {
+void registerTestDynDialect(DialectRegistry ®istry) {
+ registry.insertDynamic(
+ "test_dyn", [](MLIRContext *ctx, DynamicDialect *testDyn) {
+ auto opVerifier = [](Operation *op) -> LogicalResult {
+ if (op->getNumOperands() == 0 && op->getNumResults() == 1 &&
+ op->getNumRegions() == 0)
+ return success();
+ return op->emitError(
+ "expected a single result, no operands and no regions");
+ };
+
+ auto opRegionVerifier = [](Operation *op) { return success(); };
+
+ testDyn->registerDynamicOp(DynamicOpDefinition::get(
+ "one_result", testDyn, opVerifier, opRegionVerifier));
+ });
+}
+} // namespace test
diff --git a/mlir/test/mlir-opt/commandline.mlir b/mlir/test/mlir-opt/commandline.mlir
index d4a48f6552afa..a47568ca7ec35 100644
--- a/mlir/test/mlir-opt/commandline.mlir
+++ b/mlir/test/mlir-opt/commandline.mlir
@@ -34,6 +34,7 @@
// CHECK-NEXT: spv
// CHECK-NEXT: tensor
// CHECK-NEXT: test
+// CHECK-NEXT: test_dyn
// CHECK-NEXT: tosa
// CHECK-NEXT: transform
// CHECK-NEXT: vector
diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index 59036b1f467bc..e1a90956e122b 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -27,6 +27,7 @@ if(MLIR_INCLUDE_TESTS)
MLIRTensorTestPasses
MLIRTestAnalysis
MLIRTestDialect
+ MLIRTestDynDialect
MLIRTestIR
MLIRTestPass
MLIRTestPDLL
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 72dcd91c5b869..05bb9e425550d 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -121,6 +121,7 @@ void registerTestNvgpuLowerings();
namespace test {
void registerTestDialect(DialectRegistry &);
void registerTestTransformDialectExtension(DialectRegistry &);
+void registerTestDynDialect(DialectRegistry &);
} // namespace test
#ifdef MLIR_INCLUDE_TESTS
@@ -225,6 +226,7 @@ int main(int argc, char **argv) {
#ifdef MLIR_INCLUDE_TESTS
::test::registerTestDialect(registry);
::test::registerTestTransformDialectExtension(registry);
+ ::test::registerTestDynDialect(registry);
#endif
return mlir::asMainReturnCode(
mlir::MlirOptMain(argc, argv, "MLIR modular optimizer driver\n", registry,
More information about the Mlir-commits
mailing list