[llvm] Adding IR2Vec as an analysis pass (PR #134004)
Aiden Grossman via llvm-commits
llvm-commits at lists.llvm.org
Thu May 15 16:20:55 PDT 2025
================
@@ -0,0 +1,300 @@
+//===- IR2Vec.cpp - Implementation of IR2Vec -----------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM
+// Exceptions. See the LICENSE file for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+///
+/// \file
+/// This file implements the IR2Vec algorithm.
+///
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/IR2Vec.h"
+
+#include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/Statistic.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/PassManager.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/Errc.h"
+#include "llvm/Support/Error.h"
+#include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/Format.h"
+#include "llvm/Support/JSON.h"
+#include "llvm/Support/MemoryBuffer.h"
+
+using namespace llvm;
+using namespace ir2vec;
+
+#define DEBUG_TYPE "ir2vec"
+
+STATISTIC(VocabMissCounter,
+ "Number of lookups to entites not present in the vocabulary");
+
+static cl::OptionCategory IR2VecCategory("IR2Vec Options");
+
+// FIXME: Use a default vocab when not specified
+static cl::opt<std::string>
+ VocabFile("ir2vec-vocab-path", cl::Optional,
+ cl::desc("Path to the vocabulary file for IR2Vec"), cl::init(""),
+ cl::cat(IR2VecCategory));
+static cl::opt<float> OpcWeight("ir2vec-opc-weight", cl::Optional,
+ cl::init(1.0),
+ cl::desc("Weight for opcode embeddings"),
+ cl::cat(IR2VecCategory));
+static cl::opt<float> TypeWeight("ir2vec-type-weight", cl::Optional,
+ cl::init(0.5),
+ cl::desc("Weight for type embeddings"),
+ cl::cat(IR2VecCategory));
+static cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional,
+ cl::init(0.2),
+ cl::desc("Weight for argument embeddings"),
+ cl::cat(IR2VecCategory));
+
+AnalysisKey IR2VecVocabAnalysis::Key;
+
+// ==----------------------------------------------------------------------===//
+// Embedder and its subclasses
+//===----------------------------------------------------------------------===//
+
+#define RETURN_LOOKUP_IF(CONDITION, KEY_STR) \
+ if (CONDITION) \
+ return lookupVocab(KEY_STR);
+
+Embedder::Embedder(const Function &F, const Vocab &Vocabulary,
+ unsigned Dimension)
+ : F(F), Vocabulary(Vocabulary), Dimension(Dimension), OpcWeight(OpcWeight),
+ TypeWeight(TypeWeight), ArgWeight(ArgWeight) {}
+
+ErrorOr<std::unique_ptr<Embedder>> Embedder::create(IR2VecKind Mode,
+ const Function &F,
+ const Vocab &Vocabulary,
+ unsigned Dimension) {
+ switch (Mode) {
+ case IR2VecKind::Symbolic:
+ return std::make_unique<SymbolicEmbedder>(F, Vocabulary, Dimension);
+ default:
+ return errc::invalid_argument;
+ }
+}
+
+void Embedder::addVectors(Embedding &Dst, const Embedding &Src) {
+ std::transform(Dst.begin(), Dst.end(), Src.begin(), Dst.begin(),
+ std::plus<double>());
+}
+
+void Embedder::addScaledVector(Embedding &Dst, const Embedding &Src,
+ float Factor) {
+ assert(Dst.size() == Src.size() && "Vectors must have the same dimension");
+ for (size_t i = 0; i < Dst.size(); ++i) {
+ Dst[i] += Src[i] * Factor;
+ }
+}
+
+// FIXME: Currently lookups are string based. Use numeric Keys
+// for efficiency
+Embedding Embedder::lookupVocab(const std::string &Key) {
+ Embedding Vec(Dimension, 0);
+ // FIXME: Use zero vectors in vocab and assert failure for
+ // unknown entities rather than silently returning zeroes here.
+ auto It = Vocabulary.find(Key);
+ if (It != Vocabulary.end())
+ return It->second;
+ LLVM_DEBUG(errs() << "cannot find key in map : " << Key << "\n");
+ ++VocabMissCounter;
+ return Vec;
+}
+
+Embedding SymbolicEmbedder::getTypeEmbedding(const Type *Ty) {
+ RETURN_LOOKUP_IF(Ty->isVoidTy(), "voidTy");
+ RETURN_LOOKUP_IF(Ty->isFloatingPointTy(), "floatTy");
+ RETURN_LOOKUP_IF(Ty->isIntegerTy(), "integerTy");
+ RETURN_LOOKUP_IF(Ty->isFunctionTy(), "functionTy");
+ RETURN_LOOKUP_IF(Ty->isStructTy(), "structTy");
+ RETURN_LOOKUP_IF(Ty->isArrayTy(), "arrayTy");
+ RETURN_LOOKUP_IF(Ty->isPointerTy(), "pointerTy");
+ RETURN_LOOKUP_IF(Ty->isVectorTy(), "vectorTy");
+ RETURN_LOOKUP_IF(Ty->isEmptyTy(), "emptyTy");
+ RETURN_LOOKUP_IF(Ty->isLabelTy(), "labelTy");
+ RETURN_LOOKUP_IF(Ty->isTokenTy(), "tokenTy");
+ RETURN_LOOKUP_IF(Ty->isMetadataTy(), "metadataTy");
+ return lookupVocab("unknownTy");
+}
+
+Embedding SymbolicEmbedder::getOperandEmbedding(const Value *Op) {
+ RETURN_LOOKUP_IF(isa<Function>(Op), "function");
+ RETURN_LOOKUP_IF(isa<PointerType>(Op->getType()), "pointer");
+ RETURN_LOOKUP_IF(isa<Constant>(Op), "constant");
+ return lookupVocab("variable");
+}
+
+void SymbolicEmbedder::computeEmbeddings() {
+ if (F.isDeclaration())
+ return;
+ for (auto &BB : F) {
+ auto [It, WasInserted] = BBVecMap.try_emplace(&BB, computeBB2Vec(BB));
+ assert(WasInserted && "Basic block already exists in the map");
+ addVectors(FuncVector, It->second);
+ }
+}
+
+Embedding SymbolicEmbedder::computeBB2Vec(const BasicBlock &BB) {
+ Embedding BBVector(Dimension, 0);
+
+ for (auto &I : BB) {
----------------
boomanaiden154 wrote:
Can this be `const auto &I`?
https://github.com/llvm/llvm-project/pull/134004
More information about the llvm-commits
mailing list