[llvm] [IR2Vec] Removing Dimension from `Embedder::Create` (PR #142486)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Jun 2 14:07:28 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlgo
Author: S. VenkataKeerthy (svkeerthy)
<details>
<summary>Changes</summary>
This PR removes the necessity to know the dimension of the embeddings while invoking `Embedder::Create`. Having the `Dimension` parameter introduces complexities in downstream consumers.
(Tracking issue - #<!-- -->141817)
---
Full diff: https://github.com/llvm/llvm-project/pull/142486.diff
4 Files Affected:
- (modified) llvm/docs/MLGO.rst (+1-2)
- (modified) llvm/include/llvm/Analysis/IR2Vec.h (+5-8)
- (modified) llvm/lib/Analysis/IR2Vec.cpp (+8-12)
- (modified) llvm/unittests/Analysis/IR2VecTest.cpp (+5-6)
``````````diff
diff --git a/llvm/docs/MLGO.rst b/llvm/docs/MLGO.rst
index 377c2aec44475..4f8fb3f59ca19 100644
--- a/llvm/docs/MLGO.rst
+++ b/llvm/docs/MLGO.rst
@@ -469,7 +469,6 @@ embeddings can be computed and accessed via an ``ir2vec::Embedder`` instance.
return;
}
const ir2vec::Vocab &Vocabulary = VocabRes.getVocabulary();
- unsigned Dimension = VocabRes.getDimension();
Note that ``IR2VecVocabAnalysis`` pass is immutable.
@@ -481,7 +480,7 @@ embeddings can be computed and accessed via an ``ir2vec::Embedder`` instance.
// Assuming F is an llvm::Function&
// For example, using IR2VecKind::Symbolic:
Expected<std::unique_ptr<ir2vec::Embedder>> EmbOrErr =
- ir2vec::Embedder::create(IR2VecKind::Symbolic, F, Vocabulary, Dimension);
+ ir2vec::Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);
if (auto Err = EmbOrErr.takeError()) {
// Handle error in embedder creation
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 288753b3b3b8f..9fd1b0ae8e248 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -84,7 +84,7 @@ class Embedder {
mutable BBEmbeddingsMap BBVecMap;
mutable InstEmbeddingsMap InstVecMap;
- Embedder(const Function &F, const Vocab &Vocabulary, unsigned Dimension);
+ Embedder(const Function &F, const Vocab &Vocabulary);
/// Helper function to compute embeddings. It generates embeddings for all
/// the instructions and basic blocks in the function F. Logic of computing
@@ -110,10 +110,8 @@ class Embedder {
virtual ~Embedder() = default;
/// Factory method to create an Embedder object.
- static Expected<std::unique_ptr<Embedder>> create(IR2VecKind Mode,
- const Function &F,
- const Vocab &Vocabulary,
- unsigned Dimension);
+ static Expected<std::unique_ptr<Embedder>>
+ create(IR2VecKind Mode, const Function &F, const Vocab &Vocabulary);
/// Returns a map containing instructions and the corresponding embeddings for
/// the function F if it has been computed. If not, it computes the embeddings
@@ -149,9 +147,8 @@ class SymbolicEmbedder : public Embedder {
void computeEmbeddings(const BasicBlock &BB) const override;
public:
- SymbolicEmbedder(const Function &F, const Vocab &Vocabulary,
- unsigned Dimension)
- : Embedder(F, Vocabulary, Dimension) {
+ SymbolicEmbedder(const Function &F, const Vocab &Vocabulary)
+ : Embedder(F, Vocabulary) {
FuncVector = Embedding(Dimension, 0);
}
};
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 67af44dcac424..490db5fdcdf99 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -59,19 +59,16 @@ AnalysisKey IR2VecVocabAnalysis::Key;
// Embedder and its subclasses
//===----------------------------------------------------------------------===//
-Embedder::Embedder(const Function &F, const Vocab &Vocabulary,
- unsigned Dimension)
- : F(F), Vocabulary(Vocabulary), Dimension(Dimension),
- OpcWeight(::OpcWeight), TypeWeight(::TypeWeight), ArgWeight(::ArgWeight) {
-}
+Embedder::Embedder(const Function &F, const Vocab &Vocabulary)
+ : F(F), Vocabulary(Vocabulary),
+ Dimension(Vocabulary.begin()->second.size()), OpcWeight(::OpcWeight),
+ TypeWeight(::TypeWeight), ArgWeight(::ArgWeight) {}
-Expected<std::unique_ptr<Embedder>> Embedder::create(IR2VecKind Mode,
- const Function &F,
- const Vocab &Vocabulary,
- unsigned Dimension) {
+Expected<std::unique_ptr<Embedder>>
+Embedder::create(IR2VecKind Mode, const Function &F, const Vocab &Vocabulary) {
switch (Mode) {
case IR2VecKind::Symbolic:
- return std::make_unique<SymbolicEmbedder>(F, Vocabulary, Dimension);
+ return std::make_unique<SymbolicEmbedder>(F, Vocabulary);
}
return make_error<StringError>("Unknown IR2VecKind", errc::invalid_argument);
}
@@ -286,10 +283,9 @@ PreservedAnalyses IR2VecPrinterPass::run(Module &M,
assert(IR2VecVocabResult.isValid() && "IR2Vec Vocabulary is invalid");
auto Vocab = IR2VecVocabResult.getVocabulary();
- auto Dim = IR2VecVocabResult.getDimension();
for (Function &F : M) {
Expected<std::unique_ptr<Embedder>> EmbOrErr =
- Embedder::create(IR2VecKind::Symbolic, F, Vocab, Dim);
+ Embedder::create(IR2VecKind::Symbolic, F, Vocab);
if (auto Err = EmbOrErr.takeError()) {
handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
OS << "Error creating IR2Vec embeddings: " << EI.message() << "\n";
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp
index 0158038b59b6c..9e47b2cd8bedd 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -28,8 +28,7 @@ namespace {
class TestableEmbedder : public Embedder {
public:
- TestableEmbedder(const Function &F, const Vocab &V, unsigned Dim)
- : Embedder(F, V, Dim) {}
+ TestableEmbedder(const Function &F, const Vocab &V) : Embedder(F, V) {}
void computeEmbeddings() const override {}
void computeEmbeddings(const BasicBlock &BB) const override {}
using Embedder::lookupVocab;
@@ -50,7 +49,7 @@ TEST(IR2VecTest, CreateSymbolicEmbedder) {
FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
- auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V, 2);
+ auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
EXPECT_TRUE(static_cast<bool>(Result));
auto *Emb = Result->get();
@@ -66,7 +65,7 @@ TEST(IR2VecTest, CreateInvalidMode) {
Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
// static_cast an invalid int to IR2VecKind
- auto Result = Embedder::create(static_cast<IR2VecKind>(-1), *F, V, 2);
+ auto Result = Embedder::create(static_cast<IR2VecKind>(-1), *F, V);
EXPECT_FALSE(static_cast<bool>(Result));
std::string ErrMsg;
@@ -123,7 +122,7 @@ TEST(IR2VecTest, LookupVocab) {
FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
- TestableEmbedder E(*F, V, 2);
+ TestableEmbedder E(*F, V);
auto V_foo = E.lookupVocab("foo");
EXPECT_EQ(V_foo.size(), 2u);
EXPECT_THAT(V_foo, ElementsAre(1.0, 2.0));
@@ -190,7 +189,7 @@ struct GetterTestEnv {
Add = BinaryOperator::CreateAdd(Arg, Const, "add", BB);
Ret = ReturnInst::Create(Ctx, Add, BB);
- auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V, 2);
+ auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V);
EXPECT_TRUE(static_cast<bool>(Result));
Emb = std::move(*Result);
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/142486
More information about the llvm-commits
mailing list