[llvm] Adding IR2Vec as an analysis pass (PR #134004)

S. VenkataKeerthy via llvm-commits llvm-commits at lists.llvm.org
Tue Apr 8 19:46:20 PDT 2025


================
@@ -0,0 +1,429 @@
+//===- IR2VecAnalysis.cpp - IR2Vec Analysis Implementation ----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM
+// Exceptions. See the LICENSE file for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// This file implements the IR2Vec algorithm.
+///
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/IR2VecAnalysis.h"
+
+#include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/Statistic.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/Errc.h"
+#include "llvm/Support/Error.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/Format.h"
+#include "llvm/Support/JSON.h"
+#include "llvm/Support/MemoryBuffer.h"
+
+using namespace llvm;
+using namespace ir2vec;
+
+#define DEBUG_TYPE "ir2vec"
+
+STATISTIC(DataMissCounter, "Number of data misses in the vocabulary");
+
+/// IR2Vec computes two kinds of embeddings: Symbolic and Flow-aware.
+/// Symbolic embeddings capture the "syntactic" and "statistical correlation"
+/// of the IR entities. Flow-aware embeddings build on top of symbolic
+/// embeddings and additionally capture the flow information in the IR.
+/// IR2VecKind is used to specify the type of embeddings to generate.
+// FIXME: Currently we support only Symbolic.  Add support for
+// Flow-aware in upcoming patches.
+enum class IR2VecKind { Symbolic, Flowaware };
+
+static cl::OptionCategory IR2VecAnalysisCategory("IR2Vec Analysis Options");
+
+cl::opt<IR2VecKind>
+    IR2VecMode("ir2vec-mode",
+               cl::desc("Choose type of embeddings to generate:"),
+               cl::values(clEnumValN(IR2VecKind::Symbolic, "symbolic",
+                                     "Generates symbolic embeddings"),
+                          clEnumValN(IR2VecKind::Flowaware, "flowaware",
+                                     "Generates flow-aware embeddings")),
+               cl::init(IR2VecKind::Symbolic), cl::cat(IR2VecAnalysisCategory));
+
+// FIXME: Use a default vocab when not specified
+static cl::opt<std::string>
+    VocabFile("ir2vec-vocab-path", cl::Optional,
+              cl::desc("Path to the vocabulary file for IR2Vec"), cl::init(""),
+              cl::cat(IR2VecAnalysisCategory));
+
+AnalysisKey IR2VecVocabAnalysis::Key;
+AnalysisKey IR2VecAnalysis::Key;
+
+// ==----------------------------------------------------------------------===//
+// Embeddings and its subclasses
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// Embeddings provides the interface to generate vector representations for
+/// instructions, basic blocks, and functions. The vector
+/// representations are generated using IR2Vec algorithms.
+///
+/// The Embeddings class is an abstract class and it is intended to be
+/// subclassed for different IR2Vec algorithms like Symbolic and Flow-aware.
+class Embeddings {
+protected:
+  const Function &F;
+  Vocab Vocabulary;
+
+  /// Weights for different entities (like opcode, arguments, types)
+  /// in the IR instructions to generate the vector representation.
+  // FIXME: Defaults to the values used in the original algorithm. Can be
+  // parameterized later.
+  float WO = 1.0, WT = 0.5, WA = 0.2;
+
+  /// Dimension of the vector representation; captured from the input vocabulary
+  unsigned DIM = 300;
+
+  // Utility maps - these are used to store the vector representations of
+  // instructions, basic blocks and functions.
+  Embedding FuncVector;
+  SmallMapVector<const BasicBlock *, Embedding, 16> BBVecMap;
+  SmallMapVector<const Instruction *, Embedding, 128> InstVecMap;
+
+  Embeddings(const Function &F, const Vocab &Vocabulary, unsigned DIM)
+      : F(F), Vocabulary(Vocabulary), DIM(DIM) {}
+
+  /// Lookup vocabulary for a given Key. If the key is not found, it returns a
+  /// zero vector.
+  Embedding lookupVocab(const std::string &Key);
+
+public:
+  virtual ~Embeddings() = default;
+
+  /// Top level function to compute embeddings. Given a function, it
+  /// generates embeddings for all the instructions and basic blocks in that
+  /// function. Logic of computing the embeddings is specific to the kind of
+  /// embeddings being computed.
+  virtual void computeEmbeddings() = 0;
+
+  /// Returns a map containing instructions and the corresponding vector
+  /// representations for a given module corresponding to the IR2Vec
+  /// algorithm.
+  const SmallMapVector<const Instruction *, Embedding, 128> &
+  getInstVecMap() const {
+    return InstVecMap;
+  }
+
+  /// Returns a map containing basic block and the corresponding vector
+  /// representations for a given module corresponding to the IR2Vec
+  /// algorithm.
+  const SmallMapVector<const BasicBlock *, Embedding, 16> &getBBVecMap() const {
+    return BBVecMap;
+  }
+
+  /// Returns the vector representation for a given function corresponding to
+  /// the IR2Vec algorithm.
+  const Embedding &getFunctionVector() const { return FuncVector; }
+};
+
+/// Class for computing the Symbolic embeddings of IR2Vec
+class Symbolic : public Embeddings {
+private:
+  /// Utility function to compute the vector representation for a given basic
+  /// block.
+  Embedding computeBB2Vec(const BasicBlock &BB);
+
+  /// Utility function to compute the vector representation for a given
+  /// function.
+  Embedding computeFunc2Vec();
+
+public:
+  Symbolic(const Function &F, const Vocab &Vocabulary, unsigned DIM)
+      : Embeddings(F, Vocabulary, DIM) {
+    FuncVector = Embedding(DIM, 0);
+  }
+  void computeEmbeddings() override;
+};
+
+/// Scales the vector Vec by Factor
+void scaleVector(Embedding &Vec, const float Factor) {
+  std::transform(Vec.begin(), Vec.end(), Vec.begin(),
+                 [Factor](double X) { return X * Factor; });
+}
+
+/// Adds two vectors: Vec += Vec2
+void addVectors(Embedding &Vec, const Embedding &Vec2) {
+  std::transform(Vec.begin(), Vec.end(), Vec2.begin(), Vec.begin(),
+                 std::plus<double>());
+}
+
+// FIXME: Currently lookups are string based. Use numeric Keys
+// for efficiency.
+Embedding Embeddings::lookupVocab(const std::string &Key) {
+  Embedding Vec(DIM, 0);
+  // FIXME: Use zero vectors in vocab and assert failure for
+  // unknown entities rather than silently returning zeroes here.
+  if (Vocabulary.find(Key) == Vocabulary.end()) {
+    LLVM_DEBUG(errs() << "cannot find key in map : " << Key << "\n");
+    DataMissCounter++;
+  } else {
+    Vec = Vocabulary[Key];
+  }
+  return Vec;
+}
+
+void Symbolic::computeEmbeddings() {
+  if (F.isDeclaration())
+    return;
+  for (auto &BB : F) {
+    auto It = BBVecMap.find(&BB);
+    if (It != BBVecMap.end())
+      continue;
+    BBVecMap[&BB] = computeBB2Vec(BB);
+    addVectors(FuncVector, BBVecMap[&BB]);
+  }
+}
+
+Embedding Symbolic::computeBB2Vec(const BasicBlock &BB) {
+  Embedding BBVector(DIM, 0);
+
+  for (auto &I : BB) {
+    Embedding InstVector(DIM, 0);
+
+    auto Vec = lookupVocab(I.getOpcodeName());
+    scaleVector(Vec, WO);
+    addVectors(InstVector, Vec);
+
+    auto Type = I.getType();
+    if (Type->isVoidTy()) {
+      Vec = lookupVocab("voidTy");
+    } else if (Type->isFloatingPointTy()) {
+      Vec = lookupVocab("floatTy");
+    } else if (Type->isIntegerTy()) {
+      Vec = lookupVocab("integerTy");
+    } else if (Type->isFunctionTy()) {
+      Vec = lookupVocab("functionTy");
+    } else if (Type->isStructTy()) {
+      Vec = lookupVocab("structTy");
+    } else if (Type->isArrayTy()) {
+      Vec = lookupVocab("arrayTy");
+    } else if (Type->isPointerTy()) {
+      Vec = lookupVocab("pointerTy");
+    } else if (Type->isVectorTy()) {
+      Vec = lookupVocab("vectorTy");
+    } else if (Type->isEmptyTy()) {
+      Vec = lookupVocab("emptyTy");
+    } else if (Type->isLabelTy()) {
+      Vec = lookupVocab("labelTy");
+    } else if (Type->isTokenTy()) {
+      Vec = lookupVocab("tokenTy");
+    } else if (Type->isMetadataTy()) {
+      Vec = lookupVocab("metadataTy");
+    } else {
+      Vec = lookupVocab("unknownTy");
----------------
svkeerthy wrote:

We don't want to throw error right now. I am planning to handle it in subsequent patches once the default vocab is in place.

https://github.com/llvm/llvm-project/pull/134004


More information about the llvm-commits mailing list