[llvm] 52b1850 - [IR2Vec] Add support for Cmp predicates in vocabulary and embeddings (#156952)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Oct 1 16:39:27 PDT 2025
Author: S. VenkataKeerthy
Date: 2025-10-01T16:39:22-07:00
New Revision: 52b185075919a3d8ec9dc6b7d7da365e28532735
URL: https://github.com/llvm/llvm-project/commit/52b185075919a3d8ec9dc6b7d7da365e28532735
DIFF: https://github.com/llvm/llvm-project/commit/52b185075919a3d8ec9dc6b7d7da365e28532735.diff
LOG: [IR2Vec] Add support for Cmp predicates in vocabulary and embeddings (#156952)
Comparison predicates (equal, not equal, greater than, etc.) provide important semantic information about program behavior. Previously, IR2Vec only captured that a comparison was happening but not what kind of comparison it was. This PR extends the IR2Vec vocabulary to include comparison predicates (ICmp and FCmp) as part of the embedding space.
Following are the changes:
1. Expand the vocabulary slot layout to include predicate entries after opcodes, types, and operands
2. Add methods to handle predicate embedding lookups and conversions
3. Update the embedder implementations to include predicate information when processing CmpInst instructions
4. Update test files to include the new predicate entries in the vocabulary
(Tracking issues: #141817, #141833)
Added:
Modified:
llvm/include/llvm/Analysis/IR2Vec.h
llvm/lib/Analysis/IR2Vec.cpp
llvm/test/Analysis/IR2Vec/Inputs/dummy_2D_vocab.json
llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_arg_vocab.json
llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json
llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt
llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt
llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt
llvm/test/Analysis/IR2Vec/if-else.ll
llvm/test/Analysis/IR2Vec/unreachable.ll
llvm/test/tools/llvm-ir2vec/entities.ll
llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
llvm/unittests/Analysis/IR2VecTest.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 3671c1c71ac0b..f3f9de460218b 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -36,6 +36,7 @@
#define LLVM_ANALYSIS_IR2VEC_H
#include "llvm/ADT/DenseMap.h"
+#include "llvm/IR/Instructions.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/Type.h"
#include "llvm/Support/CommandLine.h"
@@ -162,15 +163,34 @@ using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
/// embeddings.
class Vocabulary {
friend class llvm::IR2VecVocabAnalysis;
+
+ // Vocabulary Slot Layout:
+ // +----------------+------------------------------------------------------+
+ // | Entity Type | Index Range |
+ // +----------------+------------------------------------------------------+
+ // | Opcodes | [0 .. (MaxOpcodes-1)] |
+ // | Canonical Types| [MaxOpcodes .. (MaxOpcodes+MaxCanonicalTypeIDs-1)] |
+ // | Operands | [(MaxOpcodes+MaxCanonicalTypeIDs) .. NumCanEntries] |
+ // +----------------+------------------------------------------------------+
+ // Note: MaxOpcodes is the number of unique opcodes supported by LLVM IR.
+ // MaxCanonicalTypeIDs is the number of canonicalized type IDs.
+ // "Similar" LLVM Types are grouped/canonicalized together. E.g., all
+ // float variants (FloatTy, DoubleTy, HalfTy, etc.) map to
+ // CanonicalTypeID::FloatTy. This helps reduce the vocabulary size
+ // and improves learning. Operands include Comparison predicates
+ // (ICmp/FCmp) along with other operand types. This can be extended to
+ // include other specializations in future.
using VocabVector = std::vector<ir2vec::Embedding>;
VocabVector Vocab;
-public:
- // Slot layout:
- // [0 .. MaxOpcodes-1] => Instruction opcodes
- // [MaxOpcodes .. MaxOpcodes+MaxCanonicalTypeIDs-1] => Canonicalized types
- // [MaxOpcodes+MaxCanonicalTypeIDs .. NumCanonicalEntries-1] => Operand kinds
+ static constexpr unsigned NumICmpPredicates =
+ static_cast<unsigned>(CmpInst::LAST_ICMP_PREDICATE) -
+ static_cast<unsigned>(CmpInst::FIRST_ICMP_PREDICATE) + 1;
+ static constexpr unsigned NumFCmpPredicates =
+ static_cast<unsigned>(CmpInst::LAST_FCMP_PREDICATE) -
+ static_cast<unsigned>(CmpInst::FIRST_FCMP_PREDICATE) + 1;
+public:
/// Canonical type IDs supported by IR2Vec Vocabulary
enum class CanonicalTypeID : unsigned {
FloatTy,
@@ -207,13 +227,18 @@ class Vocabulary {
static_cast<unsigned>(CanonicalTypeID::MaxCanonicalType);
static constexpr unsigned MaxOperandKinds =
static_cast<unsigned>(OperandKind::MaxOperandKind);
+ // CmpInst::Predicate has gaps. We want the vocabulary to be dense without
+ // empty slots.
+ static constexpr unsigned MaxPredicateKinds =
+ NumICmpPredicates + NumFCmpPredicates;
Vocabulary() = default;
LLVM_ABI Vocabulary(VocabVector &&Vocab) : Vocab(std::move(Vocab)) {}
LLVM_ABI bool isValid() const { return Vocab.size() == NumCanonicalEntries; };
LLVM_ABI unsigned getDimension() const;
- /// Total number of entries (opcodes + canonicalized types + operand kinds)
+ /// Total number of entries (opcodes + canonicalized types + operand kinds +
+ /// predicates)
static constexpr size_t getCanonicalSize() { return NumCanonicalEntries; }
/// Function to get vocabulary key for a given Opcode
@@ -228,16 +253,21 @@ class Vocabulary {
/// Function to classify an operand into OperandKind
LLVM_ABI static OperandKind getOperandKind(const Value *Op);
+ /// Function to get vocabulary key for a given predicate
+ LLVM_ABI static StringRef getVocabKeyForPredicate(CmpInst::Predicate P);
+
/// Functions to return the slot index or position of a given Opcode, TypeID,
/// or OperandKind in the vocabulary.
LLVM_ABI static unsigned getSlotIndex(unsigned Opcode);
LLVM_ABI static unsigned getSlotIndex(Type::TypeID TypeID);
LLVM_ABI static unsigned getSlotIndex(const Value &Op);
+ LLVM_ABI static unsigned getSlotIndex(CmpInst::Predicate P);
/// Accessors to get the embedding for a given entity.
LLVM_ABI const ir2vec::Embedding &operator[](unsigned Opcode) const;
LLVM_ABI const ir2vec::Embedding &operator[](Type::TypeID TypeId) const;
LLVM_ABI const ir2vec::Embedding &operator[](const Value &Arg) const;
+ LLVM_ABI const ir2vec::Embedding &operator[](CmpInst::Predicate P) const;
/// Const Iterator type aliases
using const_iterator = VocabVector::const_iterator;
@@ -274,7 +304,13 @@ class Vocabulary {
private:
constexpr static unsigned NumCanonicalEntries =
- MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds;
+ MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds + MaxPredicateKinds;
+
+ // Base offsets for slot layout to simplify index computation
+ constexpr static unsigned OperandBaseOffset =
+ MaxOpcodes + MaxCanonicalTypeIDs;
+ constexpr static unsigned PredicateBaseOffset =
+ OperandBaseOffset + MaxOperandKinds;
/// String mappings for CanonicalTypeID values
static constexpr StringLiteral CanonicalTypeNames[] = {
@@ -326,6 +362,11 @@ class Vocabulary {
/// Function to convert TypeID to CanonicalTypeID
LLVM_ABI static CanonicalTypeID getCanonicalTypeID(Type::TypeID TypeID);
+
+ /// Function to get the predicate enum value for a given index. Index is
+ /// relative to the predicates section of the vocabulary. E.g., Index 0
+ /// corresponds to the first predicate.
+ LLVM_ABI static CmpInst::Predicate getPredicate(unsigned Index);
};
/// Embedder provides the interface to generate embeddings (vector
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 99afc0601d523..f51f0898cb37e 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -216,6 +216,8 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
ArgEmb += Vocab[*Op];
auto InstVector =
Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
+ if (const auto *IC = dyn_cast<CmpInst>(&I))
+ InstVector += Vocab[IC->getPredicate()];
InstVecMap[&I] = InstVector;
BBVector += InstVector;
}
@@ -250,6 +252,9 @@ void FlowAwareEmbedder::computeEmbeddings(const BasicBlock &BB) const {
// embeddings
auto InstVector =
Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
+ // Add compare predicate embedding as an additional operand if applicable
+ if (const auto *IC = dyn_cast<CmpInst>(&I))
+ InstVector += Vocab[IC->getPredicate()];
InstVecMap[&I] = InstVector;
BBVector += InstVector;
}
@@ -278,7 +283,17 @@ unsigned Vocabulary::getSlotIndex(Type::TypeID TypeID) {
unsigned Vocabulary::getSlotIndex(const Value &Op) {
unsigned Index = static_cast<unsigned>(getOperandKind(&Op));
assert(Index < MaxOperandKinds && "Invalid OperandKind");
- return MaxOpcodes + MaxCanonicalTypeIDs + Index;
+ return OperandBaseOffset + Index;
+}
+
+unsigned Vocabulary::getSlotIndex(CmpInst::Predicate P) {
+ unsigned PU = static_cast<unsigned>(P);
+ unsigned FirstFC = static_cast<unsigned>(CmpInst::FIRST_FCMP_PREDICATE);
+ unsigned FirstIC = static_cast<unsigned>(CmpInst::FIRST_ICMP_PREDICATE);
+
+ unsigned PredIdx =
+ (PU >= FirstIC) ? (NumFCmpPredicates + (PU - FirstIC)) : (PU - FirstFC);
+ return PredicateBaseOffset + PredIdx;
}
const Embedding &Vocabulary::operator[](unsigned Opcode) const {
@@ -293,6 +308,10 @@ const ir2vec::Embedding &Vocabulary::operator[](const Value &Arg) const {
return Vocab[getSlotIndex(Arg)];
}
+const ir2vec::Embedding &Vocabulary::operator[](CmpInst::Predicate P) const {
+ return Vocab[getSlotIndex(P)];
+}
+
StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) {
assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
#define HANDLE_INST(NUM, OPCODE, CLASS) \
@@ -338,18 +357,41 @@ Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) {
return OperandKind::VariableID;
}
+CmpInst::Predicate Vocabulary::getPredicate(unsigned Index) {
+ assert(Index < MaxPredicateKinds && "Invalid predicate index");
+ unsigned PredEnumVal =
+ (Index < NumFCmpPredicates)
+ ? (static_cast<unsigned>(CmpInst::FIRST_FCMP_PREDICATE) + Index)
+ : (static_cast<unsigned>(CmpInst::FIRST_ICMP_PREDICATE) +
+ (Index - NumFCmpPredicates));
+ return static_cast<CmpInst::Predicate>(PredEnumVal);
+}
+
+StringRef Vocabulary::getVocabKeyForPredicate(CmpInst::Predicate Pred) {
+ static SmallString<16> PredNameBuffer;
+ if (Pred < CmpInst::FIRST_ICMP_PREDICATE)
+ PredNameBuffer = "FCMP_";
+ else
+ PredNameBuffer = "ICMP_";
+ PredNameBuffer += CmpInst::getPredicateName(Pred);
+ return PredNameBuffer;
+}
+
StringRef Vocabulary::getStringKey(unsigned Pos) {
assert(Pos < NumCanonicalEntries && "Position out of bounds in vocabulary");
// Opcode
if (Pos < MaxOpcodes)
return getVocabKeyForOpcode(Pos + 1);
// Type
- if (Pos < MaxOpcodes + MaxCanonicalTypeIDs)
+ if (Pos < OperandBaseOffset)
return getVocabKeyForCanonicalTypeID(
static_cast<CanonicalTypeID>(Pos - MaxOpcodes));
// Operand
- return getVocabKeyForOperandKind(
- static_cast<OperandKind>(Pos - MaxOpcodes - MaxCanonicalTypeIDs));
+ if (Pos < PredicateBaseOffset)
+ return getVocabKeyForOperandKind(
+ static_cast<OperandKind>(Pos - OperandBaseOffset));
+ // Predicates
+ return getVocabKeyForPredicate(getPredicate(Pos - PredicateBaseOffset));
}
// For now, assume vocabulary is stable unless explicitly invalidated.
@@ -363,11 +405,9 @@ Vocabulary::VocabVector Vocabulary::createDummyVocabForTest(unsigned Dim) {
VocabVector DummyVocab;
DummyVocab.reserve(NumCanonicalEntries);
float DummyVal = 0.1f;
- // Create a dummy vocabulary with entries for all opcodes, types, and
- // operands
- for ([[maybe_unused]] unsigned _ :
- seq(0u, Vocabulary::MaxOpcodes + Vocabulary::MaxCanonicalTypeIDs +
- Vocabulary::MaxOperandKinds)) {
+ // Create a dummy vocabulary with entries for all opcodes, types, operands
+ // and predicates
+ for ([[maybe_unused]] unsigned _ : seq(0u, Vocabulary::NumCanonicalEntries)) {
DummyVocab.push_back(Embedding(Dim, DummyVal));
DummyVal += 0.1f;
}
@@ -510,6 +550,24 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
}
Vocab.insert(Vocab.end(), NumericArgEmbeddings.begin(),
NumericArgEmbeddings.end());
+
+ // Handle Predicates: part of Operands section. We look up predicate keys
+ // in ArgVocab.
+ std::vector<Embedding> NumericPredEmbeddings(Vocabulary::MaxPredicateKinds,
+ Embedding(Dim, 0));
+ NumericPredEmbeddings.reserve(Vocabulary::MaxPredicateKinds);
+ for (unsigned PK : seq(0u, Vocabulary::MaxPredicateKinds)) {
+ StringRef VocabKey =
+ Vocabulary::getVocabKeyForPredicate(Vocabulary::getPredicate(PK));
+ auto It = ArgVocab.find(VocabKey.str());
+ if (It != ArgVocab.end()) {
+ NumericPredEmbeddings[PK] = It->second;
+ continue;
+ }
+ handleMissingEntity(VocabKey.str());
+ }
+ Vocab.insert(Vocab.end(), NumericPredEmbeddings.begin(),
+ NumericPredEmbeddings.end());
}
IR2VecVocabAnalysis::IR2VecVocabAnalysis(const VocabVector &Vocab)
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/dummy_2D_vocab.json b/llvm/test/Analysis/IR2Vec/Inputs/dummy_2D_vocab.json
index 07fde84c1541b..ae36ff54686c5 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/dummy_2D_vocab.json
+++ b/llvm/test/Analysis/IR2Vec/Inputs/dummy_2D_vocab.json
@@ -87,6 +87,32 @@
"Function": [1, 2],
"Pointer": [3, 4],
"Constant": [5, 6],
- "Variable": [7, 8]
+ "Variable": [7, 8],
+ "FCMP_false": [9, 10],
+ "FCMP_oeq": [11, 12],
+ "FCMP_ogt": [13, 14],
+ "FCMP_oge": [15, 16],
+ "FCMP_olt": [17, 18],
+ "FCMP_ole": [19, 20],
+ "FCMP_one": [21, 22],
+ "FCMP_ord": [23, 24],
+ "FCMP_uno": [25, 26],
+ "FCMP_ueq": [27, 28],
+ "FCMP_ugt": [29, 30],
+ "FCMP_uge": [31, 32],
+ "FCMP_ult": [33, 34],
+ "FCMP_ule": [35, 36],
+ "FCMP_une": [37, 38],
+ "FCMP_true": [39, 40],
+ "ICMP_eq": [41, 42],
+ "ICMP_ne": [43, 44],
+ "ICMP_ugt": [45, 46],
+ "ICMP_uge": [47, 48],
+ "ICMP_ult": [49, 50],
+ "ICMP_ule": [51, 52],
+ "ICMP_sgt": [53, 54],
+ "ICMP_sge": [55, 56],
+ "ICMP_slt": [57, 58],
+ "ICMP_sle": [59, 60]
}
}
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_arg_vocab.json b/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_arg_vocab.json
index 932b3a217b70c..9003dc73954aa 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_arg_vocab.json
+++ b/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_arg_vocab.json
@@ -86,6 +86,32 @@
"Function": [1, 2, 3],
"Pointer": [4, 5, 6],
"Constant": [7, 8, 9],
- "Variable": [10, 11, 12]
+ "Variable": [10, 11, 12],
+ "FCMP_false": [13, 14, 15],
+ "FCMP_oeq": [16, 17, 18],
+ "FCMP_ogt": [19, 20, 21],
+ "FCMP_oge": [22, 23, 24],
+ "FCMP_olt": [25, 26, 27],
+ "FCMP_ole": [28, 29, 30],
+ "FCMP_one": [31, 32, 33],
+ "FCMP_ord": [34, 35, 36],
+ "FCMP_uno": [37, 38, 39],
+ "FCMP_ueq": [40, 41, 42],
+ "FCMP_ugt": [43, 44, 45],
+ "FCMP_uge": [46, 47, 48],
+ "FCMP_ult": [49, 50, 51],
+ "FCMP_ule": [52, 53, 54],
+ "FCMP_une": [55, 56, 57],
+ "FCMP_true": [58, 59, 60],
+ "ICMP_eq": [61, 62, 63],
+ "ICMP_ne": [64, 65, 66],
+ "ICMP_ugt": [67, 68, 69],
+ "ICMP_uge": [70, 71, 72],
+ "ICMP_ult": [73, 74, 75],
+ "ICMP_ule": [76, 77, 78],
+ "ICMP_sgt": [79, 80, 81],
+ "ICMP_sge": [82, 83, 84],
+ "ICMP_slt": [85, 86, 87],
+ "ICMP_sle": [88, 89, 90]
}
}
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json b/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json
index 19f3efee9f6a1..7ef85490b27df 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json
+++ b/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json
@@ -47,6 +47,7 @@
"FPTrunc": [133, 134, 135],
"FPExt": [136, 137, 138],
"PtrToInt": [139, 140, 141],
+ "PtrToAddr": [202, 203, 204],
"IntToPtr": [142, 143, 144],
"BitCast": [145, 146, 147],
"AddrSpaceCast": [148, 149, 150],
@@ -86,6 +87,32 @@
"Function": [0, 0, 0],
"Pointer": [0, 0, 0],
"Constant": [0, 0, 0],
- "Variable": [0, 0, 0]
+ "Variable": [0, 0, 0],
+ "FCMP_false": [0, 0, 0],
+ "FCMP_oeq": [0, 0, 0],
+ "FCMP_ogt": [0, 0, 0],
+ "FCMP_oge": [0, 0, 0],
+ "FCMP_olt": [0, 0, 0],
+ "FCMP_ole": [0, 0, 0],
+ "FCMP_one": [0, 0, 0],
+ "FCMP_ord": [0, 0, 0],
+ "FCMP_uno": [0, 0, 0],
+ "FCMP_ueq": [0, 0, 0],
+ "FCMP_ugt": [0, 0, 0],
+ "FCMP_uge": [0, 0, 0],
+ "FCMP_ult": [0, 0, 0],
+ "FCMP_ule": [0, 0, 0],
+ "FCMP_une": [0, 0, 0],
+ "FCMP_true": [0, 0, 0],
+ "ICMP_eq": [0, 0, 0],
+ "ICMP_ne": [0, 0, 0],
+ "ICMP_ugt": [0, 0, 0],
+ "ICMP_uge": [0, 0, 0],
+ "ICMP_ult": [0, 0, 0],
+ "ICMP_ule": [0, 0, 0],
+ "ICMP_sgt": [1, 1, 1],
+ "ICMP_sge": [0, 0, 0],
+ "ICMP_slt": [0, 0, 0],
+ "ICMP_sle": [0, 0, 0]
}
}
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt b/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt
index df7769c9c6a65..d62b0dd157b0b 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt
+++ b/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt
@@ -82,3 +82,29 @@ Key: Function: [ 0.20 0.40 ]
Key: Pointer: [ 0.60 0.80 ]
Key: Constant: [ 1.00 1.20 ]
Key: Variable: [ 1.40 1.60 ]
+Key: FCMP_false: [ 1.80 2.00 ]
+Key: FCMP_oeq: [ 2.20 2.40 ]
+Key: FCMP_ogt: [ 2.60 2.80 ]
+Key: FCMP_oge: [ 3.00 3.20 ]
+Key: FCMP_olt: [ 3.40 3.60 ]
+Key: FCMP_ole: [ 3.80 4.00 ]
+Key: FCMP_one: [ 4.20 4.40 ]
+Key: FCMP_ord: [ 4.60 4.80 ]
+Key: FCMP_uno: [ 5.00 5.20 ]
+Key: FCMP_ueq: [ 5.40 5.60 ]
+Key: FCMP_ugt: [ 5.80 6.00 ]
+Key: FCMP_uge: [ 6.20 6.40 ]
+Key: FCMP_ult: [ 6.60 6.80 ]
+Key: FCMP_ule: [ 7.00 7.20 ]
+Key: FCMP_une: [ 7.40 7.60 ]
+Key: FCMP_true: [ 7.80 8.00 ]
+Key: ICMP_eq: [ 8.20 8.40 ]
+Key: ICMP_ne: [ 8.60 8.80 ]
+Key: ICMP_ugt: [ 9.00 9.20 ]
+Key: ICMP_uge: [ 9.40 9.60 ]
+Key: ICMP_ult: [ 9.80 10.00 ]
+Key: ICMP_ule: [ 10.20 10.40 ]
+Key: ICMP_sgt: [ 10.60 10.80 ]
+Key: ICMP_sge: [ 11.00 11.20 ]
+Key: ICMP_slt: [ 11.40 11.60 ]
+Key: ICMP_sle: [ 11.80 12.00 ]
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt
index f3ce809fd2fd2..e443adb17ac78 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt
+++ b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt
@@ -82,3 +82,29 @@ Key: Function: [ 0.50 1.00 ]
Key: Pointer: [ 1.50 2.00 ]
Key: Constant: [ 2.50 3.00 ]
Key: Variable: [ 3.50 4.00 ]
+Key: FCMP_false: [ 4.50 5.00 ]
+Key: FCMP_oeq: [ 5.50 6.00 ]
+Key: FCMP_ogt: [ 6.50 7.00 ]
+Key: FCMP_oge: [ 7.50 8.00 ]
+Key: FCMP_olt: [ 8.50 9.00 ]
+Key: FCMP_ole: [ 9.50 10.00 ]
+Key: FCMP_one: [ 10.50 11.00 ]
+Key: FCMP_ord: [ 11.50 12.00 ]
+Key: FCMP_uno: [ 12.50 13.00 ]
+Key: FCMP_ueq: [ 13.50 14.00 ]
+Key: FCMP_ugt: [ 14.50 15.00 ]
+Key: FCMP_uge: [ 15.50 16.00 ]
+Key: FCMP_ult: [ 16.50 17.00 ]
+Key: FCMP_ule: [ 17.50 18.00 ]
+Key: FCMP_une: [ 18.50 19.00 ]
+Key: FCMP_true: [ 19.50 20.00 ]
+Key: ICMP_eq: [ 20.50 21.00 ]
+Key: ICMP_ne: [ 21.50 22.00 ]
+Key: ICMP_ugt: [ 22.50 23.00 ]
+Key: ICMP_uge: [ 23.50 24.00 ]
+Key: ICMP_ult: [ 24.50 25.00 ]
+Key: ICMP_ule: [ 25.50 26.00 ]
+Key: ICMP_sgt: [ 26.50 27.00 ]
+Key: ICMP_sge: [ 27.50 28.00 ]
+Key: ICMP_slt: [ 28.50 29.00 ]
+Key: ICMP_sle: [ 29.50 30.00 ]
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt
index 72b25b9bd3d9c..7fb6043552f7b 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt
+++ b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt
@@ -82,3 +82,29 @@ Key: Function: [ 0.00 0.00 ]
Key: Pointer: [ 0.00 0.00 ]
Key: Constant: [ 0.00 0.00 ]
Key: Variable: [ 0.00 0.00 ]
+Key: FCMP_false: [ 0.00 0.00 ]
+Key: FCMP_oeq: [ 0.00 0.00 ]
+Key: FCMP_ogt: [ 0.00 0.00 ]
+Key: FCMP_oge: [ 0.00 0.00 ]
+Key: FCMP_olt: [ 0.00 0.00 ]
+Key: FCMP_ole: [ 0.00 0.00 ]
+Key: FCMP_one: [ 0.00 0.00 ]
+Key: FCMP_ord: [ 0.00 0.00 ]
+Key: FCMP_uno: [ 0.00 0.00 ]
+Key: FCMP_ueq: [ 0.00 0.00 ]
+Key: FCMP_ugt: [ 0.00 0.00 ]
+Key: FCMP_uge: [ 0.00 0.00 ]
+Key: FCMP_ult: [ 0.00 0.00 ]
+Key: FCMP_ule: [ 0.00 0.00 ]
+Key: FCMP_une: [ 0.00 0.00 ]
+Key: FCMP_true: [ 0.00 0.00 ]
+Key: ICMP_eq: [ 0.00 0.00 ]
+Key: ICMP_ne: [ 0.00 0.00 ]
+Key: ICMP_ugt: [ 0.00 0.00 ]
+Key: ICMP_uge: [ 0.00 0.00 ]
+Key: ICMP_ult: [ 0.00 0.00 ]
+Key: ICMP_ule: [ 0.00 0.00 ]
+Key: ICMP_sgt: [ 0.00 0.00 ]
+Key: ICMP_sge: [ 0.00 0.00 ]
+Key: ICMP_slt: [ 0.00 0.00 ]
+Key: ICMP_sle: [ 0.00 0.00 ]
diff --git a/llvm/test/Analysis/IR2Vec/if-else.ll b/llvm/test/Analysis/IR2Vec/if-else.ll
index fe532479086d3..804c1ca5cb6f6 100644
--- a/llvm/test/Analysis/IR2Vec/if-else.ll
+++ b/llvm/test/Analysis/IR2Vec/if-else.ll
@@ -29,7 +29,7 @@ return: ; preds = %if.else, %if.then
; CHECK: Basic block vectors:
; CHECK-NEXT: Basic block: entry:
-; CHECK-NEXT: [ 816.00 825.00 834.00 ]
+; CHECK-NEXT: [ 816.20 825.20 834.20 ]
; CHECK-NEXT: Basic block: if.then:
; CHECK-NEXT: [ 195.00 198.00 201.00 ]
; CHECK-NEXT: Basic block: if.else:
diff --git a/llvm/test/Analysis/IR2Vec/unreachable.ll b/llvm/test/Analysis/IR2Vec/unreachable.ll
index b0e3e49978018..9be0ee1c2de7a 100644
--- a/llvm/test/Analysis/IR2Vec/unreachable.ll
+++ b/llvm/test/Analysis/IR2Vec/unreachable.ll
@@ -33,7 +33,7 @@ return: ; preds = %if.else, %if.then
; CHECK: Basic block vectors:
; CHECK-NEXT: Basic block: entry:
-; CHECK-NEXT: [ 816.00 825.00 834.00 ]
+; CHECK-NEXT: [ 816.20 825.20 834.20 ]
; CHECK-NEXT: Basic block: if.then:
; CHECK-NEXT: [ 195.00 198.00 201.00 ]
; CHECK-NEXT: Basic block: if.else:
diff --git a/llvm/test/tools/llvm-ir2vec/entities.ll b/llvm/test/tools/llvm-ir2vec/entities.ll
index 4b51adf30bf74..8dbce57302f6f 100644
--- a/llvm/test/tools/llvm-ir2vec/entities.ll
+++ b/llvm/test/tools/llvm-ir2vec/entities.ll
@@ -1,6 +1,6 @@
; RUN: llvm-ir2vec entities | FileCheck %s
-CHECK: 84
+CHECK: 110
CHECK-NEXT: Ret 0
CHECK-NEXT: Br 1
CHECK-NEXT: Switch 2
@@ -85,3 +85,29 @@ CHECK-NEXT: Function 80
CHECK-NEXT: Pointer 81
CHECK-NEXT: Constant 82
CHECK-NEXT: Variable 83
+CHECK-NEXT: FCMP_false 84
+CHECK-NEXT: FCMP_oeq 85
+CHECK-NEXT: FCMP_ogt 86
+CHECK-NEXT: FCMP_oge 87
+CHECK-NEXT: FCMP_olt 88
+CHECK-NEXT: FCMP_ole 89
+CHECK-NEXT: FCMP_one 90
+CHECK-NEXT: FCMP_ord 91
+CHECK-NEXT: FCMP_uno 92
+CHECK-NEXT: FCMP_ueq 93
+CHECK-NEXT: FCMP_ugt 94
+CHECK-NEXT: FCMP_uge 95
+CHECK-NEXT: FCMP_ult 96
+CHECK-NEXT: FCMP_ule 97
+CHECK-NEXT: FCMP_une 98
+CHECK-NEXT: FCMP_true 99
+CHECK-NEXT: ICMP_eq 100
+CHECK-NEXT: ICMP_ne 101
+CHECK-NEXT: ICMP_ugt 102
+CHECK-NEXT: ICMP_uge 103
+CHECK-NEXT: ICMP_ult 104
+CHECK-NEXT: ICMP_ule 105
+CHECK-NEXT: ICMP_sgt 106
+CHECK-NEXT: ICMP_sge 107
+CHECK-NEXT: ICMP_slt 108
+CHECK-NEXT: ICMP_sle 109
diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
index aabebf0cc90a9..1c656b8fcf4e7 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
@@ -184,7 +184,7 @@ class IR2VecTool {
// Add "Arg" relationships
unsigned ArgIndex = 0;
for (const Use &U : I.operands()) {
- unsigned OperandID = Vocabulary::getSlotIndex(*U);
+ unsigned OperandID = Vocabulary::getSlotIndex(*U.get());
unsigned RelationID = ArgRelation + ArgIndex;
OS << Opcode << '\t' << OperandID << '\t' << RelationID << '\n';
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp
index 9f2f6a3496ce0..9bc48e45eab5e 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -435,6 +435,7 @@ static constexpr unsigned MaxOpcodes = Vocabulary::MaxOpcodes;
static constexpr unsigned MaxTypeIDs = Vocabulary::MaxTypeIDs;
static constexpr unsigned MaxCanonicalTypeIDs = Vocabulary::MaxCanonicalTypeIDs;
static constexpr unsigned MaxOperands = Vocabulary::MaxOperandKinds;
+static constexpr unsigned MaxPredicateKinds = Vocabulary::MaxPredicateKinds;
// Mapping between LLVM Type::TypeID tokens and Vocabulary::CanonicalTypeID
// names and their canonical string keys.
@@ -460,7 +461,8 @@ TEST(IR2VecVocabularyTest, DummyVocabTest) {
EXPECT_EQ(Emb.size(), Dim);
// Should have the correct total number of embeddings
- EXPECT_EQ(VocabVecSize, MaxOpcodes + MaxCanonicalTypeIDs + MaxOperands);
+ EXPECT_EQ(VocabVecSize, MaxOpcodes + MaxCanonicalTypeIDs + MaxOperands +
+ MaxPredicateKinds);
auto ExpectedVocab = VocabVec;
@@ -527,6 +529,26 @@ TEST(IR2VecVocabularyTest, SlotIdxMapping) {
EXPECT_EQ(Vocabulary::getSlotIndex(*Arg),
EXPECTED_VOCAB_OPERAND_SLOT(Vocabulary::OperandKind::VariableID));
#undef EXPECTED_VOCAB_OPERAND_SLOT
+
+ // Test getSlotIndex for predicates
+#define EXPECTED_VOCAB_PREDICATE_SLOT(X) \
+ MaxOpcodes + MaxCanonicalTypeIDs + MaxOperands + static_cast<unsigned>(X)
+ for (unsigned P = CmpInst::FIRST_FCMP_PREDICATE;
+ P <= CmpInst::LAST_FCMP_PREDICATE; ++P) {
+ CmpInst::Predicate Pred = static_cast<CmpInst::Predicate>(P);
+ unsigned ExpectedIdx =
+ EXPECTED_VOCAB_PREDICATE_SLOT((P - CmpInst::FIRST_FCMP_PREDICATE));
+ EXPECT_EQ(Vocabulary::getSlotIndex(Pred), ExpectedIdx);
+ }
+ auto ICMP_Start = CmpInst::LAST_FCMP_PREDICATE + 1;
+ for (unsigned P = CmpInst::FIRST_ICMP_PREDICATE;
+ P <= CmpInst::LAST_ICMP_PREDICATE; ++P) {
+ CmpInst::Predicate Pred = static_cast<CmpInst::Predicate>(P);
+ unsigned ExpectedIdx = EXPECTED_VOCAB_PREDICATE_SLOT(
+ ICMP_Start + P - CmpInst::FIRST_ICMP_PREDICATE);
+ EXPECT_EQ(Vocabulary::getSlotIndex(Pred), ExpectedIdx);
+ }
+#undef EXPECTED_VOCAB_PREDICATE_SLOT
}
#if GTEST_HAS_DEATH_TEST
@@ -569,6 +591,7 @@ TEST(IR2VecVocabularyTest, StringKeyGeneration) {
#undef EXPECT_CANONICAL_TYPE_NAME
+ // Verify OperandKind -> string mapping
#define HANDLE_OPERAND_KINDS(X) \
X(FunctionID, "Function") \
X(PointerID, "Pointer") \
@@ -592,6 +615,28 @@ TEST(IR2VecVocabularyTest, StringKeyGeneration) {
Vocabulary::getStringKey(MaxOpcodes + MaxCanonicalTypeIDs + 1);
EXPECT_EQ(FuncArgKey, "Function");
EXPECT_EQ(PtrArgKey, "Pointer");
+
+// Verify PredicateKind -> string mapping
+#define EXPECT_PREDICATE_KIND(PredNum, PredPos, PredKind) \
+ do { \
+ std::string PredStr = \
+ std::string(PredKind) + "_" + \
+ CmpInst::getPredicateName(static_cast<CmpInst::Predicate>(PredNum)) \
+ .str(); \
+ unsigned Pos = MaxOpcodes + MaxCanonicalTypeIDs + MaxOperands + PredPos; \
+ EXPECT_EQ(Vocabulary::getStringKey(Pos), PredStr); \
+ } while (0)
+
+ for (unsigned P = CmpInst::FIRST_FCMP_PREDICATE;
+ P <= CmpInst::LAST_FCMP_PREDICATE; ++P)
+ EXPECT_PREDICATE_KIND(P, P - CmpInst::FIRST_FCMP_PREDICATE, "FCMP");
+
+ auto ICMP_Pos = CmpInst::LAST_FCMP_PREDICATE + 1;
+ for (unsigned P = CmpInst::FIRST_ICMP_PREDICATE;
+ P <= CmpInst::LAST_ICMP_PREDICATE; ++P)
+ EXPECT_PREDICATE_KIND(P, ICMP_Pos++, "ICMP");
+
+#undef EXPECT_PREDICATE_KIND
}
TEST(IR2VecVocabularyTest, VocabularyDimensions) {
More information about the llvm-commits
mailing list