[llvm] 5ccb058 - [InstCombine] Simplify udiv -> lshr folding

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Wed Feb 23 05:55:31 PST 2022


Author: Nikita Popov
Date: 2022-02-23T14:55:23+01:00
New Revision: 5ccb0582c2b199913829d75a1dbbc866a707f400

URL: https://github.com/llvm/llvm-project/commit/5ccb0582c2b199913829d75a1dbbc866a707f400
DIFF: https://github.com/llvm/llvm-project/commit/5ccb0582c2b199913829d75a1dbbc866a707f400.diff

LOG: [InstCombine] Simplify udiv -> lshr folding

What we're really doing here is converting Op0 udiv Op1 into
Op0 lshr log2(Op1), so phrase it in that way. Actually pushing
the lshr into the log2(Op1) expression should be seen as a separate
transform.

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
    llvm/test/Transforms/InstCombine/div-shift.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index db239385aed0..36fb08f58221 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -902,9 +902,7 @@ static const unsigned MaxDepth = 6;
 
 namespace {
 
-using FoldUDivOperandCb = Instruction *(*)(Value *Op0, Value *Op1,
-                                           const BinaryOperator &I,
-                                           InstCombinerImpl &IC);
+using FoldUDivOperandCb = Value *(*)(Value *V, InstCombinerImpl &IC);
 
 /// Used to maintain state for visitUDivOperand().
 struct UDivFoldAction {
@@ -917,7 +915,7 @@ struct UDivFoldAction {
 
   union {
     /// The instruction returned when FoldAction is invoked.
-    Instruction *FoldResult;
+    Value *FoldResult;
 
     /// Stores the LHS action index if this action joins two actions together.
     size_t SelectLHSIdx;
@@ -931,26 +929,20 @@ struct UDivFoldAction {
 
 } // end anonymous namespace
 
-// X udiv 2^C -> X >> C
-static Instruction *foldUDivPow2Cst(Value *Op0, Value *Op1,
-                                    const BinaryOperator &I,
-                                    InstCombinerImpl &IC) {
-  Constant *C1 = ConstantExpr::getExactLogBase2(cast<Constant>(Op1));
-  if (!C1)
+// log2(2^C) -> C
+static Value *foldUDivPow2Cst(Value *V, InstCombinerImpl &IC) {
+  Constant *C = ConstantExpr::getExactLogBase2(cast<Constant>(V));
+  if (!C)
     llvm_unreachable("Failed to constant fold udiv -> logbase2");
-  BinaryOperator *LShr = BinaryOperator::CreateLShr(Op0, C1);
-  if (I.isExact())
-    LShr->setIsExact();
-  return LShr;
+  return C;
 }
 
-// X udiv (C1 << N), where C1 is "1<<C2"  -->  X >> (N+C2)
-// X udiv (zext (C1 << N)), where C1 is "1<<C2"  -->  X >> (N+C2)
-static Instruction *foldUDivShl(Value *Op0, Value *Op1, const BinaryOperator &I,
-                                InstCombinerImpl &IC) {
+// log2(C1 << N) -> N+C2, where C1 is 1<<C2
+// log2(zext(C1 << N)) -> zext(N+C2), where C1 is 1<<C2
+static Value *foldUDivShl(Value *V, InstCombinerImpl &IC) {
   Value *ShiftLeft;
-  if (!match(Op1, m_ZExt(m_Value(ShiftLeft))))
-    ShiftLeft = Op1;
+  if (!match(V, m_ZExt(m_Value(ShiftLeft))))
+    ShiftLeft = V;
 
   Constant *CI;
   Value *N;
@@ -960,19 +952,16 @@ static Instruction *foldUDivShl(Value *Op0, Value *Op1, const BinaryOperator &I,
   if (!Log2Base)
     llvm_unreachable("getLogBase2 should never fail here!");
   N = IC.Builder.CreateAdd(N, Log2Base);
-  if (Op1 != ShiftLeft)
-    N = IC.Builder.CreateZExt(N, Op1->getType());
-  BinaryOperator *LShr = BinaryOperator::CreateLShr(Op0, N);
-  if (I.isExact())
-    LShr->setIsExact();
-  return LShr;
+  if (V != ShiftLeft)
+    N = IC.Builder.CreateZExt(N, V->getType());
+  return N;
 }
 
 // Recursively visits the possible right hand operands of a udiv
 // instruction, seeing through select instructions, to determine if we can
 // replace the udiv with something simpler.  If we find that an operand is not
 // able to simplify the udiv, we abort the entire transformation.
-static size_t visitUDivOperand(Value *Op, const BinaryOperator &I,
+static size_t visitUDivOperand(Value *Op,
                                SmallVectorImpl<UDivFoldAction> &Actions,
                                unsigned Depth = 0) {
   // FIXME: assert that Op1 isn't/doesn't contain undef.
@@ -999,8 +988,8 @@ static size_t visitUDivOperand(Value *Op, const BinaryOperator &I,
     // FIXME: missed optimization: if one of the hands of select is/contains
     //        undef, just directly pick the other one.
     // FIXME: can both hands contain undef?
-    if (size_t LHSIdx = visitUDivOperand(SI->getOperand(1), I, Actions, Depth))
-      if (visitUDivOperand(SI->getOperand(2), I, Actions, Depth)) {
+    if (size_t LHSIdx = visitUDivOperand(SI->getOperand(1), Actions, Depth))
+      if (visitUDivOperand(SI->getOperand(2), Actions, Depth)) {
         Actions.push_back(UDivFoldAction(nullptr, Op, LHSIdx - 1));
         return Actions.size();
       }
@@ -1105,15 +1094,15 @@ Instruction *InstCombinerImpl::visitUDiv(BinaryOperator &I) {
       return BinaryOperator::CreateUDiv(A, X);
   }
 
-  // (LHS udiv (select (select (...)))) -> (LHS >> (select (select (...))))
+  // Op1 udiv Op2 -> Op1 lshr log2(Op2), if log2() folds away.
   SmallVector<UDivFoldAction, 6> UDivActions;
-  if (visitUDivOperand(Op1, I, UDivActions))
+  if (visitUDivOperand(Op1, UDivActions))
     for (unsigned i = 0, e = UDivActions.size(); i != e; ++i) {
       FoldUDivOperandCb Action = UDivActions[i].FoldAction;
       Value *ActionOp1 = UDivActions[i].OperandToFold;
-      Instruction *Inst;
+      Value *Res;
       if (Action)
-        Inst = Action(Op0, ActionOp1, I, *this);
+        Res = Action(ActionOp1, *this);
       else {
         // This action joins two actions together.  The RHS of this action is
         // simply the last action we processed, we saved the LHS action index in
@@ -1122,18 +1111,19 @@ Instruction *InstCombinerImpl::visitUDiv(BinaryOperator &I) {
         Value *SelectRHS = UDivActions[SelectRHSIdx].FoldResult;
         size_t SelectLHSIdx = UDivActions[i].SelectLHSIdx;
         Value *SelectLHS = UDivActions[SelectLHSIdx].FoldResult;
-        Inst = SelectInst::Create(cast<SelectInst>(ActionOp1)->getCondition(),
-                                  SelectLHS, SelectRHS);
+        Res = Builder.CreateSelect(cast<SelectInst>(ActionOp1)->getCondition(),
+                                   SelectLHS, SelectRHS);
       }
 
       // If this is the last action to process, return it to the InstCombiner.
-      // Otherwise, we insert it before the UDiv and record it so that we may
-      // use it as part of a joining action (i.e., a SelectInst).
+      // Otherwise, record it so that we may use it as part of a joining action
+      // (i.e., a SelectInst).
       if (e - i != 1) {
-        Inst->insertBefore(&I);
-        UDivActions[i].FoldResult = Inst;
-      } else
-        return Inst;
+        UDivActions[i].FoldResult = Res;
+      } else {
+        return replaceInstUsesWith(
+            I, Builder.CreateLShr(Op0, Res, I.getName(), I.isExact()));
+      }
     }
 
   return nullptr;
@@ -1241,8 +1231,10 @@ Instruction *InstCombinerImpl::visitSDiv(BinaryOperator &I) {
     if (match(Op1, m_NegatedPower2())) {
       // X sdiv (-(1 << C)) -> -(X sdiv (1 << C)) ->
       //                    -> -(X udiv (1 << C)) -> -(X u>> C)
-      return BinaryOperator::CreateNeg(Builder.Insert(foldUDivPow2Cst(
-          Op0, ConstantExpr::getNeg(cast<Constant>(Op1)), I, *this)));
+      Constant *CNegLog2 = ConstantExpr::getExactLogBase2(
+          ConstantExpr::getNeg(cast<Constant>(Op1)));
+      Value *Shr = Builder.CreateLShr(Op0, CNegLog2, I.getName(), I.isExact());
+      return BinaryOperator::CreateNeg(Shr);
     }
 
     if (isKnownToBeAPowerOfTwo(Op1, /*OrZero*/ true, 0, &I)) {

diff  --git a/llvm/test/Transforms/InstCombine/div-shift.ll b/llvm/test/Transforms/InstCombine/div-shift.ll
index 8ee9063894d9..6d285ab2f099 100644
--- a/llvm/test/Transforms/InstCombine/div-shift.ll
+++ b/llvm/test/Transforms/InstCombine/div-shift.ll
@@ -6,8 +6,8 @@ define i32 @t1(i16 zeroext %x, i32 %y) {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[CONV:%.*]] = zext i16 [[X:%.*]] to i32
 ; CHECK-NEXT:    [[TMP0:%.*]] = add i32 [[Y:%.*]], 1
-; CHECK-NEXT:    [[D:%.*]] = lshr i32 [[CONV]], [[TMP0]]
-; CHECK-NEXT:    ret i32 [[D]]
+; CHECK-NEXT:    [[D1:%.*]] = lshr i32 [[CONV]], [[TMP0]]
+; CHECK-NEXT:    ret i32 [[D1]]
 ;
 entry:
   %conv = zext i16 %x to i32
@@ -21,8 +21,8 @@ define <2 x i32> @t1vec(<2 x i16> %x, <2 x i32> %y) {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[CONV:%.*]] = zext <2 x i16> [[X:%.*]] to <2 x i32>
 ; CHECK-NEXT:    [[TMP0:%.*]] = add <2 x i32> [[Y:%.*]], <i32 1, i32 1>
-; CHECK-NEXT:    [[D:%.*]] = lshr <2 x i32> [[CONV]], [[TMP0]]
-; CHECK-NEXT:    ret <2 x i32> [[D]]
+; CHECK-NEXT:    [[D1:%.*]] = lshr <2 x i32> [[CONV]], [[TMP0]]
+; CHECK-NEXT:    ret <2 x i32> [[D1]]
 ;
 entry:
   %conv = zext <2 x i16> %x to <2 x i32>
@@ -61,9 +61,9 @@ define i64 @t3(i64 %x, i32 %y) {
 define i32 @t4(i32 %x, i32 %y) {
 ; CHECK-LABEL: @t4(
 ; CHECK-NEXT:    [[TMP1:%.*]] = icmp ugt i32 [[Y:%.*]], 5
-; CHECK-NEXT:    [[DOTV:%.*]] = select i1 [[TMP1]], i32 [[Y]], i32 5
-; CHECK-NEXT:    [[TMP2:%.*]] = lshr i32 [[X:%.*]], [[DOTV]]
-; CHECK-NEXT:    ret i32 [[TMP2]]
+; CHECK-NEXT:    [[TMP2:%.*]] = select i1 [[TMP1]], i32 [[Y]], i32 5
+; CHECK-NEXT:    [[TMP3:%.*]] = lshr i32 [[X:%.*]], [[TMP2]]
+; CHECK-NEXT:    ret i32 [[TMP3]]
 ;
   %1 = shl i32 1, %y
   %2 = icmp ult i32 %1, 32
@@ -74,10 +74,10 @@ define i32 @t4(i32 %x, i32 %y) {
 
 define i32 @t5(i1 %x, i1 %y, i32 %V) {
 ; CHECK-LABEL: @t5(
-; CHECK-NEXT:    [[DOTV:%.*]] = select i1 [[X:%.*]], i32 5, i32 6
-; CHECK-NEXT:    [[TMP1:%.*]] = lshr i32 [[V:%.*]], [[DOTV]]
-; CHECK-NEXT:    [[TMP2:%.*]] = select i1 [[Y:%.*]], i32 [[TMP1]], i32 0
-; CHECK-NEXT:    ret i32 [[TMP2]]
+; CHECK-NEXT:    [[TMP1:%.*]] = select i1 [[X:%.*]], i32 5, i32 6
+; CHECK-NEXT:    [[TMP2:%.*]] = select i1 [[Y:%.*]], i32 [[TMP1]], i32 [[V:%.*]]
+; CHECK-NEXT:    [[TMP3:%.*]] = lshr i32 [[V]], [[TMP2]]
+; CHECK-NEXT:    ret i32 [[TMP3]]
 ;
   %1 = shl i32 1, %V
   %2 = select i1 %x, i32 32, i32 64


        


More information about the llvm-commits mailing list