[Mlir-commits] [mlir] f4ef77c - [mlir][Inliner] Properly handle callgraph node deletion
River Riddle
llvmlistbot at llvm.org
Wed Jun 17 15:49:54 PDT 2020
Author: River Riddle
Date: 2020-06-17T15:45:56-07:00
New Revision: f4ef77cbb48b549211ecc18085f14ec7a17c01fc
URL: https://github.com/llvm/llvm-project/commit/f4ef77cbb48b549211ecc18085f14ec7a17c01fc
DIFF: https://github.com/llvm/llvm-project/commit/f4ef77cbb48b549211ecc18085f14ec7a17c01fc.diff
LOG: [mlir][Inliner] Properly handle callgraph node deletion
We previously weren't properly updating the SCC iterator when nodes were removed, leading to asan failures in certain situations. This commit adds a CallGraphSCC class and defers operation deletion until inlining has finished.
Differential Revision: https://reviews.llvm.org/D81984
Added:
Modified:
mlir/lib/Transforms/Inliner.cpp
mlir/test/Transforms/inlining-dce.mlir
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestOps.td
Removed:
################################################################################
diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp
index 0cd706790bdc..e17a379d54b8 100644
--- a/mlir/lib/Transforms/Inliner.cpp
+++ b/mlir/lib/Transforms/Inliner.cpp
@@ -242,19 +242,47 @@ void CGUseList::decrementDiscardableUses(CGUser &uses) {
// CallGraph traversal
//===----------------------------------------------------------------------===//
+namespace {
+/// This class represents a specific callgraph SCC.
+class CallGraphSCC {
+public:
+ CallGraphSCC(llvm::scc_iterator<const CallGraph *> &parentIterator)
+ : parentIterator(parentIterator) {}
+ /// Return a range over the nodes within this SCC.
+ std::vector<CallGraphNode *>::iterator begin() { return nodes.begin(); }
+ std::vector<CallGraphNode *>::iterator end() { return nodes.end(); }
+
+ /// Reset the nodes of this SCC with those provided.
+ void reset(const std::vector<CallGraphNode *> &newNodes) { nodes = newNodes; }
+
+ /// Remove the given node from this SCC.
+ void remove(CallGraphNode *node) {
+ auto it = llvm::find(nodes, node);
+ if (it != nodes.end()) {
+ nodes.erase(it);
+ parentIterator.ReplaceNode(node, nullptr);
+ }
+ }
+
+private:
+ std::vector<CallGraphNode *> nodes;
+ llvm::scc_iterator<const CallGraph *> &parentIterator;
+};
+} // end anonymous namespace
+
/// Run a given transformation over the SCCs of the callgraph in a bottom up
/// traversal.
-static void runTransformOnCGSCCs(
- const CallGraph &cg,
- function_ref<void(MutableArrayRef<CallGraphNode *>)> sccTransformer) {
- std::vector<CallGraphNode *> currentSCCVec;
- auto cgi = llvm::scc_begin(&cg);
+static void
+runTransformOnCGSCCs(const CallGraph &cg,
+ function_ref<void(CallGraphSCC &)> sccTransformer) {
+ llvm::scc_iterator<const CallGraph *> cgi = llvm::scc_begin(&cg);
+ CallGraphSCC currentSCC(cgi);
while (!cgi.isAtEnd()) {
// Copy the current SCC and increment so that the transformer can modify the
// SCC without invalidating our iterator.
- currentSCCVec = *cgi;
+ currentSCC.reset(*cgi);
++cgi;
- sccTransformer(currentSCCVec);
+ sccTransformer(currentSCC);
}
}
@@ -343,6 +371,19 @@ struct Inliner : public InlinerInterface {
/*traverseNestedCGNodes=*/true);
}
+ /// Mark the given callgraph node for deletion.
+ void markForDeletion(CallGraphNode *node) { deadNodes.insert(node); }
+
+ /// This method properly disposes of callables that became dead during
+ /// inlining. This should not be called while iterating over the SCCs.
+ void eraseDeadCallables() {
+ for (CallGraphNode *node : deadNodes)
+ node->getCallableRegion()->getParentOp()->erase();
+ }
+
+ /// The set of callables known to be dead.
+ SmallPtrSet<CallGraphNode *, 8> deadNodes;
+
/// The current set of call instructions to consider for inlining.
SmallVector<ResolvedCall, 8> calls;
@@ -368,27 +409,16 @@ static bool shouldInline(ResolvedCall &resolvedCall) {
return true;
}
-/// Delete the given node and remove it from the current scc and the callgraph.
-static void deleteNode(CallGraphNode *node, CGUseList &useList, CallGraph &cg,
- MutableArrayRef<CallGraphNode *> currentSCC) {
- // Erase the parent operation and remove it from the various lists.
- node->getCallableRegion()->getParentOp()->erase();
- cg.eraseNode(node);
-
- // Replace this node in the currentSCC with the external node.
- auto it = llvm::find(currentSCC, node);
- if (it != currentSCC.end())
- *it = cg.getExternalNode();
-}
-
/// Attempt to inline calls within the given scc. This function returns
/// success if any calls were inlined, failure otherwise.
-static LogicalResult
-inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
- MutableArrayRef<CallGraphNode *> currentSCC) {
+static LogicalResult inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
+ CallGraphSCC ¤tSCC) {
CallGraph &cg = inliner.cg;
auto &calls = inliner.calls;
+ // A set of dead nodes to remove after inlining.
+ SmallVector<CallGraphNode *, 1> deadNodes;
+
// Collect all of the direct calls within the nodes of the current SCC. We
// don't traverse nested callgraph nodes, because they are handled separately
// likely within a
diff erent SCC.
@@ -396,18 +426,13 @@ inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
if (node->isExternal())
continue;
- // If this node is dead, just delete it now.
+ // Don't collect calls if the node is already dead.
if (useList.isDead(node))
- deleteNode(node, useList, cg, currentSCC);
+ deadNodes.push_back(node);
else
collectCallOps(*node->getCallableRegion(), node, cg, calls,
/*traverseNestedCGNodes=*/false);
}
- if (calls.empty())
- return failure();
-
- // A set of dead nodes to remove after inlining.
- SmallVector<CallGraphNode *, 1> deadNodes;
// Try to inline each of the call operations. Don't cache the end iterator
// here as more calls may be added during inlining.
@@ -453,8 +478,10 @@ inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
}
}
- for (CallGraphNode *node : deadNodes)
- deleteNode(node, useList, cg, currentSCC);
+ for (CallGraphNode *node : deadNodes) {
+ currentSCC.remove(node);
+ inliner.markForDeletion(node);
+ }
calls.clear();
return success(inlinedAnyCalls);
}
@@ -462,8 +489,7 @@ inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
/// Canonicalize the nodes within the given SCC with the given set of
/// canonicalization patterns.
static void canonicalizeSCC(CallGraph &cg, CGUseList &useList,
- MutableArrayRef<CallGraphNode *> currentSCC,
- MLIRContext *context,
+ CallGraphSCC ¤tSCC, MLIRContext *context,
const OwningRewritePatternList &canonPatterns) {
// Collect the sets of nodes to canonicalize.
SmallVector<CallGraphNode *, 4> nodesToCanonicalize;
@@ -533,8 +559,7 @@ struct InlinerPass : public InlinerBase<InlinerPass> {
/// Attempt to inline calls within the given scc, and run canonicalizations
/// with the given patterns, until a fixed point is reached. This allows for
/// the inlining of newly devirtualized calls.
- void inlineSCC(Inliner &inliner, CGUseList &useList,
- MutableArrayRef<CallGraphNode *> currentSCC,
+ void inlineSCC(Inliner &inliner, CGUseList &useList, CallGraphSCC ¤tSCC,
MLIRContext *context,
const OwningRewritePatternList &canonPatterns);
};
@@ -562,14 +587,16 @@ void InlinerPass::runOnOperation() {
// Run the inline transform in post-order over the SCCs in the callgraph.
Inliner inliner(context, cg);
CGUseList useList(getOperation(), cg);
- runTransformOnCGSCCs(cg, [&](MutableArrayRef<CallGraphNode *> scc) {
+ runTransformOnCGSCCs(cg, [&](CallGraphSCC &scc) {
inlineSCC(inliner, useList, scc, context, canonPatterns);
});
+
+ // After inlining, make sure to erase any callables proven to be dead.
+ inliner.eraseDeadCallables();
}
void InlinerPass::inlineSCC(Inliner &inliner, CGUseList &useList,
- MutableArrayRef<CallGraphNode *> currentSCC,
- MLIRContext *context,
+ CallGraphSCC ¤tSCC, MLIRContext *context,
const OwningRewritePatternList &canonPatterns) {
// If we successfully inlined any calls, run some simplifications on the
// nodes of the scc. Continue attempting to inline until we reach a fixed
diff --git a/mlir/test/Transforms/inlining-dce.mlir b/mlir/test/Transforms/inlining-dce.mlir
index 73b6489127d8..06504ec39e71 100644
--- a/mlir/test/Transforms/inlining-dce.mlir
+++ b/mlir/test/Transforms/inlining-dce.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -allow-unregistered-dialect %s -inline | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect %s -inline -split-input-file | FileCheck %s
// This file tests the callgraph dead code elimination performed by the inliner.
@@ -51,3 +51,23 @@ func @live_function_d() attributes {sym_visibility = "private"} {
}
"live.user"() {use = @live_function_d} : () -> ()
+
+// -----
+
+// This test checks that the inliner can properly handle the deletion of
+// functions in
diff erent SCCs that are referenced by calls materialized during
+// canonicalization.
+// CHECK: func @live_function_e
+func @live_function_e() {
+ call @dead_function_e() : () -> ()
+ return
+}
+// CHECK-NOT: func @dead_function_e
+func @dead_function_e() -> () attributes {sym_visibility = "private"} {
+ "test.fold_to_call_op"() {callee=@dead_function_f} : () -> ()
+ return
+}
+// CHECK-NOT: func @dead_function_f
+func @dead_function_f() attributes {sym_visibility = "private"} {
+ return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index ac067e32e528..e4b1793c9399 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -173,6 +173,28 @@ TestBranchOp::getMutableSuccessorOperands(unsigned index) {
return targetOperandsMutable();
}
+//===----------------------------------------------------------------------===//
+// TestFoldToCallOp
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
+ using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(FoldToCallOp op,
+ PatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<CallOp>(op, ArrayRef<Type>(), op.calleeAttr(),
+ ValueRange());
+ return success();
+ }
+};
+} // end anonymous namespace
+
+void FoldToCallOp::getCanonicalizationPatterns(
+ OwningRewritePatternList &results, MLIRContext *context) {
+ results.insert<FoldToCallOpPattern>(context);
+}
+
//===----------------------------------------------------------------------===//
// Test IsolatedRegionOp - parse passthrough region arguments.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index a4d74ebfd82b..630a7c58e449 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -321,6 +321,12 @@ def FunctionalRegionOp : TEST_Op<"functional_region_op",
}];
}
+
+def FoldToCallOp : TEST_Op<"fold_to_call_op"> {
+ let arguments = (ins FlatSymbolRefAttr:$callee);
+ let hasCanonicalizer = 1;
+}
+
//===----------------------------------------------------------------------===//
// Test Traits
//===----------------------------------------------------------------------===//
More information about the Mlir-commits
mailing list