[Mlir-commits] [mlir] e0d884d - [mlir][irdl] Add IRDL registration

Arjun P llvmlistbot at llvm.org
Thu Apr 20 07:36:02 PDT 2023


Author: Mathieu Fehr
Date: 2023-04-20T15:35:41+01:00
New Revision: e0d884de360b5c3fe79c6a53f8f88b57f0e42275

URL: https://github.com/llvm/llvm-project/commit/e0d884de360b5c3fe79c6a53f8f88b57f0e42275
DIFF: https://github.com/llvm/llvm-project/commit/e0d884de360b5c3fe79c6a53f8f88b57f0e42275.diff

LOG: [mlir][irdl] Add IRDL registration

This patch add support for loading IRDL dialects at runtime
with `mlir-opt`.

Given the following `dialect.irdl` file:
```mlir
module {
  irdl.dialect @cmath {
    irdl.type @complex {
      %0 = irdl.is f32
      %1 = irdl.is f64
      %2 = irdl.any_of(%0, %1)
      irdl.parameters(%2)
    }

    irdl.operation @norm {
      %0 = irdl.any
      %1 = irdl.parametric @complex<%0>
      irdl.operands(%1)
      irdl.results(%0)
    }
}
```

the IRDL file can be loaded with the `mlir-opt --irdl-file=dialect.irdl`
command, and the following file can then be parsed:

```mlir
func.func @conorm(%p: !cmath.complex<f32>, %q: !cmath.complex<f32>) -> f32 {
  %norm_p = "cmath.norm"(%p) : (!cmath.complex<f32>) -> f32
  %norm_q = "cmath.norm"(%q) : (!cmath.complex<f32>) -> f32
  %pq = arith.mulf %norm_p, %norm_q : f32
  return %pq : f32
}
```

To minimize the size of this patch, the operation, attribute, and type verifier are all always returning `success()`.

Depends on D144692

Reviewed By: rriddle, Mogball, mehdi_amini

Differential Revision: https://reviews.llvm.org/D144693

Added: 
    mlir/include/mlir/Dialect/IRDL/IRDLLoading.h
    mlir/lib/Dialect/IRDL/IRDLLoading.cpp
    mlir/test/Dialect/IRDL/test-cmath.mlir
    mlir/test/Dialect/IRDL/testd.mlir

Modified: 
    mlir/include/mlir/IR/OpDefinition.h
    mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
    mlir/lib/Dialect/IRDL/CMakeLists.txt
    mlir/lib/Tools/mlir-opt/MlirOptMain.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/IRDL/IRDLLoading.h b/mlir/include/mlir/Dialect/IRDL/IRDLLoading.h
new file mode 100644
index 0000000000000..62b6827ac3874
--- /dev/null
+++ b/mlir/include/mlir/Dialect/IRDL/IRDLLoading.h
@@ -0,0 +1,28 @@
+//===- IRDLRegistration.h - IRDL registration -------------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Manages the registration of MLIR objects from IRDL operations.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_IRDL_IRDLREGISTRATION_H
+#define MLIR_DIALECT_IRDL_IRDLREGISTRATION_H
+
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Support/LogicalResult.h"
+
+namespace mlir {
+namespace irdl {
+
+/// Load all the dialects defined in the module.
+LogicalResult loadDialects(ModuleOp op);
+
+} // namespace irdl
+} // namespace mlir
+
+#endif // MLIR_DIALECT_IRDL_IRDLREGISTRATION_H

diff  --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index a86acafcc1e89..5a73d776996f9 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -23,8 +23,8 @@
 #include "mlir/IR/Operation.h"
 #include "llvm/Support/PointerLikeTypeTraits.h"
 
-#include <type_traits>
 #include <optional>
