[llvm] 781d077 - [InstCombine] reassociateShiftAmtsOfTwoSameDirectionShifts(): fix miscompile (PR44802)

Roman Lebedev via llvm-commits llvm-commits at lists.llvm.org
Tue Feb 25 07:25:15 PST 2020


Author: Roman Lebedev
Date: 2020-02-25T18:23:51+03:00
New Revision: 781d077afb0ed9771c513d064c40170c1ccd21c9

URL: https://github.com/llvm/llvm-project/commit/781d077afb0ed9771c513d064c40170c1ccd21c9
DIFF: https://github.com/llvm/llvm-project/commit/781d077afb0ed9771c513d064c40170c1ccd21c9.diff

LOG: [InstCombine] reassociateShiftAmtsOfTwoSameDirectionShifts(): fix miscompile (PR44802)

As input, we have the following pattern:
  Sh0 (Sh1 X, Q), K
We want to rewrite that as:
  Sh x, (Q+K)  iff (Q+K) u< bitwidth(x)
While we know that originally (Q+K) would not overflow
(because  2 * (N-1) u<= iN -1), we may have looked past extensions of
shift amounts. so it may now overflow in smaller bitwidth.

To ensure that does not happen, we need to ensure that the total maximal
shift amount is still representable in that smaller bitwidth.
If the overflow would happen, (Q+K) u< bitwidth(x) check would be bogus.

https://bugs.llvm.org/show_bug.cgi?id=44802

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
    llvm/test/Transforms/InstCombine/shift-amount-reassociation.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index 49d6443d2277..3624295d0343 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -23,8 +23,11 @@ using namespace PatternMatch;
 // Given pattern:
 //   (x shiftopcode Q) shiftopcode K
 // we should rewrite it as
-//   x shiftopcode (Q+K)  iff (Q+K) u< bitwidth(x)
-// This is valid for any shift, but they must be identical.
+//   x shiftopcode (Q+K)  iff (Q+K) u< bitwidth(x) and
+//
+// This is valid for any shift, but they must be identical, and we must be
+// careful in case we have (zext(Q)+zext(K)) and look past extensions,
+// (Q+K) must not overflow or else (Q+K) u< bitwidth(x) is bogus.
 //
 // AnalyzeForSignBitExtraction indicates that we will only analyze whether this
 // pattern has any 2 right-shifts that sum to 1 less than original bit width.
@@ -58,6 +61,23 @@ Value *InstCombiner::reassociateShiftAmtsOfTwoSameDirectionShifts(
   if (ShAmt0->getType() != ShAmt1->getType())
     return nullptr;
 
+  // As input, we have the following pattern:
+  //   Sh0 (Sh1 X, Q), K
+  // We want to rewrite that as:
+  //   Sh x, (Q+K)  iff (Q+K) u< bitwidth(x)
+  // While we know that originally (Q+K) would not overflow
+  // (because  2 * (N-1) u<= iN -1), we have looked past extensions of
+  // shift amounts. so it may now overflow in smaller bitwidth.
+  // To ensure that does not happen, we need to ensure that the total maximal
+  // shift amount is still representable in that smaller bit width.
+  unsigned MaximalPossibleTotalShiftAmount =
+      (Sh0->getType()->getScalarSizeInBits() - 1) +
+      (Sh1->getType()->getScalarSizeInBits() - 1);
+  APInt MaximalRepresentableShiftAmount =
+      APInt::getAllOnesValue(ShAmt0->getType()->getScalarSizeInBits());
+  if (MaximalRepresentableShiftAmount.ult(MaximalPossibleTotalShiftAmount))
+    return nullptr;
+
   // We are only looking for signbit extraction if we have two right shifts.
   bool HadTwoRightShifts = match(Sh0, m_Shr(m_Value(), m_Value())) &&
                            match(Sh1, m_Shr(m_Value(), m_Value()));

diff  --git a/llvm/test/Transforms/InstCombine/shift-amount-reassociation.ll b/llvm/test/Transforms/InstCombine/shift-amount-reassociation.ll
index 0b8187d04172..96461691e70b 100644
--- a/llvm/test/Transforms/InstCombine/shift-amount-reassociation.ll
+++ b/llvm/test/Transforms/InstCombine/shift-amount-reassociation.ll
@@ -320,12 +320,15 @@ define i32 @n20(i32 %x, i32 %y) {
   ret i32 %t3
 }
 
-; FIXME: this is a miscompile. We should not transform this.
 ; See https://bugs.llvm.org/show_bug.cgi?id=44802
 define i3 @pr44802(i3 %t0) {
 ; CHECK-LABEL: @pr44802(
 ; CHECK-NEXT:    [[T1:%.*]] = sub i3 0, [[T0:%.*]]
-; CHECK-NEXT:    ret i3 [[T1]]
+; CHECK-NEXT:    [[T2:%.*]] = icmp ne i3 [[T0]], 0
+; CHECK-NEXT:    [[T3:%.*]] = zext i1 [[T2]] to i3
+; CHECK-NEXT:    [[T4:%.*]] = lshr i3 [[T1]], [[T3]]
+; CHECK-NEXT:    [[T5:%.*]] = lshr i3 [[T4]], [[T3]]
+; CHECK-NEXT:    ret i3 [[T5]]
 ;
   %t1 = sub i3 0, %t0
   %t2 = icmp ne i3 %t0, 0


        


More information about the llvm-commits mailing list