[Mlir-commits] [mlir] c2ecf16 - [mlir][Inliner] Support recursion in Inliner
Javed Absar
llvmlistbot at llvm.org
Thu Jun 30 10:52:58 PDT 2022
Author: Javed Absar
Date: 2022-06-30T18:52:45+01:00
New Revision: c2ecf1622479ee95c6fa2c7232fac52300c2368e
URL: https://github.com/llvm/llvm-project/commit/c2ecf1622479ee95c6fa2c7232fac52300c2368e
DIFF: https://github.com/llvm/llvm-project/commit/c2ecf1622479ee95c6fa2c7232fac52300c2368e.diff
LOG: [mlir][Inliner] Support recursion in Inliner
This fixes Bug https://github.com/llvm/llvm-project/issues/53492
and uses InlineHistory to track recursive inlining.
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D127072
Added:
mlir/test/Transforms/inlining-recursive.mlir
Modified:
mlir/lib/Transforms/Inliner.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp
index 8989ec7e1d95d..5ce32b14185ed 100644
--- a/mlir/lib/Transforms/Inliner.cpp
+++ b/mlir/lib/Transforms/Inliner.cpp
@@ -19,6 +19,7 @@
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/DebugStringHelper.h"
#include "mlir/Transforms/InliningUtils.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/SCCIterator.h"
@@ -364,6 +365,31 @@ static void collectCallOps(iterator_range<Region::iterator> blocks,
//===----------------------------------------------------------------------===//
// Inliner
//===----------------------------------------------------------------------===//
+
+#ifndef NDEBUG
+static std::string getNodeName(CallOpInterface op) {
+ if (auto sym = op.getCallableForCallee().dyn_cast<SymbolRefAttr>())
+ return debugString(op);
+ return "_unnamed_callee_";
+}
+#endif
+
+/// Return true if the specified `inlineHistoryID` indicates an inline history
+/// that already includes `node`.
+static bool inlineHistoryIncludes(
+ CallGraphNode *node, Optional<size_t> inlineHistoryID,
+ MutableArrayRef<std::pair<CallGraphNode *, Optional<size_t>>>
+ inlineHistory) {
+ while (inlineHistoryID.has_value()) {
+ assert(inlineHistoryID.value() < inlineHistory.size() &&
+ "Invalid inline history ID");
+ if (inlineHistory[inlineHistoryID.value()].first == node)
+ return true;
+ inlineHistoryID = inlineHistory[inlineHistoryID.value()].second;
+ }
+ return false;
+}
+
namespace {
/// This class provides a specialization of the main inlining interface.
struct Inliner : public InlinerInterface {
@@ -454,23 +480,43 @@ static LogicalResult inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
}
}
+ // When inlining a callee produces new call sites, we want to keep track of
+ // the fact that they were inlined from the callee. This allows us to avoid
+ // infinite inlining.
+ using InlineHistoryT = Optional<size_t>;
+ SmallVector<std::pair<CallGraphNode *, InlineHistoryT>, 8> inlineHistory;
+ std::vector<InlineHistoryT> callHistory(calls.size(), InlineHistoryT{});
+
+ LLVM_DEBUG({
+ llvm::dbgs() << "* Inliner: Initial calls in SCC are: {\n";
+ for (unsigned i = 0, e = calls.size(); i < e; ++i)
+ llvm::dbgs() << " " << i << ". " << calls[i].call << ",\n";
+ llvm::dbgs() << "}\n";
+ });
+
// Try to inline each of the call operations. Don't cache the end iterator
// here as more calls may be added during inlining.
bool inlinedAnyCalls = false;
- for (unsigned i = 0; i != calls.size(); ++i) {
+ for (unsigned i = 0; i < calls.size(); ++i) {
if (deadNodes.contains(calls[i].sourceNode))
continue;
ResolvedCall it = calls[i];
- bool doInline = shouldInline(it);
+
+ InlineHistoryT inlineHistoryID = callHistory[i];
+ bool inHistory =
+ inlineHistoryIncludes(it.targetNode, inlineHistoryID, inlineHistory);
+ bool doInline = !inHistory && shouldInline(it);
CallOpInterface call = it.call;
LLVM_DEBUG({
if (doInline)
- llvm::dbgs() << "* Inlining call: " << call << "\n";
+ llvm::dbgs() << "* Inlining call: " << i << ". " << call << "\n";
else
- llvm::dbgs() << "* Not inlining call: " << call << "\n";
+ llvm::dbgs() << "* Not inlining call: " << i << ". " << call << "\n";
});
if (!doInline)
continue;
+
+ unsigned prevSize = calls.size();
Region *targetRegion = it.targetNode->getCallableRegion();
// If this is the last call to the target node and the node is discardable,
@@ -486,6 +532,29 @@ static LogicalResult inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
}
inlinedAnyCalls = true;
+ // Create a inline history entry for this inlined call, so that we remember
+ // that new callsites came about due to inlining Callee.
+ InlineHistoryT newInlineHistoryID{inlineHistory.size()};
+ inlineHistory.push_back(std::make_pair(it.targetNode, inlineHistoryID));
+
+ auto historyToString = [](InlineHistoryT h) {
+ return h.has_value() ? std::to_string(h.value()) : "root";
+ };
+ (void)historyToString;
+ LLVM_DEBUG(llvm::dbgs()
+ << "* new inlineHistory entry: " << newInlineHistoryID << ". ["
+ << getNodeName(call) << ", " << historyToString(inlineHistoryID)
+ << "]\n");
+
+ for (unsigned k = prevSize; k != calls.size(); ++k) {
+ callHistory.push_back(newInlineHistoryID);
+ LLVM_DEBUG(llvm::dbgs() << "* new call " << k << " {" << calls[i].call
+ << "}\n with historyID = " << newInlineHistoryID
+ << ", added due to inlining of\n call {" << call
+ << "}\n with historyID = "
+ << historyToString(inlineHistoryID) << "\n");
+ }
+
// If the inlining was successful, Merge the new uses into the source node.
useList.dropCallUses(it.sourceNode, call.getOperation(), cg);
useList.mergeUsesAfterInlining(it.targetNode, it.sourceNode);
diff --git a/mlir/test/Transforms/inlining-recursive.mlir b/mlir/test/Transforms/inlining-recursive.mlir
new file mode 100644
index 0000000000000..a02fe69133ad8
--- /dev/null
+++ b/mlir/test/Transforms/inlining-recursive.mlir
@@ -0,0 +1,25 @@
+// RUN: mlir-opt %s -inline='default-pipeline=''' | FileCheck %s
+// RUN: mlir-opt %s --mlir-disable-threading -inline='default-pipeline=''' | FileCheck %s
+
+// CHECK-LABEL: func.func @foo0
+func.func @foo0(%arg0 : i32) -> i32 {
+ // CHECK: call @foo1
+ // CHECK: }
+ %0 = arith.constant 0 : i32
+ %1 = arith.cmpi eq, %arg0, %0 : i32
+ cf.cond_br %1, ^exit, ^tail
+^exit:
+ return %0 : i32
+^tail:
+ %3 = call @foo1(%arg0) : (i32) -> i32
+ return %3 : i32
+}
+
+// CHECK-LABEL: func.func @foo1
+func.func @foo1(%arg0 : i32) -> i32 {
+ // CHECK: call @foo1
+ %0 = arith.constant 1 : i32
+ %1 = arith.subi %arg0, %0 : i32
+ %2 = call @foo0(%1) : (i32) -> i32
+ return %2 : i32
+}
More information about the Mlir-commits
mailing list