[llvm] [llvm-ir2vec] Added Enum for ir2vec embedding mode (PR #190466)

Nishant Sachdeva via llvm-commits llvm-commits at lists.llvm.org
Sat Apr 4 08:36:49 PDT 2026


https://github.com/nishant-sachdeva created https://github.com/llvm/llvm-project/pull/190466

Currently, the initEmbedding() takes mode as an input. This input is a string input. This PR introduces a patch to take the input as an enum value. 

@svkeerthy 

>From 5422b32972164d9ade299fb9d0b25a9eff6920fe Mon Sep 17 00:00:00 2001
From: nishant-sachdeva <nishant.sachdeva at research.iiit.ac.in>
Date: Sat, 4 Apr 2026 21:01:49 +0530
Subject: [PATCH] Added Enum for ir2vec mode

---
 .../bindings/ir2vec-getBBEmbMap.py            |  2 +-
 .../llvm-ir2vec/bindings/ir2vec-getFuncEmb.py |  2 +-
 .../bindings/ir2vec-getFuncEmbMap.py          |  2 +-
 .../bindings/ir2vec-getFuncNames.py           |  2 +-
 .../bindings/ir2vec-getInstEmbMap.py          |  2 +-
 .../bindings/ir2vec-initEmbedding.py          | 14 +++++------
 llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp  | 23 ++++++++++---------
 7 files changed, 24 insertions(+), 23 deletions(-)

diff --git a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getBBEmbMap.py b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getBBEmbMap.py
index 333feadc6c932..a85876c7d8e9f 100644
--- a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getBBEmbMap.py
+++ b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getBBEmbMap.py
@@ -6,7 +6,7 @@
 ll_file = sys.argv[1]
 vocab_path = sys.argv[2]
 
-tool = ir2vec.initEmbedding(filename=ll_file, mode="sym", vocabPath=vocab_path)
+tool = ir2vec.initEmbedding(filename=ll_file, mode=ir2vec.IR2VecKind.Symbolic, vocabPath=vocab_path)
 
 # Success case
 bb_map = tool.getBBEmbMap("conditional")
diff --git a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getFuncEmb.py b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getFuncEmb.py
index 61b9464c89757..782cbb4f65ae3 100644
--- a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getFuncEmb.py
+++ b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getFuncEmb.py
@@ -6,7 +6,7 @@
 ll_file = sys.argv[1]
 vocab_path = sys.argv[2]
 
-tool = ir2vec.initEmbedding(filename=ll_file, mode="sym", vocabPath=vocab_path)
+tool = ir2vec.initEmbedding(filename=ll_file, mode=ir2vec.IR2VecKind.Symbolic, vocabPath=vocab_path)
 
 # Success case
 emb = tool.getFuncEmb("add")
diff --git a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getFuncEmbMap.py b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getFuncEmbMap.py
index 7600d5e4a2986..4235ed079c407 100644
--- a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getFuncEmbMap.py
+++ b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getFuncEmbMap.py
@@ -6,7 +6,7 @@
 ll_file = sys.argv[1]
 vocab_path = sys.argv[2]
 
-tool = ir2vec.initEmbedding(filename=ll_file, mode="sym", vocabPath=vocab_path)
+tool = ir2vec.initEmbedding(filename=ll_file, mode=ir2vec.IR2VecKind.Symbolic, vocabPath=vocab_path)
 
 # Success case
 emb_map = tool.getFuncEmbMap()
diff --git a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getFuncNames.py b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getFuncNames.py
index 432d80e97edb9..3e3048ead6e5f 100644
--- a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getFuncNames.py
+++ b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getFuncNames.py
@@ -6,7 +6,7 @@
 ll_file = sys.argv[1]
 vocab_path = sys.argv[2]
 
-tool = ir2vec.initEmbedding(filename=ll_file, mode="sym", vocabPath=vocab_path)
+tool = ir2vec.initEmbedding(filename=ll_file, mode=ir2vec.IR2VecKind.Symbolic, vocabPath=vocab_path)
 
 # Success case
 func_names = tool.getFuncNames()
diff --git a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getInstEmbMap.py b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getInstEmbMap.py
index 3157ae34cfd3c..f9bd4ffcc9e5d 100644
--- a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getInstEmbMap.py
+++ b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-getInstEmbMap.py
@@ -6,7 +6,7 @@
 ll_file = sys.argv[1]
 vocab_path = sys.argv[2]
 
-tool = ir2vec.initEmbedding(filename=ll_file, mode="sym", vocabPath=vocab_path)
+tool = ir2vec.initEmbedding(filename=ll_file, mode=ir2vec.IR2VecKind.Symbolic, vocabPath=vocab_path)
 
 # Success case
 inst_map = tool.getInstEmbMap("add")
diff --git a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-initEmbedding.py b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-initEmbedding.py
index c19935a0c6b7d..43c3db5d98c03 100644
--- a/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-initEmbedding.py
+++ b/llvm/test/tools/llvm-ir2vec/bindings/ir2vec-initEmbedding.py
@@ -7,41 +7,41 @@
 vocab_path = sys.argv[2]
 
 # Success case
-tool = ir2vec.initEmbedding(filename=ll_file, mode="sym", vocabPath=vocab_path)
+tool = ir2vec.initEmbedding(filename=ll_file, mode=ir2vec.IR2VecKind.Symbolic, vocabPath=vocab_path)
 print(f"SUCCESS: {type(tool).__name__}")
 # CHECK: SUCCESS: IR2VecTool
 
 # Error: Invalid mode
 try:
     ir2vec.initEmbedding(filename=ll_file, mode="invalid", vocabPath=vocab_path)
-except ValueError:
+except TypeError:
     print("ERROR: Invalid mode")
 # CHECK: ERROR: Invalid mode
 
 # Error: Empty vocab path
 try:
-    ir2vec.initEmbedding(filename=ll_file, mode="sym", vocabPath="")
+    ir2vec.initEmbedding(filename=ll_file, mode=ir2vec.IR2VecKind.Symbolic, vocabPath="")
 except ValueError:
     print("ERROR: Empty vocab path")
 # CHECK: ERROR: Empty vocab path
 
 # Error: Invalid file
 try:
-    ir2vec.initEmbedding(filename="/bad.ll", mode="sym", vocabPath=vocab_path)
+    ir2vec.initEmbedding(filename="/bad.ll", mode=ir2vec.IR2VecKind.Symbolic, vocabPath=vocab_path)
 except ValueError:
     print("ERROR: Invalid file")
 # CHECK: ERROR: Invalid file
 
 # Error: Empty filename
 try:
-    ir2vec.initEmbedding(filename="", mode="sym", vocabPath=vocab_path)
+    ir2vec.initEmbedding(filename="", mode=ir2vec.IR2VecKind.Symbolic, vocabPath=vocab_path)
 except ValueError:
     print("ERROR: Empty filename")
 # CHECK: ERROR: Empty filename
 
 # Error: Invalid vocab file
 try:
-    ir2vec.initEmbedding(filename=ll_file, mode="sym", vocabPath="/bad.json")
+    ir2vec.initEmbedding(filename=ll_file, mode=ir2vec.IR2VecKind.Symbolic, vocabPath="/bad.json")
 except ValueError:
     print("ERROR: Invalid vocab")
 # CHECK: ERROR: Invalid vocab
@@ -54,7 +54,7 @@
     f.write("{ this is not valid json }")
     bad_vocab = f.name
 try:
-    ir2vec.initEmbedding(filename=ll_file, mode="sym", vocabPath=bad_vocab)
+    ir2vec.initEmbedding(filename=ll_file, mode=ir2vec.IR2VecKind.Symbolic, vocabPath=bad_vocab)
 except ValueError:
     print("ERROR: Invalid vocab file")
 finally:
diff --git a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
index 2f885b11519c7..d76510f15b081 100644
--- a/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
+++ b/llvm/tools/llvm-ir2vec/Bindings/PyIR2Vec.cpp
@@ -46,15 +46,9 @@ class PyIR2VecTool {
   IR2VecKind OutputEmbeddingMode;
 
 public:
-  PyIR2VecTool(const std::string &Filename, const std::string &Mode,
+  PyIR2VecTool(const std::string &Filename, IR2VecKind Mode,
                const std::string &VocabPath) {
-    OutputEmbeddingMode = [](const std::string &Mode) -> IR2VecKind {
-      if (Mode == "sym")
-        return IR2VecKind::Symbolic;
-      if (Mode == "fa")
-        return IR2VecKind::FlowAware;
-      throw nb::value_error("Invalid mode. Use 'sym' or 'fa'");
-    }(Mode);
+    OutputEmbeddingMode = Mode;
 
     if (VocabPath.empty())
       throw nb::value_error("Empty Vocab Path not allowed");
@@ -200,8 +194,15 @@ class PyIR2VecTool {
 NB_MODULE(ir2vec, m) {
   m.doc() = std::string("Python bindings for ") + ToolName;
 
+  nb::enum_<IR2VecKind>(m, "IR2VecKind",
+                       "Embedding mode for IR2Vec representations")
+      .value("Symbolic", IR2VecKind::Symbolic, "Symbolic encodings only")
+      .value("FlowAware", IR2VecKind::FlowAware,
+            "Flow-aware encodings (includes data/control flow)")
+      .export_values();
+
   nb::class_<PyIR2VecTool>(m, "IR2VecTool")
-      .def(nb::init<const std::string &, const std::string &,
+      .def(nb::init<const std::string &, IR2VecKind,
                     const std::string &>(),
            nb::arg("filename"), nb::arg("mode"), nb::arg("vocabPath"))
       .def("getFuncNames", &PyIR2VecTool::getFuncNames,
@@ -228,10 +229,10 @@ NB_MODULE(ir2vec, m) {
 
   m.def(
       "initEmbedding",
-      [](const std::string &filename, const std::string &mode,
+      [](const std::string &filename, IR2VecKind mode,
          const std::string &vocabPath) {
         return std::make_unique<PyIR2VecTool>(filename, mode, vocabPath);
       },
-      nb::arg("filename"), nb::arg("mode") = "sym", nb::arg("vocabPath"),
+      nb::arg("filename"), nb::arg("mode") = IR2VecKind::Symbolic, nb::arg("vocabPath"),
       nb::rv_policy::take_ownership);
 }



More information about the llvm-commits mailing list