[llvm] 246f345 - [SLP][REVEC] Make CastInst support vector instructions. (#103216)

via llvm-commits llvm-commits at lists.llvm.org
Tue Aug 13 08:52:35 PDT 2024


Author: Han-Kuan Chen
Date: 2024-08-13T23:52:32+08:00
New Revision: 246f345152e933aa40fd20929b59b5c8ef04ce38

URL: https://github.com/llvm/llvm-project/commit/246f345152e933aa40fd20929b59b5c8ef04ce38
DIFF: https://github.com/llvm/llvm-project/commit/246f345152e933aa40fd20929b59b5c8ef04ce38.diff

LOG: [SLP][REVEC] Make CastInst support vector instructions. (#103216)

Added: 
    

Modified: 
    llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
    llvm/test/Transforms/SLPVectorizer/revec.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index ebfb11f841086..feffd9ae3c99b 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -9877,16 +9877,18 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
     auto *SrcVecTy = getWidenedType(SrcScalarTy, VL.size());
     unsigned Opcode = ShuffleOrOp;
     unsigned VecOpcode = Opcode;
-    if (!ScalarTy->isFloatingPointTy() && !SrcScalarTy->isFloatingPointTy() &&
+    if (!ScalarTy->isFPOrFPVectorTy() && !SrcScalarTy->isFPOrFPVectorTy() &&
         (SrcIt != MinBWs.end() || It != MinBWs.end())) {
       // Check if the values are candidates to demote.
-      unsigned SrcBWSz = DL->getTypeSizeInBits(SrcScalarTy);
+      unsigned SrcBWSz = DL->getTypeSizeInBits(SrcScalarTy->getScalarType());
       if (SrcIt != MinBWs.end()) {
         SrcBWSz = SrcIt->second.first;
+        unsigned SrcScalarTyNumElements = getNumElements(SrcScalarTy);
         SrcScalarTy = IntegerType::get(F->getContext(), SrcBWSz);
-        SrcVecTy = getWidenedType(SrcScalarTy, VL.size());
+        SrcVecTy =
+            getWidenedType(SrcScalarTy, VL.size() * SrcScalarTyNumElements);
       }
-      unsigned BWSz = DL->getTypeSizeInBits(ScalarTy);
+      unsigned BWSz = DL->getTypeSizeInBits(ScalarTy->getScalarType());
       if (BWSz == SrcBWSz) {
         VecOpcode = Instruction::BitCast;
       } else if (BWSz < SrcBWSz) {
@@ -13452,14 +13454,14 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
       Instruction::CastOps VecOpcode = CI->getOpcode();
       Type *SrcScalarTy = cast<VectorType>(InVec->getType())->getElementType();
       auto SrcIt = MinBWs.find(getOperandEntry(E, 0));
-      if (!ScalarTy->isFloatingPointTy() && !SrcScalarTy->isFloatingPointTy() &&
+      if (!ScalarTy->isFPOrFPVectorTy() && !SrcScalarTy->isFPOrFPVectorTy() &&
           (SrcIt != MinBWs.end() || It != MinBWs.end() ||
-           SrcScalarTy != CI->getOperand(0)->getType())) {
+           SrcScalarTy != CI->getOperand(0)->getType()->getScalarType())) {
         // Check if the values are candidates to demote.
         unsigned SrcBWSz = DL->getTypeSizeInBits(SrcScalarTy);
         if (SrcIt != MinBWs.end())
           SrcBWSz = SrcIt->second.first;
-        unsigned BWSz = DL->getTypeSizeInBits(ScalarTy);
+        unsigned BWSz = DL->getTypeSizeInBits(ScalarTy->getScalarType());
         if (BWSz == SrcBWSz) {
           VecOpcode = Instruction::BitCast;
         } else if (BWSz < SrcBWSz) {

diff  --git a/llvm/test/Transforms/SLPVectorizer/revec.ll b/llvm/test/Transforms/SLPVectorizer/revec.ll
index 31ee107c81cd4..59201da1d9ac1 100644
--- a/llvm/test/Transforms/SLPVectorizer/revec.ll
+++ b/llvm/test/Transforms/SLPVectorizer/revec.ll
@@ -296,3 +296,43 @@ for.body13:                                       ; preds = %for.body13, %entry
   store <4 x i32> %vmovl.i110, ptr %add.ptr29, align 4
   br label %for.body13
 }
+
+define void @test10() {
+; CHECK-LABEL: @test10(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = load <16 x i8>, ptr null, align 1
+; CHECK-NEXT:    [[TMP1:%.*]] = call <32 x i8> @llvm.vector.insert.v32i8.v16i8(<32 x i8> poison, <16 x i8> poison, i64 16)
+; CHECK-NEXT:    [[TMP2:%.*]] = call <32 x i8> @llvm.vector.insert.v32i8.v16i8(<32 x i8> [[TMP1]], <16 x i8> [[TMP0]], i64 0)
+; CHECK-NEXT:    [[TMP3:%.*]] = shufflevector <32 x i8> [[TMP2]], <32 x i8> poison, <32 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15, i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
+; CHECK-NEXT:    [[TMP4:%.*]] = shufflevector <32 x i8> [[TMP3]], <32 x i8> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 24, i32 25, i32 26, i32 27, i32 28, i32 29, i32 30, i32 31>
+; CHECK-NEXT:    [[TMP5:%.*]] = sext <16 x i8> [[TMP4]] to <16 x i16>
+; CHECK-NEXT:    [[TMP6:%.*]] = shufflevector <16 x i16> [[TMP5]], <16 x i16> poison, <32 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
+; CHECK-NEXT:    [[TMP7:%.*]] = shufflevector <32 x i16> [[TMP6]], <32 x i16> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 12, i32 13, i32 14, i32 15, i32 16, i32 17, i32 18, i32 19, i32 28, i32 29, i32 30, i32 31>
+; CHECK-NEXT:    [[TMP8:%.*]] = trunc <16 x i16> [[TMP7]] to <16 x i8>
+; CHECK-NEXT:    [[TMP9:%.*]] = sext <16 x i8> [[TMP8]] to <16 x i32>
+; CHECK-NEXT:    store <16 x i32> [[TMP9]], ptr null, align 4
+; CHECK-NEXT:    ret void
+;
+entry:
+  %0 = load <16 x i8>, ptr null, align 1
+  %shuffle.i = shufflevector <16 x i8> %0, <16 x i8> zeroinitializer, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
+  %shuffle.i107 = shufflevector <16 x i8> %0, <16 x i8> zeroinitializer, <8 x i32> <i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
+  %vmovl.i106 = sext <8 x i8> %shuffle.i to <8 x i16>
+  %vmovl.i = sext <8 x i8> %shuffle.i107 to <8 x i16>
+  %shuffle.i113 = shufflevector <8 x i16> %vmovl.i106, <8 x i16> zeroinitializer, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %shuffle.i115 = shufflevector <8 x i16> %vmovl.i106, <8 x i16> zeroinitializer, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
+  %shuffle.i112 = shufflevector <8 x i16> %vmovl.i, <8 x i16> zeroinitializer, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %shuffle.i114 = shufflevector <8 x i16> %vmovl.i, <8 x i16> zeroinitializer, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
+  %vmovl.i111 = sext <4 x i16> %shuffle.i113 to <4 x i32>
+  %vmovl.i110 = sext <4 x i16> %shuffle.i115 to <4 x i32>
+  %vmovl.i109 = sext <4 x i16> %shuffle.i112 to <4 x i32>
+  %vmovl.i108 = sext <4 x i16> %shuffle.i114 to <4 x i32>
+  %add.ptr29 = getelementptr i8, ptr null, i64 16
+  %add.ptr32 = getelementptr i8, ptr null, i64 32
+  %add.ptr35 = getelementptr i8, ptr null, i64 48
+  store <4 x i32> %vmovl.i111, ptr null, align 4
+  store <4 x i32> %vmovl.i110, ptr %add.ptr29, align 4
+  store <4 x i32> %vmovl.i109, ptr %add.ptr32, align 4
+  store <4 x i32> %vmovl.i108, ptr %add.ptr35, align 4
+  ret void
+}


        


More information about the llvm-commits mailing list