[llvm] [SandboxIR] Implement missing ConstantVector member functions (PR #131390)

via llvm-commits llvm-commits at lists.llvm.org
Fri Mar 14 14:01:36 PDT 2025


https://github.com/vporpo created https://github.com/llvm/llvm-project/pull/131390

This patch implements get(), getSplat(), getSplatValue() and getType().

>From d5f20755c53592dab3b46a1e7cf07537254ff3bc Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Fri, 14 Mar 2025 13:58:27 -0700
Subject: [PATCH] [SandboxIR] Implement missing ConstantVector member functions

This patch implements get(), getSplat(), getSplatValue() and getType().
---
 llvm/include/llvm/SandboxIR/Constant.h     | 14 +++++++++++++-
 llvm/include/llvm/SandboxIR/Value.h        |  1 +
 llvm/lib/SandboxIR/Constant.cpp            | 22 ++++++++++++++++++++++
 llvm/unittests/SandboxIR/SandboxIRTest.cpp | 12 ++++++++++++
 4 files changed, 48 insertions(+), 1 deletion(-)

diff --git a/llvm/include/llvm/SandboxIR/Constant.h b/llvm/include/llvm/SandboxIR/Constant.h
index 17f55e973cd76..c4841e0b0dd66 100644
--- a/llvm/include/llvm/SandboxIR/Constant.h
+++ b/llvm/include/llvm/SandboxIR/Constant.h
@@ -424,7 +424,19 @@ class ConstantVector final : public ConstantAggregate {
   friend class Context; // For constructor.
 
 public:
-  // TODO: Missing functions: getSplat(), getType(), getSplatValue(), get().
+  static Constant *get(ArrayRef<Constant *> V);
+  /// Return a ConstantVector with the specified constant in each element.
+  /// Note that this might not return an instance of ConstantVector
+  static Constant *getSplat(ElementCount EC, Constant *Elt);
+  /// Specialize the getType() method to always return a FixedVectorType,
+  /// which reduces the amount of casting needed in parts of the compiler.
+  inline FixedVectorType *getType() const {
+    return cast<FixedVectorType>(Value::getType());
+  }
+  /// If all elements of the vector constant have the same value, return that
+  /// value. Otherwise, return nullptr. Ignore poison elements by setting
+  /// AllowPoison to true.
+  Constant *getSplatValue(bool AllowPoison = false) const;
 
   /// For isa/dyn_cast.
   static bool classof(const Value *From) {
diff --git a/llvm/include/llvm/SandboxIR/Value.h b/llvm/include/llvm/SandboxIR/Value.h
index 2e91b96bb22e6..508978709f96c 100644
--- a/llvm/include/llvm/SandboxIR/Value.h
+++ b/llvm/include/llvm/SandboxIR/Value.h
@@ -145,6 +145,7 @@ class Value {
   friend class CmpInst;               // For getting `Val`.
   friend class ConstantArray;         // For `Val`.
   friend class ConstantStruct;        // For `Val`.
+  friend class ConstantVector;        // For `Val`.
   friend class ConstantAggregateZero; // For `Val`.
   friend class ConstantPointerNull;   // For `Val`.
   friend class UndefValue;            // For `Val`.
diff --git a/llvm/lib/SandboxIR/Constant.cpp b/llvm/lib/SandboxIR/Constant.cpp
index 3e13c935c4281..e2b0472bf6b08 100644
--- a/llvm/lib/SandboxIR/Constant.cpp
+++ b/llvm/lib/SandboxIR/Constant.cpp
@@ -174,6 +174,28 @@ StructType *ConstantStruct::getTypeForElements(Context &Ctx,
   return StructType::get(Ctx, EltTypes, Packed);
 }
 
+Constant *ConstantVector::get(ArrayRef<Constant *> V) {
+  assert(!V.empty() && "Expected non-empty V!");
+  auto &Ctx = V[0]->getContext();
+  SmallVector<llvm::Constant *, 8> LLVMV;
+  LLVMV.reserve(V.size());
+  for (auto *Elm : V)
+    LLVMV.push_back(cast<llvm::Constant>(Elm->Val));
+  return Ctx.getOrCreateConstant(llvm::ConstantVector::get(LLVMV));
+}
+
+Constant *ConstantVector::getSplat(ElementCount EC, Constant *Elt) {
+  auto *LLVMElt = cast<llvm::Constant>(Elt->Val);
+  auto &Ctx = Elt->getContext();
+  return Ctx.getOrCreateConstant(llvm::ConstantVector::getSplat(EC, LLVMElt));
+}
+
+Constant *ConstantVector::getSplatValue(bool AllowPoison) const {
+  auto *LLVMSplatValue = cast_or_null<llvm::Constant>(
+      cast<llvm::ConstantVector>(Val)->getSplatValue(AllowPoison));
+  return LLVMSplatValue ? Ctx.getOrCreateConstant(LLVMSplatValue) : nullptr;
+}
+
 ConstantAggregateZero *ConstantAggregateZero::get(Type *Ty) {
   auto *LLVMC = llvm::ConstantAggregateZero::get(Ty->LLVMTy);
   return cast<ConstantAggregateZero>(
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index 088264e0429fd..bdc9c2c222ae5 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -524,6 +524,18 @@ define void @foo() {
   auto *StructTy2Packed = sandboxir::ConstantStruct::getTypeForElements(
       Ctx, {ZeroI42, OneI42}, /*Packed=*/true);
   EXPECT_EQ(StructTy2Packed, StructTyPacked);
+
+  // Check ConstantVector::get().
+  auto *NewCV = sandboxir::ConstantVector::get({ZeroI42, OneI42});
+  EXPECT_EQ(NewCV, Vector);
+  // Check ConstantVector::getSplat(), getType().
+  auto *SplatRaw =
+      sandboxir::ConstantVector::getSplat(ElementCount::getFixed(2), OneI42);
+  auto *Splat = cast<sandboxir::ConstantVector>(SplatRaw);
+  EXPECT_EQ(Splat->getType()->getNumElements(), 2u);
+  EXPECT_THAT(Splat->operands(), testing::ElementsAre(OneI42, OneI42));
+  // Check ConstantVector::getSplatValue().
+  EXPECT_EQ(Splat->getSplatValue(), OneI42);
 }
 
 TEST_F(SandboxIRTest, ConstantAggregateZero) {



More information about the llvm-commits mailing list