[llvm] 46db90c - [SCEV] `MatchBinaryOp()`: try to recognize `or` as `add`-in-disguise (w/ no common bits set)

Roman Lebedev via llvm-commits llvm-commits at lists.llvm.org
Tue Dec 6 09:27:45 PST 2022


Author: Roman Lebedev
Date: 2022-12-06T20:26:53+03:00
New Revision: 46db90cc71d11df1dd7d397d253445671836dfa0

URL: https://github.com/llvm/llvm-project/commit/46db90cc71d11df1dd7d397d253445671836dfa0
DIFF: https://github.com/llvm/llvm-project/commit/46db90cc71d11df1dd7d397d253445671836dfa0.diff

LOG: [SCEV] `MatchBinaryOp()`: try to recognize `or` as `add`-in-disguise (w/ no common bits set)

LLVM *loves* to convert `add` of operands with no common bits
into an `or`. But SCEV really doesn't deal with `or` that well,
so try extra hard to recognize this `or` as an `add`.

I believe, previously this wasn't being done because of the recursive
of this, but now that the `createSCEV()` is not recursive,
this should be fine. Unless this is *too* costly compile-time wise...

https://alive2.llvm.org/ce/z/EfapCo

Added: 
    

Modified: 
    llvm/lib/Analysis/ScalarEvolution.cpp
    llvm/test/Transforms/IndVarSimplify/pr58702-invalidate-scev-when-replacing-congruent-phis.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 751fad444a7d..546687bbe44a 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -5127,7 +5127,10 @@ struct BinaryOp {
 } // end anonymous namespace
 
 /// Try to map \p V into a BinaryOp, and return \c None on failure.
-static std::optional<BinaryOp> MatchBinaryOp(Value *V, DominatorTree &DT) {
+static std::optional<BinaryOp> MatchBinaryOp(Value *V, const DataLayout &DL,
+                                             AssumptionCache &AC,
+                                             const DominatorTree &DT,
+                                             const Instruction *CxtI) {
   auto *Op = dyn_cast<Operator>(V);
   if (!Op)
     return std::nullopt;
@@ -5143,11 +5146,21 @@ static std::optional<BinaryOp> MatchBinaryOp(Value *V, DominatorTree &DT) {
   case Instruction::UDiv:
   case Instruction::URem:
   case Instruction::And:
-  case Instruction::Or:
   case Instruction::AShr:
   case Instruction::Shl:
     return BinaryOp(Op);
 
+  case Instruction::Or: {
+    // LLVM loves to convert `add` of operands with no common bits
+    // into an `or`. But SCEV really doesn't deal with `or` that well,
+    // so try extra hard to recognize this `or` as an `add`.
+    if (haveNoCommonBitsSet(Op->getOperand(0), Op->getOperand(1), DL, &AC, CxtI,
+                            &DT, /*UseInstrInfo=*/true))
+      return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1),
+                      /*IsNSW=*/true, /*IsNUW=*/true);
+    return BinaryOp(Op);
+  }
+
   case Instruction::Xor:
     if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1)))
       // If the RHS of the xor is a signmask, then this is just an add.
@@ -5604,7 +5617,7 @@ const SCEV *ScalarEvolution::createSimpleAffineAddRec(PHINode *PN,
   assert(L && L->getHeader() == PN->getParent());
   assert(BEValueV && StartValueV);
 
-  auto BO = MatchBinaryOp(BEValueV, DT);
+  auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN);
   if (!BO)
     return nullptr;
 
@@ -5719,7 +5732,7 @@ const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
            cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
         SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap;
 
