[llvm] [IR2Vec] Add support for Cmp predicates in vocabulary and embeddings (PR #156952)

S. VenkataKeerthy via llvm-commits llvm-commits at lists.llvm.org
Wed Oct 1 14:34:41 PDT 2025


https://github.com/svkeerthy updated https://github.com/llvm/llvm-project/pull/156952

>From 89e0e620c24e9cccad45970726fb1cd47b4bb281 Mon Sep 17 00:00:00 2001
From: svkeerthy <venkatakeerthy at google.com>
Date: Wed, 3 Sep 2025 22:56:08 +0000
Subject: [PATCH] Support predicates

---
 llvm/include/llvm/Analysis/IR2Vec.h           | 55 ++++++++++++--
 llvm/lib/Analysis/IR2Vec.cpp                  | 76 ++++++++++++++++---
 .../IR2Vec/Inputs/dummy_2D_vocab.json         | 28 ++++++-
 .../Inputs/dummy_3D_nonzero_arg_vocab.json    | 28 ++++++-
 .../Inputs/dummy_3D_nonzero_opc_vocab.json    | 29 ++++++-
 .../Inputs/reference_default_vocab_print.txt  | 26 +++++++
 .../Inputs/reference_wtd1_vocab_print.txt     | 26 +++++++
 .../Inputs/reference_wtd2_vocab_print.txt     | 26 +++++++
 llvm/test/Analysis/IR2Vec/if-else.ll          |  2 +-
 llvm/test/Analysis/IR2Vec/unreachable.ll      |  2 +-
 llvm/test/tools/llvm-ir2vec/entities.ll       | 28 ++++++-
 llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp        |  2 +-
 llvm/unittests/Analysis/IR2VecTest.cpp        | 47 +++++++++++-
 13 files changed, 351 insertions(+), 24 deletions(-)

diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 3671c1c71ac0b..f3f9de460218b 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -36,6 +36,7 @@
 #define LLVM_ANALYSIS_IR2VEC_H
 
 #include "llvm/ADT/DenseMap.h"
+#include "llvm/IR/Instructions.h"
 #include "llvm/IR/PassManager.h"
 #include "llvm/IR/Type.h"
 #include "llvm/Support/CommandLine.h"
@@ -162,15 +163,34 @@ using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
 /// embeddings.
 class Vocabulary {
   friend class llvm::IR2VecVocabAnalysis;
+
+  // Vocabulary Slot Layout:
+  // +----------------+------------------------------------------------------+
+  // | Entity Type    | Index Range                                          |
+  // +----------------+------------------------------------------------------+
+  // | Opcodes        | [0 .. (MaxOpcodes-1)]                                |
+  // | Canonical Types| [MaxOpcodes .. (MaxOpcodes+MaxCanonicalTypeIDs-1)]   |
+  // | Operands       | [(MaxOpcodes+MaxCanonicalTypeIDs) .. NumCanEntries]  |
+  // +----------------+------------------------------------------------------+
+  // Note: MaxOpcodes is the number of unique opcodes supported by LLVM IR.
+  //       MaxCanonicalTypeIDs is the number of canonicalized type IDs.
+  //       "Similar" LLVM Types are grouped/canonicalized together. E.g., all
+  //       float variants (FloatTy, DoubleTy, HalfTy, etc.) map to
+  //       CanonicalTypeID::FloatTy. This helps reduce the vocabulary size
+  //       and improves learning. Operands include Comparison predicates
+  //       (ICmp/FCmp) along with other operand types. This can be extended to
+  //       include other specializations in future.
   using VocabVector = std::vector<ir2vec::Embedding>;
   VocabVector Vocab;
 
-public:
-  // Slot layout:
-  // [0 .. MaxOpcodes-1]               => Instruction opcodes
-  // [MaxOpcodes .. MaxOpcodes+MaxCanonicalTypeIDs-1] => Canonicalized types
-  // [MaxOpcodes+MaxCanonicalTypeIDs .. NumCanonicalEntries-1] => Operand kinds
+  static constexpr unsigned NumICmpPredicates =
+      static_cast<unsigned>(CmpInst::LAST_ICMP_PREDICATE) -
+      static_cast<unsigned>(CmpInst::FIRST_ICMP_PREDICATE) + 1;
+  static constexpr unsigned NumFCmpPredicates =
+      static_cast<unsigned>(CmpInst::LAST_FCMP_PREDICATE) -
+      static_cast<unsigned>(CmpInst::FIRST_FCMP_PREDICATE) + 1;
 
+public:
   /// Canonical type IDs supported by IR2Vec Vocabulary
   enum class CanonicalTypeID : unsigned {
     FloatTy,
@@ -207,13 +227,18 @@ class Vocabulary {
       static_cast<unsigned>(CanonicalTypeID::MaxCanonicalType);
   static constexpr unsigned MaxOperandKinds =
       static_cast<unsigned>(OperandKind::MaxOperandKind);
+  // CmpInst::Predicate has gaps. We want the vocabulary to be dense without
+  // empty slots.
+  static constexpr unsigned MaxPredicateKinds =
+      NumICmpPredicates + NumFCmpPredicates;
 
   Vocabulary() = default;
   LLVM_ABI Vocabulary(VocabVector &&Vocab) : Vocab(std::move(Vocab)) {}
 
   LLVM_ABI bool isValid() const { return Vocab.size() == NumCanonicalEntries; };
   LLVM_ABI unsigned getDimension() const;
-  /// Total number of entries (opcodes + canonicalized types + operand kinds)
+  /// Total number of entries (opcodes + canonicalized types + operand kinds +
+  /// predicates)
   static constexpr size_t getCanonicalSize() { return NumCanonicalEntries; }
 
   /// Function to get vocabulary key for a given Opcode
@@ -228,16 +253,21 @@ class Vocabulary {
   /// Function to classify an operand into OperandKind
   LLVM_ABI static OperandKind getOperandKind(const Value *Op);
 
+  /// Function to get vocabulary key for a given predicate
+  LLVM_ABI static StringRef getVocabKeyForPredicate(CmpInst::Predicate P);
+
   /// Functions to return the slot index or position of a given Opcode, TypeID,
   /// or OperandKind in the vocabulary.
   LLVM_ABI static unsigned getSlotIndex(unsigned Opcode);
   LLVM_ABI static unsigned getSlotIndex(Type::TypeID TypeID);
   LLVM_ABI static unsigned getSlotIndex(const Value &Op);
+  LLVM_ABI static unsigned getSlotIndex(CmpInst::Predicate P);
 
   /// Accessors to get the embedding for a given entity.
   LLVM_ABI const ir2vec::Embedding &operator[](unsigned Opcode) const;
   LLVM_ABI const ir2vec::Embedding &operator[](Type::TypeID TypeId) const;
   LLVM_ABI const ir2vec::Embedding &operator[](const Value &Arg) const;
+  LLVM_ABI const ir2vec::Embedding &operator[](CmpInst::Predicate P) const;
 
   /// Const Iterator type aliases
   using const_iterator = VocabVector::const_iterator;
@@ -274,7 +304,13 @@ class Vocabulary {
 
 private:
   constexpr static unsigned NumCanonicalEntries =
-      MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds;
+      MaxOpcodes + MaxCanonicalTypeIDs + MaxOperandKinds + MaxPredicateKinds;
+
+  // Base offsets for slot layout to simplify index computation
+  constexpr static unsigned OperandBaseOffset =
+      MaxOpcodes + MaxCanonicalTypeIDs;
+  constexpr static unsigned PredicateBaseOffset =
+      OperandBaseOffset + MaxOperandKinds;
 
   /// String mappings for CanonicalTypeID values
   static constexpr StringLiteral CanonicalTypeNames[] = {
@@ -326,6 +362,11 @@ class Vocabulary {
 
   /// Function to convert TypeID to CanonicalTypeID
   LLVM_ABI static CanonicalTypeID getCanonicalTypeID(Type::TypeID TypeID);
+
+  /// Function to get the predicate enum value for a given index. Index is
+  /// relative to the predicates section of the vocabulary. E.g., Index 0
+  /// corresponds to the first predicate.
+  LLVM_ABI static CmpInst::Predicate getPredicate(unsigned Index);
 };
 
 /// Embedder provides the interface to generate embeddings (vector
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 99afc0601d523..f51f0898cb37e 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -216,6 +216,8 @@ void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
       ArgEmb += Vocab[*Op];
     auto InstVector =
         Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
+    if (const auto *IC = dyn_cast<CmpInst>(&I))
+      InstVector += Vocab[IC->getPredicate()];
     InstVecMap[&I] = InstVector;
     BBVector += InstVector;
   }
@@ -250,6 +252,9 @@ void FlowAwareEmbedder::computeEmbeddings(const BasicBlock &BB) const {
     // embeddings
     auto InstVector =
         Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
+    // Add compare predicate embedding as an additional operand if applicable
+    if (const auto *IC = dyn_cast<CmpInst>(&I))
+      InstVector += Vocab[IC->getPredicate()];
     InstVecMap[&I] = InstVector;
     BBVector += InstVector;
   }
@@ -278,7 +283,17 @@ unsigned Vocabulary::getSlotIndex(Type::TypeID TypeID) {
 unsigned Vocabulary::getSlotIndex(const Value &Op) {
   unsigned Index = static_cast<unsigned>(getOperandKind(&Op));
   assert(Index < MaxOperandKinds && "Invalid OperandKind");
-  return MaxOpcodes + MaxCanonicalTypeIDs + Index;
+  return OperandBaseOffset + Index;
+}
+
+unsigned Vocabulary::getSlotIndex(CmpInst::Predicate P) {
+  unsigned PU = static_cast<unsigned>(P);
+  unsigned FirstFC = static_cast<unsigned>(CmpInst::FIRST_FCMP_PREDICATE);
+  unsigned FirstIC = static_cast<unsigned>(CmpInst::FIRST_ICMP_PREDICATE);
+
+  unsigned PredIdx =
+      (PU >= FirstIC) ? (NumFCmpPredicates + (PU - FirstIC)) : (PU - FirstFC);
+  return PredicateBaseOffset + PredIdx;
 }
 
 const Embedding &Vocabulary::operator[](unsigned Opcode) const {
@@ -293,6 +308,10 @@ const ir2vec::Embedding &Vocabulary::operator[](const Value &Arg) const {
   return Vocab[getSlotIndex(Arg)];
 }
 
+const ir2vec::Embedding &Vocabulary::operator[](CmpInst::Predicate P) const {
+  return Vocab[getSlotIndex(P)];
+}
+
 StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) {
   assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
 #define HANDLE_INST(NUM, OPCODE, CLASS)                                        \
@@ -338,18 +357,41 @@ Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) {
   return OperandKind::VariableID;
 }
 
+CmpInst::Predicate Vocabulary::getPredicate(unsigned Index) {
+  assert(Index < MaxPredicateKinds && "Invalid predicate index");
+  unsigned PredEnumVal =
+      (Index < NumFCmpPredicates)
+          ? (static_cast<unsigned>(CmpInst::FIRST_FCMP_PREDICATE) + Index)
+          : (static_cast<unsigned>(CmpInst::FIRST_ICMP_PREDICATE) +
+             (Index - NumFCmpPredicates));
+  return static_cast<CmpInst::Predicate>(PredEnumVal);
+}
+
+StringRef Vocabulary::getVocabKeyForPredicate(CmpInst::Predicate Pred) {
+  static SmallString<16> PredNameBuffer;
+  if (Pred < CmpInst::FIRST_ICMP_PREDICATE)
+    PredNameBuffer = "FCMP_";
+  else
+    PredNameBuffer = "ICMP_";
+  PredNameBuffer += CmpInst::getPredicateName(Pred);
+  return PredNameBuffer;
+}
+
 StringRef Vocabulary::getStringKey(unsigned Pos) {
   assert(Pos < NumCanonicalEntries && "Position out of bounds in vocabulary");
   // Opcode
   if (Pos < MaxOpcodes)
     return getVocabKeyForOpcode(Pos + 1);
   // Type
-  if (Pos < MaxOpcodes + MaxCanonicalTypeIDs)
+  if (Pos < OperandBaseOffset)
     return getVocabKeyForCanonicalTypeID(
         static_cast<CanonicalTypeID>(Pos - MaxOpcodes));
   // Operand
-  return getVocabKeyForOperandKind(
-      static_cast<OperandKind>(Pos - MaxOpcodes - MaxCanonicalTypeIDs));
+  if (Pos < PredicateBaseOffset)
+    return getVocabKeyForOperandKind(
+        static_cast<OperandKind>(Pos - OperandBaseOffset));
+  // Predicates
+  return getVocabKeyForPredicate(getPredicate(Pos - PredicateBaseOffset));
 }
 
 // For now, assume vocabulary is stable unless explicitly invalidated.
@@ -363,11 +405,9 @@ Vocabulary::VocabVector Vocabulary::createDummyVocabForTest(unsigned Dim) {
   VocabVector DummyVocab;
   DummyVocab.reserve(NumCanonicalEntries);
   float DummyVal = 0.1f;
-  // Create a dummy vocabulary with entries for all opcodes, types, and
-  // operands
-  for ([[maybe_unused]] unsigned _ :
-       seq(0u, Vocabulary::MaxOpcodes + Vocabulary::MaxCanonicalTypeIDs +
-                   Vocabulary::MaxOperandKinds)) {
+  // Create a dummy vocabulary with entries for all opcodes, types, operands
+  // and predicates
+  for ([[maybe_unused]] unsigned _ : seq(0u, Vocabulary::NumCanonicalEntries)) {
     DummyVocab.push_back(Embedding(Dim, DummyVal));
     DummyVal += 0.1f;
   }
@@ -510,6 +550,24 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() {
   }
   Vocab.insert(Vocab.end(), NumericArgEmbeddings.begin(),
                NumericArgEmbeddings.end());
+
+  // Handle Predicates: part of Operands section. We look up predicate keys
+  // in ArgVocab.
+  std::vector<Embedding> NumericPredEmbeddings(Vocabulary::MaxPredicateKinds,
+                                               Embedding(Dim, 0));
+  NumericPredEmbeddings.reserve(Vocabulary::MaxPredicateKinds);
+  for (unsigned PK : seq(0u, Vocabulary::MaxPredicateKinds)) {
+    StringRef VocabKey =
+        Vocabulary::getVocabKeyForPredicate(Vocabulary::getPredicate(PK));
+    auto It = ArgVocab.find(VocabKey.str());
+    if (It != ArgVocab.end()) {
+      NumericPredEmbeddings[PK] = It->second;
+      continue;
+    }
+    handleMissingEntity(VocabKey.str());
+  }
+  Vocab.insert(Vocab.end(), NumericPredEmbeddings.begin(),
+               NumericPredEmbeddings.end());
 }
 
 IR2VecVocabAnalysis::IR2VecVocabAnalysis(const VocabVector &Vocab)
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/dummy_2D_vocab.json b/llvm/test/Analysis/IR2Vec/Inputs/dummy_2D_vocab.json
index 07fde84c1541b..ae36ff54686c5 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/dummy_2D_vocab.json
+++ b/llvm/test/Analysis/IR2Vec/Inputs/dummy_2D_vocab.json
@@ -87,6 +87,32 @@
         "Function": [1, 2],
         "Pointer": [3, 4],
         "Constant": [5, 6],
-        "Variable": [7, 8]
+        "Variable": [7, 8],
+        "FCMP_false": [9, 10],
+        "FCMP_oeq": [11, 12], 
+        "FCMP_ogt": [13, 14], 
+        "FCMP_oge": [15, 16], 
+        "FCMP_olt": [17, 18], 
+        "FCMP_ole": [19, 20], 
+        "FCMP_one": [21, 22], 
+        "FCMP_ord": [23, 24], 
+        "FCMP_uno": [25, 26], 
+        "FCMP_ueq": [27, 28], 
+        "FCMP_ugt": [29, 30], 
+        "FCMP_uge": [31, 32], 
+        "FCMP_ult": [33, 34], 
+        "FCMP_ule": [35, 36], 
+        "FCMP_une": [37, 38], 
+        "FCMP_true": [39, 40], 
+        "ICMP_eq": [41, 42], 
+        "ICMP_ne": [43, 44], 
+        "ICMP_ugt": [45, 46], 
+        "ICMP_uge": [47, 48], 
+        "ICMP_ult": [49, 50], 
+        "ICMP_ule": [51, 52], 
+        "ICMP_sgt": [53, 54], 
+        "ICMP_sge": [55, 56], 
+        "ICMP_slt": [57, 58], 
+        "ICMP_sle": [59, 60]
     }
 }
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_arg_vocab.json b/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_arg_vocab.json
index 932b3a217b70c..9003dc73954aa 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_arg_vocab.json
+++ b/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_arg_vocab.json
@@ -86,6 +86,32 @@
         "Function": [1, 2, 3],
         "Pointer": [4, 5, 6],
         "Constant": [7, 8, 9],
-        "Variable": [10, 11, 12]
+        "Variable": [10, 11, 12],
+        "FCMP_false": [13, 14, 15],
+        "FCMP_oeq": [16, 17, 18],
+        "FCMP_ogt": [19, 20, 21],
+        "FCMP_oge": [22, 23, 24],
+        "FCMP_olt": [25, 26, 27],
+        "FCMP_ole": [28, 29, 30],
+        "FCMP_one": [31, 32, 33],
+        "FCMP_ord": [34, 35, 36],
+        "FCMP_uno": [37, 38, 39],
+        "FCMP_ueq": [40, 41, 42],
+        "FCMP_ugt": [43, 44, 45],
+        "FCMP_uge": [46, 47, 48],
+        "FCMP_ult": [49, 50, 51],
+        "FCMP_ule": [52, 53, 54],
+        "FCMP_une": [55, 56, 57],
+        "FCMP_true": [58, 59, 60],        
+        "ICMP_eq": [61, 62, 63],
+        "ICMP_ne": [64, 65, 66],
+        "ICMP_ugt": [67, 68, 69],
+        "ICMP_uge": [70, 71, 72],
+        "ICMP_ult": [73, 74, 75],
+        "ICMP_ule": [76, 77, 78],
+        "ICMP_sgt": [79, 80, 81],
+        "ICMP_sge": [82, 83, 84],
+        "ICMP_slt": [85, 86, 87],
+        "ICMP_sle": [88, 89, 90]
     }
 }
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json b/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json
index 19f3efee9f6a1..7ef85490b27df 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json
+++ b/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json
@@ -47,6 +47,7 @@
         "FPTrunc": [133, 134, 135],
         "FPExt": [136, 137, 138],
         "PtrToInt": [139, 140, 141],
+        "PtrToAddr": [202, 203, 204],
         "IntToPtr": [142, 143, 144],
         "BitCast": [145, 146, 147],
         "AddrSpaceCast": [148, 149, 150],
@@ -86,6 +87,32 @@
         "Function": [0, 0, 0],
         "Pointer": [0, 0, 0],
         "Constant": [0, 0, 0],
-        "Variable": [0, 0, 0]
+        "Variable": [0, 0, 0],
+        "FCMP_false": [0, 0, 0],
+        "FCMP_oeq": [0, 0, 0],
+        "FCMP_ogt": [0, 0, 0],
+        "FCMP_oge": [0, 0, 0],
+        "FCMP_olt": [0, 0, 0],
+        "FCMP_ole": [0, 0, 0],
+        "FCMP_one": [0, 0, 0],
+        "FCMP_ord": [0, 0, 0],
+        "FCMP_uno": [0, 0, 0],
+        "FCMP_ueq": [0, 0, 0],
+        "FCMP_ugt": [0, 0, 0],
+        "FCMP_uge": [0, 0, 0],
+        "FCMP_ult": [0, 0, 0],
+        "FCMP_ule": [0, 0, 0],
+        "FCMP_une": [0, 0, 0],
+        "FCMP_true": [0, 0, 0],
+        "ICMP_eq": [0, 0, 0],
+        "ICMP_ne": [0, 0, 0],
+        "ICMP_ugt": [0, 0, 0],
+        "ICMP_uge": [0, 0, 0],
+        "ICMP_ult": [0, 0, 0],
+        "ICMP_ule": [0, 0, 0],
+        "ICMP_sgt": [1, 1, 1],
+        "ICMP_sge": [0, 0, 0],
+        "ICMP_slt": [0, 0, 0],
+        "ICMP_sle": [0, 0, 0]
     }
 }
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt b/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt
index df7769c9c6a65..d62b0dd157b0b 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt
+++ b/llvm/test/Analysis/IR2Vec/Inputs/reference_default_vocab_print.txt
@@ -82,3 +82,29 @@ Key: Function:  [ 0.20  0.40 ]
 Key: Pointer:  [ 0.60  0.80 ]
 Key: Constant:  [ 1.00  1.20 ]
 Key: Variable:  [ 1.40  1.60 ]
+Key: FCMP_false:  [ 1.80  2.00 ]
+Key: FCMP_oeq:  [ 2.20  2.40 ]
+Key: FCMP_ogt:  [ 2.60  2.80 ]
+Key: FCMP_oge:  [ 3.00  3.20 ]
+Key: FCMP_olt:  [ 3.40  3.60 ]
+Key: FCMP_ole:  [ 3.80  4.00 ]
+Key: FCMP_one:  [ 4.20  4.40 ]
+Key: FCMP_ord:  [ 4.60  4.80 ]
+Key: FCMP_uno:  [ 5.00  5.20 ]
+Key: FCMP_ueq:  [ 5.40  5.60 ]
+Key: FCMP_ugt:  [ 5.80  6.00 ]
+Key: FCMP_uge:  [ 6.20  6.40 ]
+Key: FCMP_ult:  [ 6.60  6.80 ]
+Key: FCMP_ule:  [ 7.00  7.20 ]
+Key: FCMP_une:  [ 7.40  7.60 ]
+Key: FCMP_true:  [ 7.80  8.00 ]
+Key: ICMP_eq:  [ 8.20  8.40 ]
+Key: ICMP_ne:  [ 8.60  8.80 ]
+Key: ICMP_ugt:  [ 9.00  9.20 ]
+Key: ICMP_uge:  [ 9.40  9.60 ]
+Key: ICMP_ult:  [ 9.80  10.00 ]
+Key: ICMP_ule:  [ 10.20  10.40 ]
+Key: ICMP_sgt:  [ 10.60  10.80 ]
+Key: ICMP_sge:  [ 11.00  11.20 ]
+Key: ICMP_slt:  [ 11.40  11.60 ]
+Key: ICMP_sle:  [ 11.80  12.00 ]
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt
index f3ce809fd2fd2..e443adb17ac78 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt
+++ b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd1_vocab_print.txt
@@ -82,3 +82,29 @@ Key: Function:  [ 0.50  1.00 ]
 Key: Pointer:  [ 1.50  2.00 ]
 Key: Constant:  [ 2.50  3.00 ]
 Key: Variable:  [ 3.50  4.00 ]
+Key: FCMP_false:  [ 4.50  5.00 ]
+Key: FCMP_oeq:  [ 5.50  6.00 ]
+Key: FCMP_ogt:  [ 6.50  7.00 ]
+Key: FCMP_oge:  [ 7.50  8.00 ]
+Key: FCMP_olt:  [ 8.50  9.00 ]
+Key: FCMP_ole:  [ 9.50  10.00 ]
+Key: FCMP_one:  [ 10.50  11.00 ]
+Key: FCMP_ord:  [ 11.50  12.00 ]
+Key: FCMP_uno:  [ 12.50  13.00 ]
+Key: FCMP_ueq:  [ 13.50  14.00 ]
+Key: FCMP_ugt:  [ 14.50  15.00 ]
+Key: FCMP_uge:  [ 15.50  16.00 ]
+Key: FCMP_ult:  [ 16.50  17.00 ]
+Key: FCMP_ule:  [ 17.50  18.00 ]
+Key: FCMP_une:  [ 18.50  19.00 ]
+Key: FCMP_true:  [ 19.50  20.00 ]
+Key: ICMP_eq:  [ 20.50  21.00 ]
+Key: ICMP_ne:  [ 21.50  22.00 ]
+Key: ICMP_ugt:  [ 22.50  23.00 ]
+Key: ICMP_uge:  [ 23.50  24.00 ]
+Key: ICMP_ult:  [ 24.50  25.00 ]
+Key: ICMP_ule:  [ 25.50  26.00 ]
+Key: ICMP_sgt:  [ 26.50  27.00 ]
+Key: ICMP_sge:  [ 27.50  28.00 ]
+Key: ICMP_slt:  [ 28.50  29.00 ]
+Key: ICMP_sle:  [ 29.50  30.00 ]
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt
index 72b25b9bd3d9c..7fb6043552f7b 100644
--- a/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt
+++ b/llvm/test/Analysis/IR2Vec/Inputs/reference_wtd2_vocab_print.txt
@@ -82,3 +82,29 @@ Key: Function:  [ 0.00  0.00 ]
 Key: Pointer:  [ 0.00  0.00 ]
 Key: Constant:  [ 0.00  0.00 ]
 Key: Variable:  [ 0.00  0.00 ]
+Key: FCMP_false:  [ 0.00  0.00 ]
+Key: FCMP_oeq:  [ 0.00  0.00 ]
+Key: FCMP_ogt:  [ 0.00  0.00 ]
+Key: FCMP_oge:  [ 0.00  0.00 ]
+Key: FCMP_olt:  [ 0.00  0.00 ]
+Key: FCMP_ole:  [ 0.00  0.00 ]
+Key: FCMP_one:  [ 0.00  0.00 ]
+Key: FCMP_ord:  [ 0.00  0.00 ]
+Key: FCMP_uno:  [ 0.00  0.00 ]
+Key: FCMP_ueq:  [ 0.00  0.00 ]
+Key: FCMP_ugt:  [ 0.00  0.00 ]
+Key: FCMP_uge:  [ 0.00  0.00 ]
+Key: FCMP_ult:  [ 0.00  0.00 ]
+Key: FCMP_ule:  [ 0.00  0.00 ]
+Key: FCMP_une:  [ 0.00  0.00 ]
+Key: FCMP_true:  [ 0.00  0.00 ]
+Key: ICMP_eq:  [ 0.00  0.00 ]
+Key: ICMP_ne:  [ 0.00  0.00 ]
+Key: ICMP_ugt:  [ 0.00  0.00 ]
+Key: ICMP_uge:  [ 0.00  0.00 ]
+Key: ICMP_ult:  [ 0.00  0.00 ]
+Key: ICMP_ule:  [ 0.00  0.00 ]
+Key: ICMP_sgt:  [ 0.00  0.00 ]
+Key: ICMP_sge:  [ 0.00  0.00 ]
+Key: ICMP_slt:  [ 0.00  0.00 ]
+Key: ICMP_sle:  [ 0.00  0.00 ]
diff --git a/llvm/test/Analysis/IR2Vec/if-else.ll b/llvm/test/Analysis/IR2Vec/if-else.ll
index fe532479086d3..804c1ca5cb6f6 100644
--- a/llvm/test/Analysis/IR2Vec/if-else.ll
+++ b/llvm/test/Analysis/IR2Vec/if-else.ll
@@ -29,7 +29,7 @@ return:                                           ; preds = %if.else, %if.then
 
 ; CHECK: Basic block vectors:
 ; CHECK-NEXT: Basic block: entry:
-; CHECK-NEXT: [ 816.00  825.00  834.00 ]
+; CHECK-NEXT: [ 816.20  825.20  834.20 ]
 ; CHECK-NEXT: Basic block: if.then:
 ; CHECK-NEXT: [ 195.00  198.00  201.00 ]
 ; CHECK-NEXT: Basic block: if.else:
diff --git a/llvm/test/Analysis/IR2Vec/unreachable.ll b/llvm/test/Analysis/IR2Vec/unreachable.ll
index b0e3e49978018..9be0ee1c2de7a 100644
--- a/llvm/test/Analysis/IR2Vec/unreachable.ll
+++ b/llvm/test/Analysis/IR2Vec/unreachable.ll
@@ -33,7 +33,7 @@ return:                                           ; preds = %if.else, %if.then
 
 ; CHECK: Basic block vectors:
 ; CHECK-NEXT: Basic block: entry:
-; CHECK-NEXT: [ 816.00  825.00  834.00 ]
+; CHECK-NEXT: [ 816.20  825.20  834.20 ]
 ; CHECK-NEXT: Basic block: if.then:
 ; CHECK-NEXT: [ 195.00  198.00  201.00 ]
 ; CHECK-NEXT: Basic block: if.else:
diff --git a/llvm/test/tools/llvm-ir2vec/entities.ll b/llvm/test/tools/llvm-ir2vec/entities.ll
index 4b51adf30bf74..8dbce57302f6f 100644
--- a/llvm/test/tools/llvm-ir2vec/entities.ll
+++ b/llvm/test/tools/llvm-ir2vec/entities.ll
@@ -1,6 +1,6 @@
 ; RUN: llvm-ir2vec entities | FileCheck %s
 
-CHECK: 84
+CHECK: 110
 CHECK-NEXT: Ret     0
 CHECK-NEXT: Br      1
 CHECK-NEXT: Switch  2
@@ -85,3 +85,29 @@ CHECK-NEXT: Function        80
 CHECK-NEXT: Pointer 81
 CHECK-NEXT: Constant        82
 CHECK-NEXT: Variable        83
+CHECK-NEXT: FCMP_false   84
+CHECK-NEXT: FCMP_oeq     85
+CHECK-NEXT: FCMP_ogt     86
+CHECK-NEXT: FCMP_oge     87
+CHECK-NEXT: FCMP_olt     88
+CHECK-NEXT: FCMP_ole     89
+CHECK-NEXT: FCMP_one     90
+CHECK-NEXT: FCMP_ord     91
+CHECK-NEXT: FCMP_uno     92
+CHECK-NEXT: FCMP_ueq     93
+CHECK-NEXT: FCMP_ugt     94
+CHECK-NEXT: FCMP_uge     95
+CHECK-NEXT: FCMP_ult     96
+CHECK-NEXT: FCMP_ule     97
+CHECK-NEXT: FCMP_une     98
+CHECK-NEXT: FCMP_true    99
+CHECK-NEXT: ICMP_eq      100
+CHECK-NEXT: ICMP_ne      101
+CHECK-NEXT: ICMP_ugt     102
+CHECK-NEXT: ICMP_uge     103
+CHECK-NEXT: ICMP_ult     104
+CHECK-NEXT: ICMP_ule     105
+CHECK-NEXT: ICMP_sgt     106
+CHECK-NEXT: ICMP_sge     107
+CHECK-NEXT: ICMP_slt     108
+CHECK-NEXT: ICMP_sle     109
diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
index aabebf0cc90a9..1c656b8fcf4e7 100644
--- a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
+++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
@@ -184,7 +184,7 @@ class IR2VecTool {
         // Add "Arg" relationships
         unsigned ArgIndex = 0;
         for (const Use &U : I.operands()) {
-          unsigned OperandID = Vocabulary::getSlotIndex(*U);
+          unsigned OperandID = Vocabulary::getSlotIndex(*U.get());
           unsigned RelationID = ArgRelation + ArgIndex;
           OS << Opcode << '\t' << OperandID << '\t' << RelationID << '\n';
 
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp
index 9f2f6a3496ce0..9bc48e45eab5e 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -435,6 +435,7 @@ static constexpr unsigned MaxOpcodes = Vocabulary::MaxOpcodes;
 static constexpr unsigned MaxTypeIDs = Vocabulary::MaxTypeIDs;
 static constexpr unsigned MaxCanonicalTypeIDs = Vocabulary::MaxCanonicalTypeIDs;
 static constexpr unsigned MaxOperands = Vocabulary::MaxOperandKinds;
+static constexpr unsigned MaxPredicateKinds = Vocabulary::MaxPredicateKinds;
 
 // Mapping between LLVM Type::TypeID tokens and Vocabulary::CanonicalTypeID
 // names and their canonical string keys.
@@ -460,7 +461,8 @@ TEST(IR2VecVocabularyTest, DummyVocabTest) {
       EXPECT_EQ(Emb.size(), Dim);
 
     // Should have the correct total number of embeddings
-    EXPECT_EQ(VocabVecSize, MaxOpcodes + MaxCanonicalTypeIDs + MaxOperands);
+    EXPECT_EQ(VocabVecSize, MaxOpcodes + MaxCanonicalTypeIDs + MaxOperands +
+                                MaxPredicateKinds);
 
     auto ExpectedVocab = VocabVec;
 
@@ -527,6 +529,26 @@ TEST(IR2VecVocabularyTest, SlotIdxMapping) {
   EXPECT_EQ(Vocabulary::getSlotIndex(*Arg),
             EXPECTED_VOCAB_OPERAND_SLOT(Vocabulary::OperandKind::VariableID));
 #undef EXPECTED_VOCAB_OPERAND_SLOT
+
+  // Test getSlotIndex for predicates
+#define EXPECTED_VOCAB_PREDICATE_SLOT(X)                                       \
+  MaxOpcodes + MaxCanonicalTypeIDs + MaxOperands + static_cast<unsigned>(X)
+  for (unsigned P = CmpInst::FIRST_FCMP_PREDICATE;
+       P <= CmpInst::LAST_FCMP_PREDICATE; ++P) {
+    CmpInst::Predicate Pred = static_cast<CmpInst::Predicate>(P);
+    unsigned ExpectedIdx =
+        EXPECTED_VOCAB_PREDICATE_SLOT((P - CmpInst::FIRST_FCMP_PREDICATE));
+    EXPECT_EQ(Vocabulary::getSlotIndex(Pred), ExpectedIdx);
+  }
+  auto ICMP_Start = CmpInst::LAST_FCMP_PREDICATE + 1;
+  for (unsigned P = CmpInst::FIRST_ICMP_PREDICATE;
+       P <= CmpInst::LAST_ICMP_PREDICATE; ++P) {
+    CmpInst::Predicate Pred = static_cast<CmpInst::Predicate>(P);
+    unsigned ExpectedIdx = EXPECTED_VOCAB_PREDICATE_SLOT(
+        ICMP_Start + P - CmpInst::FIRST_ICMP_PREDICATE);
+    EXPECT_EQ(Vocabulary::getSlotIndex(Pred), ExpectedIdx);
+  }
+#undef EXPECTED_VOCAB_PREDICATE_SLOT
 }
 
 #if GTEST_HAS_DEATH_TEST
@@ -569,6 +591,7 @@ TEST(IR2VecVocabularyTest, StringKeyGeneration) {
 
 #undef EXPECT_CANONICAL_TYPE_NAME
 
+  // Verify OperandKind -> string mapping
 #define HANDLE_OPERAND_KINDS(X)                                                \
   X(FunctionID, "Function")                                                    \
   X(PointerID, "Pointer")                                                      \
@@ -592,6 +615,28 @@ TEST(IR2VecVocabularyTest, StringKeyGeneration) {
       Vocabulary::getStringKey(MaxOpcodes + MaxCanonicalTypeIDs + 1);
   EXPECT_EQ(FuncArgKey, "Function");
   EXPECT_EQ(PtrArgKey, "Pointer");
+
+// Verify PredicateKind -> string mapping
+#define EXPECT_PREDICATE_KIND(PredNum, PredPos, PredKind)                      \
+  do {                                                                         \
+    std::string PredStr =                                                      \
+        std::string(PredKind) + "_" +                                          \
+        CmpInst::getPredicateName(static_cast<CmpInst::Predicate>(PredNum))    \
+            .str();                                                            \
+    unsigned Pos = MaxOpcodes + MaxCanonicalTypeIDs + MaxOperands + PredPos;   \
+    EXPECT_EQ(Vocabulary::getStringKey(Pos), PredStr);                         \
+  } while (0)
+
+  for (unsigned P = CmpInst::FIRST_FCMP_PREDICATE;
+       P <= CmpInst::LAST_FCMP_PREDICATE; ++P)
+    EXPECT_PREDICATE_KIND(P, P - CmpInst::FIRST_FCMP_PREDICATE, "FCMP");
+
+  auto ICMP_Pos = CmpInst::LAST_FCMP_PREDICATE + 1;
+  for (unsigned P = CmpInst::FIRST_ICMP_PREDICATE;
+       P <= CmpInst::LAST_ICMP_PREDICATE; ++P)
+    EXPECT_PREDICATE_KIND(P, ICMP_Pos++, "ICMP");
+
+#undef EXPECT_PREDICATE_KIND
 }
 
 TEST(IR2VecVocabularyTest, VocabularyDimensions) {



More information about the llvm-commits mailing list