[llvm] [IR2Vec] Add support for flow-aware embeddings (PR #152613)

via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 7 16:47:29 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-analysis

Author: S. VenkataKeerthy (svkeerthy)

<details>
<summary>Changes</summary>

This patch introduces support for Flow-Aware embeddings in IR2Vec, which capture data flow information in addition to symbolic representations.



---

Patch is 20.99 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/152613.diff


6 Files Affected:

- (modified) llvm/include/llvm/Analysis/IR2Vec.h (+23-6) 
- (modified) llvm/lib/Analysis/IR2Vec.cpp (+65-9) 
- (added) llvm/test/Analysis/IR2Vec/basic-flowaware.ll (+72) 
- (renamed) llvm/test/Analysis/IR2Vec/basic-symbolic.ll (+1-13) 
- (added) llvm/test/Analysis/IR2Vec/basic-vocab.ll (+27) 
- (modified) llvm/unittests/Analysis/IR2VecTest.cpp (+71-5) 


``````````diff
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 17f41129fd4fa..3cfc206c94788 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -54,14 +54,19 @@ class IR2VecVocabAnalysis;
 /// of the IR entities. Flow-aware embeddings build on top of symbolic
 /// embeddings and additionally capture the flow information in the IR.
 /// IR2VecKind is used to specify the type of embeddings to generate.
-/// Currently, only Symbolic embeddings are supported.
-enum class IR2VecKind { Symbolic };
+/// Note: Implementation of FlowAware embeddings is not same as the one
+/// described in the paper. The current implementation is a simplified version
+/// that captures the flow information (SSA-based use-defs) without tracing
+/// through memory level use-defs in the embedding computation described in the
+/// paper.
+enum class IR2VecKind { Symbolic, FlowAware };
 
 namespace ir2vec {
 
 LLVM_ABI extern cl::opt<float> OpcWeight;
 LLVM_ABI extern cl::opt<float> TypeWeight;
 LLVM_ABI extern cl::opt<float> ArgWeight;
+LLVM_ABI extern cl::opt<IR2VecKind> IR2VecEmbeddingKind;
 
 /// Embedding is a datatype that wraps std::vector<double>. It provides
 /// additional functionality for arithmetic and comparison operations.
@@ -257,9 +262,8 @@ class Embedder {
   LLVM_ABI Embedder(const Function &F, const Vocabulary &Vocab);
 
   /// Helper function to compute embeddings. It generates embeddings for all
-  /// the instructions and basic blocks in the function F. Logic of computing
-  /// the embeddings is specific to the kind of embeddings being computed.
-  virtual void computeEmbeddings() const = 0;
+  /// the instructions and basic blocks in the function F.
+  void computeEmbeddings() const;
 
   /// Helper function to compute the embedding for a given basic block.
   /// Specific to the kind of embeddings being computed.
@@ -296,7 +300,6 @@ class Embedder {
 /// representations obtained from the Vocabulary.
 class LLVM_ABI SymbolicEmbedder : public Embedder {
 private:
-  void computeEmbeddings() const override;
   void computeEmbeddings(const BasicBlock &BB) const override;
 
 public:
@@ -306,6 +309,20 @@ class LLVM_ABI SymbolicEmbedder : public Embedder {
   }
 };
 
+/// Class for computing the Flow-aware embeddings of IR2Vec.
+/// Flow-aware embeddings build on the vocabulary, just like Symbolic
+/// embeddings, and additionally capture the flow information in the IR.
+class LLVM_ABI FlowAwareEmbedder : public Embedder {
+private:
+  void computeEmbeddings(const BasicBlock &BB) const override;
+
+public:
+  FlowAwareEmbedder(const Function &F, const Vocabulary &Vocab)
+      : Embedder(F, Vocab) {
+    FuncVector = Embedding(Dimension, 0);
+  }
+};
+
 } // namespace ir2vec
 
 /// This analysis provides the vocabulary for IR2Vec. The vocabulary provides a
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 95f30fd3f4275..0bea25ec26b8e 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -52,6 +52,15 @@ cl::opt<float> TypeWeight("ir2vec-type-weight", cl::Optional, cl::init(0.5),
 cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional, cl::init(0.2),
                          cl::desc("Weight for argument embeddings"),
                          cl::cat(IR2VecCategory));
+cl::opt<IR2VecKind> IR2VecEmbeddingKind(
+    "ir2vec-kind", cl::Optional,
+    cl::values(clEnumValN(IR2VecKind::Symbolic, "symbolic",
+                          "Generate symbolic embeddings"),
+               clEnumValN(IR2VecKind::FlowAware, "flow-aware",
+                          "Generate flow-aware embeddings")),
+    cl::init(IR2VecKind::Symbolic), cl::desc("IR2Vec embedding kind"),
+    cl::cat(IR2VecCategory));
+
 } // namespace ir2vec
 } // namespace llvm
 
@@ -149,6 +158,8 @@ std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
   switch (Mode) {
   case IR2VecKind::Symbolic:
     return std::make_unique<SymbolicEmbedder>(F, Vocab);
+  case IR2VecKind::FlowAware:
+    return std::make_unique<FlowAwareEmbedder>(F, Vocab);
   }
   return nullptr;
 }
