[llvm] f9b3e34 - Adjust macros which define the ML inlining features.

Jacob Hegna via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 27 15:57:53 PDT 2023


Author: Jacob Hegna
Date: 2023-04-27T22:47:12Z
New Revision: f9b3e3411cc09393afecfdef6439a77e9ba77cc9

URL: https://github.com/llvm/llvm-project/commit/f9b3e3411cc09393afecfdef6439a77e9ba77cc9
DIFF: https://github.com/llvm/llvm-project/commit/f9b3e3411cc09393afecfdef6439a77e9ba77cc9.diff

LOG: Adjust macros which define the ML inlining features.

This aligns the inlining macros more closely with how the regalloc
macros are defined.

 - Explicitly specify the dtype/shape
 - Remove separate names for python/C++
 - Add docstring for inline cost features

Differential Revision: https://reviews.llvm.org/D149384

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/InlineModelFeatureMaps.h
    llvm/lib/Analysis/InlineCost.cpp
    llvm/lib/Analysis/MLInlineAdvisor.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/InlineModelFeatureMaps.h b/llvm/include/llvm/Analysis/InlineModelFeatureMaps.h
index a64e4c3cb727..5523962085af 100644
--- a/llvm/include/llvm/Analysis/InlineModelFeatureMaps.h
+++ b/llvm/include/llvm/Analysis/InlineModelFeatureMaps.h
@@ -18,38 +18,60 @@
 
 namespace llvm {
 
+const std::vector<int64_t> ScalarShape;
+
 // List of cost features. A "cost" feature is a summand of the heuristic-based
 // inline cost, and we define them separately to preserve the original heuristic
 // behavior.
 #define INLINE_COST_FEATURE_ITERATOR(M)                                        \
-  M(SROASavings, "sroa_savings")                                               \
-  M(SROALosses, "sroa_losses")                                                 \
-  M(LoadElimination, "load_elimination")                                       \
-  M(CallPenalty, "call_penalty")                                               \
-  M(CallArgumentSetup, "call_argument_setup")                                  \
-  M(LoadRelativeIntrinsic, "load_relative_intrinsic")                          \
-  M(LoweredCallArgSetup, "lowered_call_arg_setup")                             \
-  M(IndirectCallPenalty, "indirect_call_penalty")                              \
-  M(JumpTablePenalty, "jump_table_penalty")                                    \
-  M(CaseClusterPenalty, "case_cluster_penalty")                                \
-  M(SwitchPenalty, "switch_penalty")                                           \
-  M(UnsimplifiedCommonInstructions, "unsimplified_common_instructions")        \
-  M(NumLoops, "num_loops")                                                     \
-  M(DeadBlocks, "dead_blocks")                                                 \
-  M(SimplifiedInstructions, "simplified_instructions")                         \
-  M(ConstantArgs, "constant_args")                                             \
-  M(ConstantOffsetPtrArgs, "constant_offset_ptr_args")                         \
-  M(CallSiteCost, "callsite_cost")                                             \
-  M(ColdCcPenalty, "cold_cc_penalty")                                          \
-  M(LastCallToStaticBonus, "last_call_to_static_bonus")                        \
-  M(IsMultipleBlocks, "is_multiple_blocks")                                    \
-  M(NestedInlines, "nested_inlines")                                           \
-  M(NestedInlineCostEstimate, "nested_inline_cost_estimate")                   \
-  M(Threshold, "threshold")
+  M(int64_t, ScalarShape, sroa_savings,                                        \
+    "Savings from SROA (scalar replacement of aggregates)")                    \
+  M(int64_t, ScalarShape, sroa_losses,                                         \
+    "Losses from SROA (scalar replacement of aggregates)")                     \
+  M(int64_t, ScalarShape, load_elimination,                                    \
+    "Cost of load elimination in the call")                                    \
+  M(int64_t, ScalarShape, call_penalty,                                        \
+    "Accumulation of penalty applied to call sites when inlining")             \
+  M(int64_t, ScalarShape, call_argument_setup,                                 \
+    "Accumulation of call argument setup costs")                               \
+  M(int64_t, ScalarShape, load_relative_intrinsic,                             \
+    "Accumulation of costs of loading relative intrinsics")                    \
+  M(int64_t, ScalarShape, lowered_call_arg_setup,                              \
+    "Accumulation of cost of lowered call argument setups")                    \
+  M(int64_t, ScalarShape, indirect_call_penalty,                               \
+    "Accumulation of costs for indirect calls")                                \
+  M(int64_t, ScalarShape, jump_table_penalty,                                  \
+    "Accumulation of costs for jump tables")                                   \
+  M(int64_t, ScalarShape, case_cluster_penalty,                                \
+    "Accumulation of costs for case clusters")                                 \
+  M(int64_t, ScalarShape, switch_penalty,                                      \
+    "Accumulation of costs for switch statements")                             \
+  M(int64_t, ScalarShape, unsimplified_common_instructions,                    \
+    "Costs from unsimplified common instructions")                             \
+  M(int64_t, ScalarShape, num_loops, "Number of loops in the caller")          \
+  M(int64_t, ScalarShape, dead_blocks, "Number of dead blocks in the caller")  \
+  M(int64_t, ScalarShape, simplified_instructions,                             \
+    "Number of simplified instructions")                                       \
+  M(int64_t, ScalarShape, constant_args,                                       \
+    "Number of constant arguments in the call site")                           \
+  M(int64_t, ScalarShape, constant_offset_ptr_args,                            \
+    "Number of constant offset pointer args in the call site")                 \
+  M(int64_t, ScalarShape, callsite_cost, "Estimated cost of the call site")    \
+  M(int64_t, ScalarShape, cold_cc_penalty,                                     \
+    "Penalty for a cold calling convention")                                   \
+  M(int64_t, ScalarShape, last_call_to_static_bonus,                           \
+    "Bonus for being the last call to static")                                 \
+  M(int64_t, ScalarShape, is_multiple_blocks,                                  \
+    "Boolean; is the Callee multiple blocks")                                  \
+  M(int64_t, ScalarShape, nested_inlines,                                      \
+    "Would the default inliner perfom nested inlining")                        \
+  M(int64_t, ScalarShape, nested_inline_cost_estimate,                         \
+    "Estimate of the accumulated cost of nested inlines")                      \
+  M(int64_t, ScalarShape, threshold, "Threshold for the heuristic inliner")
 
 // clang-format off
 enum class InlineCostFeatureIndex : size_t {
-#define POPULATE_INDICES(INDEX_NAME, NAME) INDEX_NAME,
+#define POPULATE_INDICES(DTYPE, SHAPE, NAME, DOC) NAME,
   INLINE_COST_FEATURE_ITERATOR(POPULATE_INDICES)
 #undef POPULATE_INDICES
 
@@ -62,15 +84,15 @@ using InlineCostFeatures =
                static_cast<size_t>(InlineCostFeatureIndex::NumberOfFeatures)>;
 
 constexpr bool isHeuristicInlineCostFeature(InlineCostFeatureIndex Feature) {
-  return Feature != InlineCostFeatureIndex::SROASavings &&
-         Feature != InlineCostFeatureIndex::IsMultipleBlocks &&
-         Feature != InlineCostFeatureIndex::DeadBlocks &&
-         Feature != InlineCostFeatureIndex::SimplifiedInstructions &&
-         Feature != InlineCostFeatureIndex::ConstantArgs &&
-         Feature != InlineCostFeatureIndex::ConstantOffsetPtrArgs &&
-         Feature != InlineCostFeatureIndex::NestedInlines &&
-         Feature != InlineCostFeatureIndex::NestedInlineCostEstimate &&
-         Feature != InlineCostFeatureIndex::Threshold;
+  return Feature != InlineCostFeatureIndex::sroa_savings &&
+         Feature != InlineCostFeatureIndex::is_multiple_blocks &&
+         Feature != InlineCostFeatureIndex::dead_blocks &&
+         Feature != InlineCostFeatureIndex::simplified_instructions &&
+         Feature != InlineCostFeatureIndex::constant_args &&
+         Feature != InlineCostFeatureIndex::constant_offset_ptr_args &&
+         Feature != InlineCostFeatureIndex::nested_inlines &&
+         Feature != InlineCostFeatureIndex::nested_inline_cost_estimate &&
+         Feature != InlineCostFeatureIndex::threshold;
 }
 
 // List of features. Each feature is defined through a triple:
@@ -81,39 +103,38 @@ constexpr bool isHeuristicInlineCostFeature(InlineCostFeatureIndex Feature) {
 // programmatically, and serves as workaround to inability of inserting comments
 // in macros.
 #define INLINE_FEATURE_ITERATOR(M)                                             \
-  M(CalleeBasicBlockCount, "callee_basic_block_count",                         \
+  M(int64_t, ScalarShape, callee_basic_block_count,                            \
     "number of basic blocks of the callee")                                    \
-  M(CallSiteHeight, "callsite_height",                                         \
+  M(int64_t, ScalarShape, callsite_height,                                     \
     "position of the call site in the original call graph - measured from "    \
     "the farthest SCC")                                                        \
-  M(NodeCount, "node_count",                                                   \
+  M(int64_t, ScalarShape, node_count,                                          \
     "total current number of defined functions in the module")                 \
-  M(NrCtantParams, "nr_ctant_params",                                          \
+  M(int64_t, ScalarShape, nr_ctant_params,                                     \
     "number of parameters in the call site that are constants")                \
-  M(CostEstimate, "cost_estimate", "total cost estimate (threshold - free)")   \
-  M(EdgeCount, "edge_count", "total number of calls in the module")            \
-  M(CallerUsers, "caller_users",                                               \
+  M(int64_t, ScalarShape, cost_estimate,                                       \
+    "total cost estimate (threshold - free)")                                  \
+  M(int64_t, ScalarShape, edge_count, "total number of calls in the module")   \
+  M(int64_t, ScalarShape, caller_users,                                        \
     "number of module-internal users of the caller, +1 if the caller is "      \
     "exposed externally")                                                      \
-  M(CallerConditionallyExecutedBlocks, "caller_conditionally_executed_blocks", \
+  M(int64_t, ScalarShape, caller_conditionally_executed_blocks,                \
     "number of blocks reached from a conditional instruction, in the caller")  \
-  M(CallerBasicBlockCount, "caller_basic_block_count",                         \
+  M(int64_t, ScalarShape, caller_basic_block_count,                            \
     "number of basic blocks in the caller")                                    \
-  M(CalleeConditionallyExecutedBlocks, "callee_conditionally_executed_blocks", \
+  M(int64_t, ScalarShape, callee_conditionally_executed_blocks,                \
     "number of blocks reached from a conditional instruction, in the callee")  \
-  M(CalleeUsers, "callee_users",                                               \
+  M(int64_t, ScalarShape, callee_users,                                        \
     "number of module-internal users of the callee, +1 if the callee is "      \
     "exposed externally")
 
 // clang-format off
 enum class FeatureIndex : size_t {
+#define POPULATE_INDICES(DTYPE, SHAPE, NAME, COMMENT) NAME,
 // InlineCost features - these must come first
-#define POPULATE_INDICES(INDEX_NAME, NAME) INDEX_NAME,
   INLINE_COST_FEATURE_ITERATOR(POPULATE_INDICES)
-#undef POPULATE_INDICES
 
 // Non-cost features
-#define POPULATE_INDICES(INDEX_NAME, NAME, COMMENT) INDEX_NAME,
   INLINE_FEATURE_ITERATOR(POPULATE_INDICES)
 #undef POPULATE_INDICES
 

diff  --git a/llvm/lib/Analysis/InlineCost.cpp b/llvm/lib/Analysis/InlineCost.cpp
index 6877e44813f5..09d4b81ad6c4 100644
--- a/llvm/lib/Analysis/InlineCost.cpp
+++ b/llvm/lib/Analysis/InlineCost.cpp
@@ -1108,31 +1108,31 @@ class InlineCostFeaturesAnalyzer final : public CallAnalyzer {
     if (CostIt == SROACosts.end())
       return;
 
-    increment(InlineCostFeatureIndex::SROALosses, CostIt->second);
+    increment(InlineCostFeatureIndex::sroa_losses, CostIt->second);
     SROACostSavingOpportunities -= CostIt->second;
     SROACosts.erase(CostIt);
   }
 
   void onDisableLoadElimination() override {
-    set(InlineCostFeatureIndex::LoadElimination, 1);
+    set(InlineCostFeatureIndex::load_elimination, 1);
   }
 
   void onCallPenalty() override {
-    increment(InlineCostFeatureIndex::CallPenalty, CallPenalty);
+    increment(InlineCostFeatureIndex::call_penalty, CallPenalty);
   }
 
   void onCallArgumentSetup(const CallBase &Call) override {
-    increment(InlineCostFeatureIndex::CallArgumentSetup,
+    increment(InlineCostFeatureIndex::call_argument_setup,
               Call.arg_size() * InstrCost);
   }
 
   void onLoadRelativeIntrinsic() override {
-    increment(InlineCostFeatureIndex::LoadRelativeIntrinsic, 3 * InstrCost);
+    increment(InlineCostFeatureIndex::load_relative_intrinsic, 3 * InstrCost);
   }
 
   void onLoweredCall(Function *F, CallBase &Call,
                      bool IsIndirectCall) override {
-    increment(InlineCostFeatureIndex::LoweredCallArgSetup,
+    increment(InlineCostFeatureIndex::lowered_call_arg_setup,
               Call.arg_size() * InstrCost);
 
     if (IsIndirectCall) {
@@ -1153,9 +1153,9 @@ class InlineCostFeaturesAnalyzer final : public CallAnalyzer {
                                 GetAssumptionCache, GetBFI, PSI, ORE, false,
                                 true);
       if (CA.analyze().isSuccess()) {
-        increment(InlineCostFeatureIndex::NestedInlineCostEstimate,
+        increment(InlineCostFeatureIndex::nested_inline_cost_estimate,
                   CA.getCost());
-        increment(InlineCostFeatureIndex::NestedInlines, 1);
+        increment(InlineCostFeatureIndex::nested_inlines, 1);
       }
     } else {
       onCallPenalty();
@@ -1168,12 +1168,12 @@ class InlineCostFeaturesAnalyzer final : public CallAnalyzer {
     if (JumpTableSize) {
       int64_t JTCost = static_cast<int64_t>(JumpTableSize) * InstrCost +
                        JTCostMultiplier * InstrCost;
-      increment(InlineCostFeatureIndex::JumpTablePenalty, JTCost);
+      increment(InlineCostFeatureIndex::jump_table_penalty, JTCost);
       return;
     }
 
     if (NumCaseCluster <= 3) {
-      increment(InlineCostFeatureIndex::CaseClusterPenalty,
+      increment(InlineCostFeatureIndex::case_cluster_penalty,
                 NumCaseCluster * CaseClusterCostMultiplier * InstrCost);
       return;
     }
@@ -1183,11 +1183,11 @@ class InlineCostFeaturesAnalyzer final : public CallAnalyzer {
 
     int64_t SwitchCost =
         ExpectedNumberOfCompare * SwitchCostMultiplier * InstrCost;
-    increment(InlineCostFeatureIndex::SwitchPenalty, SwitchCost);
+    increment(InlineCostFeatureIndex::switch_penalty, SwitchCost);
   }
 
   void onMissedSimplification() override {
-    increment(InlineCostFeatureIndex::UnsimplifiedCommonInstructions,
+    increment(InlineCostFeatureIndex::unsimplified_common_instructions,
               InstrCost);
   }
 
@@ -1199,7 +1199,7 @@ class InlineCostFeaturesAnalyzer final : public CallAnalyzer {
 
   void onBlockAnalyzed(const BasicBlock *BB) override {
     if (BB->getTerminator()->getNumSuccessors() > 1)
-      set(InlineCostFeatureIndex::IsMultipleBlocks, 1);
+      set(InlineCostFeatureIndex::is_multiple_blocks, 1);
     Threshold -= SingleBBBonus;
   }
 
@@ -1212,24 +1212,24 @@ class InlineCostFeaturesAnalyzer final : public CallAnalyzer {
         // Ignore loops that will not be executed
         if (DeadBlocks.count(L->getHeader()))
           continue;
-        increment(InlineCostFeatureIndex::NumLoops,
+        increment(InlineCostFeatureIndex::num_loops,
                   InlineConstants::LoopPenalty);
       }
     }
-    set(InlineCostFeatureIndex::DeadBlocks, DeadBlocks.size());
-    set(InlineCostFeatureIndex::SimplifiedInstructions,
+    set(InlineCostFeatureIndex::dead_blocks, DeadBlocks.size());
+    set(InlineCostFeatureIndex::simplified_instructions,
         NumInstructionsSimplified);
-    set(InlineCostFeatureIndex::ConstantArgs, NumConstantArgs);
-    set(InlineCostFeatureIndex::ConstantOffsetPtrArgs,
+    set(InlineCostFeatureIndex::constant_args, NumConstantArgs);
+    set(InlineCostFeatureIndex::constant_offset_ptr_args,
         NumConstantOffsetPtrArgs);
-    set(InlineCostFeatureIndex::SROASavings, SROACostSavingOpportunities);
+    set(InlineCostFeatureIndex::sroa_savings, SROACostSavingOpportunities);
 
     if (NumVectorInstructions <= NumInstructions / 10)
       Threshold -= VectorBonus;
     else if (NumVectorInstructions <= NumInstructions / 2)
       Threshold -= VectorBonus / 2;
 
-    set(InlineCostFeatureIndex::Threshold, Threshold);
+    set(InlineCostFeatureIndex::threshold, Threshold);
 
     return InlineResult::success();
   }
@@ -1237,17 +1237,17 @@ class InlineCostFeaturesAnalyzer final : public CallAnalyzer {
   bool shouldStop() override { return false; }
 
   void onLoadEliminationOpportunity() override {
-    increment(InlineCostFeatureIndex::LoadElimination, 1);
+    increment(InlineCostFeatureIndex::load_elimination, 1);
   }
 
   InlineResult onAnalysisStart() override {
-    increment(InlineCostFeatureIndex::CallSiteCost,
+    increment(InlineCostFeatureIndex::callsite_cost,
               -1 * getCallsiteCost(this->CandidateCall, DL));
 
-    set(InlineCostFeatureIndex::ColdCcPenalty,
+    set(InlineCostFeatureIndex::cold_cc_penalty,
         (F.getCallingConv() == CallingConv::Cold));
 
-    set(InlineCostFeatureIndex::LastCallToStaticBonus,
+    set(InlineCostFeatureIndex::last_call_to_static_bonus,
         isSoleCallToLocalFunction(CandidateCall, F));
 
     // FIXME: we shouldn't repeat this logic in both the Features and Cost

diff  --git a/llvm/lib/Analysis/MLInlineAdvisor.cpp b/llvm/lib/Analysis/MLInlineAdvisor.cpp
index 49d5deea8fc9..e7faad51a501 100644
--- a/llvm/lib/Analysis/MLInlineAdvisor.cpp
+++ b/llvm/lib/Analysis/MLInlineAdvisor.cpp
@@ -46,6 +46,8 @@ static cl::opt<bool>
     InteractiveIncludeDefault("inliner-interactive-include-default", cl::Hidden,
                               cl::desc(InclDefaultMsg));
 
+const std::vector<int64_t> ScalarShape = {1};
+
 #if defined(LLVM_HAVE_TF_AOT_INLINERSIZEMODEL)
 // codegen-ed file
 #include "InlinerSizeModel.h" // NOLINT
@@ -93,13 +95,11 @@ static cl::opt<bool> KeepFPICache(
 
 // clang-format off
 const std::vector<TensorSpec> llvm::FeatureMap{
-#define POPULATE_NAMES(_, NAME) TensorSpec::createSpec<int64_t>(NAME, {1} ),
+#define POPULATE_NAMES(DTYPE, SHAPE, NAME, __) TensorSpec::createSpec<DTYPE>(#NAME, SHAPE),
 // InlineCost features - these must come first
   INLINE_COST_FEATURE_ITERATOR(POPULATE_NAMES)
-#undef POPULATE_NAMES
 
 // Non-cost features
-#define POPULATE_NAMES(_, NAME, __) TensorSpec::createSpec<int64_t>(NAME, {1} ),
   INLINE_FEATURE_ITERATOR(POPULATE_NAMES)
 #undef POPULATE_NAMES
 };
@@ -383,26 +383,27 @@ std::unique_ptr<InlineAdvice> MLInlineAdvisor::getAdviceImpl(CallBase &CB) {
   auto &CallerBefore = getCachedFPI(Caller);
   auto &CalleeBefore = getCachedFPI(Callee);
 
-  *ModelRunner->getTensor<int64_t>(FeatureIndex::CalleeBasicBlockCount) =
+  *ModelRunner->getTensor<int64_t>(FeatureIndex::callee_basic_block_count) =
       CalleeBefore.BasicBlockCount;
-  *ModelRunner->getTensor<int64_t>(FeatureIndex::CallSiteHeight) =
+  *ModelRunner->getTensor<int64_t>(FeatureIndex::callsite_height) =
       getInitialFunctionLevel(Caller);
-  *ModelRunner->getTensor<int64_t>(FeatureIndex::NodeCount) = NodeCount;
-  *ModelRunner->getTensor<int64_t>(FeatureIndex::NrCtantParams) = NrCtantParams;
-  *ModelRunner->getTensor<int64_t>(FeatureIndex::EdgeCount) = EdgeCount;
-  *ModelRunner->getTensor<int64_t>(FeatureIndex::CallerUsers) =
+  *ModelRunner->getTensor<int64_t>(FeatureIndex::node_count) = NodeCount;
+  *ModelRunner->getTensor<int64_t>(FeatureIndex::nr_ctant_params) =
+      NrCtantParams;
+  *ModelRunner->getTensor<int64_t>(FeatureIndex::edge_count) = EdgeCount;
+  *ModelRunner->getTensor<int64_t>(FeatureIndex::caller_users) =
       CallerBefore.Uses;
   *ModelRunner->getTensor<int64_t>(
-      FeatureIndex::CallerConditionallyExecutedBlocks) =
+      FeatureIndex::caller_conditionally_executed_blocks) =
       CallerBefore.BlocksReachedFromConditionalInstruction;
-  *ModelRunner->getTensor<int64_t>(FeatureIndex::CallerBasicBlockCount) =
+  *ModelRunner->getTensor<int64_t>(FeatureIndex::caller_basic_block_count) =
       CallerBefore.BasicBlockCount;
   *ModelRunner->getTensor<int64_t>(
-      FeatureIndex::CalleeConditionallyExecutedBlocks) =
+      FeatureIndex::callee_conditionally_executed_blocks) =
       CalleeBefore.BlocksReachedFromConditionalInstruction;
-  *ModelRunner->getTensor<int64_t>(FeatureIndex::CalleeUsers) =
+  *ModelRunner->getTensor<int64_t>(FeatureIndex::callee_users) =
       CalleeBefore.Uses;
-  *ModelRunner->getTensor<int64_t>(FeatureIndex::CostEstimate) = CostEstimate;
+  *ModelRunner->getTensor<int64_t>(FeatureIndex::cost_estimate) = CostEstimate;
 
   // Add the cost features
   for (size_t I = 0;


        


More information about the llvm-commits mailing list