[Mlir-commits] [mlir] c2ecf16 - [mlir][Inliner] Support recursion in Inliner

Javed Absar llvmlistbot at llvm.org
Thu Jun 30 10:52:58 PDT 2022


Author: Javed Absar
Date: 2022-06-30T18:52:45+01:00
New Revision: c2ecf1622479ee95c6fa2c7232fac52300c2368e

URL: https://github.com/llvm/llvm-project/commit/c2ecf1622479ee95c6fa2c7232fac52300c2368e
DIFF: https://github.com/llvm/llvm-project/commit/c2ecf1622479ee95c6fa2c7232fac52300c2368e.diff

LOG: [mlir][Inliner] Support recursion in Inliner

This fixes  Bug https://github.com/llvm/llvm-project/issues/53492
 and uses InlineHistory to track recursive inlining.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D127072

Added: 
    mlir/test/Transforms/inlining-recursive.mlir

Modified: 
    mlir/lib/Transforms/Inliner.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp
index 8989ec7e1d95d..5ce32b14185ed 100644
--- a/mlir/lib/Transforms/Inliner.cpp
+++ b/mlir/lib/Transforms/Inliner.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Interfaces/CallInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Pass/PassManager.h"
+#include "mlir/Support/DebugStringHelper.h"
 #include "mlir/Transforms/InliningUtils.h"
 #include "mlir/Transforms/Passes.h"
 #include "llvm/ADT/SCCIterator.h"
