[llvm] r289470 - [SLP] Fix sign-extends for type-shrinking

Matthew Simpson via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 12 13:11:05 PST 2016


Author: mssimpso
Date: Mon Dec 12 15:11:04 2016
New Revision: 289470

URL: http://llvm.org/viewvc/llvm-project?rev=289470&view=rev
Log:
[SLP] Fix sign-extends for type-shrinking

This patch ensures the correct minimum bit width during type-shrinking.
Previously when type-shrinking, we always sign-extended values back to their
original width. However, if we are going to sign-extend, and the sign bit is
unknown, we have to increase the minimum bit width by one bit so the
sign-extend will fill the upper bits correctly. If the sign bit is known to be
zero, we can perform a zero-extend instead. This should fix PR31243.

Reference: https://llvm.org/bugs/show_bug.cgi?id=31243
Differential Revision: https://reviews.llvm.org/D27466

Added:
    llvm/trunk/test/Transforms/SLPVectorizer/X86/minimum-sizes.ll
Modified:
    llvm/trunk/lib/Transforms/Vectorize/SLPVectorizer.cpp

Modified: llvm/trunk/lib/Transforms/Vectorize/SLPVectorizer.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/Vectorize/SLPVectorizer.cpp?rev=289470&r1=289469&r2=289470&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/Vectorize/SLPVectorizer.cpp (original)
+++ llvm/trunk/lib/Transforms/Vectorize/SLPVectorizer.cpp Mon Dec 12 15:11:04 2016
@@ -910,8 +910,11 @@ private:
   IRBuilder<> Builder;
 
   /// A map of scalar integer values to the smallest bit width with which they
-  /// can legally be represented.
-  MapVector<Value *, uint64_t> MinBWs;
+  /// can legally be represented. The values map to (width, signed) pairs,
+  /// where "width" indicates the minimum bit width and "signed" is True if the
+  /// value must be signed-extended, rather than zero-extended, back to its
+  /// original width.
+  MapVector<Value *, std::pair<uint64_t, bool>> MinBWs;
 };
 
 } // end namespace llvm
@@ -1572,8 +1575,8 @@ int BoUpSLP::getEntryCost(TreeEntry *E)
   // If we have computed a smaller type for the expression, update VecTy so
   // that the costs will be accurate.
   if (MinBWs.count(VL[0]))
