[llvm-branch-commits] [llvm] Support predicates (PR #156952)

S. VenkataKeerthy via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Sep 4 12:07:42 PDT 2025


https://github.com/svkeerthy created https://github.com/llvm/llvm-project/pull/156952

None

>From 6185e40a9a6731955e190131067cc3c5bc90595e Mon Sep 17 00:00:00 2001
From: svkeerthy <venkatakeerthy at google.com>
Date: Wed, 3 Sep 2025 22:56:08 +0000
Subject: [PATCH] Support predicates

---
 llvm/include/llvm/Analysis/IR2Vec.h    | 43 +++++++++++++---
 llvm/lib/Analysis/IR2Vec.cpp           | 70 ++++++++++++++++++++++----
 llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp |  2 +-
 3 files changed, 98 insertions(+), 17 deletions(-)

diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index b7b881999241e..d49854e2d06a8 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -36,6 +36,7 @@
 #define LLVM_ANALYSIS_IR2VEC_H
 
 #include "llvm/ADT/DenseMap.h"
+#include "llvm/IR/Instructions.h"
 #include "llvm/IR/PassManager.h"
 #include "llvm/IR/Type.h"
 #include "llvm/Support/CommandLine.h"
@@ -162,16 +163,25 @@ using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
 /// embeddings.
 class Vocabulary {
   friend class llvm::IR2VecVocabAnalysis;
+  // Slot layout:
+  // [0 .. MaxOpcodes-1]                                   => Instruction
+  // opcodes [MaxOpcodes .. MaxOpcodes+MaxCanonicalTypeIDs-1]      =>
+  // Canonicalized types [MaxOpcodes+MaxCanonicalTypeIDs .. end of operands) =>
+  // Operands
+  //   Within Operands: first OperandKind entries, followed by compare
+  //   predicates
   using VocabVector = std::vector<ir2vec::Embedding>;
   VocabVector Vocab;
+
   bool Valid = false;
+  static constexpr unsigned NumICmpPredicates =
+      static_cast<unsigned>(CmpInst::LAST_ICMP_PREDICATE) -
+      static_cast<unsigned>(CmpInst::FIRST_ICMP_PREDICATE) + 1;
+  static constexpr unsigned NumFCmpPredicates =
+      static_cast<unsigned>(CmpInst::LAST_FCMP_PREDICATE) -
+      static_cast<unsigned>(CmpInst::FIRST_FCMP_PREDICATE) + 1;
 
 public:
-  // Slot layout:
-  // [0 .. MaxOpcodes-1]               => Instruction opcodes
-  // [MaxOpcodes .. MaxOpcodes+MaxCanonicalTypeIDs-1] => Canonicalized types
-  // [MaxOpcodes+MaxCanonicalTypeIDs .. NumCanonicalEntries-1] => Operand kinds
-
   /// Canonical type IDs supported by IR2Vec Vocabulary
   enum class CanonicalTypeID : unsigned {
     FloatTy,
@@ -208,13 +218,18 @@ class Vocabulary {
       static_cast<unsigned>(CanonicalTypeID::MaxCanonicalType);
   static constexpr unsigned MaxOperandKinds =
       static_cast<unsigned>(OperandKind::MaxOperandKind);
+  // CmpInst::Predicate has gaps. We want the vocabulary to be dense without
+  // empty slots.
+  static constexpr unsigned MaxPredicateKinds =
+      NumICmpPredicates + NumFCmpPredicates;
 
   Vocabulary() = default;
   LLVM_ABI Vocabulary(VocabVector &&Vocab);
 
   LLVM_ABI bool isValid() const;
   LLVM_ABI unsigned getDimension() const;
-  /// Total number of entries (opcodes + canonicalized types + operand kinds)
+  /// Total number of entries (opcodes + canonicalized types + operand kinds +
+  /// predicates)
   static constexpr size_t getCanonicalSize() { return NumCanonicalEntries; }
 
   /// Function to get vocabulary key for a given Opcode
@@ -229,16 +244,21 @@ class Vocabulary {
   /// Function to classify an operand into OperandKind
   LLVM_ABI static OperandKind getOperandKind(const Value *Op);
 
+  /// Function to get vocabulary key for a given predicate
+  LLVM_ABI static StringRef getVocabKeyForPredicate(CmpInst::Predicate P);
+
   /// Functions to return the slot index or position of a given Opcode, TypeID,
   /// or OperandKind in the vocabulary.
   LLVM_ABI static unsigned getSlotIndex(unsigned Opcode);
   LLVM_ABI static unsigned getSlotIndex(Type::TypeID TypeID);
   LLVM_ABI static unsigned getSlotIndex(const Value &Op);
+  LLVM_ABI static unsigned getSlotIndex(CmpInst::Predicate P);
 
   /// Accessors to get the embedding for a given entity.
   LLVM_ABI const ir2vec::Embedding &operator[](unsigned Opcode) const;
   LLVM_ABI const ir2vec::Embedding &operator[](Type::TypeID TypeId) const;
   LLVM_ABI const ir2vec::Embedding &operator[](const Value &Arg) const;
+  LLVM_ABI const ir2vec::Embedding &operator[](CmpInst::Predicate P) const;
 
   /// Const Iterator type aliases
   using const_iterator = VocabVector::const_iterator;
@@ -275,7 +295,13 @@ class Vocabulary {
 
 private:
   constexpr static unsigned NumCanonicalEntries =
-      MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds;
+      MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds + MaxPredicateKinds;
+
+  // Base offsets for slot layout to simplify index computation
+  constexpr static unsigned OperandBaseOffset =
+      MaxOpcodes + MaxCanonicalTypeIDs;
+  constexpr static unsigned PredicateBaseOffset =
+      OperandBaseOffset + MaxOperandKinds;
 
   /// String mappings for CanonicalTypeID values
   static constexpr StringLiteral CanonicalTypeNames[] = {
@@ -327,6 +353,9 @@ class Vocabulary {
 
   /// Function to convert TypeID to CanonicalTypeID
   LLVM_ABI static CanonicalTypeID getCanonicalTypeID(Type::TypeID TypeID);
+
+  /// Function to get the predicate enum value for a given index
+  LLVM_ABI static CmpInst::Predicate getPredicate(unsigned Index);
 };
 
 /// Embedder provides the interface to generate embeddings (vector
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 98849fd922843..c79c9c1ed493d 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -216,6 +216,8 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
       ArgEmb += Vocab[*Op];
     auto InstVector =
         Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
+    if (const auto *IC = dyn_cast<CmpInst>(&I))
+      InstVector += Vocab[IC->getPredicate()];
     InstVecMap[&I] = InstVector;
     BBVector += InstVector;
   }
@@ -250,6 +252,9 @@ void FlowAwareEmbedder::computeEmbeddings(const BasicBlock &BB) const {
     // embeddings
     auto InstVector =
         Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
+    // Add compare predicate embedding as an additional operand if applicable
+    if (const auto *IC = dyn_cast<CmpInst>(&I))
+      InstVector += Vocab[IC->getPredicate()];
     InstVecMap[&I] = InstVector;
     BBVector += InstVector;
   }
@@ -285,7 +290,17 @@ unsigned Vocabulary::getSlotIndex(Type::TypeID TypeID) {
 unsigned Vocabulary::getSlotIndex(const Value &Op) {
   unsigned Index = static_cast<unsigned>(getOperandKind(&Op));
   assert(Index < MaxOperandKinds && "Invalid OperandKind");
-  return MaxOpcodes + MaxCanonicalTypeIDs + Index;
+  return OperandBaseOffset + Index;
+}
+
+unsigned Vocabulary::getSlotIndex(CmpInst::Predicate P) {
+  unsigned PU = static_cast<unsigned>(P);
+  unsigned FirstFC = static_cast<unsigned>(CmpInst::FIRST_FCMP_PREDICATE);
+  unsigned FirstIC = static_cast<unsigned>(CmpInst::FIRST_ICMP_PREDICATE);
+
+  unsigned PredIdx =
+      (PU >= FirstIC) ? (NumFCmpPredicates + (PU - FirstIC)) : (PU - FirstFC);
+  return PredicateBaseOffset + PredIdx;
 }
 
 const Embedding &Vocabulary::operator[](unsigned Opcode) const {
@@ -300,6 +315,10 @@ const ir2vec::Embedding &Vocabulary::operator[](const Value &Arg) const {
   return Vocab[getSlotIndex(Arg)];
 }
 
+const ir2vec::Embedding &Vocabulary::operator[](CmpInst::Predicate P) const {
+  return Vocab[getSlotIndex(P)];
+}
+
 StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) {
   assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
 #define HANDLE_INST(NUM, OPCODE, CLASS)                                        \
@@ -345,18 +364,35 @@ Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) {
   return OperandKind::VariableID;
 }
 
+CmpInst::Predicate Vocabulary::getPredicate(unsigned Index) {
+  assert(Index < MaxPredicateKinds && "Invalid predicate index");
+  unsigned PredEnumVal =
+      (Index < NumFCmpPredicates)
+          ? (static_cast<unsigned>(CmpInst::FIRST_FCMP_PREDICATE) + Index)
+          : (static_cast<unsigned>(CmpInst::FIRST_ICMP_PREDICATE) +
+             (Index - NumFCmpPredicates));
+  return static_cast<CmpInst::Predicate>(PredEnumVal);
+}
+
+StringRef Vocabulary::getVocabKeyForPredicate(CmpInst::Predicate Pred) {
+  return CmpInst::getPredicateName(Pred);
+}
+
 StringRef Vocabulary::getStringKey(unsigned Pos) {
   assert(Pos < NumCanonicalEntries && "Position out of bounds in vocabulary");
   // Opcode
   if (Pos < MaxOpcodes)
     return getVocabKeyForOpcode(Pos + 1);
   // Type
-  if (Pos < MaxOpcodes + MaxCanonicalTypeIDs)
+  if (Pos < OperandBaseOffset)
     return getVocabKeyForCanonicalTypeID(
         static_cast<CanonicalTypeID>(Pos - MaxOpcodes));
   // Operand
-  return getVocabKeyForOperandKind(
-      static_cast<OperandKind>(Pos - MaxOpcodes - MaxCanonicalTypeIDs));
+  if (Pos < PredicateBaseOffset)
+    return getVocabKeyForOperandKind(
+        static_cast<OperandKind>(Pos - OperandBaseOffset));
+  // Predicates
+  return getVocabKeyForPredicate(getPredicate(Pos - PredicateBaseOffset));
 }
 
 // For now, assume vocabulary is stable unless explicitly invalidated.
@@ -370,11 +406,9 @@ Vocabulary::VocabVector Vocabulary::createDummyVocabForTest(unsigned Dim) {
   VocabVector DummyVocab;
   DummyVocab.reserve(NumCanonicalEntries);
   float DummyVal = 0.1f;
-  // Create a dummy vocabulary with entries for all opcodes, types, and
-  // operands
-  for ([[maybe_unused]] unsigned _ :
-       seq(0u, Vocabulary::MaxOpcodes + Vocabulary::MaxCanonicalTypeIDs +
-                   Vocabulary::MaxOperandKinds)) {
+  // Create a dummy vocabulary with entries for all opcodes, types, operands
+  // and predicates
+  for ([[maybe_unused]] unsigned _ : seq(0u, Vocabulary::NumCanonicalEntries)) {
     DummyVocab.push_back(Embedding(Dim, DummyVal));
     DummyVal += 0.1f;
   }
@@ -517,6 +551,24 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
   }
   Vocab.insert(Vocab.end(), NumericArgEmbeddings.begin(),
                NumericArgEmbeddings.end());
+
+  // Handle Predicates: part of Operands section. We look up predicate keys
+  // in ArgVocab.
+  std::vector<Embedding> NumericPredEmbeddings(Vocabulary::MaxPredicateKinds,
+                                               Embedding(Dim, 0));
+  NumericPredEmbeddings.reserve(Vocabulary::MaxPredicateKinds);
+  for (unsigned PK : seq(0u, Vocabulary::MaxPredicateKinds)) {
+    StringRef VocabKey =
+        Vocabulary::getVocabKeyForPredicate(Vocabulary::getPredicate(PK));
+    auto It = ArgVocab.find(VocabKey.str());
+    if (It != ArgVocab.end()) {
+      NumericPredEmbeddings[PK] = It->second;
+      continue;
+    }
+    handleMissingEntity(VocabKey.str());
+  }
+  Vocab.insert(Vocab.end(), NumericPredEmbeddings.begin(),
+               NumericPredEmbeddings.end());
 }
 
 IR2VecVocabAnalysis::IR2VecVocabAnalysis(const VocabVector &Vocab)
diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
index aabebf0cc90a9..1c656b8fcf4e7 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
@@ -184,7 +184,7 @@ class IR2VecTool {
         // Add "Arg" relationships
         unsigned ArgIndex = 0;
         for (const Use &U : I.operands()) {
-          unsigned OperandID = Vocabulary::getSlotIndex(*U);
+          unsigned OperandID = Vocabulary::getSlotIndex(*U.get());
           unsigned RelationID = ArgRelation + ArgIndex;
           OS << Opcode << '\t' << OperandID << '\t' << RelationID << '\n';
 



More information about the llvm-branch-commits mailing list