[Mlir-commits] [mlir] 213c6cd - Harden MLIR detection of misconfiguration when missing dialect registration
Mehdi Amini
llvmlistbot at llvm.org
Thu May 28 01:15:01 PDT 2020
Author: Mehdi Amini
Date: 2020-05-28T08:14:49Z
New Revision: 213c6cdf2e7a30d722cee4cd66b7d48fc396d44b
URL: https://github.com/llvm/llvm-project/commit/213c6cdf2e7a30d722cee4cd66b7d48fc396d44b
DIFF: https://github.com/llvm/llvm-project/commit/213c6cdf2e7a30d722cee4cd66b7d48fc396d44b.diff
LOG: Harden MLIR detection of misconfiguration when missing dialect registration
This changes will catch error where C++ op are used without being
registered, either through creation with the OpBuilder or when trying to
cast to the C++ op.
Differential Revision: https://reviews.llvm.org/D80651
Added:
Modified:
mlir/include/mlir/IR/Builders.h
mlir/include/mlir/IR/MLIRContext.h
mlir/include/mlir/IR/OpDefinition.h
mlir/lib/IR/MLIRContext.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 424eb980cd33..0dcf4daf656f 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -374,6 +374,10 @@ class OpBuilder : public Builder {
template <typename OpTy, typename... Args>
OpTy create(Location location, Args &&... args) {
OperationState state(location, OpTy::getOperationName());
+ if (!state.name.getAbstractOperation())
+ llvm::report_fatal_error("Building op `" +
+ state.name.getStringRef().str() +
+ "` but it isn't registered in this MLIRContext");
OpTy::build(*this, state, std::forward<Args>(args)...);
auto *op = createOperation(state);
auto result = dyn_cast<OpTy>(op);
@@ -390,6 +394,10 @@ class OpBuilder : public Builder {
// Create the operation without using 'createOperation' as we don't want to
// insert it yet.
OperationState state(location, OpTy::getOperationName());
+ if (!state.name.getAbstractOperation())
+ llvm::report_fatal_error("Building op `" +
+ state.name.getStringRef().str() +
+ "` but it isn't registered in this MLIRContext");
OpTy::build(*this, state, std::forward<Args>(args)...);
Operation *op = Operation::create(state);
diff --git a/mlir/include/mlir/IR/MLIRContext.h b/mlir/include/mlir/IR/MLIRContext.h
index da0b0bd826ce..8e75bb624449 100644
--- a/mlir/include/mlir/IR/MLIRContext.h
+++ b/mlir/include/mlir/IR/MLIRContext.h
@@ -85,6 +85,9 @@ class MLIRContext {
/// directly.
std::vector<AbstractOperation *> getRegisteredOperations();
+ /// Return true if this operation name is registered in this context.
+ bool isOperationRegistered(StringRef name);
+
// This is effectively private given that only MLIRContext.cpp can see the
// MLIRContextImpl type.
MLIRContextImpl &getImpl() { return *impl; }
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index bf5bd70c2b7f..e92d54ec84f9 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -1235,7 +1235,10 @@ class Op : public OpState,
static bool classof(Operation *op) {
if (auto *abstractOp = op->getAbstractOperation())
return TypeID::get<ConcreteType>() == abstractOp->typeID;
- return op->getName().getStringRef() == ConcreteType::getOperationName();
+ assert(op->getContext()->isOperationRegistered(
+ ConcreteType::getOperationName()) &&
+ "Casting attempt to an unregistered operation");
+ return false;
}
/// This is the hook used by the AsmParser to parse the custom form of this
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 0728f294be86..da607a2319bf 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -543,6 +543,13 @@ std::vector<AbstractOperation *> MLIRContext::getRegisteredOperations() {
return result;
}
+bool MLIRContext::isOperationRegistered(StringRef name) {
+ // Lock access to the context registry.
+ ScopedReaderLock registryLock(impl->contextMutex, impl->threadingIsEnabled);
+
+ return impl->registeredOperations.count(name);
+}
+
void Dialect::addOperation(AbstractOperation opInfo) {
assert((getNamespace().empty() ||
opInfo.name.split('.').first == getNamespace()) &&
@@ -621,8 +628,9 @@ Identifier Identifier::get(StringRef str, MLIRContext *context) {
static Dialect &lookupDialectForSymbol(MLIRContext *ctx, TypeID typeID) {
auto &impl = ctx->getImpl();
auto it = impl.registeredDialectSymbols.find(typeID);
- assert(it != impl.registeredDialectSymbols.end() &&
- "symbol is not registered.");
+ if (it == impl.registeredDialectSymbols.end())
+ llvm::report_fatal_error(
+ "Trying to create a type that was not registered in this MLIRContext.");
return *it->second;
}
More information about the Mlir-commits
mailing list