[llvm] r310509 - [InstCombine] narrow rotate left/right patterns to eliminate zext/trunc (PR34046)

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Wed Aug 9 11:37:41 PDT 2017


Author: spatel
Date: Wed Aug  9 11:37:41 2017
New Revision: 310509

URL: http://llvm.org/viewvc/llvm-project?rev=310509&view=rev
Log:
[InstCombine] narrow rotate left/right patterns to eliminate zext/trunc (PR34046)

I couldn't find any smaller folds to help the cases in:
https://bugs.llvm.org/show_bug.cgi?id=34046
after:
rL310141

The truncated rotate-by-variable patterns elude all of the existing transforms because 
of multiple uses and knowledge about demanded bits and knownbits that doesn't exist 
without the whole pattern. So we need an unfortunately large pattern match. But by 
simplifying this pattern in IR, the backend is already able to generate 
rolb/rolw/rorb/rorw for x86 using its existing rotate matching logic (although
there is a likely extraneous 'and' of the rotate amount). 

Note that rotate-by-constant doesn't have this problem - smaller folds should already 
produce the narrow IR ops.

Differential Revision: https://reviews.llvm.org/D36395


Added:
    llvm/trunk/test/Transforms/InstCombine/rotate.ll
Modified:
    llvm/trunk/lib/Transforms/InstCombine/InstCombineCasts.cpp
    llvm/trunk/lib/Transforms/InstCombine/InstCombineInternal.h

Modified: llvm/trunk/lib/Transforms/InstCombine/InstCombineCasts.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/InstCombine/InstCombineCasts.cpp?rev=310509&r1=310508&r2=310509&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/InstCombine/InstCombineCasts.cpp (original)
+++ llvm/trunk/lib/Transforms/InstCombine/InstCombineCasts.cpp Wed Aug  9 11:37:41 2017
@@ -443,13 +443,81 @@ static Instruction *foldVecTruncToExtElt
   return ExtractElementInst::Create(VecInput, IC.Builder.getInt32(Elt));
 }
 
