[llvm] [IR2Vec] Add support for flow-aware embeddings (PR #152613)
S. VenkataKeerthy via llvm-commits
llvm-commits at lists.llvm.org
Thu Aug 28 12:03:05 PDT 2025
https://github.com/svkeerthy updated https://github.com/llvm/llvm-project/pull/152613
>From 7869cc0d583680f7218fe69d59cc37f692589825 Mon Sep 17 00:00:00 2001
From: svkeerthy <venkatakeerthy at google.com>
Date: Thu, 7 Aug 2025 23:43:45 +0000
Subject: [PATCH] Flow-Aware Embeddings
---
llvm/include/llvm/Analysis/IR2Vec.h | 31 +++++--
llvm/lib/Analysis/IR2Vec.cpp | 82 ++++++++++++++++---
llvm/test/Analysis/IR2Vec/basic-flowaware.ll | 72 ++++++++++++++++
.../IR2Vec/{basic.ll => basic-symbolic.ll} | 14 +---
llvm/test/Analysis/IR2Vec/basic-vocab.ll | 27 ++++++
llvm/unittests/Analysis/IR2VecTest.cpp | 74 +++++++++++++++--
6 files changed, 261 insertions(+), 39 deletions(-)
create mode 100644 llvm/test/Analysis/IR2Vec/basic-flowaware.ll
rename llvm/test/Analysis/IR2Vec/{basic.ll => basic-symbolic.ll} (81%)
create mode 100644 llvm/test/Analysis/IR2Vec/basic-vocab.ll
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 17f41129fd4fa..15221c7f07917 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -54,14 +54,19 @@ class IR2VecVocabAnalysis;
/// of the IR entities. Flow-aware embeddings build on top of symbolic
/// embeddings and additionally capture the flow information in the IR.
/// IR2VecKind is used to specify the type of embeddings to generate.
-/// Currently, only Symbolic embeddings are supported.
-enum class IR2VecKind { Symbolic };
+/// Note: Implementation of FlowAware embeddings is not same as the one
+/// described in the paper. The current implementation is a simplified version
+/// that captures the flow information (SSA-based use-defs) without tracing
+/// through memory level use-defs in the embedding computation described in the
+/// paper.
+enum class IR2VecKind { Symbolic, FlowAware };
namespace ir2vec {
LLVM_ABI extern cl::opt<float> OpcWeight;
LLVM_ABI extern cl::opt<float> TypeWeight;
LLVM_ABI extern cl::opt<float> ArgWeight;
+LLVM_ABI extern cl::opt<IR2VecKind> IR2VecEmbeddingKind;
/// Embedding is a datatype that wraps std::vector<double>. It provides
/// additional functionality for arithmetic and comparison operations.
@@ -257,9 +262,8 @@ class Embedder {
LLVM_ABI Embedder(const Function &F, const Vocabulary &Vocab);
/// Helper function to compute embeddings. It generates embeddings for all
- /// the instructions and basic blocks in the function F. Logic of computing
- /// the embeddings is specific to the kind of embeddings being computed.
- virtual void computeEmbeddings() const = 0;
+ /// the instructions and basic blocks in the function F.
+ void computeEmbeddings() const;
/// Helper function to compute the embedding for a given basic block.
/// Specific to the kind of embeddings being computed.
@@ -296,14 +300,23 @@ class Embedder {
/// representations obtained from the Vocabulary.
class LLVM_ABI SymbolicEmbedder : public Embedder {
private:
- void computeEmbeddings() const override;
void computeEmbeddings(const BasicBlock &BB) const override;
public:
SymbolicEmbedder(const Function &F, const Vocabulary &Vocab)
- : Embedder(F, Vocab) {
- FuncVector = Embedding(Dimension, 0);
- }
+ : Embedder(F, Vocab) {}
+};
+
+/// Class for computing the Flow-aware embeddings of IR2Vec.
+/// Flow-aware embeddings build on the vocabulary, just like Symbolic
+/// embeddings, and additionally capture the flow information in the IR.
+class LLVM_ABI FlowAwareEmbedder : public Embedder {
+private:
+ void computeEmbeddings(const BasicBlock &BB) const override;
+
+public:
+ FlowAwareEmbedder(const Function &F, const Vocabulary &Vocab)
+ : Embedder(F, Vocab) {}
};
} // namespace ir2vec
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 95f30fd3f4275..7d0ad6a69a398 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -52,6 +52,15 @@ cl::opt<float> TypeWeight("ir2vec-type-weight", cl::Optional, cl::init(0.5),
cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional, cl::init(0.2),
cl::desc("Weight for argument embeddings"),
cl::cat(IR2VecCategory));
+cl::opt<IR2VecKind> IR2VecEmbeddingKind(
+ "ir2vec-kind", cl::Optional,
+ cl::values(clEnumValN(IR2VecKind::Symbolic, "symbolic",
+ "Generate symbolic embeddings"),
+ clEnumValN(IR2VecKind::FlowAware, "flow-aware",
+ "Generate flow-aware embeddings")),
+ cl::init(IR2VecKind::Symbolic), cl::desc("IR2Vec embedding kind"),
+ cl::cat(IR2VecCategory));
+
} // namespace ir2vec
} // namespace llvm
@@ -123,8 +132,12 @@ bool Embedding::approximatelyEquals(const Embedding &RHS,
double Tolerance) const {
assert(this->size() == RHS.size() && "Vectors must have the same dimension");
for (size_t Itr = 0; Itr < this->size(); ++Itr)
- if (std::abs((*this)[Itr] - RHS[Itr]) > Tolerance)
+ if (std::abs((*this)[Itr] - RHS[Itr]) > Tolerance) {
+ LLVM_DEBUG(errs() << "Embedding mismatch at index " << Itr << ": "
+ << (*this)[Itr] << " vs " << RHS[Itr]
+ << "; Tolerance: " << Tolerance << "\n");
return false;
+ }
return true;
}
@@ -141,14 +154,16 @@ void Embedding::print(raw_ostream &OS) const {
Embedder::Embedder(const Function &F, const Vocabulary &Vocab)
: F(F), Vocab(Vocab), Dimension(Vocab.getDimension()),
- OpcWeight(::OpcWeight), TypeWeight(::TypeWeight), ArgWeight(::ArgWeight) {
-}
+ OpcWeight(::OpcWeight), TypeWeight(::TypeWeight), ArgWeight(::ArgWeight),
+ FuncVector(Embedding(Dimension, 0)) {}
std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
const Vocabulary &Vocab) {
switch (Mode) {
case IR2VecKind::Symbolic:
return std::make_unique<SymbolicEmbedder>(F, Vocab);
+ case IR2VecKind::FlowAware:
+ return std::make_unique<FlowAwareEmbedder>(F, Vocab);
}
return nullptr;
}
@@ -180,6 +195,17 @@ const Embedding &Embedder::getFunctionVector() const {
return FuncVector;
}
+void Embedder::computeEmbeddings() const {
+ if (F.isDeclaration())
+ return;
+
+ // Consider only the basic blocks that are reachable from entry
+ for (const BasicBlock *BB : depth_first(&F)) {
+ computeEmbeddings(*BB);
+ FuncVector += BBVecMap[BB];
+ }
+}
+
void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
Embedding BBVector(Dimension, 0);
@@ -196,15 +222,38 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
BBVecMap[&BB] = BBVector;
}
-void SymbolicEmbedder::computeEmbeddings() const {
- if (F.isDeclaration())
- return;
+void FlowAwareEmbedder::computeEmbeddings(const BasicBlock &BB) const {
+ Embedding BBVector(Dimension, 0);
- // Consider only the basic blocks that are reachable from entry
- for (const BasicBlock *BB : depth_first(&F)) {
- computeEmbeddings(*BB);
- FuncVector += BBVecMap[BB];
+ // We consider only the non-debug and non-pseudo instructions
+ for (const auto &I : BB.instructionsWithoutDebug()) {
+ // TODO: Handle call instructions differently.
+ // For now, we treat them like other instructions
+ Embedding ArgEmb(Dimension, 0);
+ for (const auto &Op : I.operands()) {
+ // If the operand is defined elsewhere, we use its embedding
+ if (const auto *DefInst = dyn_cast<Instruction>(Op)) {
+ auto DefIt = InstVecMap.find(DefInst);
+ assert(DefIt != InstVecMap.end() &&
+ "Instruction should have been processed before its operands");
+ ArgEmb += DefIt->second;
+ continue;
+ }
+ // If the operand is not defined by an instruction, we use the vocabulary
+ else {
+ LLVM_DEBUG(errs() << "Using embedding from vocabulary for operand: "
+ << *Op << "=" << Vocab[Op][0] << "\n");
+ ArgEmb += Vocab[Op];
+ }
+ }
+ // Create the instruction vector by combining opcode, type, and arguments
+ // embeddings
+ auto InstVector =
+ Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
+ InstVecMap[&I] = InstVector;
+ BBVector += InstVector;
}
+ BBVecMap[&BB] = BBVector;
}
// ==----------------------------------------------------------------------===//
@@ -552,8 +601,17 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
assert(Vocabulary.isValid() && "IR2Vec Vocabulary is invalid");
for (Function &F : M) {
- std::unique_ptr<Embedder> Emb =
- Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);
+ std::unique_ptr<Embedder> Emb;
+ switch (IR2VecEmbeddingKind) {
+ case IR2VecKind::Symbolic:
+ Emb = std::make_unique<SymbolicEmbedder>(F, Vocabulary);
+ break;
+ case IR2VecKind::FlowAware:
+ Emb = std::make_unique<FlowAwareEmbedder>(F, Vocabulary);
+ break;
+ default:
+ llvm_unreachable("Unknown IR2Vec embedding kind");
+ }
if (!Emb) {
OS << "Error creating IR2Vec embeddings \n";
continue;
diff --git a/llvm/test/Analysis/IR2Vec/basic-flowaware.ll b/llvm/test/Analysis/IR2Vec/basic-flowaware.ll
new file mode 100644
index 0000000000000..4a7f970a9cf91
--- /dev/null
+++ b/llvm/test/Analysis/IR2Vec/basic-flowaware.ll
@@ -0,0 +1,72 @@
+; RUN: opt -passes='print<ir2vec>' -ir2vec-kind=flow-aware -o /dev/null -ir2vec-vocab-path=%S/Inputs/dummy_3D_nonzero_opc_vocab.json %s 2>&1 | FileCheck %s -check-prefix=3D-CHECK-OPC
+; RUN: opt -passes='print<ir2vec>' -ir2vec-kind=flow-aware -o /dev/null -ir2vec-vocab-path=%S/Inputs/dummy_3D_nonzero_type_vocab.json %s 2>&1 | FileCheck %s -check-prefix=3D-CHECK-TYPE
+; RUN: opt -passes='print<ir2vec>' -ir2vec-kind=flow-aware -o /dev/null -ir2vec-vocab-path=%S/Inputs/dummy_3D_nonzero_arg_vocab.json %s 2>&1 | FileCheck %s -check-prefix=3D-CHECK-ARG
+
+define dso_local noundef float @_Z3abcif(i32 noundef %a, float noundef %b) #0 {
+entry:
+ %a.addr = alloca i32, align 4
+ %b.addr = alloca float, align 4
+ store i32 %a, ptr %a.addr, align 4
+ store float %b, ptr %b.addr, align 4
+ %0 = load i32, ptr %a.addr, align 4
+ %1 = load i32, ptr %a.addr, align 4
+ %mul = mul nsw i32 %0, %1
+ %conv = sitofp i32 %mul to float
+ %2 = load float, ptr %b.addr, align 4
+ %add = fadd float %conv, %2
+ ret float %add
+}
+
+; 3D-CHECK-OPC: IR2Vec embeddings for function _Z3abcif:
+; 3D-CHECK-OPC-NEXT: Function vector: [ 3630.00 3672.00 3714.00 ]
+; 3D-CHECK-OPC-NEXT: Basic block vectors:
+; 3D-CHECK-OPC-NEXT: Basic block: entry:
+; 3D-CHECK-OPC-NEXT: [ 3630.00 3672.00 3714.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction vectors:
+; 3D-CHECK-OPC-NEXT: Instruction: %a.addr = alloca i32, align 4 [ 91.00 92.00 93.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction: %b.addr = alloca float, align 4 [ 91.00 92.00 93.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction: store i32 %a, ptr %a.addr, align 4 [ 188.00 190.00 192.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction: store float %b, ptr %b.addr, align 4 [ 188.00 190.00 192.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction: %0 = load i32, ptr %a.addr, align 4 [ 185.00 187.00 189.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction: %1 = load i32, ptr %a.addr, align 4 [ 185.00 187.00 189.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction: %mul = mul nsw i32 %0, %1 [ 419.00 424.00 429.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction: %conv = sitofp i32 %mul to float [ 549.00 555.00 561.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction: %2 = load float, ptr %b.addr, align 4 [ 185.00 187.00 189.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction: %add = fadd float %conv, %2 [ 774.00 783.00 792.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction: ret float %add [ 775.00 785.00 795.00 ]
+
+; 3D-CHECK-TYPE: IR2Vec embeddings for function _Z3abcif:
+; 3D-CHECK-TYPE-NEXT: Function vector: [ 355.50 376.50 397.50 ]
+; 3D-CHECK-TYPE-NEXT: Basic block vectors:
+; 3D-CHECK-TYPE-NEXT: Basic block: entry:
+; 3D-CHECK-TYPE-NEXT: [ 355.50 376.50 397.50 ]
+; 3D-CHECK-TYPE-NEXT: Instruction vectors:
+; 3D-CHECK-TYPE-NEXT: Instruction: %a.addr = alloca i32, align 4 [ 12.50 13.00 13.50 ]
+; 3D-CHECK-TYPE-NEXT: Instruction: %b.addr = alloca float, align 4 [ 12.50 13.00 13.50 ]
+; 3D-CHECK-TYPE-NEXT: Instruction: store i32 %a, ptr %a.addr, align 4 [ 14.50 15.50 16.50 ]
+; 3D-CHECK-TYPE-NEXT: Instruction: store float %b, ptr %b.addr, align 4 [ 14.50 15.50 16.50 ]
+; 3D-CHECK-TYPE-NEXT: Instruction: %0 = load i32, ptr %a.addr, align 4 [ 22.00 23.00 24.00 ]
+; 3D-CHECK-TYPE-NEXT: Instruction: %1 = load i32, ptr %a.addr, align 4 [ 22.00 23.00 24.00 ]
+; 3D-CHECK-TYPE-NEXT: Instruction: %mul = mul nsw i32 %0, %1 [ 53.50 56.00 58.50 ]
+; 3D-CHECK-TYPE-NEXT: Instruction: %conv = sitofp i32 %mul to float [ 54.00 57.00 60.00 ]
+; 3D-CHECK-TYPE-NEXT: Instruction: %2 = load float, ptr %b.addr, align 4 [ 13.00 14.00 15.00 ]
+; 3D-CHECK-TYPE-NEXT: Instruction: %add = fadd float %conv, %2 [ 67.50 72.00 76.50 ]
+; 3D-CHECK-TYPE-NEXT: Instruction: ret float %add [ 69.50 74.50 79.50 ]
+
+; 3D-CHECK-ARG: IR2Vec embeddings for function _Z3abcif:
+; 3D-CHECK-ARG-NEXT: Function vector: [ 27.80 31.60 35.40 ]
+; 3D-CHECK-ARG-NEXT: Basic block vectors:
+; 3D-CHECK-ARG-NEXT: Basic block: entry:
+; 3D-CHECK-ARG-NEXT: [ 27.80 31.60 35.40 ]
+; 3D-CHECK-ARG-NEXT: Instruction vectors:
+; 3D-CHECK-ARG-NEXT: Instruction: %a.addr = alloca i32, align 4 [ 1.40 1.60 1.80 ]
+; 3D-CHECK-ARG-NEXT: Instruction: %b.addr = alloca float, align 4 [ 1.40 1.60 1.80 ]
+; 3D-CHECK-ARG-NEXT: Instruction: store i32 %a, ptr %a.addr, align 4 [ 3.40 3.80 4.20 ]
+; 3D-CHECK-ARG-NEXT: Instruction: store float %b, ptr %b.addr, align 4 [ 3.40 3.80 4.20 ]
+; 3D-CHECK-ARG-NEXT: Instruction: %0 = load i32, ptr %a.addr, align 4 [ 1.40 1.60 1.80 ]
+; 3D-CHECK-ARG-NEXT: Instruction: %1 = load i32, ptr %a.addr, align 4 [ 1.40 1.60 1.80 ]
+; 3D-CHECK-ARG-NEXT: Instruction: %mul = mul nsw i32 %0, %1 [ 2.80 3.20 3.60 ]
+; 3D-CHECK-ARG-NEXT: Instruction: %conv = sitofp i32 %mul to float [ 2.80 3.20 3.60 ]
+; 3D-CHECK-ARG-NEXT: Instruction: %2 = load float, ptr %b.addr, align 4 [ 1.40 1.60 1.80 ]
+; 3D-CHECK-ARG-NEXT: Instruction: %add = fadd float %conv, %2 [ 4.20 4.80 5.40 ]
+; 3D-CHECK-ARG-NEXT: Instruction: ret float %add [ 4.20 4.80 5.40 ]
diff --git a/llvm/test/Analysis/IR2Vec/basic.ll b/llvm/test/Analysis/IR2Vec/basic-symbolic.ll
similarity index 81%
rename from llvm/test/Analysis/IR2Vec/basic.ll
rename to llvm/test/Analysis/IR2Vec/basic-symbolic.ll
index cb0544fb19860..35abd3c7fa269 100644
--- a/llvm/test/Analysis/IR2Vec/basic.ll
+++ b/llvm/test/Analysis/IR2Vec/basic-symbolic.ll
@@ -1,11 +1,7 @@
; RUN: opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/dummy_3D_nonzero_opc_vocab.json %s 2>&1 | FileCheck %s -check-prefix=3D-CHECK-OPC
; RUN: opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/dummy_3D_nonzero_type_vocab.json %s 2>&1 | FileCheck %s -check-prefix=3D-CHECK-TYPE
; RUN: opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/dummy_3D_nonzero_arg_vocab.json %s 2>&1 | FileCheck %s -check-prefix=3D-CHECK-ARG
-; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab1.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB1-CHECK
-; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab2.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB2-CHECK
-; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab3.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB3-CHECK
-; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab4.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB4-CHECK
-
+
define dso_local noundef float @_Z3abcif(i32 noundef %a, float noundef %b) #0 {
entry:
%a.addr = alloca i32, align 4
@@ -74,11 +70,3 @@ entry:
; 3D-CHECK-ARG-NEXT: Instruction: %2 = load float, ptr %b.addr, align 4 [ 0.80 1.00 1.20 ]
; 3D-CHECK-ARG-NEXT: Instruction: %add = fadd float %conv, %2 [ 4.00 4.40 4.80 ]
; 3D-CHECK-ARG-NEXT: Instruction: ret float %add [ 2.00 2.20 2.40 ]
-
-; INCORRECT-VOCAB1-CHECK: error: Error reading vocabulary: Missing 'Opcodes' section in vocabulary file
-
-; INCORRECT-VOCAB2-CHECK: error: Error reading vocabulary: Missing 'Types' section in vocabulary file
-
-; INCORRECT-VOCAB3-CHECK: error: Error reading vocabulary: Missing 'Arguments' section in vocabulary file
-
-; INCORRECT-VOCAB4-CHECK: error: Error reading vocabulary: Vocabulary sections have different dimensions
diff --git a/llvm/test/Analysis/IR2Vec/basic-vocab.ll b/llvm/test/Analysis/IR2Vec/basic-vocab.ll
new file mode 100644
index 0000000000000..eeeee831814a8
--- /dev/null
+++ b/llvm/test/Analysis/IR2Vec/basic-vocab.ll
@@ -0,0 +1,27 @@
+; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab1.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB1-CHECK
+; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab2.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB2-CHECK
+; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab3.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB3-CHECK
+; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab4.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB4-CHECK
+
+define dso_local noundef float @_Z3abcif(i32 noundef %a, float noundef %b) #0 {
+entry:
+ %a.addr = alloca i32, align 4
+ %b.addr = alloca float, align 4
+ store i32 %a, ptr %a.addr, align 4
+ store float %b, ptr %b.addr, align 4
+ %0 = load i32, ptr %a.addr, align 4
+ %1 = load i32, ptr %a.addr, align 4
+ %mul = mul nsw i32 %0, %1
+ %conv = sitofp i32 %mul to float
+ %2 = load float, ptr %b.addr, align 4
+ %add = fadd float %conv, %2
+ ret float %add
+}
+
+; INCORRECT-VOCAB1-CHECK: error: Error reading vocabulary: Missing 'Opcodes' section in vocabulary file
+
+; INCORRECT-VOCAB2-CHECK: error: Error reading vocabulary: Missing 'Types' section in vocabulary file
+
+; INCORRECT-VOCAB3-CHECK: error: Error reading vocabulary: Missing 'Arguments' section in vocabulary file
+
+; INCORRECT-VOCAB4-CHECK: error: Error reading vocabulary: Vocabulary sections have different dimensions
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp
index f7838cc4068ce..f0c81e160ca15 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -30,7 +30,6 @@ namespace {
class TestableEmbedder : public Embedder {
public:
TestableEmbedder(const Function &F, const Vocabulary &V) : Embedder(F, V) {}
- void computeEmbeddings() const override {}
void computeEmbeddings(const BasicBlock &BB) const override {}
};
@@ -258,6 +257,18 @@ TEST(IR2VecTest, CreateSymbolicEmbedder) {
EXPECT_NE(Emb, nullptr);
}
+TEST(IR2VecTest, CreateFlowAwareEmbedder) {
+ Vocabulary V = Vocabulary(Vocabulary::createDummyVocabForTest());
+
+ LLVMContext Ctx;
+ Module M("M", Ctx);
+ FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
+ Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
+
+ auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, V);
+ EXPECT_NE(Emb, nullptr);
+}
+
TEST(IR2VecTest, CreateInvalidMode) {
Vocabulary V = Vocabulary(Vocabulary::createDummyVocabForTest());
@@ -310,7 +321,7 @@ class IR2VecTestFixture : public ::testing::Test {
}
};
-TEST_F(IR2VecTestFixture, GetInstVecMap) {
+TEST_F(IR2VecTestFixture, GetInstVecMap_Symbolic) {
auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
ASSERT_TRUE(static_cast<bool>(Emb));
@@ -329,7 +340,24 @@ TEST_F(IR2VecTestFixture, GetInstVecMap) {
EXPECT_TRUE(RetEmb.approximatelyEquals(Embedding(2, 17.0)));
}
-TEST_F(IR2VecTestFixture, GetBBVecMap) {
+TEST_F(IR2VecTestFixture, GetInstVecMap_FlowAware) {
+ auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, V);
+ ASSERT_TRUE(static_cast<bool>(Emb));
+
+ const auto &InstMap = Emb->getInstVecMap();
+
+ EXPECT_EQ(InstMap.size(), 2u);
+ EXPECT_TRUE(InstMap.count(AddInst));
+ EXPECT_TRUE(InstMap.count(RetInst));
+
+ EXPECT_EQ(InstMap.at(AddInst).size(), 2u);
+ EXPECT_EQ(InstMap.at(RetInst).size(), 2u);
+
+ EXPECT_TRUE(InstMap.at(AddInst).approximatelyEquals(Embedding(2, 27.9)));
+ EXPECT_TRUE(InstMap.at(RetInst).approximatelyEquals(Embedding(2, 35.6)));
+}
+
+TEST_F(IR2VecTestFixture, GetBBVecMap_Symbolic) {
auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
ASSERT_TRUE(static_cast<bool>(Emb));
@@ -344,7 +372,22 @@ TEST_F(IR2VecTestFixture, GetBBVecMap) {
EXPECT_TRUE(BBMap.at(BB).approximatelyEquals(Embedding(2, 44.9)));
}
-TEST_F(IR2VecTestFixture, GetBBVector) {
+TEST_F(IR2VecTestFixture, GetBBVecMap_FlowAware) {
+ auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, V);
+ ASSERT_TRUE(static_cast<bool>(Emb));
+
+ const auto &BBMap = Emb->getBBVecMap();
+
+ EXPECT_EQ(BBMap.size(), 1u);
+ EXPECT_TRUE(BBMap.count(BB));
+ EXPECT_EQ(BBMap.at(BB).size(), 2u);
+
+ // BB vector should be sum of add and ret: {27.9, 27.9} + {35.6, 35.6} =
+ // {63.5, 63.5}
+ EXPECT_TRUE(BBMap.at(BB).approximatelyEquals(Embedding(2, 63.5)));
+}
+
+TEST_F(IR2VecTestFixture, GetBBVector_Symbolic) {
auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
ASSERT_TRUE(static_cast<bool>(Emb));
@@ -354,7 +397,17 @@ TEST_F(IR2VecTestFixture, GetBBVector) {
EXPECT_TRUE(BBVec.approximatelyEquals(Embedding(2, 44.9)));
}
-TEST_F(IR2VecTestFixture, GetFunctionVector) {
+TEST_F(IR2VecTestFixture, GetBBVector_FlowAware) {
+ auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, V);
+ ASSERT_TRUE(static_cast<bool>(Emb));
+
+ const auto &BBVec = Emb->getBBVector(*BB);
+
+ EXPECT_EQ(BBVec.size(), 2u);
+ EXPECT_TRUE(BBVec.approximatelyEquals(Embedding(2, 63.5)));
+}
+
+TEST_F(IR2VecTestFixture, GetFunctionVector_Symbolic) {
auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
ASSERT_TRUE(static_cast<bool>(Emb));
@@ -366,6 +419,17 @@ TEST_F(IR2VecTestFixture, GetFunctionVector) {
EXPECT_TRUE(FuncVec.approximatelyEquals(Embedding(2, 44.9)));
}
+TEST_F(IR2VecTestFixture, GetFunctionVector_FlowAware) {
+ auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, V);
+ ASSERT_TRUE(static_cast<bool>(Emb));
+
+ const auto &FuncVec = Emb->getFunctionVector();
+
+ EXPECT_EQ(FuncVec.size(), 2u);
+ // Function vector should match BB vector (only one BB): {63.5, 63.5}
+ EXPECT_TRUE(FuncVec.approximatelyEquals(Embedding(2, 63.5)));
+}
+
static constexpr unsigned MaxOpcodes = Vocabulary::MaxOpcodes;
static constexpr unsigned MaxTypeIDs = Vocabulary::MaxTypeIDs;
static constexpr unsigned MaxOperands = Vocabulary::MaxOperandKinds;
More information about the llvm-commits
mailing list