[llvm] [SLP] Use known bits of each value in computeMinimumValueSizes (PR #82013)

via llvm-commits llvm-commits at lists.llvm.org
Fri Feb 16 09:42:53 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Philip Reames (preames)

<details>
<summary>Changes</summary>

I'm not entirely sure, but I think the prior code is both imprecise and incorrect.  This was noticed by inspection, not from reducing an actual benchmark so I've struggled to find good tests to illustrate the two points.

The imprecise bit is easier to argue.  The existing code adds one to the bitwidth if the tree root is possibly negative.  This is conservative as we only need the extra bit for the nodes we are going to sign extend.  We don't need it for the ones which have leading zeros; those we can zero extend.

In the test case - which is actually an analysis improvement - we have a sext and or both with 8 required bits.  The previous code computed the max of the former, rounded up to 8, *then* added one more for the additional bit required.

The incorrect bit is the part I'm less sure on.  I'm worried by the implicit assumption in the old code that if the root is non-negative that we don't need to represent the sign bit on the internal nodes.  I can't find an IR test case which exercises this, but consider the following rough sketch.

Given the expression (add (sub a, b), c).  Assign a = 0x00, b = 0x07, and c = 0x07 at runtime (i.e. don't constant fold).  (a-b) = 0xf9, and ((a-b)+c) = 0x00.  Assume assume ComputeNumSign bits can figure out that a has 7 sign bits, b has 5, c has 5, (a - b) has 5, and (a-b)+c has 7 respectively.  Assume computeKnownBits can prove the leading bits on a are zero.

The old and new code should compute the following:
  subexpression         OldBW  NewBW
  (sub a, b)            3      4
  (add (sub a, b), c)   1      1
  Result                3      4

Note that the prior result is wrong for the whole tree as using only 3 bits to represent the result of the sub and then extending produces 0x01 not 0xF9, and thus the add produces 0x08 not 0x0.

Note that the bottom three bits are correct, and I think that's where the conceptual mistake came from.  We *aren't* proving a demanded fact here - all of the bits were demanded.

---
Full diff: https://github.com/llvm/llvm-project/pull/82013.diff


2 Files Affected:

- (modified) llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp (+5-32) 
- (modified) llvm/test/Transforms/SLPVectorizer/X86/minimum-sizes.ll (+6-7) 


``````````diff
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index c54d065cac6382..5cadec82a9b16a 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -13326,11 +13326,6 @@ void BoUpSLP::computeMinimumValueSizes() {
                                      MaxBitWidth);
   }
 
-  // True if the roots can be zero-extended back to their original type, rather
-  // than sign-extended. We know that if the leading bits are not demanded, we
-  // can safely zero-extend. So we initialize IsKnownPositive to True.
-  bool IsKnownPositive = true;
-
   // If all the bits of the roots are demanded, we can try a little harder to
   // compute a narrower type. This can happen, for example, if the roots are
   // getelementptr indices. InstCombine promotes these indices to the pointer
@@ -13347,38 +13342,16 @@ void BoUpSLP::computeMinimumValueSizes() {
       })) {
     MaxBitWidth = 8u;
 
-    // Determine if the sign bit of all the roots is known to be zero. If not,
-    // IsKnownPositive is set to False.
-    IsKnownPositive = llvm::all_of(TreeRoot, [&](Value *R) {
-      KnownBits Known = computeKnownBits(R, *DL);
-      return Known.isNonNegative();
-    });
-
     // Determine the maximum number of bits required to store the scalar
     // values.
     for (auto *Scalar : ToDemote) {
       auto NumSignBits = ComputeNumSignBits(Scalar, *DL, 0, AC, nullptr, DT);
       auto NumTypeBits = DL->getTypeSizeInBits(Scalar->getType());
-      MaxBitWidth = std::max<unsigned>(NumTypeBits - NumSignBits, MaxBitWidth);
-    }
-
-    // If we can't prove that the sign bit is zero, we must add one to the
-    // maximum bit width to account for the unknown sign bit. This preserves
-    // the existing sign bit so we can safely sign-extend the root back to the
-    // original type. Otherwise, if we know the sign bit is zero, we will
-    // zero-extend the root instead.
-    //
-    // FIXME: This is somewhat suboptimal, as there will be cases where adding
-    //        one to the maximum bit width will yield a larger-than-necessary
-    //        type. In general, we need to add an extra bit only if we can't
-    //        prove that the upper bit of the original type is equal to the
-    //        upper bit of the proposed smaller type. If these two bits are the
-    //        same (either zero or one) we know that sign-extending from the
-    //        smaller type will result in the same value. Here, since we can't
-    //        yet prove this, we are just making the proposed smaller type
-    //        larger to ensure correctness.
-    if (!IsKnownPositive)
-      ++MaxBitWidth;
+      KnownBits Known = computeKnownBits(Scalar, *DL);
+      unsigned RequiredSignBit =  !Known.isNonNegative();
+      unsigned LocalBitWidth = NumTypeBits - NumSignBits + RequiredSignBit;
+      MaxBitWidth = std::max<unsigned>(LocalBitWidth, MaxBitWidth);
+    }
   }
 
   // Round MaxBitWidth up to the next power-of-two.
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/minimum-sizes.ll b/llvm/test/Transforms/SLPVectorizer/X86/minimum-sizes.ll
index 651631de2c35ad..35c4c9816594e5 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/minimum-sizes.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/minimum-sizes.ll
@@ -89,13 +89,12 @@ define i8 @PR31243_sext(i8 %v0, i8 %v1, i8 %v2, i8 %v3, ptr %ptr) {
 ; AVX-NEXT:    [[TMP0:%.*]] = insertelement <2 x i8> poison, i8 [[V0:%.*]], i64 0
 ; AVX-NEXT:    [[TMP1:%.*]] = insertelement <2 x i8> [[TMP0]], i8 [[V1:%.*]], i64 1
 ; AVX-NEXT:    [[TMP2:%.*]] = or <2 x i8> [[TMP1]], <i8 1, i8 1>
-; AVX-NEXT:    [[TMP3:%.*]] = sext <2 x i8> [[TMP2]] to <2 x i16>
-; AVX-NEXT:    [[TMP4:%.*]] = extractelement <2 x i16> [[TMP3]], i64 0
-; AVX-NEXT:    [[TMP5:%.*]] = sext i16 [[TMP4]] to i64
-; AVX-NEXT:    [[T4:%.*]] = getelementptr inbounds i8, ptr [[PTR:%.*]], i64 [[TMP5]]
-; AVX-NEXT:    [[TMP6:%.*]] = extractelement <2 x i16> [[TMP3]], i64 1
-; AVX-NEXT:    [[TMP7:%.*]] = sext i16 [[TMP6]] to i64
-; AVX-NEXT:    [[T5:%.*]] = getelementptr inbounds i8, ptr [[PTR]], i64 [[TMP7]]
+; AVX-NEXT:    [[TMP3:%.*]] = extractelement <2 x i8> [[TMP2]], i64 0
+; AVX-NEXT:    [[TMP4:%.*]] = sext i8 [[TMP3]] to i64
+; AVX-NEXT:    [[T4:%.*]] = getelementptr inbounds i8, ptr [[PTR:%.*]], i64 [[TMP4]]
+; AVX-NEXT:    [[TMP5:%.*]] = extractelement <2 x i8> [[TMP2]], i64 1
+; AVX-NEXT:    [[TMP6:%.*]] = sext i8 [[TMP5]] to i64
+; AVX-NEXT:    [[T5:%.*]] = getelementptr inbounds i8, ptr [[PTR]], i64 [[TMP6]]
 ; AVX-NEXT:    [[T6:%.*]] = load i8, ptr [[T4]], align 1
 ; AVX-NEXT:    [[T7:%.*]] = load i8, ptr [[T5]], align 1
 ; AVX-NEXT:    [[T8:%.*]] = add i8 [[T6]], [[T7]]

``````````

</details>


https://github.com/llvm/llvm-project/pull/82013


More information about the llvm-commits mailing list