[llvm] [SandboxIR] Add missing VectorType functions (PR #107650)

via llvm-commits llvm-commits at lists.llvm.org
Mon Sep 9 12:05:30 PDT 2024


https://github.com/Sterling-Augustine updated https://github.com/llvm/llvm-project/pull/107650

>From 33332c4f19f5d19518307c0818333bdcb57a771f Mon Sep 17 00:00:00 2001
From: Sterling Augustine <saugustine at google.com>
Date: Thu, 5 Sep 2024 15:18:21 -0700
Subject: [PATCH 1/2] [SandboxIR] Add missing VectorType functions

---
 llvm/include/llvm/SandboxIR/Type.h     | 31 +++++++++++++++++--
 llvm/lib/SandboxIR/Type.cpp            | 41 +++++++++++++++++++++++++-
 llvm/unittests/SandboxIR/TypesTest.cpp | 41 +++++++++++++++++++++++---
 3 files changed, 105 insertions(+), 8 deletions(-)

diff --git a/llvm/include/llvm/SandboxIR/Type.h b/llvm/include/llvm/SandboxIR/Type.h
index 69ca156e82101c..c76ad505f9e137 100644
--- a/llvm/include/llvm/SandboxIR/Type.h
+++ b/llvm/include/llvm/SandboxIR/Type.h
@@ -50,8 +50,8 @@ class Type {
   friend class ConstantArray;  // For LLVMTy.
   friend class ConstantStruct; // For LLVMTy.
   friend class ConstantVector; // For LLVMTy.
-  friend class CmpInst; // For LLVMTy. TODO: Cleanup after sandboxir::VectorType
-                        // is more complete.
+  friend class CmpInst;        // For LLVMTy. TODO: Cleanup after
+                               // sandboxir::VectorType is more complete.
 
   // Friend all instruction classes because `create()` functions use LLVMTy.
 #define DEF_INSTR(ID, OPCODE, CLASS) friend class CLASS;
@@ -317,7 +317,32 @@ class StructType : public Type {
 class VectorType : public Type {
 public:
   static VectorType *get(Type *ElementType, ElementCount EC);
-  // TODO: add missing functions
+  static VectorType *get(Type *ElementType, unsigned NumElements,
+                         bool Scalable) {
+    return VectorType::get(ElementType,
+                           ElementCount::get(NumElements, Scalable));
+  }
+  // Needs tests
+  Type *getElementType() const;
+
+  static VectorType *get(Type *ElementType, const VectorType *Other) {
+    return VectorType::get(ElementType, Other->getElementCount());
+  }
+
+  inline ElementCount getElementCount() const {
+    return cast<llvm::VectorType>(LLVMTy)->getElementCount();
+  }
+  static VectorType *getInteger(Context &Ctx, VectorType *VTy);
+  static VectorType *getExtendedElementVectorType(Context &Ctx,
+                                                  VectorType *VTy);
+  static VectorType *getTruncatedElementVectorType(Context &Ctx,
+                                                   VectorType *VTy);
+  static VectorType *getSubdividedVectorType(Context &Ctx, VectorType *VTy,
+                                             int NumSubdivs);
+  static VectorType *getHalfElementsVectorType(Context &Ctx, VectorType *VTy);
+  static VectorType *getDoubleElementsVectorType(Context &Ctx, VectorType *VTy);
+  static bool isValidElementType(Context &Ctx, Type *ElemTy);
+
   static bool classof(const Type *From) {
     return isa<llvm::VectorType>(From->LLVMTy);
   }
diff --git a/llvm/lib/SandboxIR/Type.cpp b/llvm/lib/SandboxIR/Type.cpp
index 11a16e865213fb..1421526e9e7e31 100644
--- a/llvm/lib/SandboxIR/Type.cpp
+++ b/llvm/lib/SandboxIR/Type.cpp
@@ -36,7 +36,6 @@ Type *Type::getDoubleTy(Context &Ctx) {
 Type *Type::getFloatTy(Context &Ctx) {
   return Ctx.getType(llvm::Type::getFloatTy(Ctx.LLVMCtx));
 }
-
 PointerType *PointerType::get(Type *ElementType, unsigned AddressSpace) {
   return cast<PointerType>(ElementType->getContext().getType(
       llvm::PointerType::get(ElementType->LLVMTy, AddressSpace)));
@@ -67,6 +66,46 @@ VectorType *VectorType::get(Type *ElementType, ElementCount EC) {
       llvm::VectorType::get(ElementType->LLVMTy, EC)));
 }
 
+Type *VectorType::getElementType() const {
+  return Ctx.getType(cast<llvm::VectorType>(LLVMTy)->getElementType());
+}
+VectorType *VectorType::getInteger(Context &Ctx, VectorType *VTy) {
+  return cast<VectorType>(Ctx.getType(
+      llvm::VectorType::getInteger(cast<llvm::VectorType>(VTy->LLVMTy))));
+}
+VectorType *VectorType::getExtendedElementVectorType(Context &Ctx,
+                                                     VectorType *VTy) {
+  return cast<VectorType>(
+      Ctx.getType(llvm::VectorType::getExtendedElementVectorType(
+          cast<llvm::VectorType>(VTy->LLVMTy))));
+}
+VectorType *VectorType::getTruncatedElementVectorType(Context &Ctx,
+                                                      VectorType *VTy) {
+  return cast<VectorType>(
+      Ctx.getType(llvm::VectorType::getTruncatedElementVectorType(
+          cast<llvm::VectorType>(VTy->LLVMTy))));
+}
+VectorType *VectorType::getSubdividedVectorType(Context &Ctx, VectorType *VTy,
+                                                int NumSubdivs) {
+  return cast<VectorType>(Ctx.getType(llvm::VectorType::getSubdividedVectorType(
+      cast<llvm::VectorType>(VTy->LLVMTy), NumSubdivs)));
+}
+VectorType *VectorType::getHalfElementsVectorType(Context &Ctx,
+                                                  VectorType *VTy) {
+  return cast<VectorType>(
+      Ctx.getType(llvm::VectorType::getHalfElementsVectorType(
+          cast<llvm::VectorType>(VTy->LLVMTy))));
+}
+VectorType *VectorType::getDoubleElementsVectorType(Context &Ctx,
+                                                    VectorType *VTy) {
+  return cast<VectorType>(
+      Ctx.getType(llvm::VectorType::getDoubleElementsVectorType(
+          cast<llvm::VectorType>(VTy->LLVMTy))));
+}
+bool VectorType::isValidElementType(Context &Ctx, Type *ElemTy) {
+  return llvm::VectorType::isValidElementType(ElemTy->LLVMTy);
+}
+
 IntegerType *IntegerType::get(Context &Ctx, unsigned NumBits) {
   return cast<IntegerType>(
       Ctx.getType(llvm::IntegerType::get(Ctx.LLVMCtx, NumBits)));
diff --git a/llvm/unittests/SandboxIR/TypesTest.cpp b/llvm/unittests/SandboxIR/TypesTest.cpp
index 36ef0cf8e52911..619d0c9af2c068 100644
--- a/llvm/unittests/SandboxIR/TypesTest.cpp
+++ b/llvm/unittests/SandboxIR/TypesTest.cpp
@@ -268,16 +268,49 @@ define void @foo({i32, i8} %v0) {
 
 TEST_F(SandboxTypeTest, VectorType) {
   parseIR(C, R"IR(
-define void @foo(<2 x i8> %v0) {
+define void @foo(<4 x i16> %vi0, <4 x float> %vf1, i8 %i0) {
   ret void
 }
 )IR");
   llvm::Function *LLVMF = &*M->getFunction("foo");
   sandboxir::Context Ctx(C);
   auto *F = Ctx.createFunction(LLVMF);
-  // Check classof(), creation.
-  [[maybe_unused]] auto *VecTy =
-      cast<sandboxir::VectorType>(F->getArg(0)->getType());
+  // Check classof(), creation, accessors
+  auto *VecTy = cast<sandboxir::VectorType>(F->getArg(0)->getType());
+  EXPECT_TRUE(VecTy->getElementType()->isIntegerTy(16));
+  EXPECT_EQ(VecTy->getElementCount(), ElementCount::getFixed(4));
+
+  auto *FVecTy = cast<sandboxir::VectorType>(F->getArg(1)->getType());
+  EXPECT_TRUE(FVecTy->getElementType()->isFloatTy());
+  auto *IVecTy = sandboxir::VectorType::getInteger(Ctx, FVecTy);
+  EXPECT_TRUE(IVecTy->getElementType()->isIntegerTy(32));
+  EXPECT_EQ(IVecTy->getElementCount(), FVecTy->getElementCount());
+
+  auto *ExtVecTy =
+      sandboxir::VectorType::getExtendedElementVectorType(Ctx, IVecTy);
+  EXPECT_TRUE(ExtVecTy->getElementType()->isIntegerTy(64));
+  EXPECT_EQ(ExtVecTy->getElementCount(), VecTy->getElementCount());
+  auto *TruncVecTy =
+      sandboxir::VectorType::getTruncatedElementVectorType(Ctx, IVecTy);
+  EXPECT_TRUE(TruncVecTy->getElementType()->isIntegerTy(16));
+  EXPECT_EQ(TruncVecTy->getElementCount(), VecTy->getElementCount());
+  auto *SubVecTy =
+      sandboxir::VectorType::getSubdividedVectorType(Ctx, VecTy, 1);
+  EXPECT_TRUE(SubVecTy->getElementType()->isIntegerTy(8));
+  EXPECT_EQ(SubVecTy->getElementCount(), ElementCount::getFixed(8));
+  auto *HalfVecTy =
+      sandboxir::VectorType::getHalfElementsVectorType(Ctx, VecTy);
+  EXPECT_TRUE(HalfVecTy->getElementType()->isIntegerTy(16));
+  EXPECT_EQ(HalfVecTy->getElementCount(), ElementCount::getFixed(2));
+  auto *DoubleVecTy =
+      sandboxir::VectorType::getDoubleElementsVectorType(Ctx, VecTy);
+  EXPECT_TRUE(DoubleVecTy->getElementType()->isIntegerTy(16));
+  EXPECT_EQ(DoubleVecTy->getElementCount(), ElementCount::getFixed(8));
+
+  auto *I8Type = F->getArg(2)->getType();
+  EXPECT_TRUE(I8Type->isIntegerTy());
+  EXPECT_TRUE(sandboxir::VectorType::isValidElementType(Ctx, I8Type));
+  EXPECT_FALSE(sandboxir::VectorType::isValidElementType(Ctx, FVecTy));
 }
 
 TEST_F(SandboxTypeTest, FunctionType) {

>From 8d14f0c7dae0d151a7c536aa0032b4a0aeeae4c9 Mon Sep 17 00:00:00 2001
From: Sterling Augustine <saugustine at google.com>
Date: Mon, 9 Sep 2024 12:04:42 -0700
Subject: [PATCH 2/2] Address comments

---
 llvm/include/llvm/SandboxIR/Type.h     | 18 ++++++--------
 llvm/lib/SandboxIR/Type.cpp            | 33 ++++++++++++--------------
 llvm/unittests/SandboxIR/TypesTest.cpp | 29 +++++++++++-----------
 3 files changed, 37 insertions(+), 43 deletions(-)

diff --git a/llvm/include/llvm/SandboxIR/Type.h b/llvm/include/llvm/SandboxIR/Type.h
index c76ad505f9e137..44aee4e4a5b46e 100644
--- a/llvm/include/llvm/SandboxIR/Type.h
+++ b/llvm/include/llvm/SandboxIR/Type.h
@@ -322,7 +322,6 @@ class VectorType : public Type {
     return VectorType::get(ElementType,
                            ElementCount::get(NumElements, Scalable));
   }
-  // Needs tests
   Type *getElementType() const;
 
   static VectorType *get(Type *ElementType, const VectorType *Other) {
@@ -332,16 +331,13 @@ class VectorType : public Type {
   inline ElementCount getElementCount() const {
     return cast<llvm::VectorType>(LLVMTy)->getElementCount();
   }
-  static VectorType *getInteger(Context &Ctx, VectorType *VTy);
-  static VectorType *getExtendedElementVectorType(Context &Ctx,
-                                                  VectorType *VTy);
-  static VectorType *getTruncatedElementVectorType(Context &Ctx,
-                                                   VectorType *VTy);
-  static VectorType *getSubdividedVectorType(Context &Ctx, VectorType *VTy,
-                                             int NumSubdivs);
-  static VectorType *getHalfElementsVectorType(Context &Ctx, VectorType *VTy);
-  static VectorType *getDoubleElementsVectorType(Context &Ctx, VectorType *VTy);
-  static bool isValidElementType(Context &Ctx, Type *ElemTy);
+  static VectorType *getInteger(VectorType *VTy);
+  static VectorType *getExtendedElementVectorType(VectorType *VTy);
+  static VectorType *getTruncatedElementVectorType(VectorType *VTy);
+  static VectorType *getSubdividedVectorType(VectorType *VTy, int NumSubdivs);
+  static VectorType *getHalfElementsVectorType(VectorType *VTy);
+  static VectorType *getDoubleElementsVectorType(VectorType *VTy);
+  static bool isValidElementType(Type *ElemTy);
 
   static bool classof(const Type *From) {
     return isa<llvm::VectorType>(From->LLVMTy);
diff --git a/llvm/lib/SandboxIR/Type.cpp b/llvm/lib/SandboxIR/Type.cpp
index 1421526e9e7e31..bf9f02e2ba3111 100644
--- a/llvm/lib/SandboxIR/Type.cpp
+++ b/llvm/lib/SandboxIR/Type.cpp
@@ -69,40 +69,37 @@ VectorType *VectorType::get(Type *ElementType, ElementCount EC) {
 Type *VectorType::getElementType() const {
   return Ctx.getType(cast<llvm::VectorType>(LLVMTy)->getElementType());
 }
-VectorType *VectorType::getInteger(Context &Ctx, VectorType *VTy) {
-  return cast<VectorType>(Ctx.getType(
+VectorType *VectorType::getInteger(VectorType *VTy) {
+  return cast<VectorType>(VTy->getContext().getType(
       llvm::VectorType::getInteger(cast<llvm::VectorType>(VTy->LLVMTy))));
 }
-VectorType *VectorType::getExtendedElementVectorType(Context &Ctx,
-                                                     VectorType *VTy) {
+VectorType *VectorType::getExtendedElementVectorType(VectorType *VTy) {
   return cast<VectorType>(
-      Ctx.getType(llvm::VectorType::getExtendedElementVectorType(
+      VTy->getContext().getType(llvm::VectorType::getExtendedElementVectorType(
           cast<llvm::VectorType>(VTy->LLVMTy))));
 }
-VectorType *VectorType::getTruncatedElementVectorType(Context &Ctx,
-                                                      VectorType *VTy) {
+VectorType *VectorType::getTruncatedElementVectorType(VectorType *VTy) {
   return cast<VectorType>(
-      Ctx.getType(llvm::VectorType::getTruncatedElementVectorType(
+      VTy->getContext().getType(llvm::VectorType::getTruncatedElementVectorType(
           cast<llvm::VectorType>(VTy->LLVMTy))));
 }
-VectorType *VectorType::getSubdividedVectorType(Context &Ctx, VectorType *VTy,
+VectorType *VectorType::getSubdividedVectorType(VectorType *VTy,
                                                 int NumSubdivs) {
-  return cast<VectorType>(Ctx.getType(llvm::VectorType::getSubdividedVectorType(
-      cast<llvm::VectorType>(VTy->LLVMTy), NumSubdivs)));
+  return cast<VectorType>(
+      VTy->getContext().getType(llvm::VectorType::getSubdividedVectorType(
+          cast<llvm::VectorType>(VTy->LLVMTy), NumSubdivs)));
 }
-VectorType *VectorType::getHalfElementsVectorType(Context &Ctx,
-                                                  VectorType *VTy) {
+VectorType *VectorType::getHalfElementsVectorType(VectorType *VTy) {
   return cast<VectorType>(
-      Ctx.getType(llvm::VectorType::getHalfElementsVectorType(
+      VTy->getContext().getType(llvm::VectorType::getHalfElementsVectorType(
           cast<llvm::VectorType>(VTy->LLVMTy))));
 }
-VectorType *VectorType::getDoubleElementsVectorType(Context &Ctx,
-                                                    VectorType *VTy) {
+VectorType *VectorType::getDoubleElementsVectorType(VectorType *VTy) {
   return cast<VectorType>(
-      Ctx.getType(llvm::VectorType::getDoubleElementsVectorType(
+      VTy->getContext().getType(llvm::VectorType::getDoubleElementsVectorType(
           cast<llvm::VectorType>(VTy->LLVMTy))));
 }
-bool VectorType::isValidElementType(Context &Ctx, Type *ElemTy) {
+bool VectorType::isValidElementType(Type *ElemTy) {
   return llvm::VectorType::isValidElementType(ElemTy->LLVMTy);
 }
 
diff --git a/llvm/unittests/SandboxIR/TypesTest.cpp b/llvm/unittests/SandboxIR/TypesTest.cpp
index 619d0c9af2c068..dda12b856e3992 100644
--- a/llvm/unittests/SandboxIR/TypesTest.cpp
+++ b/llvm/unittests/SandboxIR/TypesTest.cpp
@@ -282,35 +282,36 @@ define void @foo(<4 x i16> %vi0, <4 x float> %vf1, i8 %i0) {
 
   auto *FVecTy = cast<sandboxir::VectorType>(F->getArg(1)->getType());
   EXPECT_TRUE(FVecTy->getElementType()->isFloatTy());
-  auto *IVecTy = sandboxir::VectorType::getInteger(Ctx, FVecTy);
+  // getInteger
+  auto *IVecTy = sandboxir::VectorType::getInteger(FVecTy);
   EXPECT_TRUE(IVecTy->getElementType()->isIntegerTy(32));
   EXPECT_EQ(IVecTy->getElementCount(), FVecTy->getElementCount());
-
-  auto *ExtVecTy =
-      sandboxir::VectorType::getExtendedElementVectorType(Ctx, IVecTy);
+  // getExtendedElementCountVectorType
+  auto *ExtVecTy = sandboxir::VectorType::getExtendedElementVectorType(IVecTy);
   EXPECT_TRUE(ExtVecTy->getElementType()->isIntegerTy(64));
   EXPECT_EQ(ExtVecTy->getElementCount(), VecTy->getElementCount());
+  // getTruncatedElementVectorType
   auto *TruncVecTy =
-      sandboxir::VectorType::getTruncatedElementVectorType(Ctx, IVecTy);
+      sandboxir::VectorType::getTruncatedElementVectorType(IVecTy);
   EXPECT_TRUE(TruncVecTy->getElementType()->isIntegerTy(16));
   EXPECT_EQ(TruncVecTy->getElementCount(), VecTy->getElementCount());
-  auto *SubVecTy =
-      sandboxir::VectorType::getSubdividedVectorType(Ctx, VecTy, 1);
+  // getSubdividedVectorType
+  auto *SubVecTy = sandboxir::VectorType::getSubdividedVectorType(VecTy, 1);
   EXPECT_TRUE(SubVecTy->getElementType()->isIntegerTy(8));
   EXPECT_EQ(SubVecTy->getElementCount(), ElementCount::getFixed(8));
-  auto *HalfVecTy =
-      sandboxir::VectorType::getHalfElementsVectorType(Ctx, VecTy);
+  // getHalfElementsVectorType
+  auto *HalfVecTy = sandboxir::VectorType::getHalfElementsVectorType(VecTy);
   EXPECT_TRUE(HalfVecTy->getElementType()->isIntegerTy(16));
   EXPECT_EQ(HalfVecTy->getElementCount(), ElementCount::getFixed(2));
-  auto *DoubleVecTy =
-      sandboxir::VectorType::getDoubleElementsVectorType(Ctx, VecTy);
+  // getDoubleElementsVectorType
+  auto *DoubleVecTy = sandboxir::VectorType::getDoubleElementsVectorType(VecTy);
   EXPECT_TRUE(DoubleVecTy->getElementType()->isIntegerTy(16));
   EXPECT_EQ(DoubleVecTy->getElementCount(), ElementCount::getFixed(8));
-
+  // isValidElementType
   auto *I8Type = F->getArg(2)->getType();
   EXPECT_TRUE(I8Type->isIntegerTy());
-  EXPECT_TRUE(sandboxir::VectorType::isValidElementType(Ctx, I8Type));
-  EXPECT_FALSE(sandboxir::VectorType::isValidElementType(Ctx, FVecTy));
+  EXPECT_TRUE(sandboxir::VectorType::isValidElementType(I8Type));
+  EXPECT_FALSE(sandboxir::VectorType::isValidElementType(FVecTy));
 }
 
 TEST_F(SandboxTypeTest, FunctionType) {



More information about the llvm-commits mailing list