-    VecTy = VectorType::get(IntegerType::get(F->getContext(), MinBWs[VL[0]]),
-                            VL.size());
+    VecTy = VectorType::get(
+        IntegerType::get(F->getContext(), MinBWs[VL[0]].first), VL.size());
 
   if (E->NeedToGather) {
     if (allConstant(VL))
@@ -1929,10 +1932,12 @@ int BoUpSLP::getTreeCost() {
     auto *VecTy = VectorType::get(EU.Scalar->getType(), BundleWidth);
     auto *ScalarRoot = VectorizableTree[0].Scalars[0];
     if (MinBWs.count(ScalarRoot)) {
-      auto *MinTy = IntegerType::get(F->getContext(), MinBWs[ScalarRoot]);
+      auto *MinTy = IntegerType::get(F->getContext(), MinBWs[ScalarRoot].first);
+      auto Extend =
+          MinBWs[ScalarRoot].second ? Instruction::SExt : Instruction::ZExt;
       VecTy = VectorType::get(MinTy, BundleWidth);
-      ExtractCost += TTI->getExtractWithExtendCost(
-          Instruction::SExt, EU.Scalar->getType(), VecTy, EU.Lane);
+      ExtractCost += TTI->getExtractWithExtendCost(Extend, EU.Scalar->getType(),
+                                                   VecTy, EU.Lane);
     } else {
       ExtractCost +=
           TTI->getVectorInstrCost(Instruction::ExtractElement, VecTy, EU.Lane);
@@ -2718,7 +2723,7 @@ Value *BoUpSLP::vectorizeTree() {
     if (auto *I = dyn_cast<Instruction>(VectorRoot))
       Builder.SetInsertPoint(&*++BasicBlock::iterator(I));
     auto BundleWidth = VectorizableTree[0].Scalars.size();
-    auto *MinTy = IntegerType::get(F->getContext(), MinBWs[ScalarRoot]);
+    auto *MinTy = IntegerType::get(F->getContext(), MinBWs[ScalarRoot].first);
     auto *VecTy = VectorType::get(MinTy, BundleWidth);
     auto *Trunc = Builder.CreateTrunc(VectorRoot, VecTy);
     VectorizableTree[0].VectorizedValue = Trunc;
@@ -2726,6 +2731,16 @@ Value *BoUpSLP::vectorizeTree() {
 
   DEBUG(dbgs() << "SLP: Extracting " << ExternalUses.size() << " values .\n");
 
+  // If necessary, sign-extend or zero-extend ScalarRoot to the larger type
+  // specified by ScalarType.
+  auto extend = [&](Value *ScalarRoot, Value *Ex, Type *ScalarType) {
+    if (!MinBWs.count(ScalarRoot))
+      return Ex;
+    if (MinBWs[ScalarRoot].second)
+      return Builder.CreateSExt(Ex, ScalarType);
+    return Builder.CreateZExt(Ex, ScalarType);
+  };
+
   // Extract all of the elements with the external uses.
   for (const auto &ExternalUse : ExternalUses) {
     Value *Scalar = ExternalUse.Scalar;
@@ -2760,8 +2775,7 @@ Value *BoUpSLP::vectorizeTree() {
               Builder.SetInsertPoint(PH->getIncomingBlock(i)->getTerminator());
             }
             Value *Ex = Builder.CreateExtractElement(Vec, Lane);
-            if (MinBWs.count(ScalarRoot))
-              Ex = Builder.CreateSExt(Ex, Scalar->getType());
+            Ex = extend(ScalarRoot, Ex, Scalar->getType());
             CSEBlocks.insert(PH->getIncomingBlock(i));
             PH->setOperand(i, Ex);
           }
@@ -2769,16 +2783,14 @@ Value *BoUpSLP::vectorizeTree() {
       } else {
         Builder.SetInsertPoint(cast<Instruction>(User));
         Value *Ex = Builder.CreateExtractElement(Vec, Lane);
-        if (MinBWs.count(ScalarRoot))
-          Ex = Builder.CreateSExt(Ex, Scalar->getType());
+        Ex = extend(ScalarRoot, Ex, Scalar->getType());
         CSEBlocks.insert(cast<Instruction>(User)->getParent());
         User->replaceUsesOfWith(Scalar, Ex);
      }
     } else {
       Builder.SetInsertPoint(&F->getEntryBlock().front());
       Value *Ex = Builder.CreateExtractElement(Vec, Lane);
-      if (MinBWs.count(ScalarRoot))
-        Ex = Builder.CreateSExt(Ex, Scalar->getType());
+      Ex = extend(ScalarRoot, Ex, Scalar->getType());
       CSEBlocks.insert(&F->getEntryBlock());
       User->replaceUsesOfWith(Scalar, Ex);
     }
@@ -3499,6 +3511,11 @@ void BoUpSLP::computeMinimumValueSizes()
         Mask.getBitWidth() - Mask.countLeadingZeros(), MaxBitWidth);
   }
 
+  // True if the roots can be zero-extended back to their original type, rather
+  // than sign-extended. We know that if the leading bits are not demanded, we
+  // can safely zero-extend. So we initialize IsKnownPositive to True.
+  bool IsKnownPositive = true;
+
   // If all the bits of the roots are demanded, we can try a little harder to
   // compute a narrower type. This can happen, for example, if the roots are
   // getelementptr indices. InstCombine promotes these indices to the pointer
@@ -3510,11 +3527,41 @@ void BoUpSLP::computeMinimumValueSizes()
   // compute the number of high-order bits we can truncate.
   if (MaxBitWidth == DL->getTypeSizeInBits(TreeRoot[0]->getType())) {
     MaxBitWidth = 8u;
+
+    // Determine if the sign bit of all the roots is known to be zero. If not,
+    // IsKnownPositive is set to False.
+    IsKnownPositive = all_of(TreeRoot, [&](Value *R) {
+      bool KnownZero = false;
+      bool KnownOne = false;
+      ComputeSignBit(R, KnownZero, KnownOne, *DL);
+      return KnownZero;
+    });
+
+    // Determine the maximum number of bits required to store the scalar
+    // values.
     for (auto *Scalar : ToDemote) {
       auto NumSignBits = ComputeNumSignBits(Scalar, *DL, 0, AC, 0, DT);
       auto NumTypeBits = DL->getTypeSizeInBits(Scalar->getType());
       MaxBitWidth = std::max<unsigned>(NumTypeBits - NumSignBits, MaxBitWidth);
     }
+
+    // If we can't prove that the sign bit is zero, we must add one to the
+    // maximum bit width to account for the unknown sign bit. This preserves
+    // the existing sign bit so we can safely sign-extend the root back to the
+    // original type. Otherwise, if we know the sign bit is zero, we will
+    // zero-extend the root instead.
+    //
+    // FIXME: This is somewhat suboptimal, as there will be cases where adding
+    //        one to the maximum bit width will yield a larger-than-necessary
+    //        type. In general, we need to add an extra bit only if we can't
+    //        prove that the upper bit of the original type is equal to the
+    //        upper bit of the proposed smaller type. If these two bits are the
+    //        same (either zero or one) we know that sign-extending from the
+    //        smaller type will result in the same value. Here, since we can't
+    //        yet prove this, we are just making the proposed smaller type
+    //        larger to ensure correctness.
+    if (!IsKnownPositive)
+      ++MaxBitWidth;
   }
 
   // Round MaxBitWidth up to the next power-of-two.
@@ -3534,7 +3581,7 @@ void BoUpSLP::computeMinimumValueSizes()
 
   // Finally, map the values we can demote to the maximum bit with we computed.
   for (auto *Scalar : ToDemote)
-    MinBWs[Scalar] = MaxBitWidth;
+    MinBWs[Scalar] = std::make_pair(MaxBitWidth, !IsKnownPositive);
 }
 
 namespace {

Added: llvm/trunk/test/Transforms/SLPVectorizer/X86/minimum-sizes.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/SLPVectorizer/X86/minimum-sizes.ll?rev=289470&view=auto
==============================================================================
--- llvm/trunk/test/Transforms/SLPVectorizer/X86/minimum-sizes.ll (added)
+++ llvm/trunk/test/Transforms/SLPVectorizer/X86/minimum-sizes.ll Mon Dec 12 15:11:04 2016
@@ -0,0 +1,72 @@
+; RUN: opt -S -slp-threshold=-6 -slp-vectorizer -instcombine < %s | FileCheck %s
+
+target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128"
+target triple = "x86_64-unknown-linux-gnu"
+
+; These tests ensure that we do not regress due to PR31243. Note that we set
+; the SLP threshold to force vectorization even when not profitable.
+
+; CHECK-LABEL: @PR31243_zext
+;
+; When computing minimum sizes, if we can prove the sign bit is zero, we can
+; zero-extend the roots back to their original sizes.
+;
+; CHECK: %[[OR:.+]] = or <2 x i8> {{.*}}, <i8 1, i8 1>
+; CHECK: %[[E0:.+]] = extractelement <2 x i8> %[[OR]], i32 0
+; CHECK: %[[Z0:.+]] = zext i8 %[[E0]] to i64
+; CHECK: getelementptr inbounds i8, i8* %ptr, i64 %[[Z0]]
+; CHECK: %[[E1:.+]] = extractelement <2 x i8> %[[OR]], i32 1
+; CHECK: %[[Z1:.+]] = zext i8 %[[E1]] to i64
+; CHECK: getelementptr inbounds i8, i8* %ptr, i64 %[[Z1]]
+;
+define i8 @PR31243_zext(i8 %v0, i8 %v1, i8 %v2, i8 %v3, i8* %ptr) {
+entry:
+  %tmp0 = zext i8 %v0 to i32
+  %tmp1 = zext i8 %v1 to i32
+  %tmp2 = or i32 %tmp0, 1
+  %tmp3 = or i32 %tmp1, 1
+  %tmp4 = getelementptr inbounds i8, i8* %ptr, i32 %tmp2
+  %tmp5 = getelementptr inbounds i8, i8* %ptr, i32 %tmp3
+  %tmp6 = load i8, i8* %tmp4
+  %tmp7 = load i8, i8* %tmp5
+  %tmp8 = add i8 %tmp6, %tmp7
+  ret i8 %tmp8
+}
+
+; CHECK-LABEL: @PR31243_sext
+;
+; When computing minimum sizes, if we cannot prove the sign bit is zero, we
+; have to include one extra bit for signedness since we will sign-extend the
+; roots.
+;
+; FIXME: This test is suboptimal since the compuation can be performed in i8.
+;        In general, we need to add an extra bit to the maximum bit width only
+;        if we can't prove that the upper bit of the original type is equal to
+;        the upper bit of the proposed smaller type. If these two bits are the
+;        same (either zero or one) we know that sign-extending from the smaller
+;        type will result in the same value. Since we don't yet perform this
+;        optimization, we make the proposed smaller type (i8) larger (i16) to
+;        ensure correctness.
+;
+; CHECK: %[[S0:.+]] = sext <2 x i8> {{.*}} to <2 x i16>
+; CHECK: %[[OR:.+]] = or <2 x i16> %[[S0]], <i16 1, i16 1>
+; CHECK: %[[E0:.+]] = extractelement <2 x i16> %[[OR]], i32 0
+; CHECK: %[[S1:.+]] = sext i16 %[[E0]] to i64
+; CHECK: getelementptr inbounds i8, i8* %ptr, i64 %[[S1]]
+; CHECK: %[[E1:.+]] = extractelement <2 x i16> %[[OR]], i32 1
+; CHECK: %[[S2:.+]] = sext i16 %[[E1]] to i64
+; CHECK: getelementptr inbounds i8, i8* %ptr, i64 %[[S2]]
+;
+define i8 @PR31243_sext(i8 %v0, i8 %v1, i8 %v2, i8 %v3, i8* %ptr) {
+entry:
+  %tmp0 = sext i8 %v0 to i32
+  %tmp1 = sext i8 %v1 to i32
+  %tmp2 = or i32 %tmp0, 1
+  %tmp3 = or i32 %tmp1, 1
+  %tmp4 = getelementptr inbounds i8, i8* %ptr, i32 %tmp2
+  %tmp5 = getelementptr inbounds i8, i8* %ptr, i32 %tmp3
+  %tmp6 = load i8, i8* %tmp4
+  %tmp7 = load i8, i8* %tmp5
+  %tmp8 = add i8 %tmp6, %tmp7
+  ret i8 %tmp8
+}




More information about the llvm-commits mailing list