[llvm] [SLP] Use known bits of each value in computeMinimumValueSizes (PR #82013)

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Fri Feb 16 09:42:17 PST 2024


https://github.com/preames created https://github.com/llvm/llvm-project/pull/82013

I'm not entirely sure, but I think the prior code is both imprecise and incorrect.  This was noticed by inspection, not from reducing an actual benchmark so I've struggled to find good tests to illustrate the two points.

The imprecise bit is easier to argue.  The existing code adds one to the bitwidth if the tree root is possibly negative.  This is conservative as we only need the extra bit for the nodes we are going to sign extend.  We don't need it for the ones which have leading zeros; those we can zero extend.

In the test case - which is actually an analysis improvement - we have a sext and or both with 8 required bits.  The previous code computed the max of the former, rounded up to 8, *then* added one more for the additional bit required.

The incorrect bit is the part I'm less sure on.  I'm worried by the implicit assumption in the old code that if the root is non-negative that we don't need to represent the sign bit on the internal nodes.  I can't find an IR test case which exercises this, but consider the following rough sketch.

Given the expression (add (sub a, b), c).  Assign a = 0x00, b = 0x07, and c = 0x07 at runtime (i.e. don't constant fold).  (a-b) = 0xf9, and ((a-b)+c) = 0x00.  Assume assume ComputeNumSign bits can figure out that a has 7 sign bits, b has 5, c has 5, (a - b) has 5, and (a-b)+c has 7 respectively.  Assume computeKnownBits can prove the leading bits on a are zero.

The old and new code should compute the following:
  subexpression         OldBW  NewBW
  (sub a, b)            3      4
  (add (sub a, b), c)   1      1
  Result                3      4

Note that the prior result is wrong for the whole tree as using only 3 bits to represent the result of the sub and then extending produces 0x01 not 0xF9, and thus the add produces 0x08 not 0x0.

Note that the bottom three bits are correct, and I think that's where the conceptual mistake came from.  We *aren't* proving a demanded fact here - all of the bits were demanded.

>From 19b3f18bbfe426934abd40283cfc4df03014b9dc Mon Sep 17 00:00:00 2001
From: Philip Reames <preames at rivosinc.com>
Date: Fri, 16 Feb 2024 07:25:40 -0800
Subject: [PATCH] [SLP] Use known bits of each value in
 computeMinimumValueSizes

I'm not entirely sure, but I think the prior code is both imprecise
and incorrect.  This was noticed by inspection, not from reducing
an actual benchmark so I've struggled to find good tests to illustrate
the two points.

The imprecise bit is easier to argue.  The existing code adds one
to the bitwidth if the tree root is possibly negative.  This is
conservative as we only need the extra bit for the nodes we are
going to sign extend.  We don't need it for the ones which have
leading zeros; those we can zero extend.

In the test case - which is actually an analysis improvement - we
have a sext and or both with 8 required bits.  The previous code
computed the max of the former, rounded up to 8, *then* added
one more for the additional bit required.

The incorrect bit is the part I'm less sure on.  I'm worried by
the implicit assumption in the old code that if the root is
non-negative that we don't need to represent the sign bit on the
internal nodes.  I can't find an IR test case which exercises this,
but consider the following rough sketch.

Given the expression (add (sub a, b), c).  Assign a = 0x00, b = 0x07,
and c = 0x07 at runtime (i.e. don't constant fold).  (a-b) = 0xf9,
and ((a-b)+c) = 0x00.  Assume assume ComputeNumSign bits can figure
out that a has 7 sign bits, b has 5, c has 5, (a - b) has 5, and
(a-b)+c has 7 respectively.  Assume computeKnownBits can prove the
leading bits on a are zero.

The old and new code should compute the following:
  subexpression         OldBW  NewBW
  (sub a, b)            3      4
  (add (sub a, b), c)   1      1
  Result                3      4

Note that the prior result is wrong for the whole tree as using only
3 bits to represent the result of the sub and then extending produces
0x01 not 0xF9, and thus the add produces 0x08 not 0x0.

Note that the bottom three bits are correct, and I think that's where
the conceptual mistake came from.  We *aren't* proving a demanded fact
here - all of the bits were demanded.
---
 .../Transforms/Vectorize/SLPVectorizer.cpp    | 37 +++----------------
 .../SLPVectorizer/X86/minimum-sizes.ll        | 13 +++----
 2 files changed, 11 insertions(+), 39 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index c54d065cac6382..5cadec82a9b16a 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -13326,11 +13326,6 @@ void BoUpSLP::computeMinimumValueSizes() {
                                      MaxBitWidth);
   }
 
-  // True if the roots can be zero-extended back to their original type, rather
-  // than sign-extended. We know that if the leading bits are not demanded, we
-  // can safely zero-extend. So we initialize IsKnownPositive to True.
-  bool IsKnownPositive = true;
-
   // If all the bits of the roots are demanded, we can try a little harder to
   // compute a narrower type. This can happen, for example, if the roots are
   // getelementptr indices. InstCombine promotes these indices to the pointer
@@ -13347,38 +13342,16 @@ void BoUpSLP::computeMinimumValueSizes() {
       })) {
     MaxBitWidth = 8u;
 
-    // Determine if the sign bit of all the roots is known to be zero. If not,
-    // IsKnownPositive is set to False.
-    IsKnownPositive = llvm::all_of(TreeRoot, [&](Value *R) {
-      KnownBits Known = computeKnownBits(R, *DL);
-      return Known.isNonNegative();
-    });
-
     // Determine the maximum number of bits required to store the scalar
     // values.
     for (auto *Scalar : ToDemote) {
       auto NumSignBits = ComputeNumSignBits(Scalar, *DL, 0, AC, nullptr, DT);
       auto NumTypeBits = DL->getTypeSizeInBits(Scalar->getType());
-      MaxBitWidth = std::max<unsigned>(NumTypeBits - NumSignBits, MaxBitWidth);
-    }
-
-    // If we can't prove that the sign bit is zero, we must add one to the
-    // maximum bit width to account for the unknown sign bit. This preserves
-    // the existing sign bit so we can safely sign-extend the root back to the
-    // original type. Otherwise, if we know the sign bit is zero, we will
-    // zero-extend the root instead.
-    //
-    // FIXME: This is somewhat suboptimal, as there will be cases where adding
-    //        one to the maximum bit width will yield a larger-than-necessary
-    //        type. In general, we need to add an extra bit only if we can't
-    //        prove that the upper bit of the original type is equal to the
-    //        upper bit of the proposed smaller type. If these two bits are the
-    //        same (either zero or one) we know that sign-extending from the
-    //        smaller type will result in the same value. Here, since we can't
-    //        yet prove this, we are just making the proposed smaller type
-    //        larger to ensure correctness.
-    if (!IsKnownPositive)
-      ++MaxBitWidth;
+      KnownBits Known = computeKnownBits(Scalar, *DL);
+      unsigned RequiredSignBit =  !Known.isNonNegative();
+      unsigned LocalBitWidth = NumTypeBits - NumSignBits + RequiredSignBit;
+      MaxBitWidth = std::max<unsigned>(LocalBitWidth, MaxBitWidth);
+    }
   }
 
   // Round MaxBitWidth up to the next power-of-two.