+/// Rotate left/right may occur in a wider type than necessary because of type
+/// promotion rules. Try to narrow all of the component instructions.
+Instruction *InstCombiner::narrowRotate(TruncInst &Trunc) {
+  assert((isa<VectorType>(Trunc.getSrcTy()) ||
+          shouldChangeType(Trunc.getSrcTy(), Trunc.getType())) &&
+         "Don't narrow to an illegal scalar type");
+
+  // First, find an or'd pair of opposite shifts with the same shifted operand:
+  // trunc (or (lshr ShVal, ShAmt0), (shl ShVal, ShAmt1))
+  Value *Or0, *Or1;
+  if (!match(Trunc.getOperand(0), m_OneUse(m_Or(m_Value(Or0), m_Value(Or1)))))
+    return nullptr;
+
+  Value *ShVal, *ShAmt0, *ShAmt1;
+  if (!match(Or0, m_OneUse(m_LogicalShift(m_Value(ShVal), m_Value(ShAmt0)))) ||
+      !match(Or1, m_OneUse(m_LogicalShift(m_Specific(ShVal), m_Value(ShAmt1)))))
+    return nullptr;
+
+  auto ShiftOpcode0 = cast<BinaryOperator>(Or0)->getOpcode();
+  auto ShiftOpcode1 = cast<BinaryOperator>(Or1)->getOpcode();
+  if (ShiftOpcode0 == ShiftOpcode1)
+    return nullptr;
+
+  // The shift amounts must add up to the narrow bit width.
+  Value *ShAmt;
+  bool SubIsOnLHS;
+  Type *DestTy = Trunc.getType();
+  unsigned NarrowWidth = DestTy->getScalarSizeInBits();
+  if (match(ShAmt0,
+            m_OneUse(m_Sub(m_SpecificInt(NarrowWidth), m_Specific(ShAmt1))))) {
+    ShAmt = ShAmt1;
+    SubIsOnLHS = true;
+  } else if (match(ShAmt1, m_OneUse(m_Sub(m_SpecificInt(NarrowWidth),
+                                          m_Specific(ShAmt0))))) {
+    ShAmt = ShAmt0;
+    SubIsOnLHS = false;
+  } else {
+    return nullptr;
+  }
+
+  // The shifted value must have high zeros in the wide type. Typically, this
+  // will be a zext, but it could also be the result of an 'and' or 'shift'.
+  unsigned WideWidth = Trunc.getSrcTy()->getScalarSizeInBits();
+  APInt HiBitMask = APInt::getHighBitsSet(WideWidth, WideWidth - NarrowWidth);
+  if (!MaskedValueIsZero(ShVal, HiBitMask, 0, &Trunc))
+    return nullptr;
+
+  // We have an unnecessarily wide rotate!
+  // trunc (or (lshr ShVal, ShAmt), (shl ShVal, BitWidth - ShAmt))
+  // Narrow it down to eliminate the zext/trunc:
+  // or (lshr trunc(ShVal), ShAmt0'), (shl trunc(ShVal), ShAmt1')
+  Value *NarrowShAmt = Builder.CreateTrunc(ShAmt, DestTy);
+  Value *NegShAmt = Builder.CreateNeg(NarrowShAmt);
+
+  // Mask both shift amounts to ensure there's no UB from oversized shifts.
+  Constant *MaskC = ConstantInt::get(DestTy, NarrowWidth - 1);
+  Value *MaskedShAmt = Builder.CreateAnd(NarrowShAmt, MaskC);
+  Value *MaskedNegShAmt = Builder.CreateAnd(NegShAmt, MaskC);
+
+  // Truncate the original value and use narrow ops.
+  Value *X = Builder.CreateTrunc(ShVal, DestTy);
+  Value *NarrowShAmt0 = SubIsOnLHS ? MaskedNegShAmt : MaskedShAmt;
+  Value *NarrowShAmt1 = SubIsOnLHS ? MaskedShAmt : MaskedNegShAmt;
+  Value *NarrowSh0 = Builder.CreateBinOp(ShiftOpcode0, X, NarrowShAmt0);
+  Value *NarrowSh1 = Builder.CreateBinOp(ShiftOpcode1, X, NarrowShAmt1);
+  return BinaryOperator::CreateOr(NarrowSh0, NarrowSh1);
+}
+
 /// Try to narrow the width of math or bitwise logic instructions by pulling a
 /// truncate ahead of binary operators.
 /// TODO: Transforms for truncated shifts should be moved into here.
 Instruction *InstCombiner::narrowBinOp(TruncInst &Trunc) {
   Type *SrcTy = Trunc.getSrcTy();
   Type *DestTy = Trunc.getType();
-  if (isa<IntegerType>(SrcTy) && !shouldChangeType(SrcTy, DestTy))
+  if (!isa<VectorType>(SrcTy) && !shouldChangeType(SrcTy, DestTy))
     return nullptr;
 
   BinaryOperator *BinOp;
@@ -485,6 +553,9 @@ Instruction *InstCombiner::narrowBinOp(T
   default: break;
   }
 
+  if (Instruction *NarrowOr = narrowRotate(Trunc))
+    return NarrowOr;
+
   return nullptr;
 }
 

Modified: llvm/trunk/lib/Transforms/InstCombine/InstCombineInternal.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/InstCombine/InstCombineInternal.h?rev=310509&r1=310508&r2=310509&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/InstCombine/InstCombineInternal.h (original)
+++ llvm/trunk/lib/Transforms/InstCombine/InstCombineInternal.h Wed Aug  9 11:37:41 2017
@@ -440,6 +440,7 @@ private:
   Value *EvaluateInDifferentElementOrder(Value *V, ArrayRef<int> Mask);
   Instruction *foldCastedBitwiseLogic(BinaryOperator &I);
   Instruction *narrowBinOp(TruncInst &Trunc);
+  Instruction *narrowRotate(TruncInst &Trunc);
   Instruction *optimizeBitCastFromPhi(CastInst &CI, PHINode *PN);
 
   /// Determine if a pair of casts can be replaced by a single cast.

