[Mlir-commits] [mlir] 04f2b71 - [mlir] Fix unsafe create operation in GreedyPatternRewriter

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Mar 23 11:51:35 PDT 2020


Author: MaheshRavishankar
Date: 2020-03-23T11:50:40-07:00
New Revision: 04f2b717d23b17d3bd0a15f6b2b3c3c79a83b955

URL: https://github.com/llvm/llvm-project/commit/04f2b717d23b17d3bd0a15f6b2b3c3c79a83b955
DIFF: https://github.com/llvm/llvm-project/commit/04f2b717d23b17d3bd0a15f6b2b3c3c79a83b955.diff

LOG: [mlir] Fix unsafe create operation in GreedyPatternRewriter

When trying to fold an operation during operation creation check that
the operation folding succeeds before inserting the op.

Differential Revision: https://reviews.llvm.org/D76415

Added: 
    

Modified: 
    mlir/include/mlir/Transforms/FoldUtils.h
    mlir/lib/Transforms/Utils/FoldUtils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Transforms/FoldUtils.h b/mlir/include/mlir/Transforms/FoldUtils.h
index 83ce3bf0d072..0bab87c5e4e3 100644
--- a/mlir/include/mlir/Transforms/FoldUtils.h
+++ b/mlir/include/mlir/Transforms/FoldUtils.h
@@ -75,11 +75,20 @@ class OperationFolder {
   template <typename OpTy, typename... Args>
   void create(OpBuilder &builder, SmallVectorImpl<Value> &results,
               Location location, Args &&... args) {
-    Operation *op = builder.create<OpTy>(location, std::forward<Args>(args)...);
-    if (failed(tryToFold(op, results)))
+    // The op needs to be inserted only if the fold (below) fails, or the number
+    // of results of the op is zero (which is treated as an in-place
+    // fold). Using create methods of the builder will insert the op, so not
+    // using it here.
+    OperationState state(location, OpTy::getOperationName());
+    OpTy::build(&builder, state, std::forward<Args>(args)...);
+    Operation *op = Operation::create(state);
+
+    if (failed(tryToFold(builder, op, results)) || op->getNumResults() == 0) {
+      builder.insert(op);
       results.assign(op->result_begin(), op->result_end());
-    else if (op->getNumResults() != 0)
-      op->erase();
+      return;
+    }
+    op->destroy();
   }
 
   /// Overload to create or fold a single result operation.
@@ -120,7 +129,7 @@ class OperationFolder {
   /// Tries to perform folding on the given `op`. If successful, populates
   /// `results` with the results of the folding.
   LogicalResult tryToFold(
-      Operation *op, SmallVectorImpl<Value> &results,
+      OpBuilder &builder, Operation *op, SmallVectorImpl<Value> &results,
       function_ref<void(Operation *)> processGeneratedConstants = nullptr);
 
   /// Try to get or create a new constant entry. On success this returns the

diff  --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp
index 7d209b2231a2..f2099bca75ea 100644
--- a/mlir/lib/Transforms/Utils/FoldUtils.cpp
+++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp
@@ -24,8 +24,8 @@ using namespace mlir;
 /// inserted into.
 static Region *getInsertionRegion(
     DialectInterfaceCollection<OpFolderDialectInterface> &interfaces,
-    Operation *op) {
-  while (Region *region = op->getParentRegion()) {
+    Block *insertionBlock) {
+  while (Region *region = insertionBlock->getParent()) {
     // Insert in this region for any of the following scenarios:
     //  * The parent is unregistered, or is known to be isolated from above.
     //  * The parent is a top-level operation.
@@ -40,7 +40,7 @@ static Region *getInsertionRegion(
       return region;
 
     // Traverse up the parent looking for an insertion region.
-    op = parentOp;
+    insertionBlock = parentOp->getBlock();
   }
   llvm_unreachable("expected valid insertion region");
 }
@@ -82,7 +82,8 @@ LogicalResult OperationFolder::tryToFold(
 
   // Try to fold the operation.
   SmallVector<Value, 8> results;
-  if (failed(tryToFold(op, results, processGeneratedConstants)))
+  OpBuilder builder(op);
+  if (failed(tryToFold(builder, op, results, processGeneratedConstants)))
     return failure();
 
   // Check to see if the operation was just updated in place.
@@ -117,7 +118,8 @@ void OperationFolder::notifyRemoval(Operation *op) {
   assert(constValue);
 
   // Get the constant map that this operation was uniqued in.
-  auto &uniquedConstants = foldScopes[getInsertionRegion(interfaces, op)];
+  auto &uniquedConstants =
+      foldScopes[getInsertionRegion(interfaces, op->getBlock())];
 
   // Erase all of the references to this operation.
   auto type = op->getResult(0).getType();
@@ -135,7 +137,7 @@ void OperationFolder::clear() {
 /// Tries to perform folding on the given `op`. If successful, populates
 /// `results` with the results of the folding.
 LogicalResult OperationFolder::tryToFold(
-    Operation *op, SmallVectorImpl<Value> &results,
+    OpBuilder &builder, Operation *op, SmallVectorImpl<Value> &results,
     function_ref<void(Operation *)> processGeneratedConstants) {
   SmallVector<Attribute, 8> operandConstants;
   SmallVector<OpFoldResult, 8> foldResults;
@@ -164,9 +166,11 @@ LogicalResult OperationFolder::tryToFold(
 
   // Create a builder to insert new operations into the entry block of the
   // insertion region.
-  auto *insertRegion = getInsertionRegion(interfaces, op);
+  auto *insertRegion =
+      getInsertionRegion(interfaces, builder.getInsertionBlock());
   auto &entry = insertRegion->front();
-  OpBuilder builder(&entry, entry.begin());
+  OpBuilder::InsertionGuard foldGuard(builder);
+  builder.setInsertionPoint(&entry, entry.begin());
 
   // Get the constant map for the insertion region of this operation.
   auto &uniquedConstants = foldScopes[insertRegion];


        


More information about the Mlir-commits mailing list