[llvm] [IR2Vec] Adding unit tests (PR #141873)

S. VenkataKeerthy via llvm-commits llvm-commits at lists.llvm.org
Thu May 29 10:59:59 PDT 2025


================
@@ -0,0 +1,254 @@
+//===- IR2VecTest.cpp - Unit tests for IR2Vec -----------------------------==//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Analysis/IR2Vec.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/Type.h"
+#include "llvm/Support/Error.h"
+
+#include "gtest/gtest.h"
+#include <map>
+#include <vector>
+
+using namespace llvm;
+using namespace ir2vec;
+
+namespace {
+
+class TestableEmbedder : public Embedder {
+public:
+  TestableEmbedder(const Function &F, const Vocab &V, unsigned Dim)
+      : Embedder(F, V, Dim) {}
+  void computeEmbeddings() const override {}
+  using Embedder::lookupVocab;
+  static void addVectors(Embedding &Dst, const Embedding &Src) {
+    Embedder::addVectors(Dst, Src);
+  }
+  static void addScaledVector(Embedding &Dst, const Embedding &Src,
+                              float Factor) {
+    Embedder::addScaledVector(Dst, Src, Factor);
+  }
+};
+
+class IR2VecTest : public ::testing::Test {
+protected:
+  void SetUp() override {}
+  void TearDown() override {}
+};
+
+TEST_F(IR2VecTest, CreateSymbolicEmbedder) {
+  Vocab V = {{"foo", {1.0, 2.0}}};
+
+  LLVMContext Ctx;
+  Module M("M", Ctx);
+  FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
+  Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
+
+  auto Result = Embedder::create(IR2VecKind::Symbolic, *F, V, 2);
+  EXPECT_TRUE(static_cast<bool>(Result));
+
+  auto *Emb = Result->get();
+  EXPECT_NE(Emb, nullptr);
+}
+
+TEST_F(IR2VecTest, CreateInvalidMode) {
+  Vocab V = {{"foo", {1.0, 2.0}}};
+
+  LLVMContext Ctx;
+  Module M("M", Ctx);
+  FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
+  Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
+
+  // static_cast an invalid int to IR2VecKind
+  auto Result = Embedder::create(static_cast<IR2VecKind>(-1), *F, V, 2);
+  EXPECT_FALSE(static_cast<bool>(Result));
+
+  std::string ErrMsg;
+  llvm::handleAllErrors(
+      Result.takeError(),
+      [&](const llvm::ErrorInfoBase &EIB) { ErrMsg = EIB.message(); });
+  EXPECT_NE(ErrMsg.find("Unknown IR2VecKind"), std::string::npos);
+}
+
+TEST_F(IR2VecTest, AddVectors) {
+  Embedding E1 = {1.0, 2.0, 3.0};
+  Embedding E2 = {0.5, 1.5, -1.0};
+
+  TestableEmbedder::addVectors(E1, E2);
+  EXPECT_DOUBLE_EQ(E1[0], 1.5);
+  EXPECT_DOUBLE_EQ(E1[1], 3.5);
+  EXPECT_DOUBLE_EQ(E1[2], 2.0);
+
+  // Check that E2 is unchanged
+  EXPECT_DOUBLE_EQ(E2[0], 0.5);
+  EXPECT_DOUBLE_EQ(E2[1], 1.5);
+  EXPECT_DOUBLE_EQ(E2[2], -1.0);
+}
+
+TEST_F(IR2VecTest, AddScaledVector) {
+  Embedding E1 = {1.0, 2.0, 3.0};
+  Embedding E2 = {2.0, 0.5, -1.0};
+
+  TestableEmbedder::addScaledVector(E1, E2, 0.5f);
+  EXPECT_DOUBLE_EQ(E1[0], 2.0);
+  EXPECT_DOUBLE_EQ(E1[1], 2.25);
+  EXPECT_DOUBLE_EQ(E1[2], 2.5);
+
+  // Check that E2 is unchanged
+  EXPECT_DOUBLE_EQ(E2[0], 2.0);
+  EXPECT_DOUBLE_EQ(E2[1], 0.5);
+  EXPECT_DOUBLE_EQ(E2[2], -1.0);
+}
+
+#if GTEST_HAS_DEATH_TEST
+TEST_F(IR2VecTest, MismatchedDimensionsAddVectors) {
+  Embedding E1 = {1.0, 2.0};
+  Embedding E2 = {1.0};
+  EXPECT_DEATH(TestableEmbedder::addVectors(E1, E2),
+               "Vectors must have the same dimension");
+}
+
+TEST_F(IR2VecTest, MismatchedDimensionsAddScaledVector) {
+  Embedding E1 = {1.0, 2.0};
+  Embedding E2 = {1.0};
+  EXPECT_DEATH(TestableEmbedder::addScaledVector(E1, E2, 1.0f),
+               "Vectors must have the same dimension");
+}
+#endif
+
+TEST_F(IR2VecTest, LookupVocab) {
+  Vocab V = {{"foo", {1.0, 2.0}}, {"bar", {3.0, 4.0}}};
+  LLVMContext Ctx;
+  Module M("M", Ctx);
+  FunctionType *FTy = FunctionType::get(Type::getVoidTy(Ctx), false);
+  Function *F = Function::Create(FTy, Function::ExternalLinkage, "f", M);
+
+  TestableEmbedder E(*F, V, 2);
+  auto V_foo = E.lookupVocab("foo");
+  EXPECT_EQ(V_foo.size(), 2u);
+  EXPECT_DOUBLE_EQ(V_foo[0], 1.0);
+  EXPECT_DOUBLE_EQ(V_foo[1], 2.0);
+
+  auto V_missing = E.lookupVocab("missing");
+  EXPECT_EQ(V_missing.size(), 2u);
+  EXPECT_DOUBLE_EQ(V_missing[0], 0.0);
+  EXPECT_DOUBLE_EQ(V_missing[1], 0.0);
+}
+
+TEST_F(IR2VecTest, ZeroDimensionEmbedding) {
+  Embedding E1;
+  Embedding E2;
+  // Should be no-op, but not crash
+  TestableEmbedder::addVectors(E1, E2);
+  TestableEmbedder::addScaledVector(E1, E2, 1.0f);
+  EXPECT_TRUE(E1.empty());
+}
+
+TEST_F(IR2VecTest, IR2VecVocabResultValidity) {
+  // Default constructed is invalid
+  IR2VecVocabResult invalidResult;
+  EXPECT_FALSE(invalidResult.isValid());
+#if GTEST_HAS_DEATH_TEST
+  EXPECT_DEATH(invalidResult.getVocabulary(), "IR2Vec Vocabulary is invalid");
+  EXPECT_DEATH(invalidResult.getDimension(), "IR2Vec Vocabulary is invalid");
+#endif
+
+  // Valid vocab
+  Vocab V = {{"foo", {1.0, 2.0}}, {"bar", {3.0, 4.0}}};
+  IR2VecVocabResult validResult(std::move(V));
+  EXPECT_TRUE(validResult.isValid());
+  EXPECT_EQ(validResult.getDimension(), 2u);
+}
+
+// Helper to create a minimal function and embedder for getter tests
+struct GetterTestEnv {
+  Vocab V;
+  LLVMContext Ctx;
+  std::unique_ptr<Module> M;
+  Function *F;
----------------
svkeerthy wrote:

Done

https://github.com/llvm/llvm-project/pull/141873


More information about the llvm-commits mailing list