-        if (auto BO = MatchBinaryOp(BEValueV, DT)) {
+        if (auto BO = MatchBinaryOp(BEValueV, getDataLayout(), AC, DT, PN)) {
           if (BO->Opcode == Instruction::Add && BO->LHS == PN) {
             if (BO->IsNUW)
               Flags = setFlags(Flags, SCEV::FlagNUW);
@@ -7359,7 +7372,8 @@ ScalarEvolution::getOperandsToCreate(Value *V, SmallVectorImpl<Value *> &Ops) {
     return getUnknown(V);
 
   Operator *U = cast<Operator>(V);
-  if (auto BO = MatchBinaryOp(U, DT)) {
+  if (auto BO =
+          MatchBinaryOp(U, getDataLayout(), AC, DT, dyn_cast<Instruction>(V))) {
     bool IsConstArg = isa<ConstantInt>(BO->RHS);
     switch (BO->Opcode) {
     case Instruction::Add:
@@ -7375,7 +7389,8 @@ ScalarEvolution::getOperandsToCreate(Value *V, SmallVectorImpl<Value *> &Ops) {
           }
         }
         Ops.push_back(BO->RHS);
-        auto NewBO = MatchBinaryOp(BO->LHS, DT);
+        auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
+                                   dyn_cast<Instruction>(V));
         if (!NewBO ||
             (U->getOpcode() == Instruction::Add &&
              (NewBO->Opcode != Instruction::Add &&
@@ -7546,7 +7561,8 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
   const SCEV *RHS;
 
   Operator *U = cast<Operator>(V);
-  if (auto BO = MatchBinaryOp(U, DT)) {
+  if (auto BO =
+          MatchBinaryOp(U, getDataLayout(), AC, DT, dyn_cast<Instruction>(V))) {
     switch (BO->Opcode) {
     case Instruction::Add: {
       // The simple thing to do would be to just call getSCEV on both operands
@@ -7587,7 +7603,8 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
         else
           AddOps.push_back(getSCEV(BO->RHS));
 
-        auto NewBO = MatchBinaryOp(BO->LHS, DT);
+        auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
+                                   dyn_cast<Instruction>(V));
         if (!NewBO || (NewBO->Opcode != Instruction::Add &&
                        NewBO->Opcode != Instruction::Sub)) {
           AddOps.push_back(getSCEV(BO->LHS));
@@ -7618,7 +7635,8 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
         }
 
         MulOps.push_back(getSCEV(BO->RHS));
-        auto NewBO = MatchBinaryOp(BO->LHS, DT);
+        auto NewBO = MatchBinaryOp(BO->LHS, getDataLayout(), AC, DT,
+                                   dyn_cast<Instruction>(V));
         if (!NewBO || NewBO->Opcode != Instruction::Mul) {
           MulOps.push_back(getSCEV(BO->LHS));
           break;
@@ -7703,22 +7721,6 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
       break;
 
     case Instruction::Or:
-      // If the RHS of the Or is a constant, we may have something like:
-      // X*4+1 which got turned into X*4|1.  Handle this as an Add so loop
-      // optimizations will transparently handle this case.
-      //
-      // In order for this transformation to be safe, the LHS must be of the
-      // form X*(2^n) and the Or constant must be less than 2^n.
-      if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) {
-        const SCEV *LHS = getSCEV(BO->LHS);
-        const APInt &CIVal = CI->getValue();
-        if (GetMinTrailingZeros(LHS) >=
-            (CIVal.getBitWidth() - CIVal.countLeadingZeros())) {
-          // Build a plain add SCEV.
-          return getAddExpr(LHS, getSCEV(CI),
-                            (SCEV::NoWrapFlags)(SCEV::FlagNUW | SCEV::FlagNSW));
-        }
-      }
       // Binary `or` is a bit-wise `umax`.
       if (BO->LHS->getType()->isIntegerTy(1)) {
         LHS = getSCEV(BO->LHS);
@@ -7861,7 +7863,8 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
     return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
 
   case Instruction::SExt:
-    if (auto BO = MatchBinaryOp(U->getOperand(0), DT)) {
+    if (auto BO = MatchBinaryOp(U->getOperand(0), getDataLayout(), AC, DT,
+                                dyn_cast<Instruction>(V))) {
       // The NSW flag of a subtract does not always survive the conversion to
       // A + (-1)*B.  By pushing sign extension onto its operands we are much
       // more likely to preserve NSW and allow later AddRec optimisations.

diff  --git a/llvm/test/Transforms/IndVarSimplify/pr58702-invalidate-scev-when-replacing-congruent-phis.ll b/llvm/test/Transforms/IndVarSimplify/pr58702-invalidate-scev-when-replacing-congruent-phis.ll
index e12dd288b624..d3013655ae5d 100644
--- a/llvm/test/Transforms/IndVarSimplify/pr58702-invalidate-scev-when-replacing-congruent-phis.ll
+++ b/llvm/test/Transforms/IndVarSimplify/pr58702-invalidate-scev-when-replacing-congruent-phis.ll
@@ -4,8 +4,7 @@
 define i32 @test(i32 %p_16, i1 %c) {
 ; CHECK-LABEL: @test(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[OR:%.*]] = or i32 0, [[P_16:%.*]]
-; CHECK-NEXT:    [[XOR:%.*]] = xor i32 [[OR]], 6
+; CHECK-NEXT:    [[XOR:%.*]] = xor i32 [[P_16:%.*]], 6
 ; CHECK-NEXT:    [[OR_1:%.*]] = or i32 [[XOR]], [[P_16]]
 ; CHECK-NEXT:    [[XOR_1:%.*]] = xor i32 [[OR_1]], 6
 ; CHECK-NEXT:    [[OR_2:%.*]] = or i32 [[XOR_1]], [[P_16]]


        


More information about the llvm-commits mailing list