[llvm] 1ee740a - [VFABI] Add support for vector functions that return struct types (#119000)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Dec 18 01:46:48 PST 2024
Author: Benjamin Maxwell
Date: 2024-12-18T09:46:45Z
New Revision: 1ee740a79620aa680f68d873d6a7b5cfa1df7b19
URL: https://github.com/llvm/llvm-project/commit/1ee740a79620aa680f68d873d6a7b5cfa1df7b19
DIFF: https://github.com/llvm/llvm-project/commit/1ee740a79620aa680f68d873d6a7b5cfa1df7b19.diff
LOG: [VFABI] Add support for vector functions that return struct types (#119000)
This patch updates the `VFABIDemangler` to support vector functions that
return struct types. For example, a vector variant of `sincos` that
returns a vector of sine values and a vector of cosine values within a
struct.
This patch also adds some helpers for vectorizing types (including
struct types). Some of these are used in the `VFABIDemangler`, and
others will be used in subsequent patches, so this patch simply adds
tests for them.
Added:
llvm/include/llvm/IR/VectorTypeUtils.h
llvm/lib/IR/VectorTypeUtils.cpp
llvm/unittests/IR/VectorTypeUtilsTest.cpp
Modified:
llvm/include/llvm/Analysis/VectorUtils.h
llvm/lib/IR/CMakeLists.txt
llvm/lib/IR/VFABIDemangler.cpp
llvm/unittests/IR/CMakeLists.txt
llvm/unittests/IR/VFABIDemanglerTest.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/Analysis/VectorUtils.h b/llvm/include/llvm/Analysis/VectorUtils.h
index c1016dd7bdddbd..7f8a0c9c0af7be 100644
--- a/llvm/include/llvm/Analysis/VectorUtils.h
+++ b/llvm/include/llvm/Analysis/VectorUtils.h
@@ -18,6 +18,7 @@
#include "llvm/Analysis/LoopAccessAnalysis.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/VFABIDemangler.h"
+#include "llvm/IR/VectorTypeUtils.h"
#include "llvm/Support/CheckedArithmetic.h"
namespace llvm {
@@ -127,19 +128,6 @@ namespace Intrinsic {
typedef unsigned ID;
}
-/// A helper function for converting Scalar types to vector types. If
-/// the incoming type is void, we return void. If the EC represents a
-/// scalar, we return the scalar type.
-inline Type *ToVectorTy(Type *Scalar, ElementCount EC) {
- if (Scalar->isVoidTy() || Scalar->isMetadataTy() || EC.isScalar())
- return Scalar;
- return VectorType::get(Scalar, EC);
-}
-
-inline Type *ToVectorTy(Type *Scalar, unsigned VF) {
- return ToVectorTy(Scalar, ElementCount::getFixed(VF));
-}
-
/// Identify if the intrinsic is trivially vectorizable.
/// This method returns true if the intrinsic's argument types are all scalars
/// for the scalar form of the intrinsic and all vectors (or scalars handled by
diff --git a/llvm/include/llvm/IR/VectorTypeUtils.h b/llvm/include/llvm/IR/VectorTypeUtils.h
new file mode 100644
index 00000000000000..f30bf9ee9240b0
--- /dev/null
+++ b/llvm/include/llvm/IR/VectorTypeUtils.h
@@ -0,0 +1,94 @@
+//===------- VectorTypeUtils.h - Vector type utility functions -*- C++ -*-====//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_IR_VECTORTYPEUTILS_H
+#define LLVM_IR_VECTORTYPEUTILS_H
+
+#include "llvm/IR/DerivedTypes.h"
+
+namespace llvm {
+
+/// A helper function for converting Scalar types to vector types. If
+/// the incoming type is void, we return void. If the EC represents a
+/// scalar, we return the scalar type.
+inline Type *ToVectorTy(Type *Scalar, ElementCount EC) {
+ if (Scalar->isVoidTy() || Scalar->isMetadataTy() || EC.isScalar())
+ return Scalar;
+ return VectorType::get(Scalar, EC);
+}
+
+inline Type *ToVectorTy(Type *Scalar, unsigned VF) {
+ return ToVectorTy(Scalar, ElementCount::getFixed(VF));
+}
+
+/// A helper for converting structs of scalar types to structs of vector types.
+/// Note:
+/// - If \p EC is scalar, \p StructTy is returned unchanged
+/// - Only unpacked literal struct types are supported
+Type *toVectorizedStructTy(StructType *StructTy, ElementCount EC);
+
+/// A helper for converting structs of vector types to structs of scalar types.
+/// Note: Only unpacked literal struct types are supported.
+Type *toScalarizedStructTy(StructType *StructTy);
+
+/// Returns true if `StructTy` is an unpacked literal struct where all elements
+/// are vectors of matching element count. This does not include empty structs.
+bool isVectorizedStructTy(StructType *StructTy);
+
+/// A helper for converting to vectorized types. For scalar types, this is
+/// equivalent to calling `ToVectorTy`. For struct types, this returns a new
+/// struct where each element type has been widened to a vector type.
+/// Note:
+/// - If the incoming type is void, we return void
+/// - If \p EC is scalar, \p Ty is returned unchanged
+/// - Only unpacked literal struct types are supported
+inline Type *toVectorizedTy(Type *Ty, ElementCount EC) {
+ if (StructType *StructTy = dyn_cast<StructType>(Ty))
+ return toVectorizedStructTy(StructTy, EC);
+ return ToVectorTy(Ty, EC);
+}
+
+/// A helper for converting vectorized types to scalarized (non-vector) types.
+/// For vector types, this is equivalent to calling .getScalarType(). For struct
+/// types, this returns a new struct where each element type has been converted
+/// to a scalar type. Note: Only unpacked literal struct types are supported.
+inline Type *toScalarizedTy(Type *Ty) {
+ if (StructType *StructTy = dyn_cast<StructType>(Ty))
+ return toScalarizedStructTy(StructTy);
+ return Ty->getScalarType();
+}
+
+/// Returns true if `Ty` is a vector type or a struct of vector types where all
+/// vector types share the same VF.
+inline bool isVectorizedTy(Type *Ty) {
+ if (StructType *StructTy = dyn_cast<StructType>(Ty))
+ return isVectorizedStructTy(StructTy);
+ return Ty->isVectorTy();
+}
+
+/// Returns the types contained in `Ty`. For struct types, it returns the
+/// elements, all other types are returned directly.
+inline ArrayRef<Type *> getContainedTypes(Type *const &Ty) {
+ if (auto *StructTy = dyn_cast<StructType>(Ty))
+ return StructTy->elements();
+ return ArrayRef<Type *>(&Ty, 1);
+}
+
+/// Returns the number of vector elements for a vectorized type.
+inline ElementCount getVectorizedTypeVF(Type *Ty) {
+ assert(isVectorizedTy(Ty) && "expected vectorized type");
+ return cast<VectorType>(getContainedTypes(Ty).front())->getElementCount();
+}
+
+inline bool isUnpackedStructLiteral(StructType *StructTy) {
+ return StructTy->isLiteral() && !StructTy->isPacked();
+}
+
+} // namespace llvm
+
+#endif
diff --git a/llvm/lib/IR/CMakeLists.txt b/llvm/lib/IR/CMakeLists.txt
index 544f4ea9223d0e..5f6254b2313180 100644
--- a/llvm/lib/IR/CMakeLists.txt
+++ b/llvm/lib/IR/CMakeLists.txt
@@ -73,6 +73,7 @@ add_llvm_component_library(LLVMCore
Value.cpp
ValueSymbolTable.cpp
VectorBuilder.cpp
+ VectorTypeUtils.cpp
Verifier.cpp
VFABIDemangler.cpp
RuntimeLibcalls.cpp
diff --git a/llvm/lib/IR/VFABIDemangler.cpp b/llvm/lib/IR/VFABIDemangler.cpp
index 897583084bf38c..62f96b10cea4ac 100644
--- a/llvm/lib/IR/VFABIDemangler.cpp
+++ b/llvm/lib/IR/VFABIDemangler.cpp
@@ -11,6 +11,7 @@
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/IR/Module.h"
+#include "llvm/IR/VectorTypeUtils.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <limits>
@@ -346,12 +347,20 @@ getScalableECFromSignature(const FunctionType *Signature, const VFISAKind ISA,
// Also check the return type if not void.
Type *RetTy = Signature->getReturnType();
if (!RetTy->isVoidTy()) {
- std::optional<ElementCount> ReturnEC = getElementCountForTy(ISA, RetTy);
- // If we have an unknown scalar element type we can't find a reasonable VF.
- if (!ReturnEC)
+ // If the return type is a struct, only allow unpacked struct literals.
+ StructType *StructTy = dyn_cast<StructType>(RetTy);
+ if (StructTy && !isUnpackedStructLiteral(StructTy))
return std::nullopt;
- if (ElementCount::isKnownLT(*ReturnEC, MinEC))
- MinEC = *ReturnEC;
+
+ for (Type *RetTy : getContainedTypes(RetTy)) {
+ std::optional<ElementCount> ReturnEC = getElementCountForTy(ISA, RetTy);
+ // If we have an unknown scalar element type we can't find a reasonable
+ // VF.
+ if (!ReturnEC)
+ return std::nullopt;
+ if (ElementCount::isKnownLT(*ReturnEC, MinEC))
+ MinEC = *ReturnEC;
+ }
}
// The SVE Vector function call ABI bases the VF on the widest element types
@@ -566,7 +575,7 @@ FunctionType *VFABI::createFunctionType(const VFInfo &Info,
auto *RetTy = ScalarFTy->getReturnType();
if (!RetTy->isVoidTy())
- RetTy = VectorType::get(RetTy, VF);
+ RetTy = toVectorizedTy(RetTy, VF);
return FunctionType::get(RetTy, VecTypes, false);
}
diff --git a/llvm/lib/IR/VectorTypeUtils.cpp b/llvm/lib/IR/VectorTypeUtils.cpp
new file mode 100644
index 00000000000000..e6e265414a2b8e
--- /dev/null
+++ b/llvm/lib/IR/VectorTypeUtils.cpp
@@ -0,0 +1,54 @@
+//===------- VectorTypeUtils.cpp - Vector type utility functions ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/IR/VectorTypeUtils.h"
+#include "llvm/ADT/SmallVectorExtras.h"
+
+using namespace llvm;
+
+/// A helper for converting structs of scalar types to structs of vector types.
+/// Note: Only unpacked literal struct types are supported.
+Type *llvm::toVectorizedStructTy(StructType *StructTy, ElementCount EC) {
+ if (EC.isScalar())
+ return StructTy;
+ assert(isUnpackedStructLiteral(StructTy) &&
+ "expected unpacked struct literal");
+ assert(all_of(StructTy->elements(), VectorType::isValidElementType) &&
+ "expected all element types to be valid vector element types");
+ return StructType::get(
+ StructTy->getContext(),
+ map_to_vector(StructTy->elements(), [&](Type *ElTy) -> Type * {
+ return VectorType::get(ElTy, EC);
+ }));
+}
+
+/// A helper for converting structs of vector types to structs of scalar types.
+/// Note: Only unpacked literal struct types are supported.
+Type *llvm::toScalarizedStructTy(StructType *StructTy) {
+ assert(isUnpackedStructLiteral(StructTy) &&
+ "expected unpacked struct literal");
+ return StructType::get(
+ StructTy->getContext(),
+ map_to_vector(StructTy->elements(), [](Type *ElTy) -> Type * {
+ return ElTy->getScalarType();
+ }));
+}
+
+/// Returns true if `StructTy` is an unpacked literal struct where all elements
+/// are vectors of matching element count. This does not include empty structs.
+bool llvm::isVectorizedStructTy(StructType *StructTy) {
+ if (!isUnpackedStructLiteral(StructTy))
+ return false;
+ auto ElemTys = StructTy->elements();
+ if (ElemTys.empty() || !ElemTys.front()->isVectorTy())
+ return false;
+ ElementCount VF = cast<VectorType>(ElemTys.front())->getElementCount();
+ return all_of(ElemTys, [&](Type *Ty) {
+ return Ty->isVectorTy() && cast<VectorType>(Ty)->getElementCount() == VF;
+ });
+}
diff --git a/llvm/unittests/IR/CMakeLists.txt b/llvm/unittests/IR/CMakeLists.txt
index ed93ee547d2231..b3dfe3d72fd385 100644
--- a/llvm/unittests/IR/CMakeLists.txt
+++ b/llvm/unittests/IR/CMakeLists.txt
@@ -51,6 +51,7 @@ add_llvm_unittest(IRTests
ValueMapTest.cpp
ValueTest.cpp
VectorBuilderTest.cpp
+ VectorTypeUtilsTest.cpp
VectorTypesTest.cpp
VerifierTest.cpp
VFABIDemanglerTest.cpp
diff --git a/llvm/unittests/IR/VFABIDemanglerTest.cpp b/llvm/unittests/IR/VFABIDemanglerTest.cpp
index 07bff16df49335..e30e0f865f7199 100644
--- a/llvm/unittests/IR/VFABIDemanglerTest.cpp
+++ b/llvm/unittests/IR/VFABIDemanglerTest.cpp
@@ -40,7 +40,9 @@ class VFABIParserTest : public ::testing::Test {
VFInfo Info;
/// Reset the data needed for the test.
void reset(const StringRef ScalarFTyStr) {
- M = parseAssemblyString("declare void @dummy()", Err, Ctx);
+ M = parseAssemblyString("%dummy_named_struct = type { double, double }\n"
+ "declare void @dummy()",
+ Err, Ctx);
EXPECT_NE(M.get(), nullptr)
<< "Loading an invalid module.\n " << Err.getMessage() << "\n";
Type *Ty = parseType(ScalarFTyStr, Err, *(M));
@@ -753,6 +755,87 @@ TEST_F(VFABIParserTest, ParseVoidReturnTypeSVE) {
EXPECT_EQ(VectorName, "vector_foo");
}
+TEST_F(VFABIParserTest, ParseWideStructReturnTypeSVE) {
+ EXPECT_TRUE(
+ invokeParser("_ZGVsMxv_foo(vector_foo)", "{double, double}(float)"));
+ EXPECT_EQ(ISA, VFISAKind::SVE);
+ EXPECT_TRUE(isMasked());
+ ElementCount NXV2 = ElementCount::getScalable(2);
+ FunctionType *FTy = FunctionType::get(
+ StructType::get(VectorType::get(Type::getDoubleTy(Ctx), NXV2),
+ VectorType::get(Type::getDoubleTy(Ctx), NXV2)),
+ {
+ VectorType::get(Type::getFloatTy(Ctx), NXV2),
+ VectorType::get(Type::getInt1Ty(Ctx), NXV2),
+ },
+ false);
+ EXPECT_EQ(getFunctionType(), FTy);
+ EXPECT_EQ(Parameters.size(), 2U);
+ EXPECT_EQ(Parameters[0], VFParameter({0, VFParamKind::Vector}));
+ EXPECT_EQ(Parameters[1], VFParameter({1, VFParamKind::GlobalPredicate}));
+ EXPECT_EQ(VF, NXV2);
+ EXPECT_EQ(ScalarName, "foo");
+ EXPECT_EQ(VectorName, "vector_foo");
+}
+
+TEST_F(VFABIParserTest, ParseWideStructMixedReturnTypeSVE) {
+ EXPECT_TRUE(invokeParser("_ZGVsMxv_foo(vector_foo)", "{float, i64}(float)"));
+ EXPECT_EQ(ISA, VFISAKind::SVE);
+ EXPECT_TRUE(isMasked());
+ ElementCount NXV2 = ElementCount::getScalable(2);
+ FunctionType *FTy = FunctionType::get(
+ StructType::get(VectorType::get(Type::getFloatTy(Ctx), NXV2),
+ VectorType::get(Type::getInt64Ty(Ctx), NXV2)),
+ {
+ VectorType::get(Type::getFloatTy(Ctx), NXV2),
+ VectorType::get(Type::getInt1Ty(Ctx), NXV2),
+ },
+ false);
+ EXPECT_EQ(getFunctionType(), FTy);
+ EXPECT_EQ(Parameters.size(), 2U);
+ EXPECT_EQ(Parameters[0], VFParameter({0, VFParamKind::Vector}));
+ EXPECT_EQ(Parameters[1], VFParameter({1, VFParamKind::GlobalPredicate}));
+ EXPECT_EQ(VF, NXV2);
+ EXPECT_EQ(ScalarName, "foo");
+ EXPECT_EQ(VectorName, "vector_foo");
+}
+
+TEST_F(VFABIParserTest, ParseWideStructReturnTypeNEON) {
+ EXPECT_TRUE(
+ invokeParser("_ZGVnN4v_foo(vector_foo)", "{float, float}(float)"));
+ EXPECT_EQ(ISA, VFISAKind::AdvancedSIMD);
+ EXPECT_FALSE(isMasked());
+ ElementCount V4 = ElementCount::getFixed(4);
+ FunctionType *FTy = FunctionType::get(
+ StructType::get(VectorType::get(Type::getFloatTy(Ctx), V4),
+ VectorType::get(Type::getFloatTy(Ctx), V4)),
+ {
+ VectorType::get(Type::getFloatTy(Ctx), V4),
+ },
+ false);
+ EXPECT_EQ(getFunctionType(), FTy);
+ EXPECT_EQ(Parameters.size(), 1U);
+ EXPECT_EQ(Parameters[0], VFParameter({0, VFParamKind::Vector}));
+ EXPECT_EQ(VF, V4);
+ EXPECT_EQ(ScalarName, "foo");
+ EXPECT_EQ(VectorName, "vector_foo");
+}
+
+TEST_F(VFABIParserTest, ParseUnsupportedStructReturnTypesSVE) {
+ // Struct with array element type.
+ EXPECT_FALSE(
+ invokeParser("_ZGVsMxv_foo(vector_foo)", "{double, [4 x float]}(float)"));
+ // Nested struct type.
+ EXPECT_FALSE(
+ invokeParser("_ZGVsMxv_foo(vector_foo)", "{{float, float}}(float)"));
+ // Packed struct type.
+ EXPECT_FALSE(
+ invokeParser("_ZGVsMxv_foo(vector_foo)", "<{double, float}>(float)"));
+ // Named struct type.
+ EXPECT_FALSE(
+ invokeParser("_ZGVsMxv_foo(vector_foo)", "%dummy_named_struct(float)"));
+}
+
// Make sure we reject unsupported parameter types.
TEST_F(VFABIParserTest, ParseUnsupportedElementTypeSVE) {
EXPECT_FALSE(invokeParser("_ZGVsMxv_foo(vector_foo)", "void(i128)"));
diff --git a/llvm/unittests/IR/VectorTypeUtilsTest.cpp b/llvm/unittests/IR/VectorTypeUtilsTest.cpp
new file mode 100644
index 00000000000000..c77f183e921de4
--- /dev/null
+++ b/llvm/unittests/IR/VectorTypeUtilsTest.cpp
@@ -0,0 +1,149 @@
+//===------- VectorTypeUtilsTest.cpp - Vector utils tests -----------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/IR/VectorTypeUtils.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/LLVMContext.h"
+#include "gtest/gtest.h"
+
+using namespace llvm;
+
+namespace {
+
+class VectorTypeUtilsTest : public ::testing::Test {};
+
+TEST(VectorTypeUtilsTest, TestToVectorizedTy) {
+ LLVMContext C;
+
+ Type *ITy = Type::getInt32Ty(C);
+ Type *FTy = Type::getFloatTy(C);
+ Type *HomogeneousStructTy = StructType::get(FTy, FTy, FTy);
+ Type *MixedStructTy = StructType::get(FTy, ITy);
+ Type *VoidTy = Type::getVoidTy(C);
+
+ for (ElementCount VF :
+ {ElementCount::getFixed(4), ElementCount::getScalable(2)}) {
+ Type *IntVec = toVectorizedTy(ITy, VF);
+ EXPECT_TRUE(isa<VectorType>(IntVec));
+ EXPECT_EQ(IntVec, VectorType::get(ITy, VF));
+
+ Type *FloatVec = toVectorizedTy(FTy, VF);
+ EXPECT_TRUE(isa<VectorType>(FloatVec));
+ EXPECT_EQ(FloatVec, VectorType::get(FTy, VF));
+
+ Type *WideHomogeneousStructTy = toVectorizedTy(HomogeneousStructTy, VF);
+ EXPECT_TRUE(isa<StructType>(WideHomogeneousStructTy));
+ EXPECT_TRUE(
+ cast<StructType>(WideHomogeneousStructTy)->containsHomogeneousTypes());
+ EXPECT_TRUE(cast<StructType>(WideHomogeneousStructTy)->getNumElements() ==
+ 3);
+ EXPECT_TRUE(cast<StructType>(WideHomogeneousStructTy)->getElementType(0) ==
+ VectorType::get(FTy, VF));
+
+ Type *WideMixedStructTy = toVectorizedTy(MixedStructTy, VF);
+ EXPECT_TRUE(isa<StructType>(WideMixedStructTy));
+ EXPECT_TRUE(cast<StructType>(WideMixedStructTy)->getNumElements() == 2);
+ EXPECT_TRUE(cast<StructType>(WideMixedStructTy)->getElementType(0) ==
+ VectorType::get(FTy, VF));
+ EXPECT_TRUE(cast<StructType>(WideMixedStructTy)->getElementType(1) ==
+ VectorType::get(ITy, VF));
+
+ EXPECT_EQ(toVectorizedTy(VoidTy, VF), VoidTy);
+ }
+
+ ElementCount ScalarVF = ElementCount::getFixed(1);
+ for (Type *Ty : {ITy, FTy, HomogeneousStructTy, MixedStructTy, VoidTy}) {
+ EXPECT_EQ(toVectorizedTy(Ty, ScalarVF), Ty);
+ }
+}
+
+TEST(VectorTypeUtilsTest, TestToScalarizedTy) {
+ LLVMContext C;
+
+ Type *ITy = Type::getInt32Ty(C);
+ Type *FTy = Type::getFloatTy(C);
+ Type *HomogeneousStructTy = StructType::get(FTy, FTy, FTy);
+ Type *MixedStructTy = StructType::get(FTy, ITy);
+ Type *VoidTy = Type::getVoidTy(C);
+
+ for (ElementCount VF : {ElementCount::getFixed(1), ElementCount::getFixed(4),
+ ElementCount::getScalable(2)}) {
+ for (Type *Ty : {ITy, FTy, HomogeneousStructTy, MixedStructTy, VoidTy}) {
+ // toScalarizedTy should be the inverse of toVectorizedTy.
+ EXPECT_EQ(toScalarizedTy(toVectorizedTy(Ty, VF)), Ty);
+ };
+ }
+}
+
+TEST(VectorTypeUtilsTest, TestGetContainedTypes) {
+ LLVMContext C;
+
+ Type *ITy = Type::getInt32Ty(C);
+ Type *FTy = Type::getFloatTy(C);
+ Type *HomogeneousStructTy = StructType::get(FTy, FTy, FTy);
+ Type *MixedStructTy = StructType::get(FTy, ITy);
+ Type *VoidTy = Type::getVoidTy(C);
+
+ EXPECT_EQ(getContainedTypes(ITy), ArrayRef<Type *>({ITy}));
+ EXPECT_EQ(getContainedTypes(FTy), ArrayRef<Type *>({FTy}));
+ EXPECT_EQ(getContainedTypes(VoidTy), ArrayRef<Type *>({VoidTy}));
+ EXPECT_EQ(getContainedTypes(HomogeneousStructTy),
+ ArrayRef<Type *>({FTy, FTy, FTy}));
+ EXPECT_EQ(getContainedTypes(MixedStructTy), ArrayRef<Type *>({FTy, ITy}));
+}
+
+TEST(VectorTypeUtilsTest, TestIsVectorizedTy) {
+ LLVMContext C;
+
+ Type *ITy = Type::getInt32Ty(C);
+ Type *FTy = Type::getFloatTy(C);
+ Type *NarrowStruct = StructType::get(FTy, ITy);
+ Type *VoidTy = Type::getVoidTy(C);
+
+ EXPECT_FALSE(isVectorizedTy(ITy));
+ EXPECT_FALSE(isVectorizedTy(NarrowStruct));
+ EXPECT_FALSE(isVectorizedTy(VoidTy));
+
+ ElementCount VF = ElementCount::getFixed(4);
+ EXPECT_TRUE(isVectorizedTy(toVectorizedTy(ITy, VF)));
+ EXPECT_TRUE(isVectorizedTy(toVectorizedTy(NarrowStruct, VF)));
+
+ Type *MixedVFStruct =
+ StructType::get(VectorType::get(ITy, ElementCount::getFixed(2)),
+ VectorType::get(ITy, ElementCount::getFixed(4)));
+ EXPECT_FALSE(isVectorizedTy(MixedVFStruct));
+
+ // Currently only literals types are considered wide.
+ Type *NamedWideStruct = StructType::create("Named", VectorType::get(ITy, VF),
+ VectorType::get(ITy, VF));
+ EXPECT_FALSE(isVectorizedTy(NamedWideStruct));
+
+ // Currently only unpacked types are considered wide.
+ Type *PackedWideStruct = StructType::get(
+ C, ArrayRef<Type *>{VectorType::get(ITy, VF), VectorType::get(ITy, VF)},
+ /*isPacked=*/true);
+ EXPECT_FALSE(isVectorizedTy(PackedWideStruct));
+}
+
+TEST(VectorTypeUtilsTest, TestGetVectorizedTypeVF) {
+ LLVMContext C;
+
+ Type *ITy = Type::getInt32Ty(C);
+ Type *FTy = Type::getFloatTy(C);
+ Type *HomogeneousStructTy = StructType::get(FTy, FTy, FTy);
+ Type *MixedStructTy = StructType::get(FTy, ITy);
+
+ for (ElementCount VF :
+ {ElementCount::getFixed(4), ElementCount::getScalable(2)}) {
+ for (Type *Ty : {ITy, FTy, HomogeneousStructTy, MixedStructTy}) {
+ EXPECT_EQ(getVectorizedTypeVF(toVectorizedTy(Ty, VF)), VF);
+ };
+ }
+}
+
+} // namespace
More information about the llvm-commits
mailing list