[Mlir-commits] [mlir] ba8424a - [mlir] Add Dynamic Dialects

River Riddle llvmlistbot at llvm.org
Mon Sep 19 09:58:31 PDT 2022


Author: Mathieu Fehr
Date: 2022-09-19T09:58:18-07:00
New Revision: ba8424a251d72756d3d697fd0b208c1c29f3b99f

URL: https://github.com/llvm/llvm-project/commit/ba8424a251d72756d3d697fd0b208c1c29f3b99f
DIFF: https://github.com/llvm/llvm-project/commit/ba8424a251d72756d3d697fd0b208c1c29f3b99f.diff

LOG: [mlir] Add Dynamic Dialects

Dynamic dialects are dialects that can be defined at runtime.
Dynamic dialects are extensible by new operations, types, and
attributes at runtime.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D125201

Added: 
    mlir/test/lib/Dialect/TestDyn/CMakeLists.txt
    mlir/test/lib/Dialect/TestDyn/TestDynDialect.cpp

Modified: 
    mlir/docs/DefiningDialects.md
    mlir/include/mlir/IR/DialectRegistry.h
    mlir/include/mlir/IR/ExtensibleDialect.h
    mlir/include/mlir/IR/MLIRContext.h
    mlir/lib/IR/Dialect.cpp
    mlir/lib/IR/ExtensibleDialect.cpp
    mlir/lib/IR/MLIRContext.cpp
    mlir/test/IR/dynamic.mlir
    mlir/test/lib/Dialect/CMakeLists.txt
    mlir/test/mlir-opt/commandline.mlir
    mlir/tools/mlir-opt/CMakeLists.txt
    mlir/tools/mlir-opt/mlir-opt.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/docs/DefiningDialects.md b/mlir/docs/DefiningDialects.md