@@ -364,6 +365,31 @@ static void collectCallOps(iterator_range<Region::iterator> blocks,
 //===----------------------------------------------------------------------===//
 // Inliner
 //===----------------------------------------------------------------------===//
+
+#ifndef NDEBUG
+static std::string getNodeName(CallOpInterface op) {
+  if (auto sym = op.getCallableForCallee().dyn_cast<SymbolRefAttr>())
+    return debugString(op);
+  return "_unnamed_callee_";
+}
+#endif
+
+/// Return true if the specified `inlineHistoryID`  indicates an inline history
+/// that already includes `node`.
+static bool inlineHistoryIncludes(
+    CallGraphNode *node, Optional<size_t> inlineHistoryID,
+    MutableArrayRef<std::pair<CallGraphNode *, Optional<size_t>>>
+        inlineHistory) {
+  while (inlineHistoryID.has_value()) {
+    assert(inlineHistoryID.value() < inlineHistory.size() &&
+           "Invalid inline history ID");
+    if (inlineHistory[inlineHistoryID.value()].first == node)
+      return true;
+    inlineHistoryID = inlineHistory[inlineHistoryID.value()].second;
+  }
+  return false;
+}
+
 namespace {
 /// This class provides a specialization of the main inlining interface.
 struct Inliner : public InlinerInterface {
@@ -454,23 +480,43 @@ static LogicalResult inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
     }
   }
 
+  // When inlining a callee produces new call sites, we want to keep track of
+  // the fact that they were inlined from the callee. This allows us to avoid
+  // infinite inlining.
+  using InlineHistoryT = Optional<size_t>;
+  SmallVector<std::pair<CallGraphNode *, InlineHistoryT>, 8> inlineHistory;
+  std::vector<InlineHistoryT> callHistory(calls.size(), InlineHistoryT{});
+
+  LLVM_DEBUG({
+    llvm::dbgs() << "* Inliner: Initial calls in SCC are: {\n";
+    for (unsigned i = 0, e = calls.size(); i < e; ++i)
+      llvm::dbgs() << "  " << i << ". " << calls[i].call << ",\n";
+    llvm::dbgs() << "}\n";
+  });
+
   // Try to inline each of the call operations. Don't cache the end iterator
   // here as more calls may be added during inlining.
   bool inlinedAnyCalls = false;
-  for (unsigned i = 0; i != calls.size(); ++i) {
+  for (unsigned i = 0; i < calls.size(); ++i) {
     if (deadNodes.contains(calls[i].sourceNode))
       continue;
     ResolvedCall it = calls[i];
-    bool doInline = shouldInline(it);
+
+    InlineHistoryT inlineHistoryID = callHistory[i];
+    bool inHistory =
+        inlineHistoryIncludes(it.targetNode, inlineHistoryID, inlineHistory);
+    bool doInline = !inHistory && shouldInline(it);
     CallOpInterface call = it.call;
     LLVM_DEBUG({
       if (doInline)
-        llvm::dbgs() << "* Inlining call: " << call << "\n";
+        llvm::dbgs() << "* Inlining call: " << i << ". " << call << "\n";
       else
-        llvm::dbgs() << "* Not inlining call: " << call << "\n";
+        llvm::dbgs() << "* Not inlining call: " << i << ". " << call << "\n";
     });
     if (!doInline)
       continue;
+
+    unsigned prevSize = calls.size();
     Region *targetRegion = it.targetNode->getCallableRegion();
 
     // If this is the last call to the target node and the node is discardable,
@@ -486,6 +532,29 @@ static LogicalResult inlineCallsInSCC(Inliner &inliner, CGUseList &useList,
     }
     inlinedAnyCalls = true;
 
+    // Create a inline history entry for this inlined call, so that we remember
+    // that new callsites came about due to inlining Callee.
+    InlineHistoryT newInlineHistoryID{inlineHistory.size()};
+    inlineHistory.push_back(std::make_pair(it.targetNode, inlineHistoryID));
+
+    auto historyToString = [](InlineHistoryT h) {
+      return h.has_value() ? std::to_string(h.value()) : "root";
+    };
+    (void)historyToString;
+    LLVM_DEBUG(llvm::dbgs()
+               << "* new inlineHistory entry: " << newInlineHistoryID << ". ["
+               << getNodeName(call) << ", " << historyToString(inlineHistoryID)
+               << "]\n");
+
+    for (unsigned k = prevSize; k != calls.size(); ++k) {
+      callHistory.push_back(newInlineHistoryID);
+      LLVM_DEBUG(llvm::dbgs() << "* new call " << k << " {" << calls[i].call
+                              << "}\n   with historyID = " << newInlineHistoryID
+                              << ", added due to inlining of\n  call {" << call
+                              << "}\n with historyID = "
+                              << historyToString(inlineHistoryID) << "\n");
+    }
+
     // If the inlining was successful, Merge the new uses into the source node.
     useList.dropCallUses(it.sourceNode, call.getOperation(), cg);
     useList.mergeUsesAfterInlining(it.targetNode, it.sourceNode);

diff  --git a/mlir/test/Transforms/inlining-recursive.mlir b/mlir/test/Transforms/inlining-recursive.mlir
new file mode 100644
index 0000000000000..a02fe69133ad8
--- /dev/null
+++ b/mlir/test/Transforms/inlining-recursive.mlir
@@ -0,0 +1,25 @@
+// RUN: mlir-opt %s -inline='default-pipeline=''' | FileCheck %s
+// RUN: mlir-opt %s --mlir-disable-threading -inline='default-pipeline=''' | FileCheck %s
+
+// CHECK-LABEL: func.func @foo0
+func.func @foo0(%arg0 : i32) -> i32 {
+  // CHECK: call @foo1
+  // CHECK: }
+  %0 = arith.constant 0 : i32
+  %1 = arith.cmpi eq, %arg0, %0 : i32
+  cf.cond_br %1, ^exit, ^tail
+^exit:
+  return %0 : i32
+^tail:
+  %3 = call @foo1(%arg0) : (i32) -> i32
+  return %3 : i32
+}
+
+// CHECK-LABEL: func.func @foo1
+func.func @foo1(%arg0 : i32) -> i32 {
+  // CHECK:    call @foo1
+  %0 = arith.constant 1 : i32
+  %1 = arith.subi %arg0, %0 : i32
+  %2 = call @foo0(%1) : (i32) -> i32
+  return %2 : i32
+}


        


More information about the Mlir-commits mailing list