[Mlir-commits] [mlir] f4ef77c - [mlir][Inliner] Properly handle callgraph node deletion

River Riddle llvmlistbot at llvm.org
Wed Jun 17 15:49:54 PDT 2020


Author: River Riddle
Date: 2020-06-17T15:45:56-07:00
New Revision: f4ef77cbb48b549211ecc18085f14ec7a17c01fc

URL: https://github.com/llvm/llvm-project/commit/f4ef77cbb48b549211ecc18085f14ec7a17c01fc
DIFF: https://github.com/llvm/llvm-project/commit/f4ef77cbb48b549211ecc18085f14ec7a17c01fc.diff

LOG: [mlir][Inliner] Properly handle callgraph node deletion

We previously weren't properly updating the SCC iterator when nodes were removed, leading to asan failures in certain situations. This commit adds a CallGraphSCC class and defers operation deletion until inlining has finished.

Differential Revision: https://reviews.llvm.org/D81984

Added: 
    

Modified: 
    mlir/lib/Transforms/Inliner.cpp
    mlir/test/Transforms/inlining-dce.mlir
    mlir/test/lib/Dialect/Test/TestDialect.cpp
    mlir/test/lib/Dialect/Test/TestOps.td

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp
index 0cd706790bdc..e17a379d54b8 100644
--- a/mlir/lib/Transforms/Inliner.cpp
+++ b/mlir/lib/Transforms/Inliner.cpp
@@ -242,19 +242,47 @@ void CGUseList::decrementDiscardableUses(CGUser &uses) {
 // CallGraph traversal
 //===----------------------------------------------------------------------===//
 
+namespace {
+/// This class represents a specific callgraph SCC.
+class CallGraphSCC {
+public:
+  CallGraphSCC(llvm::scc_iterator<const CallGraph *> &parentIterator)
+      : parentIterator(parentIterator) {}
+  /// Return a range over the nodes within this SCC.
+  std::vector<CallGraphNode *>::iterator begin() { return nodes.begin(); }
+  std::vector<CallGraphNode *>::iterator end() { return nodes.end(); }
+
+  /// Reset the nodes of this SCC with those provided.
+  void reset(const std::vector<CallGraphNode *> &newNodes) { nodes = newNodes; }
+
+  /// Remove the given node from this SCC.
+  void remove(CallGraphNode *node) {
+    auto it = llvm::find(nodes, node);
+    if (it != nodes.end()) {
+      nodes.erase(it);
+      parentIterator.ReplaceNode(node, nullptr);
+    }
+  }
+
+private:
+  std::vector<CallGraphNode *> nodes;
+  llvm::scc_iterator<const CallGraph *> &parentIterator;
+};
+} // end anonymous namespace
+
 /// Run a given transformation over the SCCs of the callgraph in a bottom up
 /// traversal.
-static void runTransformOnCGSCCs(
-    const CallGraph &cg,
-    function_ref<void(MutableArrayRef<CallGraphNode *>)> sccTransformer) {
-  std::vector<CallGraphNode *> currentSCCVec;
-  auto cgi = llvm::scc_begin(&cg);
+static void
+runTransformOnCGSCCs(const CallGraph &cg,
+                     function_ref<void(CallGraphSCC &)> sccTransformer) {
+  llvm::scc_iterator<const CallGraph *> cgi = llvm::scc_begin(&cg);
+  CallGraphSCC currentSCC(cgi);
   while (!cgi.isAtEnd()) {
     // Copy the current SCC and increment so that the transformer can modify the
     // SCC without invalidating our iterator.
-    currentSCCVec = *cgi;
+    currentSCC.reset(*cgi);
     ++cgi;
-    sccTransformer(currentSCCVec);
+    sccTransformer(currentSCC);
   }
 }
 
@@ -343,6 +371,19 @@ struct Inliner : public InlinerInterface {
                    /*traverseNestedCGNodes=*/true);
   }
 
+  /// Mark the given callgraph node for deletion.
+  void markForDeletion(CallGraphNode *node) { deadNodes.insert(node); }
+
+  /// This method properly disposes of callables that became dead during
+  /// inlining. This should not be called while iterating over the SCCs.
+  void eraseDeadCallables() {
+    for (CallGraphNode *node : deadNodes)
+      node->getCallableRegion()->getParentOp()->erase();
+  }
+
+  /// The set of callables known to be dead.
+  SmallPtrSet<CallGraphNode *, 8> deadNodes;
+
   /// The current set of call instructions to consider for inlining.
   SmallVector<ResolvedCall, 8> calls;
 
@@ -368,27 +409,16 @@ static bool shouldInline(ResolvedCall &resolvedCall) {
   return true;
 }
 
-/// Delete the given node and remove it from the current scc and the callgraph.
-static void deleteNode(CallGraphNode *node, CGUseList &useList, CallGraph &cg,
-                       MutableArrayRef<CallGraphNode *> currentSCC) {
-  // Erase the parent operation and remove it from the various lists.
-  node->getCallableRegion()->getParentOp()->erase();
-  cg.eraseNode(node);
-
-  // Replace this node in the currentSCC with the external node.
-  auto it = llvm::find(currentSCC, node);
-  if (it != currentSCC.end())
-    *it = cg.getExternalNode();
-}
-
 /// Attempt to inline calls within the given scc. This function returns
 /// success if any calls were inlined, failure otherwise.
-static LogicalResult
-inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
-                 MutableArrayRef<CallGraphNode *> currentSCC) {
+static LogicalResult inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
+                                      CallGraphSCC &currentSCC) {
   CallGraph &cg = inliner.cg;
   auto &calls = inliner.calls;
 
+  // A set of dead nodes to remove after inlining.
+  SmallVector<CallGraphNode *, 1> deadNodes;
+
   // Collect all of the direct calls within the nodes of the current SCC. We
   // don't traverse nested callgraph nodes, because they are handled separately
   // likely within a 
diff erent SCC.
@@ -396,18 +426,13 @@ inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
     if (node->isExternal())
       continue;
 
-    // If this node is dead, just delete it now.
+    // Don't collect calls if the node is already dead.
     if (useList.isDead(node))
-      deleteNode(node, useList, cg, currentSCC);
+      deadNodes.push_back(node);
     else
       collectCallOps(*node->getCallableRegion(), node, cg, calls,
                      /*traverseNestedCGNodes=*/false);
   }