index 1b5a6fb48a150..7ef6ae0235ef5 100644
--- a/mlir/docs/DefiningDialects.md
+++ b/mlir/docs/DefiningDialects.md
@@ -372,6 +372,30 @@ if (auto extensibleDialect = llvm::dyn_cast<ExtensibleDialect>(dialect)) {
 }
 ```
 
+### Defining a dynamic dialect
+
+Dynamic dialects are extensible dialects that can be defined at runtime. They
+are only populated with dynamic operations, types, and attributes. They can be
+registered in a `DialectRegistry` with `insertDynamic`.
+
+```c++
+auto populateDialect = [](MLIRContext *ctx, DynamicDialect* dialect) {
+  // Code that will be ran when the dynamic dialect is created and loaded.
+  // For instance, this is where we register the dynamic operations, types, and
+  // attributes of the dialect.
+  ...
+}
+
+registry.insertDynamic("dialectName", populateDialect);
+```
+
+Once a dynamic dialect is registered in the `MLIRContext`, it can be retrieved
+with `getOrLoadDialect`.
+
+```c++
+Dialect *dialect = ctx->getOrLoadDialect("dialectName");
+```
+
 ### Defining an operation at runtime
 
 The `DynamicOpDefinition` class represents the definition of an operation

diff  --git a/mlir/include/mlir/IR/DialectRegistry.h b/mlir/include/mlir/IR/DialectRegistry.h
index a94d9b35545ab..7874813106a26 100644
--- a/mlir/include/mlir/IR/DialectRegistry.h
+++ b/mlir/include/mlir/IR/DialectRegistry.h
@@ -26,6 +26,8 @@ class Dialect;
 
 using DialectAllocatorFunction = std::function<Dialect *(MLIRContext *)>;
 using DialectAllocatorFunctionRef = function_ref<Dialect *(MLIRContext *)>;
+using DynamicDialectPopulationFunction =
+    std::function<void(MLIRContext *, DynamicDialect *)>;
 
 //===----------------------------------------------------------------------===//
 // DialectExtension
@@ -135,8 +137,15 @@ class DialectRegistry {
   void insert(TypeID typeID, StringRef name,
               const DialectAllocatorFunction &ctor);
 
-  /// Return an allocation function for constructing the dialect identified by
-  /// its namespace, or nullptr if the namespace is not in this registry.
+  /// Add a new dynamic dialect constructor in the registry. The constructor
+  /// provides as argument the created dynamic dialect, and is expected to
+  /// register the dialect types, attributes, and ops, using the
+  /// methods defined in ExtensibleDialect such as registerDynamicOperation.
+  void insertDynamic(StringRef name,
+                     const DynamicDialectPopulationFunction &ctor);
+
+  /// Return an allocation function for constructing the dialect identified
+  /// by its namespace, or nullptr if the namespace is not in this registry.
   DialectAllocatorFunctionRef getDialectAllocator(StringRef name) const;
 
   // Register all dialects available in the current registry with the registry

diff  --git a/mlir/include/mlir/IR/ExtensibleDialect.h b/mlir/include/mlir/IR/ExtensibleDialect.h
index 84dc56749e775..520712f0bbc60 100644
--- a/mlir/include/mlir/IR/ExtensibleDialect.h
+++ b/mlir/include/mlir/IR/ExtensibleDialect.h
@@ -550,6 +550,30 @@ class ExtensibleDialect : public mlir::Dialect {
   /// Owns the TypeID generated at runtime for operations.
   TypeIDAllocator typeIDAllocator;
 };
+
+//===----------------------------------------------------------------------===//
+// Dynamic dialect
+//===----------------------------------------------------------------------===//
+
+/// A dialect that can be defined at runtime. It can be extended with new
+/// operations, types, and attributes at runtime.
+class DynamicDialect : public SelfOwningTypeID, public ExtensibleDialect {
+public:
+  DynamicDialect(StringRef name, MLIRContext *ctx);
+
+  TypeID getTypeID() { return SelfOwningTypeID::getTypeID(); }
+
+  /// Check if the dialect is an extensible dialect.
+  static bool classof(const Dialect *dialect);
+
+  virtual Type parseType(DialectAsmParser &parser) const override;
+  virtual void printType(Type type, DialectAsmPrinter &printer) const override;
+
+  virtual Attribute parseAttribute(DialectAsmParser &parser,
+                                   Type type) const override;
+  virtual void printAttribute(Attribute attr,
+                              DialectAsmPrinter &printer) const override;
+};
 } // namespace mlir
 
 namespace llvm {
@@ -561,6 +585,15 @@ struct isa_impl<mlir::ExtensibleDialect, mlir::Dialect> {
     return mlir::ExtensibleDialect::classof(&dialect);
   }
 };
+
+/// Provide isa functionality for DynamicDialect.
+/// This is to override the isa functionality for Dialect.
+template <>
+struct isa_impl<mlir::DynamicDialect, mlir::Dialect> {
+  static inline bool doit(const ::mlir::Dialect &dialect) {
+    return mlir::DynamicDialect::classof(&dialect);
+  }
+};
 } // namespace llvm
 
 #endif // MLIR_IR_EXTENSIBLEDIALECT_H

diff  --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h
index a66399dee71d7..c162b00f8402c 100644
--- a/mlir/include/mlir/IR/MLIRContext.h
+++ b/mlir/include/mlir/IR/MLIRContext.h
@@ -24,6 +24,7 @@ class DebugActionManager;
 class DiagnosticEngine;
 class Dialect;
 class DialectRegistry;
+class DynamicDialect;
 class InFlightDiagnostic;
 class Location;
 class MLIRContextImpl;
@@ -110,6 +111,11 @@ class MLIRContext {
     loadDialect<OtherDialect, MoreDialects...>();
   }
 
+  /// Get (or create) a dynamic dialect for the given name.
+  DynamicDialect *
+  getOrLoadDynamicDialect(StringRef dialectNamespace,
+                          function_ref<void(DynamicDialect *)> ctor);
+
   /// Load all dialects available in the registry in this context.
   void loadAllAvailableDialects();
 

diff  --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp
index e72e071d8f95a..c97c800d3f022 100644
--- a/mlir/lib/IR/Dialect.cpp
+++ b/mlir/lib/IR/Dialect.cpp
@@ -11,6 +11,7 @@
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/DialectInterface.h"
+#include "mlir/IR/ExtensibleDialect.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Operation.h"
 #include "llvm/ADT/MapVector.h"
@@ -167,6 +168,24 @@ void DialectRegistry::insert(TypeID typeID, StringRef name,
   }
 }
 
+void DialectRegistry::insertDynamic(
+    StringRef name, const DynamicDialectPopulationFunction &ctor) {
+  // This TypeID marks dynamic dialects. We cannot give a TypeID for the
+  // dialect yet, since the TypeID of a dynamic dialect is defined at its
+  // construction.
+  TypeID typeID = TypeID::get<void>();
+
+  // Create the dialect, and then call ctor, which allocates its components.
+  auto constructor = [nameStr = name.str(), ctor](MLIRContext *ctx) {
+    auto *dynDialect = ctx->getOrLoadDynamicDialect(
+        nameStr, [ctx, ctor](DynamicDialect *dialect) { ctor(ctx, dialect); });
+    assert(dynDialect && "Dynamic dialect creation unexpectedly failed");
+    return dynDialect;
+  };
+
+  insert(typeID, name, constructor);
+}
+
 void DialectRegistry::applyExtensions(Dialect *dialect) const {
   MLIRContext *ctx = dialect->getContext();
   StringRef dialectName = dialect->getNamespace();

diff  --git a/mlir/lib/IR/ExtensibleDialect.cpp b/mlir/lib/IR/ExtensibleDialect.cpp
index 148ba9f04f1a4..41f44f57bfaed 100644
--- a/mlir/lib/IR/ExtensibleDialect.cpp
+++ b/mlir/lib/IR/ExtensibleDialect.cpp
@@ -507,3 +507,84 @@ LogicalResult ExtensibleDialect::printIfDynamicAttr(Attribute attribute,
   }
   return failure();
 }
+
+//===----------------------------------------------------------------------===//
+// Dynamic dialect
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Interface that can only be implemented by extensible dialects.
+/// The interface is used to check if a dialect is extensible or not.
+class IsDynamicDialect : public DialectInterface::Base<IsDynamicDialect> {
+public:
+  IsDynamicDialect(Dialect *dialect) : Base(dialect) {}
+
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(IsDynamicDialect)
+};
+} // namespace
+
+DynamicDialect::DynamicDialect(StringRef name, MLIRContext *ctx)
+    : SelfOwningTypeID(),
+      ExtensibleDialect(name, ctx, SelfOwningTypeID::getTypeID()) {
+  addInterfaces<IsDynamicDialect>();
+}
+
+bool DynamicDialect::classof(const Dialect *dialect) {
+  return const_cast<Dialect *>(dialect)
+      ->getRegisteredInterface<IsDynamicDialect>();
+}
+
+Type DynamicDialect::parseType(DialectAsmParser &parser) const {
+  auto loc = parser.getCurrentLocation();
+  StringRef typeTag;
+  if (failed(parser.parseKeyword(&typeTag)))
+    return Type();
+
+  {
+    Type dynType;
+    auto parseResult = parseOptionalDynamicType(typeTag, parser, dynType);
+    if (parseResult.has_value()) {
+      if (succeeded(parseResult.value()))
+        return dynType;
+      return Type();
+    }
+  }
+
+  parser.emitError(loc, "expected dynamic type");
+  return Type();
+}
+
+void DynamicDialect::printType(Type type, DialectAsmPrinter &printer) const {
+  auto wasDynamic = printIfDynamicType(type, printer);
+  (void)wasDynamic;
+  assert(succeeded(wasDynamic) &&
+         "non-dynamic type defined in dynamic dialect");
+}
+
+Attribute DynamicDialect::parseAttribute(DialectAsmParser &parser,
+                                         Type type) const {
+  auto loc = parser.getCurrentLocation();
+  StringRef typeTag;
+  if (failed(parser.parseKeyword(&typeTag)))
+    return Attribute();
+
+  {
+    Attribute dynAttr;
+    auto parseResult = parseOptionalDynamicAttr(typeTag, parser, dynAttr);
+    if (parseResult.has_value()) {
+      if (succeeded(parseResult.value()))
+        return dynAttr;
+      return Attribute();
+    }
+  }
+
+  parser.emitError(loc, "expected dynamic attribute");
+  return Attribute();
+}
+void DynamicDialect::printAttribute(Attribute attr,
+                                    DialectAsmPrinter &printer) const {
+  auto wasDynamic = printIfDynamicAttr(attr, printer);
+  (void)wasDynamic;
+  assert(succeeded(wasDynamic) &&
+         "non-dynamic attribute defined in dynamic dialect");
+}

diff  --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 144dc61442b69..3d41823eb6c16 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -18,6 +18,7 @@
 #include "mlir/IR/BuiltinDialect.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/Dialect.h"
+#include "mlir/IR/ExtensibleDialect.h"
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/IR/Location.h"
 #include "mlir/IR/OpImplementation.h"
@@ -455,6 +456,41 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
   return dialect.get();
 }
 
+DynamicDialect *MLIRContext::getOrLoadDynamicDialect(
+    StringRef dialectNamespace, function_ref<void(DynamicDialect *)> ctor) {
+  auto &impl = getImpl();
+  // Get the correct insertion position sorted by namespace.
+  auto dialectIt = impl.loadedDialects.find(dialectNamespace);
+
+  if (dialectIt != impl.loadedDialects.end()) {
+    if (auto dynDialect = dyn_cast<DynamicDialect>(dialectIt->second.get()))
+      return dynDialect;
+    llvm::report_fatal_error("a dialect with namespace '" + dialectNamespace +
+                             "' has already been registered");
+  }
+
+  LLVM_DEBUG(llvm::dbgs() << "Load new dynamic dialect in Context "
+                          << dialectNamespace << "\n");
+#ifndef NDEBUG
+  if (impl.multiThreadedExecutionContext != 0)
+    llvm::report_fatal_error(
+        "Loading a dynamic dialect (" + dialectNamespace +
+        ") while in a multi-threaded execution context (maybe "
+        "the PassManager): this can indicate a "
+        "missing `dependentDialects` in a pass for example.");
+#endif
+
+  auto name = StringAttr::get(this, dialectNamespace);
+  auto *dialect = new DynamicDialect(name, this);
+  (void)getOrLoadDialect(name, dialect->getTypeID(), [dialect, ctor]() {
+    ctor(dialect);
+    return std::unique_ptr<DynamicDialect>(dialect);
+  });
+  // This is the same result as `getOrLoadDialect` (if it didn't failed),
+  // since it has the same TypeID, and TypeIDs are unique.
+  return dialect;
+}
+
 void MLIRContext::loadAllAvailableDialects() {
   for (StringRef name : getAvailableDialects())
     getOrLoadDialect(name);

diff  --git a/mlir/test/IR/dynamic.mlir b/mlir/test/IR/dynamic.mlir
index 677fd9894dc85..cf03414c89a35 100644
--- a/mlir/test/IR/dynamic.mlir
+++ b/mlir/test/IR/dynamic.mlir
@@ -124,3 +124,18 @@ func.func @customOpParserPrinter() {
   test.dynamic_custom_parser_printer custom_keyword
   return
 }
+
+//===----------------------------------------------------------------------===//
+// Dynamic dialect
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// Check that the verifier of a dynamic operation in a dynamic dialect
+// can fail. This shows that the dialect is correctly registered.
+
+func.func @failedDynamicDialectOpVerifier() {
+  // expected-error at +1 {{expected a single result, no operands and no regions}}
+  "test_dyn.one_result"() : () -> ()
+  return
+}

diff  --git a/mlir/test/lib/Dialect/CMakeLists.txt b/mlir/test/lib/Dialect/CMakeLists.txt
index 6bc8635757e44..46b38dcfbc736 100644
--- a/mlir/test/lib/Dialect/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/CMakeLists.txt
@@ -11,6 +11,7 @@ add_subdirectory(Shape)
 add_subdirectory(SPIRV)
 add_subdirectory(Tensor)
 add_subdirectory(Test)
+add_subdirectory(TestDyn)
 add_subdirectory(Tosa)
 add_subdirectory(Transform)
 add_subdirectory(Vector)

diff  --git a/mlir/test/lib/Dialect/TestDyn/CMakeLists.txt b/mlir/test/lib/Dialect/TestDyn/CMakeLists.txt
new file mode 100644
index 0000000000000..13eb9040b0744
--- /dev/null
+++ b/mlir/test/lib/Dialect/TestDyn/CMakeLists.txt
@@ -0,0 +1,8 @@
+add_mlir_dialect_library(MLIRTestDynDialect
+  TestDynDialect.cpp
+
+  EXCLUDE_FROM_LIBMLIR
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+)

diff  --git a/mlir/test/lib/Dialect/TestDyn/TestDynDialect.cpp b/mlir/test/lib/Dialect/TestDyn/TestDynDialect.cpp
new file mode 100644
index 0000000000000..1757f1dbb873d
--- /dev/null
+++ b/mlir/test/lib/Dialect/TestDyn/TestDynDialect.cpp
@@ -0,0 +1,36 @@
+//===- TestDynDialect.cpp -------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines a fake 'test_dyn' dynamic dialect that is used to test the
+// registration of dynamic dialects.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/ExtensibleDialect.h"
+
+using namespace mlir;
+
+namespace test {
+void registerTestDynDialect(DialectRegistry &registry) {
+  registry.insertDynamic(
+      "test_dyn", [](MLIRContext *ctx, DynamicDialect *testDyn) {
+        auto opVerifier = [](Operation *op) -> LogicalResult {
+          if (op->getNumOperands() == 0 && op->getNumResults() == 1 &&
+              op->getNumRegions() == 0)
+            return success();
+          return op->emitError(
+              "expected a single result, no operands and no regions");
+        };
+
+        auto opRegionVerifier = [](Operation *op) { return success(); };
+
+        testDyn->registerDynamicOp(DynamicOpDefinition::get(
+            "one_result", testDyn, opVerifier, opRegionVerifier));
+      });
+}
+} // namespace test

diff  --git a/mlir/test/mlir-opt/commandline.mlir b/mlir/test/mlir-opt/commandline.mlir
index d4a48f6552afa..a47568ca7ec35 100644
--- a/mlir/test/mlir-opt/commandline.mlir
+++ b/mlir/test/mlir-opt/commandline.mlir
@@ -34,6 +34,7 @@
 // CHECK-NEXT: spv
 // CHECK-NEXT: tensor
 // CHECK-NEXT: test
+// CHECK-NEXT: test_dyn
 // CHECK-NEXT: tosa
 // CHECK-NEXT: transform
 // CHECK-NEXT: vector

diff  --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index 59036b1f467bc..e1a90956e122b 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -27,6 +27,7 @@ if(MLIR_INCLUDE_TESTS)
     MLIRTensorTestPasses
     MLIRTestAnalysis
     MLIRTestDialect
+    MLIRTestDynDialect
     MLIRTestIR
     MLIRTestPass
     MLIRTestPDLL

diff  --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 72dcd91c5b869..05bb9e425550d 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -121,6 +121,7 @@ void registerTestNvgpuLowerings();
 namespace test {
 void registerTestDialect(DialectRegistry &);
 void registerTestTransformDialectExtension(DialectRegistry &);
+void registerTestDynDialect(DialectRegistry &);
 } // namespace test
 
 #ifdef MLIR_INCLUDE_TESTS
@@ -225,6 +226,7 @@ int main(int argc, char **argv) {
 #ifdef MLIR_INCLUDE_TESTS
   ::test::registerTestDialect(registry);
   ::test::registerTestTransformDialectExtension(registry);
+  ::test::registerTestDynDialect(registry);
 #endif
   return mlir::asMainReturnCode(
       mlir::MlirOptMain(argc, argv, "MLIR modular optimizer driver\n", registry,


        


More information about the Mlir-commits mailing list