[llvm] 285bc69 - [SLP]Fix PR80027: Fix costs processing for minbitwidth types.

Alexey Bataev via llvm-commits llvm-commits at lists.llvm.org
Tue Jan 30 10:36:32 PST 2024


Author: Alexey Bataev
Date: 2024-01-30T10:32:55-08:00
New Revision: 285bc69846e76af805cd106ea3ea538a12f5c9b6

URL: https://github.com/llvm/llvm-project/commit/285bc69846e76af805cd106ea3ea538a12f5c9b6
DIFF: https://github.com/llvm/llvm-project/commit/285bc69846e76af805cd106ea3ea538a12f5c9b6.diff

LOG: [SLP]Fix PR80027: Fix costs processing for minbitwidth types.

Need to switch the types, the destination is first in getCastInstrCost
function.

Added: 
    llvm/test/Transforms/SLPVectorizer/SystemZ/minbitwidth-trunc.ll

Modified: 
    llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index e0a779467b1fa..bde65717ac1d4 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -7888,7 +7888,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
               unsigned BWSz = DL->getTypeSizeInBits(ScalarTy);
               unsigned SrcBWSz = DL->getTypeSizeInBits(UserScalarTy);
               unsigned VecOpcode;
-              auto *SrcVecTy =
+              auto *UserVecTy =
                   FixedVectorType::get(UserScalarTy, E->getVectorFactor());
               if (BWSz > SrcBWSz)
                 VecOpcode = Instruction::Trunc;
@@ -7896,11 +7896,10 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
                 VecOpcode =
                     It->second.second ? Instruction::SExt : Instruction::ZExt;
               TTI::CastContextHint CCH = GetCastContextHint(VL0);
-              VecCost += TTI->getCastInstrCost(VecOpcode, VecTy, SrcVecTy, CCH,
+              VecCost += TTI->getCastInstrCost(VecOpcode, UserVecTy, VecTy, CCH,
                                                CostKind);
-              ScalarCost +=
-                  Sz * TTI->getCastInstrCost(VecOpcode, ScalarTy, UserScalarTy,
-                                             CCH, CostKind);
+              ScalarCost += Sz * TTI->getCastInstrCost(VecOpcode, UserScalarTy,
+                                                       ScalarTy, CCH, CostKind);
             }
           }
         }
@@ -8981,7 +8980,7 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
   SmallVector<std::pair<Value *, const TreeEntry *>> FirstUsers;
   SmallVector<APInt> DemandedElts;
   SmallDenseSet<Value *, 4> UsedInserts;
