[llvm] [SandboxIR] Implement FixedVectorType (PR #107930)

via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 10 15:17:18 PDT 2024


https://github.com/Sterling-Augustine updated https://github.com/llvm/llvm-project/pull/107930

>From cb85b818935ee69b666d7e8902e4ae032c071c93 Mon Sep 17 00:00:00 2001
From: Sterling Augustine <saugustine at google.com>
Date: Mon, 9 Sep 2024 15:42:15 -0700
Subject: [PATCH 1/2] [SandboxIR] Implement FixedVectorType

---
 llvm/include/llvm/SandboxIR/Type.h     | 46 ++++++++++++++++++++
 llvm/lib/SandboxIR/Type.cpp            |  5 +++
 llvm/unittests/SandboxIR/TypesTest.cpp | 58 ++++++++++++++++++++++++++
 3 files changed, 109 insertions(+)

diff --git a/llvm/include/llvm/SandboxIR/Type.h b/llvm/include/llvm/SandboxIR/Type.h
index 44aee4e4a5b46e..ec141c249fb21e 100644
--- a/llvm/include/llvm/SandboxIR/Type.h
+++ b/llvm/include/llvm/SandboxIR/Type.h
@@ -25,6 +25,7 @@ class Context;
 // Forward declare friend classes for MSVC.
 class PointerType;
 class VectorType;
+class FixedVectorType;
 class IntegerType;
 class FunctionType;
 class ArrayType;
@@ -41,6 +42,7 @@ class Type {
   friend class ArrayType;      // For LLVMTy.
   friend class StructType;     // For LLVMTy.
   friend class VectorType;     // For LLVMTy.
+  friend class FixedVectorType; // For LLVMTy.
   friend class PointerType;    // For LLVMTy.
   friend class FunctionType;   // For LLVMTy.
   friend class IntegerType;    // For LLVMTy.
@@ -344,6 +346,50 @@ class VectorType : public Type {
   }
 };
 
+class FixedVectorType : public VectorType {
+public:
+  static FixedVectorType *get(Type *ElementType, unsigned NumElts);
+
+  static FixedVectorType *get(Type *ElementType, const FixedVectorType *FVTy) {
+    return get(ElementType, FVTy->getNumElements());
+  }
+
+  static FixedVectorType *getInteger(FixedVectorType *VTy) {
+    return cast<FixedVectorType>(VectorType::getInteger(VTy));
+  }
+
+  static FixedVectorType *getExtendedElementVectorType(FixedVectorType *VTy) {
+    return cast<FixedVectorType>(VectorType::getExtendedElementVectorType(VTy));
+  }
+
+  static FixedVectorType *getTruncatedElementVectorType(FixedVectorType *VTy) {
+    return cast<FixedVectorType>(
+        VectorType::getTruncatedElementVectorType(VTy));
+  }
+
+  static FixedVectorType *getSubdividedVectorType(FixedVectorType *VTy,
+                                                  int NumSubdivs) {
+    return cast<FixedVectorType>(
+        VectorType::getSubdividedVectorType(VTy, NumSubdivs));
+  }
+
+  static FixedVectorType *getHalfElementsVectorType(FixedVectorType *VTy) {
+    return cast<FixedVectorType>(VectorType::getHalfElementsVectorType(VTy));
+  }
+
+  static FixedVectorType *getDoubleElementsVectorType(FixedVectorType *VTy) {
+    return cast<FixedVectorType>(VectorType::getDoubleElementsVectorType(VTy));
+  }
+
+  static bool classof(const Type *T) {
+    return isa<llvm::FixedVectorType>(T->LLVMTy);
+  }
+
+  unsigned getNumElements() const {
+    return cast<llvm::FixedVectorType>(LLVMTy)->getNumElements();
+  }
+};
+
 class FunctionType : public Type {
 public:
   // TODO: add missing functions
diff --git a/llvm/lib/SandboxIR/Type.cpp b/llvm/lib/SandboxIR/Type.cpp
index bf9f02e2ba3111..26aa8b3743084c 100644
--- a/llvm/lib/SandboxIR/Type.cpp
+++ b/llvm/lib/SandboxIR/Type.cpp
@@ -103,6 +103,11 @@ bool VectorType::isValidElementType(Type *ElemTy) {
   return llvm::VectorType::isValidElementType(ElemTy->LLVMTy);
 }
 
+FixedVectorType *FixedVectorType::get(Type *ElementType, unsigned NumElts) {
+  return cast<FixedVectorType>(ElementType->getContext().getType(
+      llvm::FixedVectorType::get(ElementType->LLVMTy, NumElts)));
+}
+
 IntegerType *IntegerType::get(Context &Ctx, unsigned NumBits) {
   return cast<IntegerType>(
       Ctx.getType(llvm::IntegerType::get(Ctx.LLVMCtx, NumBits)));
diff --git a/llvm/unittests/SandboxIR/TypesTest.cpp b/llvm/unittests/SandboxIR/TypesTest.cpp
index e4f9235c1ef3ca..2e10084205069e 100644
--- a/llvm/unittests/SandboxIR/TypesTest.cpp
+++ b/llvm/unittests/SandboxIR/TypesTest.cpp
@@ -323,6 +323,64 @@ define void @foo(<4 x i16> %vi0, <4 x float> %vf1, i8 %i0) {
   EXPECT_FALSE(sandboxir::VectorType::isValidElementType(FVecTy));
 }
 
+TEST_F(SandboxTypeTest, FixedVectorType) {
+  parseIR(C, R"IR(
+define void @foo(<4 x i16> %vi0, <4 x float> %vf1, i8 %i0) {
+  ret void
+}
+)IR");
+  llvm::Function *LLVMF = &*M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  auto *F = Ctx.createFunction(LLVMF);
+  // Check classof(), creation, accessors
+  auto *VecTy = cast<sandboxir::FixedVectorType>(F->getArg(0)->getType());
+  EXPECT_TRUE(VecTy->getElementType()->isIntegerTy(16));
+  EXPECT_EQ(VecTy->getElementCount(), ElementCount::getFixed(4));
+
+  // get(ElementType, NumElements)
+  EXPECT_EQ(
+      sandboxir::FixedVectorType::get(sandboxir::Type::getInt16Ty(Ctx), 4),
+      F->getArg(0)->getType());
+  // get(ElementType, Other)
+  EXPECT_EQ(sandboxir::FixedVectorType::get(
+                sandboxir::Type::getInt16Ty(Ctx),
+                cast<sandboxir::FixedVectorType>(F->getArg(0)->getType())),
+            F->getArg(0)->getType());
+  auto *FVecTy = cast<sandboxir::FixedVectorType>(F->getArg(1)->getType());
+  EXPECT_TRUE(FVecTy->getElementType()->isFloatTy());
+  // getInteger
+  auto *IVecTy = sandboxir::FixedVectorType::getInteger(FVecTy);
+  EXPECT_TRUE(IVecTy->getElementType()->isIntegerTy(32));
+  EXPECT_EQ(IVecTy->getElementCount(), FVecTy->getElementCount());
+  // getExtendedElementCountVectorType
+  auto *ExtVecTy =
+      sandboxir::FixedVectorType::getExtendedElementVectorType(IVecTy);
+  EXPECT_TRUE(ExtVecTy->getElementType()->isIntegerTy(64));
+  EXPECT_EQ(ExtVecTy->getElementCount(), VecTy->getElementCount());
+  // getTruncatedElementVectorType
+  auto *TruncVecTy =
+      sandboxir::FixedVectorType::getTruncatedElementVectorType(IVecTy);
+  EXPECT_TRUE(TruncVecTy->getElementType()->isIntegerTy(16));
+  EXPECT_EQ(TruncVecTy->getElementCount(), VecTy->getElementCount());
+  // getSubdividedVectorType
+  auto *SubVecTy =
+      sandboxir::FixedVectorType::getSubdividedVectorType(VecTy, 1);
+  EXPECT_TRUE(SubVecTy->getElementType()->isIntegerTy(8));
+  EXPECT_EQ(SubVecTy->getElementCount(), ElementCount::getFixed(8));
+  // getNumElements
+  EXPECT_EQ(SubVecTy->getNumElements(), 8u);
+  // getHalfElementsVectorType
+  auto *HalfVecTy =
+      sandboxir::FixedVectorType::getHalfElementsVectorType(VecTy);
+  EXPECT_TRUE(HalfVecTy->getElementType()->isIntegerTy(16));
+  EXPECT_EQ(HalfVecTy->getElementCount(), ElementCount::getFixed(2));
+  // getDoubleElementsVectorType
+  auto *DoubleVecTy =
+      sandboxir::FixedVectorType::getDoubleElementsVectorType(VecTy);
+  EXPECT_TRUE(DoubleVecTy->getElementType()->isIntegerTy(16));
+  EXPECT_EQ(DoubleVecTy->getElementCount(), ElementCount::getFixed(8));
+}
+
 TEST_F(SandboxTypeTest, FunctionType) {
   parseIR(C, R"IR(
 define void @foo() {

>From f87c9845a2f1d0b5d044611af230ad0ab81cc7af Mon Sep 17 00:00:00 2001
From: Sterling Augustine <saugustine at google.com>
Date: Tue, 10 Sep 2024 15:16:19 -0700
Subject: [PATCH 2/2] Rename variables to reflect expected sizes.

---
 llvm/unittests/SandboxIR/TypesTest.cpp | 58 +++++++++++++-------------
 1 file changed, 29 insertions(+), 29 deletions(-)

diff --git a/llvm/unittests/SandboxIR/TypesTest.cpp b/llvm/unittests/SandboxIR/TypesTest.cpp
index 2e10084205069e..17486deb2325e8 100644
--- a/llvm/unittests/SandboxIR/TypesTest.cpp
+++ b/llvm/unittests/SandboxIR/TypesTest.cpp
@@ -333,9 +333,9 @@ define void @foo(<4 x i16> %vi0, <4 x float> %vf1, i8 %i0) {
   sandboxir::Context Ctx(C);
   auto *F = Ctx.createFunction(LLVMF);
   // Check classof(), creation, accessors
-  auto *VecTy = cast<sandboxir::FixedVectorType>(F->getArg(0)->getType());
-  EXPECT_TRUE(VecTy->getElementType()->isIntegerTy(16));
-  EXPECT_EQ(VecTy->getElementCount(), ElementCount::getFixed(4));
+  auto *Vec4i16Ty = cast<sandboxir::FixedVectorType>(F->getArg(0)->getType());
+  EXPECT_TRUE(Vec4i16Ty->getElementType()->isIntegerTy(16));
+  EXPECT_EQ(Vec4i16Ty->getElementCount(), ElementCount::getFixed(4));
 
   // get(ElementType, NumElements)
   EXPECT_EQ(
@@ -346,39 +346,39 @@ define void @foo(<4 x i16> %vi0, <4 x float> %vf1, i8 %i0) {
                 sandboxir::Type::getInt16Ty(Ctx),
                 cast<sandboxir::FixedVectorType>(F->getArg(0)->getType())),
             F->getArg(0)->getType());
-  auto *FVecTy = cast<sandboxir::FixedVectorType>(F->getArg(1)->getType());
-  EXPECT_TRUE(FVecTy->getElementType()->isFloatTy());
+  auto *Vec4FTy = cast<sandboxir::FixedVectorType>(F->getArg(1)->getType());
+  EXPECT_TRUE(Vec4FTy->getElementType()->isFloatTy());
   // getInteger
-  auto *IVecTy = sandboxir::FixedVectorType::getInteger(FVecTy);
-  EXPECT_TRUE(IVecTy->getElementType()->isIntegerTy(32));
-  EXPECT_EQ(IVecTy->getElementCount(), FVecTy->getElementCount());
+  auto *Vec4i32Ty = sandboxir::FixedVectorType::getInteger(Vec4FTy);
+  EXPECT_TRUE(Vec4i32Ty->getElementType()->isIntegerTy(32));
+  EXPECT_EQ(Vec4i32Ty->getElementCount(), Vec4FTy->getElementCount());
   // getExtendedElementCountVectorType
-  auto *ExtVecTy =
-      sandboxir::FixedVectorType::getExtendedElementVectorType(IVecTy);
-  EXPECT_TRUE(ExtVecTy->getElementType()->isIntegerTy(64));
-  EXPECT_EQ(ExtVecTy->getElementCount(), VecTy->getElementCount());
+  auto *Vec4i64Ty =
+      sandboxir::FixedVectorType::getExtendedElementVectorType(Vec4i32Ty);
+  EXPECT_TRUE(Vec4i64Ty->getElementType()->isIntegerTy(64));
+  EXPECT_EQ(Vec4i64Ty->getElementCount(), Vec4i16Ty->getElementCount());
   // getTruncatedElementVectorType
-  auto *TruncVecTy =
-      sandboxir::FixedVectorType::getTruncatedElementVectorType(IVecTy);
-  EXPECT_TRUE(TruncVecTy->getElementType()->isIntegerTy(16));
-  EXPECT_EQ(TruncVecTy->getElementCount(), VecTy->getElementCount());
+  auto *TVec4i16Ty =
+      sandboxir::FixedVectorType::getTruncatedElementVectorType(Vec4i32Ty);
+  EXPECT_TRUE(TVec4i16Ty->getElementType()->isIntegerTy(16));
+  EXPECT_EQ(TVec4i16Ty->getElementCount(), TVec4i16Ty->getElementCount());
   // getSubdividedVectorType
-  auto *SubVecTy =
-      sandboxir::FixedVectorType::getSubdividedVectorType(VecTy, 1);
-  EXPECT_TRUE(SubVecTy->getElementType()->isIntegerTy(8));
-  EXPECT_EQ(SubVecTy->getElementCount(), ElementCount::getFixed(8));
+  auto *Vec8i8Ty =
+      sandboxir::FixedVectorType::getSubdividedVectorType(Vec4i16Ty, 1);
+  EXPECT_TRUE(Vec8i8Ty->getElementType()->isIntegerTy(8));
+  EXPECT_EQ(Vec8i8Ty->getElementCount(), ElementCount::getFixed(8));
   // getNumElements
-  EXPECT_EQ(SubVecTy->getNumElements(), 8u);
+  EXPECT_EQ(Vec8i8Ty->getNumElements(), 8u);
   // getHalfElementsVectorType
-  auto *HalfVecTy =
-      sandboxir::FixedVectorType::getHalfElementsVectorType(VecTy);
-  EXPECT_TRUE(HalfVecTy->getElementType()->isIntegerTy(16));
-  EXPECT_EQ(HalfVecTy->getElementCount(), ElementCount::getFixed(2));
+  auto *Vec2i16Ty =
+      sandboxir::FixedVectorType::getHalfElementsVectorType(Vec4i16Ty);
+  EXPECT_TRUE(Vec2i16Ty->getElementType()->isIntegerTy(16));
+  EXPECT_EQ(Vec2i16Ty->getElementCount(), ElementCount::getFixed(2));
   // getDoubleElementsVectorType
-  auto *DoubleVecTy =
-      sandboxir::FixedVectorType::getDoubleElementsVectorType(VecTy);
-  EXPECT_TRUE(DoubleVecTy->getElementType()->isIntegerTy(16));
-  EXPECT_EQ(DoubleVecTy->getElementCount(), ElementCount::getFixed(8));
+  auto *Vec8i16Ty =
+      sandboxir::FixedVectorType::getDoubleElementsVectorType(Vec4i16Ty);
+  EXPECT_TRUE(Vec8i16Ty->getElementType()->isIntegerTy(16));
+  EXPECT_EQ(Vec8i16Ty->getElementCount(), ElementCount::getFixed(8));
 }
 
 TEST_F(SandboxTypeTest, FunctionType) {



More information about the llvm-commits mailing list