[llvm] r375371 - [InstCombine] Shift amount reassociation in shifty sign bit test (PR43595)

Roman Lebedev via llvm-commits llvm-commits at lists.llvm.org
Sun Oct 20 12:38:50 PDT 2019


Author: lebedevri
Date: Sun Oct 20 12:38:50 2019
New Revision: 375371

URL: http://llvm.org/viewvc/llvm-project?rev=375371&view=rev
Log:
[InstCombine] Shift amount reassociation in shifty sign bit test (PR43595)

Summary:
This problem consists of several parts:
* Basic sign bit extraction - `trunc? (?shr %x, (bitwidth(x)-1))`.
  This is trivial, and easy to do, we have a fold for it.
* Shift amount reassociation - if we have two identical shifts,
  and we can simplify-add their shift amounts together,
  then we likely can just perform them as a single shift.
  But this is finicky, has one-use restrictions,
  and shift opcodes must be identical.

But there is a super-pattern where both of these work together.
to produce sign bit test from two shifts + comparison.
We do indeed already handle this in most cases.
But since we get that fold transitively, it has one-use restrictions.
And what's worse, in this case the right-shifts aren't required to be
identical, and we can't handle that transitively:

If the total shift amount is bitwidth-1, only a sign bit will remain
in the output value. But if we look at this from the perspective of
two shifts, we can't fold - we can't possibly know what bit pattern
we'd produce via two shifts, it will be *some* kind of a mask
produced from original sign bit, but we just can't tell it's shape:
https://rise4fun.com/Alive/cM0 https://rise4fun.com/Alive/9IN

But it will *only* contain sign bit and zeros.
So from the perspective of sign bit test, we're good:
https://rise4fun.com/Alive/FRz https://rise4fun.com/Alive/qBU
Superb!

So the simplest solution is to extend `reassociateShiftAmtsOfTwoSameDirectionShifts()` to also have a
sudo-analysis mode that will ignore extra-uses, and will only check
whether a) those are two right shifts and b) they end up with bitwidth(x)-1
shift amount and return either the original value that we sign-checking,
or null.

This does not have any functionality change for
the existing `reassociateShiftAmtsOfTwoSameDirectionShifts()`.

All that being said, as disscussed in the review, this yet again
increases usage of instsimplify in instcombine as utility.
Some day that may need to be reevaluated.

https://bugs.llvm.org/show_bug.cgi?id=43595

Reviewers: spatel, efriedma, vsk

Reviewed By: spatel

Subscribers: xbolva00, hiraditya, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D68930

Modified:
    llvm/trunk/lib/Transforms/InstCombine/InstCombineCompares.cpp
    llvm/trunk/lib/Transforms/InstCombine/InstCombineInternal.h
    llvm/trunk/lib/Transforms/InstCombine/InstCombineShifts.cpp
    llvm/trunk/test/Transforms/InstCombine/sign-bit-test-via-right-shifting-all-other-bits.ll

Modified: llvm/trunk/lib/Transforms/InstCombine/InstCombineCompares.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/InstCombine/InstCombineCompares.cpp?rev=375371&r1=375370&r2=375371&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/InstCombine/InstCombineCompares.cpp (original)
+++ llvm/trunk/lib/Transforms/InstCombine/InstCombineCompares.cpp Sun Oct 20 12:38:50 2019
@@ -1358,19 +1358,28 @@ Instruction *InstCombiner::foldIRemByPow
 
 /// Fold equality-comparison between zero and any (maybe truncated) right-shift
 /// by one-less-than-bitwidth into a sign test on the original value.
