[llvm] [SandboxIR] Implement ConstantAggregateZero (PR #107172)

via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 3 17:55:54 PDT 2024


https://github.com/vporpo created https://github.com/llvm/llvm-project/pull/107172

This patch implements sandboxir::ConstantAggregateZero mirroring llvm::ConstantAggregateZero.

>From 25c14b0a7fef588988c505d4ae9fc907a13449eb Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Tue, 20 Aug 2024 09:05:26 -0700
Subject: [PATCH] [SandboxIR] Implement ConstantAggregateZero

This patch implements sandboxir::ConstantAggregateZero mirroring
llvm::ConstantAggregateZero.
---
 llvm/include/llvm/SandboxIR/SandboxIR.h       | 43 ++++++++++
 .../llvm/SandboxIR/SandboxIRValues.def        |  1 +
 llvm/include/llvm/SandboxIR/Type.h            |  1 +
 llvm/lib/SandboxIR/SandboxIR.cpp              | 78 +++++++++++++++----
 llvm/lib/SandboxIR/Type.cpp                   |  5 ++
 llvm/unittests/SandboxIR/SandboxIRTest.cpp    | 69 ++++++++++++++++
 llvm/unittests/SandboxIR/TypesTest.cpp        |  4 +
 7 files changed, 185 insertions(+), 16 deletions(-)

diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index 5c2d58c1b99dc3..ff25b56bd537ff 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -114,6 +114,7 @@ namespace sandboxir {
 class BasicBlock;
 class ConstantInt;
 class ConstantFP;
+class ConstantAggregateZero;
 class Context;
 class Function;
 class Instruction;
@@ -306,6 +307,7 @@ class Value {
   friend class CatchSwitchAddHandler; // For `Val`.
   friend class ConstantArray;         // For `Val`.
   friend class ConstantStruct;        // For `Val`.
+  friend class ConstantAggregateZero; // For `Val`.
 
   /// All values point to the context.
   Context &Ctx;
@@ -933,6 +935,47 @@ class ConstantVector final : public ConstantAggregate {
   }
 };
 
+class ConstantAggregateZero final : public Constant {
+  ConstantAggregateZero(llvm::ConstantAggregateZero *C, Context &Ctx)
+      : Constant(ClassID::ConstantAggregateZero, C, Ctx) {}
+  friend class Context; // For constructor.
+
+public:
+  static ConstantAggregateZero *get(Type *Ty);
+  /// If this CAZ has array or vector type, return a zero with the right element
+  /// type.
+  Constant *getSequentialElement() const;
+  /// If this CAZ has struct type, return a zero with the right element type for
+  /// the specified element.
+  Constant *getStructElement(unsigned Elt) const;
+  /// Return a zero of the right value for the specified GEP index if we can,
+  /// otherwise return null (e.g. if C is a ConstantExpr).
+  Constant *getElementValue(Constant *C) const;
+  /// Return a zero of the right value for the specified GEP index.
+  Constant *getElementValue(unsigned Idx) const;
+  /// Return the number of elements in the array, vector, or struct.
+  ElementCount getElementCount() const {
+    return cast<llvm::ConstantAggregateZero>(Val)->getElementCount();
+  }
+
+  /// For isa/dyn_cast.
+  static bool classof(const sandboxir::Value *From) {
+    return From->getSubclassID() == ClassID::ConstantAggregateZero;
+  }
+  unsigned getUseOperandNo(const Use &Use) const final {
+    llvm_unreachable("ConstantAggregateZero has no operands!");
+  }
+#ifndef NDEBUG
+  void verify() const override {
+    assert(isa<llvm::ConstantAggregateZero>(Val) && "Expected a CAZ!");
+  }
+  void dumpOS(raw_ostream &OS) const override {
+    dumpCommonPrefix(OS);
+    dumpCommonSuffix(OS);
+  }
+#endif
+};
+
 /// Iterator for `Instruction`s in a `BasicBlock.
 /// \Returns an sandboxir::Instruction & when derereferenced.
 class BBIterator {
diff --git a/llvm/include/llvm/SandboxIR/SandboxIRValues.def b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
index d2031bbdcfb543..577fc0a9dfa5c3 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIRValues.def
+++ b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
@@ -30,6 +30,7 @@ DEF_CONST(ConstantFP, ConstantFP)
 DEF_CONST(ConstantArray, ConstantArray)
 DEF_CONST(ConstantStruct, ConstantStruct)
 DEF_CONST(ConstantVector, ConstantVector)
+DEF_CONST(ConstantAggregateZero, ConstantAggregateZero)
 
 #ifndef DEF_INSTR
 #define DEF_INSTR(ID, OPCODE, CLASS)
diff --git a/llvm/include/llvm/SandboxIR/Type.h b/llvm/include/llvm/SandboxIR/Type.h
index 2f9b94b8d71751..c35c2e5588ef95 100644
--- a/llvm/include/llvm/SandboxIR/Type.h
+++ b/llvm/include/llvm/SandboxIR/Type.h
@@ -291,6 +291,7 @@ class PointerType : public Type {
 
 class ArrayType : public Type {
 public:
+  static ArrayType *get(Type *ElementType, uint64_t NumElements);
   // TODO: add missing functions
   static bool classof(const Type *From) {
     return isa<llvm::ArrayType>(From->LLVMTy);
diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index e8d081e6b17e74..f6bc8dcfefafab 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -2402,6 +2402,30 @@ StructType *ConstantStruct::getTypeForElements(Context &Ctx,
   return StructType::get(Ctx, EltTypes, Packed);
 }
 
+ConstantAggregateZero *ConstantAggregateZero::get(Type *Ty) {
+  auto *LLVMC = llvm::ConstantAggregateZero::get(Ty->LLVMTy);
+  return cast<ConstantAggregateZero>(
+      Ty->getContext().getOrCreateConstant(LLVMC));
+}
+
+Constant *ConstantAggregateZero::getSequentialElement() const {
+  return cast<Constant>(Ctx.getValue(
+      cast<llvm::ConstantAggregateZero>(Val)->getSequentialElement()));
+}
+Constant *ConstantAggregateZero::getStructElement(unsigned Elt) const {
+  return cast<Constant>(Ctx.getValue(
+      cast<llvm::ConstantAggregateZero>(Val)->getStructElement(Elt)));
+}
+Constant *ConstantAggregateZero::getElementValue(Constant *C) const {
+  return cast<Constant>(
+      Ctx.getValue(cast<llvm::ConstantAggregateZero>(Val)->getElementValue(
+          cast<llvm::Constant>(C->Val))));
+}
+Constant *ConstantAggregateZero::getElementValue(unsigned Idx) const {
+  return cast<Constant>(Ctx.getValue(
+      cast<llvm::ConstantAggregateZero>(Val)->getElementValue(Idx)));
+}
+
 FunctionType *Function::getFunctionType() const {
   return cast<FunctionType>(
       Ctx.getType(cast<llvm::Function>(Val)->getFunctionType()));
@@ -2489,26 +2513,48 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
     return It->second.get();
 
   if (auto *C = dyn_cast<llvm::Constant>(LLVMV)) {
-    if (auto *CI = dyn_cast<llvm::ConstantInt>(C)) {
-      It->second = std::unique_ptr<ConstantInt>(new ConstantInt(CI, *this));
+    switch (C->getValueID()) {
+    case llvm::Value::ConstantIntVal:
+      It->second = std::unique_ptr<ConstantInt>(
+          new ConstantInt(cast<llvm::ConstantInt>(C), *this));
       return It->second.get();
-    }
-    if (auto *CF = dyn_cast<llvm::ConstantFP>(C)) {
-      It->second = std::unique_ptr<ConstantFP>(new ConstantFP(CF, *this));
+    case llvm::Value::ConstantFPVal:
+      It->second = std::unique_ptr<ConstantFP>(
+          new ConstantFP(cast<llvm::ConstantFP>(C), *this));
       return It->second.get();
+    case llvm::Value::ConstantAggregateZeroVal: {
+      auto *CAZ = cast<llvm::ConstantAggregateZero>(C);
+      It->second = std::unique_ptr<ConstantAggregateZero>(
+          new ConstantAggregateZero(CAZ, *this));
+      auto *Ret = It->second.get();
+      // Must create sandboxir for elements.
+      auto EC = CAZ->getElementCount();
+      if (EC.isFixed()) {
+        for (auto ElmIdx : seq<unsigned>(0, EC.getFixedValue()))
+          getOrCreateValueInternal(CAZ->getElementValue(ElmIdx), CAZ);
+      }
+      return Ret;
     }
-    if (auto *CA = dyn_cast<llvm::ConstantArray>(C))
-      It->second = std::unique_ptr<ConstantArray>(new ConstantArray(CA, *this));
-    else if (auto *CS = dyn_cast<llvm::ConstantStruct>(C))
-      It->second =
-          std::unique_ptr<ConstantStruct>(new ConstantStruct(CS, *this));
-    else if (auto *CV = dyn_cast<llvm::ConstantVector>(C))
-      It->second =
-          std::unique_ptr<ConstantVector>(new ConstantVector(CV, *this));
-    else if (auto *F = dyn_cast<llvm::Function>(LLVMV))
-      It->second = std::unique_ptr<Function>(new Function(F, *this));
-    else
+    case llvm::Value::ConstantArrayVal:
+      It->second = std::unique_ptr<ConstantArray>(
+          new ConstantArray(cast<llvm::ConstantArray>(C), *this));
+      break;
+    case llvm::Value::ConstantStructVal:
+      It->second = std::unique_ptr<ConstantStruct>(
+          new ConstantStruct(cast<llvm::ConstantStruct>(C), *this));
+      break;
+    case llvm::Value::ConstantVectorVal:
+      It->second = std::unique_ptr<ConstantVector>(
+          new ConstantVector(cast<llvm::ConstantVector>(C), *this));
+      break;
+    case llvm::Value::FunctionVal:
+      It->second = std::unique_ptr<Function>(
+          new Function(cast<llvm::Function>(C), *this));
+      break;
+    default:
       It->second = std::unique_ptr<Constant>(new Constant(C, *this));
+      break;
+    }
     auto *NewC = It->second.get();
     for (llvm::Value *COp : C->operands())
       getOrCreateValueInternal(COp, C);
diff --git a/llvm/lib/SandboxIR/Type.cpp b/llvm/lib/SandboxIR/Type.cpp
index 535b0f75fd8743..11a16e865213fb 100644
--- a/llvm/lib/SandboxIR/Type.cpp
+++ b/llvm/lib/SandboxIR/Type.cpp
@@ -47,6 +47,11 @@ PointerType *PointerType::get(Context &Ctx, unsigned AddressSpace) {
       Ctx.getType(llvm::PointerType::get(Ctx.LLVMCtx, AddressSpace)));
 }
 
+ArrayType *ArrayType::get(Type *ElementType, uint64_t NumElements) {
+  return cast<ArrayType>(ElementType->getContext().getType(
+      llvm::ArrayType::get(ElementType->LLVMTy, NumElements)));
+}
+
 StructType *StructType::get(Context &Ctx, ArrayRef<Type *> Elements,
                             bool IsPacked) {
   SmallVector<llvm::Type *> LLVMElements;
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index ca2a183e532683..eb5bbb759fa134 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -520,6 +520,75 @@ define void @foo() {
   EXPECT_EQ(StructTy2Packed, StructTyPacked);
 }
 
+TEST_F(SandboxIRTest, ConstantAggregateZero) {
+  parseIR(C, R"IR(
+define void @foo(ptr %ptr, {i32, i8} %v1, <2 x i8> %v2) {
+  %extr0 = extractvalue [2 x i8] zeroinitializer, 0
+  %extr1 = extractvalue {i32, i8} zeroinitializer, 0
+  %extr2 = extractelement <2 x i8> zeroinitializer, i32 0
+  ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+
+  auto &F = *Ctx.createFunction(&LLVMF);
+  auto &BB = *F.begin();
+  auto It = BB.begin();
+  auto *Extr0 = &*It++;
+  auto *Extr1 = &*It++;
+  auto *Extr2 = &*It++;
+  [[maybe_unused]] auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
+  auto *Zero32 =
+      sandboxir::ConstantInt::get(sandboxir::Type::getInt32Ty(Ctx), 0);
+  auto *Zero8 = sandboxir::ConstantInt::get(sandboxir::Type::getInt8Ty(Ctx), 0);
+  auto *Int8Ty = sandboxir::Type::getInt8Ty(Ctx);
+  auto *Int32Ty = sandboxir::Type::getInt32Ty(Ctx);
+  auto *ArrayTy = sandboxir::ArrayType::get(Int8Ty, 2u);
+  auto *StructTy = sandboxir::StructType::get(Ctx, {Int32Ty, Int8Ty});
+  auto *VectorTy =
+      sandboxir::VectorType::get(Int8Ty, ElementCount::getFixed(2u));
+
+  // Check creation and classof().
+  auto *ArrayCAZ = cast<sandboxir::ConstantAggregateZero>(Extr0->getOperand(0));
+  EXPECT_EQ(ArrayCAZ->getType(), ArrayTy);
+  auto *StructCAZ =
+      cast<sandboxir::ConstantAggregateZero>(Extr1->getOperand(0));
+  EXPECT_EQ(StructCAZ->getType(), StructTy);
+  auto *VectorCAZ =
+      cast<sandboxir::ConstantAggregateZero>(Extr2->getOperand(0));
+  EXPECT_EQ(VectorCAZ->getType(), VectorTy);
+  // Check get().
+  auto *SameVectorCAZ =
+      sandboxir::ConstantAggregateZero::get(sandboxir::VectorType::get(
+          sandboxir::Type::getInt8Ty(Ctx), ElementCount::getFixed(2)));
+  EXPECT_EQ(SameVectorCAZ, VectorCAZ); // Should be uniqued.
+  auto *NewVectorCAZ =
+      sandboxir::ConstantAggregateZero::get(sandboxir::VectorType::get(
+          sandboxir::Type::getInt8Ty(Ctx), ElementCount::getFixed(4)));
+  EXPECT_NE(NewVectorCAZ, VectorCAZ);
+  // Check getSequentialElement().
+  auto *SeqElm = VectorCAZ->getSequentialElement();
+  EXPECT_EQ(SeqElm,
+            sandboxir::ConstantInt::get(sandboxir::Type::getInt8Ty(Ctx), 0));
+  // Check getStructElement().
+  auto *StructElm0 = StructCAZ->getStructElement(0);
+  auto *StructElm1 = StructCAZ->getStructElement(1);
+  EXPECT_EQ(StructElm0, Zero32);
+  EXPECT_EQ(StructElm1, Zero8);
+  // Check getElementValue(Constant).
+  EXPECT_EQ(ArrayCAZ->getElementValue(Zero32), Zero8);
+  EXPECT_EQ(StructCAZ->getElementValue(Zero32), Zero32);
+  EXPECT_EQ(VectorCAZ->getElementValue(Zero32), Zero8);
+  // Check getElementValue(unsigned).
+  EXPECT_EQ(ArrayCAZ->getElementValue(0u), Zero8);
+  EXPECT_EQ(StructCAZ->getElementValue(0u), Zero32);
+  EXPECT_EQ(VectorCAZ->getElementValue(0u), Zero8);
+  // Check getElementCount().
+  EXPECT_EQ(ArrayCAZ->getElementCount(), ElementCount::getFixed(2));
+  EXPECT_EQ(NewVectorCAZ->getElementCount(), ElementCount::getFixed(4));
+}
+
 TEST_F(SandboxIRTest, Use) {
   parseIR(C, R"IR(
 define i32 @foo(i32 %v0, i32 %v1) {
diff --git a/llvm/unittests/SandboxIR/TypesTest.cpp b/llvm/unittests/SandboxIR/TypesTest.cpp
index d4c2de441268c1..36ef0cf8e52911 100644
--- a/llvm/unittests/SandboxIR/TypesTest.cpp
+++ b/llvm/unittests/SandboxIR/TypesTest.cpp
@@ -236,6 +236,10 @@ define void @foo([2 x i8] %v0) {
   // Check classof(), creation.
   [[maybe_unused]] auto *ArrayTy =
       cast<sandboxir::ArrayType>(F->getArg(0)->getType());
+  // Check get().
+  auto *NewArrayTy =
+      sandboxir::ArrayType::get(sandboxir::Type::getInt8Ty(Ctx), 2u);
+  EXPECT_EQ(NewArrayTy, ArrayTy);
 }
 
 TEST_F(SandboxTypeTest, StructType) {



More information about the llvm-commits mailing list