[llvm] 673d725 - Reland "[llvm-ir2vec] Adding Inst Embeddings Map API to ir2vec python bindings (#180140)" (#184196)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Mar 5 09:03:19 PST 2026
Author: Nishant Sachdeva
Date: 2026-03-05T22:33:15+05:30
New Revision: 673d725c3bcf5b551e42ff96d00b77b8f3b9adcb
URL: https://github.com/llvm/llvm-project/commit/673d725c3bcf5b551e42ff96d00b77b8f3b9adcb
DIFF: https://github.com/llvm/llvm-project/commit/673d725c3bcf5b551e42ff96d00b77b8f3b9adcb.diff
LOG: Reland "[llvm-ir2vec] Adding Inst Embeddings Map API to ir2vec python bindings (#180140)" (#184196)
Relanding change from https://github.com/llvm/llvm-project/pull/180140
- Returns a Inst Embedding Map based on the input function name
`getInstEmbMap(funcName) -> Map<Inst string, Embedding>`
- Refactors IR2VecTool methods to have a separate call to create the
embedder object
Added:
Modified:
llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
llvm/tools/llvm-ir2vec/lib/Utils.cpp
llvm/tools/llvm-ir2vec/lib/Utils.h
Removed:
################################################################################
diff --git a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
index 502f8a2411aa8..bb29d33dc8ca6 100644
--- a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
+++ b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
@@ -43,6 +43,18 @@
print(f" BB: {bb_name}")
print(f" Embedding: {emb.tolist()}")
+ # Test getInstEmbMap
+ print("\n=== Instruction Embeddings ===")
+
+ # Test valid function names in sorted order
+ for func_name in sorted(["add", "multiply", "conditional"]):
+ inst_emb_map = tool.getInstEmbMap(func_name)
+ print(f"Function: {func_name}")
+ for inst_str in sorted(inst_emb_map.keys()):
+ emb = inst_emb_map[inst_str]
+ print(f" Inst: {inst_str}")
+ print(f" Embedding: {emb.tolist()}")
+
# CHECK: SUCCESS: Tool initialized
# CHECK: Tool type: IR2VecTool
# CHECK: === Function Embeddings ===
@@ -75,3 +87,29 @@
# CHECK: Function: multiply
# CHECK: BB: entry
# CHECK-NEXT: Embedding: [50.0, 52.0, 54.0]
+# CHECK: === Instruction Embeddings ===
+# CHECK: Function: add
+# CHECK: Inst: %sum = add i32 %a, %b
+# CHECK-NEXT: Embedding: [37.0, 38.0, 39.0]
+# CHECK: Inst: ret i32 %sum
+# CHECK-NEXT: Embedding: [1.0, 2.0, 3.0]
+# CHECK: Function: conditional
+# CHECK: Inst: %cmp = icmp sgt i32 %n, 0
+# CHECK-NEXT: Embedding: [157.20000000298023, 158.20000000298023, 159.20000000298023]
+# CHECK: Inst: %neg_val = sub i32 %n, 10
+# CHECK-NEXT: Embedding: [43.0, 44.0, 45.0]
+# CHECK: Inst: %pos_val = add i32 %n, 10
+# CHECK-NEXT: Embedding: [37.0, 38.0, 39.0]
+# CHECK: Inst: %result = phi i32 [ %pos_val, %positive ], [ %neg_val, %negative ]
+# CHECK-NEXT: Embedding: [163.0, 164.0, 165.0]
+# CHECK: Inst: br i1 %cmp, label %positive, label %negative
+# CHECK-NEXT: Embedding: [4.0, 5.0, 6.0]
+# CHECK: Inst: br label %exit
+# CHECK-NEXT: Embedding: [4.0, 5.0, 6.0]
+# CHECK: Inst: ret i32 %result
+# CHECK-NEXT: Embedding: [1.0, 2.0, 3.0]
+# CHECK: Function: multiply
+# CHECK: Inst: %prod = mul i32 %x, %y
+# CHECK-NEXT: Embedding: [49.0, 50.0, 51.0]
+# CHECK: Inst: ret i32 %prod
+# CHECK-NEXT: Embedding: [1.0, 2.0, 3.0]
diff --git a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
index e334997b31394..e4ddaf9c14e5a 100644
--- a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
+++ b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
@@ -150,6 +150,41 @@ class PyIR2VecTool {
return NbBBEmbMap;
}
+
+ nb::dict getInstEmbMap(const std::string &FuncName) {
+ const Function *F = M->getFunction(FuncName);
+
+ if (!F)
+ throw nb::value_error(
+ ("Function '" + FuncName + "' not found in module").c_str());
+
+ auto ToolInstEmbMap = Tool->getInstEmbeddingsMap(*F, OutputEmbeddingMode);
+
+ if (!ToolInstEmbMap)
+ throw nb::value_error(toString(ToolInstEmbMap.takeError()).c_str());
+
+ nb::dict NbInstEmbMap;
+
+ for (const auto &[InstPtr, InstEmb] : *ToolInstEmbMap) {
+ auto InstEmbVec = InstEmb.getData();
+ double *NbInstEmbVec = new double[InstEmbVec.size()];
+ std::copy(InstEmbVec.begin(), InstEmbVec.end(), NbInstEmbVec);
+
+ auto NbArray = nb::ndarray<nb::numpy, double>(
+ NbInstEmbVec, {InstEmbVec.size()},
+ nb::capsule(NbInstEmbVec, [](void *P) noexcept {
+ delete[] static_cast<double *>(P);
+ }));
+
+ std::string InstStr;
+ raw_string_ostream OS(InstStr);
+ InstPtr->print(OS);
+
+ NbInstEmbMap[nb::str(OS.str().c_str())] = NbArray;
+ }
+
+ return NbInstEmbMap;
+ }
};
} // namespace
@@ -173,7 +208,13 @@ NB_MODULE(ir2vec, m) {
"Generate embeddings for all basic blocks in a function\n"
"Args: funcName (str) - IR-Name of the function\n"
"Returns: dict[str, ndarray[float64]] - "
- "{basic_block_name: embedding vector}");
+ "{basic_block_name: embedding vector}")
+ .def("getInstEmbMap", &PyIR2VecTool::getInstEmbMap, nb::arg("funcName"),
+ "Generate embeddings for all instructions in a function\n"
+ "Args: funcName (str) - IR-Name of the function\n"
+ "Returns: dict[str, ndarray[float64]] - "
+ "{instruction_string: embedding_vector}");
+
m.def(
"initEmbedding",
[](const std::string &filename, const std::string &mode,
diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.cpp b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
index 70f501c116407..c508f8b8b5b46 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.cpp
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
@@ -152,8 +152,8 @@ void IR2VecTool::writeEntitiesToStream(raw_ostream &OS) {
OS << Entities[EntityID] << '\t' << EntityID << '\n';
}
-Expected<Embedding> IR2VecTool::getFunctionEmbedding(const Function &F,
- IR2VecKind Kind) const {
+Expected<std::unique_ptr<Embedder>>
+IR2VecTool::createIR2VecEmbedder(const Function &F, IR2VecKind Kind) const {
if (!Vocab || !Vocab->isValid())
return createStringError(
errc::invalid_argument,
@@ -169,16 +169,20 @@ Expected<Embedding> IR2VecTool::getFunctionEmbedding(const Function &F,
"Failed to create embedder for function '%s'.",
F.getName().str().c_str());
- return Emb->getFunctionVector();
+ return std::move(Emb);
+}
+
+Expected<Embedding> IR2VecTool::getFunctionEmbedding(const Function &F,
+ IR2VecKind Kind) const {
+ auto Emb = createIR2VecEmbedder(F, Kind);
+ if (!Emb)
+ return Emb.takeError();
+
+ return (*Emb)->getFunctionVector();
}
Expected<FuncEmbMap>
IR2VecTool::getFunctionEmbeddingsMap(IR2VecKind Kind) const {
- if (!Vocab || !Vocab->isValid())
- return createStringError(
- errc::invalid_argument,
- "Vocabulary is not valid. IR2VecTool not initialized.");
-
FuncEmbMap Result;
for (const Function &F : M.getFunctionDefs()) {
@@ -193,61 +197,47 @@ IR2VecTool::getFunctionEmbeddingsMap(IR2VecKind Kind) const {
Expected<BBEmbeddingsMap>
IR2VecTool::getBBEmbeddingsMap(const Function &F, IR2VecKind Kind) const {
- if (!Vocab || !Vocab->isValid())
- return createStringError(
- errc::invalid_argument,
- "Vocabulary is not valid. IR2VecTool not initialized.");
+ auto Emb = createIR2VecEmbedder(F, Kind);
+ if (!Emb)
+ return Emb.takeError();
BBEmbeddingsMap Result;
- if (F.isDeclaration())
- return createStringError(errc::invalid_argument,
- "Function is a declaration.");
+ for (const BasicBlock &BB : F)
+ Result.try_emplace(&BB, (*Emb)->getBBVector(BB));
- auto Emb = Embedder::create(Kind, F, *Vocab);
+ return Result;
+}
+
+Expected<InstEmbeddingsMap>
+IR2VecTool::getInstEmbeddingsMap(const Function &F, IR2VecKind Kind) const {
+ auto Emb = createIR2VecEmbedder(F, Kind);
if (!Emb)
- return createStringError(errc::invalid_argument,
- "Failed to create embedder for function '%s'.",
- F.getName().str().c_str());
+ return Emb.takeError();
- for (const BasicBlock &BB : F)
- Result.try_emplace(&BB, Emb->getBBVector(BB));
+ InstEmbeddingsMap Result;
+
+ for (const Instruction &I : instructions(F))
+ Result.try_emplace(&I, (*Emb)->getInstVector(I));
return Result;
}
void IR2VecTool::writeEmbeddingsToStream(raw_ostream &OS,
EmbeddingLevel Level) const {
- if (!Vocab || !Vocab->isValid()) {
- WithColor::error(errs(), ToolName)
- << "Vocabulary is not valid. IR2VecTool not initialized.\n";
- return;
- }
-
for (const Function &F : M.getFunctionDefs())
writeEmbeddingsToStream(F, OS, Level);
}
void IR2VecTool::writeEmbeddingsToStream(const Function &F, raw_ostream &OS,
EmbeddingLevel Level) const {
- if (!Vocab || !Vocab->isValid()) {
- WithColor::error(errs(), ToolName)
- << "Vocabulary is not valid. IR2VecTool not initialized.\n";
- return;
- }
-
- if (F.isDeclaration()) {
- OS << "Function " << F.getName() << " is a declaration, skipping.\n";
- return;
- }
-
- // Create embedder for this function
- auto Emb = Embedder::create(IR2VecEmbeddingKind, F, *Vocab);
- if (!Emb) {
+ auto IR2VecEmbedderObj = createIR2VecEmbedder(F, IR2VecEmbeddingKind);
+ if (!IR2VecEmbedderObj) {
WithColor::error(errs(), ToolName)
- << "Failed to create embedder for function " << F.getName() << "\n";
+ << toString(IR2VecEmbedderObj.takeError()) << "\n";
return;
}
+ auto Emb = std::move(*IR2VecEmbedderObj);
OS << "Function: " << F.getName() << "\n";
diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.h b/llvm/tools/llvm-ir2vec/lib/Utils.h
index 04ec4a74b1e24..ae2f931a90cf9 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.h
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.h
@@ -94,6 +94,10 @@ class IR2VecTool {
public:
explicit IR2VecTool(Module &M) : M(M) {}
+ /// Creates the embedding object for downstream embedding streaming
+ Expected<std::unique_ptr<Embedder>>
+ createIR2VecEmbedder(const Function &F, IR2VecKind Kind) const;
+
/// Initialize the IR2Vec vocabulary from the specified file path.
Error initializeVocabulary(StringRef VocabPath);
@@ -127,6 +131,9 @@ class IR2VecTool {
/// Get embeddings for all basic blocks in a function
Expected<BBEmbeddingsMap> getBBEmbeddingsMap(const Function &F,
IR2VecKind Kind) const;
+ /// Get embeddings for all instructions in a function
+ Expected<InstEmbeddingsMap> getInstEmbeddingsMap(const Function &F,
+ IR2VecKind Kind) const;
/// Generate embeddings for the entire module
void writeEmbeddingsToStream(raw_ostream &OS, EmbeddingLevel Level) const;
More information about the llvm-commits
mailing list