-  DenseSet<Value *> VectorCasts;
+  DenseSet<std::pair<const TreeEntry *, Type *>> VectorCasts;
   for (ExternalUser &EU : ExternalUses) {
     // We only add extract cost once for the same scalar.
     if (!isa_and_nonnull<InsertElementInst>(EU.User) &&
@@ -9051,11 +9050,14 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
             DemandedElts.push_back(APInt::getZero(FTy->getNumElements()));
             VecId = FirstUsers.size() - 1;
             auto It = MinBWs.find(ScalarTE);
-            if (It != MinBWs.end() && VectorCasts.insert(EU.Scalar).second) {
+            if (It != MinBWs.end() &&
+                VectorCasts
+                    .insert(std::make_pair(ScalarTE, FTy->getElementType()))
+                    .second) {
               unsigned BWSz = It->second.second;
-              unsigned SrcBWSz = DL->getTypeSizeInBits(FTy->getElementType());
+              unsigned DstBWSz = DL->getTypeSizeInBits(FTy->getElementType());
               unsigned VecOpcode;
-              if (BWSz < SrcBWSz)
+              if (DstBWSz < BWSz)
                 VecOpcode = Instruction::Trunc;
               else
                 VecOpcode =
@@ -9108,17 +9110,20 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
   }
   // Add reduced value cost, if resized.
   if (!VectorizedVals.empty()) {
-    auto BWIt = MinBWs.find(VectorizableTree.front().get());
+    const TreeEntry &Root = *VectorizableTree.front().get();
+    auto BWIt = MinBWs.find(&Root);
     if (BWIt != MinBWs.end()) {
-      Type *DstTy = VectorizableTree.front()->Scalars.front()->getType();
+      Type *DstTy = Root.Scalars.front()->getType();
       unsigned OriginalSz = DL->getTypeSizeInBits(DstTy);
-      unsigned Opcode = Instruction::Trunc;
-      if (OriginalSz < BWIt->second.first)
-        Opcode = BWIt->second.second ? Instruction::SExt : Instruction::ZExt;
-      Type *SrcTy = IntegerType::get(DstTy->getContext(), BWIt->second.first);
-      Cost += TTI->getCastInstrCost(Opcode, DstTy, SrcTy,
-                                    TTI::CastContextHint::None,
-                                    TTI::TCK_RecipThroughput);
+      if (OriginalSz != BWIt->second.first) {
+        unsigned Opcode = Instruction::Trunc;
+        if (OriginalSz < BWIt->second.first)
+          Opcode = BWIt->second.second ? Instruction::SExt : Instruction::ZExt;
+        Type *SrcTy = IntegerType::get(DstTy->getContext(), BWIt->second.first);
+        Cost += TTI->getCastInstrCost(Opcode, DstTy, SrcTy,
+                                      TTI::CastContextHint::None,
+                                      TTI::TCK_RecipThroughput);
+      }
     }
   }
 
@@ -11419,9 +11424,10 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
           VecOpcode = Instruction::BitCast;
         } else if (BWSz < SrcBWSz) {
           VecOpcode = Instruction::Trunc;
-        } else if (It != MinBWs.end()) {
+        } else if (SrcIt != MinBWs.end()) {
           assert(BWSz > SrcBWSz && "Invalid cast!");
-          VecOpcode = It->second.second ? Instruction::SExt : Instruction::ZExt;
+          VecOpcode =
+              SrcIt->second.second ? Instruction::SExt : Instruction::ZExt;
         }
       }
       Value *V = (VecOpcode != ShuffleOrOp && VecOpcode == Instruction::BitCast)
@@ -11929,7 +11935,7 @@ Value *BoUpSLP::vectorizeTree(
   // basic block. Only one extractelement per block should be emitted.
   DenseMap<Value *, DenseMap<BasicBlock *, Instruction *>> ScalarToEEs;
   SmallDenseSet<Value *, 4> UsedInserts;
-  DenseMap<Value *, Value *> VectorCasts;
+  DenseMap<std::pair<Value *, Type *>, Value *> VectorCasts;
   SmallDenseSet<Value *, 4> ScalarsWithNullptrUser;
   // Extract all of the elements with the external uses.
   for (const auto &ExternalUse : ExternalUses) {
@@ -12050,7 +12056,9 @@ Value *BoUpSLP::vectorizeTree(
           // Need to use original vector, if the root is truncated.
           auto BWIt = MinBWs.find(E);
           if (BWIt != MinBWs.end() && Vec->getType() != VU->getType()) {
-            auto VecIt = VectorCasts.find(Scalar);
+            auto *ScalarTy = FTy->getElementType();
+            auto Key = std::make_pair(Vec, ScalarTy);
+            auto VecIt = VectorCasts.find(Key);
             if (VecIt == VectorCasts.end()) {
               IRBuilder<>::InsertPointGuard Guard(Builder);
               if (auto *IVec = dyn_cast<Instruction>(Vec))
@@ -12058,10 +12066,10 @@ Value *BoUpSLP::vectorizeTree(
               Vec = Builder.CreateIntCast(
                   Vec,
                   FixedVectorType::get(
-                      cast<VectorType>(VU->getType())->getElementType(),
+                      ScalarTy,
                       cast<FixedVectorType>(Vec->getType())->getNumElements()),
                   BWIt->second.second);
-              VectorCasts.try_emplace(Scalar, Vec);
+              VectorCasts.try_emplace(Key, Vec);
             } else {
               Vec = VecIt->second;
             }

diff  --git a/llvm/test/Transforms/SLPVectorizer/SystemZ/minbitwidth-trunc.ll b/llvm/test/Transforms/SLPVectorizer/SystemZ/minbitwidth-trunc.ll
new file mode 100644
index 0000000000000..e1942eb326079
--- /dev/null
+++ b/llvm/test/Transforms/SLPVectorizer/SystemZ/minbitwidth-trunc.ll
@@ -0,0 +1,48 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt -S --passes=slp-vectorizer -mtriple=s390x-unknown-linux -mcpu=z14 < %s | FileCheck %s
+
+define void @test() {
+; CHECK-LABEL: define void @test(
+; CHECK-SAME: ) #[[ATTR0:[0-9]+]] {
+; CHECK-NEXT:    [[TMP1:%.*]] = zext i8 0 to i32
+; CHECK-NEXT:    [[TMP2:%.*]] = zext i8 0 to i32
+; CHECK-NEXT:    [[TMP3:%.*]] = insertelement <4 x i32> <i32 0, i32 poison, i32 0, i32 0>, i32 [[TMP2]], i32 1
+; CHECK-NEXT:    [[TMP4:%.*]] = select <4 x i1> zeroinitializer, <4 x i32> zeroinitializer, <4 x i32> [[TMP3]]
+; CHECK-NEXT:    [[TMP5:%.*]] = select i1 false, i32 0, i32 0
+; CHECK-NEXT:    [[TMP6:%.*]] = select i1 false, i32 0, i32 [[TMP1]]
+; CHECK-NEXT:    [[TMP7:%.*]] = select i1 false, i32 0, i32 [[TMP2]]
+; CHECK-NEXT:    [[TMP8:%.*]] = call i32 @llvm.vector.reduce.xor.v4i32(<4 x i32> [[TMP4]])
+; CHECK-NEXT:    [[OP_RDX:%.*]] = xor i32 [[TMP8]], [[TMP5]]
+; CHECK-NEXT:    [[OP_RDX1:%.*]] = xor i32 [[TMP6]], [[TMP7]]
+; CHECK-NEXT:    [[OP_RDX2:%.*]] = xor i32 [[OP_RDX]], [[OP_RDX1]]
+; CHECK-NEXT:    [[TMP9:%.*]] = trunc i32 [[OP_RDX2]] to i16
+; CHECK-NEXT:    store i16 [[TMP9]], ptr null, align 2
+; CHECK-NEXT:    ret void
+;
+  %1 = zext i8 0 to i32
+  %.not = icmp sgt i32 0, %1
+  %2 = zext i8 0 to i32
+  %3 = select i1 %.not, i32 0, i32 0
+  %4 = zext i8 0 to i32
+  %.not.1 = icmp sgt i32 0, %4
+  %5 = zext i8 0 to i32
+  %6 = select i1 %.not.1, i32 0, i32 %5
+  %7 = xor i32 %6, %3
+  %8 = zext i8 0 to i32
+  %.not.2 = icmp sgt i32 0, %8
+  %9 = select i1 %.not.2, i32 0, i32 0
+  %10 = xor i32 %9, %7
+  %11 = zext i8 0 to i32
+  %.not.3 = icmp sgt i32 0, %11
+  %12 = select i1 %.not.3, i32 0, i32 0
+  %13 = xor i32 %12, %10
+  %14 = select i1 false, i32 0, i32 0
+  %15 = xor i32 %14, %13
+  %16 = select i1 false, i32 0, i32 %2
+  %17 = xor i32 %16, %15
+  %18 = select i1 false, i32 0, i32 %5
+  %19 = xor i32 %18, %17
+  %20 = trunc i32 %19 to i16
+  store i16 %20, ptr null, align 2
+  ret void
+}


        


More information about the llvm-commits mailing list