[llvm] [SandboxIR] Implement ConstantDataVector member functions (PR #136200)

via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 17 14:09:16 PDT 2025


https://github.com/vporpo created https://github.com/llvm/llvm-project/pull/136200

Mirroring LLVM IR.

>From 6c26f7eb46da79211816c95b0619e077128990d3 Mon Sep 17 00:00:00 2001
From: Vasileios Porpodas <vporpodas at google.com>
Date: Mon, 17 Mar 2025 12:57:59 -0700
Subject: [PATCH] [SandboxIR] Implement ConstantDataVector member functions

Mirroring LLVM IR.
---
 llvm/include/llvm/SandboxIR/Constant.h     | 91 +++++++++++++++++++++-
 llvm/include/llvm/SandboxIR/Value.h        |  1 +
 llvm/unittests/SandboxIR/SandboxIRTest.cpp | 69 ++++++++++++++++
 3 files changed, 160 insertions(+), 1 deletion(-)

diff --git a/llvm/include/llvm/SandboxIR/Constant.h b/llvm/include/llvm/SandboxIR/Constant.h
index fa00b29dbd803..2012cf8a8ed3e 100644
--- a/llvm/include/llvm/SandboxIR/Constant.h
+++ b/llvm/include/llvm/SandboxIR/Constant.h
@@ -670,7 +670,96 @@ class ConstantDataVector final : public ConstantDataSequential {
   friend class Context;
 
 public:
-  // TODO: Add missing functions.
+  /// Methods for support type inquiry through isa, cast, and dyn_cast:
+  static bool classof(const Value *From) {
+    return From->getSubclassID() == ClassID::ConstantDataVector;
+  }
+  /// get() constructors - Return a constant with vector type with an element
+  /// count and element type matching the ArrayRef passed in.  Note that this
+  /// can return a ConstantAggregateZero object.
+  static Constant *get(Context &Ctx, ArrayRef<uint8_t> Elts) {
+    auto *NewLLVMC = llvm::ConstantDataVector::get(Ctx.LLVMCtx, Elts);
+    return Ctx.getOrCreateConstant(NewLLVMC);
+  }
+  static Constant *get(Context &Ctx, ArrayRef<uint16_t> Elts) {
+    auto *NewLLVMC = llvm::ConstantDataVector::get(Ctx.LLVMCtx, Elts);
+    return Ctx.getOrCreateConstant(NewLLVMC);
+  }
+  static Constant *get(Context &Ctx, ArrayRef<uint32_t> Elts) {
+    auto *NewLLVMC = llvm::ConstantDataVector::get(Ctx.LLVMCtx, Elts);
+    return Ctx.getOrCreateConstant(NewLLVMC);
+  }
+  static Constant *get(Context &Ctx, ArrayRef<uint64_t> Elts) {
+    auto *NewLLVMC = llvm::ConstantDataVector::get(Ctx.LLVMCtx, Elts);
+    return Ctx.getOrCreateConstant(NewLLVMC);
+  }
+  static Constant *get(Context &Ctx, ArrayRef<float> Elts) {
+    auto *NewLLVMC = llvm::ConstantDataVector::get(Ctx.LLVMCtx, Elts);
+    return Ctx.getOrCreateConstant(NewLLVMC);
+  }
+  static Constant *get(Context &Ctx, ArrayRef<double> Elts) {
+    auto *NewLLVMC = llvm::ConstantDataVector::get(Ctx.LLVMCtx, Elts);
+    return Ctx.getOrCreateConstant(NewLLVMC);
+  }
+
+  /// getRaw() constructor - Return a constant with vector type with an element
+  /// count and element type matching the NumElements and ElementTy parameters
+  /// passed in. Note that this can return a ConstantAggregateZero object.
+  /// ElementTy must be one of i8/i16/i32/i64/half/bfloat/float/double. Data is
+  /// the buffer containing the elements. Be careful to make sure Data uses the
+  /// right endianness, the buffer will be used as-is.
+  static Constant *getRaw(StringRef Data, uint64_t NumElements,
+                          Type *ElementTy) {
+    auto *NewLLVMC =
+        llvm::ConstantDataVector::getRaw(Data, NumElements, ElementTy->LLVMTy);
+    return ElementTy->getContext().getOrCreateConstant(NewLLVMC);
+  }
+  /// getFP() constructors - Return a constant of vector type with a float
+  /// element type taken from argument `ElementType', and count taken from
+  /// argument `Elts'.  The amount of bits of the contained type must match the
+  /// number of bits of the type contained in the passed in ArrayRef.
+  /// (i.e. half or bfloat for 16bits, float for 32bits, double for 64bits) Note
+  /// that this can return a ConstantAggregateZero object.
+  static Constant *getFP(Type *ElementType, ArrayRef<uint16_t> Elts) {
+    auto *NewLLVMC = llvm::ConstantDataVector::getFP(ElementType->LLVMTy, Elts);
+    return ElementType->getContext().getOrCreateConstant(NewLLVMC);
+  }
+  static Constant *getFP(Type *ElementType, ArrayRef<uint32_t> Elts) {
+    auto *NewLLVMC = llvm::ConstantDataVector::getFP(ElementType->LLVMTy, Elts);
+    return ElementType->getContext().getOrCreateConstant(NewLLVMC);
+  }
+  static Constant *getFP(Type *ElementType, ArrayRef<uint64_t> Elts) {
+    auto *NewLLVMC = llvm::ConstantDataVector::getFP(ElementType->LLVMTy, Elts);
+    return ElementType->getContext().getOrCreateConstant(NewLLVMC);
+  }
+
+  /// Return a ConstantVector with the specified constant in each element.
+  /// The specified constant has to be a of a compatible type (i8/i16/
+  /// i32/i64/half/bfloat/float/double) and must be a ConstantFP or ConstantInt.
+  static Constant *getSplat(unsigned NumElts, Constant *Elt) {
+    auto *NewLLVMC = llvm::ConstantDataVector::getSplat(
+        NumElts, cast<llvm::Constant>(Elt->Val));
+    return Elt->getContext().getOrCreateConstant(NewLLVMC);
+  }
+
+  /// Returns true if this is a splat constant, meaning that all elements have
+  /// the same value.
+  bool isSplat() const {
+    return cast<llvm::ConstantDataVector>(Val)->isSplat();
+  }
+
+  /// If this is a splat constant, meaning that all of the elements have the
+  /// same value, return that value. Otherwise return NULL.
+  Constant *getSplatValue() const {
+    return Ctx.getOrCreateConstant(
+        cast<llvm::ConstantDataVector>(Val)->getSplatValue());
+  }
+
+  /// Specialize the getType() method to always return a FixedVectorType,
+  /// which reduces the amount of casting needed in parts of the compiler.
+  inline FixedVectorType *getType() const {
+    return cast<FixedVectorType>(Value::getType());
+  }
 };
 
 // TODO: Inherit from ConstantData.
diff --git a/llvm/include/llvm/SandboxIR/Value.h b/llvm/include/llvm/SandboxIR/Value.h
index d45aa4059de69..dbd0208b4f3f3 100644
--- a/llvm/include/llvm/SandboxIR/Value.h
+++ b/llvm/include/llvm/SandboxIR/Value.h
@@ -171,6 +171,7 @@ class Value {
   friend class Region;
   friend class ScoreBoard; // Needs access to `Val` for the instruction cost.
   friend class ConstantDataArray; // For `Val`
+  friend class ConstantDataVector; // For `Val`
 
   /// All values point to the context.
   Context &Ctx;
diff --git a/llvm/unittests/SandboxIR/SandboxIRTest.cpp b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
index 8cce659596a4d..18882add59941 100644
--- a/llvm/unittests/SandboxIR/SandboxIRTest.cpp
+++ b/llvm/unittests/SandboxIR/SandboxIRTest.cpp
@@ -622,6 +622,7 @@ define void @foo() {
   %fvector = extractelement <2 x double> <double 0.0, double 1.0>, i32 0
   %string = extractvalue [6 x i8] [i8 72, i8 69, i8 76, i8 76, i8 79, i8 0], 0
   %stringNoNull = extractvalue [5 x i8] [i8 72, i8 69, i8 76, i8 76, i8 79], 0
+  %splat = extractelement <4 x i8> <i8 1, i8 1, i8 1, i8 1>, i32 0
   ret void
 }
 )IR");
@@ -637,6 +638,7 @@ define void @foo() {
   auto *I3 = &*It++;
   auto *I4 = &*It++;
   auto *I5 = &*It++;
+  auto *I6 = &*It++;
   auto *Array = cast<sandboxir::ConstantDataArray>(I0->getOperand(0));
   EXPECT_TRUE(isa<sandboxir::ConstantDataSequential>(Array));
   auto *Vector = cast<sandboxir::ConstantDataVector>(I1->getOperand(0));
@@ -649,6 +651,8 @@ define void @foo() {
   EXPECT_TRUE(isa<sandboxir::ConstantDataArray>(String));
   auto *StringNoNull = cast<sandboxir::ConstantDataArray>(I5->getOperand(0));
   EXPECT_TRUE(isa<sandboxir::ConstantDataArray>(StringNoNull));
+  auto *Splat = cast<sandboxir::ConstantDataVector>(I6->getOperand(0));
+  EXPECT_TRUE(isa<sandboxir::ConstantDataVector>(Splat));
 
   auto *Zero8 = sandboxir::ConstantInt::get(sandboxir::Type::getInt8Ty(Ctx), 0);
   auto *One8 = sandboxir::ConstantInt::get(sandboxir::Type::getInt8Ty(Ctx), 1);
@@ -750,9 +754,74 @@ define void @foo() {
                            llvm::Type::getDoubleTy(C), Elts64))));
   // Check getString().
   EXPECT_EQ(sandboxir::ConstantDataArray::getString(Ctx, "HELLO"), String);
