[llvm] [SandboxIR] Implement ScalableVectorType (PR #108124)

via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 10 18:19:49 PDT 2024


https://github.com/Sterling-Augustine updated https://github.com/llvm/llvm-project/pull/108124

>From a0533d7272cc350a7681107b5f0a0e29db9aaf3f Mon Sep 17 00:00:00 2001
From: Sterling Augustine <saugustine at google.com>
Date: Tue, 10 Sep 2024 17:49:53 -0700
Subject: [PATCH 1/3] [SandboxIR] Implement ScalableVectorType

---
 llvm/include/llvm/SandboxIR/Type.h     | 83 +++++++++++++++++++++-----
 llvm/lib/SandboxIR/Type.cpp            |  6 ++
 llvm/unittests/SandboxIR/TypesTest.cpp | 59 ++++++++++++++++++
 3 files changed, 133 insertions(+), 15 deletions(-)

diff --git a/llvm/include/llvm/SandboxIR/Type.h b/llvm/include/llvm/SandboxIR/Type.h
index ec141c249fb21e..4a55564d81ff8d 100644
--- a/llvm/include/llvm/SandboxIR/Type.h
+++ b/llvm/include/llvm/SandboxIR/Type.h
@@ -26,6 +26,7 @@ class Context;
 class PointerType;
 class VectorType;
 class FixedVectorType;
+class ScalableVectorType;
 class IntegerType;
 class FunctionType;
 class ArrayType;
@@ -39,21 +40,22 @@ class StructType;
 class Type {
 protected:
   llvm::Type *LLVMTy;
-  friend class ArrayType;      // For LLVMTy.
-  friend class StructType;     // For LLVMTy.
-  friend class VectorType;     // For LLVMTy.
-  friend class FixedVectorType; // For LLVMTy.
-  friend class PointerType;    // For LLVMTy.
-  friend class FunctionType;   // For LLVMTy.
-  friend class IntegerType;    // For LLVMTy.
-  friend class Function;       // For LLVMTy.
-  friend class CallBase;       // For LLVMTy.
-  friend class ConstantInt;    // For LLVMTy.
-  friend class ConstantArray;  // For LLVMTy.
-  friend class ConstantStruct; // For LLVMTy.
-  friend class ConstantVector; // For LLVMTy.
-  friend class CmpInst;        // For LLVMTy. TODO: Cleanup after
-                               // sandboxir::VectorType is more complete.
+  friend class ArrayType;        // For LLVMTy.
+  friend class StructType;       // For LLVMTy.
+  friend class VectorType;       // For LLVMTy.
+  friend class FixedVectorType;  // For LLVMTy.
+  friend class ScalableVectorType; // For LLVMTy.
+  friend class PointerType;      // For LLVMTy.
+  friend class FunctionType;     // For LLVMTy.
+  friend class IntegerType;      // For LLVMTy.
+  friend class Function;         // For LLVMTy.
+  friend class CallBase;         // For LLVMTy.
+  friend class ConstantInt;      // For LLVMTy.
+  friend class ConstantArray;    // For LLVMTy.
+  friend class ConstantStruct;   // For LLVMTy.
+  friend class ConstantVector;   // For LLVMTy.
+  friend class CmpInst;          // For LLVMTy. TODO: Cleanup after
+                                 // sandboxir::VectorType is more complete.
 
   // Friend all instruction classes because `create()` functions use LLVMTy.
 #define DEF_INSTR(ID, OPCODE, CLASS) friend class CLASS;
@@ -390,6 +392,57 @@ class FixedVectorType : public VectorType {
   }
 };
 
+class ScalableVectorType : public VectorType {
+public:
+  static ScalableVectorType *get(Type *ElementType, unsigned MinNumElts);
+
+  static ScalableVectorType *get(Type *ElementType,
+                                 const ScalableVectorType *SVTy) {
+    return get(ElementType, SVTy->getMinNumElements());
+  }
+
+  static ScalableVectorType *getInteger(ScalableVectorType *VTy) {
+    return cast<ScalableVectorType>(VectorType::getInteger(VTy));
+  }
+
+  static ScalableVectorType *
+  getExtendedElementVectorType(ScalableVectorType *VTy) {
+    return cast<ScalableVectorType>(
+        VectorType::getExtendedElementVectorType(VTy));
+  }
+
+  static ScalableVectorType *
+  getTruncatedElementVectorType(ScalableVectorType *VTy) {
+    return cast<ScalableVectorType>(
+        VectorType::getTruncatedElementVectorType(VTy));
+  }
+
+  static ScalableVectorType *getSubdividedVectorType(ScalableVectorType *VTy,
+                                                     int NumSubdivs) {
+    return cast<ScalableVectorType>(
+        VectorType::getSubdividedVectorType(VTy, NumSubdivs));
+  }
+
+  static ScalableVectorType *
+  getHalfElementsVectorType(ScalableVectorType *VTy) {
+    return cast<ScalableVectorType>(VectorType::getHalfElementsVectorType(VTy));
+  }
+
+  static ScalableVectorType *
+  getDoubleElementsVectorType(ScalableVectorType *VTy) {
+    return cast<ScalableVectorType>(
+        VectorType::getDoubleElementsVectorType(VTy));
+  }
+
+  unsigned getMinNumElements() const {
+    return cast<llvm::ScalableVectorType>(LLVMTy)->getMinNumElements();
+  }
+
+  static bool classof(const Type *T) {
+    return isa<llvm::ScalableVectorType>(T->LLVMTy);
+  }
+};
+
 class FunctionType : public Type {
 public:
   // TODO: add missing functions
diff --git a/llvm/lib/SandboxIR/Type.cpp b/llvm/lib/SandboxIR/Type.cpp
index 26aa8b3743084c..87dcb726dde351 100644
--- a/llvm/lib/SandboxIR/Type.cpp
+++ b/llvm/lib/SandboxIR/Type.cpp
@@ -108,6 +108,12 @@ FixedVectorType *FixedVectorType::get(Type *ElementType, unsigned NumElts) {
       llvm::FixedVectorType::get(ElementType->LLVMTy, NumElts)));
 }
 
