[llvm] [IR2Vec] Add support for flow-aware embeddings (PR #152613)

S. VenkataKeerthy via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 28 12:00:06 PDT 2025


https://github.com/svkeerthy updated https://github.com/llvm/llvm-project/pull/152613

>From 589fdb96761727812e27fb0ce4e6483d67a53ba8 Mon Sep 17 00:00:00 2001
From: svkeerthy <venkatakeerthy at google.com>
Date: Thu, 7 Aug 2025 23:43:45 +0000
Subject: [PATCH] Flow-Aware Embeddings

---
 llvm/include/llvm/Analysis/IR2Vec.h           | 31 +++++--
 llvm/lib/Analysis/IR2Vec.cpp                  | 82 ++++++++++++++++---
 llvm/test/Analysis/IR2Vec/basic-flowaware.ll  | 72 ++++++++++++++++
 .../IR2Vec/{basic.ll => basic-symbolic.ll}    | 14 +---
 llvm/test/Analysis/IR2Vec/basic-vocab.ll      | 27 ++++++
 llvm/unittests/Analysis/IR2VecTest.cpp        | 74 +++++++++++++++--
 6 files changed, 261 insertions(+), 39 deletions(-)
 create mode 100644 llvm/test/Analysis/IR2Vec/basic-flowaware.ll
 rename llvm/test/Analysis/IR2Vec/{basic.ll => basic-symbolic.ll} (81%)
 create mode 100644 llvm/test/Analysis/IR2Vec/basic-vocab.ll

diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 17f41129fd4fa..15221c7f07917 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -54,14 +54,19 @@ class IR2VecVocabAnalysis;
 /// of the IR entities. Flow-aware embeddings build on top of symbolic
 /// embeddings and additionally capture the flow information in the IR.
 /// IR2VecKind is used to specify the type of embeddings to generate.
-/// Currently, only Symbolic embeddings are supported.
-enum class IR2VecKind { Symbolic };
+/// Note: Implementation of FlowAware embeddings is not same as the one
+/// described in the paper. The current implementation is a simplified version
+/// that captures the flow information (SSA-based use-defs) without tracing
+/// through memory level use-defs in the embedding computation described in the
+/// paper.
+enum class IR2VecKind { Symbolic, FlowAware };
 
 namespace ir2vec {
 
 LLVM_ABI extern cl::opt<float> OpcWeight;
 LLVM_ABI extern cl::opt<float> TypeWeight;
 LLVM_ABI extern cl::opt<float> ArgWeight;
+LLVM_ABI extern cl::opt<IR2VecKind> IR2VecEmbeddingKind;
 
 /// Embedding is a datatype that wraps std::vector<double>. It provides
 /// additional functionality for arithmetic and comparison operations.
@@ -257,9 +262,8 @@ class Embedder {
   LLVM_ABI Embedder(const Function &F, const Vocabulary &Vocab);
 
   /// Helper function to compute embeddings. It generates embeddings for all
-  /// the instructions and basic blocks in the function F. Logic of computing
-  /// the embeddings is specific to the kind of embeddings being computed.
-  virtual void computeEmbeddings() const = 0;
+  /// the instructions and basic blocks in the function F.
+  void computeEmbeddings() const;
 
   /// Helper function to compute the embedding for a given basic block.
   /// Specific to the kind of embeddings being computed.
@@ -296,14 +300,23 @@ class Embedder {
 /// representations obtained from the Vocabulary.
 class LLVM_ABI SymbolicEmbedder : public Embedder {
 private:
-  void computeEmbeddings() const override;
   void computeEmbeddings(const BasicBlock &BB) const override;
 
 public:
   SymbolicEmbedder(const Function &F, const Vocabulary &Vocab)
-      : Embedder(F, Vocab) {
-    FuncVector = Embedding(Dimension, 0);
-  }
+      : Embedder(F, Vocab) {}
+};
+
+/// Class for computing the Flow-aware embeddings of IR2Vec.
+/// Flow-aware embeddings build on the vocabulary, just like Symbolic
+/// embeddings, and additionally capture the flow information in the IR.
+class LLVM_ABI FlowAwareEmbedder : public Embedder {
+private:
+  void computeEmbeddings(const BasicBlock &BB) const override;
+
+public:
+  FlowAwareEmbedder(const Function &F, const Vocabulary &Vocab)
+      : Embedder(F, Vocab) {}
 };
 
 } // namespace ir2vec
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 95f30fd3f4275..7d0ad6a69a398 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -52,6 +52,15 @@ cl::opt<float> TypeWeight("ir2vec-type-weight", cl::Optional, cl::init(0.5),
 cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional, cl::init(0.2),
                          cl::desc("Weight for argument embeddings"),
                          cl::cat(IR2VecCategory));
+cl::opt<IR2VecKind> IR2VecEmbeddingKind(
+    "ir2vec-kind", cl::Optional,
+    cl::values(clEnumValN(IR2VecKind::Symbolic, "symbolic",
+                          "Generate symbolic embeddings"),
+               clEnumValN(IR2VecKind::FlowAware, "flow-aware",
+                          "Generate flow-aware embeddings")),
+    cl::init(IR2VecKind::Symbolic), cl::desc("IR2Vec embedding kind"),
+    cl::cat(IR2VecCategory));
+
 } // namespace ir2vec
 } // namespace llvm
 
@@ -123,8 +132,12 @@ bool Embedding::approximatelyEquals(const Embedding &RHS,
                                     double Tolerance) const {
   assert(this->size() == RHS.size() && "Vectors must have the same dimension");
   for (size_t Itr = 0; Itr < this->size(); ++Itr)
-    if (std::abs((*this)[Itr] - RHS[Itr]) > Tolerance)
+    if (std::abs((*this)[Itr] - RHS[Itr]) > Tolerance) {
+      LLVM_DEBUG(errs() << "Embedding mismatch at index " << Itr << ": "
+                        << (*this)[Itr] << " vs " << RHS[Itr]
+                        << "; Tolerance: " << Tolerance << "\n");
       return false;
+    }
   return true;
 }
 
@@ -141,14 +154,16 @@ void Embedding::print(raw_ostream &OS) const {
 
 Embedder::Embedder(const Function &F, const Vocabulary &Vocab)
     : F(F), Vocab(Vocab), Dimension(Vocab.getDimension()),
-      OpcWeight(::OpcWeight), TypeWeight(::TypeWeight), ArgWeight(::ArgWeight) {
-}
+      OpcWeight(::OpcWeight), TypeWeight(::TypeWeight), ArgWeight(::ArgWeight),
+      FuncVector(Embedding(Dimension, 0)) {}
 
 std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
                                            const Vocabulary &Vocab) {
   switch (Mode) {
   case IR2VecKind::Symbolic:
     return std::make_unique<SymbolicEmbedder>(F, Vocab);
+  case IR2VecKind::FlowAware:
+    return std::make_unique<FlowAwareEmbedder>(F, Vocab);
   }
   return nullptr;
 }
@@ -180,6 +195,17 @@ const Embedding &Embedder::getFunctionVector() const {
   return FuncVector;
 }
 
+void Embedder::computeEmbeddings() const {
+  if (F.isDeclaration())
+    return;
+
+  // Consider only the basic blocks that are reachable from entry
+  for (const BasicBlock *BB : depth_first(&F)) {
+    computeEmbeddings(*BB);
+    FuncVector += BBVecMap[BB];
+  }
+}
+
 void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
   Embedding BBVector(Dimension, 0);
 
@@ -196,15 +222,38 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
   BBVecMap[&BB] = BBVector;
 }
 
