[llvm-branch-commits] [llvm] [MLGO][IR2Vec] Integrating IR2Vec with	MLInliner (PR #143479)
    S. VenkataKeerthy via llvm-branch-commits 
    llvm-branch-commits at lists.llvm.org
       
    Tue Jun 10 12:54:34 PDT 2025
    
    
  
https://github.com/svkeerthy updated https://github.com/llvm/llvm-project/pull/143479
>From 579c3c03fa1117382c810c8d322f5d83689d79dc Mon Sep 17 00:00:00 2001
From: svkeerthy <venkatakeerthy at google.com>
Date: Tue, 10 Jun 2025 05:40:38 +0000
Subject: [PATCH] [MLIniner][IR2Vec] Integrating IR2Vec with MLInliner
---
 .../Analysis/FunctionPropertiesAnalysis.h     |  26 ++-
 llvm/include/llvm/Analysis/IR2Vec.h           |   2 +-
 llvm/include/llvm/Analysis/InlineAdvisor.h    |   3 +
 .../llvm/Analysis/InlineModelFeatureMaps.h    |   6 +-
 llvm/include/llvm/Analysis/MLInlineAdvisor.h  |   1 +
 .../Analysis/FunctionPropertiesAnalysis.cpp   | 115 ++++++++++-
 llvm/lib/Analysis/IR2Vec.cpp                  |   4 +-
 llvm/lib/Analysis/InlineAdvisor.cpp           |  29 +++
 llvm/lib/Analysis/MLInlineAdvisor.cpp         |  34 +++-
 .../FunctionPropertiesAnalysisTest.cpp        | 179 +++++++++++++++---
 10 files changed, 361 insertions(+), 38 deletions(-)
diff --git a/llvm/include/llvm/Analysis/FunctionPropertiesAnalysis.h b/llvm/include/llvm/Analysis/FunctionPropertiesAnalysis.h
index babb6d9d6cf0c..06dbfc35a5294 100644
--- a/llvm/include/llvm/Analysis/FunctionPropertiesAnalysis.h
+++ b/llvm/include/llvm/Analysis/FunctionPropertiesAnalysis.h
@@ -15,6 +15,7 @@
 #define LLVM_ANALYSIS_FUNCTIONPROPERTIESANALYSIS_H
 
 #include "llvm/ADT/DenseSet.h"
+#include "llvm/Analysis/IR2Vec.h"
 #include "llvm/IR/Dominators.h"
 #include "llvm/IR/PassManager.h"
 #include "llvm/Support/Compiler.h"
@@ -32,17 +33,19 @@ class FunctionPropertiesInfo {
   void updateAggregateStats(const Function &F, const LoopInfo &LI);
   void reIncludeBB(const BasicBlock &BB);
 
+  ir2vec::Embedding FunctionEmbedding = ir2vec::Embedding(0.0);
+  std::optional<ir2vec::Vocab> IR2VecVocab;
+
 public:
   LLVM_ABI static FunctionPropertiesInfo
   getFunctionPropertiesInfo(const Function &F, const DominatorTree &DT,
-                            const LoopInfo &LI);
+                            const LoopInfo &LI,
+                            const IR2VecVocabResult *VocabResult);
 
   LLVM_ABI static FunctionPropertiesInfo
   getFunctionPropertiesInfo(Function &F, FunctionAnalysisManager &FAM);
 
-  bool operator==(const FunctionPropertiesInfo &FPI) const {
-    return std::memcmp(this, &FPI, sizeof(FunctionPropertiesInfo)) == 0;
-  }
+  bool operator==(const FunctionPropertiesInfo &FPI) const;
 
   bool operator!=(const FunctionPropertiesInfo &FPI) const {
     return !(*this == FPI);
@@ -137,6 +140,19 @@ class FunctionPropertiesInfo {
   int64_t CallReturnsVectorPointerCount = 0;
   int64_t CallWithManyArgumentsCount = 0;
   int64_t CallWithPointerArgumentCount = 0;
+
+  const ir2vec::Embedding &getFunctionEmbedding() const {
+    return FunctionEmbedding;
+  }
+
+  const std::optional<ir2vec::Vocab> &getIR2VecVocab() const {
+    return IR2VecVocab;
+  }
+
+  // Helper intended to be useful for unittests
+  void setFunctionEmbeddingForTest(const ir2vec::Embedding &Embedding) {
+    FunctionEmbedding = Embedding;
+  }
 };
 
 // Analysis pass
@@ -192,7 +208,7 @@ class FunctionPropertiesUpdater {
 
   DominatorTree &getUpdatedDominatorTree(FunctionAnalysisManager &FAM) const;
 
-  DenseSet<const BasicBlock *> Successors;
+  DenseSet<const BasicBlock *> Successors, CallUsers;
 
   // Edges we might potentially need to remove from the dominator tree.
   SmallVector<DominatorTree::UpdateType, 2> DomTreeUpdates;
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 3a6f47ded8ca4..9acffb996283c 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -239,7 +239,7 @@ class IR2VecVocabAnalysis : public AnalysisInfoMixin<IR2VecVocabAnalysis> {
 public:
   static AnalysisKey Key;
   IR2VecVocabAnalysis() = default;
-  explicit IR2VecVocabAnalysis(ir2vec::Vocab &&Vocab);
+  explicit IR2VecVocabAnalysis(ir2vec::Vocab Vocab);
   using Result = IR2VecVocabResult;
   Result run(Module &M, ModuleAnalysisManager &MAM);
 };
diff --git a/llvm/include/llvm/Analysis/InlineAdvisor.h b/llvm/include/llvm/Analysis/InlineAdvisor.h
index 9d15136e81d10..d2cad4717cbdb 100644
--- a/llvm/include/llvm/Analysis/InlineAdvisor.h
+++ b/llvm/include/llvm/Analysis/InlineAdvisor.h
@@ -331,6 +331,9 @@ class InlineAdvisorAnalysis : public AnalysisInfoMixin<InlineAdvisorAnalysis> {
   };
 
   Result run(Module &M, ModuleAnalysisManager &MAM) { return Result(M, MAM); }
+
+private:
+  static bool initializeIR2VecVocab(Module &M, ModuleAnalysisManager &MAM);
 };
 
 /// Printer pass for the InlineAdvisorAnalysis results.
diff --git a/llvm/include/llvm/Analysis/InlineModelFeatureMaps.h b/llvm/include/llvm/Analysis/InlineModelFeatureMaps.h
index 961d5091bf9f3..91d3378565fc5 100644
--- a/llvm/include/llvm/Analysis/InlineModelFeatureMaps.h
+++ b/llvm/include/llvm/Analysis/InlineModelFeatureMaps.h
@@ -142,6 +142,10 @@ enum class FeatureIndex : size_t {
   INLINE_FEATURE_ITERATOR(POPULATE_INDICES)
 #undef POPULATE_INDICES
 
+// IR2Vec embeddings
+  callee_embedding,
+  caller_embedding,
+
   NumberOfFeatures
 };
 // clang-format on
@@ -154,7 +158,7 @@ inlineCostFeatureToMlFeature(InlineCostFeatureIndex Feature) {
 constexpr size_t NumberOfFeatures =
     static_cast<size_t>(FeatureIndex::NumberOfFeatures);
 
-LLVM_ABI extern const std::vector<TensorSpec> FeatureMap;
+LLVM_ABI extern std::vector<TensorSpec> FeatureMap;
 
 LLVM_ABI extern const char *const DecisionName;
 LLVM_ABI extern const TensorSpec InlineDecisionSpec;
diff --git a/llvm/include/llvm/Analysis/MLInlineAdvisor.h b/llvm/include/llvm/Analysis/MLInlineAdvisor.h
index 580dd5e95d760..935e4c56dfce6 100644
--- a/llvm/include/llvm/Analysis/MLInlineAdvisor.h
+++ b/llvm/include/llvm/Analysis/MLInlineAdvisor.h
@@ -82,6 +82,7 @@ class MLInlineAdvisor : public InlineAdvisor {
   int64_t NodeCount = 0;
   int64_t EdgeCount = 0;
   int64_t EdgesOfLastSeenNodes = 0;
+  bool UseIR2Vec = false;
 
   std::map<const LazyCallGraph::Node *, unsigned> FunctionLevels;
   const int32_t InitialIRSize = 0;
diff --git a/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp b/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp
index 9d044c8a35910..29d3aaf46dc06 100644
--- a/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp
+++ b/llvm/lib/Analysis/FunctionPropertiesAnalysis.cpp
@@ -199,6 +199,29 @@ void FunctionPropertiesInfo::updateForBB(const BasicBlock &BB,
 #undef CHECK_OPERAND
     }
   }
+
+  if (IR2VecVocab) {
+    // We instantiate the IR2Vec embedder each time, as having an unique
+    // pointer to the embedder as member of the class would make it
+    // non-copyable. Instantiating the embedder in itself is not costly.
+    auto EmbOrErr = ir2vec::Embedder::create(IR2VecKind::Symbolic,
+                                             *BB.getParent(), *IR2VecVocab);
+    if (Error Err = EmbOrErr.takeError()) {
+      handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
+        BB.getContext().emitError("Error creating IR2Vec embeddings: " +
+                                  EI.message());
+      });
+      return;
+    }
+    auto Embedder = std::move(*EmbOrErr);
+    const auto &BBEmbedding = Embedder->getBBVector(BB);
+    // Subtract BBEmbedding from Function embedding if the direction is -1,
+    // and add it if the direction is +1.
+    if (Direction == -1)
+      FunctionEmbedding -= BBEmbedding;
+    else
+      FunctionEmbedding += BBEmbedding;
+  }
 }
 
 void FunctionPropertiesInfo::updateAggregateStats(const Function &F,
@@ -220,14 +243,24 @@ void FunctionPropertiesInfo::updateAggregateStats(const Function &F,
 
 FunctionPropertiesInfo FunctionPropertiesInfo::getFunctionPropertiesInfo(
     Function &F, FunctionAnalysisManager &FAM) {
+  // We use the cached result of the IR2VecVocabAnalysis run by
+  // InlineAdvisorAnalysis. If the IR2VecVocabAnalysis is not run, we don't
+  // use IR2Vec embeddings.
+  auto VocabResult = FAM.getResult<ModuleAnalysisManagerFunctionProxy>(F)
+                         .getCachedResult<IR2VecVocabAnalysis>(*F.getParent());
   return getFunctionPropertiesInfo(F, FAM.getResult<DominatorTreeAnalysis>(F),
-                                   FAM.getResult<LoopAnalysis>(F));
+                                   FAM.getResult<LoopAnalysis>(F), VocabResult);
 }
 
 FunctionPropertiesInfo FunctionPropertiesInfo::getFunctionPropertiesInfo(
-    const Function &F, const DominatorTree &DT, const LoopInfo &LI) {
+    const Function &F, const DominatorTree &DT, const LoopInfo &LI,
+    const IR2VecVocabResult *VocabResult) {
 
   FunctionPropertiesInfo FPI;
+  if (VocabResult && VocabResult->isValid()) {
+    FPI.IR2VecVocab = VocabResult->getVocabulary();
+    FPI.FunctionEmbedding = ir2vec::Embedding(VocabResult->getDimension(), 0.0);
+  }
   for (const auto &BB : F)
     if (DT.isReachableFromEntry(&BB))
       FPI.reIncludeBB(BB);
@@ -235,6 +268,66 @@ FunctionPropertiesInfo FunctionPropertiesInfo::getFunctionPropertiesInfo(
   return FPI;
 }
 
+bool FunctionPropertiesInfo::operator==(
+    const FunctionPropertiesInfo &FPI) const {
+  if (BasicBlockCount != FPI.BasicBlockCount ||
+      BlocksReachedFromConditionalInstruction !=
+          FPI.BlocksReachedFromConditionalInstruction ||
+      Uses != FPI.Uses ||
+      DirectCallsToDefinedFunctions != FPI.DirectCallsToDefinedFunctions ||
+      LoadInstCount != FPI.LoadInstCount ||
+      StoreInstCount != FPI.StoreInstCount ||
+      MaxLoopDepth != FPI.MaxLoopDepth ||
+      TopLevelLoopCount != FPI.TopLevelLoopCount ||
+      TotalInstructionCount != FPI.TotalInstructionCount ||
+      BasicBlocksWithSingleSuccessor != FPI.BasicBlocksWithSingleSuccessor ||
+      BasicBlocksWithTwoSuccessors != FPI.BasicBlocksWithTwoSuccessors ||
+      BasicBlocksWithMoreThanTwoSuccessors !=
+          FPI.BasicBlocksWithMoreThanTwoSuccessors ||
+      BasicBlocksWithSinglePredecessor !=
+          FPI.BasicBlocksWithSinglePredecessor ||
+      BasicBlocksWithTwoPredecessors != FPI.BasicBlocksWithTwoPredecessors ||
+      BasicBlocksWithMoreThanTwoPredecessors !=
+          FPI.BasicBlocksWithMoreThanTwoPredecessors ||
+      BigBasicBlocks != FPI.BigBasicBlocks ||
+      MediumBasicBlocks != FPI.MediumBasicBlocks ||
+      SmallBasicBlocks != FPI.SmallBasicBlocks ||
+      CastInstructionCount != FPI.CastInstructionCount ||
+      FloatingPointInstructionCount != FPI.FloatingPointInstructionCount ||
+      IntegerInstructionCount != FPI.IntegerInstructionCount ||
+      ConstantIntOperandCount != FPI.ConstantIntOperandCount ||
+      ConstantFPOperandCount != FPI.ConstantFPOperandCount ||
+      ConstantOperandCount != FPI.ConstantOperandCount ||
+      InstructionOperandCount != FPI.InstructionOperandCount ||
+      BasicBlockOperandCount != FPI.BasicBlockOperandCount ||
+      GlobalValueOperandCount != FPI.GlobalValueOperandCount ||
+      InlineAsmOperandCount != FPI.InlineAsmOperandCount ||
+      ArgumentOperandCount != FPI.ArgumentOperandCount ||
+      UnknownOperandCount != FPI.UnknownOperandCount ||
+      CriticalEdgeCount != FPI.CriticalEdgeCount ||
+      ControlFlowEdgeCount != FPI.ControlFlowEdgeCount ||
+      UnconditionalBranchCount != FPI.UnconditionalBranchCount ||
+      IntrinsicCount != FPI.IntrinsicCount ||
+      DirectCallCount != FPI.DirectCallCount ||
+      IndirectCallCount != FPI.IndirectCallCount ||
+      CallReturnsIntegerCount != FPI.CallReturnsIntegerCount ||
+      CallReturnsFloatCount != FPI.CallReturnsFloatCount ||
+      CallReturnsPointerCount != FPI.CallReturnsPointerCount ||
+      CallReturnsVectorIntCount != FPI.CallReturnsVectorIntCount ||
+      CallReturnsVectorFloatCount != FPI.CallReturnsVectorFloatCount ||
+      CallReturnsVectorPointerCount != FPI.CallReturnsVectorPointerCount ||
+      CallWithManyArgumentsCount != FPI.CallWithManyArgumentsCount ||
+      CallWithPointerArgumentCount != FPI.CallWithPointerArgumentCount) {
+    return false;
+  }
+  // Check the equality of the function embeddings. We don't check the equality
+  // of Vocabulary as it remains the same.
+  if (!FunctionEmbedding.approximatelyEquals(FPI.FunctionEmbedding))
+    return false;
+
+  return true;
+}
+
 void FunctionPropertiesInfo::print(raw_ostream &OS) const {
 #define PRINT_PROPERTY(PROP_NAME) OS << #PROP_NAME ": " << PROP_NAME << "\n";
 
@@ -322,6 +415,16 @@ FunctionPropertiesUpdater::FunctionPropertiesUpdater(
   // The caller's entry BB may change due to new alloca instructions.
   LikelyToChangeBBs.insert(&*Caller.begin());
 
+  // The users of the value returned by call instruction can change
+  // leading to the change in embeddings being computed, when used.
+  // We conservatively add the BBs with such uses to LikelyToChangeBBs.
+  for (const auto *User : CB.users())
+    CallUsers.insert(dyn_cast<Instruction>(User)->getParent());
+  // CallSiteBB can be removed from CallUsers if present, it's taken care
+  // separately.
+  CallUsers.erase(&CallSiteBB);
+  LikelyToChangeBBs.insert_range(CallUsers);
+
   // The successors may become unreachable in the case of `invoke` inlining.
   // We track successors separately, too, because they form a boundary, together
   // with the CB BB ('Entry') between which the inlined callee will be pasted.
@@ -435,6 +538,9 @@ void FunctionPropertiesUpdater::finish(FunctionAnalysisManager &FAM) const {
   if (&CallSiteBB != &*Caller.begin())
     Reinclude.insert(&*Caller.begin());
 
+  // Reinclude the BBs which use the values returned by call instruction
+  Reinclude.insert_range(CallUsers);
+
   // Distribute the successors to the 2 buckets.
   for (const auto *Succ : Successors)
     if (DT.isReachableFromEntry(Succ))
@@ -486,6 +592,9 @@ bool FunctionPropertiesUpdater::isUpdateValid(Function &F,
     return false;
   DominatorTree DT(F);
   LoopInfo LI(DT);
-  auto Fresh = FunctionPropertiesInfo::getFunctionPropertiesInfo(F, DT, LI);
+  auto VocabResult = FAM.getResult<ModuleAnalysisManagerFunctionProxy>(F)
+                         .getCachedResult<IR2VecVocabAnalysis>(*F.getParent());
+  auto Fresh =
+      FunctionPropertiesInfo::getFunctionPropertiesInfo(F, DT, LI, VocabResult);
   return FPI == Fresh;
 }
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 4a5218e040b49..69fef01d7c781 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -289,8 +289,8 @@ Error IR2VecVocabAnalysis::readVocabulary() {
   return Error::success();
 }
 
-IR2VecVocabAnalysis::IR2VecVocabAnalysis(Vocab &&Vocabulary)
-    : Vocabulary(std::move(Vocabulary)) {}
+IR2VecVocabAnalysis::IR2VecVocabAnalysis(Vocab Vocabulary)
+    : Vocabulary(Vocabulary) {}
 
 void IR2VecVocabAnalysis::emitError(Error Err, LLVMContext &Ctx) {
   handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
diff --git a/llvm/lib/Analysis/InlineAdvisor.cpp b/llvm/lib/Analysis/InlineAdvisor.cpp
index 3d30f3d10a9d0..2e869dfd91713 100644
--- a/llvm/lib/Analysis/InlineAdvisor.cpp
+++ b/llvm/lib/Analysis/InlineAdvisor.cpp
@@ -16,6 +16,7 @@
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Analysis/AssumptionCache.h"
 #include "llvm/Analysis/EphemeralValuesCache.h"
+#include "llvm/Analysis/IR2Vec.h"
 #include "llvm/Analysis/InlineCost.h"
 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
 #include "llvm/Analysis/ProfileSummaryInfo.h"
@@ -64,6 +65,13 @@ static cl::opt<bool>
                         cl::desc("If true, annotate inline advisor remarks "
                                  "with LTO and pass information."));
 
+// This flag is used to enable IR2Vec embeddings in the ML inliner; Only valid
+// with ML inliner. The vocab file is used to initialize the embeddings.
+static cl::opt<std::string> IR2VecVocabFile(
+    "ml-inliner-ir2vec-vocab-file", cl::Hidden,
+    cl::desc("Vocab file for IR2Vec; Setting this enables "
+             "configuring the model to use IR2Vec embeddings."));
+
 namespace llvm {
 extern cl::opt<InlinerFunctionImportStatsOpts> InlinerFunctionImportStats;
 } // namespace llvm
@@ -206,6 +214,20 @@ void InlineAdvice::recordInliningWithCalleeDeleted() {
 AnalysisKey InlineAdvisorAnalysis::Key;
 AnalysisKey PluginInlineAdvisorAnalysis::Key;
 
+bool InlineAdvisorAnalysis::initializeIR2VecVocab(Module &M,
+                                                  ModuleAnalysisManager &MAM) {
+  if (!IR2VecVocabFile.empty()) {
+    auto IR2VecVocabResult = MAM.getResult<IR2VecVocabAnalysis>(M);
+    if (!IR2VecVocabResult.isValid()) {
+      M.getContext().emitError("Failed to load IR2Vec vocabulary");
+      return false;
+    }
+  }
+  // No vocab file specified is OK; We just don't use IR2Vec
+  // embeddings.
+  return true;
+}
+
 bool InlineAdvisorAnalysis::Result::tryCreate(
     InlineParams Params, InliningAdvisorMode Mode,
     const ReplayInlinerSettings &ReplaySettings, InlineContext IC) {
@@ -231,14 +253,21 @@ bool InlineAdvisorAnalysis::Result::tryCreate(
                                              /* EmitRemarks =*/true, IC);
     }
     break;
+    // Run IR2VecVocabAnalysis once per module to get the vocabulary.
+    // We run it here because it is immutable and we want to avoid running it
+    // multiple times.
   case InliningAdvisorMode::Development:
 #ifdef LLVM_HAVE_TFLITE
     LLVM_DEBUG(dbgs() << "Using development-mode inliner policy.\n");
+    if (!InlineAdvisorAnalysis::initializeIR2VecVocab(M, MAM))
+      return false;
     Advisor = llvm::getDevelopmentModeAdvisor(M, MAM, GetDefaultAdvice);
 #endif
     break;
   case InliningAdvisorMode::Release:
     LLVM_DEBUG(dbgs() << "Using release-mode inliner policy.\n");
+    if (!InlineAdvisorAnalysis::initializeIR2VecVocab(M, MAM))
+      return false;
     Advisor = llvm::getReleaseModeAdvisor(M, MAM, GetDefaultAdvice);
     break;
   }
diff --git a/llvm/lib/Analysis/MLInlineAdvisor.cpp b/llvm/lib/Analysis/MLInlineAdvisor.cpp
index 81a3bc94a6ad8..3a9a68670e852 100644
--- a/llvm/lib/Analysis/MLInlineAdvisor.cpp
+++ b/llvm/lib/Analysis/MLInlineAdvisor.cpp
@@ -107,7 +107,7 @@ static cl::opt<bool> KeepFPICache(
     cl::init(false));
 
 // clang-format off
-const std::vector<TensorSpec> llvm::FeatureMap{
+std::vector<TensorSpec> llvm::FeatureMap{
 #define POPULATE_NAMES(DTYPE, SHAPE, NAME, __) TensorSpec::createSpec<DTYPE>(#NAME, SHAPE),
 // InlineCost features - these must come first
   INLINE_COST_FEATURE_ITERATOR(POPULATE_NAMES)
@@ -186,6 +186,20 @@ MLInlineAdvisor::MLInlineAdvisor(
     EdgeCount += getLocalCalls(KVP.first->getFunction());
   }
   NodeCount = AllNodes.size();
+
+  if (auto IR2VecVocabResult = MAM.getCachedResult<IR2VecVocabAnalysis>(M)) {
+    if (!IR2VecVocabResult->isValid()) {
+      M.getContext().emitError("IR2VecVocabAnalysis is not valid");
+      return;
+    }
+    // Add the IR2Vec features to the feature map
+    auto IR2VecDim = IR2VecVocabResult->getDimension();
+    FeatureMap.push_back(
+        TensorSpec::createSpec<float>("callee_embedding", {IR2VecDim}));
+    FeatureMap.push_back(
+        TensorSpec::createSpec<float>("caller_embedding", {IR2VecDim}));
+    UseIR2Vec = true;
+  }
 }
 
 unsigned MLInlineAdvisor::getInitialFunctionLevel(const Function &F) const {
@@ -433,6 +447,24 @@ std::unique_ptr<InlineAdvice> MLInlineAdvisor::getAdviceImpl(CallBase &CB) {
   *ModelRunner->getTensor<int64_t>(FeatureIndex::is_caller_avail_external) =
       Caller.hasAvailableExternallyLinkage();
 
+  if (UseIR2Vec) {
+    // Python side expects float embeddings. The IR2Vec embeddings are doubles
+    // as of now due to the restriction of fromJSON method used by the
+    // readVocabulary method in ir2vec::Embeddings.
+    auto setEmbedding = [&](const ir2vec::Embedding &Embedding,
+                            FeatureIndex Index) {
+      auto Embedding_float =
+          std::vector<float>(Embedding.begin(), Embedding.end());
+      std::memcpy(ModelRunner->getTensor<float>(Index), Embedding_float.data(),
+                  Embedding.size() * sizeof(float));
+    };
+
+    setEmbedding(CalleeBefore.getFunctionEmbedding(),
+                 FeatureIndex::callee_embedding);
+    setEmbedding(CallerBefore.getFunctionEmbedding(),
+                 FeatureIndex::caller_embedding);
+  }
+
   // Add the cost features
   for (size_t I = 0;
        I < static_cast<size_t>(InlineCostFeatureIndex::NumberOfFeatures); ++I) {
diff --git a/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp b/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp
index 0720d935b0362..3ef2964f2d170 100644
--- a/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp
+++ b/llvm/unittests/Analysis/FunctionPropertiesAnalysisTest.cpp
@@ -8,6 +8,7 @@
 
 #include "llvm/Analysis/FunctionPropertiesAnalysis.h"
 #include "llvm/Analysis/AliasAnalysis.h"
+#include "llvm/Analysis/IR2Vec.h"
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/AsmParser/Parser.h"
 #include "llvm/IR/Dominators.h"
@@ -20,15 +21,20 @@
 #include "llvm/Support/Compiler.h"
 #include "llvm/Support/SourceMgr.h"
 #include "llvm/Transforms/Utils/Cloning.h"
+#include "gmock/gmock.h"
 #include "gtest/gtest.h"
 #include <cstring>
 
 using namespace llvm;
+using namespace testing;
 
 namespace llvm {
 LLVM_ABI extern cl::opt<bool> EnableDetailedFunctionProperties;
 LLVM_ABI extern cl::opt<bool> BigBasicBlockInstructionThreshold;
 LLVM_ABI extern cl::opt<bool> MediumBasicBlockInstrutionThreshold;
+LLVM_ABI extern cl::opt<float> ir2vec::OpcWeight;
+LLVM_ABI extern cl::opt<float> ir2vec::TypeWeight;
+LLVM_ABI extern cl::opt<float> ir2vec::ArgWeight;
 } // namespace llvm
 
 namespace {
@@ -36,17 +42,81 @@ namespace {
 class FunctionPropertiesAnalysisTest : public testing::Test {
 public:
   FunctionPropertiesAnalysisTest() {
+    createTestVocabulary(1);
+    MAM.registerPass([&] { return IR2VecVocabAnalysis(Vocabulary); });
+    MAM.registerPass([&] { return PassInstrumentationAnalysis(); });
+    FAM.registerPass([&] { return ModuleAnalysisManagerFunctionProxy(MAM); });
     FAM.registerPass([&] { return DominatorTreeAnalysis(); });
     FAM.registerPass([&] { return LoopAnalysis(); });
     FAM.registerPass([&] { return PassInstrumentationAnalysis(); });
+
+    ir2vec::OpcWeight = 1.0;
+    ir2vec::TypeWeight = 1.0;
+    ir2vec::ArgWeight = 1.0;
+  }
+
+private:
+  float OriginalOpcWeight = ir2vec::OpcWeight;
+  float OriginalTypeWeight = ir2vec::TypeWeight;
+  float OriginalArgWeight = ir2vec::ArgWeight;
+
+  void createTestVocabulary(unsigned Dim) {
+    Vocabulary["add"] = ir2vec::Embedding(Dim, 0.1);
+    Vocabulary["sub"] = ir2vec::Embedding(Dim, 0.2);
+    Vocabulary["mul"] = ir2vec::Embedding(Dim, 0.3);
+    Vocabulary["icmp"] = ir2vec::Embedding(Dim, 0.4);
+    Vocabulary["br"] = ir2vec::Embedding(Dim, 0.5);
+    Vocabulary["ret"] = ir2vec::Embedding(Dim, 0.6);
+    Vocabulary["store"] = ir2vec::Embedding(Dim, 0.7);
+    Vocabulary["load"] = ir2vec::Embedding(Dim, 0.8);
+    Vocabulary["alloca"] = ir2vec::Embedding(Dim, 0.9);
+    Vocabulary["phi"] = ir2vec::Embedding(Dim, 1.0);
+    Vocabulary["call"] = ir2vec::Embedding(Dim, 1.1);
+    Vocabulary["voidTy"] = ir2vec::Embedding(Dim, 1.3);
+    Vocabulary["floatTy"] = ir2vec::Embedding(Dim, 1.4);
+    Vocabulary["integerTy"] = ir2vec::Embedding(Dim, 1.5);
+    Vocabulary["functionTy"] = ir2vec::Embedding(Dim, 1.6);
+    Vocabulary["structTy"] = ir2vec::Embedding(Dim, 1.7);
+    Vocabulary["arrayTy"] = ir2vec::Embedding(Dim, 1.8);
+    Vocabulary["pointerTy"] = ir2vec::Embedding(Dim, 1.9);
+    Vocabulary["vectorTy"] = ir2vec::Embedding(Dim, 2.0);
+    Vocabulary["emptyTy"] = ir2vec::Embedding(Dim, 2.1);
+    Vocabulary["labelTy"] = ir2vec::Embedding(Dim, 2.2);
+    Vocabulary["tokenTy"] = ir2vec::Embedding(Dim, 2.3);
+    Vocabulary["metadataTy"] = ir2vec::Embedding(Dim, 2.4);
+    Vocabulary["unknownTy"] = ir2vec::Embedding(Dim, 2.5);
+    Vocabulary["function"] = ir2vec::Embedding(Dim, 3.1);
+    Vocabulary["pointer"] = ir2vec::Embedding(Dim, 3.2);
+    Vocabulary["constant"] = ir2vec::Embedding(Dim, 3.3);
+    Vocabulary["variable"] = ir2vec::Embedding(Dim, 3.4);
+    Vocabulary["getelementptr"] = ir2vec::Embedding(Dim, 3.5);
+    Vocabulary["invoke"] = ir2vec::Embedding(Dim, 3.6);
+    Vocabulary["landingpad"] = ir2vec::Embedding(Dim, 3.7);
+    Vocabulary["resume"] = ir2vec::Embedding(Dim, 3.8);
+    Vocabulary["catch"] = ir2vec::Embedding(Dim, 3.9);
+    Vocabulary["cleanup"] = ir2vec::Embedding(Dim, 4.0);
+    return;
   }
 
 protected:
   std::unique_ptr<DominatorTree> DT;
   std::unique_ptr<LoopInfo> LI;
   FunctionAnalysisManager FAM;
+  ModuleAnalysisManager MAM;
+  ir2vec::Vocab Vocabulary;
+
+  void TearDown() override {
+    // Restore original IR2Vec weights
+    ir2vec::OpcWeight = OriginalOpcWeight;
+    ir2vec::TypeWeight = OriginalTypeWeight;
+    ir2vec::ArgWeight = OriginalArgWeight;
+  }
 
   FunctionPropertiesInfo buildFPI(Function &F) {
+    // FunctionPropertiesInfo assumes IR2VecVocabAnalysis has been run to
+    // use IR2Vec.
+    auto VocabResult = MAM.getResult<IR2VecVocabAnalysis>(*F.getParent());
+    (void)VocabResult;
     return FunctionPropertiesInfo::getFunctionPropertiesInfo(F, FAM);
   }
 
@@ -62,15 +132,22 @@ class FunctionPropertiesAnalysisTest : public testing::Test {
       Err.print("MLAnalysisTests", errs());
     return Mod;
   }
-  
-  CallBase* findCall(Function& F, const char* Name = nullptr) {
+
+  CallBase *findCall(Function &F, const char *Name = nullptr) {
     for (auto &BB : F)
-      for (auto &I : BB )
+      for (auto &I : BB)
         if (auto *CB = dyn_cast<CallBase>(&I))
           if (!Name || CB->getName() == Name)
             return CB;
     return nullptr;
   }
+
+  std::unique_ptr<ir2vec::Embedder> createEmbedder(const Function &F) {
+    auto EmbResult =
+        ir2vec::Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);
+    EXPECT_TRUE(static_cast<bool>(EmbResult));
+    return std::move(*EmbResult);
+  }
 };
 
 TEST_F(FunctionPropertiesAnalysisTest, BasicTest) {
@@ -113,6 +190,8 @@ define internal i32 @top() {
   EXPECT_EQ(BranchesFeatures.StoreInstCount, 0);
   EXPECT_EQ(BranchesFeatures.MaxLoopDepth, 0);
   EXPECT_EQ(BranchesFeatures.TopLevelLoopCount, 0);
+  EXPECT_TRUE(BranchesFeatures.getFunctionEmbedding().approximatelyEquals(
+      createEmbedder(*BranchesFunction)->getFunctionVector()));
 
   Function *TopFunction = M->getFunction("top");
   FunctionPropertiesInfo TopFeatures = buildFPI(*TopFunction);
@@ -120,6 +199,8 @@ define internal i32 @top() {
   EXPECT_EQ(TopFeatures.BlocksReachedFromConditionalInstruction, 0);
   EXPECT_EQ(TopFeatures.Uses, 0);
   EXPECT_EQ(TopFeatures.DirectCallsToDefinedFunctions, 1);
+  EXPECT_TRUE(TopFeatures.getFunctionEmbedding().approximatelyEquals(
+      createEmbedder(*TopFunction)->getFunctionVector()));
   EXPECT_EQ(BranchesFeatures.LoadInstCount, 0);
   EXPECT_EQ(BranchesFeatures.StoreInstCount, 0);
   EXPECT_EQ(BranchesFeatures.MaxLoopDepth, 0);
@@ -159,6 +240,9 @@ define internal i32 @top() {
   EXPECT_EQ(DetailedBranchesFeatures.CallReturnsPointerCount, 0);
   EXPECT_EQ(DetailedBranchesFeatures.CallWithManyArgumentsCount, 0);
   EXPECT_EQ(DetailedBranchesFeatures.CallWithPointerArgumentCount, 0);
+  EXPECT_TRUE(
+      DetailedBranchesFeatures.getFunctionEmbedding().approximatelyEquals(
+          createEmbedder(*BranchesFunction)->getFunctionVector()));
   EnableDetailedFunctionProperties.setValue(false);
 }
 
@@ -210,6 +294,8 @@ define i64 @f1() {
   EXPECT_EQ(DetailedF1Properties.CallReturnsPointerCount, 0);
   EXPECT_EQ(DetailedF1Properties.CallWithManyArgumentsCount, 0);
   EXPECT_EQ(DetailedF1Properties.CallWithPointerArgumentCount, 0);
+  EXPECT_TRUE(DetailedF1Properties.getFunctionEmbedding().approximatelyEquals(
+      createEmbedder(*F1)->getFunctionVector()));
   EnableDetailedFunctionProperties.setValue(false);
 }
 
@@ -232,28 +318,29 @@ define i32 @f2(i32 %a) {
 )IR");
 
   Function *F1 = M->getFunction("f1");
-  CallBase* CB = findCall(*F1, "b");
+  CallBase *CB = findCall(*F1, "b");
   EXPECT_NE(CB, nullptr);
 
-  FunctionPropertiesInfo ExpectedInitial;
-  ExpectedInitial.BasicBlockCount = 1;
-  ExpectedInitial.TotalInstructionCount = 3;
-  ExpectedInitial.Uses = 1;
-  ExpectedInitial.DirectCallsToDefinedFunctions = 1;
-
-  FunctionPropertiesInfo ExpectedFinal = ExpectedInitial;
-  ExpectedFinal.DirectCallsToDefinedFunctions = 0;
-
   auto FPI = buildFPI(*F1);
-  EXPECT_EQ(FPI, ExpectedInitial);
+  EXPECT_EQ(FPI.BasicBlockCount, 1);
+  EXPECT_EQ(FPI.TotalInstructionCount, 3);
+  EXPECT_EQ(FPI.Uses, 1);
+  EXPECT_EQ(FPI.DirectCallsToDefinedFunctions, 1);
+  EXPECT_THAT(FPI.getFunctionEmbedding(), ElementsAre(DoubleNear(22.7, 1e-6)));
 
   FunctionPropertiesUpdater FPU(FPI, *CB);
   InlineFunctionInfo IFI;
   auto IR = llvm::InlineFunction(*CB, IFI);
   EXPECT_TRUE(IR.isSuccess());
   invalidate(*F1);
+
   EXPECT_TRUE(FPU.finishAndTest(FAM));
-  EXPECT_EQ(FPI, ExpectedFinal);
+  EXPECT_EQ(FPI.BasicBlockCount, 1);
+  EXPECT_EQ(FPI.TotalInstructionCount, 3);
+  EXPECT_EQ(FPI.Uses, 1);
+  EXPECT_EQ(FPI.DirectCallsToDefinedFunctions, 0);
+  EXPECT_TRUE(FPI.getFunctionEmbedding().approximatelyEquals(
+      createEmbedder(*F1)->getFunctionVector()));
 }
 
 TEST_F(FunctionPropertiesAnalysisTest, InlineSameBBLargerCFG) {
@@ -285,7 +372,7 @@ define i32 @f2(i32 %a) {
 )IR");
 
   Function *F1 = M->getFunction("f1");
-  CallBase* CB = findCall(*F1, "b");
+  CallBase *CB = findCall(*F1, "b");
   EXPECT_NE(CB, nullptr);
 
   FunctionPropertiesInfo ExpectedInitial;
@@ -294,6 +381,8 @@ define i32 @f2(i32 %a) {
   ExpectedInitial.TotalInstructionCount = 9;
   ExpectedInitial.Uses = 1;
   ExpectedInitial.DirectCallsToDefinedFunctions = 1;
+  ExpectedInitial.setFunctionEmbeddingForTest(
+      createEmbedder(*F1)->getFunctionVector());
 
   FunctionPropertiesInfo ExpectedFinal = ExpectedInitial;
   ExpectedFinal.DirectCallsToDefinedFunctions = 0;
@@ -307,6 +396,9 @@ define i32 @f2(i32 %a) {
   EXPECT_TRUE(IR.isSuccess());
   invalidate(*F1);
   EXPECT_TRUE(FPU.finishAndTest(FAM));
+
+  ExpectedFinal.setFunctionEmbeddingForTest(
+      createEmbedder(*F1)->getFunctionVector());
   EXPECT_EQ(FPI, ExpectedFinal);
 }
 
@@ -347,7 +439,7 @@ define i32 @f2(i32 %a) {
 )IR");
 
   Function *F1 = M->getFunction("f1");
-  CallBase* CB = findCall(*F1, "b");
+  CallBase *CB = findCall(*F1, "b");
   EXPECT_NE(CB, nullptr);
 
   FunctionPropertiesInfo ExpectedInitial;
@@ -356,6 +448,8 @@ define i32 @f2(i32 %a) {
   ExpectedInitial.TotalInstructionCount = 9;
   ExpectedInitial.Uses = 1;
   ExpectedInitial.DirectCallsToDefinedFunctions = 1;
+  ExpectedInitial.setFunctionEmbeddingForTest(
+      createEmbedder(*F1)->getFunctionVector());
 
   FunctionPropertiesInfo ExpectedFinal;
   ExpectedFinal.BasicBlockCount = 6;
@@ -374,6 +468,9 @@ define i32 @f2(i32 %a) {
   EXPECT_TRUE(IR.isSuccess());
   invalidate(*F1);
   EXPECT_TRUE(FPU.finishAndTest(FAM));
+
+  ExpectedFinal.setFunctionEmbeddingForTest(
+      createEmbedder(*F1)->getFunctionVector());
   EXPECT_EQ(FPI, ExpectedFinal);
 }
 
@@ -409,7 +506,7 @@ declare i32 @__gxx_personality_v0(...)
 )IR");
 
   Function *F1 = M->getFunction("caller");
-  CallBase* CB = findCall(*F1);
+  CallBase *CB = findCall(*F1);
   EXPECT_NE(CB, nullptr);
 
   auto FPI = buildFPI(*F1);
@@ -422,6 +519,8 @@ declare i32 @__gxx_personality_v0(...)
   EXPECT_EQ(static_cast<size_t>(FPI.BasicBlockCount), F1->size());
   EXPECT_EQ(static_cast<size_t>(FPI.TotalInstructionCount),
             F1->getInstructionCount());
+  EXPECT_TRUE(FPI.getFunctionEmbedding().approximatelyEquals(
+      createEmbedder(*F1)->getFunctionVector()));
 }
 
 TEST_F(FunctionPropertiesAnalysisTest, InvokeUnreachableHandler) {
@@ -462,7 +561,7 @@ declare i32 @__gxx_personality_v0(...)
 )IR");
 
   Function *F1 = M->getFunction("caller");
-  CallBase* CB = findCall(*F1);
+  CallBase *CB = findCall(*F1);
   EXPECT_NE(CB, nullptr);
 
   auto FPI = buildFPI(*F1);
@@ -475,6 +574,8 @@ declare i32 @__gxx_personality_v0(...)
   EXPECT_EQ(static_cast<size_t>(FPI.BasicBlockCount), F1->size() - 1);
   EXPECT_EQ(static_cast<size_t>(FPI.TotalInstructionCount),
             F1->getInstructionCount() - 2);
+  EXPECT_TRUE(FPI.getFunctionEmbedding().approximatelyEquals(
+      createEmbedder(*F1)->getFunctionVector()));
   EXPECT_EQ(FPI, FunctionPropertiesInfo::getFunctionPropertiesInfo(*F1, FAM));
 }
 
@@ -516,7 +617,7 @@ declare i32 @__gxx_personality_v0(...)
 )IR");
 
   Function *F1 = M->getFunction("caller");
-  CallBase* CB = findCall(*F1);
+  CallBase *CB = findCall(*F1);
   EXPECT_NE(CB, nullptr);
 
   auto FPI = buildFPI(*F1);
@@ -568,7 +669,7 @@ define void @outer() personality i8* null {
 )IR");
 
   Function *F1 = M->getFunction("outer");
-  CallBase* CB = findCall(*F1);
+  CallBase *CB = findCall(*F1);
   EXPECT_NE(CB, nullptr);
 
   auto FPI = buildFPI(*F1);
@@ -581,6 +682,8 @@ define void @outer() personality i8* null {
   EXPECT_EQ(static_cast<size_t>(FPI.BasicBlockCount), F1->size() - 1);
   EXPECT_EQ(static_cast<size_t>(FPI.TotalInstructionCount),
             F1->getInstructionCount() - 2);
+  EXPECT_TRUE(FPI.getFunctionEmbedding().approximatelyEquals(
+      createEmbedder(*F1)->getFunctionVector()));
   EXPECT_EQ(FPI, FunctionPropertiesInfo::getFunctionPropertiesInfo(*F1, FAM));
 }
 
@@ -624,7 +727,7 @@ if.then:
 )IR");
 
   Function *F1 = M->getFunction("outer");
-  CallBase* CB = findCall(*F1);
+  CallBase *CB = findCall(*F1);
   EXPECT_NE(CB, nullptr);
 
   auto FPI = buildFPI(*F1);
@@ -637,6 +740,8 @@ if.then:
   EXPECT_EQ(static_cast<size_t>(FPI.BasicBlockCount), F1->size() - 1);
   EXPECT_EQ(static_cast<size_t>(FPI.TotalInstructionCount),
             F1->getInstructionCount() - 2);
+  EXPECT_TRUE(FPI.getFunctionEmbedding().approximatelyEquals(
+      createEmbedder(*F1)->getFunctionVector()));
   EXPECT_EQ(FPI, FunctionPropertiesInfo::getFunctionPropertiesInfo(*F1, FAM));
 }
 
@@ -689,6 +794,8 @@ define i32 @f2(i32 %a) {
   ExpectedInitial.DirectCallsToDefinedFunctions = 1;
   ExpectedInitial.MaxLoopDepth = 1;
   ExpectedInitial.TopLevelLoopCount = 1;
+  ExpectedInitial.setFunctionEmbeddingForTest(
+      createEmbedder(*F1)->getFunctionVector());
 
   FunctionPropertiesInfo ExpectedFinal = ExpectedInitial;
   ExpectedFinal.BasicBlockCount = 6;
@@ -705,6 +812,9 @@ define i32 @f2(i32 %a) {
   EXPECT_TRUE(IR.isSuccess());
   invalidate(*F1);
   EXPECT_TRUE(FPU.finishAndTest(FAM));
+
+  ExpectedFinal.setFunctionEmbeddingForTest(
+      createEmbedder(*F1)->getFunctionVector());
   EXPECT_EQ(FPI, ExpectedFinal);
 }
 
@@ -733,7 +843,7 @@ cond.false:                                       ; preds = %entry
 extra2:
   br label %cond.end
 
-cond.end:                                         ; preds = %cond.false, %cond.true
+cond.end:                                         ; preds = %extra2, %cond.true
   %cond = phi i64 [ %conv2, %cond.true ], [ %call3, %extra ]
   ret i64 %cond
 }
@@ -757,7 +867,9 @@ declare void @llvm.trap()
   ExpectedInitial.BlocksReachedFromConditionalInstruction = 2;
   ExpectedInitial.Uses = 1;
   ExpectedInitial.DirectCallsToDefinedFunctions = 1;
-  
+  ExpectedInitial.setFunctionEmbeddingForTest(
+      createEmbedder(*F1)->getFunctionVector());
+
   FunctionPropertiesInfo ExpectedFinal = ExpectedInitial;
   ExpectedFinal.BasicBlockCount = 4;
   ExpectedFinal.DirectCallsToDefinedFunctions = 0;
@@ -772,6 +884,9 @@ declare void @llvm.trap()
   EXPECT_TRUE(IR.isSuccess());
   invalidate(*F1);
   EXPECT_TRUE(FPU.finishAndTest(FAM));
+
+  ExpectedFinal.setFunctionEmbeddingForTest(
+      createEmbedder(*F1)->getFunctionVector());
   EXPECT_EQ(FPI, ExpectedFinal);
 }
 
@@ -817,6 +932,8 @@ declare void @f3()
   ExpectedInitial.BlocksReachedFromConditionalInstruction = 0;
   ExpectedInitial.Uses = 1;
   ExpectedInitial.DirectCallsToDefinedFunctions = 1;
+  ExpectedInitial.setFunctionEmbeddingForTest(
+      createEmbedder(*F1)->getFunctionVector());
 
   FunctionPropertiesInfo ExpectedFinal = ExpectedInitial;
   ExpectedFinal.BasicBlockCount = 6;
@@ -832,6 +949,9 @@ declare void @f3()
   EXPECT_TRUE(IR.isSuccess());
   invalidate(*F1);
   EXPECT_TRUE(FPU.finishAndTest(FAM));
+
+  ExpectedFinal.setFunctionEmbeddingForTest(
+      createEmbedder(*F1)->getFunctionVector());
   EXPECT_EQ(FPI, ExpectedFinal);
 }
 
@@ -885,6 +1005,8 @@ define i64 @f1(i64 %e) {
   EXPECT_EQ(DetailedF1Properties.CallReturnsPointerCount, 0);
   EXPECT_EQ(DetailedF1Properties.CallWithManyArgumentsCount, 0);
   EXPECT_EQ(DetailedF1Properties.CallWithPointerArgumentCount, 0);
+  EXPECT_TRUE(DetailedF1Properties.getFunctionEmbedding().approximatelyEquals(
+      createEmbedder(*F1)->getFunctionVector()));
   EnableDetailedFunctionProperties.setValue(false);
 }
 
@@ -910,6 +1032,8 @@ declare float @llvm.cos.f32(float)
   EXPECT_EQ(DetailedF1Properties.CallReturnsPointerCount, 0);
   EXPECT_EQ(DetailedF1Properties.CallWithManyArgumentsCount, 0);
   EXPECT_EQ(DetailedF1Properties.CallWithPointerArgumentCount, 0);
+  EXPECT_TRUE(DetailedF1Properties.getFunctionEmbedding().approximatelyEquals(
+      createEmbedder(*F1)->getFunctionVector()));
   EnableDetailedFunctionProperties.setValue(false);
 }
 
@@ -943,6 +1067,8 @@ declare float @f5()
   EXPECT_EQ(DetailedF1Properties.CallReturnsPointerCount, 1);
   EXPECT_EQ(DetailedF1Properties.CallWithManyArgumentsCount, 1);
   EXPECT_EQ(DetailedF1Properties.CallWithPointerArgumentCount, 1);
+  EXPECT_TRUE(DetailedF1Properties.getFunctionEmbedding().approximatelyEquals(
+      createEmbedder(*F1)->getFunctionVector()));
   EnableDetailedFunctionProperties.setValue(false);
 }
 
@@ -972,10 +1098,11 @@ define i64 @f1(i64 %a) {
   EnableDetailedFunctionProperties.setValue(true);
   FunctionPropertiesInfo DetailedF1Properties = buildFPI(*F1);
   EXPECT_EQ(DetailedF1Properties.CriticalEdgeCount, 1);
+  EXPECT_TRUE(DetailedF1Properties.getFunctionEmbedding().approximatelyEquals(
+      createEmbedder(*F1)->getFunctionVector()));
   EnableDetailedFunctionProperties.setValue(false);
 }
 
-
 TEST_F(FunctionPropertiesAnalysisTest, FunctionReturnVectors) {
   LLVMContext C;
   std::unique_ptr<Module> M = makeLLVMModule(C,
@@ -998,6 +1125,8 @@ declare <4 x ptr> @f4()
   EXPECT_EQ(DetailedF1Properties.CallReturnsVectorIntCount, 1);
   EXPECT_EQ(DetailedF1Properties.CallReturnsVectorFloatCount, 1);
   EXPECT_EQ(DetailedF1Properties.CallReturnsVectorPointerCount, 1);
+  EXPECT_TRUE(DetailedF1Properties.getFunctionEmbedding().approximatelyEquals(
+      createEmbedder(*F1)->getFunctionVector()));
   EnableDetailedFunctionProperties.setValue(false);
 }
 
    
    
More information about the llvm-branch-commits
mailing list