[llvm] [SandboxIR] Implement ConstantAggregate (PR #107136)

via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 3 12:40:24 PDT 2024


https://github.com/vporpo updated https://github.com/llvm/llvm-project/pull/107136

>From 0cba5171db602b18dfbca78b45a5dc7f5c60c05d Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Tue, 20 Aug 2024 12:09:51 -0700
Subject: [PATCH] [SandboxIR] Implement ConstantAggregate

This patch implements sandboxir:: ConstantAggregate, ConstantStruct,
ConstantArray and ConstantVector, mirroring LLVM IR.
---
 llvm/include/llvm/SandboxIR/SandboxIR.h       | 94 +++++++++++++++++++
 .../llvm/SandboxIR/SandboxIRValues.def        |  3 +
 llvm/include/llvm/SandboxIR/Type.h            | 45 +++++++--
 llvm/lib/SandboxIR/SandboxIR.cpp              | 48 +++++++++-
 llvm/lib/SandboxIR/Type.cpp                   | 15 +++
 llvm/unittests/SandboxIR/SandboxIRTest.cpp    | 75 +++++++++++++++
 llvm/unittests/SandboxIR/TypesTest.cpp        | 38 ++++++++
 7 files changed, 310 insertions(+), 8 deletions(-)

diff --git a/llvm/include/llvm/SandboxIR/SandboxIR.h b/llvm/include/llvm/SandboxIR/SandboxIR.h
index 0ac049af4db2bd..71fd9202703dc5 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIR.h
+++ b/llvm/include/llvm/SandboxIR/SandboxIR.h
@@ -304,6 +304,8 @@ class Value {
   friend class PHINode;               // For getting `Val`.
   friend class UnreachableInst;       // For getting `Val`.
   friend class CatchSwitchAddHandler; // For `Val`.
+  friend class ConstantArray;         // For `Val`.
+  friend class ConstantStruct;        // For `Val`.
 
   /// All values point to the context.
   Context &Ctx;
@@ -840,6 +842,97 @@ class ConstantFP final : public Constant {
 #endif
 };
 
+/// Base class for aggregate constants (with operands).
+class ConstantAggregate : public Constant {
+protected:
+  ConstantAggregate(ClassID ID, llvm::Constant *C, Context &Ctx)
+      : Constant(ID, C, Ctx) {}
+
+public:
+  /// For isa/dyn_cast.
+  static bool classof(const sandboxir::Value *From) {
+    auto ID = From->getSubclassID();
+    return ID == ClassID::ConstantVector || ID == ClassID::ConstantStruct ||
+           ID == ClassID::ConstantArray;
+  }
+};
+
+class ConstantArray final : public ConstantAggregate {
+  ConstantArray(llvm::ConstantArray *C, Context &Ctx)
+      : ConstantAggregate(ClassID::ConstantArray, C, Ctx) {}
+  friend class Context; // For constructor.
+
+public:
+  static Constant *get(ArrayType *T, ArrayRef<Constant *> V);
+  ArrayType *getType() const;
+
+  // TODO: Missing functions: getType(), getTypeForElements(), getAnon(), get().
+
+  /// For isa/dyn_cast.
+  static bool classof(const Value *From) {
+    return From->getSubclassID() == ClassID::ConstantArray;
+  }
+};
+
+class ConstantStruct final : public ConstantAggregate {
+  ConstantStruct(llvm::ConstantStruct *C, Context &Ctx)
+      : ConstantAggregate(ClassID::ConstantStruct, C, Ctx) {}
+  friend class Context; // For constructor.
+
+public:
+  static Constant *get(StructType *T, ArrayRef<Constant *> V);
+
+  template <typename... Csts>
+  static std::enable_if_t<are_base_of<Constant, Csts...>::value, Constant *>
+  get(StructType *T, Csts *...Vs) {
+    return get(T, ArrayRef<Constant *>({Vs...}));
+  }
+  /// Return an anonymous struct that has the specified elements.
+  /// If the struct is possibly empty, then you must specify a context.
+  static Constant *getAnon(ArrayRef<Constant *> V, bool Packed = false) {
+    return get(getTypeForElements(V, Packed), V);
+  }
+  static Constant *getAnon(Context &Ctx, ArrayRef<Constant *> V,
+                           bool Packed = false) {
+    return get(getTypeForElements(Ctx, V, Packed), V);
+  }
+  /// This version of the method allows an empty list.
+  static StructType *getTypeForElements(Context &Ctx, ArrayRef<Constant *> V,
+                                        bool Packed = false);
+  /// Return an anonymous struct type to use for a constant with the specified
+  /// set of elements. The list must not be empty.
+  static StructType *getTypeForElements(ArrayRef<Constant *> V,
+                                        bool Packed = false) {
+    assert(!V.empty() &&
+           "ConstantStruct::getTypeForElements cannot be called on empty list");
+    return getTypeForElements(V[0]->getContext(), V, Packed);
+  }
+
+  /// Specialization - reduce amount of casting.
+  inline StructType *getType() const {
+    return cast<StructType>(Value::getType());
+  }
+
+  /// For isa/dyn_cast.
+  static bool classof(const Value *From) {
+    return From->getSubclassID() == ClassID::ConstantStruct;
+  }
+};
+
+class ConstantVector final : public ConstantAggregate {
+  ConstantVector(llvm::ConstantVector *C, Context &Ctx)
+      : ConstantAggregate(ClassID::ConstantVector, C, Ctx) {}
+  friend class Context; // For constructor.
+
+public:
+  // TODO: Missing functions: getSplat(), getType(), getSplatValue().
+
+  /// For isa/dyn_cast.
+  static bool classof(const Value *From) {
+    return From->getSubclassID() == ClassID::ConstantVector;
+  }
+};
+
 /// Iterator for `Instruction`s in a `BasicBlock.
 /// \Returns an sandboxir::Instruction & when derereferenced.
 class BBIterator {
@@ -3353,6 +3446,7 @@ class Context {
   friend class Type;        // For LLVMCtx.
   friend class PointerType; // For LLVMCtx.
   friend class IntegerType; // For LLVMCtx.
+  friend class StructType;  // For LLVMCtx.
   Tracker IRTracker;
 
   /// Maps LLVM Value to the corresponding sandboxir::Value. Owns all
diff --git a/llvm/include/llvm/SandboxIR/SandboxIRValues.def b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
index 2fc24ed71c4cf6..d2031bbdcfb543 100644
--- a/llvm/include/llvm/SandboxIR/SandboxIRValues.def
+++ b/llvm/include/llvm/SandboxIR/SandboxIRValues.def
@@ -27,6 +27,9 @@ DEF_VALUE(Block, BasicBlock)
 DEF_CONST(Constant, Constant)
 DEF_CONST(ConstantInt, ConstantInt)
 DEF_CONST(ConstantFP, ConstantFP)
+DEF_CONST(ConstantArray, ConstantArray)
+DEF_CONST(ConstantStruct, ConstantStruct)
+DEF_CONST(ConstantVector, ConstantVector)
 
 #ifndef DEF_INSTR
 #define DEF_INSTR(ID, OPCODE, CLASS)
diff --git a/llvm/include/llvm/SandboxIR/Type.h b/llvm/include/llvm/SandboxIR/Type.h
index 39c545a6e6c6d2..2f9b94b8d71751 100644
--- a/llvm/include/llvm/SandboxIR/Type.h
+++ b/llvm/include/llvm/SandboxIR/Type.h
@@ -27,6 +27,8 @@ class PointerType;
 class VectorType;
 class IntegerType;
 class FunctionType;
+class ArrayType;
+class StructType;
 #define DEF_INSTR(ID, OPCODE, CLASS) class CLASS;
 #define DEF_CONST(ID, CLASS) class CLASS;
 #include "llvm/SandboxIR/SandboxIRValues.def"
@@ -36,13 +38,19 @@ class FunctionType;
 class Type {
 protected:
   llvm::Type *LLVMTy;
-  friend class VectorType;   // For LLVMTy.
-  friend class PointerType;  // For LLVMTy.
-  friend class FunctionType; // For LLVMTy.
-  friend class IntegerType;  // For LLVMTy.
-  friend class Function;     // For LLVMTy.
-  friend class CallBase;     // For LLVMTy.
-  friend class ConstantInt;  // For LLVMTy.
+  friend class ArrayType;      // For LLVMTy.
+  friend class StructType;     // For LLVMTy.
+  friend class VectorType;     // For LLVMTy.
+  friend class PointerType;    // For LLVMTy.
+  friend class FunctionType;   // For LLVMTy.
+  friend class IntegerType;    // For LLVMTy.
+  friend class Function;       // For LLVMTy.
+  friend class CallBase;       // For LLVMTy.
+  friend class ConstantInt;    // For LLVMTy.
+  friend class ConstantArray;  // For LLVMTy.
+  friend class ConstantStruct; // For LLVMTy.
+  friend class ConstantVector; // For LLVMTy.
+
   // Friend all instruction classes because `create()` functions use LLVMTy.
 #define DEF_INSTR(ID, OPCODE, CLASS) friend class CLASS;
 #define DEF_CONST(ID, CLASS) friend class CLASS;
@@ -281,8 +289,31 @@ class PointerType : public Type {
   }
 };
 
+class ArrayType : public Type {
+public:
+  // TODO: add missing functions
+  static bool classof(const Type *From) {
+    return isa<llvm::ArrayType>(From->LLVMTy);
+  }
+};
+
+class StructType : public Type {
+public:
+  /// This static method is the primary way to create a literal StructType.
+  static StructType *get(Context &Ctx, ArrayRef<Type *> Elements,
+                         bool IsPacked = false);
+
+  bool isPacked() const { return cast<llvm::StructType>(LLVMTy)->isPacked(); }
+
+  // TODO: add missing functions
+  static bool classof(const Type *From) {
+    return isa<llvm::StructType>(From->LLVMTy);
+  }
+};
+
 class VectorType : public Type {
 public:
+  static VectorType *get(Type *ElementType, ElementCount EC);
   // TODO: add missing functions
   static bool classof(const Type *From) {
     return isa<llvm::VectorType>(From->LLVMTy);
diff --git a/llvm/lib/SandboxIR/SandboxIR.cpp b/llvm/lib/SandboxIR/SandboxIR.cpp
index 5af6fbdde42cb7..e8d081e6b17e74 100644
--- a/llvm/lib/SandboxIR/SandboxIR.cpp
+++ b/llvm/lib/SandboxIR/SandboxIR.cpp
@@ -2364,6 +2364,44 @@ bool ConstantFP::isValueValidForType(Type *Ty, const APFloat &V) {
   return llvm::ConstantFP::isValueValidForType(Ty->LLVMTy, V);
 }
 
+Constant *ConstantArray::get(ArrayType *T, ArrayRef<Constant *> V) {
+  auto &Ctx = T->getContext();
+  SmallVector<llvm::Constant *> LLVMValues;
+  LLVMValues.reserve(V.size());
+  for (auto *Elm : V)
+    LLVMValues.push_back(cast<llvm::Constant>(Elm->Val));
+  auto *LLVMC =
+      llvm::ConstantArray::get(cast<llvm::ArrayType>(T->LLVMTy), LLVMValues);
+  return cast<ConstantArray>(Ctx.getOrCreateConstant(LLVMC));
+}
+
+ArrayType *ConstantArray::getType() const {
+  return cast<ArrayType>(
+      Ctx.getType(cast<llvm::ConstantArray>(Val)->getType()));
+}
+
+Constant *ConstantStruct::get(StructType *T, ArrayRef<Constant *> V) {
+  auto &Ctx = T->getContext();
+  SmallVector<llvm::Constant *> LLVMValues;
+  LLVMValues.reserve(V.size());
+  for (auto *Elm : V)
+    LLVMValues.push_back(cast<llvm::Constant>(Elm->Val));
+  auto *LLVMC =
+      llvm::ConstantStruct::get(cast<llvm::StructType>(T->LLVMTy), LLVMValues);
+  return cast<ConstantStruct>(Ctx.getOrCreateConstant(LLVMC));
+}
+
+StructType *ConstantStruct::getTypeForElements(Context &Ctx,
+                                               ArrayRef<Constant *> V,
+                                               bool Packed) {
+  unsigned VecSize = V.size();
+  SmallVector<Type *, 16> EltTypes;
+  EltTypes.reserve(VecSize);
+  for (Constant *Elm : V)
+    EltTypes.push_back(Elm->getType());
+  return StructType::get(Ctx, EltTypes, Packed);
+}
+
 FunctionType *Function::getFunctionType() const {
   return cast<FunctionType>(
       Ctx.getType(cast<llvm::Function>(Val)->getFunctionType()));
@@ -2459,7 +2497,15 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
       It->second = std::unique_ptr<ConstantFP>(new ConstantFP(CF, *this));
       return It->second.get();
     }
-    if (auto *F = dyn_cast<llvm::Function>(LLVMV))
+    if (auto *CA = dyn_cast<llvm::ConstantArray>(C))
+      It->second = std::unique_ptr<ConstantArray>(new ConstantArray(CA, *this));
+    else if (auto *CS = dyn_cast<llvm::ConstantStruct>(C))
+      It->second =
+          std::unique_ptr<ConstantStruct>(new ConstantStruct(CS, *this));
+    else if (auto *CV = dyn_cast<llvm::ConstantVector>(C))
+      It->second =
+          std::unique_ptr<ConstantVector>(new ConstantVector(CV, *this));
+    else if (auto *F = dyn_cast<llvm::Function>(LLVMV))
       It->second = std::unique_ptr<Function>(new Function(F, *this));
     else
       It->second = std::unique_ptr<Constant>(new Constant(C, *this));
diff --git a/llvm/lib/SandboxIR/Type.cpp b/llvm/lib/SandboxIR/Type.cpp
index eee69c5ec7c893..535b0f75fd8743 100644
--- a/llvm/lib/SandboxIR/Type.cpp
+++ b/llvm/lib/SandboxIR/Type.cpp
@@ -47,6 +47,21 @@ PointerType *PointerType::get(Context &Ctx, unsigned AddressSpace) {
       Ctx.getType(llvm::PointerType::get(Ctx.LLVMCtx, AddressSpace)));
 }
 
+StructType *StructType::get(Context &Ctx, ArrayRef<Type *> Elements,
+                            bool IsPacked) {
+  SmallVector<llvm::Type *> LLVMElements;
+  LLVMElements.reserve(Elements.size());
+  for (Type *Elm : Elements)
+    LLVMElements.push_back(Elm->LLVMTy);
+  return cast<StructType>(
+      Ctx.getType(llvm::StructType::get(Ctx.LLVMCtx, LLVMElements, IsPacked)));
+}
+
+VectorType *VectorType::get(Type *ElementType, ElementCount EC) {
+  return cast<VectorType>(ElementType->getContext().getType(
+      llvm::VectorType::get(ElementType->LLVMTy, EC)));
+}
+
 IntegerType *IntegerType::get(Context &Ctx, unsigned NumBits) {
   return cast<IntegerType>(
       Ctx.getType(llvm::IntegerType::get(Ctx.LLVMCtx, NumBits)));
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index 2ec8eefd8c323c..ca2a183e532683 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -445,6 +445,81 @@ define void @foo(float %v0, double %v1) {
   EXPECT_TRUE(NegZero->isExactlyValue(-0.0));
 }
 
+// Tests ConstantArray, ConstantStruct and ConstantVector.
+TEST_F(SandboxIRTest, ConstantAggregate) {
+  // Note: we are using i42 to avoid the creation of ConstantDataVector or
+  // ConstantDataArray.
+  parseIR(C, R"IR(
+define void @foo() {
+  %array = extractvalue [2 x i42] [i42 0, i42 1], 0
+  %struct = extractvalue {i42, i42} {i42 0, i42 1}, 0
+  %vector = extractelement <2 x i42> <i42 0, i42 1>, i32 0
+  ret void
+}
+)IR");
+  Function &LLVMF = *M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+
+  auto &F = *Ctx.createFunction(&LLVMF);
+  auto &BB = *F.begin();
+  auto It = BB.begin();
+  auto *I0 = &*It++;
+  auto *I1 = &*It++;
+  auto *I2 = &*It++;
+  // Check classof() and creation.
+  auto *Array = cast<sandboxir::ConstantArray>(I0->getOperand(0));
+  EXPECT_TRUE(isa<sandboxir::ConstantAggregate>(Array));
+  auto *Struct = cast<sandboxir::ConstantStruct>(I1->getOperand(0));
+  EXPECT_TRUE(isa<sandboxir::ConstantAggregate>(Struct));
+  auto *Vector = cast<sandboxir::ConstantVector>(I2->getOperand(0));
+  EXPECT_TRUE(isa<sandboxir::ConstantAggregate>(Vector));
+
+  auto *ZeroI42 = cast<sandboxir::ConstantInt>(Array->getOperand(0));
+  auto *OneI42 = cast<sandboxir::ConstantInt>(Array->getOperand(1));
+  // Check ConstantArray::get(), getType().
+  auto *NewCA =
+      sandboxir::ConstantArray::get(Array->getType(), {ZeroI42, OneI42});
+  EXPECT_EQ(NewCA, Array);
+
+  // Check ConstantStruct::get(), getType().
+  auto *NewCS =
+      sandboxir::ConstantStruct::get(Struct->getType(), {ZeroI42, OneI42});
+  EXPECT_EQ(NewCS, Struct);
+  // Check ConstantStruct::get(...).
+  auto *NewCS2 =
+      sandboxir::ConstantStruct::get(Struct->getType(), ZeroI42, OneI42);
+  EXPECT_EQ(NewCS2, Struct);
+  // Check ConstantStruct::getAnon(ArayRef).
+  auto *AnonCS = sandboxir::ConstantStruct::getAnon({ZeroI42, OneI42});
+  EXPECT_FALSE(cast<sandboxir::StructType>(AnonCS->getType())->isPacked());
+  auto *AnonCSPacked =
+      sandboxir::ConstantStruct::getAnon({ZeroI42, OneI42}, /*Packed=*/true);
+  EXPECT_TRUE(cast<sandboxir::StructType>(AnonCSPacked->getType())->isPacked());
+  // Check ConstantStruct::getAnon(Ctx, ArrayRef).
+  auto *AnonCS2 = sandboxir::ConstantStruct::getAnon(Ctx, {ZeroI42, OneI42});
+  EXPECT_EQ(AnonCS2, AnonCS);
+  auto *AnonCS2Packed = sandboxir::ConstantStruct::getAnon(
+      Ctx, {ZeroI42, OneI42}, /*Packed=*/true);
+  EXPECT_EQ(AnonCS2Packed, AnonCSPacked);
+  // Check ConstantStruct::getTypeForElements(Ctx, ArrayRef).
+  auto *StructTy =
+      sandboxir::ConstantStruct::getTypeForElements(Ctx, {ZeroI42, OneI42});
+  EXPECT_EQ(StructTy, Struct->getType());
+  EXPECT_FALSE(StructTy->isPacked());
+  // Check ConstantStruct::getTypeForElements(Ctx, ArrayRef, Packed).
+  auto *StructTyPacked = sandboxir::ConstantStruct::getTypeForElements(
+      Ctx, {ZeroI42, OneI42}, /*Packed=*/true);
+  EXPECT_TRUE(StructTyPacked->isPacked());
+  // Check ConstantStruct::getTypeForElements(ArrayRef).
+  auto *StructTy2 =
+      sandboxir::ConstantStruct::getTypeForElements(Ctx, {ZeroI42, OneI42});
+  EXPECT_EQ(StructTy2, Struct->getType());
+  // Check ConstantStruct::getTypeForElements(ArrayRef, Packed).
+  auto *StructTy2Packed = sandboxir::ConstantStruct::getTypeForElements(
+      Ctx, {ZeroI42, OneI42}, /*Packed=*/true);
+  EXPECT_EQ(StructTy2Packed, StructTyPacked);
+}
+
 TEST_F(SandboxIRTest, Use) {
   parseIR(C, R"IR(
 define i32 @foo(i32 %v0, i32 %v1) {
diff --git a/llvm/unittests/SandboxIR/TypesTest.cpp b/llvm/unittests/SandboxIR/TypesTest.cpp
index dcbf65a20b2fd7..d4c2de441268c1 100644
--- a/llvm/unittests/SandboxIR/TypesTest.cpp
+++ b/llvm/unittests/SandboxIR/TypesTest.cpp
@@ -224,6 +224,44 @@ define void @foo(ptr %ptr) {
   EXPECT_EQ(NewPtrTy2, PtrTy);
 }
 
+TEST_F(SandboxTypeTest, ArrayType) {
+  parseIR(C, R"IR(
+define void @foo([2 x i8] %v0) {
+  ret void
+}
+)IR");
+  llvm::Function *LLVMF = &*M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  auto *F = Ctx.createFunction(LLVMF);
+  // Check classof(), creation.
+  [[maybe_unused]] auto *ArrayTy =
+      cast<sandboxir::ArrayType>(F->getArg(0)->getType());
+}
+
+TEST_F(SandboxTypeTest, StructType) {
+  parseIR(C, R"IR(
+define void @foo({i32, i8} %v0) {
+  ret void
+}
+)IR");
+  llvm::Function *LLVMF = &*M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  auto *F = Ctx.createFunction(LLVMF);
+  auto *Int32Ty = sandboxir::Type::getInt32Ty(Ctx);
+  auto *Int8Ty = sandboxir::Type::getInt8Ty(Ctx);
+  // Check classof(), creation.
+  [[maybe_unused]] auto *StructTy =
+      cast<sandboxir::StructType>(F->getArg(0)->getType());
+  // Check get().
+  auto *NewStructTy = sandboxir::StructType::get(Ctx, {Int32Ty, Int8Ty});
+  EXPECT_EQ(NewStructTy, StructTy);
+  // Check get(Packed).
+  auto *NewStructTyPacked =
+      sandboxir::StructType::get(Ctx, {Int32Ty, Int8Ty}, /*Packed=*/true);
+  EXPECT_NE(NewStructTyPacked, StructTy);
+  EXPECT_TRUE(NewStructTyPacked->isPacked());
+}
+
 TEST_F(SandboxTypeTest, VectorType) {
   parseIR(C, R"IR(
 define void @foo(<2 x i8> %v0) {



More information about the llvm-commits mailing list