[llvm] [SandboxIR] Implement ConstantAggregateZero (PR #107172)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Sep 3 17:55:54 PDT 2024
https://github.com/vporpo created https://github.com/llvm/llvm-project/pull/107172
This patch implements sandboxir::ConstantAggregateZero mirroring llvm::ConstantAggregateZero.
>From 25c14b0a7fef588988c505d4ae9fc907a13449eb Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Tue, 20 Aug 2024 09:05:26 -0700
Subject: [PATCH] [SandboxIR] Implement ConstantAggregateZero
This patch implements sandboxir::ConstantAggregateZero mirroring
llvm::ConstantAggregateZero.
---
llvm/include/llvm/SandboxIR/SandboxIR.h | 43 ++++++++++
.../llvm/SandboxIR/SandboxIRValues.def | 1 +
llvm/include/llvm/SandboxIR/Type.h | 1 +
llvm/lib/SandboxIR/SandboxIR.cpp | 78 +++++++++++++++----
llvm/lib/SandboxIR/Type.cpp | 5 ++
llvm/unittests/SandboxIR/SandboxIRTest.cpp | 69 ++++++++++++++++
llvm/unittests/SandboxIR/TypesTest.cpp | 4 +
7 files changed, 185 insertions(+), 16 deletions(-)
diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index 5c2d58c1b99dc3..ff25b56bd537ff 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -114,6 +114,7 @@ namespace sandboxir {
class BasicBlock;
class ConstantInt;
class ConstantFP;
+class ConstantAggregateZero;
class Context;
class Function;
class Instruction;
@@ -306,6 +307,7 @@ class Value {
friend class CatchSwitchAddHandler; // For `Val`.
friend class ConstantArray; // For `Val`.
friend class ConstantStruct; // For `Val`.
+ friend class ConstantAggregateZero; // For `Val`.
/// All values point to the context.
Context &Ctx;
@@ -933,6 +935,47 @@ class ConstantVector final : public ConstantAggregate {
}
};
+class ConstantAggregateZero final : public Constant {
+ ConstantAggregateZero(llvm::ConstantAggregateZero *C, Context &Ctx)
+ : Constant(ClassID::ConstantAggregateZero, C, Ctx) {}
+ friend class Context; // For constructor.
+
+public:
+ static ConstantAggregateZero *get(Type *Ty);
+ /// If this CAZ has array or vector type, return a zero with the right element
+ /// type.
+ Constant *getSequentialElement() const;
+ /// If this CAZ has struct type, return a zero with the right element type for
+ /// the specified element.
+ Constant *getStructElement(unsigned Elt) const;
+ /// Return a zero of the right value for the specified GEP index if we can,
+ /// otherwise return null (e.g. if C is a ConstantExpr).
+ Constant *getElementValue(Constant *C) const;
+ /// Return a zero of the right value for the specified GEP index.
+ Constant *getElementValue(unsigned Idx) const;
+ /// Return the number of elements in the array, vector, or struct.
+ ElementCount getElementCount() const {
+ return cast<llvm::ConstantAggregateZero>(Val)->getElementCount();
+ }
+
+ /// For isa/dyn_cast.
+ static bool classof(const sandboxir::Value *From) {
+ return From->getSubclassID() == ClassID::ConstantAggregateZero;
+ }
+ unsigned getUseOperandNo(const Use &Use) const final {
+ llvm_unreachable("ConstantAggregateZero has no operands!");
+ }
+#ifndef NDEBUG
+ void verify() const override {
+ assert(isa<llvm::ConstantAggregateZero>(Val) && "Expected a CAZ!");
+ }
+ void dumpOS(raw_ostream &OS) const override {
+ dumpCommonPrefix(OS);
+ dumpCommonSuffix(OS);
+ }
+#endif
+};
+
/// Iterator for `Instruction`s in a `BasicBlock.
/// \Returns an sandboxir::Instruction & when derereferenced.
class BBIterator {
diff --git a/llvm/include/llvm/SandboxIR/SandboxIRValues.def b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
index d2031bbdcfb543..577fc0a9dfa5c3 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIRValues.def
+++ b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
@@ -30,6 +30,7 @@ DEF_CONST(ConstantFP, ConstantFP)
DEF_CONST(ConstantArray, ConstantArray)
DEF_CONST(ConstantStruct, ConstantStruct)
DEF_CONST(ConstantVector, ConstantVector)
+DEF_CONST(ConstantAggregateZero, ConstantAggregateZero)
#ifndef DEF_INSTR
#define DEF_INSTR(ID, OPCODE, CLASS)
diff --git a/llvm/include/llvm/SandboxIR/Type.h b/llvm/include/llvm/SandboxIR/Type.h
index 2f9b94b8d71751..c35c2e5588ef95 100644
--- a/llvm/include/llvm/SandboxIR/Type.h
+++ b/llvm/include/llvm/SandboxIR/Type.h
@@ -291,6 +291,7 @@ class PointerType : public Type {
class ArrayType : public Type {
public:
+ static ArrayType *get(Type *ElementType, uint64_t NumElements);
// TODO: add missing functions
static bool classof(const Type *From) {
return isa<llvm::ArrayType>(From->LLVMTy);
diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index e8d081e6b17e74..f6bc8dcfefafab 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -2402,6 +2402,30 @@ StructType *ConstantStruct::getTypeForElements(Context &Ctx,
return StructType::get(Ctx, EltTypes, Packed);
}
+ConstantAggregateZero *ConstantAggregateZero::get(Type *Ty) {
+ auto *LLVMC = llvm::ConstantAggregateZero::get(Ty->LLVMTy);
+ return cast<ConstantAggregateZero>(
+ Ty->getContext().getOrCreateConstant(LLVMC));
+}
+
+Constant *ConstantAggregateZero::getSequentialElement() const {
+ return cast<Constant>(Ctx.getValue(
+ cast<llvm::ConstantAggregateZero>(Val)->getSequentialElement()));
+}
+Constant *ConstantAggregateZero::getStructElement(unsigned Elt) const {
+ return cast<Constant>(Ctx.getValue(
+ cast<llvm::ConstantAggregateZero>(Val)->getStructElement(Elt)));
+}
+Constant *ConstantAggregateZero::getElementValue(Constant *C) const {
+ return cast<Constant>(
+ Ctx.getValue(cast<llvm::ConstantAggregateZero>(Val)->getElementValue(
+ cast<llvm::Constant>(C->Val))));
+}
+Constant *ConstantAggregateZero::getElementValue(unsigned Idx) const {
+ return cast<Constant>(Ctx.getValue(
+ cast<llvm::ConstantAggregateZero>(Val)->getElementValue(Idx)));
+}
+
FunctionType *Function::getFunctionType() const {
return cast<FunctionType>(
Ctx.getType(cast<llvm::Function>(Val)->getFunctionType()));
@@ -2489,26 +2513,48 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
return It->second.get();
if (auto *C = dyn_cast<llvm::Constant>(LLVMV)) {
- if (auto *CI = dyn_cast<llvm::ConstantInt>(C)) {
- It->second = std::unique_ptr<ConstantInt>(new ConstantInt(CI, *this));
+ switch (C->getValueID()) {
+ case llvm::Value::ConstantIntVal:
+ It->second = std::unique_ptr<ConstantInt>(
+ new ConstantInt(cast<llvm::ConstantInt>(C), *this));
return It->second.get();
- }
- if (auto *CF = dyn_cast<llvm::ConstantFP>(C)) {
- It->second = std::unique_ptr<ConstantFP>(new ConstantFP(CF, *this));
+ case llvm::Value::ConstantFPVal:
+ It->second = std::unique_ptr<ConstantFP>(
+ new ConstantFP(cast<llvm::ConstantFP>(C), *this));
return It->second.get();
+ case llvm::Value::ConstantAggregateZeroVal: {
+ auto *CAZ = cast<llvm::ConstantAggregateZero>(C);
+ It->second = std::unique_ptr<ConstantAggregateZero>(
+ new ConstantAggregateZero(CAZ, *this));
+ auto *Ret = It->second.get();
+ // Must create sandboxir for elements.
+ auto EC = CAZ->getElementCount();
+ if (EC.isFixed()) {
+ for (auto ElmIdx : seq<unsigned>(0, EC.getFixedValue()))
+ getOrCreateValueInternal(CAZ->getElementValue(ElmIdx), CAZ);
+ }
+ return Ret;
}
- if (auto *CA = dyn_cast<llvm::ConstantArray>(C))
- It->second = std::unique_ptr<ConstantArray>(new ConstantArray(CA, *this));
- else if (auto *CS = dyn_cast<llvm::ConstantStruct>(C))
- It->second =
- std::unique_ptr<ConstantStruct>(new ConstantStruct(CS, *this));
- else if (auto *CV = dyn_cast<llvm::ConstantVector>(C))
- It->second =
- std::unique_ptr<ConstantVector>(new ConstantVector(CV, *this));
- else if (auto *F = dyn_cast<llvm::Function>(LLVMV))
- It->second = std::unique_ptr<Function>(new Function(F, *this));
- else
+ case llvm::Value::ConstantArrayVal:
+ It->second = std::unique_ptr<ConstantArray>(
+ new ConstantArray(cast<llvm::ConstantArray>(C), *this));
+ break;
+ case llvm::Value::ConstantStructVal:
+ It->second = std::unique_ptr<ConstantStruct>(
+ new ConstantStruct(cast<llvm::ConstantStruct>(C), *this));
+ break;
+ case llvm::Value::ConstantVectorVal:
+ It->second = std::unique_ptr<ConstantVector>(
+ new ConstantVector(cast<llvm::ConstantVector>(C), *this));
+ break;
+ case llvm::Value::FunctionVal:
+ It->second = std::unique_ptr<Function>(
+ new Function(cast<llvm::Function>(C), *this));
+ break;
+ default:
It->second = std::unique_ptr<Constant>(new Constant(C, *this));
+ break;
+ }
auto *NewC = It->second.get();
for (llvm::Value *COp : C->operands())
getOrCreateValueInternal(COp, C);
diff --git a/llvm/lib/SandboxIR/Type.cpp b/llvm/lib/SandboxIR/Type.cpp
index 535b0f75fd8743..11a16e865213fb 100644
--- a/llvm/lib/SandboxIR/Type.cpp
+++ b/llvm/lib/SandboxIR/Type.cpp
@@ -47,6 +47,11 @@ PointerType *PointerType::get(Context &Ctx, unsigned AddressSpace) {
Ctx.getType(llvm::PointerType::get(Ctx.LLVMCtx, AddressSpace)));
}
+ArrayType *ArrayType::get(Type *ElementType, uint64_t NumElements) {
+ return cast<ArrayType>(ElementType->getContext().getType(
+ llvm::ArrayType::get(ElementType->LLVMTy, NumElements)));
+}
+
StructType *StructType::get(Context &Ctx, ArrayRef<Type *> Elements,
bool IsPacked) {
SmallVector<llvm::Type *> LLVMElements;
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index ca2a183e532683..eb5bbb759fa134 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -520,6 +520,75 @@ define void @foo() {
EXPECT_EQ(StructTy2Packed, StructTyPacked);
}
+TEST_F(SandboxIRTest, ConstantAggregateZero) {
+ parseIR(C, R"IR(
+define void @foo(ptr %ptr, {i32, i8} %v1, <2 x i8> %v2) {
+ %extr0 = extractvalue [2 x i8] zeroinitializer, 0
+ %extr1 = extractvalue {i32, i8} zeroinitializer, 0
+ %extr2 = extractelement <2 x i8> zeroinitializer, i32 0
+ ret void
+}
+)IR");
+ Function &LLVMF = *M->getFunction("foo");
+ sandboxir::Context Ctx(C);
+
+ auto &F = *Ctx.createFunction(&LLVMF);
+ auto &BB = *F.begin();
+ auto It = BB.begin();
+ auto *Extr0 = &*It++;
+ auto *Extr1 = &*It++;
+ auto *Extr2 = &*It++;
+ [[maybe_unused]] auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
+ auto *Zero32 =
+ sandboxir::ConstantInt::get(sandboxir::Type::getInt32Ty(Ctx), 0);
+ auto *Zero8 = sandboxir::ConstantInt::get(sandboxir::Type::getInt8Ty(Ctx), 0);
+ auto *Int8Ty = sandboxir::Type::getInt8Ty(Ctx);
+ auto *Int32Ty = sandboxir::Type::getInt32Ty(Ctx);
+ auto *ArrayTy = sandboxir::ArrayType::get(Int8Ty, 2u);
+ auto *StructTy = sandboxir::StructType::get(Ctx, {Int32Ty, Int8Ty});
+ auto *VectorTy =
+ sandboxir::VectorType::get(Int8Ty, ElementCount::getFixed(2u));
+
+ // Check creation and classof().
+ auto *ArrayCAZ = cast<sandboxir::ConstantAggregateZero>(Extr0->getOperand(0));
+ EXPECT_EQ(ArrayCAZ->getType(), ArrayTy);
+ auto *StructCAZ =
+ cast<sandboxir::ConstantAggregateZero>(Extr1->getOperand(0));
+ EXPECT_EQ(StructCAZ->getType(), StructTy);
+ auto *VectorCAZ =
+ cast<sandboxir::ConstantAggregateZero>(Extr2->getOperand(0));
+ EXPECT_EQ(VectorCAZ->getType(), VectorTy);
+ // Check get().
+ auto *SameVectorCAZ =
+ sandboxir::ConstantAggregateZero::get(sandboxir::VectorType::get(
+ sandboxir::Type::getInt8Ty(Ctx), ElementCount::getFixed(2)));
+ EXPECT_EQ(SameVectorCAZ, VectorCAZ); // Should be uniqued.
+ auto *NewVectorCAZ =
+ sandboxir::ConstantAggregateZero::get(sandboxir::VectorType::get(
+ sandboxir::Type::getInt8Ty(Ctx), ElementCount::getFixed(4)));
+ EXPECT_NE(NewVectorCAZ, VectorCAZ);
+ // Check getSequentialElement().
+ auto *SeqElm = VectorCAZ->getSequentialElement();
+ EXPECT_EQ(SeqElm,
+ sandboxir::ConstantInt::get(sandboxir::Type::getInt8Ty(Ctx), 0));
+ // Check getStructElement().
+ auto *StructElm0 = StructCAZ->getStructElement(0);
+ auto *StructElm1 = StructCAZ->getStructElement(1);
+ EXPECT_EQ(StructElm0, Zero32);
+ EXPECT_EQ(StructElm1, Zero8);
+ // Check getElementValue(Constant).
+ EXPECT_EQ(ArrayCAZ->getElementValue(Zero32), Zero8);
+ EXPECT_EQ(StructCAZ->getElementValue(Zero32), Zero32);
+ EXPECT_EQ(VectorCAZ->getElementValue(Zero32), Zero8);
+ // Check getElementValue(unsigned).
+ EXPECT_EQ(ArrayCAZ->getElementValue(0u), Zero8);
+ EXPECT_EQ(StructCAZ->getElementValue(0u), Zero32);
+ EXPECT_EQ(VectorCAZ->getElementValue(0u), Zero8);
+ // Check getElementCount().
+ EXPECT_EQ(ArrayCAZ->getElementCount(), ElementCount::getFixed(2));
+ EXPECT_EQ(NewVectorCAZ->getElementCount(), ElementCount::getFixed(4));
+}
+
TEST_F(SandboxIRTest, Use) {
parseIR(C, R"IR(
define i32 @foo(i32 %v0, i32 %v1) {
diff --git a/llvm/unittests/SandboxIR/TypesTest.cpp b/llvm/unittests/SandboxIR/TypesTest.cpp
index d4c2de441268c1..36ef0cf8e52911 100644
--- a/llvm/unittests/SandboxIR/TypesTest.cpp
+++ b/llvm/unittests/SandboxIR/TypesTest.cpp
@@ -236,6 +236,10 @@ define void @foo([2 x i8] %v0) {
// Check classof(), creation.
[[maybe_unused]] auto *ArrayTy =
cast<sandboxir::ArrayType>(F->getArg(0)->getType());
+ // Check get().
+ auto *NewArrayTy =
+ sandboxir::ArrayType::get(sandboxir::Type::getInt8Ty(Ctx), 2u);
+ EXPECT_EQ(NewArrayTy, ArrayTy);
}
TEST_F(SandboxTypeTest, StructType) {
More information about the llvm-commits
mailing list