[llvm] [SandboxIR] Implement ScalableVectorType (PR #108124)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Sep 10 18:19:49 PDT 2024
https://github.com/Sterling-Augustine updated https://github.com/llvm/llvm-project/pull/108124
>From a0533d7272cc350a7681107b5f0a0e29db9aaf3f Mon Sep 17 00:00:00 2001
From: Sterling Augustine <saugustine at google.com>
Date: Tue, 10 Sep 2024 17:49:53 -0700
Subject: [PATCH 1/3] [SandboxIR] Implement ScalableVectorType
---
llvm/include/llvm/SandboxIR/Type.h | 83 +++++++++++++++++++++-----
llvm/lib/SandboxIR/Type.cpp | 6 ++
llvm/unittests/SandboxIR/TypesTest.cpp | 59 ++++++++++++++++++
3 files changed, 133 insertions(+), 15 deletions(-)
diff --git a/llvm/include/llvm/SandboxIR/Type.h b/llvm/include/llvm/SandboxIR/Type.h
index ec141c249fb21e..4a55564d81ff8d 100644
--- a/llvm/include/llvm/SandboxIR/Type.h
+++ b/llvm/include/llvm/SandboxIR/Type.h
@@ -26,6 +26,7 @@ class Context;
class PointerType;
class VectorType;
class FixedVectorType;
+class ScalableVectorType;
class IntegerType;
class FunctionType;
class ArrayType;
@@ -39,21 +40,22 @@ class StructType;
class Type {
protected:
llvm::Type *LLVMTy;
- friend class ArrayType; // For LLVMTy.
- friend class StructType; // For LLVMTy.
- friend class VectorType; // For LLVMTy.
- friend class FixedVectorType; // For LLVMTy.
- friend class PointerType; // For LLVMTy.
- friend class FunctionType; // For LLVMTy.
- friend class IntegerType; // For LLVMTy.
- friend class Function; // For LLVMTy.
- friend class CallBase; // For LLVMTy.
- friend class ConstantInt; // For LLVMTy.
- friend class ConstantArray; // For LLVMTy.
- friend class ConstantStruct; // For LLVMTy.
- friend class ConstantVector; // For LLVMTy.
- friend class CmpInst; // For LLVMTy. TODO: Cleanup after
- // sandboxir::VectorType is more complete.
+ friend class ArrayType; // For LLVMTy.
+ friend class StructType; // For LLVMTy.
+ friend class VectorType; // For LLVMTy.
+ friend class FixedVectorType; // For LLVMTy.
+ friend class ScalableVectorType; // For LLVMTy.
+ friend class PointerType; // For LLVMTy.
+ friend class FunctionType; // For LLVMTy.
+ friend class IntegerType; // For LLVMTy.
+ friend class Function; // For LLVMTy.
+ friend class CallBase; // For LLVMTy.
+ friend class ConstantInt; // For LLVMTy.
+ friend class ConstantArray; // For LLVMTy.
+ friend class ConstantStruct; // For LLVMTy.
+ friend class ConstantVector; // For LLVMTy.
+ friend class CmpInst; // For LLVMTy. TODO: Cleanup after
+ // sandboxir::VectorType is more complete.
// Friend all instruction classes because `create()` functions use LLVMTy.
#define DEF_INSTR(ID, OPCODE, CLASS) friend class CLASS;
@@ -390,6 +392,57 @@ class FixedVectorType : public VectorType {
}
};
+class ScalableVectorType : public VectorType {
+public:
+ static ScalableVectorType *get(Type *ElementType, unsigned MinNumElts);
+
+ static ScalableVectorType *get(Type *ElementType,
+ const ScalableVectorType *SVTy) {
+ return get(ElementType, SVTy->getMinNumElements());
+ }
+
+ static ScalableVectorType *getInteger(ScalableVectorType *VTy) {
+ return cast<ScalableVectorType>(VectorType::getInteger(VTy));
+ }
+
+ static ScalableVectorType *
+ getExtendedElementVectorType(ScalableVectorType *VTy) {
+ return cast<ScalableVectorType>(
+ VectorType::getExtendedElementVectorType(VTy));
+ }
+
+ static ScalableVectorType *
+ getTruncatedElementVectorType(ScalableVectorType *VTy) {
+ return cast<ScalableVectorType>(
+ VectorType::getTruncatedElementVectorType(VTy));
+ }
+
+ static ScalableVectorType *getSubdividedVectorType(ScalableVectorType *VTy,
+ int NumSubdivs) {
+ return cast<ScalableVectorType>(
+ VectorType::getSubdividedVectorType(VTy, NumSubdivs));
+ }
+
+ static ScalableVectorType *
+ getHalfElementsVectorType(ScalableVectorType *VTy) {
+ return cast<ScalableVectorType>(VectorType::getHalfElementsVectorType(VTy));
+ }
+
+ static ScalableVectorType *
+ getDoubleElementsVectorType(ScalableVectorType *VTy) {
+ return cast<ScalableVectorType>(
+ VectorType::getDoubleElementsVectorType(VTy));
+ }
+
+ unsigned getMinNumElements() const {
+ return cast<llvm::ScalableVectorType>(LLVMTy)->getMinNumElements();
+ }
+
+ static bool classof(const Type *T) {
+ return isa<llvm::ScalableVectorType>(T->LLVMTy);
+ }
+};
+
class FunctionType : public Type {
public:
// TODO: add missing functions
diff --git a/llvm/lib/SandboxIR/Type.cpp b/llvm/lib/SandboxIR/Type.cpp
index 26aa8b3743084c..87dcb726dde351 100644
--- a/llvm/lib/SandboxIR/Type.cpp
+++ b/llvm/lib/SandboxIR/Type.cpp
@@ -108,6 +108,12 @@ FixedVectorType *FixedVectorType::get(Type *ElementType, unsigned NumElts) {
llvm::FixedVectorType::get(ElementType->LLVMTy, NumElts)));
}
+ScalableVectorType *ScalableVectorType::get(Type *ElementType,
+ unsigned NumElts) {
+ return cast<ScalableVectorType>(ElementType->getContext().getType(
+ llvm::ScalableVectorType::get(ElementType->LLVMTy, NumElts)));
+}
+
IntegerType *IntegerType::get(Context &Ctx, unsigned NumBits) {
return cast<IntegerType>(
Ctx.getType(llvm::IntegerType::get(Ctx.LLVMCtx, NumBits)));
diff --git a/llvm/unittests/SandboxIR/TypesTest.cpp b/llvm/unittests/SandboxIR/TypesTest.cpp
index 3564ae66830147..40aa32fb08ed01 100644
--- a/llvm/unittests/SandboxIR/TypesTest.cpp
+++ b/llvm/unittests/SandboxIR/TypesTest.cpp
@@ -381,6 +381,65 @@ define void @foo(<4 x i16> %vi0, <4 x float> %vf1, i8 %i0) {
EXPECT_EQ(Vec8i16Ty->getElementCount(), ElementCount::getFixed(8));
}
+TEST_F(SandboxTypeTest, ScalableVectorType) {
+ parseIR(C, R"IR(
+define void @foo(<vscale x 4 x i16> %vi0, <vscale x 4 x float> %vf1, i8 %i0) {
+ ret void
+}
+)IR");
+ llvm::Function *LLVMF = &*M->getFunction("foo");
+ sandboxir::Context Ctx(C);
+ auto *F = Ctx.createFunction(LLVMF);
+ // Check classof(), creation, accessors
+ auto *Vec4i16Ty =
+ cast<sandboxir::ScalableVectorType>(F->getArg(0)->getType());
+ EXPECT_TRUE(Vec4i16Ty->getElementType()->isIntegerTy(16));
+ EXPECT_EQ(Vec4i16Ty->getMinNumElements(), 4u);
+
+ // get(ElementType, NumElements)
+ EXPECT_EQ(
+ sandboxir::ScalableVectorType::get(sandboxir::Type::getInt16Ty(Ctx), 4),
+ F->getArg(0)->getType());
+ // get(ElementType, Other)
+ EXPECT_EQ(sandboxir::ScalableVectorType::get(
+ sandboxir::Type::getInt16Ty(Ctx),
+ cast<sandboxir::ScalableVectorType>(F->getArg(0)->getType())),
+ F->getArg(0)->getType());
+ auto *Vec4FTy = cast<sandboxir::ScalableVectorType>(F->getArg(1)->getType());
+ EXPECT_TRUE(Vec4FTy->getElementType()->isFloatTy());
+ // getInteger
+ auto *Vec4i32Ty = sandboxir::ScalableVectorType::getInteger(Vec4FTy);
+ EXPECT_TRUE(Vec4i32Ty->getElementType()->isIntegerTy(32));
+ EXPECT_EQ(Vec4i32Ty->getMinNumElements(), Vec4FTy->getMinNumElements());
+ // getExtendedElementCountVectorType
+ auto *Vec4i64Ty =
+ sandboxir::ScalableVectorType::getExtendedElementVectorType(Vec4i16Ty);
+ EXPECT_TRUE(Vec4i64Ty->getElementType()->isIntegerTy(32));
+ EXPECT_EQ(Vec4i64Ty->getMinNumElements(), Vec4i16Ty->getMinNumElements());
+ // getTruncatedElementVectorType
+ auto *Vec4i8Ty =
+ sandboxir::ScalableVectorType::getTruncatedElementVectorType(Vec4i16Ty);
+ EXPECT_TRUE(Vec4i8Ty->getElementType()->isIntegerTy(8));
+ EXPECT_EQ(Vec4i8Ty->getMinNumElements(), Vec4i8Ty->getMinNumElements());
+ // getSubdividedVectorType
+ auto *Vec8i8Ty =
+ sandboxir::ScalableVectorType::getSubdividedVectorType(Vec4i16Ty, 1);
+ EXPECT_TRUE(Vec8i8Ty->getElementType()->isIntegerTy(8));
+ EXPECT_EQ(Vec8i8Ty->getMinNumElements(), 8u);
+ // getMinNumElements
+ EXPECT_EQ(Vec8i8Ty->getMinNumElements(), 8u);
+ // getHalfElementsVectorType
+ auto *Vec2i16Ty =
+ sandboxir::ScalableVectorType::getHalfElementsVectorType(Vec4i16Ty);
+ EXPECT_TRUE(Vec2i16Ty->getElementType()->isIntegerTy(16));
+ EXPECT_EQ(Vec2i16Ty->getMinNumElements(), 2u);
+ // getDoubleElementsVectorType
+ auto *Vec8i16Ty =
+ sandboxir::ScalableVectorType::getDoubleElementsVectorType(Vec4i16Ty);
+ EXPECT_TRUE(Vec8i16Ty->getElementType()->isIntegerTy(16));
+ EXPECT_EQ(Vec8i16Ty->getMinNumElements(), 8u);
+}
+
TEST_F(SandboxTypeTest, FunctionType) {
parseIR(C, R"IR(
define void @foo() {
>From b19051fae84632688196a8094f18320a07ac6989 Mon Sep 17 00:00:00 2001
From: Sterling Augustine <saugustine at google.com>
Date: Tue, 10 Sep 2024 17:59:27 -0700
Subject: [PATCH 2/3] Fix formatting
---
llvm/include/llvm/SandboxIR/Type.h | 22 +++++++++++-----------
1 file changed, 11 insertions(+), 11 deletions(-)
diff --git a/llvm/include/llvm/SandboxIR/Type.h b/llvm/include/llvm/SandboxIR/Type.h
index 4a55564d81ff8d..685563bf88228d 100644
--- a/llvm/include/llvm/SandboxIR/Type.h
+++ b/llvm/include/llvm/SandboxIR/Type.h
@@ -45,17 +45,17 @@ class Type {
friend class VectorType; // For LLVMTy.
friend class FixedVectorType; // For LLVMTy.
friend class ScalableVectorType; // For LLVMTy.
- friend class PointerType; // For LLVMTy.
- friend class FunctionType; // For LLVMTy.
- friend class IntegerType; // For LLVMTy.
- friend class Function; // For LLVMTy.
- friend class CallBase; // For LLVMTy.
- friend class ConstantInt; // For LLVMTy.
- friend class ConstantArray; // For LLVMTy.
- friend class ConstantStruct; // For LLVMTy.
- friend class ConstantVector; // For LLVMTy.
- friend class CmpInst; // For LLVMTy. TODO: Cleanup after
- // sandboxir::VectorType is more complete.
+ friend class PointerType; // For LLVMTy.
+ friend class FunctionType; // For LLVMTy.
+ friend class IntegerType; // For LLVMTy.
+ friend class Function; // For LLVMTy.
+ friend class CallBase; // For LLVMTy.
+ friend class ConstantInt; // For LLVMTy.
+ friend class ConstantArray; // For LLVMTy.
+ friend class ConstantStruct; // For LLVMTy.
+ friend class ConstantVector; // For LLVMTy.
+ friend class CmpInst; // For LLVMTy. TODO: Cleanup after
+ // sandboxir::VectorType is more complete.
// Friend all instruction classes because `create()` functions use LLVMTy.
#define DEF_INSTR(ID, OPCODE, CLASS) friend class CLASS;
>From 0bc631d7f38bc16b98534c967a7e9f3026abcfd7 Mon Sep 17 00:00:00 2001
From: Sterling Augustine <saugustine at google.com>
Date: Tue, 10 Sep 2024 18:07:35 -0700
Subject: [PATCH 3/3] Another try to fix format
---
llvm/include/llvm/SandboxIR/Type.h | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/llvm/include/llvm/SandboxIR/Type.h b/llvm/include/llvm/SandboxIR/Type.h
index 685563bf88228d..a2ac9e014b44ab 100644
--- a/llvm/include/llvm/SandboxIR/Type.h
+++ b/llvm/include/llvm/SandboxIR/Type.h
@@ -40,10 +40,10 @@ class StructType;
class Type {
protected:
llvm::Type *LLVMTy;
- friend class ArrayType; // For LLVMTy.
- friend class StructType; // For LLVMTy.
- friend class VectorType; // For LLVMTy.
- friend class FixedVectorType; // For LLVMTy.
+ friend class ArrayType; // For LLVMTy.
+ friend class StructType; // For LLVMTy.
+ friend class VectorType; // For LLVMTy.
+ friend class FixedVectorType; // For LLVMTy.
friend class ScalableVectorType; // For LLVMTy.
friend class PointerType; // For LLVMTy.
friend class FunctionType; // For LLVMTy.
More information about the llvm-commits
mailing list