[llvm] [llvm-ir2vec] adding inst-embedding map API to ir2vec python bindings (PR #177308)
Nishant Sachdeva via llvm-commits
llvm-commits at lists.llvm.org
Wed Jan 28 03:18:45 PST 2026
https://github.com/nishant-sachdeva updated https://github.com/llvm/llvm-project/pull/177308
>From daa538fae9666837e03a31e628d0e70cce52264d Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Wed, 21 Jan 2026 10:47:10 +0530
Subject: [PATCH 01/13] Adding an initEmbedding API to ir2vec python bindings
---
.../llvm-ir2vec/bindings/ir2vec-bindings.py | 21 ++++-
.../tools/llvm-ir2vec/Bindings/CMakeLists.txt | 3 +-
llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp | 87 ++++++++++++++++++-
3 files changed, 106 insertions(+), 5 deletions(-)
diff --git a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
index 459f353a478cb..e6734d2055cd8 100644
--- a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
+++ b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
@@ -1,7 +1,22 @@
-# RUN: env PYTHONPATH=%llvm_lib_dir %python %s
+# RUN: rm -rf %t.ll
+# RUN: echo "define i32 @add(i32 %%a, i32 %%b) {" > %t.ll
+# RUN: echo "entry:" >> %t.ll
+# RUN: echo " %%sum = add i32 %%a, %%b" >> %t.ll
+# RUN: echo " ret i32 %%sum" >> %t.ll
+# RUN: echo "}" >> %t.ll
+# RUN: env PYTHONPATH=%llvm_lib_dir %python %s %t.ll %ir2vec_test_vocab_dir/dummy_3D_nonzero_opc_vocab.json | FileCheck %s
+import sys
import ir2vec
-print("SUCCESS: Module imported")
+ll_file = sys.argv[1]
+vocab_path = sys.argv[2]
-# CHECK: SUCCESS: Module imported
+tool = ir2vec.initEmbedding(filename=ll_file, mode="sym", vocab_path=vocab_path)
+
+if tool is not None:
+ print("SUCCESS: Tool initialized")
+ print(f"Tool type: {type(tool).__name__}")
+
+# CHECK: SUCCESS: Tool initialized
+# CHECK: Tool type: IR2VecTool
diff --git a/llvm/tools/llvm-ir2vec/Bindings/CMakeLists.txt b/llvm/tools/llvm-ir2vec/Bindings/CMakeLists.txt
index 5df3720a24777..677208774f5a1 100644
--- a/llvm/tools/llvm-ir2vec/Bindings/CMakeLists.txt
+++ b/llvm/tools/llvm-ir2vec/Bindings/CMakeLists.txt
@@ -1,4 +1,4 @@
-find_package(Python COMPONENTS Interpreter Development.Module REQUIRED)
+find_package(Python ${Python3_VERSION} EXACT COMPONENTS Interpreter Development.Module REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
@@ -10,5 +10,6 @@ set_target_properties(LLVMEmbUtils PROPERTIES POSITION_INDEPENDENT_CODE ON)
nanobind_add_module(ir2vec MODULE PyIR2Vec.cpp)
target_link_libraries(ir2vec PRIVATE LLVMEmbUtils)
+target_include_directories(ir2vec PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/..)
message(STATUS "Python bindings for llvm-ir2vec will be built with nanobind")
diff --git a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
index b3a46d429b6d4..24297d15caaf1 100644
--- a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
+++ b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
@@ -7,7 +7,92 @@
//===----------------------------------------------------------------------===//
#include <nanobind/nanobind.h>
+#include <nanobind/stl/string.h>
+#include <nanobind/stl/unique_ptr.h>
+
+#include "lib/Utils.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IRReader/IRReader.h"
+#include "llvm/Support/SourceMgr.h"
+
+#include <fstream>
+#include <memory>
+#include <string>
namespace nb = nanobind;
+using namespace llvm;
+using namespace llvm::ir2vec;
+
+namespace llvm {
+namespace ir2vec {
+void setIR2VecVocabPath(StringRef Path);
+StringRef getIR2VecVocabPath();
+} // namespace ir2vec
+} // namespace llvm
+
+namespace {
+
+bool fileNotValid(const std::string &Filename) {
+ std::ifstream F(Filename, std::ios_base::in | std::ios_base::binary);
+ return !F.good();
+}
+
+std::unique_ptr<Module> getLLVMIR(const std::string &Filename,
+ LLVMContext &Context) {
+ SMDiagnostic Err;
+ auto M = parseIRFile(Filename, Err, Context);
+ if (!M)
+ throw std::runtime_error("Failed to parse IR file.");
+ return M;
+}
+
+class PyIR2VecTool {
+private:
+ std::unique_ptr<LLVMContext> Ctx;
+ std::unique_ptr<Module> M;
+ std::unique_ptr<IR2VecTool> Tool;
+
+public:
+ PyIR2VecTool(const std::string &Filename, const std::string &Mode,
+ const std::string &VocabPath) {
+ if (fileNotValid(Filename))
+ throw std::runtime_error("Invalid file path");
+
+ if (Mode != "sym" && Mode != "fa")
+ throw std::runtime_error("Invalid mode. Use 'sym' or 'fa'");
+
+ if (VocabPath.empty())
+ throw std::runtime_error("Error - Empty Vocab Path not allowed");
+
+ setIR2VecVocabPath(VocabPath);
+
+ Ctx = std::make_unique<LLVMContext>();
+ M = getLLVMIR(Filename, *Ctx);
+ Tool = std::make_unique<IR2VecTool>(*M);
+
+ bool Ok = Tool->initializeVocabulary();
+ if (!Ok)
+ throw std::runtime_error("Failed to initialize IR2Vec vocabulary");
+ }
+};
+
+} // namespace
+
+NB_MODULE(ir2vec, m) {
+ m.doc() = std::string("Python bindings for ") + ToolName;
+
+ nb::class_<PyIR2VecTool>(m, "IR2VecTool")
+ .def(nb::init<const std::string &, const std::string &,
+ const std::string &>(),
+ nb::arg("filename"), nb::arg("mode"), nb::arg("vocab_path"));
-NB_MODULE(ir2vec, m) { m.doc() = "Python bindings for IR2Vec"; }
+ m.def(
+ "initEmbedding",
+ [](const std::string &filename, const std::string &mode,
+ const std::string &vocab_path) {
+ return std::make_unique<PyIR2VecTool>(filename, mode, vocab_path);
+ },
+ nb::arg("filename"), nb::arg("mode") = "sym", nb::arg("vocab_path"),
+ nb::rv_policy::take_ownership);
+}
>From 025fd318cdf72040cc77de3b6776a427fc06556a Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Tue, 27 Jan 2026 14:57:43 +0530
Subject: [PATCH 02/13] initEmbedding API prepared with rebase on main
---
llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp | 18 +++++-------------
1 file changed, 5 insertions(+), 13 deletions(-)
diff --git a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
index 24297d15caaf1..f2eac46eb5f02 100644
--- a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
+++ b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
@@ -24,13 +24,6 @@ namespace nb = nanobind;
using namespace llvm;
using namespace llvm::ir2vec;
-namespace llvm {
-namespace ir2vec {
-void setIR2VecVocabPath(StringRef Path);
-StringRef getIR2VecVocabPath();
-} // namespace ir2vec
-} // namespace llvm
-
namespace {
bool fileNotValid(const std::string &Filename) {
@@ -63,17 +56,16 @@ class PyIR2VecTool {
throw std::runtime_error("Invalid mode. Use 'sym' or 'fa'");
if (VocabPath.empty())
- throw std::runtime_error("Error - Empty Vocab Path not allowed");
-
- setIR2VecVocabPath(VocabPath);
+ throw std::runtime_error("Empty Vocab Path not allowed");
Ctx = std::make_unique<LLVMContext>();
M = getLLVMIR(Filename, *Ctx);
Tool = std::make_unique<IR2VecTool>(*M);
- bool Ok = Tool->initializeVocabulary();
- if (!Ok)
- throw std::runtime_error("Failed to initialize IR2Vec vocabulary");
+ if (auto Err = Tool->initializeVocabulary(VocabPath)) {
+ throw std::runtime_error("Failed to initialize IR2Vec vocabulary: " +
+ toString(std::move(Err)));
+ }
}
};
>From ca2a4d749205b984cc4f9ea1a3e89f3a3cd644cd Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Wed, 28 Jan 2026 15:34:51 +0530
Subject: [PATCH 03/13] Refining version setting and requirements for python3
around nanobind
---
llvm/tools/llvm-ir2vec/Bindings/CMakeLists.txt | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/llvm/tools/llvm-ir2vec/Bindings/CMakeLists.txt b/llvm/tools/llvm-ir2vec/Bindings/CMakeLists.txt
index 677208774f5a1..efa890f5025dc 100644
--- a/llvm/tools/llvm-ir2vec/Bindings/CMakeLists.txt
+++ b/llvm/tools/llvm-ir2vec/Bindings/CMakeLists.txt
@@ -1,4 +1,5 @@
-find_package(Python ${Python3_VERSION} EXACT COMPONENTS Interpreter Development.Module REQUIRED)
+set(BINDINGS_MINIMUM_PYTHON_VERSION 3.10)
+find_package(Python ${BINDINGS_MINIMUM_PYTHON_VERSION} EXACT COMPONENTS Interpreter Development.Module REQUIRED)
execute_process(
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
>From f6bcadf4162a8452782544afbf7ee44af38949f7 Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Wed, 28 Jan 2026 15:39:35 +0530
Subject: [PATCH 04/13] Added status check for nanobind installation for ir2vec
python bindings
---
llvm/tools/llvm-ir2vec/Bindings/CMakeLists.txt | 10 +++++++++-
1 file changed, 9 insertions(+), 1 deletion(-)
diff --git a/llvm/tools/llvm-ir2vec/Bindings/CMakeLists.txt b/llvm/tools/llvm-ir2vec/Bindings/CMakeLists.txt
index efa890f5025dc..e5e3fb59b4da7 100644
--- a/llvm/tools/llvm-ir2vec/Bindings/CMakeLists.txt
+++ b/llvm/tools/llvm-ir2vec/Bindings/CMakeLists.txt
@@ -3,8 +3,16 @@ find_package(Python ${BINDINGS_MINIMUM_PYTHON_VERSION} EXACT COMPONENTS Interpre
execute_process(
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
- OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE nanobind_ROOT
+ RESULT_VARIABLE STATUS
+ OUTPUT_VARIABLE nanobind_ROOT
+ OUTPUT_STRIP_TRAILING_WHITESPACE
+ ERROR_QUIET
)
+
+if(NOT STATUS EQUAL "0")
+ message(FATAL_ERROR "nanobind not found (install via 'pip install nanobind' or set nanobind_DIR)")
+endif()
+
find_package(nanobind CONFIG REQUIRED)
set_target_properties(LLVMEmbUtils PROPERTIES POSITION_INDEPENDENT_CODE ON)
>From 8619eb3d2ca75e71b984c6e6add528792ac0b6b7 Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Wed, 28 Jan 2026 16:07:46 +0530
Subject: [PATCH 05/13] Refining behavior around PIC settings for python
bindings module
---
llvm/tools/llvm-ir2vec/Bindings/CMakeLists.txt | 7 +++++--
1 file changed, 5 insertions(+), 2 deletions(-)
diff --git a/llvm/tools/llvm-ir2vec/Bindings/CMakeLists.txt b/llvm/tools/llvm-ir2vec/Bindings/CMakeLists.txt
index e5e3fb59b4da7..8231b45d4d276 100644
--- a/llvm/tools/llvm-ir2vec/Bindings/CMakeLists.txt
+++ b/llvm/tools/llvm-ir2vec/Bindings/CMakeLists.txt
@@ -1,3 +1,8 @@
+if(NOT LLVM_ENABLE_PIC)
+ message(FATAL_ERROR "Python bindings require LLVM_ENABLE_PIC=ON. "
+ "Please reconfigure LLVM with -DLLVM_ENABLE_PIC=ON")
+endif()
+
set(BINDINGS_MINIMUM_PYTHON_VERSION 3.10)
find_package(Python ${BINDINGS_MINIMUM_PYTHON_VERSION} EXACT COMPONENTS Interpreter Development.Module REQUIRED)
@@ -15,8 +20,6 @@ endif()
find_package(nanobind CONFIG REQUIRED)
-set_target_properties(LLVMEmbUtils PROPERTIES POSITION_INDEPENDENT_CODE ON)
-
nanobind_add_module(ir2vec MODULE PyIR2Vec.cpp)
target_link_libraries(ir2vec PRIVATE LLVMEmbUtils)
target_include_directories(ir2vec PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/..)
>From cd2b9d310ed59f9ca227ed6f446845d70a8fa2d4 Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Wed, 28 Jan 2026 16:21:35 +0530
Subject: [PATCH 06/13] Nit commit, error messages, header positions, vocabPath
renaming changes
---
.../llvm-ir2vec/bindings/ir2vec-bindings.py | 2 +-
llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp | 27 +++++++------------
2 files changed, 11 insertions(+), 18 deletions(-)
diff --git a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
index e6734d2055cd8..2ea8922124406 100644
--- a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
+++ b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
@@ -12,7 +12,7 @@
ll_file = sys.argv[1]
vocab_path = sys.argv[2]
-tool = ir2vec.initEmbedding(filename=ll_file, mode="sym", vocab_path=vocab_path)
+tool = ir2vec.initEmbedding(filename=ll_file, mode="sym", vocabPath=vocab_path)
if tool is not None:
print("SUCCESS: Tool initialized")
diff --git a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
index f2eac46eb5f02..de0f7f2738262 100644
--- a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
+++ b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
@@ -6,16 +6,16 @@
//
//===----------------------------------------------------------------------===//
-#include <nanobind/nanobind.h>
-#include <nanobind/stl/string.h>
-#include <nanobind/stl/unique_ptr.h>
-
#include "lib/Utils.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IRReader/IRReader.h"
#include "llvm/Support/SourceMgr.h"
+#include <nanobind/nanobind.h>
+#include <nanobind/stl/string.h>
+#include <nanobind/stl/unique_ptr.h>
+
#include <fstream>
#include <memory>
#include <string>
@@ -26,17 +26,13 @@ using namespace llvm::ir2vec;
namespace {
-bool fileNotValid(const std::string &Filename) {
- std::ifstream F(Filename, std::ios_base::in | std::ios_base::binary);
- return !F.good();
-}
-
std::unique_ptr<Module> getLLVMIR(const std::string &Filename,
LLVMContext &Context) {
SMDiagnostic Err;
auto M = parseIRFile(Filename, Err, Context);
if (!M)
- throw std::runtime_error("Failed to parse IR file.");
+ throw std::runtime_error("Failed to parse IR file '" + Filename +
+ "': " + Err.getMessage().str());
return M;
}
@@ -49,9 +45,6 @@ class PyIR2VecTool {
public:
PyIR2VecTool(const std::string &Filename, const std::string &Mode,
const std::string &VocabPath) {
- if (fileNotValid(Filename))
- throw std::runtime_error("Invalid file path");
-
if (Mode != "sym" && Mode != "fa")
throw std::runtime_error("Invalid mode. Use 'sym' or 'fa'");
@@ -77,14 +70,14 @@ NB_MODULE(ir2vec, m) {
nb::class_<PyIR2VecTool>(m, "IR2VecTool")
.def(nb::init<const std::string &, const std::string &,
const std::string &>(),
- nb::arg("filename"), nb::arg("mode"), nb::arg("vocab_path"));
+ nb::arg("filename"), nb::arg("mode"), nb::arg("vocabPath"));
m.def(
"initEmbedding",
[](const std::string &filename, const std::string &mode,
- const std::string &vocab_path) {
- return std::make_unique<PyIR2VecTool>(filename, mode, vocab_path);
+ const std::string &vocabPath) {
+ return std::make_unique<PyIR2VecTool>(filename, mode, vocabPath);
},
- nb::arg("filename"), nb::arg("mode") = "sym", nb::arg("vocab_path"),
+ nb::arg("filename"), nb::arg("mode") = "sym", nb::arg("vocabPath"),
nb::rv_policy::take_ownership);
}
>From 64228e6daf647564aacf30ae601dfdb69a01b19d Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Wed, 28 Jan 2026 16:30:14 +0530
Subject: [PATCH 07/13] moving from std runtime error to nb error
---
llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp | 12 ++++++------
1 file changed, 6 insertions(+), 6 deletions(-)
diff --git a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
index de0f7f2738262..c1b71f9971e22 100644
--- a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
+++ b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
@@ -31,8 +31,8 @@ std::unique_ptr<Module> getLLVMIR(const std::string &Filename,
SMDiagnostic Err;
auto M = parseIRFile(Filename, Err, Context);
if (!M)
- throw std::runtime_error("Failed to parse IR file '" + Filename +
- "': " + Err.getMessage().str());
+ throw nb::value_error(("Failed to parse IR file '" + Filename +
+ "': " + Err.getMessage().str()).c_str());
return M;
}
@@ -46,18 +46,18 @@ class PyIR2VecTool {
PyIR2VecTool(const std::string &Filename, const std::string &Mode,
const std::string &VocabPath) {
if (Mode != "sym" && Mode != "fa")
- throw std::runtime_error("Invalid mode. Use 'sym' or 'fa'");
+ throw nb::value_error("Invalid mode. Use 'sym' or 'fa'");
if (VocabPath.empty())
- throw std::runtime_error("Empty Vocab Path not allowed");
+ throw nb::value_error("Empty Vocab Path not allowed");
Ctx = std::make_unique<LLVMContext>();
M = getLLVMIR(Filename, *Ctx);
Tool = std::make_unique<IR2VecTool>(*M);
if (auto Err = Tool->initializeVocabulary(VocabPath)) {
- throw std::runtime_error("Failed to initialize IR2Vec vocabulary: " +
- toString(std::move(Err)));
+ throw nb::value_error(("Failed to initialize IR2Vec vocabulary: " +
+ toString(std::move(Err))).c_str());
}
}
};
>From cd8e787a458ec06b5c71692caecfbd12fae9d415 Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Wed, 28 Jan 2026 16:30:28 +0530
Subject: [PATCH 08/13] formatting fixup
---
llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
index c1b71f9971e22..530adee8e052e 100644
--- a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
+++ b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
@@ -32,7 +32,8 @@ std::unique_ptr<Module> getLLVMIR(const std::string &Filename,
auto M = parseIRFile(Filename, Err, Context);
if (!M)
throw nb::value_error(("Failed to parse IR file '" + Filename +
- "': " + Err.getMessage().str()).c_str());
+ "': " + Err.getMessage().str())
+ .c_str());
return M;
}
@@ -57,7 +58,8 @@ class PyIR2VecTool {
if (auto Err = Tool->initializeVocabulary(VocabPath)) {
throw nb::value_error(("Failed to initialize IR2Vec vocabulary: " +
- toString(std::move(Err))).c_str());
+ toString(std::move(Err)))
+ .c_str());
}
}
};
>From e5201d98de944a97ebf43a2b7d56005311c6c7c3 Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Wed, 28 Jan 2026 16:46:11 +0530
Subject: [PATCH 09/13] Nanobind package find fixup
---
.../tools/llvm-ir2vec/Bindings/CMakeLists.txt | 28 ++++++++++++-------
1 file changed, 18 insertions(+), 10 deletions(-)
diff --git a/llvm/tools/llvm-ir2vec/Bindings/CMakeLists.txt b/llvm/tools/llvm-ir2vec/Bindings/CMakeLists.txt
index 8231b45d4d276..376cea77106ff 100644
--- a/llvm/tools/llvm-ir2vec/Bindings/CMakeLists.txt
+++ b/llvm/tools/llvm-ir2vec/Bindings/CMakeLists.txt
@@ -6,16 +6,24 @@ endif()
set(BINDINGS_MINIMUM_PYTHON_VERSION 3.10)
find_package(Python ${BINDINGS_MINIMUM_PYTHON_VERSION} EXACT COMPONENTS Interpreter Development.Module REQUIRED)
-execute_process(
- COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
- RESULT_VARIABLE STATUS
- OUTPUT_VARIABLE nanobind_ROOT
- OUTPUT_STRIP_TRAILING_WHITESPACE
- ERROR_QUIET
-)
-
-if(NOT STATUS EQUAL "0")
- message(FATAL_ERROR "nanobind not found (install via 'pip install nanobind' or set nanobind_DIR)")
+if(nanobind_DIR)
+ message(STATUS "Using explicit nanobind cmake directory: ${nanobind_DIR}")
+else()
+ message(STATUS "Checking for nanobind in python path...")
+ execute_process(
+ COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
+ RESULT_VARIABLE STATUS
+ OUTPUT_VARIABLE PACKAGE_DIR
+ OUTPUT_STRIP_TRAILING_WHITESPACE
+ ERROR_QUIET
+ )
+
+ if(NOT STATUS EQUAL "0")
+ message(FATAL_ERROR "nanobind not found (install via 'pip install nanobind' or set nanobind_DIR)")
+ endif()
+
+ message(STATUS "found nanobind at: ${PACKAGE_DIR}")
+ set(nanobind_DIR "${PACKAGE_DIR}")
endif()
find_package(nanobind CONFIG REQUIRED)
>From c4a345ebf26fb628b6c44624da0e8e8c81307080 Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Wed, 21 Jan 2026 18:20:16 +0530
Subject: [PATCH 10/13] Adding getFuncEmbMap functionality to ir2vec python
bindings
---
.../llvm-ir2vec/bindings/ir2vec-bindings.py | 41 +++++++++++++++++++
llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp | 35 +++++++++++++++-
llvm/tools/llvm-ir2vec/lib/Utils.cpp | 35 ++++++++++++++++
llvm/tools/llvm-ir2vec/lib/Utils.h | 9 ++++
4 files changed, 118 insertions(+), 2 deletions(-)
diff --git a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
index 2ea8922124406..ec03cb456efe6 100644
--- a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
+++ b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
@@ -4,6 +4,30 @@
# RUN: echo " %%sum = add i32 %%a, %%b" >> %t.ll
# RUN: echo " ret i32 %%sum" >> %t.ll
# RUN: echo "}" >> %t.ll
+# RUN: echo "" >> %t.ll
+# RUN: echo "define i32 @multiply(i32 %%x, i32 %%y) {" >> %t.ll
+# RUN: echo "entry:" >> %t.ll
+# RUN: echo " %%prod = mul i32 %%x, %%y" >> %t.ll
+# RUN: echo " ret i32 %%prod" >> %t.ll
+# RUN: echo "}" >> %t.ll
+# RUN: echo "" >> %t.ll
+# RUN: echo "define i32 @conditional(i32 %%n) {" >> %t.ll
+# RUN: echo "entry:" >> %t.ll
+# RUN: echo " %%cmp = icmp sgt i32 %%n, 0" >> %t.ll
+# RUN: echo " br i1 %%cmp, label %%positive, label %%negative" >> %t.ll
+# RUN: echo "" >> %t.ll
+# RUN: echo "positive:" >> %t.ll
+# RUN: echo " %%pos_val = add i32 %%n, 10" >> %t.ll
+# RUN: echo " br label %%exit" >> %t.ll
+# RUN: echo "" >> %t.ll
+# RUN: echo "negative:" >> %t.ll
+# RUN: echo " %%neg_val = sub i32 %%n, 10" >> %t.ll
+# RUN: echo " br label %%exit" >> %t.ll
+# RUN: echo "" >> %t.ll
+# RUN: echo "exit:" >> %t.ll
+# RUN: echo " %%result = phi i32 [ %%pos_val, %%positive ], [ %%neg_val, %%negative ]" >> %t.ll
+# RUN: echo " ret i32 %%result" >> %t.ll
+# RUN: echo "}" >> %t.ll
# RUN: env PYTHONPATH=%llvm_lib_dir %python %s %t.ll %ir2vec_test_vocab_dir/dummy_3D_nonzero_opc_vocab.json | FileCheck %s
import sys
@@ -18,5 +42,22 @@
print("SUCCESS: Tool initialized")
print(f"Tool type: {type(tool).__name__}")
+ # Test getFuncEmbMap
+ func_emb_map = tool.getFuncEmbMap()
+ print(f"Number of functions: {len(func_emb_map)}")
+
+ # Check that all three functions are present
+ expected_funcs = ["add", "multiply", "conditional"]
+ for func_name in expected_funcs:
+ if func_name in func_emb_map:
+ emb = func_emb_map[func_name]
+ print(f"Function '{func_name}': embedding shape = {emb.shape}")
+ else:
+ print(f"ERROR: Function '{func_name}' not found")
+
# CHECK: SUCCESS: Tool initialized
# CHECK: Tool type: IR2VecTool
+# CHECK: Number of functions: 3
+# CHECK: Function 'add': embedding shape =
+# CHECK: Function 'multiply': embedding shape =
+# CHECK: Function 'conditional': embedding shape =
diff --git a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
index 530adee8e052e..346faf879c855 100644
--- a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
+++ b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
@@ -42,12 +42,18 @@ class PyIR2VecTool {
std::unique_ptr<LLVMContext> Ctx;
std::unique_ptr<Module> M;
std::unique_ptr<IR2VecTool> Tool;
+ IR2VecKind EmbKind;
public:
PyIR2VecTool(const std::string &Filename, const std::string &Mode,
const std::string &VocabPath) {
- if (Mode != "sym" && Mode != "fa")
+ EmbKind = [](const std::string &Mode) -> IR2VecKind {
+ if (Mode == "sym")
+ return IR2VecKind::Symbolic;
+ if (Mode == "fa")
+ return IR2VecKind::FlowAware;
throw nb::value_error("Invalid mode. Use 'sym' or 'fa'");
+ }(Mode);
if (VocabPath.empty())
throw nb::value_error("Empty Vocab Path not allowed");
@@ -62,6 +68,27 @@ class PyIR2VecTool {
.c_str());
}
}
+
+ nb::dict getFuncEmbMap() {
+ auto result = Tool->getFunctionEmbeddings(EmbKind);
+ nb::dict nb_result;
+
+ for (const auto &[func_ptr, embedding] : result) {
+ std::string func_name = func_ptr->getName().str();
+ auto data = embedding.getData();
+ size_t shape[1] = {data.size()};
+ double *data_ptr = new double[data.size()];
+ std::copy(data.data(), data.data() + data.size(), data_ptr);
+
+ auto nb_array = nb::ndarray<nb::numpy, double>(
+ data_ptr, {data.size()}, nb::capsule(data_ptr, [](void *p) noexcept {
+ delete[] static_cast<double *>(p);
+ }));
+ nb_result[nb::str(func_name.c_str())] = nb_array;
+ }
+
+ return nb_result;
+ }
};
} // namespace
@@ -72,7 +99,11 @@ NB_MODULE(ir2vec, m) {
nb::class_<PyIR2VecTool>(m, "IR2VecTool")
.def(nb::init<const std::string &, const std::string &,
const std::string &>(),
- nb::arg("filename"), nb::arg("mode"), nb::arg("vocabPath"));
+ nb::arg("filename"), nb::arg("mode"), nb::arg("vocabPath"))
+ .def("getFuncEmbMap", &PyIR2VecTool::getFuncEmbMap,
+ "Generate function-level embeddings for all functions\n"
+ "Returns: dict[str, ndarray[float64]] - "
+ "{function_name: embedding}");
m.def(
"initEmbedding",
diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.cpp b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
index 190d9259e45b3..4e8589885e019 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.cpp
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
@@ -151,6 +151,41 @@ void IR2VecTool::writeEntitiesToStream(raw_ostream &OS) {
OS << Entities[EntityID] << '\t' << EntityID << '\n';
}
+std::pair<const Function *, Embedding>
+IR2VecTool::getFunctionEmbedding(const Function &F, IR2VecKind Kind) const {
+ assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
+
+ if (F.isDeclaration())
+ return {nullptr, Embedding()};
+
+ auto Emb = Embedder::create(Kind, F, *Vocab);
+ if (!Emb) {
+ return {nullptr, Embedding()};
+ }
+
+ auto FuncVec = Emb->getFunctionVector();
+
+ return {&F, std::move(FuncVec)};
+}
+
+FuncEmbMap IR2VecTool::getFunctionEmbeddings(IR2VecKind Kind) const {
+ assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
+
+ FuncEmbMap Result;
+
+ for (const Function &F : M) {
+ if (F.isDeclaration())
+ continue;
+
+ auto Emb = getFunctionEmbedding(F, Kind);
+ if (Emb.first != nullptr) {
+ Result.try_emplace(Emb.first, std::move(Emb.second));
+ }
+ }
+
+ return Result;
+}
+
void IR2VecTool::writeEmbeddingsToStream(raw_ostream &OS,
EmbeddingLevel Level) const {
if (!Vocab->isValid()) {
diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.h b/llvm/tools/llvm-ir2vec/lib/Utils.h
index d9715b03c3082..d115d9a26ca90 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.h
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.h
@@ -16,6 +16,7 @@
#define LLVM_TOOLS_LLVM_IR2VEC_UTILS_UTILS_H
#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/DenseMap.h"
#include "llvm/Analysis/IR2Vec.h"
#include "llvm/CodeGen/MIR2Vec.h"
#include "llvm/CodeGen/MIRParser/MIRParser.h"
@@ -72,6 +73,7 @@ struct TripletResult {
/// Entity mappings: [entity_name]
using EntityList = std::vector<std::string>;
+using FuncEmbMap = DenseMap<const Function *, ir2vec::Embedding>;
namespace ir2vec {
@@ -112,6 +114,13 @@ class IR2VecTool {
/// Returns EntityList containing all entity strings
static EntityList collectEntityMappings();
+ // Get embedding for a single function
+ std::pair<const Function *, Embedding>
+ getFunctionEmbedding(const Function &F, IR2VecKind Kind) const;
+
+ /// Get embeddings for all functions in the module
+ FuncEmbMap getFunctionEmbeddings(IR2VecKind Kind) const;
+
/// Dump entity ID to string mappings
static void writeEntitiesToStream(raw_ostream &OS);
>From 4cb7297153f8fb5fb2864eabdba146eb3d766c42 Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Wed, 21 Jan 2026 19:36:48 +0530
Subject: [PATCH 11/13] Changing unit-test structure for function embeddings
---
.../llvm-ir2vec/bindings/ir2vec-bindings.py | 28 +++++++++----------
1 file changed, 13 insertions(+), 15 deletions(-)
diff --git a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
index ec03cb456efe6..343ea4bdc25bb 100644
--- a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
+++ b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
@@ -40,24 +40,22 @@
if tool is not None:
print("SUCCESS: Tool initialized")
- print(f"Tool type: {type(tool).__name__}")
# Test getFuncEmbMap
+ print("\n=== Function Embeddings ===")
func_emb_map = tool.getFuncEmbMap()
- print(f"Number of functions: {len(func_emb_map)}")
- # Check that all three functions are present
- expected_funcs = ["add", "multiply", "conditional"]
- for func_name in expected_funcs:
- if func_name in func_emb_map:
- emb = func_emb_map[func_name]
- print(f"Function '{func_name}': embedding shape = {emb.shape}")
- else:
- print(f"ERROR: Function '{func_name}' not found")
+ # Sorting the function names for deterministic output
+ for func_name in sorted(func_emb_map.keys()):
+ emb = func_emb_map[func_name]
+ print(f"Function: {func_name}")
+ print(f" Embedding: {emb.tolist()}")
# CHECK: SUCCESS: Tool initialized
-# CHECK: Tool type: IR2VecTool
-# CHECK: Number of functions: 3
-# CHECK: Function 'add': embedding shape =
-# CHECK: Function 'multiply': embedding shape =
-# CHECK: Function 'conditional': embedding shape =
+# CHECK: === Function Embeddings ===
+# CHECK: Function: add
+# CHECK-NEXT: Embedding: [38.0, 40.0, 42.0]
+# CHECK: Function: conditional
+# CHECK-NEXT: Embedding: [413.20000000298023, 421.20000000298023, 429.20000000298023]
+# CHECK: Function: multiply
+# CHECK-NEXT: Embedding: [50.0, 52.0, 54.0]
>From 070f364df338dd93627ef7bd681455065e92554d Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Wed, 21 Jan 2026 19:32:36 +0530
Subject: [PATCH 12/13] adding BB embedding map API to ir2vec python bindings
---
.../llvm-ir2vec/bindings/ir2vec-bindings.py | 23 ++++++++++++
llvm/tools/llvm-ir2vec/lib/Utils.cpp | 35 +++++++++++++++++++
llvm/tools/llvm-ir2vec/lib/Utils.h | 6 ++++
3 files changed, 64 insertions(+)
diff --git a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
index 343ea4bdc25bb..693885ff02ae6 100644
--- a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
+++ b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
@@ -51,6 +51,16 @@
print(f"Function: {func_name}")
print(f" Embedding: {emb.tolist()}")
+ # Test getBBEmbMap
+ print("\n=== Basic Block Embeddings ===")
+ bb_emb_list = tool.getBBEmbMap()
+
+ # Sorting by BB name for deterministic output
+ bb_sorted = sorted(bb_emb_list, key=lambda x: (x[0], tuple(x[1].tolist())))
+ for bb_name, emb in bb_sorted:
+ print(f"BB: {bb_name}")
+ print(f" Embedding: {emb.tolist()}")
+
# CHECK: SUCCESS: Tool initialized
# CHECK: === Function Embeddings ===
# CHECK: Function: add
@@ -59,3 +69,16 @@
# CHECK-NEXT: Embedding: [413.20000000298023, 421.20000000298023, 429.20000000298023]
# CHECK: Function: multiply
# CHECK-NEXT: Embedding: [50.0, 52.0, 54.0]
+# CHECK: === Basic Block Embeddings ===
+# CHECK: BB: entry
+# CHECK-NEXT: Embedding: [38.0, 40.0, 42.0]
+# CHECK: BB: entry
+# CHECK-NEXT: Embedding: [50.0, 52.0, 54.0]
+# CHECK: BB: entry
+# CHECK-NEXT: Embedding: [161.20000000298023, 163.20000000298023, 165.20000000298023]
+# CHECK: BB: exit
+# CHECK-NEXT: Embedding: [164.0, 166.0, 168.0]
+# CHECK: BB: negative
+# CHECK-NEXT: Embedding: [47.0, 49.0, 51.0]
+# CHECK: BB: positive
+# CHECK-NEXT: Embedding: [41.0, 43.0, 45.0]
diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.cpp b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
index 4e8589885e019..aab95f7341b6e 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.cpp
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
@@ -186,6 +186,41 @@ FuncEmbMap IR2VecTool::getFunctionEmbeddings(IR2VecKind Kind) const {
return Result;
}
+BBEmbeddingsMap IR2VecTool::getBBEmbeddings(const Function &F,
+ IR2VecKind Kind) const {
+ assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
+
+ BBEmbeddingsMap Result;
+
+ if (F.isDeclaration())
+ return Result;
+
+ auto Emb = Embedder::create(Kind, F, *Vocab);
+ if (!Emb)
+ return Result;
+
+ for (const BasicBlock &BB : F)
+ Result.try_emplace(&BB, Emb->getBBVector(BB));
+
+ return Result;
+}
+
+BBEmbeddingsMap IR2VecTool::getBBEmbeddings(IR2VecKind Kind) const {
+ assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
+
+ BBEmbeddingsMap Result;
+
+ for (const Function &F : M) {
+ if (F.isDeclaration())
+ continue;
+
+ BBEmbeddingsMap FuncBBVecs = getBBEmbeddings(F, Kind);
+ Result.insert(FuncBBVecs.begin(), FuncBBVecs.end());
+ }
+
+ return Result;
+}
+
void IR2VecTool::writeEmbeddingsToStream(raw_ostream &OS,
EmbeddingLevel Level) const {
if (!Vocab->isValid()) {
diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.h b/llvm/tools/llvm-ir2vec/lib/Utils.h
index d115d9a26ca90..abd00d806197a 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.h
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.h
@@ -121,6 +121,12 @@ class IR2VecTool {
/// Get embeddings for all functions in the module
FuncEmbMap getFunctionEmbeddings(IR2VecKind Kind) const;
+ /// Get embeddings for all basic blocks in a function
+ BBEmbeddingsMap getBBEmbeddings(const Function &F, IR2VecKind Kind) const;
+
+ /// Get embeddings for all basic blocks in the module
+ BBEmbeddingsMap getBBEmbeddings(IR2VecKind Kind) const;
+
/// Dump entity ID to string mappings
static void writeEntitiesToStream(raw_ostream &OS);
>From 50b492ad38fb7ddff40f562039b4f132d55afac9 Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Thu, 22 Jan 2026 10:51:58 +0530
Subject: [PATCH 13/13] added instruction embedding map API to ir2vec python
bindings
---
.../llvm-ir2vec/bindings/ir2vec-bindings.py | 37 +++++++++++-
llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp | 56 ++++++++++++++++++-
llvm/tools/llvm-ir2vec/lib/Utils.cpp | 36 ++++++++++++
llvm/tools/llvm-ir2vec/lib/Utils.h | 6 ++
4 files changed, 133 insertions(+), 2 deletions(-)
diff --git a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
index 693885ff02ae6..03a7e8444f8ef 100644
--- a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
+++ b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-bindings.py
@@ -55,12 +55,22 @@
print("\n=== Basic Block Embeddings ===")
bb_emb_list = tool.getBBEmbMap()
- # Sorting by BB name for deterministic output
+ # Sorting by BB name, then embedding values for deterministic output
bb_sorted = sorted(bb_emb_list, key=lambda x: (x[0], tuple(x[1].tolist())))
for bb_name, emb in bb_sorted:
print(f"BB: {bb_name}")
print(f" Embedding: {emb.tolist()}")
+ # Test getInstEmbMap
+ print("\n=== Instruction Embeddings ===")
+ inst_emb_list = tool.getInstEmbMap()
+
+ # Sorting by instruction string, then embedding values for deterministic output
+ inst_sorted = sorted(inst_emb_list, key=lambda x: (x[0], tuple(x[1].tolist())))
+ for inst_str, emb in inst_sorted:
+ print(f"Inst: {inst_str}")
+ print(f" Embedding: {emb.tolist()}")
+
# CHECK: SUCCESS: Tool initialized
# CHECK: === Function Embeddings ===
# CHECK: Function: add
@@ -82,3 +92,28 @@
# CHECK-NEXT: Embedding: [47.0, 49.0, 51.0]
# CHECK: BB: positive
# CHECK-NEXT: Embedding: [41.0, 43.0, 45.0]
+# CHECK: === Instruction Embeddings ===
+# CHECK: Inst: %cmp = icmp sgt i32 %n, 0
+# CHECK-NEXT: Embedding: [157.20000000298023, 158.20000000298023, 159.20000000298023]
+# CHECK: Inst: %neg_val = sub i32 %n, 10
+# CHECK-NEXT: Embedding: [43.0, 44.0, 45.0]
+# CHECK: Inst: %pos_val = add i32 %n, 10
+# CHECK-NEXT: Embedding: [37.0, 38.0, 39.0]
+# CHECK: Inst: %prod = mul i32 %x, %y
+# CHECK-NEXT: Embedding: [49.0, 50.0, 51.0]
+# CHECK: Inst: %result = phi i32 [ %pos_val, %positive ], [ %neg_val, %negative ]
+# CHECK-NEXT: Embedding: [163.0, 164.0, 165.0]
+# CHECK: Inst: %sum = add i32 %a, %b
+# CHECK-NEXT: Embedding: [37.0, 38.0, 39.0]
+# CHECK: Inst: br i1 %cmp, label %positive, label %negative
+# CHECK-NEXT: Embedding: [4.0, 5.0, 6.0]
+# CHECK: Inst: br label %exit
+# CHECK-NEXT: Embedding: [4.0, 5.0, 6.0]
+# CHECK: Inst: br label %exit
+# CHECK-NEXT: Embedding: [4.0, 5.0, 6.0]
+# CHECK: Inst: ret i32 %prod
+# CHECK-NEXT: Embedding: [1.0, 2.0, 3.0]
+# CHECK: Inst: ret i32 %result
+# CHECK-NEXT: Embedding: [1.0, 2.0, 3.0]
+# CHECK: Inst: ret i32 %sum
+# CHECK-NEXT: Embedding: [1.0, 2.0, 3.0]
diff --git a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
index 346faf879c855..64b192c7148c3 100644
--- a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
+++ b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
@@ -89,6 +89,53 @@ class PyIR2VecTool {
return nb_result;
}
+
+ nb::list getBBEmbMap() {
+ auto result = Tool->getBBEmbeddings(EmbKind);
+ nb::list nb_result;
+
+ for (const auto &[bb_ptr, embedding] : result) {
+ std::string bb_name = bb_ptr->getName().str();
+ auto data = embedding.getData();
+
+ double *data_ptr = new double[data.size()];
+ std::copy(data.data(), data.data() + data.size(), data_ptr);
+ auto nb_array = nb::ndarray<nb::numpy, double, nb::shape<-1>>(
+ data_ptr, {data.size()}, nb::capsule(data_ptr, [](void *p) noexcept {
+ delete[] static_cast<double *>(p);
+ }));
+ nb_result.append(nb::make_tuple(nb::str(bb_name.c_str()), nb_array));
+ }
+
+ return nb_result;
+ }
+
+ nb::list getInstEmbMap() {
+ auto result = Tool->getInstEmbeddings(EmbKind);
+ nb::list nb_result;
+
+ for (const auto &[inst_ptr, embedding] : result) {
+ std::string inst_str;
+ llvm::raw_string_ostream RSO(inst_str);
+ inst_ptr->print(RSO);
+ RSO.flush();
+
+ auto data = embedding.getData();
+
+ double *data_ptr = new double[data.size()];
+ std::copy(data.data(), data.data() + data.size(), data_ptr);
+
+ // Create nanobind numpy array with dynamic 1D shape
+ auto nb_array = nb::ndarray<nb::numpy, double, nb::shape<-1>>(
+ data_ptr, {data.size()}, nb::capsule(data_ptr, [](void *p) noexcept {
+ delete[] static_cast<double *>(p);
+ }));
+
+ nb_result.append(nb::make_tuple(nb::str(inst_str.c_str()), nb_array));
+ }
+
+ return nb_result;
+ }
};
} // namespace
@@ -103,7 +150,14 @@ NB_MODULE(ir2vec, m) {
.def("getFuncEmbMap", &PyIR2VecTool::getFuncEmbMap,
"Generate function-level embeddings for all functions\n"
"Returns: dict[str, ndarray[float64]] - "
- "{function_name: embedding}");
+ "{function_name: embedding}")
+ .def("getBBEmbMap", &PyIR2VecTool::getBBEmbMap,
+ "Generate basic block embeddings for all functions\n"
+ "Returns: list[tuple[str, ndarray[float64]]] - "
+ "[{bb_name, embedding}]")
+ .def("getInstEmbMap", &PyIR2VecTool::getInstEmbMap,
+ "Generate instruction embeddings for all functions\n"
+ "Returns: list[tuple[str, ndarray[float64]]]");
m.def(
"initEmbedding",
diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.cpp b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
index aab95f7341b6e..c669f9aa69ea4 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.cpp
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.cpp
@@ -221,6 +221,42 @@ BBEmbeddingsMap IR2VecTool::getBBEmbeddings(IR2VecKind Kind) const {
return Result;
}
+InstEmbeddingsMap IR2VecTool::getInstEmbeddings(const Function &F,
+ IR2VecKind Kind) const {
+ assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
+
+ InstEmbeddingsMap Result;
+
+ if (F.isDeclaration())
+ return Result;
+
+ auto Emb = Embedder::create(Kind, F, *Vocab);
+ if (!Emb)
+ return Result;
+
+ for (const Instruction &I : instructions(F)) {
+ Result.try_emplace(&I, Emb->getInstVector(I));
+ }
+
+ return Result;
+}
+
+InstEmbeddingsMap IR2VecTool::getInstEmbeddings(IR2VecKind Kind) const {
+ assert(Vocab && Vocab->isValid() && "Vocabulary not initialized");
+
+ InstEmbeddingsMap Result;
+
+ for (const Function &F : M) {
+ if (F.isDeclaration())
+ continue;
+
+ InstEmbeddingsMap FuncInstVecs = getInstEmbeddings(F, Kind);
+ Result.insert(FuncInstVecs.begin(), FuncInstVecs.end());
+ }
+
+ return Result;
+}
+
void IR2VecTool::writeEmbeddingsToStream(raw_ostream &OS,
EmbeddingLevel Level) const {
if (!Vocab->isValid()) {
diff --git a/llvm/tools/llvm-ir2vec/lib/Utils.h b/llvm/tools/llvm-ir2vec/lib/Utils.h
index abd00d806197a..db437be5849ee 100644
--- a/llvm/tools/llvm-ir2vec/lib/Utils.h
+++ b/llvm/tools/llvm-ir2vec/lib/Utils.h
@@ -127,6 +127,12 @@ class IR2VecTool {
/// Get embeddings for all basic blocks in the module
BBEmbeddingsMap getBBEmbeddings(IR2VecKind Kind) const;
+ /// Get embeddings for all instructions in a function
+ InstEmbeddingsMap getInstEmbeddings(const Function &F, IR2VecKind Kind) const;
+
+ /// Get embeddings for all instructions in the module
+ InstEmbeddingsMap getInstEmbeddings(IR2VecKind Kind) const;
+
/// Dump entity ID to string mappings
static void writeEntitiesToStream(raw_ostream &OS);
More information about the llvm-commits
mailing list