[Mlir-commits] [mlir] 04f2b71 - [mlir] Fix unsafe create operation in GreedyPatternRewriter
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Mar 23 11:51:35 PDT 2020
Author: MaheshRavishankar
Date: 2020-03-23T11:50:40-07:00
New Revision: 04f2b717d23b17d3bd0a15f6b2b3c3c79a83b955
URL: https://github.com/llvm/llvm-project/commit/04f2b717d23b17d3bd0a15f6b2b3c3c79a83b955
DIFF: https://github.com/llvm/llvm-project/commit/04f2b717d23b17d3bd0a15f6b2b3c3c79a83b955.diff
LOG: [mlir] Fix unsafe create operation in GreedyPatternRewriter
When trying to fold an operation during operation creation check that
the operation folding succeeds before inserting the op.
Differential Revision: https://reviews.llvm.org/D76415
Added:
Modified:
mlir/include/mlir/Transforms/FoldUtils.h
mlir/lib/Transforms/Utils/FoldUtils.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Transforms/FoldUtils.h b/mlir/include/mlir/Transforms/FoldUtils.h
index 83ce3bf0d072..0bab87c5e4e3 100644
--- a/mlir/include/mlir/Transforms/FoldUtils.h
+++ b/mlir/include/mlir/Transforms/FoldUtils.h
@@ -75,11 +75,20 @@ class OperationFolder {
template <typename OpTy, typename... Args>
void create(OpBuilder &builder, SmallVectorImpl<Value> &results,
Location location, Args &&... args) {
- Operation *op = builder.create<OpTy>(location, std::forward<Args>(args)...);
- if (failed(tryToFold(op, results)))
+ // The op needs to be inserted only if the fold (below) fails, or the number
+ // of results of the op is zero (which is treated as an in-place
+ // fold). Using create methods of the builder will insert the op, so not
+ // using it here.
+ OperationState state(location, OpTy::getOperationName());
+ OpTy::build(&builder, state, std::forward<Args>(args)...);
+ Operation *op = Operation::create(state);
+
+ if (failed(tryToFold(builder, op, results)) || op->getNumResults() == 0) {
+ builder.insert(op);
results.assign(op->result_begin(), op->result_end());
- else if (op->getNumResults() != 0)
- op->erase();
+ return;
+ }
+ op->destroy();
}
/// Overload to create or fold a single result operation.
@@ -120,7 +129,7 @@ class OperationFolder {
/// Tries to perform folding on the given `op`. If successful, populates
/// `results` with the results of the folding.
LogicalResult tryToFold(
- Operation *op, SmallVectorImpl<Value> &results,
+ OpBuilder &builder, Operation *op, SmallVectorImpl<Value> &results,
function_ref<void(Operation *)> processGeneratedConstants = nullptr);
/// Try to get or create a new constant entry. On success this returns the
diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp
index 7d209b2231a2..f2099bca75ea 100644
--- a/mlir/lib/Transforms/Utils/FoldUtils.cpp
+++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp
@@ -24,8 +24,8 @@ using namespace mlir;
/// inserted into.
static Region *getInsertionRegion(
DialectInterfaceCollection<OpFolderDialectInterface> &interfaces,
- Operation *op) {
- while (Region *region = op->getParentRegion()) {
+ Block *insertionBlock) {
+ while (Region *region = insertionBlock->getParent()) {
// Insert in this region for any of the following scenarios:
// * The parent is unregistered, or is known to be isolated from above.
// * The parent is a top-level operation.
@@ -40,7 +40,7 @@ static Region *getInsertionRegion(
return region;
// Traverse up the parent looking for an insertion region.
- op = parentOp;
+ insertionBlock = parentOp->getBlock();
}
llvm_unreachable("expected valid insertion region");
}
@@ -82,7 +82,8 @@ LogicalResult OperationFolder::tryToFold(
// Try to fold the operation.
SmallVector<Value, 8> results;
- if (failed(tryToFold(op, results, processGeneratedConstants)))
+ OpBuilder builder(op);
+ if (failed(tryToFold(builder, op, results, processGeneratedConstants)))
return failure();
// Check to see if the operation was just updated in place.
@@ -117,7 +118,8 @@ void OperationFolder::notifyRemoval(Operation *op) {
assert(constValue);
// Get the constant map that this operation was uniqued in.
- auto &uniquedConstants = foldScopes[getInsertionRegion(interfaces, op)];
+ auto &uniquedConstants =
+ foldScopes[getInsertionRegion(interfaces, op->getBlock())];
// Erase all of the references to this operation.
auto type = op->getResult(0).getType();
@@ -135,7 +137,7 @@ void OperationFolder::clear() {
/// Tries to perform folding on the given `op`. If successful, populates
/// `results` with the results of the folding.
LogicalResult OperationFolder::tryToFold(
- Operation *op, SmallVectorImpl<Value> &results,
+ OpBuilder &builder, Operation *op, SmallVectorImpl<Value> &results,
function_ref<void(Operation *)> processGeneratedConstants) {
SmallVector<Attribute, 8> operandConstants;
SmallVector<OpFoldResult, 8> foldResults;
@@ -164,9 +166,11 @@ LogicalResult OperationFolder::tryToFold(
// Create a builder to insert new operations into the entry block of the
// insertion region.
- auto *insertRegion = getInsertionRegion(interfaces, op);
+ auto *insertRegion =
+ getInsertionRegion(interfaces, builder.getInsertionBlock());
auto &entry = insertRegion->front();
- OpBuilder builder(&entry, entry.begin());
+ OpBuilder::InsertionGuard foldGuard(builder);
+ builder.setInsertionPoint(&entry, entry.begin());
// Get the constant map for the insertion region of this operation.
auto &uniquedConstants = foldScopes[insertRegion];
More information about the Mlir-commits
mailing list