[llvm] Adding IR2Vec as an analysis pass (PR #134004)

S. VenkataKeerthy via llvm-commits llvm-commits at lists.llvm.org
Tue Apr 1 16:18:37 PDT 2025


https://github.com/svkeerthy created https://github.com/llvm/llvm-project/pull/134004

This PR introduces IR2Vec as an analysis pass. The changes include:
- Adding logic for generating Symbolic encodings.
- Pretrained 75D vocabulary.
- lit tests.

>From 048b61df841305d89fc02bb0370cf8e290cd342f Mon Sep 17 00:00:00 2001
From: svkeerthy <venkatakeerthy at google.com>
Date: Tue, 1 Apr 2025 23:01:07 +0000
Subject: [PATCH 1/2] Adding IR2Vec as an analysis pass

---
 llvm/include/llvm/Analysis/IR2VecAnalysis.h   | 134 ++++++
 llvm/lib/Analysis/CMakeLists.txt              |   1 +
 llvm/lib/Analysis/IR2VecAnalysis.cpp          | 425 ++++++++++++++++++
 llvm/lib/Passes/PassBuilder.cpp               |   1 +
 llvm/lib/Passes/PassRegistry.def              |   3 +
 .../IR2Vec/Inputs/dummy_3D_vocab.json         |   7 +
 .../IR2Vec/Inputs/dummy_5D_vocab.json         |  11 +
 llvm/test/Analysis/IR2Vec/basic.ll            |  50 +++
 llvm/test/Analysis/IR2Vec/if-else.ll          |  38 ++
 9 files changed, 670 insertions(+)
 create mode 100644 llvm/include/llvm/Analysis/IR2VecAnalysis.h
 create mode 100644 llvm/lib/Analysis/IR2VecAnalysis.cpp
 create mode 100644 llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_vocab.json
 create mode 100644 llvm/test/Analysis/IR2Vec/Inputs/dummy_5D_vocab.json
 create mode 100644 llvm/test/Analysis/IR2Vec/basic.ll
 create mode 100644 llvm/test/Analysis/IR2Vec/if-else.ll

diff --git a/llvm/include/llvm/Analysis/IR2VecAnalysis.h b/llvm/include/llvm/Analysis/IR2VecAnalysis.h
new file mode 100644
index 0000000000000..9f61e8a3f5106
--- /dev/null
+++ b/llvm/include/llvm/Analysis/IR2VecAnalysis.h
@@ -0,0 +1,134 @@
+//===- IR2VecAnalysis.h - IR2Vec Analysis Implementation -------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM
+// Exceptions. See the LICENSE file for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// This file contains the declaration of IR2VecAnalysis that computes
+/// IR2Vec Embeddings of the program.
+///
+/// Program Embeddings are typically or derived-from a learned
+/// representation of the program. Such embeddings are used to represent the
+/// programs as input to machine learning algorithms. IR2Vec represents the
+/// LLVM IR as embeddings.
+///
+/// The IR2Vec algorithm is described in the following paper:
+///
+///   IR2Vec: LLVM IR Based Scalable Program Embeddings, S. VenkataKeerthy,
+///   Rohit Aggarwal, Shalini Jain, Maunendra Sankar Desarkar, Ramakrishna
+///   Upadrasta, and Y. N. Srikant, ACM Transactions on Architecture and
+///   Code Optimization (TACO), 2020. https://doi.org/10.1145/3418463.
+///   https://arxiv.org/abs/1909.06228
+///
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_ANALYSIS_IR2VECANALYSIS_H
+#define LLVM_ANALYSIS_IR2VECANALYSIS_H
+
+#include "llvm/ADT/MapVector.h"
+#include "llvm/IR/PassManager.h"
+#include <map>
+
+namespace llvm {
+
+class Module;
+class BasicBlock;
+class Instruction;
+class Function;
+
+namespace ir2vec {
+using Embedding = std::vector<double>;
+// ToDo: Current the keys are strings. This can be changed to
+// use integers for cheaper lookups.
+using Vocab = std::map<std::string, Embedding>;
+} // namespace ir2vec
+
+class VocabResult;
+class IR2VecResult;
+
+/// This analysis provides the vocabulary for IR2Vec. The vocabulary provides a
+/// mapping between an entity of the IR (like opcode, type, argument, etc.) and
+/// its corresponding embedding.
+class VocabAnalysis : public AnalysisInfoMixin<VocabAnalysis> {
+  unsigned DIM = 0;
+  ir2vec::Vocab Vocabulary;
+  Error readVocabulary();
+
+public:
+  static AnalysisKey Key;
+  VocabAnalysis() = default;
+  using Result = VocabResult;
+  Result run(Module &M, ModuleAnalysisManager &MAM);
+};
+
+class VocabResult {
+  ir2vec::Vocab Vocabulary;
+  bool Valid = false;
+  unsigned DIM = 0;
+
+public:
+  VocabResult() = default;
+  VocabResult(const ir2vec::Vocab &Vocabulary, unsigned Dim);
+
+  // Helper functions
+  bool isValid() const { return Valid; }
+  const ir2vec::Vocab &getVocabulary() const;
+  unsigned getDimension() const { return DIM; }
+  bool invalidate(Module &M, const PreservedAnalyses &PA,
+                  ModuleAnalysisManager::Invalidator &Inv);
+};
+
+class IR2VecResult {
+  SmallMapVector<const Instruction *, ir2vec::Embedding, 128> InstVecMap;
+  SmallMapVector<const BasicBlock *, ir2vec::Embedding, 16> BBVecMap;
+  ir2vec::Embedding FuncVector;
+  unsigned DIM = 0;
+  bool Valid = false;
+
+public:
+  IR2VecResult() = default;
+  IR2VecResult(
+      SmallMapVector<const Instruction *, ir2vec::Embedding, 128> InstMap,
+      SmallMapVector<const BasicBlock *, ir2vec::Embedding, 16> BBMap,
+      const ir2vec::Embedding &FuncVector, unsigned Dim);
+  bool isValid() const { return Valid; }
+
+  const SmallMapVector<const Instruction *, ir2vec::Embedding, 128> &
+  getInstVecMap() const;
+  const SmallMapVector<const BasicBlock *, ir2vec::Embedding, 16> &
+  getBBVecMap() const;
+  const ir2vec::Embedding &getFunctionVector() const;
+  unsigned getDimension() const;
+};
+
+/// This analysis provides the IR2Vec embeddings for instructions, basic blocks,
+/// and functions.
+class IR2VecAnalysis : public AnalysisInfoMixin<IR2VecAnalysis> {
+  bool Avg;
+  float WO = 1, WT = 0.5, WA = 0.2;
+
+public:
+  IR2VecAnalysis() = default;
+  static AnalysisKey Key;
+  using Result = IR2VecResult;
+  Result run(Function &F, FunctionAnalysisManager &FAM);
+};
+
+/// This pass prints the IR2Vec embeddings for instructions, basic blocks, and
+/// functions.
+class IR2VecPrinterPass : public PassInfoMixin<IR2VecPrinterPass> {
+  raw_ostream &OS;
+  void printVector(const ir2vec::Embedding &Vec) const;
+
+public:
+  explicit IR2VecPrinterPass(raw_ostream &OS) : OS(OS) {}
+  PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
+  static bool isRequired() { return true; }
+};
+
+} // namespace llvm
+
+#endif // LLVM_ANALYSIS_IR2VECANALYSIS_H
diff --git a/llvm/lib/Analysis/CMakeLists.txt b/llvm/lib/Analysis/CMakeLists.txt
index fbf3b587d6bd2..8a6399f756f27 100644
--- a/llvm/lib/Analysis/CMakeLists.txt
+++ b/llvm/lib/Analysis/CMakeLists.txt
@@ -67,6 +67,7 @@ add_llvm_component_library(LLVMAnalysis
   GlobalsModRef.cpp
   GuardUtils.cpp
   HeatUtils.cpp
+  IR2VecAnalysis.cpp
   IRSimilarityIdentifier.cpp
   IVDescriptors.cpp
   IVUsers.cpp
diff --git a/llvm/lib/Analysis/IR2VecAnalysis.cpp b/llvm/lib/Analysis/IR2VecAnalysis.cpp
new file mode 100644
index 0000000000000..e6b170d6053ab
--- /dev/null
+++ b/llvm/lib/Analysis/IR2VecAnalysis.cpp
@@ -0,0 +1,425 @@
+//===- IR2VecAnalysis.cpp - IR2Vec Analysis Implementation ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM
+// Exceptions. See the LICENSE file for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// This file implements the IR2Vec algorithm.
+///
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/IR2VecAnalysis.h"
+
+#include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/Statistic.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/Errc.h"
+#include "llvm/Support/Error.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/Format.h"
+#include "llvm/Support/JSON.h"
+#include "llvm/Support/MemoryBuffer.h"
+
+using namespace llvm;
+using namespace ir2vec;
+
+#define DEBUG_TYPE "ir2vec"
+
+STATISTIC(dataMissCounter, "Number of data misses in the vocabulary");
+
+/// IR2Vec computes two kinds of embeddings: Symbolic and Flow-aware.
+/// Symbolic embeddings capture the "syntactic" and "statistical correlation"
+/// of the IR entities. Flow-aware embeddings build on top of symbolic
+/// embeddings and additionally capture the flow information in the IR.
+/// IR2VecKind is used to specify the type of embeddings to generate.
+// ToDo: Currently we support only Symbolic.
+// We shall add support for Flow-aware in upcoming patches.
+enum IR2VecKind { symbolic, flowaware };
+
+static cl::OptionCategory IR2VecAnalysisCategory("IR2Vec Analysis Options");
+
+cl::opt<IR2VecKind>
+    IR2VecMode("ir2vec-mode",
+               cl::desc("Choose type of embeddings to generate:"),
+               cl::values(clEnumValN(symbolic, "symbolic",
+                                     "Generates symbolic embeddings"),
+                          clEnumValN(flowaware, "flowaware",
+                                     "Generates flow-aware embeddings")),
+               cl::init(symbolic), cl::cat(IR2VecAnalysisCategory));
+
+// ToDo: Use a default vocab when not specified
+static cl::opt<std::string>
+    VocabFile("ir2vec-vocab-path", cl::Optional,
+              cl::desc("Path to the vocabulary file for IR2Vec"), cl::init(""),
+              cl::cat(IR2VecAnalysisCategory));
+
+AnalysisKey VocabAnalysis::Key;
+AnalysisKey IR2VecAnalysis::Key;
+
+// ==----------------------------------------------------------------------===//
+// Embeddings and its subclasses
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Embeddings provides the interface to generate vector representations for
+/// instructions, basic blocks, and functions. The vector
+/// representations are generated using IR2Vec algorithms.
+///
+/// The Embeddings class is an abstract class and it is intended to be
+/// subclassed for different IR2Vec algorithms like Symbolic and Flow-aware.
+class Embeddings {
+protected:
+  const Function &F;
+  Vocab Vocabulary;
+
+  /// Weights for different entities (like opcode, arguments, types)
+  /// in the IR instructions to generate the vector representation.
+  // ToDo: Defaults to the values used in the original algorithm. Can be
+  // parameterized later.
+  float WO = 1.0, WT = 0.5, WA = 0.2;
+
+  /// Dimension of the vector representation; captured from the input vocabulary
+  unsigned DIM = 300;
+
+  // Utility maps - these are used to store the vector representations of
+  // instructions, basic blocks and functions.
+  Embedding FuncVector;
+  SmallMapVector<const BasicBlock *, Embedding, 16> BBVecMap;
+  SmallMapVector<const Instruction *, Embedding, 128> InstVecMap;
+
+  Embeddings(const Function &F, const Vocab &Vocabulary, unsigned DIM)
+      : F(F), Vocabulary(Vocabulary), DIM(DIM) {}
+
+  /// Lookup vocabulary for a given Key. If the key is not found, it returns a
+  /// zero vector.
+  Embedding lookupVocab(const std::string &Key);
+
+public:
+  virtual ~Embeddings() = default;
+
+  /// Top level function to compute embeddings. Given a function, it
+  /// generates embeddings for all the instructions and basic blocks in that
+  /// function. Logic of computing the embeddings is specific to the kind of
+  /// embeddings being computed.
+  virtual void computeEmbeddings() = 0;
+
+  /// Returns the dimension of the embedding vector.
+  unsigned getDimension() const { return DIM; }
+
+  /// Returns a map containing instructions and the corresponding vector
+  /// representations for a given module corresponding to the IR2Vec
+  /// algorithm.
+  const SmallMapVector<const Instruction *, Embedding, 128> &
+  getInstVecMap() const {
+    return InstVecMap;
+  }
+
+  /// Returns a map containing basic block and the corresponding vector
+  /// representations for a given module corresponding to the IR2Vec
+  /// algorithm.
+  const SmallMapVector<const BasicBlock *, Embedding, 16> &getBBVecMap() const {
+    return BBVecMap;
+  }
+
+  /// Returns the vector representation for a given function corresponding to
+  /// the IR2Vec algorithm.
+  const Embedding &getFunctionVector() const { return FuncVector; }
+};
+
+/// Class for computing the Symbolic embeddings of IR2Vec
+class Symbolic : public Embeddings {
+private:
+  /// Utility function to compute the vector representation for a given basic
+  /// block.
+  Embedding computeBB2Vec(const BasicBlock &BB);
+
+  /// Utility function to compute the vector representation for a given
+  /// function.
+  Embedding computeFunc2Vec();
+
+public:
+  Symbolic(const Function &F, const Vocab &Vocabulary, unsigned DIM)
+      : Embeddings(F, Vocabulary, DIM) {
+    FuncVector = Embedding(DIM, 0);
+  }
+  void computeEmbeddings() override;
+};
+
+/// Scales the vector Vec by Factor
+void scaleVector(Embedding &Vec, const float Factor) {
+  std::transform(Vec.begin(), Vec.end(), Vec.begin(),
+                 [Factor](double X) { return X * Factor; });
+}
+
+/// Adds two vectors: Vec += Vec2
+void addVectors(Embedding &Vec, const Embedding &Vec2) {
+  std::transform(Vec.begin(), Vec.end(), Vec2.begin(), Vec.begin(),
+                 std::plus<double>());
+}
+
+// ToDo: Currently lookups are string based. Use numeric Keys
+// for efficiency.
+Embedding Embeddings::lookupVocab(const std::string &Key) {
+  Embedding Vec(DIM, 0);
+  // ToDo: Use zero vectors in vocab and assert failure for
+  // unknown entities rather than silently returning zeroes here.
+  if (Vocabulary.find(Key) == Vocabulary.end()) {
+    LLVM_DEBUG(errs() << "cannot find key in map : " << Key << "\n");
+    dataMissCounter++;
+  } else {
+    Vec = Vocabulary[Key];
+  }
+  return Vec;
+}
+
+void Symbolic::computeEmbeddings() {
+  if (F.isDeclaration())
+    return;
+  for (auto &BB : F) {
+    BBVecMap[&BB] = computeBB2Vec(BB);
+    addVectors(FuncVector, BBVecMap[&BB]);
+  }
+}
+
+Embedding Symbolic::computeBB2Vec(const BasicBlock &BB) {
+  auto It = BBVecMap.find(&BB);
+  if (It != BBVecMap.end()) {
+    return It->second;
+  }
+  Embedding BBVector(DIM, 0);
+
+  for (auto &I : BB) {
+    Embedding InstVector(DIM, 0);
+
+    auto Vec = lookupVocab(I.getOpcodeName());
+    scaleVector(Vec, WO);
+    addVectors(InstVector, Vec);
+
+    auto Type = I.getType();
+    if (Type->isVoidTy()) {
+      Vec = lookupVocab("voidTy");
+    } else if (Type->isFloatingPointTy()) {
+      Vec = lookupVocab("floatTy");
+    } else if (Type->isIntegerTy()) {
+      Vec = lookupVocab("integerTy");
+    } else if (Type->isFunctionTy()) {
+      Vec = lookupVocab("functionTy");
+    } else if (Type->isStructTy()) {
+      Vec = lookupVocab("structTy");
+    } else if (Type->isArrayTy()) {
+      Vec = lookupVocab("arrayTy");
+    } else if (Type->isPointerTy()) {
+      Vec = lookupVocab("pointerTy");
+    } else if (Type->isVectorTy()) {
+      Vec = lookupVocab("vectorTy");
+    } else if (Type->isEmptyTy()) {
+      Vec = lookupVocab("emptyTy");
+    } else if (Type->isLabelTy()) {
+      Vec = lookupVocab("labelTy");
+    } else if (Type->isTokenTy()) {
+      Vec = lookupVocab("tokenTy");
+    } else if (Type->isMetadataTy()) {
+      Vec = lookupVocab("metadataTy");
+    } else {
+      Vec = lookupVocab("unknownTy");
+    }
+    scaleVector(Vec, WT);
+    addVectors(InstVector, Vec);
+
+    for (auto &Op : I.operands()) {
+      Embedding Vec;
+      if (isa<Function>(Op)) {
+        Vec = lookupVocab("function");
+      } else if (isa<PointerType>(Op->getType())) {
+        Vec = lookupVocab("pointer");
+      } else if (isa<Constant>(Op)) {
+        Vec = lookupVocab("constant");
+      } else {
+        Vec = lookupVocab("variable");
+      }
+      scaleVector(Vec, WA);
+      addVectors(InstVector, Vec);
+      InstVecMap[&I] = InstVector;
+    }
+    addVectors(BBVector, InstVector);
+  }
+  return BBVector;
+}
+} // namespace
+
+// ==----------------------------------------------------------------------===//
+// VocabResult and VocabAnalysis
+//===----------------------------------------------------------------------===//
+
+VocabResult::VocabResult(const ir2vec::Vocab &Vocabulary, unsigned Dim)
+    : Vocabulary(std::move(Vocabulary)), Valid(true), DIM(Dim) {}
+
+const ir2vec::Vocab &VocabResult::getVocabulary() const {
+  assert(Valid);
+  return Vocabulary;
+}
+
+// For now, assume vocabulary is stable unless explicitly invalidated.
+bool VocabResult::invalidate(Module &M, const PreservedAnalyses &PA,
+                             ModuleAnalysisManager::Invalidator &Inv) {
+  auto PAC = PA.getChecker<VocabAnalysis>();
+  return !(PAC.preservedWhenStateless());
+}
+
+// ToDo: Make this optional. We can avoid file reads
+// by auto-generating the vocabulary during the build time.
+Error VocabAnalysis::readVocabulary() {
+  auto BufOrError = MemoryBuffer::getFileOrSTDIN(VocabFile, /*IsText=*/true);
+  if (!BufOrError) {
+    return createFileError(VocabFile, BufOrError.getError());
+  }
+  auto Content = BufOrError.get()->getBuffer();
+  json::Path::Root Path("");
+  Expected<json::Value> ParsedVocabValue = json::parse(Content);
+  if (!ParsedVocabValue)
+    return ParsedVocabValue.takeError();
+
+  bool Res = json::fromJSON(*ParsedVocabValue, Vocabulary, Path);
+  if (!Res) {
+    return createStringError(errc::illegal_byte_sequence,
+                             "Unable to parse the vocabulary");
+  }
+  assert(Vocabulary.size() > 0 && "Vocabulary is empty");
+
+  unsigned Dim = Vocabulary.begin()->second.size();
+  assert(Dim > 0 && "Dimension of vocabulary is zero");
+  assert(std::all_of(Vocabulary.begin(), Vocabulary.end(),
+                     [Dim](const std::pair<StringRef, Embedding> &Entry) {
+                       return Entry.second.size() == Dim;
+                     }) &&
+         "All vectors in the vocabulary are not of the same dimension");
+  this->DIM = Dim;
+  return Error::success();
+}
+
+VocabAnalysis::Result VocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
+  if (VocabFile.empty()) {
+    // ToDo: Use default vocabulary
+    errs() << "Error: IR2Vec vocabulary file path not specified.\n";
+    return VocabResult(); // Return invalid result
+  }
+
+  if (auto Err = readVocabulary())
+    return VocabResult();
+
+  return VocabResult(std::move(Vocabulary), DIM);
+}
+
+// ==----------------------------------------------------------------------===//
+// IR2VecResult and IR2VecAnalysis
+//===----------------------------------------------------------------------===//
+
+IR2VecResult::IR2VecResult(
+    const SmallMapVector<const Instruction *, Embedding, 128> InstMap,
+    const SmallMapVector<const BasicBlock *, Embedding, 16> BBMap,
+    const Embedding &FuncVector, unsigned Dim)
+    : InstVecMap(std::move(InstMap)), BBVecMap(std::move(BBMap)),
+      FuncVector(std::move(FuncVector)), DIM(Dim), Valid(true) {}
+
+const SmallMapVector<const Instruction *, Embedding, 128> &
+IR2VecResult::getInstVecMap() const {
+  assert(Valid);
+  return InstVecMap;
+}
+const SmallMapVector<const BasicBlock *, Embedding, 16> &
+IR2VecResult::getBBVecMap() const {
+  assert(Valid);
+  return BBVecMap;
+}
+const Embedding &IR2VecResult::getFunctionVector() const {
+  assert(Valid);
+  return FuncVector;
+}
+unsigned IR2VecResult::getDimension() const { return DIM; }
+IR2VecAnalysis::Result IR2VecAnalysis::run(Function &F,
+                                           FunctionAnalysisManager &FAM) {
+  auto *VocabRes = FAM.getResult<ModuleAnalysisManagerFunctionProxy>(F)
+                       .getCachedResult<VocabAnalysis>(*F.getParent());
+  if (!VocabRes->isValid()) {
+    errs() << "Error: IR2Vec vocabulary is invalid.\n";
+    return IR2VecResult();
+  }
+
+  auto Dim = VocabRes->getDimension();
+  if (Dim <= 0) {
+    errs() << "Error: IR2Vec vocabulary dimension is zero.\n";
+    return IR2VecResult();
+  }
+
+  auto Vocabulary = VocabRes->getVocabulary();
+  std::unique_ptr<Embeddings> Emb;
+  switch (IR2VecMode) {
+  case IR2VecKind::symbolic:
+    Emb = std::make_unique<Symbolic>(F, Vocabulary, Dim);
+    break;
+  case flowaware:
+    // ToDo: Add support for flow-aware embeddings
+    llvm_unreachable("Flow-aware embeddings are not supported yet");
+    break;
+  default:
+    llvm_unreachable("Invalid IR2Vec mode");
+  }
+  Emb->computeEmbeddings();
+  auto InstMap = Emb->getInstVecMap();
+  auto BBMap = Emb->getBBVecMap();
+  auto FuncVec = Emb->getFunctionVector();
+  return IR2VecResult(std::move(InstMap), std::move(BBMap), std::move(FuncVec),
+                      Dim);
+}
+
+// ==----------------------------------------------------------------------===//
+// IR2VecPrinterPass
+//===----------------------------------------------------------------------===//
+
+void IR2VecPrinterPass::printVector(const Embedding &Vec) const {
+  OS << " [";
+  for (auto &Elem : Vec)
+    OS << " " << format("%.2f", Elem) << " ";
+  OS << "]\n";
+}
+
+PreservedAnalyses IR2VecPrinterPass::run(Module &M,
+                                         ModuleAnalysisManager &MAM) {
+  auto VocabResult = MAM.getResult<VocabAnalysis>(M);
+  assert(VocabResult.isValid() && "Vocab is invalid");
+
+  for (Function &F : M) {
+    auto &FAM =
+        MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
+
+    auto IR2VecRes = FAM.getResult<IR2VecAnalysis>(F);
+    if (!IR2VecRes.isValid()) {
+      errs() << "Error: IR2Vec embeddings are invalid.\n";
+      return PreservedAnalyses::all();
+    }
+
+    OS << "IR2Vec embeddings for function " << F.getName() << ":\n";
+    OS << "Function vector: ";
+    printVector(IR2VecRes.getFunctionVector());
+
+    OS << "Basic block vectors:\n";
+    for (const auto &BBVector : IR2VecRes.getBBVecMap()) {
+      OS << "Basic block: " << BBVector.first->getName() << ":\n";
+      printVector(BBVector.second);
+    }
+
+    OS << "Instruction vectors:\n";
+    for (const auto &InstVector : IR2VecRes.getInstVecMap()) {
+      OS << "Instruction: ";
+      InstVector.first->print(OS);
+      printVector(InstVector.second);
+    }
+  }
+  return PreservedAnalyses::all();
+}
diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp
index 5cda1517e127d..3aecbd2f82c17 100644
--- a/llvm/lib/Passes/PassBuilder.cpp
+++ b/llvm/lib/Passes/PassBuilder.cpp
@@ -42,6 +42,7 @@
 #include "llvm/Analysis/EphemeralValuesCache.h"
 #include "llvm/Analysis/FunctionPropertiesAnalysis.h"
 #include "llvm/Analysis/GlobalsModRef.h"
