[llvm] 99f0063 - Unpack the CostEstimate feature in ML inlining models.

Jacob Hegna via llvm-commits llvm-commits at lists.llvm.org
Fri Jul 2 09:59:05 PDT 2021


Author: Jacob Hegna
Date: 2021-07-02T16:57:16Z
New Revision: 99f00635d7acf1cbcdba35e7621f3a211aa3f237

URL: https://github.com/llvm/llvm-project/commit/99f00635d7acf1cbcdba35e7621f3a211aa3f237
DIFF: https://github.com/llvm/llvm-project/commit/99f00635d7acf1cbcdba35e7621f3a211aa3f237.diff

LOG: Unpack the CostEstimate feature in ML inlining models.

This change yields an additional 2% size reduction on an internal search
binary, and an additional 0.5% size reduction on fuchsia.

Differential Revision: https://reviews.llvm.org/D104751

Added: 
    llvm/unittests/Analysis/InlineCostTest.cpp

Modified: 
    llvm/include/llvm/Analysis/InlineCost.h
    llvm/include/llvm/Analysis/InlineModelFeatureMaps.h
    llvm/lib/Analysis/CMakeLists.txt
    llvm/lib/Analysis/InlineCost.cpp
    llvm/lib/Analysis/MLInlineAdvisor.cpp
    llvm/lib/Analysis/models/inlining/config.py
    llvm/unittests/Analysis/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/InlineCost.h b/llvm/include/llvm/Analysis/InlineCost.h
index 7f04a8ce8f5fa..a974e07bd767a 100644
--- a/llvm/include/llvm/Analysis/InlineCost.h
+++ b/llvm/include/llvm/Analysis/InlineCost.h
@@ -15,6 +15,7 @@
 
 #include "llvm/Analysis/AssumptionCache.h"
 #include "llvm/Analysis/CallGraphSCCPass.h"
+#include "llvm/Analysis/InlineModelFeatureMaps.h"
 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
 #include <cassert>
 #include <climits>
@@ -270,6 +271,15 @@ Optional<int> getInliningCostEstimate(
     ProfileSummaryInfo *PSI = nullptr,
     OptimizationRemarkEmitter *ORE = nullptr);
 
+/// Get the expanded cost features. The features are returned unconditionally,
+/// even if inlining is impossible.
+Optional<InlineCostFeatures> getInliningCostFeatures(
+    CallBase &Call, TargetTransformInfo &CalleeTTI,
+    function_ref<AssumptionCache &(Function &)> GetAssumptionCache,
+    function_ref<BlockFrequencyInfo &(Function &)> GetBFI = nullptr,
+    ProfileSummaryInfo *PSI = nullptr,
+    OptimizationRemarkEmitter *ORE = nullptr);
+
 /// Minimal filter to detect invalid constructs for inlining.
 InlineResult isInlineViable(Function &Callee);
 

