[llvm] 3b4e7c9 - [SandboxIR] Implement ScalableVectorType (#108124)

via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 10 19:11:17 PDT 2024


Author: Sterling-Augustine
Date: 2024-09-10T19:11:14-07:00
New Revision: 3b4e7c9c4502d41ece4ef3431bbc12f055adabb5

URL: https://github.com/llvm/llvm-project/commit/3b4e7c9c4502d41ece4ef3431bbc12f055adabb5
DIFF: https://github.com/llvm/llvm-project/commit/3b4e7c9c4502d41ece4ef3431bbc12f055adabb5.diff

LOG: [SandboxIR] Implement ScalableVectorType (#108124)

As in the heading.

Added: 
    

Modified: 
    llvm/include/llvm/SandboxIR/Type.h
    llvm/lib/SandboxIR/Type.cpp
    llvm/unittests/SandboxIR/TypesTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/SandboxIR/Type.h b/llvm/include/llvm/SandboxIR/Type.h
index ec141c249fb21e..a2ac9e014b44ab 100644
--- a/llvm/include/llvm/SandboxIR/Type.h
+++ b/llvm/include/llvm/SandboxIR/Type.h
@@ -26,6 +26,7 @@ class Context;
 class PointerType;
 class VectorType;
 class FixedVectorType;
+class ScalableVectorType;
 class IntegerType;
 class FunctionType;
 class ArrayType;
@@ -39,21 +40,22 @@ class StructType;
 class Type {
 protected:
   llvm::Type *LLVMTy;
-  friend class ArrayType;      // For LLVMTy.
-  friend class StructType;     // For LLVMTy.
-  friend class VectorType;     // For LLVMTy.
-  friend class FixedVectorType; // For LLVMTy.
-  friend class PointerType;    // For LLVMTy.
-  friend class FunctionType;   // For LLVMTy.
-  friend class IntegerType;    // For LLVMTy.
-  friend class Function;       // For LLVMTy.
-  friend class CallBase;       // For LLVMTy.
-  friend class ConstantInt;    // For LLVMTy.
-  friend class ConstantArray;  // For LLVMTy.
-  friend class ConstantStruct; // For LLVMTy.
-  friend class ConstantVector; // For LLVMTy.
-  friend class CmpInst;        // For LLVMTy. TODO: Cleanup after
-                               // sandboxir::VectorType is more complete.
+  friend class ArrayType;          // For LLVMTy.
+  friend class StructType;         // For LLVMTy.
+  friend class VectorType;         // For LLVMTy.
+  friend class FixedVectorType;    // For LLVMTy.
+  friend class ScalableVectorType; // For LLVMTy.
+  friend class PointerType;        // For LLVMTy.
+  friend class FunctionType;       // For LLVMTy.
+  friend class IntegerType;        // For LLVMTy.
+  friend class Function;           // For LLVMTy.
+  friend class CallBase;           // For LLVMTy.
+  friend class ConstantInt;        // For LLVMTy.
+  friend class ConstantArray;      // For LLVMTy.
+  friend class ConstantStruct;     // For LLVMTy.
+  friend class ConstantVector;     // For LLVMTy.
+  friend class CmpInst;            // For LLVMTy. TODO: Cleanup after
+                                   // sandboxir::VectorType is more complete.
 
   // Friend all instruction classes because `create()` functions use LLVMTy.
 #define DEF_INSTR(ID, OPCODE, CLASS) friend class CLASS;
@@ -390,6 +392,57 @@ class FixedVectorType : public VectorType {
   }
 };
 
+class ScalableVectorType : public VectorType {
+public:
+  static ScalableVectorType *get(Type *ElementType, unsigned MinNumElts);
+
+  static ScalableVectorType *get(Type *ElementType,
+                                 const ScalableVectorType *SVTy) {
+    return get(ElementType, SVTy->getMinNumElements());
+  }
+
+  static ScalableVectorType *getInteger(ScalableVectorType *VTy) {
+    return cast<ScalableVectorType>(VectorType::getInteger(VTy));
+  }
+
+  static ScalableVectorType *
+  getExtendedElementVectorType(ScalableVectorType *VTy) {
+    return cast<ScalableVectorType>(
+        VectorType::getExtendedElementVectorType(VTy));
+  }
+
+  static ScalableVectorType *
+  getTruncatedElementVectorType(ScalableVectorType *VTy) {
+    return cast<ScalableVectorType>(
+        VectorType::getTruncatedElementVectorType(VTy));
+  }
+
+  static ScalableVectorType *getSubdividedVectorType(ScalableVectorType *VTy,
+                                                     int NumSubdivs) {
+    return cast<ScalableVectorType>(
+        VectorType::getSubdividedVectorType(VTy, NumSubdivs));
+  }
+
+  static ScalableVectorType *
+  getHalfElementsVectorType(ScalableVectorType *VTy) {
+    return cast<ScalableVectorType>(VectorType::getHalfElementsVectorType(VTy));
+  }
+
+  static ScalableVectorType *
+  getDoubleElementsVectorType(ScalableVectorType *VTy) {
+    return cast<ScalableVectorType>(
+        VectorType::getDoubleElementsVectorType(VTy));
+  }
+
+  unsigned getMinNumElements() const {
+    return cast<llvm::ScalableVectorType>(LLVMTy)->getMinNumElements();
+  }
+
+  static bool classof(const Type *T) {
+    return isa<llvm::ScalableVectorType>(T->LLVMTy);
+  }
+};
+
 class FunctionType : public Type {
 public:
   // TODO: add missing functions

diff  --git a/llvm/lib/SandboxIR/Type.cpp b/llvm/lib/SandboxIR/Type.cpp
index 26aa8b3743084c..87dcb726dde351 100644
--- a/llvm/lib/SandboxIR/Type.cpp
+++ b/llvm/lib/SandboxIR/Type.cpp
@@ -108,6 +108,12 @@ FixedVectorType *FixedVectorType::get(Type *ElementType, unsigned NumElts) {
       llvm::FixedVectorType::get(ElementType->LLVMTy, NumElts)));
 }
 
+ScalableVectorType *ScalableVectorType::get(Type *ElementType,
+                                            unsigned NumElts) {
+  return cast<ScalableVectorType>(ElementType->getContext().getType(
+      llvm::ScalableVectorType::get(ElementType->LLVMTy, NumElts)));
+}
+
 IntegerType *IntegerType::get(Context &Ctx, unsigned NumBits) {
   return cast<IntegerType>(
       Ctx.getType(llvm::IntegerType::get(Ctx.LLVMCtx, NumBits)));

diff  --git a/llvm/unittests/SandboxIR/TypesTest.cpp b/llvm/unittests/SandboxIR/TypesTest.cpp
index 3564ae66830147..40aa32fb08ed01 100644
--- a/llvm/unittests/SandboxIR/TypesTest.cpp
+++ b/llvm/unittests/SandboxIR/TypesTest.cpp
@@ -381,6 +381,65 @@ define void @foo(<4 x i16> %vi0, <4 x float> %vf1, i8 %i0) {
   EXPECT_EQ(Vec8i16Ty->getElementCount(), ElementCount::getFixed(8));
 }
 
+TEST_F(SandboxTypeTest, ScalableVectorType) {
+  parseIR(C, R"IR(
+define void @foo(<vscale x 4 x i16> %vi0, <vscale x 4 x float> %vf1, i8 %i0) {
+  ret void
+}
+)IR");
+  llvm::Function *LLVMF = &*M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  auto *F = Ctx.createFunction(LLVMF);
+  // Check classof(), creation, accessors
+  auto *Vec4i16Ty =
+      cast<sandboxir::ScalableVectorType>(F->getArg(0)->getType());
+  EXPECT_TRUE(Vec4i16Ty->getElementType()->isIntegerTy(16));
+  EXPECT_EQ(Vec4i16Ty->getMinNumElements(), 4u);
+
+  // get(ElementType, NumElements)
+  EXPECT_EQ(
+      sandboxir::ScalableVectorType::get(sandboxir::Type::getInt16Ty(Ctx), 4),
+      F->getArg(0)->getType());
+  // get(ElementType, Other)
+  EXPECT_EQ(sandboxir::ScalableVectorType::get(
+                sandboxir::Type::getInt16Ty(Ctx),
+                cast<sandboxir::ScalableVectorType>(F->getArg(0)->getType())),
+            F->getArg(0)->getType());
+  auto *Vec4FTy = cast<sandboxir::ScalableVectorType>(F->getArg(1)->getType());
+  EXPECT_TRUE(Vec4FTy->getElementType()->isFloatTy());
+  // getInteger
+  auto *Vec4i32Ty = sandboxir::ScalableVectorType::getInteger(Vec4FTy);
+  EXPECT_TRUE(Vec4i32Ty->getElementType()->isIntegerTy(32));
+  EXPECT_EQ(Vec4i32Ty->getMinNumElements(), Vec4FTy->getMinNumElements());
+  // getExtendedElementCountVectorType
+  auto *Vec4i64Ty =
+      sandboxir::ScalableVectorType::getExtendedElementVectorType(Vec4i16Ty);
+  EXPECT_TRUE(Vec4i64Ty->getElementType()->isIntegerTy(32));
+  EXPECT_EQ(Vec4i64Ty->getMinNumElements(), Vec4i16Ty->getMinNumElements());
+  // getTruncatedElementVectorType
+  auto *Vec4i8Ty =
+      sandboxir::ScalableVectorType::getTruncatedElementVectorType(Vec4i16Ty);
+  EXPECT_TRUE(Vec4i8Ty->getElementType()->isIntegerTy(8));
+  EXPECT_EQ(Vec4i8Ty->getMinNumElements(), Vec4i8Ty->getMinNumElements());
+  // getSubdividedVectorType
+  auto *Vec8i8Ty =
+      sandboxir::ScalableVectorType::getSubdividedVectorType(Vec4i16Ty, 1);
+  EXPECT_TRUE(Vec8i8Ty->getElementType()->isIntegerTy(8));
+  EXPECT_EQ(Vec8i8Ty->getMinNumElements(), 8u);
+  // getMinNumElements
+  EXPECT_EQ(Vec8i8Ty->getMinNumElements(), 8u);
+  // getHalfElementsVectorType
+  auto *Vec2i16Ty =
+      sandboxir::ScalableVectorType::getHalfElementsVectorType(Vec4i16Ty);
+  EXPECT_TRUE(Vec2i16Ty->getElementType()->isIntegerTy(16));
+  EXPECT_EQ(Vec2i16Ty->getMinNumElements(), 2u);
+  // getDoubleElementsVectorType
+  auto *Vec8i16Ty =
+      sandboxir::ScalableVectorType::getDoubleElementsVectorType(Vec4i16Ty);
+  EXPECT_TRUE(Vec8i16Ty->getElementType()->isIntegerTy(16));
+  EXPECT_EQ(Vec8i16Ty->getMinNumElements(), 8u);
+}
+
 TEST_F(SandboxTypeTest, FunctionType) {
   parseIR(C, R"IR(
 define void @foo() {


        


More information about the llvm-commits mailing list