-void SymbolicEmbedder::computeEmbeddings() const {
-  if (F.isDeclaration())
-    return;
+void FlowAwareEmbedder::computeEmbeddings(const BasicBlock &BB) const {
+  Embedding BBVector(Dimension, 0);
 
-  // Consider only the basic blocks that are reachable from entry
-  for (const BasicBlock *BB : depth_first(&F)) {
-    computeEmbeddings(*BB);
-    FuncVector += BBVecMap[BB];
+  // We consider only the non-debug and non-pseudo instructions
+  for (const auto &I : BB.instructionsWithoutDebug()) {
+    // TODO: Handle call instructions differently.
+    // For now, we treat them like other instructions
+    Embedding ArgEmb(Dimension, 0);
+    for (const auto &Op : I.operands()) {
+      // If the operand is defined elsewhere, we use its embedding
+      if (const auto *DefInst = dyn_cast<Instruction>(Op)) {
+        auto DefIt = InstVecMap.find(DefInst);
+        assert(DefIt != InstVecMap.end() &&
+               "Instruction should have been processed before its operands");
+        ArgEmb += DefIt->second;
+        continue;
+      }
+      // If the operand is not defined by an instruction, we use the vocabulary
+      else {
+        LLVM_DEBUG(errs() << "Using embedding from vocabulary for operand: "
+                          << *Op << "=" << Vocab[Op][0] << "\n");
+        ArgEmb += Vocab[Op];
+      }
+    }
+    // Create the instruction vector by combining opcode, type, and arguments
+    // embeddings
+    auto InstVector =
+        Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
+    InstVecMap[&I] = InstVector;
+    BBVector += InstVector;
   }
+  BBVecMap[&BB] = BBVector;
 }
 
 // ==----------------------------------------------------------------------===//
@@ -552,8 +601,17 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
   assert(Vocabulary.isValid() && "IR2Vec Vocabulary is invalid");
 
   for (Function &F : M) {
-    std::unique_ptr<Embedder> Emb =
-        Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);
+    std::unique_ptr<Embedder> Emb;
+    switch (IR2VecEmbeddingKind) {
+    case IR2VecKind::Symbolic:
+      Emb = std::make_unique<SymbolicEmbedder>(F, Vocabulary);
+      break;
+    case IR2VecKind::FlowAware:
+      Emb = std::make_unique<FlowAwareEmbedder>(F, Vocabulary);
+      break;
+    default:
+      llvm_unreachable("Unknown IR2Vec embedding kind");
+    }
     if (!Emb) {
       OS << "Error creating IR2Vec embeddings \n";
       continue;
diff --git a/llvm/test/Analysis/IR2Vec/basic-flowaware.ll b/llvm/test/Analysis/IR2Vec/basic-flowaware.ll
new file mode 100644
index 0000000000000..4a7f970a9cf91
--- /dev/null
+++ b/llvm/test/Analysis/IR2Vec/basic-flowaware.ll
@@ -0,0 +1,72 @@
+; RUN: opt -passes='print<ir2vec>' -ir2vec-kind=flow-aware -o /dev/null -ir2vec-vocab-path=%S/Inputs/dummy_3D_nonzero_opc_vocab.json %s 2>&1 | FileCheck %s -check-prefix=3D-CHECK-OPC
+; RUN: opt -passes='print<ir2vec>' -ir2vec-kind=flow-aware -o /dev/null -ir2vec-vocab-path=%S/Inputs/dummy_3D_nonzero_type_vocab.json %s 2>&1 | FileCheck %s -check-prefix=3D-CHECK-TYPE
+; RUN: opt -passes='print<ir2vec>' -ir2vec-kind=flow-aware -o /dev/null -ir2vec-vocab-path=%S/Inputs/dummy_3D_nonzero_arg_vocab.json %s 2>&1 | FileCheck %s -check-prefix=3D-CHECK-ARG
+
+define dso_local noundef float @_Z3abcif(i32 noundef %a, float noundef %b) #0 {
+entry:
+  %a.addr = alloca i32, align 4
+  %b.addr = alloca float, align 4
+  store i32 %a, ptr %a.addr, align 4
+  store float %b, ptr %b.addr, align 4
+  %0 = load i32, ptr %a.addr, align 4
+  %1 = load i32, ptr %a.addr, align 4
+  %mul = mul nsw i32 %0, %1
+  %conv = sitofp i32 %mul to float
+  %2 = load float, ptr %b.addr, align 4
+  %add = fadd float %conv, %2
+  ret float %add
+}
+
+; 3D-CHECK-OPC: IR2Vec embeddings for function _Z3abcif:
+; 3D-CHECK-OPC-NEXT: Function vector: [ 3630.00 3672.00 3714.00 ]
+; 3D-CHECK-OPC-NEXT: Basic block vectors:
+; 3D-CHECK-OPC-NEXT: Basic block: entry:
+; 3D-CHECK-OPC-NEXT:  [ 3630.00 3672.00 3714.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction vectors:
+; 3D-CHECK-OPC-NEXT: Instruction:   %a.addr = alloca i32, align 4 [ 91.00  92.00  93.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction:   %b.addr = alloca float, align 4 [ 91.00  92.00  93.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction:   store i32 %a, ptr %a.addr, align 4 [ 188.00 190.00 192.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction:   store float %b, ptr %b.addr, align 4 [ 188.00 190.00 192.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction:   %0 = load i32, ptr %a.addr, align 4 [ 185.00 187.00 189.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction:   %1 = load i32, ptr %a.addr, align 4 [ 185.00 187.00 189.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction:   %mul = mul nsw i32 %0, %1 [ 419.00  424.00  429.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction:   %conv = sitofp i32 %mul to float [ 549.00  555.00  561.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction:   %2 = load float, ptr %b.addr, align 4 [ 185.00  187.00  189.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction:   %add = fadd float %conv, %2 [ 774.00  783.00  792.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction:   ret float %add [ 775.00  785.00  795.00 ]
+
+; 3D-CHECK-TYPE: IR2Vec embeddings for function _Z3abcif:
+; 3D-CHECK-TYPE-NEXT: Function vector:  [ 355.50  376.50  397.50 ]
+; 3D-CHECK-TYPE-NEXT: Basic block vectors:
+; 3D-CHECK-TYPE-NEXT: Basic block: entry:
+; 3D-CHECK-TYPE-NEXT:  [ 355.50  376.50  397.50 ]
+; 3D-CHECK-TYPE-NEXT: Instruction vectors:
+; 3D-CHECK-TYPE-NEXT: Instruction:   %a.addr = alloca i32, align 4 [ 12.50  13.00  13.50 ]
+; 3D-CHECK-TYPE-NEXT: Instruction:   %b.addr = alloca float, align 4 [ 12.50  13.00  13.50 ]
+; 3D-CHECK-TYPE-NEXT: Instruction:   store i32 %a, ptr %a.addr, align 4 [ 14.50  15.50  16.50 ]
+; 3D-CHECK-TYPE-NEXT: Instruction:   store float %b, ptr %b.addr, align 4 [ 14.50  15.50  16.50 ]
+; 3D-CHECK-TYPE-NEXT: Instruction:   %0 = load i32, ptr %a.addr, align 4 [ 22.00  23.00  24.00 ]
+; 3D-CHECK-TYPE-NEXT: Instruction:   %1 = load i32, ptr %a.addr, align 4 [ 22.00  23.00  24.00 ]
+; 3D-CHECK-TYPE-NEXT: Instruction:   %mul = mul nsw i32 %0, %1 [ 53.50  56.00  58.50 ]
+; 3D-CHECK-TYPE-NEXT: Instruction:   %conv = sitofp i32 %mul to float [ 54.00  57.00  60.00 ]
+; 3D-CHECK-TYPE-NEXT: Instruction:   %2 = load float, ptr %b.addr, align 4 [ 13.00  14.00  15.00 ]
+; 3D-CHECK-TYPE-NEXT: Instruction:   %add = fadd float %conv, %2 [ 67.50  72.00  76.50 ]
+; 3D-CHECK-TYPE-NEXT: Instruction:   ret float %add [ 69.50  74.50  79.50 ]
+
+; 3D-CHECK-ARG: IR2Vec embeddings for function _Z3abcif:
+; 3D-CHECK-ARG-NEXT: Function vector:  [ 27.80  31.60  35.40 ]
+; 3D-CHECK-ARG-NEXT: Basic block vectors:
+; 3D-CHECK-ARG-NEXT: Basic block: entry:
+; 3D-CHECK-ARG-NEXT:  [ 27.80  31.60  35.40 ]
+; 3D-CHECK-ARG-NEXT: Instruction vectors:
+; 3D-CHECK-ARG-NEXT: Instruction:   %a.addr = alloca i32, align 4 [ 1.40  1.60  1.80 ]
+; 3D-CHECK-ARG-NEXT: Instruction:   %b.addr = alloca float, align 4 [ 1.40  1.60  1.80 ]
+; 3D-CHECK-ARG-NEXT: Instruction:   store i32 %a, ptr %a.addr, align 4 [ 3.40  3.80  4.20 ]
+; 3D-CHECK-ARG-NEXT: Instruction:   store float %b, ptr %b.addr, align 4 [ 3.40  3.80  4.20 ]
+; 3D-CHECK-ARG-NEXT: Instruction:   %0 = load i32, ptr %a.addr, align 4 [ 1.40  1.60  1.80 ]
+; 3D-CHECK-ARG-NEXT: Instruction:   %1 = load i32, ptr %a.addr, align 4 [ 1.40  1.60  1.80 ]
+; 3D-CHECK-ARG-NEXT: Instruction:   %mul = mul nsw i32 %0, %1 [ 2.80  3.20  3.60 ]
+; 3D-CHECK-ARG-NEXT: Instruction:   %conv = sitofp i32 %mul to float [ 2.80  3.20  3.60 ]
+; 3D-CHECK-ARG-NEXT: Instruction:   %2 = load float, ptr %b.addr, align 4 [ 1.40  1.60  1.80 ]
+; 3D-CHECK-ARG-NEXT: Instruction:   %add = fadd float %conv, %2 [ 4.20  4.80  5.40 ]
+; 3D-CHECK-ARG-NEXT: Instruction:   ret float %add [ 4.20  4.80  5.40 ]
diff --git a/llvm/test/Analysis/IR2Vec/basic.ll b/llvm/test/Analysis/IR2Vec/basic-symbolic.ll
similarity index 81%
rename from llvm/test/Analysis/IR2Vec/basic.ll
rename to llvm/test/Analysis/IR2Vec/basic-symbolic.ll
index cb0544fb19860..35abd3c7fa269 100644
--- a/llvm/test/Analysis/IR2Vec/basic.ll
+++ b/llvm/test/Analysis/IR2Vec/basic-symbolic.ll
@@ -1,11 +1,7 @@
 ; RUN: opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/dummy_3D_nonzero_opc_vocab.json %s 2>&1 | FileCheck %s -check-prefix=3D-CHECK-OPC
 ; RUN: opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/dummy_3D_nonzero_type_vocab.json %s 2>&1 | FileCheck %s -check-prefix=3D-CHECK-TYPE
 ; RUN: opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/dummy_3D_nonzero_arg_vocab.json %s 2>&1 | FileCheck %s -check-prefix=3D-CHECK-ARG
-; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab1.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB1-CHECK
-; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab2.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB2-CHECK
-; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab3.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB3-CHECK
-; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab4.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB4-CHECK
- 
+
 define dso_local noundef float @_Z3abcif(i32 noundef %a, float noundef %b) #0 {
 entry:
   %a.addr = alloca i32, align 4
@@ -74,11 +70,3 @@ entry:
 ; 3D-CHECK-ARG-NEXT: Instruction:   %2 = load float, ptr %b.addr, align 4 [ 0.80  1.00  1.20 ]
 ; 3D-CHECK-ARG-NEXT: Instruction:   %add = fadd float %conv, %2 [ 4.00  4.40  4.80 ]
 ; 3D-CHECK-ARG-NEXT: Instruction:   ret float %add [ 2.00  2.20  2.40 ]
-
-; INCORRECT-VOCAB1-CHECK: error: Error reading vocabulary: Missing 'Opcodes' section in vocabulary file
-
-; INCORRECT-VOCAB2-CHECK: error: Error reading vocabulary: Missing 'Types' section in vocabulary file
-
-; INCORRECT-VOCAB3-CHECK: error: Error reading vocabulary: Missing 'Arguments' section in vocabulary file
-
-; INCORRECT-VOCAB4-CHECK: error: Error reading vocabulary: Vocabulary sections have different dimensions
diff --git a/llvm/test/Analysis/IR2Vec/basic-vocab.ll b/llvm/test/Analysis/IR2Vec/basic-vocab.ll
new file mode 100644
index 0000000000000..eeeee831814a8
--- /dev/null
+++ b/llvm/test/Analysis/IR2Vec/basic-vocab.ll
@@ -0,0 +1,27 @@
+; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab1.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB1-CHECK
+; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab2.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB2-CHECK
+; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab3.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB3-CHECK
+; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab4.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB4-CHECK
+ 
+define dso_local noundef float @_Z3abcif(i32 noundef %a, float noundef %b) #0 {
+entry:
+  %a.addr = alloca i32, align 4
+  %b.addr = alloca float, align 4
+  store i32 %a, ptr %a.addr, align 4
+  store float %b, ptr %b.addr, align 4
+  %0 = load i32, ptr %a.addr, align 4
+  %1 = load i32, ptr %a.addr, align 4
+  %mul = mul nsw i32 %0, %1
+  %conv = sitofp i32 %mul to float
+  %2 = load float, ptr %b.addr, align 4
+  %add = fadd float %conv, %2
+  ret float %add
+}
+
+; INCORRECT-VOCAB1-CHECK: error: Error reading vocabulary: Missing 'Opcodes' section in vocabulary file
+
+; INCORRECT-VOCAB2-CHECK: error: Error reading vocabulary: Missing 'Types' section in vocabulary file
+
+; INCORRECT-VOCAB3-CHECK: error: Error reading vocabulary: Missing 'Arguments' section in vocabulary file
+
+; INCORRECT-VOCAB4-CHECK: error: Error reading vocabulary: Vocabulary sections have different dimensions
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp
index f7838cc4068ce..f0c81e160ca15 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -30,7 +30,6 @@ namespace {
 class TestableEmbedder : public Embedder {
 public:
   TestableEmbedder(const Function &F, const Vocabulary &V) : Embedder(F, V) {}
-  void computeEmbeddings() const override {}
   void computeEmbeddings(const BasicBlock &BB) const override {}
 };
 
@@ -258,6 +257,18 @@ TEST(IR2VecTest, CreateSymbolicEmbedder) {
   EXPECT_NE(Emb, nullptr);
 }
 
+TEST(IR2VecTest, CreateFlowAwareEmbedder) {
+  Vocabulary V = Vocabulary(Vocabulary::createDummyVocabForTest());
+
+  LLVMContext Ctx;
+  Module M("M", Ctx);
+  FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
+  Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
+
+  auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, V);
+  EXPECT_NE(Emb, nullptr);
+}
+
 TEST(IR2VecTest, CreateInvalidMode) {
   Vocabulary V = Vocabulary(Vocabulary::createDummyVocabForTest());
 
@@ -310,7 +321,7 @@ class IR2VecTestFixture : public ::testing::Test {
   }
 };
 
-TEST_F(IR2VecTestFixture, GetInstVecMap) {
+TEST_F(IR2VecTestFixture, GetInstVecMap_Symbolic) {
   auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
   ASSERT_TRUE(static_cast<bool>(Emb));
 
@@ -329,7 +340,24 @@ TEST_F(IR2VecTestFixture, GetInstVecMap) {
   EXPECT_TRUE(RetEmb.approximatelyEquals(Embedding(2, 17.0)));
 }
 
-TEST_F(IR2VecTestFixture, GetBBVecMap) {
+TEST_F(IR2VecTestFixture, GetInstVecMap_FlowAware) {
+  auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, V);
+  ASSERT_TRUE(static_cast<bool>(Emb));
+
+  const auto &InstMap = Emb->getInstVecMap();
+
+  EXPECT_EQ(InstMap.size(), 2u);
+  EXPECT_TRUE(InstMap.count(AddInst));
+  EXPECT_TRUE(InstMap.count(RetInst));
+
+  EXPECT_EQ(InstMap.at(AddInst).size(), 2u);
+  EXPECT_EQ(InstMap.at(RetInst).size(), 2u);
+
+  EXPECT_TRUE(InstMap.at(AddInst).approximatelyEquals(Embedding(2, 27.9)));
+  EXPECT_TRUE(InstMap.at(RetInst).approximatelyEquals(Embedding(2, 35.6)));
+}
+
+TEST_F(IR2VecTestFixture, GetBBVecMap_Symbolic) {
   auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
   ASSERT_TRUE(static_cast<bool>(Emb));
 
@@ -344,7 +372,22 @@ TEST_F(IR2VecTestFixture, GetBBVecMap) {
   EXPECT_TRUE(BBMap.at(BB).approximatelyEquals(Embedding(2, 44.9)));
 }
 
-TEST_F(IR2VecTestFixture, GetBBVector) {
+TEST_F(IR2VecTestFixture, GetBBVecMap_FlowAware) {
+  auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, V);
+  ASSERT_TRUE(static_cast<bool>(Emb));
+
+  const auto &BBMap = Emb->getBBVecMap();
+
+  EXPECT_EQ(BBMap.size(), 1u);
+  EXPECT_TRUE(BBMap.count(BB));
+  EXPECT_EQ(BBMap.at(BB).size(), 2u);
+
+  // BB vector should be sum of add and ret: {27.9, 27.9} + {35.6, 35.6} =
+  // {63.5, 63.5}
+  EXPECT_TRUE(BBMap.at(BB).approximatelyEquals(Embedding(2, 63.5)));
+}
+
+TEST_F(IR2VecTestFixture, GetBBVector_Symbolic) {
   auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
   ASSERT_TRUE(static_cast<bool>(Emb));
 
@@ -354,7 +397,17 @@ TEST_F(IR2VecTestFixture, GetBBVector) {
   EXPECT_TRUE(BBVec.approximatelyEquals(Embedding(2, 44.9)));
 }
 
-TEST_F(IR2VecTestFixture, GetFunctionVector) {
+TEST_F(IR2VecTestFixture, GetBBVector_FlowAware) {
+  auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, V);
+  ASSERT_TRUE(static_cast<bool>(Emb));
+
+  const auto &BBVec = Emb->getBBVector(*BB);
+
+  EXPECT_EQ(BBVec.size(), 2u);
+  EXPECT_TRUE(BBVec.approximatelyEquals(Embedding(2, 63.5)));
+}
+
+TEST_F(IR2VecTestFixture, GetFunctionVector_Symbolic) {
   auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
   ASSERT_TRUE(static_cast<bool>(Emb));
 
@@ -366,6 +419,17 @@ TEST_F(IR2VecTestFixture, GetFunctionVector) {
   EXPECT_TRUE(FuncVec.approximatelyEquals(Embedding(2, 44.9)));
 }
 
+TEST_F(IR2VecTestFixture, GetFunctionVector_FlowAware) {
+  auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, V);
+  ASSERT_TRUE(static_cast<bool>(Emb));
+
+  const auto &FuncVec = Emb->getFunctionVector();
+
+  EXPECT_EQ(FuncVec.size(), 2u);
+  // Function vector should match BB vector (only one BB): {63.5, 63.5}
+  EXPECT_TRUE(FuncVec.approximatelyEquals(Embedding(2, 63.5)));
+}
+
 static constexpr unsigned MaxOpcodes = Vocabulary::MaxOpcodes;
 static constexpr unsigned MaxTypeIDs = Vocabulary::MaxTypeIDs;
 static constexpr unsigned MaxOperands = Vocabulary::MaxOperandKinds;



More information about the llvm-commits mailing list