[llvm] [NFC][IR2Vec] Minor refactoring of opcode access in vocabulary (PR #147585)
S. VenkataKeerthy via llvm-commits
llvm-commits at lists.llvm.org
Mon Jul 14 11:09:43 PDT 2025
https://github.com/svkeerthy updated https://github.com/llvm/llvm-project/pull/147585
>From bebdb9e5f630723c377296dd2f4d40a6c748af6a Mon Sep 17 00:00:00 2001
From: svkeerthy <venkatakeerthy at google.com>
Date: Mon, 7 Jul 2025 21:30:29 +0000
Subject: [PATCH] [NFC][IR2Vec] Minor refactoring of opcode access in
vocabulary
---
llvm/include/llvm/Analysis/IR2Vec.h | 9 ++++---
llvm/lib/Analysis/IR2Vec.cpp | 41 ++++++++++++++++-------------
2 files changed, 28 insertions(+), 22 deletions(-)
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 151fb0a5e8ac6..0127df7970010 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -162,15 +162,18 @@ class Vocabulary {
static constexpr unsigned MaxOperandKinds =
static_cast<unsigned>(OperandKind::MaxOperandKind);
+ /// Helper function to get vocabulary key for a given Opcode
+ static StringRef getVocabKeyForOpcode(unsigned Opcode);
+
+ /// Helper function to get vocabulary key for a given TypeID
+ static StringRef getVocabKeyForTypeID(Type::TypeID TypeID);
+
/// Helper function to get vocabulary key for a given OperandKind
static StringRef getVocabKeyForOperandKind(OperandKind Kind);
/// Helper function to classify an operand into OperandKind
static OperandKind getOperandKind(const Value *Op);
- /// Helper function to get vocabulary key for a given TypeID
- static StringRef getVocabKeyForTypeID(Type::TypeID TypeID);
-
public:
Vocabulary() = default;
Vocabulary(VocabVector &&Vocab);
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index b1255c76367b2..c6e1fa32c9ffd 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -243,6 +243,17 @@ const ir2vec::Embedding &Vocabulary::operator[](const Value *Arg) const {
return Vocab[MaxOpcodes + MaxTypeIDs + static_cast<unsigned>(ArgKind)];
}
+StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) {
+ assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
+#define HANDLE_INST(NUM, OPCODE, CLASS) \
+ if (Opcode == NUM) { \
+ return #OPCODE; \
+ }
+#include "llvm/IR/Instruction.def"
+#undef HANDLE_INST
+ return "UnknownOpcode";
+}
+
StringRef Vocabulary::getVocabKeyForTypeID(Type::TypeID TypeID) {
switch (TypeID) {
case Type::VoidTyID:
@@ -280,6 +291,7 @@ StringRef Vocabulary::getVocabKeyForTypeID(Type::TypeID TypeID) {
default:
return "UnknownTy";
}
+ return "UnknownTy";
}
StringRef Vocabulary::getVocabKeyForOperandKind(Vocabulary::OperandKind Kind) {
@@ -316,14 +328,8 @@ StringRef Vocabulary::getStringKey(unsigned Pos) {
assert(Pos < MaxOpcodes + MaxTypeIDs + MaxOperandKinds &&
"Position out of bounds in vocabulary");
// Opcode
- if (Pos < MaxOpcodes) {
-#define HANDLE_INST(NUM, OPCODE, CLASS) \
- if (Pos == NUM - 1) { \
- return #OPCODE; \
- }
-#include "llvm/IR/Instruction.def"
-#undef HANDLE_INST
- }
+ if (Pos < MaxOpcodes)
+ return getVocabKeyForOpcode(Pos + 1);
// Type
if (Pos < MaxOpcodes + MaxTypeIDs)
return getVocabKeyForTypeID(static_cast<Type::TypeID>(Pos - MaxOpcodes));
@@ -431,21 +437,18 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
// Handle Opcodes
std::vector<Embedding> NumericOpcodeEmbeddings(Vocabulary::MaxOpcodes,
Embedding(Dim, 0));
-#define HANDLE_INST(NUM, OPCODE, CLASS) \
- { \
- auto It = OpcVocab.find(#OPCODE); \
- if (It != OpcVocab.end()) \
- NumericOpcodeEmbeddings[NUM - 1] = It->second; \
- else \
- handleMissingEntity(#OPCODE); \
+ for (unsigned Opcode : seq(0u, Vocabulary::MaxOpcodes)) {
+ StringRef VocabKey = Vocabulary::getVocabKeyForOpcode(Opcode + 1);
+ auto It = OpcVocab.find(VocabKey.str());
+ if (It != OpcVocab.end())
+ NumericOpcodeEmbeddings[Opcode] = It->second;
+ else
+ handleMissingEntity(VocabKey.str());
}
-#include "llvm/IR/Instruction.def"
-#undef HANDLE_INST
Vocab.insert(Vocab.end(), NumericOpcodeEmbeddings.begin(),
NumericOpcodeEmbeddings.end());
- // Handle Types using direct iteration through TypeID enum
- // We iterate through all possible TypeID values and map them to embeddings
+ // Handle Types
std::vector<Embedding> NumericTypeEmbeddings(Vocabulary::MaxTypeIDs,
Embedding(Dim, 0));
for (unsigned TypeID : seq(0u, Vocabulary::MaxTypeIDs)) {
More information about the llvm-commits
mailing list