diff --git a/llvm/test/Transforms/SLPVectorizer/X86/minimum-sizes.ll b/llvm/test/Transforms/SLPVectorizer/X86/minimum-sizes.ll
index 651631de2c35ad..35c4c9816594e5 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/minimum-sizes.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/minimum-sizes.ll
@@ -89,13 +89,12 @@ define i8 @PR31243_sext(i8 %v0, i8 %v1, i8 %v2, i8 %v3, ptr %ptr) {
 ; AVX-NEXT:    [[TMP0:%.*]] = insertelement <2 x i8> poison, i8 [[V0:%.*]], i64 0
 ; AVX-NEXT:    [[TMP1:%.*]] = insertelement <2 x i8> [[TMP0]], i8 [[V1:%.*]], i64 1
 ; AVX-NEXT:    [[TMP2:%.*]] = or <2 x i8> [[TMP1]], <i8 1, i8 1>
-; AVX-NEXT:    [[TMP3:%.*]] = sext <2 x i8> [[TMP2]] to <2 x i16>
-; AVX-NEXT:    [[TMP4:%.*]] = extractelement <2 x i16> [[TMP3]], i64 0
-; AVX-NEXT:    [[TMP5:%.*]] = sext i16 [[TMP4]] to i64
-; AVX-NEXT:    [[T4:%.*]] = getelementptr inbounds i8, ptr [[PTR:%.*]], i64 [[TMP5]]
-; AVX-NEXT:    [[TMP6:%.*]] = extractelement <2 x i16> [[TMP3]], i64 1
-; AVX-NEXT:    [[TMP7:%.*]] = sext i16 [[TMP6]] to i64
-; AVX-NEXT:    [[T5:%.*]] = getelementptr inbounds i8, ptr [[PTR]], i64 [[TMP7]]
+; AVX-NEXT:    [[TMP3:%.*]] = extractelement <2 x i8> [[TMP2]], i64 0
+; AVX-NEXT:    [[TMP4:%.*]] = sext i8 [[TMP3]] to i64
+; AVX-NEXT:    [[T4:%.*]] = getelementptr inbounds i8, ptr [[PTR:%.*]], i64 [[TMP4]]
+; AVX-NEXT:    [[TMP5:%.*]] = extractelement <2 x i8> [[TMP2]], i64 1
+; AVX-NEXT:    [[TMP6:%.*]] = sext i8 [[TMP5]] to i64
+; AVX-NEXT:    [[T5:%.*]] = getelementptr inbounds i8, ptr [[PTR]], i64 [[TMP6]]
 ; AVX-NEXT:    [[T6:%.*]] = load i8, ptr [[T4]], align 1
 ; AVX-NEXT:    [[T7:%.*]] = load i8, ptr [[T5]], align 1
 ; AVX-NEXT:    [[T8:%.*]] = add i8 [[T6]], [[T7]]



More information about the llvm-commits mailing list