[PATCH] D82419: [SVE] add derived vector get(Type *, ElementCount) and get(Type *, VectorType)

Christopher Tetreault via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Tue Jun 23 16:44:12 PDT 2020


ctetreau created this revision.
Herald added subscribers: llvm-commits, psnobl, rkruppe, tschuett.
Herald added a reviewer: efriedma.
Herald added a project: LLVM.
ctetreau added reviewers: huntergr, david-arm, kmclaughlin, fpetrogalli.
ctetreau edited the summary of this revision.
ctetreau edited the summary of this revision.

These new getters in the derived vector types provide consistency with
all the other derived get functions. Prior to this change, if a
developer were to write:

auto *VTy = FixedVectorType::get(SomeTy, SomeVTy->getElementCount())

Not only is the type of VTy VectorType, which is inconsistent with how
all the other getters work when called through a derived vector type
get() function, but it might even actually be an instance of
ScalableVectorType, which is almost certainly not what the caller
expects.

This patch adds get functions that ensure that all base VectorType are
hidden when called via a derived vector type. When these functions are
called, if the result would not have matched the requested type, nullptr
is returned. This enables patterns like:

if (auto *SVTy = ScalableVectorType::get(SomeTy, SomeVTy)) {
 // stuff

... where if SomeVTy is a fixed width vector, the branch will not be
taken.


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D82419

Files:
  llvm/include/llvm/IR/DerivedTypes.h
  llvm/unittests/IR/VectorTypesTest.cpp


Index: llvm/unittests/IR/VectorTypesTest.cpp
===================================================================
--- llvm/unittests/IR/VectorTypesTest.cpp
+++ llvm/unittests/IR/VectorTypesTest.cpp
@@ -377,4 +377,45 @@
   // non-scalable vector sizes.
 }
 
+TEST(VectorTypesTest, DerivedGettersHidingBaseGetters) {
+  // The derived vector types implement get(Type *, ElementCount) and
+  // get(Type*, VectorType*) that hide the base class versions and return
+  // instances of the derived vector type or null.
+
+  LLVMContext Ctx;
+
+  Type *Ty = Type::getInt32Ty(Ctx);
+
+  ElementCount FV4 = {4, false};
+  ElementCount SV4 = {4, true};
+
+  auto *BFV4Ty = VectorType::get(Ty, FV4);
+  auto *BSV4Ty = VectorType::get(Ty, SV4);
+
+  EXPECT_NE(nullptr, BFV4Ty);
+  EXPECT_NE(nullptr, BSV4Ty);
+
+  // Test ElementCount getters
+  auto *FV4Ty = FixedVectorType::get(Ty, FV4);
+  auto *FV4TyNaught = FixedVectorType::get(Ty, SV4);
+  auto *SV4Ty = ScalableVectorType::get(Ty, SV4);
+  auto *SV4TyNaught = ScalableVectorType::get(Ty, FV4);
+
+  EXPECT_EQ(BFV4Ty, FV4Ty);
+  EXPECT_EQ(nullptr, FV4TyNaught);
+  EXPECT_EQ(BSV4Ty, SV4Ty);
+  EXPECT_EQ(nullptr, SV4TyNaught);
+
+  // Test VectorType getters
+  FV4Ty = FixedVectorType::get(Ty, BFV4Ty);
+  FV4TyNaught = FixedVectorType::get(Ty, BSV4Ty);
+  SV4Ty = ScalableVectorType::get(Ty, BSV4Ty);
+  SV4TyNaught = ScalableVectorType::get(Ty, BFV4Ty);
+
+  EXPECT_EQ(BFV4Ty, FV4Ty);
+  EXPECT_EQ(nullptr, FV4TyNaught);
+  EXPECT_EQ(BSV4Ty, SV4Ty);
+  EXPECT_EQ(nullptr, SV4TyNaught);
+}
+
 } // end anonymous namespace
Index: llvm/include/llvm/IR/DerivedTypes.h
===================================================================
--- llvm/include/llvm/IR/DerivedTypes.h
+++ llvm/include/llvm/IR/DerivedTypes.h
@@ -562,6 +562,17 @@
 public:
   static FixedVectorType *get(Type *ElementType, unsigned NumElts);
 
+  static FixedVectorType *get(Type *ElementType, ElementCount EC) {
+    if (EC.Scalable)
+      return nullptr;
+
+    return get(ElementType, EC.Min);
+  }
+
+  static FixedVectorType *get(Type *ElementType, const VectorType *VTy) {
+    return get(ElementType, VTy->getElementCount());
+  }
+
   static FixedVectorType *get(Type *ElementType, const FixedVectorType *FVTy) {
     return get(ElementType, FVTy->getNumElements());
   }
@@ -607,6 +618,17 @@
 public:
   static ScalableVectorType *get(Type *ElementType, unsigned MinNumElts);
 
+  static ScalableVectorType *get(Type *ElementType, ElementCount EC) {
+    if (!EC.Scalable)
+      return nullptr;
+
+    return get(ElementType, EC.Min);
+  }
+
+  static ScalableVectorType *get(Type *ElementType, const VectorType *VTy) {
+    return get(ElementType, VTy->getElementCount());
+  }
+
   static ScalableVectorType *get(Type *ElementType,
                                  const ScalableVectorType *SVTy) {
     return get(ElementType, SVTy->getMinNumElements());


-------------- next part --------------
A non-text attachment was scrubbed...
Name: D82419.272862.patch
Type: text/x-patch
Size: 2887 bytes
Desc: not available
URL: <http://lists.llvm.org/pipermail/llvm-commits/attachments/20200623/109e3129/attachment.bin>


More information about the llvm-commits mailing list