+ScalableVectorType *ScalableVectorType::get(Type *ElementType,
+                                            unsigned NumElts) {
+  return cast<ScalableVectorType>(ElementType->getContext().getType(
+      llvm::ScalableVectorType::get(ElementType->LLVMTy, NumElts)));
+}
+
 IntegerType *IntegerType::get(Context &Ctx, unsigned NumBits) {
   return cast<IntegerType>(
       Ctx.getType(llvm::IntegerType::get(Ctx.LLVMCtx, NumBits)));
diff --git a/llvm/unittests/SandboxIR/TypesTest.cpp b/llvm/unittests/SandboxIR/TypesTest.cpp
index 3564ae66830147..40aa32fb08ed01 100644
--- a/llvm/unittests/SandboxIR/TypesTest.cpp
+++ b/llvm/unittests/SandboxIR/TypesTest.cpp
@@ -381,6 +381,65 @@ define void @foo(<4 x i16> %vi0, <4 x float> %vf1, i8 %i0) {
   EXPECT_EQ(Vec8i16Ty->getElementCount(), ElementCount::getFixed(8));
 }
 
+TEST_F(SandboxTypeTest, ScalableVectorType) {
+  parseIR(C, R"IR(
+define void @foo(<vscale x 4 x i16> %vi0, <vscale x 4 x float> %vf1, i8 %i0) {
+  ret void
+}
+)IR");
+  llvm::Function *LLVMF = &*M->getFunction("foo");
+  sandboxir::Context Ctx(C);
+  auto *F = Ctx.createFunction(LLVMF);
+  // Check classof(), creation, accessors
+  auto *Vec4i16Ty =
+      cast<sandboxir::ScalableVectorType>(F->getArg(0)->getType());
+  EXPECT_TRUE(Vec4i16Ty->getElementType()->isIntegerTy(16));
+  EXPECT_EQ(Vec4i16Ty->getMinNumElements(), 4u);
+
+  // get(ElementType, NumElements)
+  EXPECT_EQ(
+      sandboxir::ScalableVectorType::get(sandboxir::Type::getInt16Ty(Ctx), 4),
+      F->getArg(0)->getType());
+  // get(ElementType, Other)
+  EXPECT_EQ(sandboxir::ScalableVectorType::get(
+                sandboxir::Type::getInt16Ty(Ctx),
+                cast<sandboxir::ScalableVectorType>(F->getArg(0)->getType())),
+            F->getArg(0)->getType());
+  auto *Vec4FTy = cast<sandboxir::ScalableVectorType>(F->getArg(1)->getType());
+  EXPECT_TRUE(Vec4FTy->getElementType()->isFloatTy());
+  // getInteger
+  auto *Vec4i32Ty = sandboxir::ScalableVectorType::getInteger(Vec4FTy);
+  EXPECT_TRUE(Vec4i32Ty->getElementType()->isIntegerTy(32));
+  EXPECT_EQ(Vec4i32Ty->getMinNumElements(), Vec4FTy->getMinNumElements());
+  // getExtendedElementCountVectorType
+  auto *Vec4i64Ty =
+      sandboxir::ScalableVectorType::getExtendedElementVectorType(Vec4i16Ty);
+  EXPECT_TRUE(Vec4i64Ty->getElementType()->isIntegerTy(32));
+  EXPECT_EQ(Vec4i64Ty->getMinNumElements(), Vec4i16Ty->getMinNumElements());
+  // getTruncatedElementVectorType
+  auto *Vec4i8Ty =
+      sandboxir::ScalableVectorType::getTruncatedElementVectorType(Vec4i16Ty);
+  EXPECT_TRUE(Vec4i8Ty->getElementType()->isIntegerTy(8));
+  EXPECT_EQ(Vec4i8Ty->getMinNumElements(), Vec4i8Ty->getMinNumElements());
+  // getSubdividedVectorType
+  auto *Vec8i8Ty =
+      sandboxir::ScalableVectorType::getSubdividedVectorType(Vec4i16Ty, 1);
+  EXPECT_TRUE(Vec8i8Ty->getElementType()->isIntegerTy(8));
+  EXPECT_EQ(Vec8i8Ty->getMinNumElements(), 8u);
+  // getMinNumElements
+  EXPECT_EQ(Vec8i8Ty->getMinNumElements(), 8u);
+  // getHalfElementsVectorType
+  auto *Vec2i16Ty =
+      sandboxir::ScalableVectorType::getHalfElementsVectorType(Vec4i16Ty);
+  EXPECT_TRUE(Vec2i16Ty->getElementType()->isIntegerTy(16));
+  EXPECT_EQ(Vec2i16Ty->getMinNumElements(), 2u);
+  // getDoubleElementsVectorType
+  auto *Vec8i16Ty =
+      sandboxir::ScalableVectorType::getDoubleElementsVectorType(Vec4i16Ty);
+  EXPECT_TRUE(Vec8i16Ty->getElementType()->isIntegerTy(16));
+  EXPECT_EQ(Vec8i16Ty->getMinNumElements(), 8u);
+}
+
 TEST_F(SandboxTypeTest, FunctionType) {
   parseIR(C, R"IR(
 define void @foo() {

>From b19051fae84632688196a8094f18320a07ac6989 Mon Sep 17 00:00:00 2001
From: Sterling Augustine <saugustine at google.com>
Date: Tue, 10 Sep 2024 17:59:27 -0700
Subject: [PATCH 2/3] Fix formatting

---
 llvm/include/llvm/SandboxIR/Type.h | 22 +++++++++++-----------
 1 file changed, 11 insertions(+), 11 deletions(-)

diff --git a/llvm/include/llvm/SandboxIR/Type.h b/llvm/include/llvm/SandboxIR/Type.h
index 4a55564d81ff8d..685563bf88228d 100644
--- a/llvm/include/llvm/SandboxIR/Type.h
+++ b/llvm/include/llvm/SandboxIR/Type.h
@@ -45,17 +45,17 @@ class Type {
   friend class VectorType;       // For LLVMTy.
   friend class FixedVectorType;  // For LLVMTy.
   friend class ScalableVectorType; // For LLVMTy.
-  friend class PointerType;      // For LLVMTy.
-  friend class FunctionType;     // For LLVMTy.
-  friend class IntegerType;      // For LLVMTy.
-  friend class Function;         // For LLVMTy.
-  friend class CallBase;         // For LLVMTy.
-  friend class ConstantInt;      // For LLVMTy.
-  friend class ConstantArray;    // For LLVMTy.
-  friend class ConstantStruct;   // For LLVMTy.
-  friend class ConstantVector;   // For LLVMTy.
-  friend class CmpInst;          // For LLVMTy. TODO: Cleanup after
-                                 // sandboxir::VectorType is more complete.
+  friend class PointerType;        // For LLVMTy.
+  friend class FunctionType;       // For LLVMTy.
+  friend class IntegerType;        // For LLVMTy.
+  friend class Function;           // For LLVMTy.
+  friend class CallBase;           // For LLVMTy.
+  friend class ConstantInt;        // For LLVMTy.
+  friend class ConstantArray;      // For LLVMTy.
+  friend class ConstantStruct;     // For LLVMTy.
+  friend class ConstantVector;     // For LLVMTy.
+  friend class CmpInst;            // For LLVMTy. TODO: Cleanup after
+                                   // sandboxir::VectorType is more complete.
 
   // Friend all instruction classes because `create()` functions use LLVMTy.
 #define DEF_INSTR(ID, OPCODE, CLASS) friend class CLASS;

>From 0bc631d7f38bc16b98534c967a7e9f3026abcfd7 Mon Sep 17 00:00:00 2001
From: Sterling Augustine <saugustine at google.com>
Date: Tue, 10 Sep 2024 18:07:35 -0700
Subject: [PATCH 3/3] Another try to fix format

---
 llvm/include/llvm/SandboxIR/Type.h | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/llvm/include/llvm/SandboxIR/Type.h b/llvm/include/llvm/SandboxIR/Type.h
index 685563bf88228d..a2ac9e014b44ab 100644
--- a/llvm/include/llvm/SandboxIR/Type.h
+++ b/llvm/include/llvm/SandboxIR/Type.h
@@ -40,10 +40,10 @@ class StructType;
 class Type {
 protected:
   llvm::Type *LLVMTy;
-  friend class ArrayType;        // For LLVMTy.
-  friend class StructType;       // For LLVMTy.
-  friend class VectorType;       // For LLVMTy.
-  friend class FixedVectorType;  // For LLVMTy.
+  friend class ArrayType;          // For LLVMTy.
+  friend class StructType;         // For LLVMTy.
+  friend class VectorType;         // For LLVMTy.
+  friend class FixedVectorType;    // For LLVMTy.
   friend class ScalableVectorType; // For LLVMTy.
   friend class PointerType;        // For LLVMTy.
   friend class FunctionType;       // For LLVMTy.



More information about the llvm-commits mailing list