+
   EXPECT_EQ(sandboxir::ConstantDataArray::getString(Ctx, "HELLO",
                                                     /*AddNull=*/false),
             StringNoNull);
+  EXPECT_EQ(
+      sandboxir::ConstantDataArray::getString(Ctx, "HELLO", /*AddNull=*/false),
+      StringNoNull);
+
+  {
+    // Check ConstantDataArray member functions
+    // ----------------------------------------
+    // Check get().
+    SmallVector<uint8_t> Elts8({0u, 1u});
+    SmallVector<uint16_t> Elts16({0u, 1u});
+    SmallVector<uint32_t> Elts32({0u, 1u});
+    SmallVector<uint64_t> Elts64({0u, 1u});
+    SmallVector<float> EltsF32({0.0, 1.0});
+    SmallVector<double> EltsF64({0.0, 1.0});
+    auto *CDV8 = sandboxir::ConstantDataVector::get(Ctx, Elts8);
+    EXPECT_EQ(CDV8, cast<sandboxir::ConstantDataVector>(
+                        Ctx.getValue(llvm::ConstantDataVector::get(C, Elts8))));
+    auto *CDV16 = sandboxir::ConstantDataVector::get(Ctx, Elts16);
+    EXPECT_EQ(CDV16, cast<sandboxir::ConstantDataVector>(Ctx.getValue(
+                         llvm::ConstantDataVector::get(C, Elts16))));
+    auto *CDV32 = sandboxir::ConstantDataVector::get(Ctx, Elts32);
+    EXPECT_EQ(CDV32, cast<sandboxir::ConstantDataVector>(Ctx.getValue(
+                         llvm::ConstantDataVector::get(C, Elts32))));
+    auto *CDVF32 = sandboxir::ConstantDataVector::get(Ctx, EltsF32);
+    EXPECT_EQ(CDVF32, cast<sandboxir::ConstantDataVector>(Ctx.getValue(
+                          llvm::ConstantDataVector::get(C, EltsF32))));
+    auto *CDVF64 = sandboxir::ConstantDataVector::get(Ctx, EltsF64);
+    EXPECT_EQ(CDVF64, cast<sandboxir::ConstantDataVector>(Ctx.getValue(
+                          llvm::ConstantDataVector::get(C, EltsF64))));
+    // Check getRaw().
+    auto *CDVRaw = sandboxir::ConstantDataVector::getRaw(
+        StringRef("HELLO"), 5, sandboxir::Type::getInt8Ty(Ctx));
+    EXPECT_EQ(CDVRaw,
+              cast<sandboxir::ConstantDataVector>(
+                  Ctx.getValue(llvm::ConstantDataVector::getRaw(
+                      StringRef("HELLO"), 5, llvm::Type::getInt8Ty(C)))));
+    // Check getFP().
+    auto *CDVFP16 = sandboxir::ConstantDataVector::getFP(F16Ty, Elts16);
+    EXPECT_EQ(CDVFP16, cast<sandboxir::ConstantDataVector>(
+                           Ctx.getValue(llvm::ConstantDataVector::getFP(
+                               llvm::Type::getHalfTy(C), Elts16))));
+    auto *CDVFP32 = sandboxir::ConstantDataVector::getFP(F32Ty, Elts32);
+    EXPECT_EQ(CDVFP32, cast<sandboxir::ConstantDataVector>(
+                           Ctx.getValue(llvm::ConstantDataVector::getFP(
+                               llvm::Type::getFloatTy(C), Elts32))));
+    auto *CDVFP64 = sandboxir::ConstantDataVector::getFP(F64Ty, Elts64);
+    EXPECT_EQ(CDVFP64, cast<sandboxir::ConstantDataVector>(
+                           Ctx.getValue(llvm::ConstantDataVector::getFP(
+                               llvm::Type::getDoubleTy(C), Elts64))));
+    // Check getSplat().
+    auto *NewSplat = cast<sandboxir::ConstantDataVector>(
+        sandboxir::ConstantDataVector::getSplat(4, One8));
+    EXPECT_EQ(NewSplat, Splat);
+    // Check isSplat().
+    EXPECT_TRUE(NewSplat->isSplat());
+    EXPECT_FALSE(Vector->isSplat());
+    // Check getSplatValue().
+    EXPECT_EQ(NewSplat->getSplatValue(), One8);
+    // Check getType().
+    EXPECT_TRUE(isa<sandboxir::FixedVectorType>(NewSplat->getType()));
+    EXPECT_EQ(
+        cast<sandboxir::FixedVectorType>(NewSplat->getType())->getNumElements(),
+        4u);
+  }
 }
 
 TEST_F(SandboxIRTest, ConstantPointerNull) {



More information about the llvm-commits mailing list