Added: llvm/trunk/test/Transforms/InstCombine/rotate.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/InstCombine/rotate.ll?rev=310509&view=auto
==============================================================================
--- llvm/trunk/test/Transforms/InstCombine/rotate.ll (added)
+++ llvm/trunk/test/Transforms/InstCombine/rotate.ll Wed Aug  9 11:37:41 2017
@@ -0,0 +1,123 @@
+; RUN: opt < %s -instcombine -S | FileCheck %s
+
+target datalayout = "e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:32:64-f32:32:32-f64:32:64-v64:64:64-v128:128:128-a0:0:64-f80:128:128"
+
+; These are UB-free rotate left/right patterns that are narrowed to a smaller bitwidth.
+; See PR34046 and PR16726 for motivating examples:
+; https://bugs.llvm.org/show_bug.cgi?id=34046
+; https://bugs.llvm.org/show_bug.cgi?id=16726
+
+define i16 @rotate_left_16bit(i16 %v, i32 %shift) {
+; CHECK-LABEL: @rotate_left_16bit(
+; CHECK-NEXT:    [[TMP1:%.*]] = trunc i32 %shift to i16
+; CHECK-NEXT:    [[TMP2:%.*]] = and i16 [[TMP1]], 15
+; CHECK-NEXT:    [[TMP3:%.*]] = sub i16 0, [[TMP1]]
+; CHECK-NEXT:    [[TMP4:%.*]] = and i16 [[TMP3]], 15
+; CHECK-NEXT:    [[TMP5:%.*]] = lshr i16 %v, [[TMP4]]
+; CHECK-NEXT:    [[TMP6:%.*]] = shl i16 %v, [[TMP2]]
+; CHECK-NEXT:    [[CONV2:%.*]] = or i16 [[TMP5]], [[TMP6]]
+; CHECK-NEXT:    ret i16 [[CONV2]]
+;
+  %and = and i32 %shift, 15
+  %conv = zext i16 %v to i32
+  %shl = shl i32 %conv, %and
+  %sub = sub i32 16, %and
+  %shr = lshr i32 %conv, %sub
+  %or = or i32 %shr, %shl
+  %conv2 = trunc i32 %or to i16
+  ret i16 %conv2
+}
+
+; Commute the 'or' operands and try a vector type.
+
+define <2 x i16> @rotate_left_commute_16bit_vec(<2 x i16> %v, <2 x i32> %shift) {
+; CHECK-LABEL: @rotate_left_commute_16bit_vec(
+; CHECK-NEXT:    [[TMP1:%.*]] = trunc <2 x i32> %shift to <2 x i16>
+; CHECK-NEXT:    [[TMP2:%.*]] = and <2 x i16> [[TMP1]], <i16 15, i16 15>
+; CHECK-NEXT:    [[TMP3:%.*]] = sub <2 x i16> zeroinitializer, [[TMP1]]
+; CHECK-NEXT:    [[TMP4:%.*]] = and <2 x i16> [[TMP3]], <i16 15, i16 15>
+; CHECK-NEXT:    [[TMP5:%.*]] = shl <2 x i16> %v, [[TMP2]]
+; CHECK-NEXT:    [[TMP6:%.*]] = lshr <2 x i16> %v, [[TMP4]]
+; CHECK-NEXT:    [[CONV2:%.*]] = or <2 x i16> [[TMP5]], [[TMP6]]
+; CHECK-NEXT:    ret <2 x i16> [[CONV2]]
+;
+  %and = and <2 x i32> %shift, <i32 15, i32 15>
+  %conv = zext <2 x i16> %v to <2 x i32>
+  %shl = shl <2 x i32> %conv, %and
+  %sub = sub <2 x i32> <i32 16, i32 16>, %and
+  %shr = lshr <2 x i32> %conv, %sub
+  %or = or <2 x i32> %shl, %shr
+  %conv2 = trunc <2 x i32> %or to <2 x i16>
+  ret <2 x i16> %conv2
+}
+
+; Change the size, rotation direction (the subtract is on the left-shift), and mask op.
+
+define i8 @rotate_right_8bit(i8 %v, i3 %shift) {
+; CHECK-LABEL: @rotate_right_8bit(
+; CHECK-NEXT:    [[TMP1:%.*]] = zext i3 %shift to i8
+; CHECK-NEXT:    [[TMP2:%.*]] = sub i3 0, %shift
+; CHECK-NEXT:    [[TMP3:%.*]] = zext i3 [[TMP2]] to i8
+; CHECK-NEXT:    [[TMP4:%.*]] = shl i8 %v, [[TMP3]]
+; CHECK-NEXT:    [[TMP5:%.*]] = lshr i8 %v, [[TMP1]]
+; CHECK-NEXT:    [[CONV2:%.*]] = or i8 [[TMP4]], [[TMP5]]
+; CHECK-NEXT:    ret i8 [[CONV2]]
+;
+  %and = zext i3 %shift to i32
+  %conv = zext i8 %v to i32
+  %shr = lshr i32 %conv, %and
+  %sub = sub i32 8, %and
+  %shl = shl i32 %conv, %sub
+  %or = or i32 %shl, %shr
+  %conv2 = trunc i32 %or to i8
+  ret i8 %conv2
+}
+
+; The shifted value does not need to be a zexted value; here it is masked.
+; The shift mask could be less than the bitwidth, but this is still ok.
+
+define i8 @rotate_right_commute_8bit(i32 %v, i32 %shift) {
+; CHECK-LABEL: @rotate_right_commute_8bit(
+; CHECK-NEXT:    [[TMP1:%.*]] = trunc i32 %shift to i8
+; CHECK-NEXT:    [[TMP2:%.*]] = and i8 [[TMP1]], 3
+; CHECK-NEXT:    [[TMP3:%.*]] = sub nsw i8 0, [[TMP2]]
+; CHECK-NEXT:    [[TMP4:%.*]] = and i8 [[TMP3]], 7
+; CHECK-NEXT:    [[TMP5:%.*]] = trunc i32 %v to i8
+; CHECK-NEXT:    [[TMP6:%.*]] = lshr i8 [[TMP5]], [[TMP2]]
+; CHECK-NEXT:    [[TMP7:%.*]] = shl i8 [[TMP5]], [[TMP4]]
+; CHECK-NEXT:    [[CONV2:%.*]] = or i8 [[TMP6]], [[TMP7]]
+; CHECK-NEXT:    ret i8 [[CONV2]]
+;
+  %and = and i32 %shift, 3
+  %conv = and i32 %v, 255
+  %shr = lshr i32 %conv, %and
+  %sub = sub i32 8, %and
+  %shl = shl i32 %conv, %sub
+  %or = or i32 %shr, %shl
+  %conv2 = trunc i32 %or to i8
+  ret i8 %conv2
+}
+
+; If the original source does not mask the shift amount,
+; we still do the transform by adding masks to make it safe.
+
+define i8 @rotate8_not_safe(i8 %v, i32 %shamt) {
+; CHECK-LABEL: @rotate8_not_safe(
+; CHECK-NEXT:    [[TMP1:%.*]] = trunc i32 %shamt to i8
+; CHECK-NEXT:    [[TMP2:%.*]] = sub i8 0, [[TMP1]]
+; CHECK-NEXT:    [[TMP3:%.*]] = and i8 [[TMP1]], 7
+; CHECK-NEXT:    [[TMP4:%.*]] = and i8 [[TMP2]], 7
+; CHECK-NEXT:    [[TMP5:%.*]] = lshr i8 %v, [[TMP4]]
+; CHECK-NEXT:    [[TMP6:%.*]] = shl i8 %v, [[TMP3]]
+; CHECK-NEXT:    [[RET:%.*]] = or i8 [[TMP5]], [[TMP6]]
+; CHECK-NEXT:    ret i8 [[RET]]
+;
+  %conv = zext i8 %v to i32
+  %sub = sub i32 8, %shamt
+  %shr = lshr i32 %conv, %sub
+  %shl = shl i32 %conv, %shamt
+  %or = or i32 %shr, %shl
+  %ret = trunc i32 %or to i8
+  ret i8 %ret
+}
+




More information about the llvm-commits mailing list