[llvm-branch-commits] [llvm] [MLIniner][IR2Vec] Integrating IR2Vec with MLInliner (PR #143479)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Mon Jun 9 22:50:56 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-analysis

Author: S. VenkataKeerthy (svkeerthy)

<details>
<summary>Changes</summary>

Changes to use Symbolic embeddings in MLInliner. 

(Fixes #<!-- -->141836, Tracking issue - #<!-- -->141817)

---

Patch is 36.65 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/143479.diff


10 Files Affected:

- (modified) llvm/include/llvm/Analysis/FunctionPropertiesAnalysis.h (+21-5) 
- (modified) llvm/include/llvm/Analysis/IR2Vec.h (+1-1) 
- (modified) llvm/include/llvm/Analysis/InlineAdvisor.h (+3) 
- (modified) llvm/include/llvm/Analysis/InlineModelFeatureMaps.h (+5-1) 
- (modified) llvm/include/llvm/Analysis/MLInlineAdvisor.h (+1) 
- (modified) llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp (+112-3) 
- (modified) llvm/lib/Analysis/IR2Vec.cpp (+2-2) 
- (modified) llvm/lib/Analysis/InlineAdvisor.cpp (+29) 
- (modified) llvm/lib/Analysis/MLInlineAdvisor.cpp (+33-1) 
- (modified) llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp (+154-25) 


``````````diff
diff --git a/llvm/include/llvm/Analysis/FunctionPropertiesAnalysis.h b/llvm/include/llvm/Analysis/FunctionPropertiesAnalysis.h
index babb6d9d6cf0c..06dbfc35a5294 100644
--- a/llvm/include/llvm/Analysis/FunctionPropertiesAnalysis.h
+++ b/llvm/include/llvm/Analysis/FunctionPropertiesAnalysis.h
@@ -15,6 +15,7 @@
 #define LLVM_ANALYSIS_FUNCTIONPROPERTIESANALYSIS_H
 
 #include "llvm/ADT/DenseSet.h"
+#include "llvm/Analysis/IR2Vec.h"
 #include "llvm/IR/Dominators.h"
 #include "llvm/IR/PassManager.h"
 #include "llvm/Support/Compiler.h"
@@ -32,17 +33,19 @@ class FunctionPropertiesInfo {
   void updateAggregateStats(const Function &F, const LoopInfo &LI);
   void reIncludeBB(const BasicBlock &BB);
 
+  ir2vec::Embedding FunctionEmbedding = ir2vec::Embedding(0.0);
+  std::optional<ir2vec::Vocab> IR2VecVocab;
+
 public:
   LLVM_ABI static FunctionPropertiesInfo
   getFunctionPropertiesInfo(const Function &F, const DominatorTree &DT,
-                            const LoopInfo &LI);
+                            const LoopInfo &LI,
+                            const IR2VecVocabResult *VocabResult);
 
   LLVM_ABI static FunctionPropertiesInfo
   getFunctionPropertiesInfo(Function &F, FunctionAnalysisManager &FAM);
 
-  bool operator==(const FunctionPropertiesInfo &FPI) const {
-    return std::memcmp(this, &FPI, sizeof(FunctionPropertiesInfo)) == 0;
-  }
+  bool operator==(const FunctionPropertiesInfo &FPI) const;
 
   bool operator!=(const FunctionPropertiesInfo &FPI) const {
     return !(*this == FPI);
@@ -137,6 +140,19 @@ class FunctionPropertiesInfo {
   int64_t CallReturnsVectorPointerCount = 0;
   int64_t CallWithManyArgumentsCount = 0;
   int64_t CallWithPointerArgumentCount = 0;
+
+  const ir2vec::Embedding &getFunctionEmbedding() const {
+    return FunctionEmbedding;
+  }
+
+  const std::optional<ir2vec::Vocab> &getIR2VecVocab() const {
+    return IR2VecVocab;
+  }
+
+  // Helper intended to be useful for unittests
+  void setFunctionEmbeddingForTest(const ir2vec::Embedding &Embedding) {
+    FunctionEmbedding = Embedding;
+  }
 };
 
 // Analysis pass
@@ -192,7 +208,7 @@ class FunctionPropertiesUpdater {
 
   DominatorTree &getUpdatedDominatorTree(FunctionAnalysisManager &FAM) const;
 
-  DenseSet<const BasicBlock *> Successors;
+  DenseSet<const BasicBlock *> Successors, CallUsers;
 
   // Edges we might potentially need to remove from the dominator tree.
   SmallVector<DominatorTree::UpdateType, 2> DomTreeUpdates;
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 7976cc7470d5b..3f44c650e640d 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -239,7 +239,7 @@ class IR2VecVocabAnalysis : public AnalysisInfoMixin<IR2VecVocabAnalysis> {
 public:
   static AnalysisKey Key;
   IR2VecVocabAnalysis() = default;
-  explicit IR2VecVocabAnalysis(ir2vec::Vocab &&Vocab);
+  explicit IR2VecVocabAnalysis(ir2vec::Vocab Vocab);
   using Result = IR2VecVocabResult;
   Result run(Module &M, ModuleAnalysisManager &MAM);
 };
diff --git a/llvm/include/llvm/Analysis/InlineAdvisor.h b/llvm/include/llvm/Analysis/InlineAdvisor.h
index 9d15136e81d10..d2cad4717cbdb 100644
--- a/llvm/include/llvm/Analysis/InlineAdvisor.h
+++ b/llvm/include/llvm/Analysis/InlineAdvisor.h
@@ -331,6 +331,9 @@ class InlineAdvisorAnalysis : public AnalysisInfoMixin<InlineAdvisorAnalysis> {
   };
 
   Result run(Module &M, ModuleAnalysisManager &MAM) { return Result(M, MAM); }
+
+private:
+  static bool initializeIR2VecVocab(Module &M, ModuleAnalysisManager &MAM);
 };
 
 /// Printer pass for the InlineAdvisorAnalysis results.
diff --git a/llvm/include/llvm/Analysis/InlineModelFeatureMaps.h b/llvm/include/llvm/Analysis/InlineModelFeatureMaps.h
index 961d5091bf9f3..91d3378565fc5 100644
--- a/llvm/include/llvm/Analysis/InlineModelFeatureMaps.h
+++ b/llvm/include/llvm/Analysis/InlineModelFeatureMaps.h
@@ -142,6 +142,10 @@ enum class FeatureIndex : size_t {
   INLINE_FEATURE_ITERATOR(POPULATE_INDICES)
 #undef POPULATE_INDICES
 
+// IR2Vec embeddings
+  callee_embedding,
+  caller_embedding,
+
   NumberOfFeatures
 };
 // clang-format on
@@ -154,7 +158,7 @@ inlineCostFeatureToMlFeature(InlineCostFeatureIndex Feature) {
 constexpr size_t NumberOfFeatures =
     static_cast<size_t>(FeatureIndex::NumberOfFeatures);
 
-LLVM_ABI extern const std::vector<TensorSpec> FeatureMap;
+LLVM_ABI extern std::vector<TensorSpec> FeatureMap;
 
 LLVM_ABI extern const char *const DecisionName;
 LLVM_ABI extern const TensorSpec InlineDecisionSpec;
diff --git a/llvm/include/llvm/Analysis/MLInlineAdvisor.h b/llvm/include/llvm/Analysis/MLInlineAdvisor.h
index 580dd5e95d760..935e4c56dfce6 100644
--- a/llvm/include/llvm/Analysis/MLInlineAdvisor.h
+++ b/llvm/include/llvm/Analysis/MLInlineAdvisor.h
@@ -82,6 +82,7 @@ class MLInlineAdvisor : public InlineAdvisor {
   int64_t NodeCount = 0;
   int64_t EdgeCount = 0;
   int64_t EdgesOfLastSeenNodes = 0;
+  bool UseIR2Vec = false;
 
   std::map<const LazyCallGraph::Node *, unsigned> FunctionLevels;
   const int32_t InitialIRSize = 0;
diff --git a/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp b/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp
index 9d044c8a35910..29d3aaf46dc06 100644
--- a/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp
+++ b/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp
@@ -199,6 +199,29 @@ void FunctionPropertiesInfo::updateForBB(const BasicBlock &BB,
 #undef CHECK_OPERAND
     }
   }
+
+  if (IR2VecVocab) {
+    // We instantiate the IR2Vec embedder each time, as having an unique
+    // pointer to the embedder as member of the class would make it
+    // non-copyable. Instantiating the embedder in itself is not costly.
+    auto EmbOrErr = ir2vec::Embedder::create(IR2VecKind::Symbolic,
+                                             *BB.getParent(), *IR2VecVocab);
+    if (Error Err = EmbOrErr.takeError()) {
+      handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
+        BB.getContext().emitError("Error creating IR2Vec embeddings: " +
+                                  EI.message());
+      });
+      return;
+    }
+    auto Embedder = std::move(*EmbOrErr);
+    const auto &BBEmbedding = Embedder->getBBVector(BB);
+    // Subtract BBEmbedding from Function embedding if the direction is -1,
+    // and add it if the direction is +1.
+    if (Direction == -1)
+      FunctionEmbedding -= BBEmbedding;
+    else
+      FunctionEmbedding += BBEmbedding;
+  }
 }
 
 void FunctionPropertiesInfo::updateAggregateStats(const Function &F,
@@ -220,14 +243,24 @@ void FunctionPropertiesInfo::updateAggregateStats(const Function &F,
 
 FunctionPropertiesInfo FunctionPropertiesInfo::getFunctionPropertiesInfo(
     Function &F, FunctionAnalysisManager &FAM) {
+  // We use the cached result of the IR2VecVocabAnalysis run by
+  // InlineAdvisorAnalysis. If the IR2VecVocabAnalysis is not run, we don't
+  // use IR2Vec embeddings.
+  auto VocabResult = FAM.getResult<ModuleAnalysisManagerFunctionProxy>(F)
+                         .getCachedResult<IR2VecVocabAnalysis>(*F.getParent());
   return getFunctionPropertiesInfo(F, FAM.getResult<DominatorTreeAnalysis>(F),
-                                   FAM.getResult<LoopAnalysis>(F));
+                                   FAM.getResult<LoopAnalysis>(F), VocabResult);
 }
 
 FunctionPropertiesInfo FunctionPropertiesInfo::getFunctionPropertiesInfo(
-    const Function &F, const DominatorTree &DT, const LoopInfo &LI) {
+    const Function &F, const DominatorTree &DT, const LoopInfo &LI,
+    const IR2VecVocabResult *VocabResult) {
 
   FunctionPropertiesInfo FPI;
+  if (VocabResult && VocabResult->isValid()) {
+    FPI.IR2VecVocab = VocabResult->getVocabulary();
+    FPI.FunctionEmbedding = ir2vec::Embedding(VocabResult->getDimension(), 0.0);
+  }
   for (const auto &BB : F)
     if (DT.isReachableFromEntry(&BB))
       FPI.reIncludeBB(BB);
@@ -235,6 +268,66 @@ FunctionPropertiesInfo FunctionPropertiesInfo::getFunctionPropertiesInfo(
   return FPI;
 }
 
+bool FunctionPropertiesInfo::operator==(
+    const FunctionPropertiesInfo &FPI) const {
+  if (BasicBlockCount != FPI.BasicBlockCount ||
+      BlocksReachedFromConditionalInstruction !=
+          FPI.BlocksReachedFromConditionalInstruction ||
+      Uses != FPI.Uses ||
+      DirectCallsToDefinedFunctions != FPI.DirectCallsToDefinedFunctions ||
+      LoadInstCount != FPI.LoadInstCount ||
+      StoreInstCount != FPI.StoreInstCount ||
+      MaxLoopDepth != FPI.MaxLoopDepth ||
+      TopLevelLoopCount != FPI.TopLevelLoopCount ||
+      TotalInstructionCount != FPI.TotalInstructionCount ||
+      BasicBlocksWithSingleSuccessor != FPI.BasicBlocksWithSingleSuccessor ||
+      BasicBlocksWithTwoSuccessors != FPI.BasicBlocksWithTwoSuccessors ||
+      BasicBlocksWithMoreThanTwoSuccessors !=
+          FPI.BasicBlocksWithMoreThanTwoSuccessors ||
+      BasicBlocksWithSinglePredecessor !=
+          FPI.BasicBlocksWithSinglePredecessor ||
+      BasicBlocksWithTwoPredecessors != FPI.BasicBlocksWithTwoPredecessors ||
+      BasicBlocksWithMoreThanTwoPredecessors !=
+          FPI.BasicBlocksWithMoreThanTwoPredecessors ||
+      BigBasicBlocks != FPI.BigBasicBlocks ||
+      MediumBasicBlocks != FPI.MediumBasicBlocks ||
+      SmallBasicBlocks != FPI.SmallBasicBlocks ||
+      CastInstructionCount != FPI.CastInstructionCount ||
+      FloatingPointInstructionCount != FPI.FloatingPointInstructionCount ||
+      IntegerInstructionCount != FPI.IntegerInstructionCount ||
+      ConstantIntOperandCount != FPI.ConstantIntOperandCount ||
+      ConstantFPOperandCount != FPI.ConstantFPOperandCount ||
+      ConstantOperandCount != FPI.ConstantOperandCount ||
+      InstructionOperandCount != FPI.InstructionOperandCount ||
+      BasicBlockOperandCount != FPI.BasicBlockOperandCount ||
+      GlobalValueOperandCount != FPI.GlobalValueOperandCount ||
+      InlineAsmOperandCount != FPI.InlineAsmOperandCount ||
+      ArgumentOperandCount != FPI.ArgumentOperandCount ||
+      UnknownOperandCount != FPI.UnknownOperandCount ||
+      CriticalEdgeCount != FPI.CriticalEdgeCount ||
+      ControlFlowEdgeCount != FPI.ControlFlowEdgeCount ||
+      UnconditionalBranchCount != FPI.UnconditionalBranchCount ||
+      IntrinsicCount != FPI.IntrinsicCount ||
+      DirectCallCount != FPI.DirectCallCount ||
+      IndirectCallCount != FPI.IndirectCallCount ||
+      CallReturnsIntegerCount != FPI.CallReturnsIntegerCount ||
+      CallReturnsFloatCount != FPI.CallReturnsFloatCount ||
+      CallReturnsPointerCount != FPI.CallReturnsPointerCount ||
+      CallReturnsVectorIntCount != FPI.CallReturnsVectorIntCount ||
+      CallReturnsVectorFloatCount != FPI.CallReturnsVectorFloatCount ||
+      CallReturnsVectorPointerCount != FPI.CallReturnsVectorPointerCount ||
+      CallWithManyArgumentsCount != FPI.CallWithManyArgumentsCount ||
+      CallWithPointerArgumentCount != FPI.CallWithPointerArgumentCount) {
+    return false;
+  }
+  // Check the equality of the function embeddings. We don't check the equality
+  // of Vocabulary as it remains the same.
+  if (!FunctionEmbedding.approximatelyEquals(FPI.FunctionEmbedding))
+    return false;
+
+  return true;
+}
+
 void FunctionPropertiesInfo::print(raw_ostream &OS) const {
 #define PRINT_PROPERTY(PROP_NAME) OS << #PROP_NAME ": " << PROP_NAME << "\n";
 
@@ -322,6 +415,16 @@ FunctionPropertiesUpdater::FunctionPropertiesUpdater(
   // The caller's entry BB may change due to new alloca instructions.
   LikelyToChangeBBs.insert(&*Caller.begin());
 
+  // The users of the value returned by call instruction can change
+  // leading to the change in embeddings being computed, when used.
+  // We conservatively add the BBs with such uses to LikelyToChangeBBs.
+  for (const auto *User : CB.users())
+    CallUsers.insert(dyn_cast<Instruction>(User)->getParent());
+  // CallSiteBB can be removed from CallUsers if present, it's taken care
+  // separately.
+  CallUsers.erase(&CallSiteBB);
+  LikelyToChangeBBs.insert_range(CallUsers);
+
   // The successors may become unreachable in the case of `invoke` inlining.
   // We track successors separately, too, because they form a boundary, together
   // with the CB BB ('Entry') between which the inlined callee will be pasted.
@@ -435,6 +538,9 @@ void FunctionPropertiesUpdater::finish(FunctionAnalysisManager &FAM) const {
   if (&CallSiteBB != &*Caller.begin())
     Reinclude.insert(&*Caller.begin());
 
+  // Reinclude the BBs which use the values returned by call instruction
+  Reinclude.insert_range(CallUsers);
+
   // Distribute the successors to the 2 buckets.
   for (const auto *Succ : Successors)
     if (DT.isReachableFromEntry(Succ))
@@ -486,6 +592,9 @@ bool FunctionPropertiesUpdater::isUpdateValid(Function &F,
     return false;
   DominatorTree DT(F);
   LoopInfo LI(DT);
-  auto Fresh = FunctionPropertiesInfo::getFunctionPropertiesInfo(F, DT, LI);
+  auto VocabResult = FAM.getResult<ModuleAnalysisManagerFunctionProxy>(F)
+                         .getCachedResult<IR2VecVocabAnalysis>(*F.getParent());
+  auto Fresh =
+      FunctionPropertiesInfo::getFunctionPropertiesInfo(F, DT, LI, VocabResult);
   return FPI == Fresh;
 }
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 8a392e0709c7f..5f2245ad6aafb 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -294,8 +294,8 @@ Error IR2VecVocabAnalysis::readVocabulary() {
   return Error::success();
 }
 
-IR2VecVocabAnalysis::IR2VecVocabAnalysis(Vocab &&Vocabulary)
-    : Vocabulary(std::move(Vocabulary)) {}
+IR2VecVocabAnalysis::IR2VecVocabAnalysis(Vocab Vocabulary)
+    : Vocabulary(Vocabulary) {}
 
 void IR2VecVocabAnalysis::emitError(Error Err, LLVMContext &Ctx) {
   handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
diff --git a/llvm/lib/Analysis/InlineAdvisor.cpp b/llvm/lib/Analysis/InlineAdvisor.cpp
index 3d30f3d10a9d0..2e869dfd91713 100644
--- a/llvm/lib/Analysis/InlineAdvisor.cpp
+++ b/llvm/lib/Analysis/InlineAdvisor.cpp
@@ -16,6 +16,7 @@
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Analysis/AssumptionCache.h"
 #include "llvm/Analysis/EphemeralValuesCache.h"
+#include "llvm/Analysis/IR2Vec.h"
 #include "llvm/Analysis/InlineCost.h"
 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
 #include "llvm/Analysis/ProfileSummaryInfo.h"
@@ -64,6 +65,13 @@ static cl::opt<bool>
                         cl::desc("If true, annotate inline advisor remarks "
                                  "with LTO and pass information."));
 
+// This flag is used to enable IR2Vec embeddings in the ML inliner; Only valid
+// with ML inliner. The vocab file is used to initialize the embeddings.
+static cl::opt<std::string> IR2VecVocabFile(
+    "ml-inliner-ir2vec-vocab-file", cl::Hidden,
+    cl::desc("Vocab file for IR2Vec; Setting this enables "
+             "configuring the model to use IR2Vec embeddings."));
+
 namespace llvm {
 extern cl::opt<InlinerFunctionImportStatsOpts> InlinerFunctionImportStats;
 } // namespace llvm
@@ -206,6 +214,20 @@ void InlineAdvice::recordInliningWithCalleeDeleted() {
 AnalysisKey InlineAdvisorAnalysis::Key;
 AnalysisKey PluginInlineAdvisorAnalysis::Key;
 
+bool InlineAdvisorAnalysis::initializeIR2VecVocab(Module &M,
+                                                  ModuleAnalysisManager &MAM) {
+  if (!IR2VecVocabFile.empty()) {
+    auto IR2VecVocabResult = MAM.getResult<IR2VecVocabAnalysis>(M);
+    if (!IR2VecVocabResult.isValid()) {
+      M.getContext().emitError("Failed to load IR2Vec vocabulary");
+      return false;
+    }
+  }
+  // No vocab file specified is OK; We just don't use IR2Vec
+  // embeddings.
+  return true;
+}
+
 bool InlineAdvisorAnalysis::Result::tryCreate(
     InlineParams Params, InliningAdvisorMode Mode,
     const ReplayInlinerSettings &ReplaySettings, InlineContext IC) {
@@ -231,14 +253,21 @@ bool InlineAdvisorAnalysis::Result::tryCreate(
                                              /* EmitRemarks =*/true, IC);
     }
     break;
+    // Run IR2VecVocabAnalysis once per module to get the vocabulary.
+    // We run it here because it is immutable and we want to avoid running it
+    // multiple times.
   case InliningAdvisorMode::Development:
 #ifdef LLVM_HAVE_TFLITE
     LLVM_DEBUG(dbgs() << "Using development-mode inliner policy.\n");
+    if (!InlineAdvisorAnalysis::initializeIR2VecVocab(M, MAM))
+      return false;
     Advisor = llvm::getDevelopmentModeAdvisor(M, MAM, GetDefaultAdvice);
 #endif
     break;
   case InliningAdvisorMode::Release:
     LLVM_DEBUG(dbgs() << "Using release-mode inliner policy.\n");
+    if (!InlineAdvisorAnalysis::initializeIR2VecVocab(M, MAM))
+      return false;
     Advisor = llvm::getReleaseModeAdvisor(M, MAM, GetDefaultAdvice);
     break;
   }
diff --git a/llvm/lib/Analysis/MLInlineAdvisor.cpp b/llvm/lib/Analysis/MLInlineAdvisor.cpp
index 81a3bc94a6ad8..3a9a68670e852 100644
--- a/llvm/lib/Analysis/MLInlineAdvisor.cpp
+++ b/llvm/lib/Analysis/MLInlineAdvisor.cpp
@@ -107,7 +107,7 @@ static cl::opt<bool> KeepFPICache(
     cl::init(false));
 
 // clang-format off
-const std::vector<TensorSpec> llvm::FeatureMap{
+std::vector<TensorSpec> llvm::FeatureMap{
 #define POPULATE_NAMES(DTYPE, SHAPE, NAME, __) TensorSpec::createSpec<DTYPE>(#NAME, SHAPE),
 // InlineCost features - these must come first
   INLINE_COST_FEATURE_ITERATOR(POPULATE_NAMES)
@@ -186,6 +186,20 @@ MLInlineAdvisor::MLInlineAdvisor(
     EdgeCount += getLocalCalls(KVP.first->getFunction());
   }
   NodeCount = AllNodes.size();
+
+  if (auto IR2VecVocabResult = MAM.getCachedResult<IR2VecVocabAnalysis>(M)) {
+    if (!IR2VecVocabResult->isValid()) {
+      M.getContext().emitError("IR2VecVocabAnalysis is not valid");
+      return;
+    }
+    // Add the IR2Vec features to the feature map
+    auto IR2VecDim = IR2VecVocabResult->getDimension();
+    FeatureMap.push_back(
+        TensorSpec::createSpec<float>("callee_embedding", {IR2VecDim}));
+    FeatureMap.push_back(
+        TensorSpec::createSpec<float>("caller_embedding", {IR2VecDim}));
+    UseIR2Vec = true;
+  }
 }
 
 unsigned MLInlineAdvisor::getInitialFunctionLevel(const Function &F) const {
@@ -433,6 +447,24 @@ std::unique_ptr<InlineAdvice> MLInlineAdvisor::getAdviceImpl(CallBase &CB) {
   *ModelRunner->getTensor<int64_t>(FeatureIndex::is_caller_avail_external) =
       Caller.hasAvailableExternallyLinkage();
 
+  if (UseIR2Vec) {
+    // Python side expects float embeddings. The IR2Vec embeddings are doubles
+    // as of now due to the restriction of fromJSON method used by the
+    // readVocabulary method in ir2vec::Embeddings.
+    auto setEmbedding = [&](const ir2vec::Embedding &Embedding,
+                            FeatureIndex Index) {
+      auto Embedding_float =
+          std::vector<float>(Embedding.begin(), Embedding.end());
+      std::memcpy(ModelRunner->getTensor<float>(Index), Embedding_float.data(),
+                  Embedding.size() * sizeof(float));
+    };
+
+    setEmbedding(CalleeBefore.getFunctionEmbedding(),
+                 FeatureIndex::callee_embedding);
+    setEmbedding(CallerBefore.getFunctionEmbedding(),
+                 FeatureIndex::caller_embedding);
+  }
+
   // Add the cost features
   for (size_t I = 0;
        I < static_cast<size_t>(InlineCostFeatureIndex::NumberOfFeatures); ++I) {
diff --git a/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp b/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp
index 0720d935b0362..3ef2964f2d170 100644
--- a/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp
+++ b/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp
@@ -8,6 +8,7 @@
 
 #include "llvm/Analysis/FunctionPropertiesAnalysis.h"
 #include "llvm/Analysis/AliasAnalysis.h"
+#include "llvm/Analysis/IR2Vec.h"
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/AsmParser/Parser.h"
 #include "llvm/IR/Dominators.h"
@@ -20,15 +21,20 @@
 #include "llvm/Support/Compiler.h"
 #include "llvm/Support/SourceMgr.h"
 #include "llvm/Transforms/Utils/Cloning.h"
+#include "gmock/gmock.h"
 #include "gtest/gtest.h"
 #include <cstring>
 
 using namespace llvm;
+using namespace testing;
 
 namespace llvm {
 LLVM_ABI extern cl::opt<bool> EnableDetailedFunctionProperties;
 LLVM_ABI extern cl::opt<bool> BigBasicBlockInstructionThreshold;
 LLVM_ABI extern cl::opt<bool> MediumBasicBlockInstrutionThreshold;
+LLVM_ABI extern cl::opt<float> ir2vec::OpcWeight;
+LLVM_ABI extern cl::opt<float> ir2vec::TypeWeight;
+LLVM_ABI...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/143479


More information about the llvm-branch-commits mailing list