[llvm] bb72865 - [SandboxIR] Implement FixedVectorType (#107930)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Sep 10 15:44:33 PDT 2024
Author: Sterling-Augustine
Date: 2024-09-10T15:44:30-07:00
New Revision: bb7286515c0b285382f370232f97ffa7cfcbc550
URL: https://github.com/llvm/llvm-project/commit/bb7286515c0b285382f370232f97ffa7cfcbc550
DIFF: https://github.com/llvm/llvm-project/commit/bb7286515c0b285382f370232f97ffa7cfcbc550.diff
LOG: [SandboxIR] Implement FixedVectorType (#107930)
Added:
Modified:
llvm/include/llvm/SandboxIR/Type.h
llvm/lib/SandboxIR/Type.cpp
llvm/unittests/SandboxIR/TypesTest.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/SandboxIR/Type.h b/llvm/include/llvm/SandboxIR/Type.h
index 44aee4e4a5b46e..ec141c249fb21e 100644
--- a/llvm/include/llvm/SandboxIR/Type.h
+++ b/llvm/include/llvm/SandboxIR/Type.h
@@ -25,6 +25,7 @@ class Context;
// Forward declare friend classes for MSVC.
class PointerType;
class VectorType;
+class FixedVectorType;
class IntegerType;
class FunctionType;
class ArrayType;
@@ -41,6 +42,7 @@ class Type {
friend class ArrayType; // For LLVMTy.
friend class StructType; // For LLVMTy.
friend class VectorType; // For LLVMTy.
+ friend class FixedVectorType; // For LLVMTy.
friend class PointerType; // For LLVMTy.
friend class FunctionType; // For LLVMTy.
friend class IntegerType; // For LLVMTy.
@@ -344,6 +346,50 @@ class VectorType : public Type {
}
};
+class FixedVectorType : public VectorType {
+public:
+ static FixedVectorType *get(Type *ElementType, unsigned NumElts);
+
+ static FixedVectorType *get(Type *ElementType, const FixedVectorType *FVTy) {
+ return get(ElementType, FVTy->getNumElements());
+ }
+
+ static FixedVectorType *getInteger(FixedVectorType *VTy) {
+ return cast<FixedVectorType>(VectorType::getInteger(VTy));
+ }
+
+ static FixedVectorType *getExtendedElementVectorType(FixedVectorType *VTy) {
+ return cast<FixedVectorType>(VectorType::getExtendedElementVectorType(VTy));
+ }
+
+ static FixedVectorType *getTruncatedElementVectorType(FixedVectorType *VTy) {
+ return cast<FixedVectorType>(
+ VectorType::getTruncatedElementVectorType(VTy));
+ }
+
+ static FixedVectorType *getSubdividedVectorType(FixedVectorType *VTy,
+ int NumSubdivs) {
+ return cast<FixedVectorType>(
+ VectorType::getSubdividedVectorType(VTy, NumSubdivs));
+ }
+
+ static FixedVectorType *getHalfElementsVectorType(FixedVectorType *VTy) {
+ return cast<FixedVectorType>(VectorType::getHalfElementsVectorType(VTy));
+ }
+
+ static FixedVectorType *getDoubleElementsVectorType(FixedVectorType *VTy) {
+ return cast<FixedVectorType>(VectorType::getDoubleElementsVectorType(VTy));
+ }
+
+ static bool classof(const Type *T) {
+ return isa<llvm::FixedVectorType>(T->LLVMTy);
+ }
+
+ unsigned getNumElements() const {
+ return cast<llvm::FixedVectorType>(LLVMTy)->getNumElements();
+ }
+};
+
class FunctionType : public Type {
public:
// TODO: add missing functions
diff --git a/llvm/lib/SandboxIR/Type.cpp b/llvm/lib/SandboxIR/Type.cpp
index bf9f02e2ba3111..26aa8b3743084c 100644
--- a/llvm/lib/SandboxIR/Type.cpp
+++ b/llvm/lib/SandboxIR/Type.cpp
@@ -103,6 +103,11 @@ bool VectorType::isValidElementType(Type *ElemTy) {
return llvm::VectorType::isValidElementType(ElemTy->LLVMTy);
}
+FixedVectorType *FixedVectorType::get(Type *ElementType, unsigned NumElts) {
+ return cast<FixedVectorType>(ElementType->getContext().getType(
+ llvm::FixedVectorType::get(ElementType->LLVMTy, NumElts)));
+}
+
IntegerType *IntegerType::get(Context &Ctx, unsigned NumBits) {
return cast<IntegerType>(
Ctx.getType(llvm::IntegerType::get(Ctx.LLVMCtx, NumBits)));
diff --git a/llvm/unittests/SandboxIR/TypesTest.cpp b/llvm/unittests/SandboxIR/TypesTest.cpp
index e4f9235c1ef3ca..3564ae66830147 100644
--- a/llvm/unittests/SandboxIR/TypesTest.cpp
+++ b/llvm/unittests/SandboxIR/TypesTest.cpp
@@ -323,6 +323,64 @@ define void @foo(<4 x i16> %vi0, <4 x float> %vf1, i8 %i0) {
EXPECT_FALSE(sandboxir::VectorType::isValidElementType(FVecTy));
}
+TEST_F(SandboxTypeTest, FixedVectorType) {
+ parseIR(C, R"IR(
+define void @foo(<4 x i16> %vi0, <4 x float> %vf1, i8 %i0) {
+ ret void
+}
+)IR");
+ llvm::Function *LLVMF = &*M->getFunction("foo");
+ sandboxir::Context Ctx(C);
+ auto *F = Ctx.createFunction(LLVMF);
+ // Check classof(), creation, accessors
+ auto *Vec4i16Ty = cast<sandboxir::FixedVectorType>(F->getArg(0)->getType());
+ EXPECT_TRUE(Vec4i16Ty->getElementType()->isIntegerTy(16));
+ EXPECT_EQ(Vec4i16Ty->getElementCount(), ElementCount::getFixed(4));
+
+ // get(ElementType, NumElements)
+ EXPECT_EQ(
+ sandboxir::FixedVectorType::get(sandboxir::Type::getInt16Ty(Ctx), 4),
+ F->getArg(0)->getType());
+ // get(ElementType, Other)
+ EXPECT_EQ(sandboxir::FixedVectorType::get(
+ sandboxir::Type::getInt16Ty(Ctx),
+ cast<sandboxir::FixedVectorType>(F->getArg(0)->getType())),
+ F->getArg(0)->getType());
+ auto *Vec4FTy = cast<sandboxir::FixedVectorType>(F->getArg(1)->getType());
+ EXPECT_TRUE(Vec4FTy->getElementType()->isFloatTy());
+ // getInteger
+ auto *Vec4i32Ty = sandboxir::FixedVectorType::getInteger(Vec4FTy);
+ EXPECT_TRUE(Vec4i32Ty->getElementType()->isIntegerTy(32));
+ EXPECT_EQ(Vec4i32Ty->getElementCount(), Vec4FTy->getElementCount());
+ // getExtendedElementCountVectorType
+ auto *Vec4i64Ty =
+ sandboxir::FixedVectorType::getExtendedElementVectorType(Vec4i16Ty);
+ EXPECT_TRUE(Vec4i64Ty->getElementType()->isIntegerTy(32));
+ EXPECT_EQ(Vec4i64Ty->getElementCount(), Vec4i16Ty->getElementCount());
+ // getTruncatedElementVectorType
+ auto *Vec4i8Ty =
+ sandboxir::FixedVectorType::getTruncatedElementVectorType(Vec4i16Ty);
+ EXPECT_TRUE(Vec4i8Ty->getElementType()->isIntegerTy(8));
+ EXPECT_EQ(Vec4i8Ty->getElementCount(), Vec4i8Ty->getElementCount());
+ // getSubdividedVectorType
+ auto *Vec8i8Ty =
+ sandboxir::FixedVectorType::getSubdividedVectorType(Vec4i16Ty, 1);
+ EXPECT_TRUE(Vec8i8Ty->getElementType()->isIntegerTy(8));
+ EXPECT_EQ(Vec8i8Ty->getElementCount(), ElementCount::getFixed(8));
+ // getNumElements
+ EXPECT_EQ(Vec8i8Ty->getNumElements(), 8u);
+ // getHalfElementsVectorType
+ auto *Vec2i16Ty =
+ sandboxir::FixedVectorType::getHalfElementsVectorType(Vec4i16Ty);
+ EXPECT_TRUE(Vec2i16Ty->getElementType()->isIntegerTy(16));
+ EXPECT_EQ(Vec2i16Ty->getElementCount(), ElementCount::getFixed(2));
+ // getDoubleElementsVectorType
+ auto *Vec8i16Ty =
+ sandboxir::FixedVectorType::getDoubleElementsVectorType(Vec4i16Ty);
+ EXPECT_TRUE(Vec8i16Ty->getElementType()->isIntegerTy(16));
+ EXPECT_EQ(Vec8i16Ty->getElementCount(), ElementCount::getFixed(8));
+}
+
TEST_F(SandboxTypeTest, FunctionType) {
parseIR(C, R"IR(
define void @foo() {
More information about the llvm-commits
mailing list