[llvm-branch-commits] [llvm] [IR2Vec][NFC] Add helper methods for numeric ID mapping in Vocabulary (PR #149212)
S. VenkataKeerthy via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Jul 17 11:04:56 PDT 2025
https://github.com/svkeerthy updated https://github.com/llvm/llvm-project/pull/149212
>From 1d7ca8076757401353b403256f03ae9498dbe404 Mon Sep 17 00:00:00 2001
From: svkeerthy <venkatakeerthy at google.com>
Date: Wed, 16 Jul 2025 21:49:05 +0000
Subject: [PATCH] exposing-new-methods
---
llvm/include/llvm/Analysis/IR2Vec.h | 9 ++++
llvm/lib/Analysis/IR2Vec.cpp | 20 +++++++-
llvm/unittests/Analysis/IR2VecTest.cpp | 63 ++++++++++++++++++++++++++
3 files changed, 90 insertions(+), 2 deletions(-)
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 3d7edf08c8807..d87457cac7642 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -170,6 +170,10 @@ class Vocabulary {
unsigned getDimension() const;
size_t size() const;
+ static size_t expectedSize() {
+ return MaxOpcodes + MaxTypeIDs + MaxOperandKinds;
+ }
+
/// Helper function to get vocabulary key for a given Opcode
static StringRef getVocabKeyForOpcode(unsigned Opcode);
@@ -182,6 +186,11 @@ class Vocabulary {
/// Helper function to classify an operand into OperandKind
static OperandKind getOperandKind(const Value *Op);
+ /// Helpers to return the IDs of a given Opcode, TypeID, or OperandKind
+ static unsigned getNumericID(unsigned Opcode);
+ static unsigned getNumericID(Type::TypeID TypeID);
+ static unsigned getNumericID(const Value *Op);
+
/// Accessors to get the embedding for a given entity.
const ir2vec::Embedding &operator[](unsigned Opcode) const;
const ir2vec::Embedding &operator[](Type::TypeID TypeId) const;
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 898bf5b202feb..95f30fd3f4275 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -215,7 +215,7 @@ Vocabulary::Vocabulary(VocabVector &&Vocab)
: Vocab(std::move(Vocab)), Valid(true) {}
bool Vocabulary::isValid() const {
- return Vocab.size() == (MaxOpcodes + MaxTypeIDs + MaxOperandKinds) && Valid;
+ return Vocab.size() == Vocabulary::expectedSize() && Valid;
}
size_t Vocabulary::size() const {
@@ -324,8 +324,24 @@ Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) {
return OperandKind::VariableID;
}
+unsigned Vocabulary::getNumericID(unsigned Opcode) {
+ assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
+ return Opcode - 1; // Convert to zero-based index
+}
+
+unsigned Vocabulary::getNumericID(Type::TypeID TypeID) {
+ assert(static_cast<unsigned>(TypeID) < MaxTypeIDs && "Invalid type ID");
+ return MaxOpcodes + static_cast<unsigned>(TypeID);
+}
+
+unsigned Vocabulary::getNumericID(const Value *Op) {
+ unsigned Index = static_cast<unsigned>(getOperandKind(Op));
+ assert(Index < MaxOperandKinds && "Invalid OperandKind");
+ return MaxOpcodes + MaxTypeIDs + Index;
+}
+
StringRef Vocabulary::getStringKey(unsigned Pos) {
- assert(Pos < MaxOpcodes + MaxTypeIDs + MaxOperandKinds &&
+ assert(Pos < Vocabulary::expectedSize() &&
"Position out of bounds in vocabulary");
// Opcode
if (Pos < MaxOpcodes)
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp
index cb6d633306a81..7c9a5464bfe1d 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -396,6 +396,69 @@ TEST(IR2VecVocabularyTest, DummyVocabTest) {
}
}
+TEST(IR2VecVocabularyTest, NumericIDMap) {
+ // Test getNumericID for opcodes
+ EXPECT_EQ(Vocabulary::getNumericID(1u), 0u);
+ EXPECT_EQ(Vocabulary::getNumericID(13u), 12u);
+ EXPECT_EQ(Vocabulary::getNumericID(MaxOpcodes), MaxOpcodes - 1);
+
+ // Test getNumericID for Type IDs
+ EXPECT_EQ(Vocabulary::getNumericID(Type::VoidTyID),
+ MaxOpcodes + static_cast<unsigned>(Type::VoidTyID));
+ EXPECT_EQ(Vocabulary::getNumericID(Type::HalfTyID),
+ MaxOpcodes + static_cast<unsigned>(Type::HalfTyID));
+ EXPECT_EQ(Vocabulary::getNumericID(Type::FloatTyID),
+ MaxOpcodes + static_cast<unsigned>(Type::FloatTyID));
+ EXPECT_EQ(Vocabulary::getNumericID(Type::IntegerTyID),
+ MaxOpcodes + static_cast<unsigned>(Type::IntegerTyID));
+ EXPECT_EQ(Vocabulary::getNumericID(Type::PointerTyID),
+ MaxOpcodes + static_cast<unsigned>(Type::PointerTyID));
+
+ // Test getNumericID for Value operands
+ LLVMContext Ctx;
+ Module M("TestM", Ctx);
+ FunctionType *FTy =
+ FunctionType::get(Type::getVoidTy(Ctx), {Type::getInt32Ty(Ctx)}, false);
+ Function *F = Function::Create(FTy, Function::ExternalLinkage, "testFunc", M);
+
+ // Test Function operand
+ EXPECT_EQ(Vocabulary::getNumericID(F),
+ MaxOpcodes + MaxTypeIDs + 0u); // Function = 0
+
+ // Test Constant operand
+ Constant *C = ConstantInt::get(Type::getInt32Ty(Ctx), 42);
+ EXPECT_EQ(Vocabulary::getNumericID(C),
+ MaxOpcodes + MaxTypeIDs + 2u); // Constant = 2
+
+ // Test Pointer operand
+ BasicBlock *BB = BasicBlock::Create(Ctx, "entry", F);
+ AllocaInst *PtrVal = new AllocaInst(Type::getInt32Ty(Ctx), 0, "ptr", BB);
+ EXPECT_EQ(Vocabulary::getNumericID(PtrVal),
+ MaxOpcodes + MaxTypeIDs + 1u); // Pointer = 1
+
+ // Test Variable operand (function argument)
+ Argument *Arg = F->getArg(0);
+ EXPECT_EQ(Vocabulary::getNumericID(Arg),
+ MaxOpcodes + MaxTypeIDs + 3u); // Variable = 3
+}
+
+#if GTEST_HAS_DEATH_TEST
+#ifndef NDEBUG
+TEST(IR2VecVocabularyTest, NumericIDMapInvalidInputs) {
+ // Test invalid opcode IDs
+ EXPECT_DEATH(Vocabulary::getNumericID(0u), "Invalid opcode");
+ EXPECT_DEATH(Vocabulary::getNumericID(MaxOpcodes + 1), "Invalid opcode");
+
+ // Test invalid type IDs
+ EXPECT_DEATH(Vocabulary::getNumericID(static_cast<Type::TypeID>(MaxTypeIDs)),
+ "Invalid type ID");
+ EXPECT_DEATH(
+ Vocabulary::getNumericID(static_cast<Type::TypeID>(MaxTypeIDs + 10)),
+ "Invalid type ID");
+}
+#endif // NDEBUG
+#endif // GTEST_HAS_DEATH_TEST
+
TEST(IR2VecVocabularyTest, StringKeyGeneration) {
EXPECT_EQ(Vocabulary::getStringKey(0), "Ret");
EXPECT_EQ(Vocabulary::getStringKey(12), "Add");
More information about the llvm-branch-commits
mailing list