[llvm] b743bbc - Add ConstantDataVector::getRaw() to create a constant data vector from raw data.
Nick Lewycky via llvm-commits
llvm-commits at lists.llvm.org
Tue Mar 16 11:58:00 PDT 2021
Author: Nick Lewycky
Date: 2021-03-16T11:57:53-07:00
New Revision: b743bbc50586151514cd9f7f6487ad4d9838aded
URL: https://github.com/llvm/llvm-project/commit/b743bbc50586151514cd9f7f6487ad4d9838aded
DIFF: https://github.com/llvm/llvm-project/commit/b743bbc50586151514cd9f7f6487ad4d9838aded.diff
LOG: Add ConstantDataVector::getRaw() to create a constant data vector from raw data.
This parallels ConstantDataArray::getRaw() and can be used with ConstantDataSequential::getRawDataValues() in the base class for both types.
Update BuildConstantData{Array,Vector} tests to test the getRaw API. Also removes its unused Module.
In passing, update some comments to include the support for half and bfloat. Update tests to include testing for bfloat.
Differential Revision: https://reviews.llvm.org/D98302
Added:
Modified:
llvm/include/llvm/IR/Constants.h
llvm/unittests/IR/ConstantsTest.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h
index 510163abe6eb..223e47aa84e7 100644
--- a/llvm/include/llvm/IR/Constants.h
+++ b/llvm/include/llvm/IR/Constants.h
@@ -558,10 +558,10 @@ class ConstantPointerNull final : public ConstantData {
//===----------------------------------------------------------------------===//
/// ConstantDataSequential - A vector or array constant whose element type is a
-/// simple 1/2/4/8-byte integer or float/double, and whose elements are just
-/// simple data values (i.e. ConstantInt/ConstantFP). This Constant node has no
-/// operands because it stores all of the elements of the constant as densely
-/// packed data, instead of as Value*'s.
+/// simple 1/2/4/8-byte integer or half/bfloat/float/double, and whose elements
+/// are just simple data values (i.e. ConstantInt/ConstantFP). This Constant
+/// node has no operands because it stores all of the elements of the constant
+/// as densely packed data, instead of as Value*'s.
///
/// This is the common base class of ConstantDataArray and ConstantDataVector.
///
@@ -700,11 +700,11 @@ class ConstantDataArray final : public ConstantDataSequential {
return ConstantDataArray::get(Context, makeArrayRef(Elts));
}
- /// get() constructor - Return a constant with array type with an element
+ /// getRaw() constructor - Return a constant with array type with an element
/// count and element type matching the NumElements and ElementTy parameters
/// passed in. Note that this can return a ConstantAggregateZero object.
- /// ElementTy needs to be one of i8/i16/i32/i64/float/double. Data is the
- /// buffer containing the elements. Be careful to make sure Data uses the
+ /// ElementTy must be one of i8/i16/i32/i64/half/bfloat/float/double. Data is
+ /// the buffer containing the elements. Be careful to make sure Data uses the
/// right endianness, the buffer will be used as-is.
static Constant *getRaw(StringRef Data, uint64_t NumElements,
Type *ElementTy) {
@@ -772,6 +772,18 @@ class ConstantDataVector final : public ConstantDataSequential {
static Constant *get(LLVMContext &Context, ArrayRef<float> Elts);
static Constant *get(LLVMContext &Context, ArrayRef<double> Elts);
+ /// getRaw() constructor - Return a constant with vector type with an element
+ /// count and element type matching the NumElements and ElementTy parameters
+ /// passed in. Note that this can return a ConstantAggregateZero object.
+ /// ElementTy must be one of i8/i16/i32/i64/half/bfloat/float/double. Data is
+ /// the buffer containing the elements. Be careful to make sure Data uses the
+ /// right endianness, the buffer will be used as-is.
+ static Constant *getRaw(StringRef Data, uint64_t NumElements,
+ Type *ElementTy) {
+ Type *Ty = VectorType::get(ElementTy, ElementCount::getFixed(NumElements));
+ return getImpl(Data, Ty);
+ }
+
/// getFP() constructors - Return a constant of vector type with a float
/// element type taken from argument `ElementType', and count taken from
/// argument `Elts'. The amount of bits of the contained type must match the
@@ -784,7 +796,7 @@ class ConstantDataVector final : public ConstantDataSequential {
/// Return a ConstantVector with the specified constant in each element.
/// The specified constant has to be a of a compatible type (i8/i16/
- /// i32/i64/float/double) and must be a ConstantFP or ConstantInt.
+ /// i32/i64/half/bfloat/float/double) and must be a ConstantFP or ConstantInt.
static Constant *getSplat(unsigned NumElts, Constant *Elt);
/// Returns true if this is a splat constant, meaning that all elements have
diff --git a/llvm/unittests/IR/ConstantsTest.cpp b/llvm/unittests/IR/ConstantsTest.cpp
index 44dbb90758ad..50eb3e0df1f5 100644
--- a/llvm/unittests/IR/ConstantsTest.cpp
+++ b/llvm/unittests/IR/ConstantsTest.cpp
@@ -418,45 +418,55 @@ static std::string getNameOfType(Type *T) {
TEST(ConstantsTest, BuildConstantDataArrays) {
LLVMContext Context;
- std::unique_ptr<Module> M(new Module("MyModule", Context));
for (Type *T : {Type::getInt8Ty(Context), Type::getInt16Ty(Context),
Type::getInt32Ty(Context), Type::getInt64Ty(Context)}) {
ArrayType *ArrayTy = ArrayType::get(T, 2);
Constant *Vals[] = {ConstantInt::get(T, 0), ConstantInt::get(T, 1)};
- Constant *CDV = ConstantArray::get(ArrayTy, Vals);
- ASSERT_TRUE(dyn_cast<ConstantDataArray>(CDV) != nullptr)
- << " T = " << getNameOfType(T);
+ Constant *CA = ConstantArray::get(ArrayTy, Vals);
+ ASSERT_TRUE(isa<ConstantDataArray>(CA)) << " T = " << getNameOfType(T);
+ auto *CDA = cast<ConstantDataArray>(CA);
+ Constant *CA2 = ConstantDataArray::getRaw(
+ CDA->getRawDataValues(), CDA->getNumElements(), CDA->getElementType());
+ ASSERT_TRUE(CA == CA2) << " T = " << getNameOfType(T);
}
- for (Type *T : {Type::getHalfTy(Context), Type::getFloatTy(Context),
- Type::getDoubleTy(Context)}) {
+ for (Type *T : {Type::getHalfTy(Context), Type::getBFloatTy(Context),
+ Type::getFloatTy(Context), Type::getDoubleTy(Context)}) {
ArrayType *ArrayTy = ArrayType::get(T, 2);
Constant *Vals[] = {ConstantFP::get(T, 0), ConstantFP::get(T, 1)};
- Constant *CDV = ConstantArray::get(ArrayTy, Vals);
- ASSERT_TRUE(dyn_cast<ConstantDataArray>(CDV) != nullptr)
- << " T = " << getNameOfType(T);
+ Constant *CA = ConstantArray::get(ArrayTy, Vals);
+ ASSERT_TRUE(isa<ConstantDataArray>(CA)) << " T = " << getNameOfType(T);
+ auto *CDA = cast<ConstantDataArray>(CA);
+ Constant *CA2 = ConstantDataArray::getRaw(
+ CDA->getRawDataValues(), CDA->getNumElements(), CDA->getElementType());
+ ASSERT_TRUE(CA == CA2) << " T = " << getNameOfType(T);
}
}
TEST(ConstantsTest, BuildConstantDataVectors) {
LLVMContext Context;
- std::unique_ptr<Module> M(new Module("MyModule", Context));
for (Type *T : {Type::getInt8Ty(Context), Type::getInt16Ty(Context),
Type::getInt32Ty(Context), Type::getInt64Ty(Context)}) {
Constant *Vals[] = {ConstantInt::get(T, 0), ConstantInt::get(T, 1)};
- Constant *CDV = ConstantVector::get(Vals);
- ASSERT_TRUE(dyn_cast<ConstantDataVector>(CDV) != nullptr)
- << " T = " << getNameOfType(T);
+ Constant *CV = ConstantVector::get(Vals);
+ ASSERT_TRUE(isa<ConstantDataVector>(CV)) << " T = " << getNameOfType(T);
+ auto *CDV = cast<ConstantDataVector>(CV);
+ Constant *CV2 = ConstantDataVector::getRaw(
+ CDV->getRawDataValues(), CDV->getNumElements(), CDV->getElementType());
+ ASSERT_TRUE(CV == CV2) << " T = " << getNameOfType(T);
}
- for (Type *T : {Type::getHalfTy(Context), Type::getFloatTy(Context),
- Type::getDoubleTy(Context)}) {
+ for (Type *T : {Type::getHalfTy(Context), Type::getBFloatTy(Context),
+ Type::getFloatTy(Context), Type::getDoubleTy(Context)}) {
Constant *Vals[] = {ConstantFP::get(T, 0), ConstantFP::get(T, 1)};
- Constant *CDV = ConstantVector::get(Vals);
- ASSERT_TRUE(dyn_cast<ConstantDataVector>(CDV) != nullptr)
- << " T = " << getNameOfType(T);
+ Constant *CV = ConstantVector::get(Vals);
+ ASSERT_TRUE(isa<ConstantDataVector>(CV)) << " T = " << getNameOfType(T);
+ auto *CDV = cast<ConstantDataVector>(CV);
+ Constant *CV2 = ConstantDataVector::getRaw(
+ CDV->getRawDataValues(), CDV->getNumElements(), CDV->getElementType());
+ ASSERT_TRUE(CV == CV2) << " T = " << getNameOfType(T);
}
}
More information about the llvm-commits
mailing list