+#include "llvm/Analysis/IR2VecAnalysis.h"
 #include "llvm/Analysis/IRSimilarityIdentifier.h"
 #include "llvm/Analysis/IVUsers.h"
 #include "llvm/Analysis/InlineAdvisor.h"
diff --git a/llvm/lib/Passes/PassRegistry.def b/llvm/lib/Passes/PassRegistry.def
index 510a505995304..5baae1ff636ae 100644
--- a/llvm/lib/Passes/PassRegistry.def
+++ b/llvm/lib/Passes/PassRegistry.def
@@ -25,6 +25,7 @@ MODULE_ANALYSIS("dxil-metadata", DXILMetadataAnalysis())
 MODULE_ANALYSIS("dxil-resource-binding", DXILResourceBindingAnalysis())
 MODULE_ANALYSIS("dxil-resource-type", DXILResourceTypeAnalysis())
 MODULE_ANALYSIS("inline-advisor", InlineAdvisorAnalysis())
+MODULE_ANALYSIS("ir2vec-vocab", VocabAnalysis())
 MODULE_ANALYSIS("ir-similarity", IRSimilarityAnalysis())
 MODULE_ANALYSIS("last-run-tracking", LastRunTrackingAnalysis())
 MODULE_ANALYSIS("lcg", LazyCallGraphAnalysis())
