[llvm] 03e6efb - [InstCombine] Further simplify udiv -> lshr folding

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Wed Feb 23 06:29:31 PST 2022


Author: Nikita Popov
Date: 2022-02-23T15:29:21+01:00
New Revision: 03e6efb8c23f489e45353b6b6d941628d9c49ca2

URL: https://github.com/llvm/llvm-project/commit/03e6efb8c23f489e45353b6b6d941628d9c49ca2
DIFF: https://github.com/llvm/llvm-project/commit/03e6efb8c23f489e45353b6b6d941628d9c49ca2.diff

LOG: [InstCombine] Further simplify udiv -> lshr folding

Rather than queuing up actions, have one function that does the
log2() fold in the obvious way, but with a flag that allows us
to check whether the fold will succeed without actually performing
it.

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index 36fb08f58221..aeae25476db6 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -900,101 +900,55 @@ Instruction *InstCombinerImpl::commonIDivTransforms(BinaryOperator &I) {
 
 static const unsigned MaxDepth = 6;
 
-namespace {
-
-using FoldUDivOperandCb = Value *(*)(Value *V, InstCombinerImpl &IC);
-
-/// Used to maintain state for visitUDivOperand().
-struct UDivFoldAction {
-  /// Informs visitUDiv() how to fold this operand.  This can be zero if this
-  /// action joins two actions together.
-  FoldUDivOperandCb FoldAction;
-
-  /// Which operand to fold.
-  Value *OperandToFold;
-
-  union {
-    /// The instruction returned when FoldAction is invoked.
-    Value *FoldResult;
-
-    /// Stores the LHS action index if this action joins two actions together.
-    size_t SelectLHSIdx;
+// Take the exact integer log2 of the value. If DoFold is true, create the
+// actual instructions, otherwise return a non-null dummy value. Return nullptr
+// on failure.
+static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth,
+                       bool DoFold) {
+  auto IfFold = [DoFold](function_ref<Value *()> Fn) {
+    if (!DoFold)
+      return reinterpret_cast<Value *>(-1);
+    return Fn();
   };
 
-  UDivFoldAction(FoldUDivOperandCb FA, Value *InputOperand)
-      : FoldAction(FA), OperandToFold(InputOperand), FoldResult(nullptr) {}
-  UDivFoldAction(FoldUDivOperandCb FA, Value *InputOperand, size_t SLHS)
-      : FoldAction(FA), OperandToFold(InputOperand), SelectLHSIdx(SLHS) {}
-};
-
-} // end anonymous namespace
-
-// log2(2^C) -> C
-static Value *foldUDivPow2Cst(Value *V, InstCombinerImpl &IC) {
-  Constant *C = ConstantExpr::getExactLogBase2(cast<Constant>(V));
-  if (!C)
-    llvm_unreachable("Failed to constant fold udiv -> logbase2");
-  return C;
-}
-
-// log2(C1 << N) -> N+C2, where C1 is 1<<C2
-// log2(zext(C1 << N)) -> zext(N+C2), where C1 is 1<<C2
-static Value *foldUDivShl(Value *V, InstCombinerImpl &IC) {
-  Value *ShiftLeft;
-  if (!match(V, m_ZExt(m_Value(ShiftLeft))))
-    ShiftLeft = V;
-
-  Constant *CI;
-  Value *N;
-  if (!match(ShiftLeft, m_Shl(m_Constant(CI), m_Value(N))))
-    llvm_unreachable("match should never fail here!");
-  Constant *Log2Base = ConstantExpr::getExactLogBase2(CI);
-  if (!Log2Base)
-    llvm_unreachable("getLogBase2 should never fail here!");
-  N = IC.Builder.CreateAdd(N, Log2Base);
-  if (V != ShiftLeft)
-    N = IC.Builder.CreateZExt(N, V->getType());
-  return N;
-}
-
-// Recursively visits the possible right hand operands of a udiv
-// instruction, seeing through select instructions, to determine if we can
-// replace the udiv with something simpler.  If we find that an operand is not
-// able to simplify the udiv, we abort the entire transformation.
-static size_t visitUDivOperand(Value *Op,
-                               SmallVectorImpl<UDivFoldAction> &Actions,
-                               unsigned Depth = 0) {
   // FIXME: assert that Op1 isn't/doesn't contain undef.
 
-  // Check to see if this is an unsigned division with an exact power of 2,
-  // if so, convert to a right shift.
-  if (match(Op, m_Power2())) {
-    Actions.push_back(UDivFoldAction(foldUDivPow2Cst, Op));
-    return Actions.size();
-  }
-
-  // X udiv (C1 << N), where C1 is "1<<C2"  -->  X >> (N+C2)
-  if (match(Op, m_Shl(m_Power2(), m_Value())) ||
-      match(Op, m_ZExt(m_Shl(m_Power2(), m_Value())))) {
-    Actions.push_back(UDivFoldAction(foldUDivShl, Op));
-    return Actions.size();
-  }
+  // log2(2^C) -> C
+  if (match(Op, m_Power2()))
+    return IfFold([&]() {
+      Constant *C = ConstantExpr::getExactLogBase2(cast<Constant>(Op));
+      if (!C)
+        llvm_unreachable("Failed to constant fold udiv -> logbase2");
+      return C;
+    });
 
   // The remaining tests are all recursive, so bail out if we hit the limit.
   if (Depth++ == MaxDepth)
-    return 0;
+    return nullptr;
 
+  // log2(zext X) -> zext log2(X)
+  Value *X, *Y;
+  if (match(Op, m_ZExt(m_Value(X))))
+    if (Value *LogX = takeLog2(Builder, X, Depth, DoFold))
+      return IfFold([&]() { return Builder.CreateZExt(LogX, Op->getType()); });
+
+  // log2(X << Y) -> log2(X) + Y
+  if (match(Op, m_Shl(m_Value(X), m_Value(Y))))
+    if (Value *LogX = takeLog2(Builder, X, Depth, DoFold))
+      return IfFold([&]() { return Builder.CreateAdd(LogX, Y); });
+
+  // log2(Cond ? X : Y) -> Cond ? log2(X) : log2(Y)
+  // FIXME: missed optimization: if one of the hands of select is/contains
+  //        undef, just directly pick the other one.
+  // FIXME: can both hands contain undef?
   if (SelectInst *SI = dyn_cast<SelectInst>(Op))
-    // FIXME: missed optimization: if one of the hands of select is/contains
-    //        undef, just directly pick the other one.
-    // FIXME: can both hands contain undef?
-    if (size_t LHSIdx = visitUDivOperand(SI->getOperand(1), Actions, Depth))
-      if (visitUDivOperand(SI->getOperand(2), Actions, Depth)) {
-        Actions.push_back(UDivFoldAction(nullptr, Op, LHSIdx - 1));
-        return Actions.size();
-      }
+    if (Value *LogX = takeLog2(Builder, SI->getOperand(1), Depth, DoFold))
+      if (Value *LogY = takeLog2(Builder, SI->getOperand(2), Depth, DoFold))
+        return IfFold([&]() {
+          return Builder.CreateSelect(SI->getOperand(0), LogX, LogY);
+        });
 
-  return 0;
+  return nullptr;
 }
 
 /// If we have zero-extended operands of an unsigned div or rem, we may be able
@@ -1095,36 +1049,11 @@ Instruction *InstCombinerImpl::visitUDiv(BinaryOperator &I) {
   }
 
   // Op1 udiv Op2 -> Op1 lshr log2(Op2), if log2() folds away.
-  SmallVector<UDivFoldAction, 6> UDivActions;
-  if (visitUDivOperand(Op1, UDivActions))
-    for (unsigned i = 0, e = UDivActions.size(); i != e; ++i) {
-      FoldUDivOperandCb Action = UDivActions[i].FoldAction;
-      Value *ActionOp1 = UDivActions[i].OperandToFold;
-      Value *Res;
-      if (Action)
-        Res = Action(ActionOp1, *this);
-      else {
-        // This action joins two actions together.  The RHS of this action is
-        // simply the last action we processed, we saved the LHS action index in
-        // the joining action.
-        size_t SelectRHSIdx = i - 1;
-        Value *SelectRHS = UDivActions[SelectRHSIdx].FoldResult;
-        size_t SelectLHSIdx = UDivActions[i].SelectLHSIdx;
-        Value *SelectLHS = UDivActions[SelectLHSIdx].FoldResult;
-        Res = Builder.CreateSelect(cast<SelectInst>(ActionOp1)->getCondition(),
-                                   SelectLHS, SelectRHS);
-      }
-
-      // If this is the last action to process, return it to the InstCombiner.
-      // Otherwise, record it so that we may use it as part of a joining action
-      // (i.e., a SelectInst).
-      if (e - i != 1) {
-        UDivActions[i].FoldResult = Res;
-      } else {
-        return replaceInstUsesWith(
-            I, Builder.CreateLShr(Op0, Res, I.getName(), I.isExact()));
-      }
-    }
+  if (takeLog2(Builder, Op1, /*Depth*/0, /*DoFold*/false)) {
+    Value *Res = takeLog2(Builder, Op1, /*Depth*/0, /*DoFold*/true);
+    return replaceInstUsesWith(
+        I, Builder.CreateLShr(Op0, Res, I.getName(), I.isExact()));
+  }
 
   return nullptr;
 }


        


More information about the llvm-commits mailing list