-Instruction *foldSignBitTest(ICmpInst &I) {
+Instruction *InstCombiner::foldSignBitTest(ICmpInst &I) {
+  Instruction *Val;
   ICmpInst::Predicate Pred;
-  Value *X;
-  Constant *C;
-  if (!I.isEquality() ||
-      !match(&I, m_ICmp(Pred, m_TruncOrSelf(m_Shr(m_Value(X), m_Constant(C))),
-                        m_Zero())))
+  if (!I.isEquality() || !match(&I, m_ICmp(Pred, m_Instruction(Val), m_Zero())))
     return nullptr;
 
-  Type *XTy = X->getType();
-  unsigned XBitWidth = XTy->getScalarSizeInBits();
-  if (!match(C, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_EQ,
-                                   APInt(XBitWidth, XBitWidth - 1))))
+  Value *X;
+  Type *XTy;
+
+  Constant *C;
+  if (match(Val, m_TruncOrSelf(m_Shr(m_Value(X), m_Constant(C))))) {
+    XTy = X->getType();
+    unsigned XBitWidth = XTy->getScalarSizeInBits();
+    if (!match(C, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_EQ,
+                                     APInt(XBitWidth, XBitWidth - 1))))
+      return nullptr;
+  } else if (isa<BinaryOperator>(Val) &&
+             (X = reassociateShiftAmtsOfTwoSameDirectionShifts(
+                  cast<BinaryOperator>(Val), SQ.getWithInstruction(Val),
+                  /*AnalyzeForSignBitExtraction=*/true))) {
+    XTy = X->getType();
+  } else
     return nullptr;
 
   return ICmpInst::Create(Instruction::ICmp,

Modified: llvm/trunk/lib/Transforms/InstCombine/InstCombineInternal.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/InstCombine/InstCombineInternal.h?rev=375371&r1=375370&r2=375371&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/InstCombine/InstCombineInternal.h (original)
+++ llvm/trunk/lib/Transforms/InstCombine/InstCombineInternal.h Sun Oct 20 12:38:50 2019
@@ -390,6 +390,9 @@ public:
   Instruction *visitOr(BinaryOperator &I);
   Instruction *visitXor(BinaryOperator &I);
   Instruction *visitShl(BinaryOperator &I);
+  Value *reassociateShiftAmtsOfTwoSameDirectionShifts(
+      BinaryOperator *Sh0, const SimplifyQuery &SQ,
+      bool AnalyzeForSignBitExtraction = false);
   Instruction *foldVariableSignZeroExtensionOfVariableHighBitExtract(
       BinaryOperator &OldAShr);
   Instruction *visitAShr(BinaryOperator &I);
@@ -912,6 +915,7 @@ private:
   Instruction *foldICmpBinOp(ICmpInst &Cmp, const SimplifyQuery &SQ);
   Instruction *foldICmpEquality(ICmpInst &Cmp);
   Instruction *foldIRemByPowerOfTwoToBitTest(ICmpInst &I);
+  Instruction *foldSignBitTest(ICmpInst &I);
   Instruction *foldICmpWithZero(ICmpInst &Cmp);
 
   Value *foldUnsignedMultiplicationOverflowCheck(ICmpInst &Cmp);

Modified: llvm/trunk/lib/Transforms/InstCombine/InstCombineShifts.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/InstCombine/InstCombineShifts.cpp?rev=375371&r1=375370&r2=375371&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/InstCombine/InstCombineShifts.cpp (original)
+++ llvm/trunk/lib/Transforms/InstCombine/InstCombineShifts.cpp Sun Oct 20 12:38:50 2019
@@ -25,10 +25,12 @@ using namespace PatternMatch;
 // we should rewrite it as
 //   x shiftopcode (Q+K)  iff (Q+K) u< bitwidth(x)
 // This is valid for any shift, but they must be identical.
-static Instruction *
-reassociateShiftAmtsOfTwoSameDirectionShifts(BinaryOperator *Sh0,
-                                             const SimplifyQuery &SQ,
-                                             InstCombiner::BuilderTy &Builder) {
+//
+// AnalyzeForSignBitExtraction indicates that we will only analyze whether this
+// pattern has any 2 right-shifts that sum to 1 less than original bit width.
+Value *InstCombiner::reassociateShiftAmtsOfTwoSameDirectionShifts(
+    BinaryOperator *Sh0, const SimplifyQuery &SQ,
+    bool AnalyzeForSignBitExtraction) {
   // Look for a shift of some instruction, ignore zext of shift amount if any.
   Instruction *Sh0Op0;
   Value *ShAmt0;
@@ -56,14 +58,25 @@ reassociateShiftAmtsOfTwoSameDirectionSh
   if (ShAmt0->getType() != ShAmt1->getType())
     return nullptr;
 
-  // The shift opcodes must be identical.
+  // We are only looking for signbit extraction if we have two right shifts.
+  bool HadTwoRightShifts = match(Sh0, m_Shr(m_Value(), m_Value())) &&
+                           match(Sh1, m_Shr(m_Value(), m_Value()));
+  // ... and if it's not two right-shifts, we know the answer already.
+  if (AnalyzeForSignBitExtraction && !HadTwoRightShifts)
+    return nullptr;
+
+  // The shift opcodes must be identical, unless we are just checking whether
+  // this pattern can be interpreted as a sign-bit-extraction.
   Instruction::BinaryOps ShiftOpcode = Sh0->getOpcode();
-  if (ShiftOpcode != Sh1->getOpcode())
+  bool IdenticalShOpcodes = Sh0->getOpcode() == Sh1->getOpcode();
+  if (!IdenticalShOpcodes && !AnalyzeForSignBitExtraction)
     return nullptr;
 
   // If we saw truncation, we'll need to produce extra instruction,
-  // and for that one of the operands of the shift must be one-use.
-  if (Trunc && !match(Sh0, m_c_BinOp(m_OneUse(m_Value()), m_Value())))
+  // and for that one of the operands of the shift must be one-use,
+  // unless of course we don't actually plan to produce any instructions here.
+  if (Trunc && !AnalyzeForSignBitExtraction &&
+      !match(Sh0, m_c_BinOp(m_OneUse(m_Value()), m_Value())))
     return nullptr;
 
   // Can we fold (ShAmt0+ShAmt1) ?
@@ -80,14 +93,22 @@ reassociateShiftAmtsOfTwoSameDirectionSh
     return nullptr; // FIXME: could perform constant-folding.
 
   // If there was a truncation, and we have a right-shift, we can only fold if
-  // we are left with the original sign bit.
+  // we are left with the original sign bit. Likewise, if we were just checking
+  // that this is a sighbit extraction, this is the place to check it.
   // FIXME: zero shift amount is also legal here, but we can't *easily* check
   // more than one predicate so it's not really worth it.
-  if (Trunc && ShiftOpcode != Instruction::BinaryOps::Shl &&
-      !match(NewShAmt,
-             m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_EQ,
-                                APInt(NewShAmtBitWidth, XBitWidth - 1))))
-    return nullptr;
+  if (HadTwoRightShifts && (Trunc || AnalyzeForSignBitExtraction)) {
+    // If it's not a sign bit extraction, then we're done.
+    if (!match(NewShAmt,
+               m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_EQ,
+                                  APInt(NewShAmtBitWidth, XBitWidth - 1))))
+      return nullptr;
+    // If it is, and that was the question, return the base value.
+    if (AnalyzeForSignBitExtraction)
+      return X;
+  }
+
+  assert(IdenticalShOpcodes && "Should not get here with different shifts.");
 
   // All good, we can do this fold.
   NewShAmt = ConstantExpr::getZExtOrBitCast(NewShAmt, X->getType());
@@ -287,8 +308,8 @@ Instruction *InstCombiner::commonShiftTr
     if (Instruction *Res = FoldShiftByConstant(Op0, CUI, I))
       return Res;
 
-  if (Instruction *NewShift =
-          reassociateShiftAmtsOfTwoSameDirectionShifts(&I, SQ, Builder))
+  if (auto *NewShift = cast_or_null<Instruction>(
+          reassociateShiftAmtsOfTwoSameDirectionShifts(&I, SQ)))
     return NewShift;
 
   // (C1 shift (A add C2)) -> (C1 shift C2) shift A)

Modified: llvm/trunk/test/Transforms/InstCombine/sign-bit-test-via-right-shifting-all-other-bits.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/InstCombine/sign-bit-test-via-right-shifting-all-other-bits.ll?rev=375371&r1=375370&r2=375371&view=diff
==============================================================================
--- llvm/trunk/test/Transforms/InstCombine/sign-bit-test-via-right-shifting-all-other-bits.ll (original)
+++ llvm/trunk/test/Transforms/InstCombine/sign-bit-test-via-right-shifting-all-other-bits.ll Sun Oct 20 12:38:50 2019
@@ -45,7 +45,7 @@ define i1 @highest_bit_test_via_lshr_wit
 ; CHECK-NEXT:    call void @use32(i32 [[HIGH_BITS_EXTRACTED_NARROW]])
 ; CHECK-NEXT:    call void @use32(i32 [[SKIP_ALL_BITS_TILL_SIGNBIT]])
 ; CHECK-NEXT:    call void @use32(i32 [[SIGNBIT]])
-; CHECK-NEXT:    [[ISNEG:%.*]] = icmp ne i32 [[SIGNBIT]], 0
+; CHECK-NEXT:    [[ISNEG:%.*]] = icmp slt i64 [[DATA]], 0
 ; CHECK-NEXT:    ret i1 [[ISNEG]]
 ;
   %num_low_bits_to_skip = sub i32 64, %nbits
@@ -107,7 +107,7 @@ define i1 @highest_bit_test_via_ashr_wit
 ; CHECK-NEXT:    call void @use32(i32 [[HIGH_BITS_EXTRACTED_NARROW]])
 ; CHECK-NEXT:    call void @use32(i32 [[SKIP_ALL_BITS_TILL_SIGNBIT]])
 ; CHECK-NEXT:    call void @use32(i32 [[SIGNBIT]])
-; CHECK-NEXT:    [[ISNEG:%.*]] = icmp ne i32 [[SIGNBIT]], 0
+; CHECK-NEXT:    [[ISNEG:%.*]] = icmp slt i64 [[DATA]], 0
 ; CHECK-NEXT:    ret i1 [[ISNEG]]
 ;
   %num_low_bits_to_skip = sub i32 64, %nbits
@@ -138,7 +138,7 @@ define i1 @highest_bit_test_via_lshr_ash
 ; CHECK-NEXT:    call void @use32(i32 [[HIGH_BITS_EXTRACTED]])
 ; CHECK-NEXT:    call void @use32(i32 [[SKIP_ALL_BITS_TILL_SIGNBIT]])
 ; CHECK-NEXT:    call void @use32(i32 [[SIGNBIT]])
-; CHECK-NEXT:    [[ISNEG:%.*]] = icmp ne i32 [[SIGNBIT]], 0
+; CHECK-NEXT:    [[ISNEG:%.*]] = icmp slt i32 [[DATA]], 0
 ; CHECK-NEXT:    ret i1 [[ISNEG]]
 ;
   %num_low_bits_to_skip = sub i32 32, %nbits
@@ -169,7 +169,7 @@ define i1 @highest_bit_test_via_lshr_ash
 ; CHECK-NEXT:    call void @use32(i32 [[HIGH_BITS_EXTRACTED_NARROW]])
 ; CHECK-NEXT:    call void @use32(i32 [[SKIP_ALL_BITS_TILL_SIGNBIT]])
 ; CHECK-NEXT:    call void @use32(i32 [[SIGNBIT]])
-; CHECK-NEXT:    [[ISNEG:%.*]] = icmp ne i32 [[SIGNBIT]], 0
+; CHECK-NEXT:    [[ISNEG:%.*]] = icmp slt i64 [[DATA]], 0
 ; CHECK-NEXT:    ret i1 [[ISNEG]]
 ;
   %num_low_bits_to_skip = sub i32 64, %nbits
@@ -200,7 +200,7 @@ define i1 @highest_bit_test_via_ashr_lsh
 ; CHECK-NEXT:    call void @use32(i32 [[HIGH_BITS_EXTRACTED]])
 ; CHECK-NEXT:    call void @use32(i32 [[SKIP_ALL_BITS_TILL_SIGNBIT]])
 ; CHECK-NEXT:    call void @use32(i32 [[SIGNBIT]])
-; CHECK-NEXT:    [[ISNEG:%.*]] = icmp ne i32 [[SIGNBIT]], 0
+; CHECK-NEXT:    [[ISNEG:%.*]] = icmp slt i32 [[DATA]], 0
 ; CHECK-NEXT:    ret i1 [[ISNEG]]
 ;
   %num_low_bits_to_skip = sub i32 32, %nbits
@@ -231,7 +231,7 @@ define i1 @highest_bit_test_via_ashr_lsh
 ; CHECK-NEXT:    call void @use32(i32 [[HIGH_BITS_EXTRACTED_NARROW]])
 ; CHECK-NEXT:    call void @use32(i32 [[SKIP_ALL_BITS_TILL_SIGNBIT]])
 ; CHECK-NEXT:    call void @use32(i32 [[SIGNBIT]])
-; CHECK-NEXT:    [[ISNEG:%.*]] = icmp ne i32 [[SIGNBIT]], 0
+; CHECK-NEXT:    [[ISNEG:%.*]] = icmp slt i64 [[DATA]], 0
 ; CHECK-NEXT:    ret i1 [[ISNEG]]
 ;
   %num_low_bits_to_skip = sub i32 64, %nbits




More information about the llvm-commits mailing list