diff  --git a/llvm/include/llvm/Analysis/InlineModelFeatureMaps.h b/llvm/include/llvm/Analysis/InlineModelFeatureMaps.h
index 9e5286d478cd1..1afa8a825f15a 100644
--- a/llvm/include/llvm/Analysis/InlineModelFeatureMaps.h
+++ b/llvm/include/llvm/Analysis/InlineModelFeatureMaps.h
@@ -16,6 +16,61 @@
 
 namespace llvm {
 
+// List of cost features. A "cost" feature is a summand of the heuristic-based
+// inline cost, and we define them separately to preserve the original heuristic
+// behavior.
+#define INLINE_COST_FEATURE_ITERATOR(M)                                        \
+  M(SROASavings, "sroa_savings")                                               \
+  M(SROALosses, "sroa_losses")                                                 \
+  M(LoadElimination, "load_elimination")                                       \
+  M(CallPenalty, "call_penalty")                                               \
+  M(CallArgumentSetup, "call_argument_setup")                                  \
+  M(LoadRelativeIntrinsic, "load_relative_intrinsic")                          \
+  M(LoweredCallArgSetup, "lowered_call_arg_setup")                             \
+  M(IndirectCallPenalty, "indirect_call_penalty")                              \
+  M(JumpTablePenalty, "jump_table_penalty")                                    \
+  M(CaseClusterPenalty, "case_cluster_penalty")                                \
+  M(SwitchPenalty, "switch_penalty")                                           \
+  M(UnsimplifiedCommonInstructions, "unsimplified_common_instructions")        \
+  M(NumLoops, "num_loops")                                                     \
+  M(DeadBlocks, "dead_blocks")                                                 \
+  M(SimplifiedInstructions, "simplified_instructions")                         \
+  M(ConstantArgs, "constant_args")                                             \
+  M(ConstantOffsetPtrArgs, "constant_offset_ptr_args")                         \
+  M(CallSiteCost, "callsite_cost")                                             \
+  M(ColdCcPenalty, "cold_cc_penalty")                                          \
+  M(LastCallToStaticBonus, "last_call_to_static_bonus")                        \
+  M(IsMultipleBlocks, "is_multiple_blocks")                                    \
+  M(NestedInlines, "nested_inlines")                                           \
+  M(NestedInlineCostEstimate, "nested_inline_cost_estimate")                   \
+  M(Threshold, "threshold")
+
+// clang-format off
+enum class InlineCostFeatureIndex : size_t {
+#define POPULATE_INDICES(INDEX_NAME, NAME) INDEX_NAME,
+  INLINE_COST_FEATURE_ITERATOR(POPULATE_INDICES)
+#undef POPULATE_INDICES
+
+  NumberOfFeatures
+};
+// clang-format on
+
+using InlineCostFeatures =
+    std::array<int,
+               static_cast<size_t>(InlineCostFeatureIndex::NumberOfFeatures)>;
+
+constexpr bool isHeuristicInlineCostFeature(InlineCostFeatureIndex Feature) {
+  return Feature != InlineCostFeatureIndex::SROASavings &&
+         Feature != InlineCostFeatureIndex::IsMultipleBlocks &&
+         Feature != InlineCostFeatureIndex::DeadBlocks &&
+         Feature != InlineCostFeatureIndex::SimplifiedInstructions &&
+         Feature != InlineCostFeatureIndex::ConstantArgs &&
+         Feature != InlineCostFeatureIndex::ConstantOffsetPtrArgs &&
+         Feature != InlineCostFeatureIndex::NestedInlines &&
+         Feature != InlineCostFeatureIndex::NestedInlineCostEstimate &&
+         Feature != InlineCostFeatureIndex::Threshold;
+}
+
 // List of features. Each feature is defined through a triple:
 // - the name of an enum member, which will be the feature index
 // - a textual name, used for Tensorflow model binding (so it needs to match the
@@ -48,12 +103,26 @@ namespace llvm {
     "number of module-internal users of the callee, +1 if the callee is "      \
     "exposed externally")
 
+// clang-format off
 enum class FeatureIndex : size_t {
+// InlineCost features - these must come first
+#define POPULATE_INDICES(INDEX_NAME, NAME) INDEX_NAME,
+  INLINE_COST_FEATURE_ITERATOR(POPULATE_INDICES)
+#undef POPULATE_INDICES
+
+// Non-cost features
 #define POPULATE_INDICES(INDEX_NAME, NAME, COMMENT) INDEX_NAME,
   INLINE_FEATURE_ITERATOR(POPULATE_INDICES)
 #undef POPULATE_INDICES
-      NumberOfFeatures
+
+  NumberOfFeatures
 };
+// clang-format on
+
+constexpr FeatureIndex
+inlineCostFeatureToMlFeature(InlineCostFeatureIndex Feature) {
+  return static_cast<FeatureIndex>(static_cast<size_t>(Feature));
+}
 
 constexpr size_t NumberOfFeatures =
     static_cast<size_t>(FeatureIndex::NumberOfFeatures);

diff  --git a/llvm/lib/Analysis/CMakeLists.txt b/llvm/lib/Analysis/CMakeLists.txt
index 0165e4bc1bcc0..8e1efc1f2b85d 100644
--- a/llvm/lib/Analysis/CMakeLists.txt
+++ b/llvm/lib/Analysis/CMakeLists.txt
@@ -5,7 +5,7 @@ if (DEFINED LLVM_HAVE_TF_AOT OR DEFINED LLVM_HAVE_TF_API)
   # This url points to the most recent most which is known to be compatible with
   # LLVM. When better models are published, this url should be updated to aid
   # discoverability.
-  set(LLVM_INLINER_MODEL_CURRENT_URL "https://github.com/google/ml-compiler-opt/releases/download/inlining-Oz-v0.1/inlining-Oz-acabaf6-v0.1.tar.gz")
+  set(LLVM_INLINER_MODEL_CURRENT_URL "TO_BE_UPDATED")
 
   if (DEFINED LLVM_HAVE_TF_AOT)
     # If the path is empty, autogenerate the model

diff  --git a/llvm/lib/Analysis/InlineCost.cpp b/llvm/lib/Analysis/InlineCost.cpp
index 92b0fbd840860..4e03629148c86 100644
--- a/llvm/lib/Analysis/InlineCost.cpp
+++ b/llvm/lib/Analysis/InlineCost.cpp
@@ -439,6 +439,25 @@ class CallAnalyzer : public InstVisitor<CallAnalyzer, bool> {
   void dump();
 };
 
+// Considering forming a binary search, we should find the number of nodes
+// which is same as the number of comparisons when lowered. For a given
+// number of clusters, n, we can define a recursive function, f(n), to find
+// the number of nodes in the tree. The recursion is :
+// f(n) = 1 + f(n/2) + f (n - n/2), when n > 3,
+// and f(n) = n, when n <= 3.
+// This will lead a binary tree where the leaf should be either f(2) or f(3)
+// when n > 3.  So, the number of comparisons from leaves should be n, while
+// the number of non-leaf should be :
+//   2^(log2(n) - 1) - 1
+//   = 2^log2(n) * 2^-1 - 1
+//   = n / 2 - 1.
+// Considering comparisons from leaf and non-leaf nodes, we can estimate the
+// number of comparisons in a simple closed form :
+//   n + n / 2 - 1 = n * 3 / 2 - 1
+int64_t getExpectedNumberOfCompare(int NumCaseCluster) {
+  return 3 * static_cast<int64_t>(NumCaseCluster) / 2 - 1;
+}
+
 /// FIXME: if it is necessary to derive from InlineCostCallAnalyzer, note
 /// the FIXME in onLoweredCall, when instantiating an InlineCostCallAnalyzer
 class InlineCostCallAnalyzer final : public CallAnalyzer {
@@ -582,28 +601,15 @@ class InlineCostCallAnalyzer final : public CallAnalyzer {
       addCost(JTCost, (int64_t)CostUpperBound);
       return;
     }
-    // Considering forming a binary search, we should find the number of nodes
-    // which is same as the number of comparisons when lowered. For a given
-    // number of clusters, n, we can define a recursive function, f(n), to find
-    // the number of nodes in the tree. The recursion is :
-    // f(n) = 1 + f(n/2) + f (n - n/2), when n > 3,
-    // and f(n) = n, when n <= 3.
-    // This will lead a binary tree where the leaf should be either f(2) or f(3)
-    // when n > 3.  So, the number of comparisons from leaves should be n, while
-    // the number of non-leaf should be :
-    //   2^(log2(n) - 1) - 1
-    //   = 2^log2(n) * 2^-1 - 1
-    //   = n / 2 - 1.
-    // Considering comparisons from leaf and non-leaf nodes, we can estimate the
-    // number of comparisons in a simple closed form :
-    //   n + n / 2 - 1 = n * 3 / 2 - 1
+
     if (NumCaseCluster <= 3) {
       // Suppose a comparison includes one compare and one conditional branch.
       addCost(NumCaseCluster * 2 * InlineConstants::InstrCost);
       return;
     }
 
-    int64_t ExpectedNumberOfCompare = 3 * (int64_t)NumCaseCluster / 2 - 1;
+    int64_t ExpectedNumberOfCompare =
+        getExpectedNumberOfCompare(NumCaseCluster);
     int64_t SwitchCost =
         ExpectedNumberOfCompare * 2 * InlineConstants::InstrCost;
 
@@ -936,6 +942,209 @@ class InlineCostCallAnalyzer final : public CallAnalyzer {
   int getCost() { return Cost; }
   bool wasDecidedByCostBenefit() { return DecidedByCostBenefit; }
 };
+
+class InlineCostFeaturesAnalyzer final : public CallAnalyzer {
+private:
+  InlineCostFeatures Cost = {};
+
+  // FIXME: These constants are taken from the heuristic-based cost visitor.
+  // These should be removed entirely in a later revision to avoid reliance on
+  // heuristics in the ML inliner.
+  static constexpr int JTCostMultiplier = 4;
+  static constexpr int CaseClusterCostMultiplier = 2;
+  static constexpr int SwitchCostMultiplier = 2;
+
+  // FIXME: These are taken from the heuristic-based cost visitor: we should
+  // eventually abstract these to the CallAnalyzer to avoid duplication.
+  unsigned SROACostSavingOpportunities = 0;
+  int VectorBonus = 0;
+  int SingleBBBonus = 0;
+  int Threshold = 5;
+
+  DenseMap<AllocaInst *, unsigned> SROACosts;
+
+  void increment(InlineCostFeatureIndex Feature, int64_t Delta = 1) {
+    Cost[static_cast<size_t>(Feature)] += Delta;
+  }
+
+  void set(InlineCostFeatureIndex Feature, int64_t Value) {
+    Cost[static_cast<size_t>(Feature)] = Value;
+  }
+
+  void onDisableSROA(AllocaInst *Arg) override {
+    auto CostIt = SROACosts.find(Arg);
+    if (CostIt == SROACosts.end())
+      return;
+
+    increment(InlineCostFeatureIndex::SROALosses, CostIt->second);
+    SROACostSavingOpportunities -= CostIt->second;
+    SROACosts.erase(CostIt);
+  }
+
+  void onDisableLoadElimination() override {
+    set(InlineCostFeatureIndex::LoadElimination, 1);
+  }
+
+  void onCallPenalty() override {
+    increment(InlineCostFeatureIndex::CallPenalty,
+              InlineConstants::CallPenalty);
+  }
+
+  void onCallArgumentSetup(const CallBase &Call) override {
+    increment(InlineCostFeatureIndex::CallArgumentSetup,
+              Call.arg_size() * InlineConstants::InstrCost);
+  }
+
+  void onLoadRelativeIntrinsic() override {
+    increment(InlineCostFeatureIndex::LoadRelativeIntrinsic,
+              3 * InlineConstants::InstrCost);
+  }
+
+  void onLoweredCall(Function *F, CallBase &Call,
+                     bool IsIndirectCall) override {
+    increment(InlineCostFeatureIndex::LoweredCallArgSetup,
+              Call.arg_size() * InlineConstants::InstrCost);
+
+    if (IsIndirectCall) {
+      InlineParams IndirectCallParams = {/* DefaultThreshold*/ 0,
+                                         /*HintThreshold*/ {},
+                                         /*ColdThreshold*/ {},
+                                         /*OptSizeThreshold*/ {},
+                                         /*OptMinSizeThreshold*/ {},
+                                         /*HotCallSiteThreshold*/ {},
+                                         /*LocallyHotCallSiteThreshold*/ {},
+                                         /*ColdCallSiteThreshold*/ {},
+                                         /*ComputeFullInlineCost*/ true,
+                                         /*EnableDeferral*/ true};
+      IndirectCallParams.DefaultThreshold =
+          InlineConstants::IndirectCallThreshold;
+
+      InlineCostCallAnalyzer CA(*F, Call, IndirectCallParams, TTI,
+                                GetAssumptionCache, GetBFI, PSI, ORE, false,
+                                true);
+      if (CA.analyze().isSuccess()) {
+        increment(InlineCostFeatureIndex::NestedInlineCostEstimate,
+                  CA.getCost());
+        increment(InlineCostFeatureIndex::NestedInlines, 1);
+      }
+    } else {
+      onCallPenalty();
+    }
+  }
+
+  void onFinalizeSwitch(unsigned JumpTableSize,
+                        unsigned NumCaseCluster) override {
+
+    if (JumpTableSize) {
+      int64_t JTCost =
+          static_cast<int64_t>(JumpTableSize) * InlineConstants::InstrCost +
+          JTCostMultiplier * InlineConstants::InstrCost;
+      increment(InlineCostFeatureIndex::JumpTablePenalty, JTCost);
+      return;
+    }
+
+    if (NumCaseCluster <= 3) {
+      increment(InlineCostFeatureIndex::CaseClusterPenalty,
+                NumCaseCluster * CaseClusterCostMultiplier *
+                    InlineConstants::InstrCost);
+      return;
+    }
+
+    int64_t ExpectedNumberOfCompare =
+        getExpectedNumberOfCompare(NumCaseCluster);
+
+    int64_t SwitchCost = ExpectedNumberOfCompare * SwitchCostMultiplier *
+                         InlineConstants::InstrCost;
+    increment(InlineCostFeatureIndex::SwitchPenalty, SwitchCost);
+  }
+
+  void onMissedSimplification() override {
+    increment(InlineCostFeatureIndex::UnsimplifiedCommonInstructions,
+              InlineConstants::InstrCost);
+  }
+
+  void onInitializeSROAArg(AllocaInst *Arg) override { SROACosts[Arg] = 0; }
+  void onAggregateSROAUse(AllocaInst *Arg) override {
+    SROACosts.find(Arg)->second += InlineConstants::InstrCost;
+    SROACostSavingOpportunities += InlineConstants::InstrCost;
+  }
+
+  void onBlockAnalyzed(const BasicBlock *BB) override {
+    if (BB->getTerminator()->getNumSuccessors() > 1)
+      set(InlineCostFeatureIndex::IsMultipleBlocks, 1);
+    Threshold -= SingleBBBonus;
+  }
+
+  InlineResult finalizeAnalysis() override {
+    auto *Caller = CandidateCall.getFunction();
+    if (Caller->hasMinSize()) {
+      DominatorTree DT(F);
+      LoopInfo LI(DT);
+      for (Loop *L : LI) {
+        // Ignore loops that will not be executed
+        if (DeadBlocks.count(L->getHeader()))
+          continue;
+        increment(InlineCostFeatureIndex::NumLoops,
+                  InlineConstants::CallPenalty);
+      }
+    }
+    set(InlineCostFeatureIndex::DeadBlocks, DeadBlocks.size());
+    set(InlineCostFeatureIndex::SimplifiedInstructions,
+        NumInstructionsSimplified);
+    set(InlineCostFeatureIndex::ConstantArgs, NumConstantArgs);
+    set(InlineCostFeatureIndex::ConstantOffsetPtrArgs,
+        NumConstantOffsetPtrArgs);
+    set(InlineCostFeatureIndex::SROASavings, SROACostSavingOpportunities);
+
+    if (NumVectorInstructions <= NumInstructions / 10)
+      increment(InlineCostFeatureIndex::Threshold, -1 * VectorBonus);
+    else if (NumVectorInstructions <= NumInstructions / 2)
+      increment(InlineCostFeatureIndex::Threshold, -1 * (VectorBonus / 2));
+
+    set(InlineCostFeatureIndex::Threshold, Threshold);
+
+    return InlineResult::success();
+  }
+
+  bool shouldStop() override { return false; }
+
+  void onLoadEliminationOpportunity() override {
+    increment(InlineCostFeatureIndex::LoadElimination, 1);
+  }
+
+  InlineResult onAnalysisStart() override {
+    increment(InlineCostFeatureIndex::CallSiteCost,
+              -1 * getCallsiteCost(this->CandidateCall, DL));
+
+    set(InlineCostFeatureIndex::ColdCcPenalty,
+        (F.getCallingConv() == CallingConv::Cold));
+
+    // FIXME: we shouldn't repeat this logic in both the Features and Cost
+    // analyzer - instead, we should abstract it to a common method in the
+    // CallAnalyzer
+    int SingleBBBonusPercent = 50;
+    int VectorBonusPercent = TTI.getInlinerVectorBonusPercent();
+    Threshold += TTI.adjustInliningThreshold(&CandidateCall);
+    Threshold *= TTI.getInliningThresholdMultiplier();
+    SingleBBBonus = Threshold * SingleBBBonusPercent / 100;
+    VectorBonus = Threshold * VectorBonusPercent / 100;
+    Threshold += (SingleBBBonus + VectorBonus);
+
+    return InlineResult::success();
+  }
+
+public:
+  InlineCostFeaturesAnalyzer(
+      const TargetTransformInfo &TTI,
+      function_ref<AssumptionCache &(Function &)> &GetAssumptionCache,
+      function_ref<BlockFrequencyInfo &(Function &)> GetBFI,
+      ProfileSummaryInfo *PSI, OptimizationRemarkEmitter *ORE, Function &Callee,
+      CallBase &Call)
+      : CallAnalyzer(Callee, Call, TTI, GetAssumptionCache, GetBFI, PSI) {}
+
+  const InlineCostFeatures &features() const { return Cost; }
+};
+
 } // namespace
 
 /// Test whether the given value is an Alloca-derived function argument.
@@ -2502,6 +2711,19 @@ Optional<int> llvm::getInliningCostEstimate(
   return CA.getCost();
 }
 
+Optional<InlineCostFeatures> llvm::getInliningCostFeatures(
+    CallBase &Call, TargetTransformInfo &CalleeTTI,
+    function_ref<AssumptionCache &(Function &)> GetAssumptionCache,
+    function_ref<BlockFrequencyInfo &(Function &)> GetBFI,
+    ProfileSummaryInfo *PSI, OptimizationRemarkEmitter *ORE) {
+  InlineCostFeaturesAnalyzer CFA(CalleeTTI, GetAssumptionCache, GetBFI, PSI,
+                                 ORE, *Call.getCalledFunction(), Call);
+  auto R = CFA.analyze();
+  if (!R.isSuccess())
+    return None;
+  return CFA.features();
+}
+
 Optional<InlineResult> llvm::getAttributeBasedInliningDecision(
     CallBase &Call, Function *Callee, TargetTransformInfo &CalleeTTI,
     function_ref<const TargetLibraryInfo &(Function &)> GetTLI) {

diff  --git a/llvm/lib/Analysis/MLInlineAdvisor.cpp b/llvm/lib/Analysis/MLInlineAdvisor.cpp
index 5ef460960f283..5b95ed223fd90 100644
--- a/llvm/lib/Analysis/MLInlineAdvisor.cpp
+++ b/llvm/lib/Analysis/MLInlineAdvisor.cpp
@@ -43,11 +43,19 @@ static cl::opt<float> SizeIncreaseThreshold(
              "blocking any further inlining."),
     cl::init(2.0));
 
+// clang-format off
 const std::array<std::string, NumberOfFeatures> llvm::FeatureNameMap{
+// InlineCost features - these must come first
+#define POPULATE_NAMES(INDEX_NAME, NAME) NAME,
+  INLINE_COST_FEATURE_ITERATOR(POPULATE_NAMES)
+#undef POPULATE_NAMES
+
+// Non-cost features
 #define POPULATE_NAMES(INDEX_NAME, NAME, COMMENT) NAME,
-    INLINE_FEATURE_ITERATOR(POPULATE_NAMES)
+  INLINE_FEATURE_ITERATOR(POPULATE_NAMES)
 #undef POPULATE_NAMES
 };
+// clang-format on
 
 const char *const llvm::DecisionName = "inlining_decision";
 const char *const llvm::DefaultDecisionName = "inlining_default";
@@ -217,6 +225,12 @@ std::unique_ptr<InlineAdvice> MLInlineAdvisor::getAdviceImpl(CallBase &CB) {
     CostEstimate = *IsCallSiteInlinable;
   }
 
+  const auto CostFeatures =
+      llvm::getInliningCostFeatures(CB, TIR, GetAssumptionCache);
+  if (!CostFeatures) {
+    return std::make_unique<InlineAdvice>(this, CB, ORE, false);
+  }
+
   if (Mandatory)
     return getMandatoryAdvice(CB, true);
 
@@ -234,7 +248,6 @@ std::unique_ptr<InlineAdvice> MLInlineAdvisor::getAdviceImpl(CallBase &CB) {
                           FunctionLevels[&Caller]);
   ModelRunner->setFeature(FeatureIndex::NodeCount, NodeCount);
   ModelRunner->setFeature(FeatureIndex::NrCtantParams, NrCtantParams);
-  ModelRunner->setFeature(FeatureIndex::CostEstimate, CostEstimate);
   ModelRunner->setFeature(FeatureIndex::EdgeCount, EdgeCount);
   ModelRunner->setFeature(FeatureIndex::CallerUsers, CallerBefore.Uses);
   ModelRunner->setFeature(FeatureIndex::CallerConditionallyExecutedBlocks,
@@ -244,6 +257,16 @@ std::unique_ptr<InlineAdvice> MLInlineAdvisor::getAdviceImpl(CallBase &CB) {
   ModelRunner->setFeature(FeatureIndex::CalleeConditionallyExecutedBlocks,
                           CalleeBefore.BlocksReachedFromConditionalInstruction);
   ModelRunner->setFeature(FeatureIndex::CalleeUsers, CalleeBefore.Uses);
+  ModelRunner->setFeature(FeatureIndex::CostEstimate, CostEstimate);
+
+  // Add the cost features
+  for (size_t I = 0;
+       I < static_cast<size_t>(InlineCostFeatureIndex::NumberOfFeatures); ++I) {
+    ModelRunner->setFeature(
+        inlineCostFeatureToMlFeature(static_cast<InlineCostFeatureIndex>(I)),
+        CostFeatures->at(I));
+  }
+
   return getAdviceFromModel(CB, ORE);
 }
 

diff  --git a/llvm/lib/Analysis/models/inlining/config.py b/llvm/lib/Analysis/models/inlining/config.py
index 3da64a6973975..78d3a8259cc29 100644
--- a/llvm/lib/Analysis/models/inlining/config.py
+++ b/llvm/lib/Analysis/models/inlining/config.py
@@ -26,11 +26,42 @@ def get_input_signature():
   # int64 features
   inputs = [
       tf.TensorSpec(dtype=tf.int64, shape=(), name=key) for key in [
-          'caller_basic_block_count', 'caller_conditionally_executed_blocks',
-          'caller_users', 'callee_basic_block_count',
-          'callee_conditionally_executed_blocks', 'callee_users',
-          'nr_ctant_params', 'node_count', 'edge_count', 'callsite_height',
-          'cost_estimate', 'inlining_default'
+          'caller_basic_block_count',
+          'caller_conditionally_executed_blocks',
+          'caller_users',
+          'callee_basic_block_count',
+          'callee_conditionally_executed_blocks',
+          'callee_users',
+          'nr_ctant_params',
+          'node_count',
+          'edge_count',
+          'callsite_height',
+          'cost_estimate',
+          'inlining_default',
+          'sroa_savings',
+          'sroa_losses',
+          'load_elimination',
+          'call_penalty',
+          'call_argument_setup',
+          'load_relative_intrinsic',
+          'lowered_call_arg_setup',
+          'indirect_call_penalty',
+          'jump_table_penalty',
+          'case_cluster_penalty',
+          'switch_penalty',
+          'unsimplified_common_instructions',
+          'num_loops',
+          'dead_blocks',
+          'simplified_instructions',
+          'constant_args',
+          'constant_offset_ptr_args',
+          'callsite_cost',
+          'cold_cc_penalty',
+          'last_call_to_static_bonus',
+          'is_multiple_blocks',
+          'nested_inlines',
+          'nested_inline_cost_estimate',
+          'threshold',
       ]
   ]
 

diff  --git a/llvm/unittests/Analysis/CMakeLists.txt b/llvm/unittests/Analysis/CMakeLists.txt
index 0480649352214..7e3e20e4af287 100644
--- a/llvm/unittests/Analysis/CMakeLists.txt
+++ b/llvm/unittests/Analysis/CMakeLists.txt
@@ -29,6 +29,7 @@ add_llvm_unittest_with_input_files(AnalysisTests
   DomTreeUpdaterTest.cpp
   GlobalsModRefTest.cpp
   FunctionPropertiesAnalysisTest.cpp
+  InlineCostTest.cpp
   IRSimilarityIdentifierTest.cpp
   IVDescriptorsTest.cpp
   LazyCallGraphTest.cpp

diff  --git a/llvm/unittests/Analysis/InlineCostTest.cpp b/llvm/unittests/Analysis/InlineCostTest.cpp
new file mode 100644
index 0000000000000..cb92ea79c4568
--- /dev/null
+++ b/llvm/unittests/Analysis/InlineCostTest.cpp
@@ -0,0 +1,77 @@
+//===- InlineCostTest.cpp - test for InlineCost ---------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/InlineCost.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/AsmParser/Parser.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Module.h"
+#include "gtest/gtest.h"
+
+namespace {
+
+// Tests that we can retrieve the CostFeatures without an error
+TEST(InlineCostTest, CostFeatures) {
+  using namespace llvm;
+
+  const auto *const IR = R"IR(
+define i32 @f(i32) {
+  ret i32 4
+}
+
+define i32 @g(i32) {
+  %2 = call i32 @f(i32 0)
+  ret i32 %2
+}
+)IR";
+
+  LLVMContext C;
+  SMDiagnostic Err;
+  std::unique_ptr<Module> M = parseAssemblyString(IR, Err, C);
+  ASSERT_TRUE(M);
+
+  auto *G = M->getFunction("g");
+  ASSERT_TRUE(G);
+
+  // find the call to f in g
+  CallBase *CB = nullptr;
+  for (auto &BB : *G) {
+    for (auto &I : BB) {
+      if ((CB = dyn_cast<CallBase>(&I)))
+        break;
+    }
+  }
+  ASSERT_TRUE(CB);
+
+  ModuleAnalysisManager MAM;
+  FunctionAnalysisManager FAM;
+  FAM.registerPass([&] { return TargetIRAnalysis(); });
+  FAM.registerPass([&] { return ModuleAnalysisManagerFunctionProxy(MAM); });
+  FAM.registerPass([&] { return AssumptionAnalysis(); });
+  MAM.registerPass([&] { return FunctionAnalysisManagerModuleProxy(FAM); });
+
+  MAM.registerPass([&] { return PassInstrumentationAnalysis(); });
+  FAM.registerPass([&] { return PassInstrumentationAnalysis(); });
+
+  ModulePassManager MPM;
+  MPM.run(*M, MAM);
+
+  auto GetAssumptionCache = [&](Function &F) -> AssumptionCache & {
+    return FAM.getResult<AssumptionAnalysis>(F);
+  };
+  auto &TIR = FAM.getResult<TargetIRAnalysis>(*G);
+
+  const auto Features =
+      llvm::getInliningCostFeatures(*CB, TIR, GetAssumptionCache);
+
+  // Check that the optional is not empty
+  ASSERT_TRUE(Features);
+}
+
+} // namespace


        


More information about the llvm-commits mailing list