[llvm-branch-commits] [llvm] Support predicates (PR #156952)
S. VenkataKeerthy via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Sep 4 12:07:42 PDT 2025
https://github.com/svkeerthy created https://github.com/llvm/llvm-project/pull/156952
None
>From 6185e40a9a6731955e190131067cc3c5bc90595e Mon Sep 17 00:00:00 2001
From: svkeerthy <venkatakeerthy at google.com>
Date: Wed, 3 Sep 2025 22:56:08 +0000
Subject: [PATCH] Support predicates
---
llvm/include/llvm/Analysis/IR2Vec.h | 43 +++++++++++++---
llvm/lib/Analysis/IR2Vec.cpp | 70 ++++++++++++++++++++++----
llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp | 2 +-
3 files changed, 98 insertions(+), 17 deletions(-)
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index b7b881999241e..d49854e2d06a8 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -36,6 +36,7 @@
#define LLVM_ANALYSIS_IR2VEC_H
#include "llvm/ADT/DenseMap.h"
+#include "llvm/IR/Instructions.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/Type.h"
#include "llvm/Support/CommandLine.h"
@@ -162,16 +163,25 @@ using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
/// embeddings.
class Vocabulary {
friend class llvm::IR2VecVocabAnalysis;
+ // Slot layout:
+ // [0 .. MaxOpcodes-1] => Instruction
+ // opcodes [MaxOpcodes .. MaxOpcodes+MaxCanonicalTypeIDs-1] =>
+ // Canonicalized types [MaxOpcodes+MaxCanonicalTypeIDs .. end of operands) =>
+ // Operands
+ // Within Operands: first OperandKind entries, followed by compare
+ // predicates
using VocabVector = std::vector<ir2vec::Embedding>;
VocabVector Vocab;
+
bool Valid = false;
+ static constexpr unsigned NumICmpPredicates =
+ static_cast<unsigned>(CmpInst::LAST_ICMP_PREDICATE) -
+ static_cast<unsigned>(CmpInst::FIRST_ICMP_PREDICATE) + 1;
+ static constexpr unsigned NumFCmpPredicates =
+ static_cast<unsigned>(CmpInst::LAST_FCMP_PREDICATE) -
+ static_cast<unsigned>(CmpInst::FIRST_FCMP_PREDICATE) + 1;
public:
- // Slot layout:
- // [0 .. MaxOpcodes-1] => Instruction opcodes
- // [MaxOpcodes .. MaxOpcodes+MaxCanonicalTypeIDs-1] => Canonicalized types
- // [MaxOpcodes+MaxCanonicalTypeIDs .. NumCanonicalEntries-1] => Operand kinds
-
/// Canonical type IDs supported by IR2Vec Vocabulary
enum class CanonicalTypeID : unsigned {
FloatTy,
@@ -208,13 +218,18 @@ class Vocabulary {
static_cast<unsigned>(CanonicalTypeID::MaxCanonicalType);
static constexpr unsigned MaxOperandKinds =
static_cast<unsigned>(OperandKind::MaxOperandKind);
+ // CmpInst::Predicate has gaps. We want the vocabulary to be dense without
+ // empty slots.
+ static constexpr unsigned MaxPredicateKinds =
+ NumICmpPredicates + NumFCmpPredicates;
Vocabulary() = default;
LLVM_ABI Vocabulary(VocabVector &&Vocab);
LLVM_ABI bool isValid() const;
LLVM_ABI unsigned getDimension() const;
- /// Total number of entries (opcodes + canonicalized types + operand kinds)
+ /// Total number of entries (opcodes + canonicalized types + operand kinds +
+ /// predicates)
static constexpr size_t getCanonicalSize() { return NumCanonicalEntries; }
/// Function to get vocabulary key for a given Opcode
@@ -229,16 +244,21 @@ class Vocabulary {
/// Function to classify an operand into OperandKind
LLVM_ABI static OperandKind getOperandKind(const Value *Op);
+ /// Function to get vocabulary key for a given predicate
+ LLVM_ABI static StringRef getVocabKeyForPredicate(CmpInst::Predicate P);
+
/// Functions to return the slot index or position of a given Opcode, TypeID,
/// or OperandKind in the vocabulary.
LLVM_ABI static unsigned getSlotIndex(unsigned Opcode);
LLVM_ABI static unsigned getSlotIndex(Type::TypeID TypeID);
LLVM_ABI static unsigned getSlotIndex(const Value &Op);
+ LLVM_ABI static unsigned getSlotIndex(CmpInst::Predicate P);
/// Accessors to get the embedding for a given entity.
LLVM_ABI const ir2vec::Embedding &operator[](unsigned Opcode) const;
LLVM_ABI const ir2vec::Embedding &operator[](Type::TypeID TypeId) const;
LLVM_ABI const ir2vec::Embedding &operator[](const Value &Arg) const;
+ LLVM_ABI const ir2vec::Embedding &operator[](CmpInst::Predicate P) const;
/// Const Iterator type aliases
using const_iterator = VocabVector::const_iterator;
@@ -275,7 +295,13 @@ class Vocabulary {
private:
constexpr static unsigned NumCanonicalEntries =
- MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds;
+ MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds + MaxPredicateKinds;
+
+ // Base offsets for slot layout to simplify index computation
+ constexpr static unsigned OperandBaseOffset =
+ MaxOpcodes + MaxCanonicalTypeIDs;
+ constexpr static unsigned PredicateBaseOffset =
+ OperandBaseOffset + MaxOperandKinds;
/// String mappings for CanonicalTypeID values
static constexpr StringLiteral CanonicalTypeNames[] = {
@@ -327,6 +353,9 @@ class Vocabulary {
/// Function to convert TypeID to CanonicalTypeID
LLVM_ABI static CanonicalTypeID getCanonicalTypeID(Type::TypeID TypeID);
+
+ /// Function to get the predicate enum value for a given index
+ LLVM_ABI static CmpInst::Predicate getPredicate(unsigned Index);
};
/// Embedder provides the interface to generate embeddings (vector
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 98849fd922843..c79c9c1ed493d 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -216,6 +216,8 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
ArgEmb += Vocab[*Op];
auto InstVector =
Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
+ if (const auto *IC = dyn_cast<CmpInst>(&I))
+ InstVector += Vocab[IC->getPredicate()];
InstVecMap[&I] = InstVector;
BBVector += InstVector;
}
@@ -250,6 +252,9 @@ void FlowAwareEmbedder::computeEmbeddings(const BasicBlock &BB) const {
// embeddings
auto InstVector =
Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
+ // Add compare predicate embedding as an additional operand if applicable
+ if (const auto *IC = dyn_cast<CmpInst>(&I))
+ InstVector += Vocab[IC->getPredicate()];
InstVecMap[&I] = InstVector;
BBVector += InstVector;
}
@@ -285,7 +290,17 @@ unsigned Vocabulary::getSlotIndex(Type::TypeID TypeID) {
unsigned Vocabulary::getSlotIndex(const Value &Op) {
unsigned Index = static_cast<unsigned>(getOperandKind(&Op));
assert(Index < MaxOperandKinds && "Invalid OperandKind");
- return MaxOpcodes + MaxCanonicalTypeIDs + Index;
+ return OperandBaseOffset + Index;
+}
+
+unsigned Vocabulary::getSlotIndex(CmpInst::Predicate P) {
+ unsigned PU = static_cast<unsigned>(P);
+ unsigned FirstFC = static_cast<unsigned>(CmpInst::FIRST_FCMP_PREDICATE);
+ unsigned FirstIC = static_cast<unsigned>(CmpInst::FIRST_ICMP_PREDICATE);
+
+ unsigned PredIdx =
+ (PU >= FirstIC) ? (NumFCmpPredicates + (PU - FirstIC)) : (PU - FirstFC);
+ return PredicateBaseOffset + PredIdx;
}
const Embedding &Vocabulary::operator[](unsigned Opcode) const {
@@ -300,6 +315,10 @@ const ir2vec::Embedding &Vocabulary::operator[](const Value &Arg) const {
return Vocab[getSlotIndex(Arg)];
}
+const ir2vec::Embedding &Vocabulary::operator[](CmpInst::Predicate P) const {
+ return Vocab[getSlotIndex(P)];
+}
+
StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) {
assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
#define HANDLE_INST(NUM, OPCODE, CLASS) \
@@ -345,18 +364,35 @@ Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) {
return OperandKind::VariableID;
}
+CmpInst::Predicate Vocabulary::getPredicate(unsigned Index) {
+ assert(Index < MaxPredicateKinds && "Invalid predicate index");
+ unsigned PredEnumVal =
+ (Index < NumFCmpPredicates)
+ ? (static_cast<unsigned>(CmpInst::FIRST_FCMP_PREDICATE) + Index)
+ : (static_cast<unsigned>(CmpInst::FIRST_ICMP_PREDICATE) +
+ (Index - NumFCmpPredicates));
+ return static_cast<CmpInst::Predicate>(PredEnumVal);
+}
+
+StringRef Vocabulary::getVocabKeyForPredicate(CmpInst::Predicate Pred) {
+ return CmpInst::getPredicateName(Pred);
+}
+
StringRef Vocabulary::getStringKey(unsigned Pos) {
assert(Pos < NumCanonicalEntries && "Position out of bounds in vocabulary");
// Opcode
if (Pos < MaxOpcodes)
return getVocabKeyForOpcode(Pos + 1);
// Type
- if (Pos < MaxOpcodes + MaxCanonicalTypeIDs)
+ if (Pos < OperandBaseOffset)
return getVocabKeyForCanonicalTypeID(
static_cast<CanonicalTypeID>(Pos - MaxOpcodes));
// Operand
- return getVocabKeyForOperandKind(
- static_cast<OperandKind>(Pos - MaxOpcodes - MaxCanonicalTypeIDs));
+ if (Pos < PredicateBaseOffset)
+ return getVocabKeyForOperandKind(
+ static_cast<OperandKind>(Pos - OperandBaseOffset));
+ // Predicates
+ return getVocabKeyForPredicate(getPredicate(Pos - PredicateBaseOffset));
}
// For now, assume vocabulary is stable unless explicitly invalidated.
@@ -370,11 +406,9 @@ Vocabulary::VocabVector Vocabulary::createDummyVocabForTest(unsigned Dim) {
VocabVector DummyVocab;
DummyVocab.reserve(NumCanonicalEntries);
float DummyVal = 0.1f;
- // Create a dummy vocabulary with entries for all opcodes, types, and
- // operands
- for ([[maybe_unused]] unsigned _ :
- seq(0u, Vocabulary::MaxOpcodes + Vocabulary::MaxCanonicalTypeIDs +
- Vocabulary::MaxOperandKinds)) {
+ // Create a dummy vocabulary with entries for all opcodes, types, operands
+ // and predicates
+ for ([[maybe_unused]] unsigned _ : seq(0u, Vocabulary::NumCanonicalEntries)) {
DummyVocab.push_back(Embedding(Dim, DummyVal));
DummyVal += 0.1f;
}
@@ -517,6 +551,24 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
}
Vocab.insert(Vocab.end(), NumericArgEmbeddings.begin(),
NumericArgEmbeddings.end());
+
+ // Handle Predicates: part of Operands section. We look up predicate keys
+ // in ArgVocab.
+ std::vector<Embedding> NumericPredEmbeddings(Vocabulary::MaxPredicateKinds,
+ Embedding(Dim, 0));
+ NumericPredEmbeddings.reserve(Vocabulary::MaxPredicateKinds);
+ for (unsigned PK : seq(0u, Vocabulary::MaxPredicateKinds)) {
+ StringRef VocabKey =
+ Vocabulary::getVocabKeyForPredicate(Vocabulary::getPredicate(PK));
+ auto It = ArgVocab.find(VocabKey.str());
+ if (It != ArgVocab.end()) {
+ NumericPredEmbeddings[PK] = It->second;
+ continue;
+ }
+ handleMissingEntity(VocabKey.str());
+ }
+ Vocab.insert(Vocab.end(), NumericPredEmbeddings.begin(),
+ NumericPredEmbeddings.end());
}
IR2VecVocabAnalysis::IR2VecVocabAnalysis(const VocabVector &Vocab)
diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
index aabebf0cc90a9..1c656b8fcf4e7 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
@@ -184,7 +184,7 @@ class IR2VecTool {
// Add "Arg" relationships
unsigned ArgIndex = 0;
for (const Use &U : I.operands()) {
- unsigned OperandID = Vocabulary::getSlotIndex(*U);
+ unsigned OperandID = Vocabulary::getSlotIndex(*U.get());
unsigned RelationID = ArgRelation + ArgIndex;
OS << Opcode << '\t' << OperandID << '\t' << RelationID << '\n';
More information about the llvm-branch-commits
mailing list