+#include <type_traits>
 
 namespace mlir {
 class Builder;
@@ -633,7 +633,7 @@ class OneTypedResult {
   class Impl
       : public TraitBase<ConcreteType, OneTypedResult<ResultType>::Impl> {
   public:
-   mlir::TypedValue<ResultType> getResult() {
+    mlir::TypedValue<ResultType> getResult() {
       return cast<mlir::TypedValue<ResultType>>(
           this->getOperation()->getResult(0));
     }
@@ -1255,6 +1255,14 @@ struct HasParent {
              << (sizeof...(ParentOpTypes) != 1 ? "to be one of '" : "'")
              << llvm::ArrayRef({ParentOpTypes::getOperationName()...}) << "'";
     }
+
+    template <typename ParentOpType =
+                  std::tuple_element_t<0, std::tuple<ParentOpTypes...>>>
+    std::enable_if_t<sizeof...(ParentOpTypes) == 1, ParentOpType>
+    getParentOp() {
+      Operation *parent = this->getOperation()->getParentOp();
+      return llvm::cast<ParentOpType>(parent);
+    }
   };
 };
 

diff  --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
index f54c29c8f6e31..b4be148b64080 100644
--- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
+++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
@@ -74,6 +74,13 @@ class MlirOptMainConfig {
   }
   bool shouldEmitBytecode() const { return emitBytecodeFlag; }
 
+  /// Set the IRDL file to load before processing the input.
+  MlirOptMainConfig &setIrdlFile(llvm::StringRef file) {
+    irdlFileFlag = file;
+    return *this;
+  }
+  llvm::StringRef getIrdlFile() const { return irdlFileFlag; }
+
   /// Set the filename to use for logging actions, use "-" for stdout.
   MlirOptMainConfig &logActionsTo(StringRef filename) {
     logActionsToFlag = filename;
@@ -157,6 +164,9 @@ class MlirOptMainConfig {
   /// Emit bytecode instead of textual assembly when generating output.
   bool emitBytecodeFlag = false;
 
+  /// IRDL file to register before processing the input.
+  std::string irdlFileFlag = "";
+
   /// Log action execution to the given file (or "-" for stdout)
   std::string logActionsToFlag;
 

diff  --git a/mlir/lib/Dialect/IRDL/CMakeLists.txt b/mlir/lib/Dialect/IRDL/CMakeLists.txt
index 534c48ed5d1be..31efd5e37a665 100644
--- a/mlir/lib/Dialect/IRDL/CMakeLists.txt
+++ b/mlir/lib/Dialect/IRDL/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIRIRDL
   IR/IRDL.cpp
+  IRDLLoading.cpp
 
   DEPENDS
   MLIRIRDLIncGen

diff  --git a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp
new file mode 100644
index 0000000000000..15cd124a42298
--- /dev/null
+++ b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp
@@ -0,0 +1,131 @@
+//===- IRDLLoading.cpp - IRDL dialect loading --------------------- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Manages the loading of MLIR objects from IRDL operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/IRDL/IRDLLoading.h"
+#include "mlir/Dialect/IRDL/IR/IRDL.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/ExtensibleDialect.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/Support/SMLoc.h"
+
+using namespace mlir;
+using namespace mlir::irdl;
+
+/// Define and load an operation represented by a `irdl.operation`
+/// operation.
+static WalkResult loadOperation(OperationOp op, ExtensibleDialect *dialect) {
+  // IRDL does not support defining custom parsers or printers.
+  auto parser = [](OpAsmParser &parser, OperationState &result) {
+    return failure();
+  };
+  auto printer = [](Operation *op, OpAsmPrinter &printer, StringRef) {
+    printer.printGenericOp(op);
+  };
+
+  auto verifier = [](Operation *op) { return success(); };
+
+  // IRDL does not support defining regions.
+  auto regionVerifier = [](Operation *op) { return success(); };
+
+  auto opDef = DynamicOpDefinition::get(
+      op.getName(), dialect, std::move(verifier), std::move(regionVerifier),
+      std::move(parser), std::move(printer));
+  dialect->registerDynamicOp(std::move(opDef));
+
+  return WalkResult::advance();
+}
+
+/// Load all dialects in the given module, without loading any
+/// operation, type or attribute definitions.
+static DenseMap<DialectOp, ExtensibleDialect *> loadEmptyDialects(ModuleOp op) {
+  DenseMap<DialectOp, ExtensibleDialect *> dialects;
+  op.walk([&](DialectOp dialectOp) {
+    MLIRContext *ctx = dialectOp.getContext();
+    StringRef dialectName = dialectOp.getName();
+
+    DynamicDialect *dialect = ctx->getOrLoadDynamicDialect(
+        dialectName, [](DynamicDialect *dialect) {});
+
+    dialects.insert({dialectOp, dialect});
+  });
+  return dialects;
+}
+
+/// Preallocate type definitions objects with empty verifiers.
+/// This in particular allocates a TypeID for each type definition.
+static DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>>
+preallocateTypeDefs(ModuleOp op,
+                    DenseMap<DialectOp, ExtensibleDialect *> dialects) {
+  DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> typeDefs;
+  op.walk([&](TypeOp typeOp) {
+    ExtensibleDialect *dialect = dialects[typeOp.getParentOp()];
+    auto typeDef = DynamicTypeDefinition::get(
+        typeOp.getName(), dialect,
+        [](function_ref<InFlightDiagnostic()>, ArrayRef<Attribute>) {
+          return success();
+        });
+    typeDefs.try_emplace(typeOp, std::move(typeDef));
+  });
+  return typeDefs;
+}
+
+/// Preallocate attribute definitions objects with empty verifiers.
+/// This in particular allocates a TypeID for each attribute definition.
+static DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>>
+preallocateAttrDefs(ModuleOp op,
+                    DenseMap<DialectOp, ExtensibleDialect *> dialects) {
+  DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> attrDefs;
+  op.walk([&](AttributeOp attrOp) {
+    ExtensibleDialect *dialect = dialects[attrOp.getParentOp()];
+    auto attrDef = DynamicAttrDefinition::get(
+        attrOp.getName(), dialect,
+        [](function_ref<InFlightDiagnostic()>, ArrayRef<Attribute>) {
+          return success();
+        });
+    attrDefs.try_emplace(attrOp, std::move(attrDef));
+  });
+  return attrDefs;
+}
+
+LogicalResult mlir::irdl::loadDialects(ModuleOp op) {
+  // Preallocate all dialects, and type and attribute definitions.
+  // In particular, this allocates TypeIDs so type and attributes can have
+  // verifiers that refer to each other.
+  DenseMap<DialectOp, ExtensibleDialect *> dialects = loadEmptyDialects(op);
+  DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> types =
+      preallocateTypeDefs(op, dialects);
+  DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> attrs =
+      preallocateAttrDefs(op, dialects);
+
+  // Define and load all operations.
+  WalkResult res = op.walk([&](OperationOp opOp) {
+    return loadOperation(opOp, dialects[opOp.getParentOp()]);
+  });
+  if (res.wasInterrupted())
+    return failure();
+
+  // Load all types in their dialects.
+  for (auto &pair : types) {
+    ExtensibleDialect *dialect = dialects[pair.first.getParentOp()];
+    dialect->registerDynamicType(std::move(pair.second));
+  }
+
+  // Load all attributes in their dialects.
+  for (auto &pair : attrs) {
+    ExtensibleDialect *dialect = dialects[pair.first.getParentOp()];
+    dialect->registerDynamicAttr(std::move(pair.second));
+  }
+
+  return success();
+}

diff  --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
index 8f608ef145aa8..87c259dcd96c1 100644
--- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
+++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
@@ -16,6 +16,8 @@
 #include "mlir/Debug/Counter.h"
 #include "mlir/Debug/ExecutionContext.h"
 #include "mlir/Debug/Observers/ActionLogging.h"
+#include "mlir/Dialect/IRDL/IR/IRDL.h"
+#include "mlir/Dialect/IRDL/IRDLLoading.h"
 #include "mlir/IR/AsmState.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinOps.h"
@@ -69,6 +71,11 @@ struct MlirOptMainConfigCLOptions : public MlirOptMainConfig {
         "emit-bytecode", cl::desc("Emit bytecode when generating output"),
         cl::location(emitBytecodeFlag), cl::init(false));
 
+    static cl::opt<std::string, /*ExternalStorage=*/true> irdlFile(
+        "irdl-file",
+        cl::desc("IRDL file to register before processing the input"),
+        cl::location(irdlFileFlag), cl::init(""), cl::value_desc("filename"));
+
     static cl::opt<bool, /*ExternalStorage=*/true> explicitModule(
         "no-implicit-module",
         cl::desc("Disable implicit addition of a top-level module op during "
@@ -275,6 +282,35 @@ performActions(raw_ostream &os,
   return success();
 }
 
+LogicalResult loadIRDLDialects(StringRef irdlFile, MLIRContext &ctx) {
+  DialectRegistry registry;
+  registry.insert<irdl::IRDLDialect>();
+  ctx.appendDialectRegistry(registry);
+
+  // Set up the input file.
+  std::string errorMessage;
+  std::unique_ptr<MemoryBuffer> file = openInputFile(irdlFile, &errorMessage);
+  if (!file) {
+    emitError(UnknownLoc::get(&ctx)) << errorMessage;
+    return failure();
+  }
+
+  // Give the buffer to the source manager.
+  // This will be picked up by the parser.
+  SourceMgr sourceMgr;
+  sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc());
+
+  SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &ctx);
+
+  // Parse the input file.
+  OwningOpRef<ModuleOp> module(parseSourceFile<ModuleOp>(sourceMgr, &ctx));
+
+  // Load IRDL dialects.
+  if (irdl::loadDialects(module.get()).failed())
+    return failure();
+  return success();
+}
+
 /// Parses the memory buffer.  If successfully, run a series of passes against
 /// it and print the result.
 static LogicalResult processBuffer(raw_ostream &os,
@@ -292,6 +328,12 @@ static LogicalResult processBuffer(raw_ostream &os,
   if (threadPool)
     context.setThreadPool(*threadPool);
 
+  StringRef irdlFile = config.getIrdlFile();
+  if (!irdlFile.empty()) {
+    if (failed(loadIRDLDialects(irdlFile, context)))
+      return failure();
+  }
+
   // Parse the input file.
   if (config.shouldPreloadDialectsInContext())
     context.loadAllAvailableDialects();

diff  --git a/mlir/test/Dialect/IRDL/test-cmath.mlir b/mlir/test/Dialect/IRDL/test-cmath.mlir
new file mode 100644
index 0000000000000..b7370c4fae730
--- /dev/null
+++ b/mlir/test/Dialect/IRDL/test-cmath.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-opt %s --irdl-file=%S/cmath.irdl.mlir | mlir-opt --irdl-file=%S/cmath.irdl.mlir | FileCheck %s
+
+module {
+  // CHECK: func.func @conorm(%[[p:[^:]*]]: !cmath.complex<f32>, %[[q:[^:]*]]: !cmath.complex<f32>) -> f32 {
+  // CHECK:   %[[norm_p:[^ ]*]] = "cmath.norm"(%[[p]]) : (!cmath.complex<f32>) -> f32
+  // CHECK:   %[[norm_q:[^ ]*]] = "cmath.norm"(%[[q]]) : (!cmath.complex<f32>) -> f32
+  // CHECK:   %[[pq:[^ ]*]] = arith.mulf %[[norm_p]], %[[norm_q]] : f32
+  // CHECK:   return %[[pq]] : f32
+  // CHECK: }
+  func.func @conorm(%p: !cmath.complex<f32>, %q: !cmath.complex<f32>) -> f32 {
+    %norm_p = "cmath.norm"(%p) : (!cmath.complex<f32>) -> f32
+    %norm_q = "cmath.norm"(%q) : (!cmath.complex<f32>) -> f32
+    %pq = arith.mulf %norm_p, %norm_q : f32
+    return %pq : f32
+  }
+
+  // CHECK: func.func @conorm2(%[[p:[^:]*]]: !cmath.complex<f32>, %[[q:[^:]*]]: !cmath.complex<f32>) -> f32 {
+  // CHECK:   %[[pq:[^ ]*]] = "cmath.mul"(%[[p]], %[[q]]) : (!cmath.complex<f32>, !cmath.complex<f32>) -> !cmath.complex<f32>
+  // CHECK:   %[[conorm:[^ ]*]] = "cmath.norm"(%[[pq]]) : (!cmath.complex<f32>) -> f32
+  // CHECK:   return %[[conorm]] : f32
+  // CHECK: }
+  func.func @conorm2(%p: !cmath.complex<f32>, %q: !cmath.complex<f32>) -> f32 {
+    %pq = "cmath.mul"(%p, %q) : (!cmath.complex<f32>, !cmath.complex<f32>) -> !cmath.complex<f32>
+    %conorm = "cmath.norm"(%pq) : (!cmath.complex<f32>) -> f32
+    return %conorm : f32
+  }
+}

diff  --git a/mlir/test/Dialect/IRDL/testd.mlir b/mlir/test/Dialect/IRDL/testd.mlir
new file mode 100644
index 0000000000000..f6d1bcb0e396f
--- /dev/null
+++ b/mlir/test/Dialect/IRDL/testd.mlir
@@ -0,0 +1,108 @@
+// RUN: mlir-opt %s --irdl-file=%S/testd.irdl.mlir -split-input-file -verify-diagnostics | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// Type or attribute constraint
+//===----------------------------------------------------------------------===//
+
+func.func @typeFitsType() {
+  // CHECK: "testd.any"() : () -> !testd.parametric<i32>
+  "testd.any"() : () -> !testd.parametric<i32>
+  return
+}
+
+// -----
+
+func.func @attrDoesntFitType() {
+  "testd.any"() : () -> !testd.parametric<"foo">
+  return
+}
+
+// -----
+
+func.func @attrFitsAttr() {
+  // CHECK: "testd.any"() : () -> !testd.attr_in_type_out<"foo">
+  "testd.any"() : () -> !testd.attr_in_type_out<"foo">
+  return
+}
+
+// -----
+
+func.func @typeFitsAttr() {
+  // CHECK: "testd.any"() : () -> !testd.attr_in_type_out<i32>
+  "testd.any"() : () -> !testd.attr_in_type_out<i32>
+  return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Equality constraint
+//===----------------------------------------------------------------------===//
+
+func.func @succeededEqConstraint() {
+  // CHECK: "testd.eq"() : () -> i32
+  "testd.eq"() : () -> i32
+  return
+}
+
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Any constraint
+//===----------------------------------------------------------------------===//
+
+func.func @succeededAnyConstraint() {
+  // CHECK: "testd.any"() : () -> i32
+  "testd.any"() : () -> i32
+  // CHECK: "testd.any"() : () -> i64
+  "testd.any"() : () -> i64
+  return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Dynamic base constraint
+//===----------------------------------------------------------------------===//
+
+func.func @succeededDynBaseConstraint() {
+  // CHECK: "testd.dynbase"() : () -> !testd.parametric<i32>
+  "testd.dynbase"() : () -> !testd.parametric<i32>
+  // CHECK: "testd.dynbase"() : () -> !testd.parametric<!testd.parametric<i32>>
+  "testd.dynbase"() : () -> !testd.parametric<!testd.parametric<i32>>
+  return
+}
+
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Dynamic parameters constraint
+//===----------------------------------------------------------------------===//
+
+func.func @succeededDynParamsConstraint() {
+  // CHECK: "testd.dynparams"() : () -> !testd.parametric<i32>
+  "testd.dynparams"() : () -> !testd.parametric<i32>
+  return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Constraint variables
+//===----------------------------------------------------------------------===//
+
+func.func @succeededConstraintVars() {
+  // CHECK: "testd.constraint_vars"() : () -> (i32, i32)
+  "testd.constraint_vars"() : () -> (i32, i32)
+  return
+}
+
+// -----
+
+func.func @succeededConstraintVars2() {
+  // CHECK: "testd.constraint_vars"() : () -> (i64, i64)
+  "testd.constraint_vars"() : () -> (i64, i64)
+  return
+}


        


More information about the Mlir-commits mailing list