@@ -180,6 +191,17 @@ const Embedding &Embedder::getFunctionVector() const {
   return FuncVector;
 }
 
+void Embedder::computeEmbeddings() const {
+  if (F.isDeclaration())
+    return;
+
+  // Consider only the basic blocks that are reachable from entry
+  for (const BasicBlock *BB : depth_first(&F)) {
+    computeEmbeddings(*BB);
+    FuncVector += BBVecMap[BB];
+  }
+}
+
 void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
   Embedding BBVector(Dimension, 0);
 
@@ -196,15 +218,46 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
   BBVecMap[&BB] = BBVector;
 }
 
-void SymbolicEmbedder::computeEmbeddings() const {
-  if (F.isDeclaration())
-    return;
+void FlowAwareEmbedder::computeEmbeddings(const BasicBlock &BB) const {
+  Embedding BBVector(Dimension, 0);
 
-  // Consider only the basic blocks that are reachable from entry
-  for (const BasicBlock *BB : depth_first(&F)) {
-    computeEmbeddings(*BB);
-    FuncVector += BBVecMap[BB];
+  // We consider only the non-debug and non-pseudo instructions
+  for (const auto &I : BB.instructionsWithoutDebug()) {
+    // TODO: Handle call instructions differently.
+    // For now, we treat them like other instructions
+    Embedding ArgEmb(Dimension, 0);
+    for (const auto &Op : I.operands()) {
+      // If the operand is defined elsewhere, we use its embedding
+      if (const Instruction *DefInst = dyn_cast<Instruction>(Op)) {
+        auto DefIt = InstVecMap.find(DefInst);
+        assert(DefIt != InstVecMap.end() &&
+               "Instruction should have been processed before its operands");
+        if (DefIt != InstVecMap.end()) {
+          ArgEmb += DefIt->second;
+          continue;
+        }
+        // If the definition is not in the map, we use the vocabulary
+        // Not expected, but handle it gracefully
+        LLVM_DEBUG(dbgs() << "Warning: Operand defined by instruction not "
+                             "found in InstVecMap: "
+                          << *DefInst << "\n");
+        ArgEmb += Vocab[Op];
+      }
+      // If the operand is not defined by an instruction, we use the vocabulary
+      else {
+        LLVM_DEBUG(errs() << "Using embedding from vocabulary for operand: "
+                          << *Op << "=" << Vocab[Op][0] << "\n");
+        ArgEmb += Vocab[Op];
+      }
+    }
+    // Create the instruction vector by combining opcode, type, and arguments
+    // embeddings
+    auto InstVector =
+        Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
+    InstVecMap[&I] = InstVector;
+    BBVector += InstVector;
   }
+  BBVecMap[&BB] = BBVector;
 }
 
 // ==----------------------------------------------------------------------===//
@@ -552,8 +605,11 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
   assert(Vocabulary.isValid() && "IR2Vec Vocabulary is invalid");
 
   for (Function &F : M) {
-    std::unique_ptr<Embedder> Emb =
-        Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);
+    std::unique_ptr<Embedder> Emb;
+    if (IR2VecEmbeddingKind == IR2VecKind::Symbolic)
+      Emb = Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);
+    else
+      Emb = Embedder::create(IR2VecKind::FlowAware, F, Vocabulary);
     if (!Emb) {
       OS << "Error creating IR2Vec embeddings \n";
       continue;
diff --git a/llvm/test/Analysis/IR2Vec/basic-flowaware.ll b/llvm/test/Analysis/IR2Vec/basic-flowaware.ll
new file mode 100644
index 0000000000000..4a7f970a9cf91
--- /dev/null
+++ b/llvm/test/Analysis/IR2Vec/basic-flowaware.ll
@@ -0,0 +1,72 @@
+; RUN: opt -passes='print<ir2vec>' -ir2vec-kind=flow-aware -o /dev/null -ir2vec-vocab-path=%S/Inputs/dummy_3D_nonzero_opc_vocab.json %s 2>&1 | FileCheck %s -check-prefix=3D-CHECK-OPC
+; RUN: opt -passes='print<ir2vec>' -ir2vec-kind=flow-aware -o /dev/null -ir2vec-vocab-path=%S/Inputs/dummy_3D_nonzero_type_vocab.json %s 2>&1 | FileCheck %s -check-prefix=3D-CHECK-TYPE
+; RUN: opt -passes='print<ir2vec>' -ir2vec-kind=flow-aware -o /dev/null -ir2vec-vocab-path=%S/Inputs/dummy_3D_nonzero_arg_vocab.json %s 2>&1 | FileCheck %s -check-prefix=3D-CHECK-ARG
+
+define dso_local noundef float @_Z3abcif(i32 noundef %a, float noundef %b) #0 {
+entry:
+  %a.addr = alloca i32, align 4
+  %b.addr = alloca float, align 4
+  store i32 %a, ptr %a.addr, align 4
+  store float %b, ptr %b.addr, align 4
+  %0 = load i32, ptr %a.addr, align 4
+  %1 = load i32, ptr %a.addr, align 4
+  %mul = mul nsw i32 %0, %1
+  %conv = sitofp i32 %mul to float
+  %2 = load float, ptr %b.addr, align 4
+  %add = fadd float %conv, %2
+  ret float %add
+}
+
+; 3D-CHECK-OPC: IR2Vec embeddings for function _Z3abcif:
+; 3D-CHECK-OPC-NEXT: Function vector: [ 3630.00 3672.00 3714.00 ]
+; 3D-CHECK-OPC-NEXT: Basic block vectors:
+; 3D-CHECK-OPC-NEXT: Basic block: entry:
+; 3D-CHECK-OPC-NEXT:  [ 3630.00 3672.00 3714.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction vectors:
+; 3D-CHECK-OPC-NEXT: Instruction:   %a.addr = alloca i32, align 4 [ 91.00  92.00  93.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction:   %b.addr = alloca float, align 4 [ 91.00  92.00  93.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction:   store i32 %a, ptr %a.addr, align 4 [ 188.00 190.00 192.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction:   store float %b, ptr %b.addr, align 4 [ 188.00 190.00 192.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction:   %0 = load i32, ptr %a.addr, align 4 [ 185.00 187.00 189.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction:   %1 = load i32, ptr %a.addr, align 4 [ 185.00 187.00 189.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction:   %mul = mul nsw i32 %0, %1 [ 419.00  424.00  429.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction:   %conv = sitofp i32 %mul to float [ 549.00  555.00  561.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction:   %2 = load float, ptr %b.addr, align 4 [ 185.00  187.00  189.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction:   %add = fadd float %conv, %2 [ 774.00  783.00  792.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction:   ret float %add [ 775.00  785.00  795.00 ]
+
+; 3D-CHECK-TYPE: IR2Vec embeddings for function _Z3abcif:
+; 3D-CHECK-TYPE-NEXT: Function vector:  [ 355.50  376.50  397.50 ]
+; 3D-CHECK-TYPE-NEXT: Basic block vectors:
+; 3D-CHECK-TYPE-NEXT: Basic block: entry:
+; 3D-CHECK-TYPE-NEXT:  [ 355.50  376.50  397.50 ]
+; 3D-CHECK-TYPE-NEXT: Instruction vectors:
+; 3D-CHECK-TYPE-NEXT: Instruction:   %a.addr = alloca i32, align 4 [ 12.50  13.00  13.50 ]
+; 3D-CHECK-TYPE-NEXT: Instruction:   %b.addr = alloca float, align 4 [ 12.50  13.00  13.50 ]
+; 3D-CHECK-TYPE-NEXT: Instruction:   store i32 %a, ptr %a.addr, align 4 [ 14.50  15.50  16.50 ]
+; 3D-CHECK-TYPE-NEXT: Instruction:   store float %b, ptr %b.addr, align 4 [ 14.50  15.50  16.50 ]
+; 3D-CHECK-TYPE-NEXT: Instruction:   %0 = load i32, ptr %a.addr, align 4 [ 22.00  23.00  24.00 ]
+; 3D-CHECK-TYPE-NEXT: Instruction:   %1 = load i32, ptr %a.addr, align 4 [ 22.00  23.00  24.00 ]
+; 3D-CHECK-TYPE-NEXT: Instruction:   %mul = mul nsw i32 %0, %1 [ 53.50  56.00  58.50 ]
+; 3D-CHECK-TYPE-NEXT: Instruction:   %conv = sitofp i32 %mul to float [ 54.00  57.00  60.00 ]
+; 3D-CHECK-TYPE-NEXT: Instruction:   %2 = load float, ptr %b.addr, align 4 [ 13.00  14.00  15.00 ]
+; 3D-CHECK-TYPE-NEXT: Instruction:   %add = fadd float %conv, %2 [ 67.50  72.00  76.50 ]
+; 3D-CHECK-TYPE-NEXT: Instruction:   ret float %add [ 69.50  74.50  79.50 ]
+
+; 3D-CHECK-ARG: IR2Vec embeddings for function _Z3abcif:
+; 3D-CHECK-ARG-NEXT: Function vector:  [ 27.80  31.60  35.40 ]
+; 3D-CHECK-ARG-NEXT: Basic block vectors:
+; 3D-CHECK-ARG-NEXT: Basic block: entry:
+; 3D-CHECK-ARG-NEXT:  [ 27.80  31.60  35.40 ]
+; 3D-CHECK-ARG-NEXT: Instruction vectors:
+; 3D-CHECK-ARG-NEXT: Instruction:   %a.addr = alloca i32, align 4 [ 1.40  1.60  1.80 ]
+; 3D-CHECK-ARG-NEXT: Instruction:   %b.addr = alloca float, align 4 [ 1.40  1.60  1.80 ]
+; 3D-CHECK-ARG-NEXT: Instruction:   store i32 %a, ptr %a.addr, align 4 [ 3.40  3.80  4.20 ]
+; 3D-CHECK-ARG-NEXT: Instruction:   store float %b, ptr %b.addr, align 4 [ 3.40  3.80  4.20 ]
+; 3D-CHECK-ARG-NEXT: Instruction:   %0 = load i32, ptr %a.addr, align 4 [ 1.40  1.60  1.80 ]
+; 3D-CHECK-ARG-NEXT: Instruction:   %1 = load i32, ptr %a.addr, align 4 [ 1.40  1.60  1.80 ]
+; 3D-CHECK-ARG-NEXT: Instruction:   %mul = mul nsw i32 %0, %1 [ 2.80  3.20  3.60 ]
+; 3D-CHECK-ARG-NEXT: Instruction:   %conv = sitofp i32 %mul to float [ 2.80  3.20  3.60 ]
+; 3D-CHECK-ARG-NEXT: Instruction:   %2 = load float, ptr %b.addr, align 4 [ 1.40  1.60  1.80 ]
+; 3D-CHECK-ARG-NEXT: Instruction:   %add = fadd float %conv, %2 [ 4.20  4.80  5.40 ]
+; 3D-CHECK-ARG-NEXT: Instruction:   ret float %add [ 4.20  4.80  5.40 ]
diff --git a/llvm/test/Analysis/IR2Vec/basic.ll b/llvm/test/Analysis/IR2Vec/basic-symbolic.ll
similarity index 81%
rename from llvm/test/Analysis/IR2Vec/basic.ll
rename to llvm/test/Analysis/IR2Vec/basic-symbolic.ll
index cb0544fb19860..35abd3c7fa269 100644
--- a/llvm/test/Analysis/IR2Vec/basic.ll
+++ b/llvm/test/Analysis/IR2Vec/basic-symbolic.ll
@@ -1,11 +1,7 @@
 ; RUN: opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/dummy_3D_nonzero_opc_vocab.json %s 2>&1 | FileCheck %s -check-prefix=3D-CHECK-OPC
 ; RUN: opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/dummy_3D_nonzero_type_vocab.json %s 2>&1 | FileCheck %s -check-prefix=3D-CHECK-TYPE
 ; RUN: opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/dummy_3D_nonzero_arg_vocab.json %s 2>&1 | FileCheck %s -check-prefix=3D-CHECK-ARG
-; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab1.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB1-CHECK
-; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab2.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB2-CHECK
-; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab3.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB3-CHECK
-; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab4.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB4-CHECK
- 
+
 define dso_local noundef float @_Z3abcif(i32 noundef %a, float noundef %b) #0 {
 entry:
   %a.addr = alloca i32, align 4
@@ -74,11 +70,3 @@ entry:
 ; 3D-CHECK-ARG-NEXT: Instruction:   %2 = load float, ptr %b.addr, align 4 [ 0.80  1.00  1.20 ]
 ; 3D-CHECK-ARG-NEXT: Instruction:   %add = fadd float %conv, %2 [ 4.00  4.40  4.80 ]
 ; 3D-CHECK-ARG-NEXT: Instruction:   ret float %add [ 2.00  2.20  2.40 ]
-
-; INCORRECT-VOCAB1-CHECK: error: Error reading vocabulary: Missing 'Opcodes' section in vocabulary file
-
-; INCORRECT-VOCAB2-CHECK: error: Error reading vocabulary: Missing 'Types' section in vocabulary file
-
-; INCORRECT-VOCAB3-CHECK: error: Error reading vocabulary: Missing 'Arguments' section in vocabulary file
-
-; INCORRECT-VOCAB4-CHECK: error: Error reading vocabulary: Vocabulary sections have different dimensions
diff --git a/llvm/test/Analysis/IR2Vec/basic-vocab.ll b/llvm/test/Analysis/IR2Vec/basic-vocab.ll
new file mode 100644
index 0000000000000..eeeee831814a8
--- /dev/null
+++ b/llvm/test/Analysis/IR2Vec/basic-vocab.ll
@@ -0,0 +1,27 @@
+; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab1.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB1-CHECK
+; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab2.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB2-CHECK
+; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab3.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB3-CHECK
+; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab4.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB4-CHECK
+ 
+define dso_local noundef float @_Z3abcif(i32 noundef %a, float noundef %b) #0 {
+entry:
+  %a.addr = alloca i32, align 4
+  %b.addr = alloca float, align 4
+  store i32 %a, ptr %a.addr, align 4
+  store float %b, ptr %b.addr, align 4
+  %0 = load i32, ptr %a.addr, align 4
+  %1 = load i32, ptr %a.addr, align 4
+  %mul = mul nsw i32 %0, %1
+  %conv = sitofp i32 %mul to float
+  %2 = load float, ptr %b.addr, align 4
+  %add = fadd float %conv, %2
+  ret float %add
+}
+
+; INCORRECT-VOCAB1-CHECK: error: Error reading vocabulary: Missing 'Opcodes' section in vocabulary file
+
+; INCORRECT-VOCAB2-CHECK: error: Error reading vocabulary: Missing 'Types' section in vocabulary file
+
+; INCORRECT-VOCAB3-CHECK: error: Error reading vocabulary: Missing 'Arguments' section in vocabulary file
+
+; INCORRECT-VOCAB4-CHECK: error: Error reading vocabulary: Vocabulary sections have different dimensions
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp
index e288585033c53..f6846963b3e2f 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -30,7 +30,6 @@ namespace {
 class TestableEmbedder : public Embedder {
 public:
   TestableEmbedder(const Function &F, const Vocabulary &V) : Embedder(F, V) {}
-  void computeEmbeddings() const override {}
   void computeEmbeddings(const BasicBlock &BB) const override {}
 };
 
@@ -258,6 +257,18 @@ TEST(IR2VecTest, CreateSymbolicEmbedder) {
   EXPECT_NE(Emb, nullptr);
 }
 
+TEST(IR2VecTest, CreateFlowAwareEmbedder) {
+  Vocabulary V = Vocabulary(Vocabulary::createDummyVocabForTest());
+
+  LLVMContext Ctx;
+  Module M("M", Ctx);
+  FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
+  Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
+
+  auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, V);
+  EXPECT_NE(Emb, nullptr);
+}
+
 TEST(IR2VecTest, CreateInvalidMode) {
   Vocabulary V = Vocabulary(Vocabulary::createDummyVocabForTest());
 
@@ -307,10 +318,12 @@ class IR2VecTestFixture : public ::testing::Test {
 
     AddInst = BinaryOperator::CreateAdd(Arg, Const, "add", BB);
     RetInst = ReturnInst::Create(Ctx, AddInst, BB);
+    F->print(llvm::errs());
+    F->dump();
   }
 };
 
-TEST_F(IR2VecTestFixture, GetInstVecMap) {
+TEST_F(IR2VecTestFixture, GetInstVecMap_Symbolic) {
   auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
   ASSERT_TRUE(static_cast<bool>(Emb));
 
@@ -327,7 +340,24 @@ TEST_F(IR2VecTestFixture, GetInstVecMap) {
   EXPECT_TRUE(InstMap.at(RetInst).approximatelyEquals(Embedding(2, 16.8)));
 }
 
-TEST_F(IR2VecTestFixture, GetBBVecMap) {
+TEST_F(IR2VecTestFixture, GetInstVecMap_FlowAware) {
+  auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, V);
+  ASSERT_TRUE(static_cast<bool>(Emb));
+
+  const auto &InstMap = Emb->getInstVecMap();
+
+  EXPECT_EQ(InstMap.size(), 2u);
+  EXPECT_TRUE(InstMap.count(AddInst));
+  EXPECT_TRUE(InstMap.count(RetInst));
+
+  EXPECT_EQ(InstMap.at(AddInst).size(), 2u);
+  EXPECT_EQ(InstMap.at(RetInst).size(), 2u);
+
+  EXPECT_TRUE(InstMap.at(AddInst).approximatelyEquals(Embedding(2, 27.6)));
+  EXPECT_TRUE(InstMap.at(RetInst).approximatelyEquals(Embedding(2, 35.2)));
+}
+
+TEST_F(IR2VecTestFixture, GetBBVecMap_Symbolic) {
   auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
   ASSERT_TRUE(static_cast<bool>(Emb));
 
@@ -342,7 +372,22 @@ TEST_F(IR2VecTestFixture, GetBBVecMap) {
   EXPECT_TRUE(BBMap.at(BB).approximatelyEquals(Embedding(2, 44.4)));
 }
 
-TEST_F(IR2VecTestFixture, GetBBVector) {
+TEST_F(IR2VecTestFixture, GetBBVecMap_FlowAware) {
+  auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, V);
+  ASSERT_TRUE(static_cast<bool>(Emb));
+
+  const auto &BBMap = Emb->getBBVecMap();
+
+  EXPECT_EQ(BBMap.size(), 1u);
+  EXPECT_TRUE(BBMap.count(BB));
+  EXPECT_EQ(BBMap.at(BB).size(), 2u);
+
+  // BB vector should be sum of add and ret: {27.6, 27.6} + {35.2, 35.2} =
+  // {62.8, 62.8}
+  EXPECT_TRUE(BBMap.at(BB).approximatelyEquals(Embedding(2, 62.8)));
+}
+
+TEST_F(IR2VecTestFixture, GetBBVector_Symbolic) {
   auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
   ASSERT_TRUE(static_cast<bool>(Emb));
 
@@ -352,7 +397,17 @@ TEST_F(IR2VecTestFixture, GetBBVector) {
   EXPECT_TRUE(BBVec.approximatelyEquals(Embedding(2, 44.4)));
 }
 
-TEST_F(IR2VecTestFixture, GetFunctionVector) {
+TEST_F(IR2VecTestFixture, GetBBVector_FlowAware) {
+  auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, V);
+  ASSERT_TRUE(static_cast<bool>(Emb));
+
+  const auto &BBVec ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/152613


More information about the llvm-commits mailing list