@@ -131,6 +132,7 @@ MODULE_PASS("print<dxil-metadata>", DXILMetadataAnalysisPrinterPass(errs()))
 MODULE_PASS("print<dxil-resource-binding>",
             DXILResourceBindingPrinterPass(errs()))
 MODULE_PASS("print<inline-advisor>", InlineAdvisorAnalysisPrinterPass(errs()))
+MODULE_PASS("print<ir2vec>", IR2VecPrinterPass(errs()))
 MODULE_PASS("print<module-debuginfo>", ModuleDebugInfoPrinterPass(errs()))
 MODULE_PASS("print<reg-usage>", PhysicalRegisterUsageInfoPrinterPass(errs()))
 MODULE_PASS("pseudo-probe", SampleProfileProbePass(TM))
@@ -295,6 +297,7 @@ FUNCTION_ANALYSIS("func-properties", FunctionPropertiesAnalysis())
 FUNCTION_ANALYSIS("machine-function-info", MachineFunctionAnalysis(TM))
 FUNCTION_ANALYSIS("gc-function", GCFunctionAnalysis())
 FUNCTION_ANALYSIS("inliner-size-estimator", InlineSizeEstimatorAnalysis())
+FUNCTION_ANALYSIS("ir2vec", IR2VecAnalysis())
 FUNCTION_ANALYSIS("last-run-tracking", LastRunTrackingAnalysis())
 FUNCTION_ANALYSIS("lazy-value-info", LazyValueAnalysis())
 FUNCTION_ANALYSIS("loops", LoopAnalysis())
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_vocab.json b/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_vocab.json
new file mode 100644
index 0000000000000..5a9efb9566424
--- /dev/null
+++ b/llvm/test/Analysis/IR2Vec/Inputs/dummy_3D_vocab.json
@@ -0,0 +1,7 @@
+{
+    "alloca": [1, 2, 3],
+    "load": [4, 5, 6],
+    "store": [7, 8, 9],
+    "add": [10, 11, 12],
+    "mul": [13, 14, 15]
+}
\ No newline at end of file
diff --git a/llvm/test/Analysis/IR2Vec/Inputs/dummy_5D_vocab.json b/llvm/test/Analysis/IR2Vec/Inputs/dummy_5D_vocab.json
new file mode 100644
index 0000000000000..44f39d3facca3
--- /dev/null
+++ b/llvm/test/Analysis/IR2Vec/Inputs/dummy_5D_vocab.json
@@ -0,0 +1,11 @@
+{
+    "alloca": [-0.1, -0.2, -0.3, 1, 2],
+    "load": [-0.4, -0.5, -0.6, 4, 5],
+    "store": [-0.7, -0.8, -0.9, 7, 8],
+    "add": [-1.0, -1.1, -1.2, 10, 11],
+    "mul": [12, 13, 14, -1.2, -1.3],
+    "integerTy": [0.2, 0.4, 0.6, 1, 0.5],
+    "pointer": [0, 1, 2, 3, 4],
+    "variable": [2, 4, 6, 8, 10],
+    "ret": [5, 6, 7, 8, 9]
+}
\ No newline at end of file
diff --git a/llvm/test/Analysis/IR2Vec/basic.ll b/llvm/test/Analysis/IR2Vec/basic.ll
new file mode 100644
index 0000000000000..b10e735a54587
--- /dev/null
+++ b/llvm/test/Analysis/IR2Vec/basic.ll
@@ -0,0 +1,50 @@
+; RUN: opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/dummy_3D_vocab.json %s 2>&1 | FileCheck %s -check-prefix=3D-CHECK
+; RUN: opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/dummy_5D_vocab.json %s 2>&1 | FileCheck %s -check-prefix=5D-CHECK
+
+define dso_local i32 @abc(i32 %0, i32 %1) {
+entry:
+  %3 = alloca i32, align 4
+  %4 = alloca i32, align 4
+  store i32 %0, ptr %3, align 4
+  store i32 %1, ptr %4, align 4
+  %5 = load i32, ptr %3, align 4
+  %6 = load i32, ptr %4, align 4
+  %7 = load i32, ptr %3, align 4
+  %8 = mul nsw i32 %6, %7
+  %9 = add nsw i32 %5, %8
+  ret i32 %9
+}
+
+; 3D-CHECK: IR2Vec embeddings for function abc:
+; 3D-CHECK-NEXT: Function vector:  [ 51.00 60.00 69.00 ]
+; 3D-CHECK-NEXT: Basic block vectors:
+; 3D-CHECK-NEXT: Basic block: entry:
+; 3D-CHECK-NEXT:  [ 51.00 60.00 69.00 ]
+; 3D-CHECK-NEXT: Instruction vectors:
+; 3D-CHECK-NEXT: Instruction:   %2 = alloca i32, align 4 [ 1.00 2.00 3.00 ]
+; 3D-CHECK-NEXT: Instruction:   %3 = alloca i32, align 4 [ 1.00 2.00 3.00 ]
+; 3D-CHECK-NEXT: Instruction:   store i32 %0, ptr %2, align 4 [ 7.00 8.00 9.00 ]
+; 3D-CHECK-NEXT: Instruction:   store i32 %1, ptr %3, align 4 [ 7.00 8.00 9.00 ]
+; 3D-CHECK-NEXT: Instruction:   %4 = load i32, ptr %2, align 4 [ 4.00 5.00 6.00 ]
+; 3D-CHECK-NEXT: Instruction:   %5 = load i32, ptr %3, align 4 [ 4.00 5.00 6.00 ]
+; 3D-CHECK-NEXT: Instruction:   %6 = load i32, ptr %2, align 4 [ 4.00 5.00 6.00 ]
+; 3D-CHECK-NEXT: Instruction:   %7 = mul nsw i32 %5, %6 [ 13.00 14.00 15.00 ]
+; 3D-CHECK-NEXT: Instruction:   %8 = add nsw i32 %4, %7 [ 10.00 11.00 12.00 ]
+; 3D-CHECK-NEXT: Instruction:   ret i32 %8 [ 0.00 0.00 0.00 ]
+
+; 5D-CHECK: IR2Vec embeddings for function abc:
+; 5D-CHECK-NEXT: Function vector:  [ 16.50  22.00  27.50  61.50  72.95 ]
+; 5D-CHECK-NEXT: Basic block vectors:
+; 5D-CHECK-NEXT: Basic block: entry:
+; 5D-CHECK-NEXT:  [ 16.50  22.00  27.50  61.50  72.95 ]
+; 5D-CHECK-NEXT: Instruction vectors:
+; 5D-CHECK-NEXT: Instruction:   %2 = alloca i32, align 4 [ -0.10  -0.20  -0.30  1.00  2.00 ]
+; 5D-CHECK-NEXT: Instruction:   %3 = alloca i32, align 4 [ -0.10  -0.20  -0.30  1.00  2.00 ]
+; 5D-CHECK-NEXT: Instruction:   store i32 %0, ptr %2, align 4 [ -0.30  0.20  0.70  9.20  10.80 ]
+; 5D-CHECK-NEXT: Instruction:   store i32 %1, ptr %3, align 4 [ -0.30  0.20  0.70  9.20  10.80 ]
+; 5D-CHECK-NEXT: Instruction:   %4 = load i32, ptr %2, align 4 [ -0.30  -0.10  0.10  5.10  6.05 ]
+; 5D-CHECK-NEXT: Instruction:   %5 = load i32, ptr %3, align 4 [ -0.30  -0.10  0.10  5.10  6.05 ]
+; 5D-CHECK-NEXT: Instruction:   %6 = load i32, ptr %2, align 4 [ -0.30  -0.10  0.10  5.10  6.05 ]
+; 5D-CHECK-NEXT: Instruction:   %7 = mul nsw i32 %5, %6 [ 12.90  14.80  16.70  2.50  2.95 ]
+; 5D-CHECK-NEXT: Instruction:   %8 = add nsw i32 %4, %7 [ -0.10  0.70  1.50  13.70  15.25 ]
+; 5D-CHECK-NEXT: Instruction:   ret i32 %8 [ 5.40  6.80  8.20  9.60  11.00 ]
diff --git a/llvm/test/Analysis/IR2Vec/if-else.ll b/llvm/test/Analysis/IR2Vec/if-else.ll
new file mode 100644
index 0000000000000..b1c64224e5328
--- /dev/null
+++ b/llvm/test/Analysis/IR2Vec/if-else.ll
@@ -0,0 +1,38 @@
+; RUN: opt -passes='print<ir2vec>' -o /dev/null -ir2vec-vocab-path=%S/Inputs/dummy_3D_vocab.json %s 2>&1 | FileCheck %s
+
+define dso_local i32 @abc(i32 noundef %a, i32 noundef %b) #0 {
+entry:
+  %retval = alloca i32, align 4
+  %a.addr = alloca i32, align 4
+  %b.addr = alloca i32, align 4
+  store i32 %a, ptr %a.addr, align 4
+  store i32 %b, ptr %b.addr, align 4
+  %0 = load i32, ptr %a.addr, align 4
+  %1 = load i32, ptr %b.addr, align 4
+  %cmp = icmp sgt i32 %0, %1
+  br i1 %cmp, label %if.then, label %if.else
+
+if.then:                                          ; preds = %entry
+  %2 = load i32, ptr %b.addr, align 4
+  store i32 %2, ptr %retval, align 4
+  br label %return
+
+if.else:                                          ; preds = %entry
+  %3 = load i32, ptr %a.addr, align 4
+  store i32 %3, ptr %retval, align 4
+  br label %return
+
+return:                                           ; preds = %if.else, %if.then
+  %4 = load i32, ptr %retval, align 4
+  ret i32 %4
+}
+
+; CHECK: Basic block vectors:
+; CHECK-NEXT: Basic block: entry:
+; CHECK-NEXT:  [ 25.00 32.00 39.00 ]
+; CHECK-NEXT: Basic block: if.then:
+; CHECK-NEXT:  [ 11.00 13.00 15.00 ]
+; CHECK-NEXT: Basic block: if.else:
+; CHECK-NEXT:  [ 11.00 13.00 15.00 ]
+; CHECK-NEXT: Basic block: return:
+; CHECK-NEXT:  [ 4.00 5.00 6.00 ]

>From 9ba9971d92951fba988999086eac7a71abc1e73e Mon Sep 17 00:00:00 2001
From: svkeerthy <venkatakeerthy at google.com>
Date: Tue, 1 Apr 2025 23:10:14 +0000
Subject: [PATCH 2/2] Added 75D IR2Vec vocabulary

---
 llvm/lib/Analysis/models/seedEmbeddingVocab75D.json | 0
 1 file changed, 0 insertions(+), 0 deletions(-)
 create mode 100644 llvm/lib/Analysis/models/seedEmbeddingVocab75D.json

diff --git a/llvm/lib/Analysis/models/seedEmbeddingVocab75D.json b/llvm/lib/Analysis/models/seedEmbeddingVocab75D.json
new file mode 100644
index 0000000000000..e69de29bb2d1d



More information about the llvm-commits mailing list