[llvm] [SandboxIR] Implement ConstantAggregate (PR #107136)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Sep 3 12:40:24 PDT 2024
https://github.com/vporpo updated https://github.com/llvm/llvm-project/pull/107136
>From 0cba5171db602b18dfbca78b45a5dc7f5c60c05d Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Tue, 20 Aug 2024 12:09:51 -0700
Subject: [PATCH] [SandboxIR] Implement ConstantAggregate
This patch implements sandboxir:: ConstantAggregate, ConstantStruct,
ConstantArray and ConstantVector, mirroring LLVM IR.
---
llvm/include/llvm/SandboxIR/SandboxIR.h | 94 +++++++++++++++++++
.../llvm/SandboxIR/SandboxIRValues.def | 3 +
llvm/include/llvm/SandboxIR/Type.h | 45 +++++++--
llvm/lib/SandboxIR/SandboxIR.cpp | 48 +++++++++-
llvm/lib/SandboxIR/Type.cpp | 15 +++
llvm/unittests/SandboxIR/SandboxIRTest.cpp | 75 +++++++++++++++
llvm/unittests/SandboxIR/TypesTest.cpp | 38 ++++++++
7 files changed, 310 insertions(+), 8 deletions(-)
diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index 0ac049af4db2bd..71fd9202703dc5 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -304,6 +304,8 @@ class Value {
friend class PHINode; // For getting `Val`.
friend class UnreachableInst; // For getting `Val`.
friend class CatchSwitchAddHandler; // For `Val`.
+ friend class ConstantArray; // For `Val`.
+ friend class ConstantStruct; // For `Val`.
/// All values point to the context.
Context &Ctx;
@@ -840,6 +842,97 @@ class ConstantFP final : public Constant {
#endif
};
+/// Base class for aggregate constants (with operands).
+class ConstantAggregate : public Constant {
+protected:
+ ConstantAggregate(ClassID ID, llvm::Constant *C, Context &Ctx)
+ : Constant(ID, C, Ctx) {}
+
+public:
+ /// For isa/dyn_cast.
+ static bool classof(const sandboxir::Value *From) {
+ auto ID = From->getSubclassID();
+ return ID == ClassID::ConstantVector || ID == ClassID::ConstantStruct ||
+ ID == ClassID::ConstantArray;
+ }
+};
+
+class ConstantArray final : public ConstantAggregate {
+ ConstantArray(llvm::ConstantArray *C, Context &Ctx)
+ : ConstantAggregate(ClassID::ConstantArray, C, Ctx) {}
+ friend class Context; // For constructor.
+
+public:
+ static Constant *get(ArrayType *T, ArrayRef<Constant *> V);
+ ArrayType *getType() const;
+
+ // TODO: Missing functions: getType(), getTypeForElements(), getAnon(), get().
+
+ /// For isa/dyn_cast.
+ static bool classof(const Value *From) {
+ return From->getSubclassID() == ClassID::ConstantArray;
+ }
+};
+
+class ConstantStruct final : public ConstantAggregate {
+ ConstantStruct(llvm::ConstantStruct *C, Context &Ctx)
+ : ConstantAggregate(ClassID::ConstantStruct, C, Ctx) {}
+ friend class Context; // For constructor.
+
+public:
+ static Constant *get(StructType *T, ArrayRef<Constant *> V);
+
+ template <typename... Csts>
+ static std::enable_if_t<are_base_of<Constant, Csts...>::value, Constant *>
+ get(StructType *T, Csts *...Vs) {
+ return get(T, ArrayRef<Constant *>({Vs...}));
+ }
+ /// Return an anonymous struct that has the specified elements.
+ /// If the struct is possibly empty, then you must specify a context.
+ static Constant *getAnon(ArrayRef<Constant *> V, bool Packed = false) {
+ return get(getTypeForElements(V, Packed), V);
+ }
+ static Constant *getAnon(Context &Ctx, ArrayRef<Constant *> V,
+ bool Packed = false) {
+ return get(getTypeForElements(Ctx, V, Packed), V);
+ }
+ /// This version of the method allows an empty list.
+ static StructType *getTypeForElements(Context &Ctx, ArrayRef<Constant *> V,
+ bool Packed = false);
+ /// Return an anonymous struct type to use for a constant with the specified
+ /// set of elements. The list must not be empty.
+ static StructType *getTypeForElements(ArrayRef<Constant *> V,
+ bool Packed = false) {
+ assert(!V.empty() &&
+ "ConstantStruct::getTypeForElements cannot be called on empty list");
+ return getTypeForElements(V[0]->getContext(), V, Packed);
+ }
+
+ /// Specialization - reduce amount of casting.
+ inline StructType *getType() const {
+ return cast<StructType>(Value::getType());
+ }
+
+ /// For isa/dyn_cast.
+ static bool classof(const Value *From) {
+ return From->getSubclassID() == ClassID::ConstantStruct;
+ }
+};
+
+class ConstantVector final : public ConstantAggregate {
+ ConstantVector(llvm::ConstantVector *C, Context &Ctx)
+ : ConstantAggregate(ClassID::ConstantVector, C, Ctx) {}
+ friend class Context; // For constructor.
+
+public:
+ // TODO: Missing functions: getSplat(), getType(), getSplatValue().
+
+ /// For isa/dyn_cast.
+ static bool classof(const Value *From) {
+ return From->getSubclassID() == ClassID::ConstantVector;
+ }
+};
+
/// Iterator for `Instruction`s in a `BasicBlock.
/// \Returns an sandboxir::Instruction & when derereferenced.
class BBIterator {
@@ -3353,6 +3446,7 @@ class Context {
friend class Type; // For LLVMCtx.
friend class PointerType; // For LLVMCtx.
friend class IntegerType; // For LLVMCtx.
+ friend class StructType; // For LLVMCtx.
Tracker IRTracker;
/// Maps LLVM Value to the corresponding sandboxir::Value. Owns all
diff --git a/llvm/include/llvm/SandboxIR/SandboxIRValues.def b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
index 2fc24ed71c4cf6..d2031bbdcfb543 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIRValues.def
+++ b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
@@ -27,6 +27,9 @@ DEF_VALUE(Block, BasicBlock)
DEF_CONST(Constant, Constant)
DEF_CONST(ConstantInt, ConstantInt)
DEF_CONST(ConstantFP, ConstantFP)
+DEF_CONST(ConstantArray, ConstantArray)
+DEF_CONST(ConstantStruct, ConstantStruct)
+DEF_CONST(ConstantVector, ConstantVector)
#ifndef DEF_INSTR
#define DEF_INSTR(ID, OPCODE, CLASS)
diff --git a/llvm/include/llvm/SandboxIR/Type.h b/llvm/include/llvm/SandboxIR/Type.h
index 39c545a6e6c6d2..2f9b94b8d71751 100644
--- a/llvm/include/llvm/SandboxIR/Type.h
+++ b/llvm/include/llvm/SandboxIR/Type.h
@@ -27,6 +27,8 @@ class PointerType;
class VectorType;
class IntegerType;
class FunctionType;
+class ArrayType;
+class StructType;
#define DEF_INSTR(ID, OPCODE, CLASS) class CLASS;
#define DEF_CONST(ID, CLASS) class CLASS;
#include "llvm/SandboxIR/SandboxIRValues.def"
@@ -36,13 +38,19 @@ class FunctionType;
class Type {
protected:
llvm::Type *LLVMTy;
- friend class VectorType; // For LLVMTy.
- friend class PointerType; // For LLVMTy.
- friend class FunctionType; // For LLVMTy.
- friend class IntegerType; // For LLVMTy.
- friend class Function; // For LLVMTy.
- friend class CallBase; // For LLVMTy.
- friend class ConstantInt; // For LLVMTy.
+ friend class ArrayType; // For LLVMTy.
+ friend class StructType; // For LLVMTy.
+ friend class VectorType; // For LLVMTy.
+ friend class PointerType; // For LLVMTy.
+ friend class FunctionType; // For LLVMTy.
+ friend class IntegerType; // For LLVMTy.
+ friend class Function; // For LLVMTy.
+ friend class CallBase; // For LLVMTy.
+ friend class ConstantInt; // For LLVMTy.
+ friend class ConstantArray; // For LLVMTy.
+ friend class ConstantStruct; // For LLVMTy.
+ friend class ConstantVector; // For LLVMTy.
+
// Friend all instruction classes because `create()` functions use LLVMTy.
#define DEF_INSTR(ID, OPCODE, CLASS) friend class CLASS;
#define DEF_CONST(ID, CLASS) friend class CLASS;
@@ -281,8 +289,31 @@ class PointerType : public Type {
}
};
+class ArrayType : public Type {
+public:
+ // TODO: add missing functions
+ static bool classof(const Type *From) {
+ return isa<llvm::ArrayType>(From->LLVMTy);
+ }
+};
+
+class StructType : public Type {
+public:
+ /// This static method is the primary way to create a literal StructType.
+ static StructType *get(Context &Ctx, ArrayRef<Type *> Elements,
+ bool IsPacked = false);
+
+ bool isPacked() const { return cast<llvm::StructType>(LLVMTy)->isPacked(); }
+
+ // TODO: add missing functions
+ static bool classof(const Type *From) {
+ return isa<llvm::StructType>(From->LLVMTy);
+ }
+};
+
class VectorType : public Type {
public:
+ static VectorType *get(Type *ElementType, ElementCount EC);
// TODO: add missing functions
static bool classof(const Type *From) {
return isa<llvm::VectorType>(From->LLVMTy);
diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index 5af6fbdde42cb7..e8d081e6b17e74 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -2364,6 +2364,44 @@ bool ConstantFP::isValueValidForType(Type *Ty, const APFloat &V) {
return llvm::ConstantFP::isValueValidForType(Ty->LLVMTy, V);
}
+Constant *ConstantArray::get(ArrayType *T, ArrayRef<Constant *> V) {
+ auto &Ctx = T->getContext();
+ SmallVector<llvm::Constant *> LLVMValues;
+ LLVMValues.reserve(V.size());
+ for (auto *Elm : V)
+ LLVMValues.push_back(cast<llvm::Constant>(Elm->Val));
+ auto *LLVMC =
+ llvm::ConstantArray::get(cast<llvm::ArrayType>(T->LLVMTy), LLVMValues);
+ return cast<ConstantArray>(Ctx.getOrCreateConstant(LLVMC));
+}
+
+ArrayType *ConstantArray::getType() const {
+ return cast<ArrayType>(
+ Ctx.getType(cast<llvm::ConstantArray>(Val)->getType()));
+}
+
+Constant *ConstantStruct::get(StructType *T, ArrayRef<Constant *> V) {
+ auto &Ctx = T->getContext();
+ SmallVector<llvm::Constant *> LLVMValues;
+ LLVMValues.reserve(V.size());
+ for (auto *Elm : V)
+ LLVMValues.push_back(cast<llvm::Constant>(Elm->Val));
+ auto *LLVMC =
+ llvm::ConstantStruct::get(cast<llvm::StructType>(T->LLVMTy), LLVMValues);
+ return cast<ConstantStruct>(Ctx.getOrCreateConstant(LLVMC));
+}
+
+StructType *ConstantStruct::getTypeForElements(Context &Ctx,
+ ArrayRef<Constant *> V,
+ bool Packed) {
+ unsigned VecSize = V.size();
+ SmallVector<Type *, 16> EltTypes;
+ EltTypes.reserve(VecSize);
+ for (Constant *Elm : V)
+ EltTypes.push_back(Elm->getType());
+ return StructType::get(Ctx, EltTypes, Packed);
+}
+
FunctionType *Function::getFunctionType() const {
return cast<FunctionType>(
Ctx.getType(cast<llvm::Function>(Val)->getFunctionType()));
@@ -2459,7 +2497,15 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
It->second = std::unique_ptr<ConstantFP>(new ConstantFP(CF, *this));
return It->second.get();
}
- if (auto *F = dyn_cast<llvm::Function>(LLVMV))
+ if (auto *CA = dyn_cast<llvm::ConstantArray>(C))
+ It->second = std::unique_ptr<ConstantArray>(new ConstantArray(CA, *this));
+ else if (auto *CS = dyn_cast<llvm::ConstantStruct>(C))
+ It->second =
+ std::unique_ptr<ConstantStruct>(new ConstantStruct(CS, *this));
+ else if (auto *CV = dyn_cast<llvm::ConstantVector>(C))
+ It->second =
+ std::unique_ptr<ConstantVector>(new ConstantVector(CV, *this));
+ else if (auto *F = dyn_cast<llvm::Function>(LLVMV))
It->second = std::unique_ptr<Function>(new Function(F, *this));
else
It->second = std::unique_ptr<Constant>(new Constant(C, *this));
diff --git a/llvm/lib/SandboxIR/Type.cpp b/llvm/lib/SandboxIR/Type.cpp
index eee69c5ec7c893..535b0f75fd8743 100644
--- a/llvm/lib/SandboxIR/Type.cpp
+++ b/llvm/lib/SandboxIR/Type.cpp
@@ -47,6 +47,21 @@ PointerType *PointerType::get(Context &Ctx, unsigned AddressSpace) {
Ctx.getType(llvm::PointerType::get(Ctx.LLVMCtx, AddressSpace)));
}
+StructType *StructType::get(Context &Ctx, ArrayRef<Type *> Elements,
+ bool IsPacked) {
+ SmallVector<llvm::Type *> LLVMElements;
+ LLVMElements.reserve(Elements.size());
+ for (Type *Elm : Elements)
+ LLVMElements.push_back(Elm->LLVMTy);
+ return cast<StructType>(
+ Ctx.getType(llvm::StructType::get(Ctx.LLVMCtx, LLVMElements, IsPacked)));
+}
+
+VectorType *VectorType::get(Type *ElementType, ElementCount EC) {
+ return cast<VectorType>(ElementType->getContext().getType(
+ llvm::VectorType::get(ElementType->LLVMTy, EC)));
+}
+
IntegerType *IntegerType::get(Context &Ctx, unsigned NumBits) {
return cast<IntegerType>(
Ctx.getType(llvm::IntegerType::get(Ctx.LLVMCtx, NumBits)));
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index 2ec8eefd8c323c..ca2a183e532683 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -445,6 +445,81 @@ define void @foo(float %v0, double %v1) {
EXPECT_TRUE(NegZero->isExactlyValue(-0.0));
}
+// Tests ConstantArray, ConstantStruct and ConstantVector.
+TEST_F(SandboxIRTest, ConstantAggregate) {
+ // Note: we are using i42 to avoid the creation of ConstantDataVector or
+ // ConstantDataArray.
+ parseIR(C, R"IR(
+define void @foo() {
+ %array = extractvalue [2 x i42] [i42 0, i42 1], 0
+ %struct = extractvalue {i42, i42} {i42 0, i42 1}, 0
+ %vector = extractelement <2 x i42> <i42 0, i42 1>, i32 0
+ ret void
+}
+)IR");
+ Function &LLVMF = *M->getFunction("foo");
+ sandboxir::Context Ctx(C);
+
+ auto &F = *Ctx.createFunction(&LLVMF);
+ auto &BB = *F.begin();
+ auto It = BB.begin();
+ auto *I0 = &*It++;
+ auto *I1 = &*It++;
+ auto *I2 = &*It++;
+ // Check classof() and creation.
+ auto *Array = cast<sandboxir::ConstantArray>(I0->getOperand(0));
+ EXPECT_TRUE(isa<sandboxir::ConstantAggregate>(Array));
+ auto *Struct = cast<sandboxir::ConstantStruct>(I1->getOperand(0));
+ EXPECT_TRUE(isa<sandboxir::ConstantAggregate>(Struct));
+ auto *Vector = cast<sandboxir::ConstantVector>(I2->getOperand(0));
+ EXPECT_TRUE(isa<sandboxir::ConstantAggregate>(Vector));
+
+ auto *ZeroI42 = cast<sandboxir::ConstantInt>(Array->getOperand(0));
+ auto *OneI42 = cast<sandboxir::ConstantInt>(Array->getOperand(1));
+ // Check ConstantArray::get(), getType().
+ auto *NewCA =
+ sandboxir::ConstantArray::get(Array->getType(), {ZeroI42, OneI42});
+ EXPECT_EQ(NewCA, Array);
+
+ // Check ConstantStruct::get(), getType().
+ auto *NewCS =
+ sandboxir::ConstantStruct::get(Struct->getType(), {ZeroI42, OneI42});
+ EXPECT_EQ(NewCS, Struct);
+ // Check ConstantStruct::get(...).
+ auto *NewCS2 =
+ sandboxir::ConstantStruct::get(Struct->getType(), ZeroI42, OneI42);
+ EXPECT_EQ(NewCS2, Struct);
+ // Check ConstantStruct::getAnon(ArayRef).
+ auto *AnonCS = sandboxir::ConstantStruct::getAnon({ZeroI42, OneI42});
+ EXPECT_FALSE(cast<sandboxir::StructType>(AnonCS->getType())->isPacked());
+ auto *AnonCSPacked =
+ sandboxir::ConstantStruct::getAnon({ZeroI42, OneI42}, /*Packed=*/true);
+ EXPECT_TRUE(cast<sandboxir::StructType>(AnonCSPacked->getType())->isPacked());
+ // Check ConstantStruct::getAnon(Ctx, ArrayRef).
+ auto *AnonCS2 = sandboxir::ConstantStruct::getAnon(Ctx, {ZeroI42, OneI42});
+ EXPECT_EQ(AnonCS2, AnonCS);
+ auto *AnonCS2Packed = sandboxir::ConstantStruct::getAnon(
+ Ctx, {ZeroI42, OneI42}, /*Packed=*/true);
+ EXPECT_EQ(AnonCS2Packed, AnonCSPacked);
+ // Check ConstantStruct::getTypeForElements(Ctx, ArrayRef).
+ auto *StructTy =
+ sandboxir::ConstantStruct::getTypeForElements(Ctx, {ZeroI42, OneI42});
+ EXPECT_EQ(StructTy, Struct->getType());
+ EXPECT_FALSE(StructTy->isPacked());
+ // Check ConstantStruct::getTypeForElements(Ctx, ArrayRef, Packed).
+ auto *StructTyPacked = sandboxir::ConstantStruct::getTypeForElements(
+ Ctx, {ZeroI42, OneI42}, /*Packed=*/true);
+ EXPECT_TRUE(StructTyPacked->isPacked());
+ // Check ConstantStruct::getTypeForElements(ArrayRef).
+ auto *StructTy2 =
+ sandboxir::ConstantStruct::getTypeForElements(Ctx, {ZeroI42, OneI42});
+ EXPECT_EQ(StructTy2, Struct->getType());
+ // Check ConstantStruct::getTypeForElements(ArrayRef, Packed).
+ auto *StructTy2Packed = sandboxir::ConstantStruct::getTypeForElements(
+ Ctx, {ZeroI42, OneI42}, /*Packed=*/true);
+ EXPECT_EQ(StructTy2Packed, StructTyPacked);
+}
+
TEST_F(SandboxIRTest, Use) {
parseIR(C, R"IR(
define i32 @foo(i32 %v0, i32 %v1) {
diff --git a/llvm/unittests/SandboxIR/TypesTest.cpp b/llvm/unittests/SandboxIR/TypesTest.cpp
index dcbf65a20b2fd7..d4c2de441268c1 100644
--- a/llvm/unittests/SandboxIR/TypesTest.cpp
+++ b/llvm/unittests/SandboxIR/TypesTest.cpp
@@ -224,6 +224,44 @@ define void @foo(ptr %ptr) {
EXPECT_EQ(NewPtrTy2, PtrTy);
}
+TEST_F(SandboxTypeTest, ArrayType) {
+ parseIR(C, R"IR(
+define void @foo([2 x i8] %v0) {
+ ret void
+}
+)IR");
+ llvm::Function *LLVMF = &*M->getFunction("foo");
+ sandboxir::Context Ctx(C);
+ auto *F = Ctx.createFunction(LLVMF);
+ // Check classof(), creation.
+ [[maybe_unused]] auto *ArrayTy =
+ cast<sandboxir::ArrayType>(F->getArg(0)->getType());
+}
+
+TEST_F(SandboxTypeTest, StructType) {
+ parseIR(C, R"IR(
+define void @foo({i32, i8} %v0) {
+ ret void
+}
+)IR");
+ llvm::Function *LLVMF = &*M->getFunction("foo");
+ sandboxir::Context Ctx(C);
+ auto *F = Ctx.createFunction(LLVMF);
+ auto *Int32Ty = sandboxir::Type::getInt32Ty(Ctx);
+ auto *Int8Ty = sandboxir::Type::getInt8Ty(Ctx);
+ // Check classof(), creation.
+ [[maybe_unused]] auto *StructTy =
+ cast<sandboxir::StructType>(F->getArg(0)->getType());
+ // Check get().
+ auto *NewStructTy = sandboxir::StructType::get(Ctx, {Int32Ty, Int8Ty});
+ EXPECT_EQ(NewStructTy, StructTy);
+ // Check get(Packed).
+ auto *NewStructTyPacked =
+ sandboxir::StructType::get(Ctx, {Int32Ty, Int8Ty}, /*Packed=*/true);
+ EXPECT_NE(NewStructTyPacked, StructTy);
+ EXPECT_TRUE(NewStructTyPacked->isPacked());
+}
+
TEST_F(SandboxTypeTest, VectorType) {
parseIR(C, R"IR(
define void @foo(<2 x i8> %v0) {
More information about the llvm-commits
mailing list