[llvm-branch-commits] [llvm] [IR2Vec] Minor vocab changes and exposing weights (PR #143200)
S. VenkataKeerthy via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Mon Jun 9 13:42:16 PDT 2025
https://github.com/svkeerthy updated https://github.com/llvm/llvm-project/pull/143200
>From 7f2012cd56db0fc6e1c430a8d5b38d360b33145f Mon Sep 17 00:00:00 2001
From: svkeerthy <venkatakeerthy at google.com>
Date: Fri, 6 Jun 2025 20:32:32 +0000
Subject: [PATCH] Vocab changes1
---
llvm/include/llvm/Analysis/IR2Vec.h | 10 ++
llvm/lib/Analysis/IR2Vec.cpp | 82 +++++++++------
llvm/unittests/Analysis/IR2VecTest.cpp | 137 ++++++++++++++++++-------
3 files changed, 163 insertions(+), 66 deletions(-)
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 14f28999b174c..3d32942670785 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -31,7 +31,9 @@
#include "llvm/ADT/DenseMap.h"
#include "llvm/IR/PassManager.h"
+#include "llvm/Support/CommandLine.h"
#include "llvm/Support/ErrorOr.h"
+#include "llvm/Support/JSON.h"
#include <map>
namespace llvm {
@@ -43,6 +45,7 @@ class Function;
class Type;
class Value;
class raw_ostream;
+class LLVMContext;
/// IR2Vec computes two kinds of embeddings: Symbolic and Flow-aware.
/// Symbolic embeddings capture the "syntactic" and "statistical correlation"
@@ -53,6 +56,11 @@ class raw_ostream;
enum class IR2VecKind { Symbolic };
namespace ir2vec {
+
+LLVM_ABI extern cl::opt<float> OpcWeight;
+LLVM_ABI extern cl::opt<float> TypeWeight;
+LLVM_ABI extern cl::opt<float> ArgWeight;
+
/// Embedding is a ADT that wraps std::vector<double>. It provides
/// additional functionality for arithmetic and comparison operations.
/// It is meant to be used *like* std::vector<double> but is more restrictive
@@ -224,10 +232,12 @@ class IR2VecVocabResult {
class IR2VecVocabAnalysis : public AnalysisInfoMixin<IR2VecVocabAnalysis> {
ir2vec::Vocab Vocabulary;
Error readVocabulary();
+ void emitError(Error Err, LLVMContext &Ctx);
public:
static AnalysisKey Key;
IR2VecVocabAnalysis() = default;
+ explicit IR2VecVocabAnalysis(ir2vec::Vocab &&Vocab);
using Result = IR2VecVocabResult;
Result run(Module &M, ModuleAnalysisManager &MAM);
};
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 25ce35d4ace37..2ad65c2f40c33 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -16,13 +16,11 @@
#include "llvm/ADT/Statistic.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/PassManager.h"
-#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/Errc.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/Format.h"
-#include "llvm/Support/JSON.h"
#include "llvm/Support/MemoryBuffer.h"
using namespace llvm;
@@ -33,6 +31,8 @@ using namespace ir2vec;
STATISTIC(VocabMissCounter,
"Number of lookups to entites not present in the vocabulary");
+namespace llvm {
+namespace ir2vec {
static cl::OptionCategory IR2VecCategory("IR2Vec Options");
// FIXME: Use a default vocab when not specified
@@ -40,18 +40,20 @@ static cl::opt<std::string>
VocabFile("ir2vec-vocab-path", cl::Optional,
cl::desc("Path to the vocabulary file for IR2Vec"), cl::init(""),
cl::cat(IR2VecCategory));
-static cl::opt<float> OpcWeight("ir2vec-opc-weight", cl::Optional,
- cl::init(1.0),
- cl::desc("Weight for opcode embeddings"),
- cl::cat(IR2VecCategory));
-static cl::opt<float> TypeWeight("ir2vec-type-weight", cl::Optional,
- cl::init(0.5),
- cl::desc("Weight for type embeddings"),
- cl::cat(IR2VecCategory));
-static cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional,
- cl::init(0.2),
- cl::desc("Weight for argument embeddings"),
- cl::cat(IR2VecCategory));
+LLVM_ABI cl::opt<float> OpcWeight("ir2vec-opc-weight", cl::Optional,
+ cl::init(1.0),
+ cl::desc("Weight for opcode embeddings"),
+ cl::cat(IR2VecCategory));
+LLVM_ABI cl::opt<float> TypeWeight("ir2vec-type-weight", cl::Optional,
+ cl::init(0.5),
+ cl::desc("Weight for type embeddings"),
+ cl::cat(IR2VecCategory));
+LLVM_ABI cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional,
+ cl::init(0.2),
+ cl::desc("Weight for argument embeddings"),
+ cl::cat(IR2VecCategory));
+} // namespace ir2vec
+} // namespace llvm
AnalysisKey IR2VecVocabAnalysis::Key;
@@ -251,9 +253,9 @@ bool IR2VecVocabResult::invalidate(
// by auto-generating a default vocabulary during the build time.
Error IR2VecVocabAnalysis::readVocabulary() {
auto BufOrError = MemoryBuffer::getFileOrSTDIN(VocabFile, /*IsText=*/true);
- if (!BufOrError) {
+ if (!BufOrError)
return createFileError(VocabFile, BufOrError.getError());
- }
+
auto Content = BufOrError.get()->getBuffer();
json::Path::Root Path("");
Expected<json::Value> ParsedVocabValue = json::parse(Content);
@@ -261,39 +263,57 @@ Error IR2VecVocabAnalysis::readVocabulary() {
return ParsedVocabValue.takeError();
bool Res = json::fromJSON(*ParsedVocabValue, Vocabulary, Path);
- if (!Res) {
+ if (!Res)
return createStringError(errc::illegal_byte_sequence,
"Unable to parse the vocabulary");
- }
- assert(Vocabulary.size() > 0 && "Vocabulary is empty");
+
+ if (Vocabulary.empty())
+ return createStringError(errc::illegal_byte_sequence,
+ "Vocabulary is empty");
unsigned Dim = Vocabulary.begin()->second.size();
- assert(Dim > 0 && "Dimension of vocabulary is zero");
- (void)Dim;
- assert(std::all_of(Vocabulary.begin(), Vocabulary.end(),
- [Dim](const std::pair<StringRef, Embedding> &Entry) {
- return Entry.second.size() == Dim;
- }) &&
- "All vectors in the vocabulary are not of the same dimension");
+ if (Dim == 0)
+ return createStringError(errc::illegal_byte_sequence,
+ "Dimension of vocabulary is zero");
+
+ if (!std::all_of(Vocabulary.begin(), Vocabulary.end(),
+ [Dim](const std::pair<StringRef, Embedding> &Entry) {
+ return Entry.second.size() == Dim;
+ }))
+ return createStringError(
+ errc::illegal_byte_sequence,
+ "All vectors in the vocabulary are not of the same dimension");
+
return Error::success();
}
+IR2VecVocabAnalysis::IR2VecVocabAnalysis(Vocab &&Vocabulary)
+ : Vocabulary(std::move(Vocabulary)) {}
+
+void IR2VecVocabAnalysis::emitError(Error Err, LLVMContext &Ctx) {
+ handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
+ Ctx.emitError("Error reading vocabulary: " + EI.message());
+ });
+}
+
IR2VecVocabAnalysis::Result
IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
auto Ctx = &M.getContext();
+ // FIXME: Scale the vocabulary once. This would avoid scaling per use later.
+ // If vocabulary is already populated by the constructor, use it.
+ if (!Vocabulary.empty())
+ return IR2VecVocabResult(std::move(Vocabulary));
+
+ // Otherwise, try to read from the vocabulary file.
if (VocabFile.empty()) {
// FIXME: Use default vocabulary
Ctx->emitError("IR2Vec vocabulary file path not specified");
return IR2VecVocabResult(); // Return invalid result
}
if (auto Err = readVocabulary()) {
- handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
- Ctx->emitError("Error reading vocabulary: " + EI.message());
- });
+ emitError(std::move(Err), *Ctx);
return IR2VecVocabResult();
}
- // FIXME: Scale the vocabulary here once. This would avoid scaling per use
- // later.
return IR2VecVocabResult(std::move(Vocabulary));
}
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp
index 46e9c71c58250..c2c65c92cfb07 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -261,25 +261,30 @@ TEST(IR2VecTest, IR2VecVocabResultValidity) {
EXPECT_EQ(validResult.getDimension(), 2u);
}
-// Helper to create a minimal function and embedder for getter tests
-struct GetterTestEnv {
- Vocab V = {};
+// Fixture for IR2Vec tests requiring IR setup and weight management.
+class IR2VecTestFixture : public ::testing::Test {
+protected:
+ Vocab V;
LLVMContext Ctx;
- std::unique_ptr<Module> M = nullptr;
+ std::unique_ptr<Module> M;
Function *F = nullptr;
BasicBlock *BB = nullptr;
- Instruction *Add = nullptr;
- Instruction *Ret = nullptr;
- std::unique_ptr<Embedder> Emb = nullptr;
+ Instruction *AddInst = nullptr;
+ Instruction *RetInst = nullptr;
- GetterTestEnv() {
+ float OriginalOpcWeight = ::OpcWeight;
+ float OriginalTypeWeight = ::TypeWeight;
+ float OriginalArgWeight = ::ArgWeight;
+
+ void SetUp() override {
V = {{"add", {1.0, 2.0}},
{"integerTy", {0.5, 0.5}},
{"constant", {0.2, 0.3}},
{"variable", {0.0, 0.0}},
{"unknownTy", {0.0, 0.0}}};
- M = std::make_unique<Module>("M", Ctx);
+ // Setup IR
+ M = std::make_unique<Module>("TestM", Ctx);
FunctionType *FTy = FunctionType::get(
Type::getInt32Ty(Ctx), {Type::getInt32Ty(Ctx), Type::getInt32Ty(Ctx)},
false);
@@ -288,61 +293,82 @@ struct GetterTestEnv {
Argument *Arg = F->getArg(0);
llvm::Value *Const = ConstantInt::get(Type::getInt32Ty(Ctx), 42);
- Add = BinaryOperator::CreateAdd(Arg, Const, "add", BB);
- Ret = ReturnInst::Create(Ctx, Add, BB);
+ AddInst = BinaryOperator::CreateAdd(Arg, Const, "add", BB);
+ RetInst = ReturnInst::Create(Ctx, AddInst, BB);
+ }
+
+ void setWeights(float OpcWeight, float TypeWeight, float ArgWeight) {
+ ::OpcWeight = OpcWeight;
+ ::TypeWeight = TypeWeight;
+ ::ArgWeight = ArgWeight;
+ }
- auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
- EXPECT_TRUE(static_cast<bool>(Result));
- Emb = std::move(*Result);
+ void TearDown() override {
+ // Restore original global weights
+ ::OpcWeight = OriginalOpcWeight;
+ ::TypeWeight = OriginalTypeWeight;
+ ::ArgWeight = OriginalArgWeight;
}
};
-TEST(IR2VecTest, GetInstVecMap) {
- GetterTestEnv Env;
- const auto &InstMap = Env.Emb->getInstVecMap();
+TEST_F(IR2VecTestFixture, GetInstVecMap) {
+ auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
+ ASSERT_TRUE(static_cast<bool>(Result));
+ auto Emb = std::move(*Result);
+
+ const auto &InstMap = Emb->getInstVecMap();
EXPECT_EQ(InstMap.size(), 2u);
- EXPECT_TRUE(InstMap.count(Env.Add));
- EXPECT_TRUE(InstMap.count(Env.Ret));
+ EXPECT_TRUE(InstMap.count(AddInst));
+ EXPECT_TRUE(InstMap.count(RetInst));
- EXPECT_EQ(InstMap.at(Env.Add).size(), 2u);
- EXPECT_EQ(InstMap.at(Env.Ret).size(), 2u);
+ EXPECT_EQ(InstMap.at(AddInst).size(), 2u);
+ EXPECT_EQ(InstMap.at(RetInst).size(), 2u);
// Check values for add: {1.29, 2.31}
- EXPECT_THAT(InstMap.at(Env.Add),
+ EXPECT_THAT(InstMap.at(AddInst),
ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
// Check values for ret: {0.0, 0.}; Neither ret nor voidTy are present in
// vocab
- EXPECT_THAT(InstMap.at(Env.Ret), ElementsAre(0.0, 0.0));
+ EXPECT_THAT(InstMap.at(RetInst), ElementsAre(0.0, 0.0));
}
-TEST(IR2VecTest, GetBBVecMap) {
- GetterTestEnv Env;
- const auto &BBMap = Env.Emb->getBBVecMap();
+TEST_F(IR2VecTestFixture, GetBBVecMap) {
+ auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
+ ASSERT_TRUE(static_cast<bool>(Result));
+ auto Emb = std::move(*Result);
+
+ const auto &BBMap = Emb->getBBVecMap();
EXPECT_EQ(BBMap.size(), 1u);
- EXPECT_TRUE(BBMap.count(Env.BB));
- EXPECT_EQ(BBMap.at(Env.BB).size(), 2u);
+ EXPECT_TRUE(BBMap.count(BB));
+ EXPECT_EQ(BBMap.at(BB).size(), 2u);
// BB vector should be sum of add and ret: {1.29, 2.31} + {0.0, 0.0} =
// {1.29, 2.31}
- EXPECT_THAT(BBMap.at(Env.BB),
+ EXPECT_THAT(BBMap.at(BB),
ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
}
-TEST(IR2VecTest, GetBBVector) {
- GetterTestEnv Env;
- const auto &BBVec = Env.Emb->getBBVector(*Env.BB);
+TEST_F(IR2VecTestFixture, GetBBVector) {
+ auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
+ ASSERT_TRUE(static_cast<bool>(Result));
+ auto Emb = std::move(*Result);
+
+ const auto &BBVec = Emb->getBBVector(*BB);
EXPECT_EQ(BBVec.size(), 2u);
EXPECT_THAT(BBVec,
ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
}
-TEST(IR2VecTest, GetFunctionVector) {
- GetterTestEnv Env;
- const auto &FuncVec = Env.Emb->getFunctionVector();
+TEST_F(IR2VecTestFixture, GetFunctionVector) {
+ auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
+ ASSERT_TRUE(static_cast<bool>(Result));
+ auto Emb = std::move(*Result);
+
+ const auto &FuncVec = Emb->getFunctionVector();
EXPECT_EQ(FuncVec.size(), 2u);
@@ -351,4 +377,45 @@ TEST(IR2VecTest, GetFunctionVector) {
ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
}
+TEST_F(IR2VecTestFixture, GetFunctionVectorWithCustomWeights) {
+ setWeights(1.0, 1.0, 1.0);
+
+ auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
+ ASSERT_TRUE(static_cast<bool>(Result));
+ auto Emb = std::move(*Result);
+
+ const auto &FuncVec = Emb->getFunctionVector();
+
+ EXPECT_EQ(FuncVec.size(), 2u);
+
+ // Expected: 1*([1.0 2.0] + [0.0 0.0]) + 1*([0.5 0.5] + [0.0 0.0]) + 1*([0.2
+ // 0.3] + [0.0 0.0])
+ EXPECT_THAT(FuncVec,
+ ElementsAre(DoubleNear(1.7, 1e-6), DoubleNear(2.8, 1e-6)));
+}
+
+TEST(IR2VecTest, IR2VecVocabAnalysisWithPrepopulatedVocab) {
+ Vocab InitialVocab = {{"key1", {1.1, 2.2}}, {"key2", {3.3, 4.4}}};
+ Vocab ExpectedVocab = InitialVocab;
+ unsigned ExpectedDim = InitialVocab.begin()->second.size();
+
+ IR2VecVocabAnalysis VocabAnalysis(std::move(InitialVocab));
+
+ LLVMContext TestCtx;
+ Module TestMod("TestModuleForVocabAnalysis", TestCtx);
+ ModuleAnalysisManager MAM;
+ IR2VecVocabResult Result = VocabAnalysis.run(TestMod, MAM);
+
+ EXPECT_TRUE(Result.isValid());
+ ASSERT_FALSE(Result.getVocabulary().empty());
+ EXPECT_EQ(Result.getDimension(), ExpectedDim);
+
+ const auto &ResultVocab = Result.getVocabulary();
+ EXPECT_EQ(ResultVocab.size(), ExpectedVocab.size());
+ for (const auto &pair : ExpectedVocab) {
+ EXPECT_TRUE(ResultVocab.count(pair.first));
+ EXPECT_THAT(ResultVocab.at(pair.first), ElementsAreArray(pair.second));
+ }
+}
+
} // end anonymous namespace
More information about the llvm-branch-commits
mailing list