[llvm] [IR2Vec] Add support for flow-aware embeddings (PR #152613)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Aug 7 16:47:29 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-analysis
Author: S. VenkataKeerthy (svkeerthy)
<details>
<summary>Changes</summary>
This patch introduces support for Flow-Aware embeddings in IR2Vec, which capture data flow information in addition to symbolic representations.
---
Patch is 20.99 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/152613.diff
6 Files Affected:
- (modified) llvm/include/llvm/Analysis/IR2Vec.h (+23-6)
- (modified) llvm/lib/Analysis/IR2Vec.cpp (+65-9)
- (added) llvm/test/Analysis/IR2Vec/basic-flowaware.ll (+72)
- (renamed) llvm/test/Analysis/IR2Vec/basic-symbolic.ll (+1-13)
- (added) llvm/test/Analysis/IR2Vec/basic-vocab.ll (+27)
- (modified) llvm/unittests/Analysis/IR2VecTest.cpp (+71-5)
``````````diff
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 17f41129fd4fa..3cfc206c94788 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -54,14 +54,19 @@ class IR2VecVocabAnalysis;
/// of the IR entities. Flow-aware embeddings build on top of symbolic
/// embeddings and additionally capture the flow information in the IR.
/// IR2VecKind is used to specify the type of embeddings to generate.
-/// Currently, only Symbolic embeddings are supported.
-enum class IR2VecKind { Symbolic };
+/// Note: Implementation of FlowAware embeddings is not same as the one
+/// described in the paper. The current implementation is a simplified version
+/// that captures the flow information (SSA-based use-defs) without tracing
+/// through memory level use-defs in the embedding computation described in the
+/// paper.
+enum class IR2VecKind { Symbolic, FlowAware };
namespace ir2vec {
LLVM_ABI extern cl::opt<float> OpcWeight;
LLVM_ABI extern cl::opt<float> TypeWeight;
LLVM_ABI extern cl::opt<float> ArgWeight;
+LLVM_ABI extern cl::opt<IR2VecKind> IR2VecEmbeddingKind;
/// Embedding is a datatype that wraps std::vector<double>. It provides
/// additional functionality for arithmetic and comparison operations.
@@ -257,9 +262,8 @@ class Embedder {
LLVM_ABI Embedder(const Function &F, const Vocabulary &Vocab);
/// Helper function to compute embeddings. It generates embeddings for all
- /// the instructions and basic blocks in the function F. Logic of computing
- /// the embeddings is specific to the kind of embeddings being computed.
- virtual void computeEmbeddings() const = 0;
+ /// the instructions and basic blocks in the function F.
+ void computeEmbeddings() const;
/// Helper function to compute the embedding for a given basic block.
/// Specific to the kind of embeddings being computed.
@@ -296,7 +300,6 @@ class Embedder {
/// representations obtained from the Vocabulary.
class LLVM_ABI SymbolicEmbedder : public Embedder {
private:
- void computeEmbeddings() const override;
void computeEmbeddings(const BasicBlock &BB) const override;
public:
@@ -306,6 +309,20 @@ class LLVM_ABI SymbolicEmbedder : public Embedder {
}
};
+/// Class for computing the Flow-aware embeddings of IR2Vec.
+/// Flow-aware embeddings build on the vocabulary, just like Symbolic
+/// embeddings, and additionally capture the flow information in the IR.
+class LLVM_ABI FlowAwareEmbedder : public Embedder {
+private:
+ void computeEmbeddings(const BasicBlock &BB) const override;
+
+public:
+ FlowAwareEmbedder(const Function &F, const Vocabulary &Vocab)
+ : Embedder(F, Vocab) {
+ FuncVector = Embedding(Dimension, 0);
+ }
+};
+
} // namespace ir2vec
/// This analysis provides the vocabulary for IR2Vec. The vocabulary provides a
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 95f30fd3f4275..0bea25ec26b8e 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -52,6 +52,15 @@ cl::opt<float> TypeWeight("ir2vec-type-weight", cl::Optional, cl::init(0.5),
cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional, cl::init(0.2),
cl::desc("Weight for argument embeddings"),
cl::cat(IR2VecCategory));
+cl::opt<IR2VecKind> IR2VecEmbeddingKind(
+ "ir2vec-kind", cl::Optional,
+ cl::values(clEnumValN(IR2VecKind::Symbolic, "symbolic",
+ "Generate symbolic embeddings"),
+ clEnumValN(IR2VecKind::FlowAware, "flow-aware",
+ "Generate flow-aware embeddings")),
+ cl::init(IR2VecKind::Symbolic), cl::desc("IR2Vec embedding kind"),
+ cl::cat(IR2VecCategory));
+
} // namespace ir2vec
} // namespace llvm
@@ -149,6 +158,8 @@ std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
switch (Mode) {
case IR2VecKind::Symbolic:
return std::make_unique<SymbolicEmbedder>(F, Vocab);
+ case IR2VecKind::FlowAware:
+ return std::make_unique<FlowAwareEmbedder>(F, Vocab);
}
return nullptr;
}
@@ -180,6 +191,17 @@ const Embedding &Embedder::getFunctionVector() const {
return FuncVector;
}
+void Embedder::computeEmbeddings() const {
+ if (F.isDeclaration())
+ return;
+
+ // Consider only the basic blocks that are reachable from entry
+ for (const BasicBlock *BB : depth_first(&F)) {
+ computeEmbeddings(*BB);
+ FuncVector += BBVecMap[BB];
+ }
+}
+
void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
Embedding BBVector(Dimension, 0);
@@ -196,15 +218,46 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
BBVecMap[&BB] = BBVector;
}
-void SymbolicEmbedder::computeEmbeddings() const {
- if (F.isDeclaration())
- return;
+void FlowAwareEmbedder::computeEmbeddings(const BasicBlock &BB) const {
+ Embedding BBVector(Dimension, 0);
- // Consider only the basic blocks that are reachable from entry
- for (const BasicBlock *BB : depth_first(&F)) {
- computeEmbeddings(*BB);
- FuncVector += BBVecMap[BB];
+ // We consider only the non-debug and non-pseudo instructions
+ for (const auto &I : BB.instructionsWithoutDebug()) {
+ // TODO: Handle call instructions differently.
+ // For now, we treat them like other instructions
+ Embedding ArgEmb(Dimension, 0);
+ for (const auto &Op : I.operands()) {
+ // If the operand is defined elsewhere, we use its embedding
+ if (const Instruction *DefInst = dyn_cast<Instruction>(Op)) {
+ auto DefIt = InstVecMap.find(DefInst);
+ assert(DefIt != InstVecMap.end() &&
+ "Instruction should have been processed before its operands");
+ if (DefIt != InstVecMap.end()) {
+ ArgEmb += DefIt->second;
+ continue;
+ }
+ // If the definition is not in the map, we use the vocabulary
+ // Not expected, but handle it gracefully
+ LLVM_DEBUG(dbgs() << "Warning: Operand defined by instruction not "
+ "found in InstVecMap: "
+ << *DefInst << "\n");
+ ArgEmb += Vocab[Op];
+ }
+ // If the operand is not defined by an instruction, we use the vocabulary
+ else {
+ LLVM_DEBUG(errs() << "Using embedding from vocabulary for operand: "
+ << *Op << "=" << Vocab[Op][0] << "\n");
+ ArgEmb += Vocab[Op];
+ }
+ }
+ // Create the instruction vector by combining opcode, type, and arguments
+ // embeddings
+ auto InstVector =
+ Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
+ InstVecMap[&I] = InstVector;
+ BBVector += InstVector;
}
+ BBVecMap[&BB] = BBVector;
}
// ==----------------------------------------------------------------------===//
@@ -552,8 +605,11 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
assert(Vocabulary.isValid() && "IR2Vec Vocabulary is invalid");
for (Function &F : M) {
- std::unique_ptr<Embedder> Emb =
- Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);
+ std::unique_ptr<Embedder> Emb;
+ if (IR2VecEmbeddingKind == IR2VecKind::Symbolic)
+ Emb = Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);
+ else
+ Emb = Embedder::create(IR2VecKind::FlowAware, F, Vocabulary);
if (!Emb) {
OS << "Error creating IR2Vec embeddings \n";
continue;
diff --git a/llvm/test/Analysis/IR2Vec/basic-flowaware.ll b/llvm/test/Analysis/IR2Vec/basic-flowaware.ll
new file mode 100644
index 0000000000000..4a7f970a9cf91
--- /dev/null
+++ b/llvm/test/Analysis/IR2Vec/basic-flowaware.ll
@@ -0,0 +1,72 @@
+; RUN: opt -passes='print<ir2vec>' -ir2vec-kind=flow-aware -o /dev/null -ir2vec-vocab-path=%S/Inputs/dummy_3D_nonzero_opc_vocab.json %s 2>&1 | FileCheck %s -check-prefix=3D-CHECK-OPC
+; RUN: opt -passes='print<ir2vec>' -ir2vec-kind=flow-aware -o /dev/null -ir2vec-vocab-path=%S/Inputs/dummy_3D_nonzero_type_vocab.json %s 2>&1 | FileCheck %s -check-prefix=3D-CHECK-TYPE
+; RUN: opt -passes='print<ir2vec>' -ir2vec-kind=flow-aware -o /dev/null -ir2vec-vocab-path=%S/Inputs/dummy_3D_nonzero_arg_vocab.json %s 2>&1 | FileCheck %s -check-prefix=3D-CHECK-ARG
+
+define dso_local noundef float @_Z3abcif(i32 noundef %a, float noundef %b) #0 {
+entry:
+ %a.addr = alloca i32, align 4
+ %b.addr = alloca float, align 4
+ store i32 %a, ptr %a.addr, align 4
+ store float %b, ptr %b.addr, align 4
+ %0 = load i32, ptr %a.addr, align 4
+ %1 = load i32, ptr %a.addr, align 4
+ %mul = mul nsw i32 %0, %1
+ %conv = sitofp i32 %mul to float
+ %2 = load float, ptr %b.addr, align 4
+ %add = fadd float %conv, %2
+ ret float %add
+}
+
+; 3D-CHECK-OPC: IR2Vec embeddings for function _Z3abcif:
+; 3D-CHECK-OPC-NEXT: Function vector: [ 3630.00 3672.00 3714.00 ]
+; 3D-CHECK-OPC-NEXT: Basic block vectors:
+; 3D-CHECK-OPC-NEXT: Basic block: entry:
+; 3D-CHECK-OPC-NEXT: [ 3630.00 3672.00 3714.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction vectors:
+; 3D-CHECK-OPC-NEXT: Instruction: %a.addr = alloca i32, align 4 [ 91.00 92.00 93.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction: %b.addr = alloca float, align 4 [ 91.00 92.00 93.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction: store i32 %a, ptr %a.addr, align 4 [ 188.00 190.00 192.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction: store float %b, ptr %b.addr, align 4 [ 188.00 190.00 192.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction: %0 = load i32, ptr %a.addr, align 4 [ 185.00 187.00 189.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction: %1 = load i32, ptr %a.addr, align 4 [ 185.00 187.00 189.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction: %mul = mul nsw i32 %0, %1 [ 419.00 424.00 429.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction: %conv = sitofp i32 %mul to float [ 549.00 555.00 561.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction: %2 = load float, ptr %b.addr, align 4 [ 185.00 187.00 189.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction: %add = fadd float %conv, %2 [ 774.00 783.00 792.00 ]
+; 3D-CHECK-OPC-NEXT: Instruction: ret float %add [ 775.00 785.00 795.00 ]
+
+; 3D-CHECK-TYPE: IR2Vec embeddings for function _Z3abcif:
+; 3D-CHECK-TYPE-NEXT: Function vector: [ 355.50 376.50 397.50 ]
+; 3D-CHECK-TYPE-NEXT: Basic block vectors:
+; 3D-CHECK-TYPE-NEXT: Basic block: entry:
+; 3D-CHECK-TYPE-NEXT: [ 355.50 376.50 397.50 ]
+; 3D-CHECK-TYPE-NEXT: Instruction vectors:
+; 3D-CHECK-TYPE-NEXT: Instruction: %a.addr = alloca i32, align 4 [ 12.50 13.00 13.50 ]
+; 3D-CHECK-TYPE-NEXT: Instruction: %b.addr = alloca float, align 4 [ 12.50 13.00 13.50 ]
+; 3D-CHECK-TYPE-NEXT: Instruction: store i32 %a, ptr %a.addr, align 4 [ 14.50 15.50 16.50 ]
+; 3D-CHECK-TYPE-NEXT: Instruction: store float %b, ptr %b.addr, align 4 [ 14.50 15.50 16.50 ]
+; 3D-CHECK-TYPE-NEXT: Instruction: %0 = load i32, ptr %a.addr, align 4 [ 22.00 23.00 24.00 ]
+; 3D-CHECK-TYPE-NEXT: Instruction: %1 = load i32, ptr %a.addr, align 4 [ 22.00 23.00 24.00 ]
+; 3D-CHECK-TYPE-NEXT: Instruction: %mul = mul nsw i32 %0, %1 [ 53.50 56.00 58.50 ]
+; 3D-CHECK-TYPE-NEXT: Instruction: %conv = sitofp i32 %mul to float [ 54.00 57.00 60.00 ]
+; 3D-CHECK-TYPE-NEXT: Instruction: %2 = load float, ptr %b.addr, align 4 [ 13.00 14.00 15.00 ]
+; 3D-CHECK-TYPE-NEXT: Instruction: %add = fadd float %conv, %2 [ 67.50 72.00 76.50 ]
+; 3D-CHECK-TYPE-NEXT: Instruction: ret float %add [ 69.50 74.50 79.50 ]
+
+; 3D-CHECK-ARG: IR2Vec embeddings for function _Z3abcif:
+; 3D-CHECK-ARG-NEXT: Function vector: [ 27.80 31.60 35.40 ]
+; 3D-CHECK-ARG-NEXT: Basic block vectors:
+; 3D-CHECK-ARG-NEXT: Basic block: entry:
+; 3D-CHECK-ARG-NEXT: [ 27.80 31.60 35.40 ]
+; 3D-CHECK-ARG-NEXT: Instruction vectors:
+; 3D-CHECK-ARG-NEXT: Instruction: %a.addr = alloca i32, align 4 [ 1.40 1.60 1.80 ]
+; 3D-CHECK-ARG-NEXT: Instruction: %b.addr = alloca float, align 4 [ 1.40 1.60 1.80 ]
+; 3D-CHECK-ARG-NEXT: Instruction: store i32 %a, ptr %a.addr, align 4 [ 3.40 3.80 4.20 ]
+; 3D-CHECK-ARG-NEXT: Instruction: store float %b, ptr %b.addr, align 4 [ 3.40 3.80 4.20 ]
+; 3D-CHECK-ARG-NEXT: Instruction: %0 = load i32, ptr %a.addr, align 4 [ 1.40 1.60 1.80 ]
+; 3D-CHECK-ARG-NEXT: Instruction: %1 = load i32, ptr %a.addr, align 4 [ 1.40 1.60 1.80 ]
+; 3D-CHECK-ARG-NEXT: Instruction: %mul = mul nsw i32 %0, %1 [ 2.80 3.20 3.60 ]
+; 3D-CHECK-ARG-NEXT: Instruction: %conv = sitofp i32 %mul to float [ 2.80 3.20 3.60 ]
+; 3D-CHECK-ARG-NEXT: Instruction: %2 = load float, ptr %b.addr, align 4 [ 1.40 1.60 1.80 ]
+; 3D-CHECK-ARG-NEXT: Instruction: %add = fadd float %conv, %2 [ 4.20 4.80 5.40 ]
+; 3D-CHECK-ARG-NEXT: Instruction: ret float %add [ 4.20 4.80 5.40 ]
diff --git a/llvm/test/Analysis/IR2Vec/basic.ll b/llvm/test/Analysis/IR2Vec/basic-symbolic.ll
similarity index 81%
rename from llvm/test/Analysis/IR2Vec/basic.ll
rename to llvm/test/Analysis/IR2Vec/basic-symbolic.ll
index cb0544fb19860..35abd3c7fa269 100644
--- a/llvm/test/Analysis/IR2Vec/basic.ll
+++ b/llvm/test/Analysis/IR2Vec/basic-symbolic.ll
@@ -1,11 +1,7 @@
; RUN: opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/dummy_3D_nonzero_opc_vocab.json %s 2>&1 | FileCheck %s -check-prefix=3D-CHECK-OPC
; RUN: opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/dummy_3D_nonzero_type_vocab.json %s 2>&1 | FileCheck %s -check-prefix=3D-CHECK-TYPE
; RUN: opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/dummy_3D_nonzero_arg_vocab.json %s 2>&1 | FileCheck %s -check-prefix=3D-CHECK-ARG
-; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab1.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB1-CHECK
-; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab2.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB2-CHECK
-; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab3.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB3-CHECK
-; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab4.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB4-CHECK
-
+
define dso_local noundef float @_Z3abcif(i32 noundef %a, float noundef %b) #0 {
entry:
%a.addr = alloca i32, align 4
@@ -74,11 +70,3 @@ entry:
; 3D-CHECK-ARG-NEXT: Instruction: %2 = load float, ptr %b.addr, align 4 [ 0.80 1.00 1.20 ]
; 3D-CHECK-ARG-NEXT: Instruction: %add = fadd float %conv, %2 [ 4.00 4.40 4.80 ]
; 3D-CHECK-ARG-NEXT: Instruction: ret float %add [ 2.00 2.20 2.40 ]
-
-; INCORRECT-VOCAB1-CHECK: error: Error reading vocabulary: Missing 'Opcodes' section in vocabulary file
-
-; INCORRECT-VOCAB2-CHECK: error: Error reading vocabulary: Missing 'Types' section in vocabulary file
-
-; INCORRECT-VOCAB3-CHECK: error: Error reading vocabulary: Missing 'Arguments' section in vocabulary file
-
-; INCORRECT-VOCAB4-CHECK: error: Error reading vocabulary: Vocabulary sections have different dimensions
diff --git a/llvm/test/Analysis/IR2Vec/basic-vocab.ll b/llvm/test/Analysis/IR2Vec/basic-vocab.ll
new file mode 100644
index 0000000000000..eeeee831814a8
--- /dev/null
+++ b/llvm/test/Analysis/IR2Vec/basic-vocab.ll
@@ -0,0 +1,27 @@
+; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab1.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB1-CHECK
+; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab2.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB2-CHECK
+; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab3.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB3-CHECK
+; RUN: not opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/incorrect_vocab4.json %s 2>&1 | FileCheck %s -check-prefix=INCORRECT-VOCAB4-CHECK
+
+define dso_local noundef float @_Z3abcif(i32 noundef %a, float noundef %b) #0 {
+entry:
+ %a.addr = alloca i32, align 4
+ %b.addr = alloca float, align 4
+ store i32 %a, ptr %a.addr, align 4
+ store float %b, ptr %b.addr, align 4
+ %0 = load i32, ptr %a.addr, align 4
+ %1 = load i32, ptr %a.addr, align 4
+ %mul = mul nsw i32 %0, %1
+ %conv = sitofp i32 %mul to float
+ %2 = load float, ptr %b.addr, align 4
+ %add = fadd float %conv, %2
+ ret float %add
+}
+
+; INCORRECT-VOCAB1-CHECK: error: Error reading vocabulary: Missing 'Opcodes' section in vocabulary file
+
+; INCORRECT-VOCAB2-CHECK: error: Error reading vocabulary: Missing 'Types' section in vocabulary file
+
+; INCORRECT-VOCAB3-CHECK: error: Error reading vocabulary: Missing 'Arguments' section in vocabulary file
+
+; INCORRECT-VOCAB4-CHECK: error: Error reading vocabulary: Vocabulary sections have different dimensions
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp
index e288585033c53..f6846963b3e2f 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -30,7 +30,6 @@ namespace {
class TestableEmbedder : public Embedder {
public:
TestableEmbedder(const Function &F, const Vocabulary &V) : Embedder(F, V) {}
- void computeEmbeddings() const override {}
void computeEmbeddings(const BasicBlock &BB) const override {}
};
@@ -258,6 +257,18 @@ TEST(IR2VecTest, CreateSymbolicEmbedder) {
EXPECT_NE(Emb, nullptr);
}
+TEST(IR2VecTest, CreateFlowAwareEmbedder) {
+ Vocabulary V = Vocabulary(Vocabulary::createDummyVocabForTest());
+
+ LLVMContext Ctx;
+ Module M("M", Ctx);
+ FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
+ Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
+
+ auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, V);
+ EXPECT_NE(Emb, nullptr);
+}
+
TEST(IR2VecTest, CreateInvalidMode) {
Vocabulary V = Vocabulary(Vocabulary::createDummyVocabForTest());
@@ -307,10 +318,12 @@ class IR2VecTestFixture : public ::testing::Test {
AddInst = BinaryOperator::CreateAdd(Arg, Const, "add", BB);
RetInst = ReturnInst::Create(Ctx, AddInst, BB);
+ F->print(llvm::errs());
+ F->dump();
}
};
-TEST_F(IR2VecTestFixture, GetInstVecMap) {
+TEST_F(IR2VecTestFixture, GetInstVecMap_Symbolic) {
auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
ASSERT_TRUE(static_cast<bool>(Emb));
@@ -327,7 +340,24 @@ TEST_F(IR2VecTestFixture, GetInstVecMap) {
EXPECT_TRUE(InstMap.at(RetInst).approximatelyEquals(Embedding(2, 16.8)));
}
-TEST_F(IR2VecTestFixture, GetBBVecMap) {
+TEST_F(IR2VecTestFixture, GetInstVecMap_FlowAware) {
+ auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, V);
+ ASSERT_TRUE(static_cast<bool>(Emb));
+
+ const auto &InstMap = Emb->getInstVecMap();
+
+ EXPECT_EQ(InstMap.size(), 2u);
+ EXPECT_TRUE(InstMap.count(AddInst));
+ EXPECT_TRUE(InstMap.count(RetInst));
+
+ EXPECT_EQ(InstMap.at(AddInst).size(), 2u);
+ EXPECT_EQ(InstMap.at(RetInst).size(), 2u);
+
+ EXPECT_TRUE(InstMap.at(AddInst).approximatelyEquals(Embedding(2, 27.6)));
+ EXPECT_TRUE(InstMap.at(RetInst).approximatelyEquals(Embedding(2, 35.2)));
+}
+
+TEST_F(IR2VecTestFixture, GetBBVecMap_Symbolic) {
auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
ASSERT_TRUE(static_cast<bool>(Emb));
@@ -342,7 +372,22 @@ TEST_F(IR2VecTestFixture, GetBBVecMap) {
EXPECT_TRUE(BBMap.at(BB).approximatelyEquals(Embedding(2, 44.4)));
}
-TEST_F(IR2VecTestFixture, GetBBVector) {
+TEST_F(IR2VecTestFixture, GetBBVecMap_FlowAware) {
+ auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, V);
+ ASSERT_TRUE(static_cast<bool>(Emb));
+
+ const auto &BBMap = Emb->getBBVecMap();
+
+ EXPECT_EQ(BBMap.size(), 1u);
+ EXPECT_TRUE(BBMap.count(BB));
+ EXPECT_EQ(BBMap.at(BB).size(), 2u);
+
+ // BB vector should be sum of add and ret: {27.6, 27.6} + {35.2, 35.2} =
+ // {62.8, 62.8}
+ EXPECT_TRUE(BBMap.at(BB).approximatelyEquals(Embedding(2, 62.8)));
+}
+
+TEST_F(IR2VecTestFixture, GetBBVector_Symbolic) {
auto Emb = Embedder::create(IR2VecKind::Symbolic, *F, V);
ASSERT_TRUE(static_cast<bool>(Emb));
@@ -352,7 +397,17 @@ TEST_F(IR2VecTestFixture, GetBBVector) {
EXPECT_TRUE(BBVec.approximatelyEquals(Embedding(2, 44.4)));
}
-TEST_F(IR2VecTestFixture, GetFunctionVector) {
+TEST_F(IR2VecTestFixture, GetBBVector_FlowAware) {
+ auto Emb = Embedder::create(IR2VecKind::FlowAware, *F, V);
+ ASSERT_TRUE(static_cast<bool>(Emb));
+
+ const auto &BBVec ...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/152613
More information about the llvm-commits
mailing list