-  if (calls.empty())
-    return failure();
-
-  // A set of dead nodes to remove after inlining.
-  SmallVector<CallGraphNode *, 1> deadNodes;
 
   // Try to inline each of the call operations. Don't cache the end iterator
   // here as more calls may be added during inlining.
@@ -453,8 +478,10 @@ inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
     }
   }
 
-  for (CallGraphNode *node : deadNodes)
-    deleteNode(node, useList, cg, currentSCC);
+  for (CallGraphNode *node : deadNodes) {
+    currentSCC.remove(node);
+    inliner.markForDeletion(node);
+  }
   calls.clear();
   return success(inlinedAnyCalls);
 }
@@ -462,8 +489,7 @@ inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
 /// Canonicalize the nodes within the given SCC with the given set of
 /// canonicalization patterns.
 static void canonicalizeSCC(CallGraph &cg, CGUseList &useList,
-                            MutableArrayRef<CallGraphNode *> currentSCC,
-                            MLIRContext *context,
+                            CallGraphSCC &currentSCC, MLIRContext *context,
                             const OwningRewritePatternList &canonPatterns) {
   // Collect the sets of nodes to canonicalize.
   SmallVector<CallGraphNode *, 4> nodesToCanonicalize;
@@ -533,8 +559,7 @@ struct InlinerPass : public InlinerBase<InlinerPass> {
   /// Attempt to inline calls within the given scc, and run canonicalizations
   /// with the given patterns, until a fixed point is reached. This allows for
   /// the inlining of newly devirtualized calls.
-  void inlineSCC(Inliner &inliner, CGUseList &useList,
-                 MutableArrayRef<CallGraphNode *> currentSCC,
+  void inlineSCC(Inliner &inliner, CGUseList &useList, CallGraphSCC &currentSCC,
                  MLIRContext *context,
                  const OwningRewritePatternList &canonPatterns);
 };
@@ -562,14 +587,16 @@ void InlinerPass::runOnOperation() {
   // Run the inline transform in post-order over the SCCs in the callgraph.
   Inliner inliner(context, cg);
   CGUseList useList(getOperation(), cg);
-  runTransformOnCGSCCs(cg, [&](MutableArrayRef<CallGraphNode *> scc) {
+  runTransformOnCGSCCs(cg, [&](CallGraphSCC &scc) {
     inlineSCC(inliner, useList, scc, context, canonPatterns);
   });
+
+  // After inlining, make sure to erase any callables proven to be dead.
+  inliner.eraseDeadCallables();
 }
 
 void InlinerPass::inlineSCC(Inliner &inliner, CGUseList &useList,
-                            MutableArrayRef<CallGraphNode *> currentSCC,
-                            MLIRContext *context,
+                            CallGraphSCC &currentSCC, MLIRContext *context,
                             const OwningRewritePatternList &canonPatterns) {
   // If we successfully inlined any calls, run some simplifications on the
   // nodes of the scc. Continue attempting to inline until we reach a fixed

diff  --git a/mlir/test/Transforms/inlining-dce.mlir b/mlir/test/Transforms/inlining-dce.mlir
index 73b6489127d8..06504ec39e71 100644
--- a/mlir/test/Transforms/inlining-dce.mlir
+++ b/mlir/test/Transforms/inlining-dce.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -allow-unregistered-dialect %s -inline | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect %s -inline -split-input-file | FileCheck %s
 
 // This file tests the callgraph dead code elimination performed by the inliner.
 
@@ -51,3 +51,23 @@ func @live_function_d() attributes {sym_visibility = "private"} {
 }
 
 "live.user"() {use = @live_function_d} : () -> ()
+
+// -----
+
+// This test checks that the inliner can properly handle the deletion of
+// functions in 
diff erent SCCs that are referenced by calls materialized during
+// canonicalization.
+// CHECK: func @live_function_e
+func @live_function_e() {
+  call @dead_function_e() : () -> ()
+  return
+}
+// CHECK-NOT: func @dead_function_e
+func @dead_function_e() -> () attributes {sym_visibility = "private"} {
+  "test.fold_to_call_op"() {callee=@dead_function_f} : () -> ()
+  return
+}
+// CHECK-NOT: func @dead_function_f
+func @dead_function_f() attributes {sym_visibility = "private"} {
+  return
+}

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index ac067e32e528..e4b1793c9399 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -173,6 +173,28 @@ TestBranchOp::getMutableSuccessorOperands(unsigned index) {
   return targetOperandsMutable();
 }
 
+//===----------------------------------------------------------------------===//
+// TestFoldToCallOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
+  using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(FoldToCallOp op,
+                                PatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<CallOp>(op, ArrayRef<Type>(), op.calleeAttr(),
+                                        ValueRange());
+    return success();
+  }
+};
+} // end anonymous namespace
+
+void FoldToCallOp::getCanonicalizationPatterns(
+    OwningRewritePatternList &results, MLIRContext *context) {
+  results.insert<FoldToCallOpPattern>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // Test IsolatedRegionOp - parse passthrough region arguments.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index a4d74ebfd82b..630a7c58e449 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -321,6 +321,12 @@ def FunctionalRegionOp : TEST_Op<"functional_region_op",
   }];
 }
 
+
+def FoldToCallOp : TEST_Op<"fold_to_call_op"> {
+  let arguments = (ins FlatSymbolRefAttr:$callee);
+  let hasCanonicalizer = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // Test Traits
 //===----------------------------------------------------------------------===//


        


More information about the Mlir-commits mailing list