[Mlir-commits] [mlir] 8ac8c92 - [mlir][irdl] Add IRDL registration
Mathieu Fehr
llvmlistbot at llvm.org
Sun Apr 23 08:29:01 PDT 2023
Author: Mathieu Fehr
Date: 2023-04-23T17:28:44+01:00
New Revision: 8ac8c922fb3f15706f5cb1db2cc655d30b095766
URL: https://github.com/llvm/llvm-project/commit/8ac8c922fb3f15706f5cb1db2cc655d30b095766
DIFF: https://github.com/llvm/llvm-project/commit/8ac8c922fb3f15706f5cb1db2cc655d30b095766.diff
LOG: [mlir][irdl] Add IRDL registration
This patch add support for loading IRDL dialects at runtime
with `mlir-opt`.
Given the following `dialect.irdl` file:
```mlir
module {
irdl.dialect @cmath {
irdl.type @complex {
%0 = irdl.is f32
%1 = irdl.is f64
%2 = irdl.any_of(%0, %1)
irdl.parameters(%2)
}
irdl.operation @norm {
%0 = irdl.any
%1 = irdl.parametric @complex<%0>
irdl.operands(%1)
irdl.results(%0)
}
}
```
the IRDL file can be loaded with the `mlir-opt --irdl-file=dialect.irdl`
command, and the following file can then be parsed:
```mlir
func.func @conorm(%p: !cmath.complex<f32>, %q: !cmath.complex<f32>) -> f32 {
%norm_p = "cmath.norm"(%p) : (!cmath.complex<f32>) -> f32
%norm_q = "cmath.norm"(%q) : (!cmath.complex<f32>) -> f32
%pq = arith.mulf %norm_p, %norm_q : f32
return %pq : f32
}
```
To minimize the size of this patch, the operation, attribute, and type verifier are all always returning `success()`.
Depends on D144692
Reviewed By: rriddle, Mogball, mehdi_amini
Differential Revision: https://reviews.llvm.org/D144693
Added:
mlir/include/mlir/Dialect/IRDL/IRDLLoading.h
mlir/lib/Dialect/IRDL/IRDLLoading.cpp
mlir/test/Dialect/IRDL/test-cmath.mlir
mlir/test/Dialect/IRDL/testd.mlir
Modified:
mlir/include/mlir/IR/OpDefinition.h
mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
mlir/lib/Dialect/IRDL/CMakeLists.txt
mlir/lib/Tools/mlir-opt/CMakeLists.txt
mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/IRDL/IRDLLoading.h b/mlir/include/mlir/Dialect/IRDL/IRDLLoading.h
new file mode 100644
index 0000000000000..64ad72b1a982a
--- /dev/null
+++ b/mlir/include/mlir/Dialect/IRDL/IRDLLoading.h
@@ -0,0 +1,30 @@
+//===- IRDLRegistration.h - IRDL registration -------------------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Manages the registration of MLIR objects from IRDL operations.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_IRDL_IRDLREGISTRATION_H
+#define MLIR_DIALECT_IRDL_IRDLREGISTRATION_H
+
+namespace mlir {
+struct LogicalResult;
+class ModuleOp;
+} // namespace mlir
+
+namespace mlir {
+namespace irdl {
+
+/// Load all the dialects defined in the module.
+LogicalResult loadDialects(ModuleOp op);
+
+} // namespace irdl
+} // namespace mlir
+
+#endif // MLIR_DIALECT_IRDL_IRDLREGISTRATION_H
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index a86acafcc1e89..5a73d776996f9 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -23,8 +23,8 @@
#include "mlir/IR/Operation.h"
#include "llvm/Support/PointerLikeTypeTraits.h"
-#include <type_traits>
#include <optional>
+#include <type_traits>
namespace mlir {
class Builder;
@@ -633,7 +633,7 @@ class OneTypedResult {
class Impl
: public TraitBase<ConcreteType, OneTypedResult<ResultType>::Impl> {
public:
- mlir::TypedValue<ResultType> getResult() {
+ mlir::TypedValue<ResultType> getResult() {
return cast<mlir::TypedValue<ResultType>>(
this->getOperation()->getResult(0));
}
@@ -1255,6 +1255,14 @@ struct HasParent {
<< (sizeof...(ParentOpTypes) != 1 ? "to be one of '" : "'")
<< llvm::ArrayRef({ParentOpTypes::getOperationName()...}) << "'";
}
+
+ template <typename ParentOpType =
+ std::tuple_element_t<0, std::tuple<ParentOpTypes...>>>
+ std::enable_if_t<sizeof...(ParentOpTypes) == 1, ParentOpType>
+ getParentOp() {
+ Operation *parent = this->getOperation()->getParentOp();
+ return llvm::cast<ParentOpType>(parent);
+ }
};
};
diff --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
index 39f7cd5e0bd80..c6e5906565e01 100644
--- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
+++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
@@ -78,6 +78,13 @@ class MlirOptMainConfig {
}
bool shouldEmitBytecode() const { return emitBytecodeFlag; }
+ /// Set the IRDL file to load before processing the input.
+ MlirOptMainConfig &setIrdlFile(StringRef file) {
+ irdlFileFlag = file;
+ return *this;
+ }
+ StringRef getIrdlFile() const { return irdlFileFlag; }
+
/// Set the filename to use for logging actions, use "-" for stdout.
MlirOptMainConfig &logActionsTo(StringRef filename) {
logActionsToFlag = filename;
@@ -173,6 +180,9 @@ class MlirOptMainConfig {
/// Emit bytecode instead of textual assembly when generating output.
bool emitBytecodeFlag = false;
+ /// IRDL file to register before processing the input.
+ std::string irdlFileFlag = "";
+
/// Log action execution to the given file (or "-" for stdout)
std::string logActionsToFlag;
diff --git a/mlir/lib/Dialect/IRDL/CMakeLists.txt b/mlir/lib/Dialect/IRDL/CMakeLists.txt
index 534c48ed5d1be..31efd5e37a665 100644
--- a/mlir/lib/Dialect/IRDL/CMakeLists.txt
+++ b/mlir/lib/Dialect/IRDL/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIRIRDL
IR/IRDL.cpp
+ IRDLLoading.cpp
DEPENDS
MLIRIRDLIncGen
diff --git a/mlir/lib/Dialect/IRDL/IRDLLoading.cpp b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp
new file mode 100644
index 0000000000000..fb00085a7ee07
--- /dev/null
+++ b/mlir/lib/Dialect/IRDL/IRDLLoading.cpp
@@ -0,0 +1,131 @@
+//===- IRDLLoading.cpp - IRDL dialect loading --------------------- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Manages the loading of MLIR objects from IRDL operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/IRDL/IRDLLoading.h"
+#include "mlir/Dialect/IRDL/IR/IRDL.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/ExtensibleDialect.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/Support/SMLoc.h"
+
+using namespace mlir;
+using namespace mlir::irdl;
+
+/// Define and load an operation represented by a `irdl.operation`
+/// operation.
+static WalkResult loadOperation(OperationOp op, ExtensibleDialect *dialect) {
+ // IRDL does not support defining custom parsers or printers.
+ auto parser = [](OpAsmParser &parser, OperationState &result) {
+ return failure();
+ };
+ auto printer = [](Operation *op, OpAsmPrinter &printer, StringRef) {
+ printer.printGenericOp(op);
+ };
+
+ auto verifier = [](Operation *op) { return success(); };
+
+ // IRDL does not support defining regions.
+ auto regionVerifier = [](Operation *op) { return success(); };
+
+ auto opDef = DynamicOpDefinition::get(
+ op.getName(), dialect, std::move(verifier), std::move(regionVerifier),
+ std::move(parser), std::move(printer));
+ dialect->registerDynamicOp(std::move(opDef));
+
+ return WalkResult::advance();
+}
+
+/// Load all dialects in the given module, without loading any operation, type
+/// or attribute definitions.
+static DenseMap<DialectOp, ExtensibleDialect *> loadEmptyDialects(ModuleOp op) {
+ DenseMap<DialectOp, ExtensibleDialect *> dialects;
+ op.walk([&](DialectOp dialectOp) {
+ MLIRContext *ctx = dialectOp.getContext();
+ StringRef dialectName = dialectOp.getName();
+
+ DynamicDialect *dialect = ctx->getOrLoadDynamicDialect(
+ dialectName, [](DynamicDialect *dialect) {});
+
+ dialects.insert({dialectOp, dialect});
+ });
+ return dialects;
+}
+
+/// Preallocate type definitions objects with empty verifiers.
+/// This in particular allocates a TypeID for each type definition.
+static DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>>
+preallocateTypeDefs(ModuleOp op,
+ DenseMap<DialectOp, ExtensibleDialect *> dialects) {
+ DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> typeDefs;
+ op.walk([&](TypeOp typeOp) {
+ ExtensibleDialect *dialect = dialects[typeOp.getParentOp()];
+ auto typeDef = DynamicTypeDefinition::get(
+ typeOp.getName(), dialect,
+ [](function_ref<InFlightDiagnostic()>, ArrayRef<Attribute>) {
+ return success();
+ });
+ typeDefs.try_emplace(typeOp, std::move(typeDef));
+ });
+ return typeDefs;
+}
+
+/// Preallocate attribute definitions objects with empty verifiers.
+/// This in particular allocates a TypeID for each attribute definition.
+static DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>>
+preallocateAttrDefs(ModuleOp op,
+ DenseMap<DialectOp, ExtensibleDialect *> dialects) {
+ DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> attrDefs;
+ op.walk([&](AttributeOp attrOp) {
+ ExtensibleDialect *dialect = dialects[attrOp.getParentOp()];
+ auto attrDef = DynamicAttrDefinition::get(
+ attrOp.getName(), dialect,
+ [](function_ref<InFlightDiagnostic()>, ArrayRef<Attribute>) {
+ return success();
+ });
+ attrDefs.try_emplace(attrOp, std::move(attrDef));
+ });
+ return attrDefs;
+}
+
+LogicalResult mlir::irdl::loadDialects(ModuleOp op) {
+ // Preallocate all dialects, and type and attribute definitions.
+ // In particular, this allocates TypeIDs so type and attributes can have
+ // verifiers that refer to each other.
+ DenseMap<DialectOp, ExtensibleDialect *> dialects = loadEmptyDialects(op);
+ DenseMap<TypeOp, std::unique_ptr<DynamicTypeDefinition>> types =
+ preallocateTypeDefs(op, dialects);
+ DenseMap<AttributeOp, std::unique_ptr<DynamicAttrDefinition>> attrs =
+ preallocateAttrDefs(op, dialects);
+
+ // Define and load all operations.
+ WalkResult res = op.walk([&](OperationOp opOp) {
+ return loadOperation(opOp, dialects[opOp.getParentOp()]);
+ });
+ if (res.wasInterrupted())
+ return failure();
+
+ // Load all types in their dialects.
+ for (auto &pair : types) {
+ ExtensibleDialect *dialect = dialects[pair.first.getParentOp()];
+ dialect->registerDynamicType(std::move(pair.second));
+ }
+
+ // Load all attributes in their dialects.
+ for (auto &pair : attrs) {
+ ExtensibleDialect *dialect = dialects[pair.first.getParentOp()];
+ dialect->registerDynamicAttr(std::move(pair.second));
+ }
+
+ return success();
+}
diff --git a/mlir/lib/Tools/mlir-opt/CMakeLists.txt b/mlir/lib/Tools/mlir-opt/CMakeLists.txt
index a15677eda42b2..f24d4c60174ee 100644
--- a/mlir/lib/Tools/mlir-opt/CMakeLists.txt
+++ b/mlir/lib/Tools/mlir-opt/CMakeLists.txt
@@ -12,4 +12,5 @@ add_mlir_library(MLIROptLib
MLIRParser
MLIRPluginsLib
MLIRSupport
+ MLIRIRDL
)
diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
index 28324508ee4f3..b9e65b1b8e622 100644
--- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
+++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
@@ -16,6 +16,8 @@
#include "mlir/Debug/Counter.h"
#include "mlir/Debug/ExecutionContext.h"
#include "mlir/Debug/Observers/ActionLogging.h"
+#include "mlir/Dialect/IRDL/IR/IRDL.h"
+#include "mlir/Dialect/IRDL/IRDLLoading.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinOps.h"
@@ -70,6 +72,11 @@ struct MlirOptMainConfigCLOptions : public MlirOptMainConfig {
"emit-bytecode", cl::desc("Emit bytecode when generating output"),
cl::location(emitBytecodeFlag), cl::init(false));
+ static cl::opt<std::string, /*ExternalStorage=*/true> irdlFile(
+ "irdl-file",
+ cl::desc("IRDL file to register before processing the input"),
+ cl::location(irdlFileFlag), cl::init(""), cl::value_desc("filename"));
+
static cl::opt<bool, /*ExternalStorage=*/true> explicitModule(
"no-implicit-module",
cl::desc("Disable implicit addition of a top-level module op during "
@@ -310,6 +317,33 @@ performActions(raw_ostream &os,
return success();
}
+LogicalResult loadIRDLDialects(StringRef irdlFile, MLIRContext &ctx) {
+ DialectRegistry registry;
+ registry.insert<irdl::IRDLDialect>();
+ ctx.appendDialectRegistry(registry);
+
+ // Set up the input file.
+ std::string errorMessage;
+ std::unique_ptr<MemoryBuffer> file = openInputFile(irdlFile, &errorMessage);
+ if (!file) {
+ emitError(UnknownLoc::get(&ctx)) << errorMessage;
+ return failure();
+ }
+
+ // Give the buffer to the source manager.
+ // This will be picked up by the parser.
+ SourceMgr sourceMgr;
+ sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc());
+
+ SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &ctx);
+
+ // Parse the input file.
+ OwningOpRef<ModuleOp> module(parseSourceFile<ModuleOp>(sourceMgr, &ctx));
+
+ // Load IRDL dialects.
+ return irdl::loadDialects(module.get());
+}
+
/// Parses the memory buffer. If successfully, run a series of passes against
/// it and print the result.
static LogicalResult processBuffer(raw_ostream &os,
@@ -327,6 +361,10 @@ static LogicalResult processBuffer(raw_ostream &os,
if (threadPool)
context.setThreadPool(*threadPool);
+ StringRef irdlFile = config.getIrdlFile();
+ if (!irdlFile.empty() && failed(loadIRDLDialects(irdlFile, context)))
+ return failure();
+
// Parse the input file.
if (config.shouldPreloadDialectsInContext())
context.loadAllAvailableDialects();
diff --git a/mlir/test/Dialect/IRDL/test-cmath.mlir b/mlir/test/Dialect/IRDL/test-cmath.mlir
new file mode 100644
index 0000000000000..b7370c4fae730
--- /dev/null
+++ b/mlir/test/Dialect/IRDL/test-cmath.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-opt %s --irdl-file=%S/cmath.irdl.mlir | mlir-opt --irdl-file=%S/cmath.irdl.mlir | FileCheck %s
+
+module {
+ // CHECK: func.func @conorm(%[[p:[^:]*]]: !cmath.complex<f32>, %[[q:[^:]*]]: !cmath.complex<f32>) -> f32 {
+ // CHECK: %[[norm_p:[^ ]*]] = "cmath.norm"(%[[p]]) : (!cmath.complex<f32>) -> f32
+ // CHECK: %[[norm_q:[^ ]*]] = "cmath.norm"(%[[q]]) : (!cmath.complex<f32>) -> f32
+ // CHECK: %[[pq:[^ ]*]] = arith.mulf %[[norm_p]], %[[norm_q]] : f32
+ // CHECK: return %[[pq]] : f32
+ // CHECK: }
+ func.func @conorm(%p: !cmath.complex<f32>, %q: !cmath.complex<f32>) -> f32 {
+ %norm_p = "cmath.norm"(%p) : (!cmath.complex<f32>) -> f32
+ %norm_q = "cmath.norm"(%q) : (!cmath.complex<f32>) -> f32
+ %pq = arith.mulf %norm_p, %norm_q : f32
+ return %pq : f32
+ }
+
+ // CHECK: func.func @conorm2(%[[p:[^:]*]]: !cmath.complex<f32>, %[[q:[^:]*]]: !cmath.complex<f32>) -> f32 {
+ // CHECK: %[[pq:[^ ]*]] = "cmath.mul"(%[[p]], %[[q]]) : (!cmath.complex<f32>, !cmath.complex<f32>) -> !cmath.complex<f32>
+ // CHECK: %[[conorm:[^ ]*]] = "cmath.norm"(%[[pq]]) : (!cmath.complex<f32>) -> f32
+ // CHECK: return %[[conorm]] : f32
+ // CHECK: }
+ func.func @conorm2(%p: !cmath.complex<f32>, %q: !cmath.complex<f32>) -> f32 {
+ %pq = "cmath.mul"(%p, %q) : (!cmath.complex<f32>, !cmath.complex<f32>) -> !cmath.complex<f32>
+ %conorm = "cmath.norm"(%pq) : (!cmath.complex<f32>) -> f32
+ return %conorm : f32
+ }
+}
diff --git a/mlir/test/Dialect/IRDL/testd.mlir b/mlir/test/Dialect/IRDL/testd.mlir
new file mode 100644
index 0000000000000..f6d1bcb0e396f
--- /dev/null
+++ b/mlir/test/Dialect/IRDL/testd.mlir
@@ -0,0 +1,108 @@
+// RUN: mlir-opt %s --irdl-file=%S/testd.irdl.mlir -split-input-file -verify-diagnostics | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// Type or attribute constraint
+//===----------------------------------------------------------------------===//
+
+func.func @typeFitsType() {
+ // CHECK: "testd.any"() : () -> !testd.parametric<i32>
+ "testd.any"() : () -> !testd.parametric<i32>
+ return
+}
+
+// -----
+
+func.func @attrDoesntFitType() {
+ "testd.any"() : () -> !testd.parametric<"foo">
+ return
+}
+
+// -----
+
+func.func @attrFitsAttr() {
+ // CHECK: "testd.any"() : () -> !testd.attr_in_type_out<"foo">
+ "testd.any"() : () -> !testd.attr_in_type_out<"foo">
+ return
+}
+
+// -----
+
+func.func @typeFitsAttr() {
+ // CHECK: "testd.any"() : () -> !testd.attr_in_type_out<i32>
+ "testd.any"() : () -> !testd.attr_in_type_out<i32>
+ return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Equality constraint
+//===----------------------------------------------------------------------===//
+
+func.func @succeededEqConstraint() {
+ // CHECK: "testd.eq"() : () -> i32
+ "testd.eq"() : () -> i32
+ return
+}
+
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Any constraint
+//===----------------------------------------------------------------------===//
+
+func.func @succeededAnyConstraint() {
+ // CHECK: "testd.any"() : () -> i32
+ "testd.any"() : () -> i32
+ // CHECK: "testd.any"() : () -> i64
+ "testd.any"() : () -> i64
+ return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Dynamic base constraint
+//===----------------------------------------------------------------------===//
+
+func.func @succeededDynBaseConstraint() {
+ // CHECK: "testd.dynbase"() : () -> !testd.parametric<i32>
+ "testd.dynbase"() : () -> !testd.parametric<i32>
+ // CHECK: "testd.dynbase"() : () -> !testd.parametric<!testd.parametric<i32>>
+ "testd.dynbase"() : () -> !testd.parametric<!testd.parametric<i32>>
+ return
+}
+
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Dynamic parameters constraint
+//===----------------------------------------------------------------------===//
+
+func.func @succeededDynParamsConstraint() {
+ // CHECK: "testd.dynparams"() : () -> !testd.parametric<i32>
+ "testd.dynparams"() : () -> !testd.parametric<i32>
+ return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// Constraint variables
+//===----------------------------------------------------------------------===//
+
+func.func @succeededConstraintVars() {
+ // CHECK: "testd.constraint_vars"() : () -> (i32, i32)
+ "testd.constraint_vars"() : () -> (i32, i32)
+ return
+}
+
+// -----
+
+func.func @succeededConstraintVars2() {
+ // CHECK: "testd.constraint_vars"() : () -> (i64, i64)
+ "testd.constraint_vars"() : () -> (i64, i64)
+ return
+}
More information about the Mlir-commits
mailing list