[llvm-branch-commits] [llvm] bb82cf6 - Revert "[SLP] NFC. Refactor getSameOpcode and reduce for loop iterations. (#1…"

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Jan 9 18:09:19 PST 2025


Author: Vitaly Buka
Date: 2025-01-09T18:09:16-08:00
New Revision: bb82cf61a7a0b7c74dee596f85243cdcbe59e695

URL: https://github.com/llvm/llvm-project/commit/bb82cf61a7a0b7c74dee596f85243cdcbe59e695
DIFF: https://github.com/llvm/llvm-project/commit/bb82cf61a7a0b7c74dee596f85243cdcbe59e695.diff

LOG: Revert "[SLP] NFC. Refactor getSameOpcode and reduce for loop iterations. (#1…"

This reverts commit 36b423e0f85cb35eb8b211662a0fab70d476f501.

Added: 
    

Modified: 
    llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 8ff70fdb1180b0..6360ddb57007d6 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -916,22 +916,24 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
   if (It == VL.end())
     return InstructionsState::invalid();
 
-  Instruction *MainOp = cast<Instruction>(*It);
+  Value *V = *It;
   unsigned InstCnt = std::count_if(It, VL.end(), IsaPred<Instruction>);
-  if ((VL.size() > 2 && !isa<PHINode>(MainOp) && InstCnt < VL.size() / 2) ||
+  if ((VL.size() > 2 && !isa<PHINode>(V) && InstCnt < VL.size() / 2) ||
       (VL.size() == 2 && InstCnt < 2))
     return InstructionsState::invalid();
 
-  bool IsCastOp = isa<CastInst>(MainOp);
-  bool IsBinOp = isa<BinaryOperator>(MainOp);
-  bool IsCmpOp = isa<CmpInst>(MainOp);
-  CmpInst::Predicate BasePred = IsCmpOp ? cast<CmpInst>(MainOp)->getPredicate()
-                                        : CmpInst::BAD_ICMP_PREDICATE;
-  Instruction *AltOp = MainOp;
-  unsigned Opcode = MainOp->getOpcode();
+  bool IsCastOp = isa<CastInst>(V);
+  bool IsBinOp = isa<BinaryOperator>(V);
+  bool IsCmpOp = isa<CmpInst>(V);
+  CmpInst::Predicate BasePred =
+      IsCmpOp ? cast<CmpInst>(V)->getPredicate() : CmpInst::BAD_ICMP_PREDICATE;
+  unsigned Opcode = cast<Instruction>(V)->getOpcode();
   unsigned AltOpcode = Opcode;
+  unsigned AltIndex = std::distance(VL.begin(), It);
 
-  bool SwappedPredsCompatible = IsCmpOp && [&]() {
+  bool SwappedPredsCompatible = [&]() {
+    if (!IsCmpOp)
+      return false;
     SetVector<unsigned> UniquePreds, UniqueNonSwappedPreds;
     UniquePreds.insert(BasePred);
     UniqueNonSwappedPreds.insert(BasePred);
@@ -954,18 +956,18 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
   }();
   // Check for one alternate opcode from another BinaryOperator.
   // TODO - generalize to support all operators (types, calls etc.).
+  auto *IBase = cast<Instruction>(V);
   Intrinsic::ID BaseID = 0;
   SmallVector<VFInfo> BaseMappings;
-  if (auto *CallBase = dyn_cast<CallInst>(MainOp)) {
+  if (auto *CallBase = dyn_cast<CallInst>(IBase)) {
     BaseID = getVectorIntrinsicIDForCall(CallBase, &TLI);
     BaseMappings = VFDatabase(*CallBase).getMappings(*CallBase);
     if (!isTriviallyVectorizable(BaseID) && BaseMappings.empty())
       return InstructionsState::invalid();
   }
   bool AnyPoison = InstCnt != VL.size();
-  // Skip MainOp.
-  for (Value *V : iterator_range(It + 1, VL.end())) {
-    auto *I = dyn_cast<Instruction>(V);
+  for (int Cnt = 0, E = VL.size(); Cnt < E; Cnt++) {
+    auto *I = dyn_cast<Instruction>(VL[Cnt]);
     if (!I)
       continue;
 
@@ -981,11 +983,11 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
       if (Opcode == AltOpcode && isValidForAlternation(InstOpcode) &&
           isValidForAlternation(Opcode)) {
         AltOpcode = InstOpcode;
-        AltOp = I;
+        AltIndex = Cnt;
         continue;
       }
     } else if (IsCastOp && isa<CastInst>(I)) {
-      Value *Op0 = MainOp->getOperand(0);
+      Value *Op0 = IBase->getOperand(0);
       Type *Ty0 = Op0->getType();
       Value *Op1 = I->getOperand(0);
       Type *Ty1 = Op1->getType();
@@ -997,12 +999,12 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
                  isValidForAlternation(InstOpcode) &&
                  "Cast isn't safe for alternation, logic needs to be updated!");
           AltOpcode = InstOpcode;
-          AltOp = I;
+          AltIndex = Cnt;
           continue;
         }
       }
-    } else if (auto *Inst = dyn_cast<CmpInst>(I); Inst && IsCmpOp) {
-      auto *BaseInst = cast<CmpInst>(MainOp);
+    } else if (auto *Inst = dyn_cast<CmpInst>(VL[Cnt]); Inst && IsCmpOp) {
+      auto *BaseInst = cast<CmpInst>(V);
       Type *Ty0 = BaseInst->getOperand(0)->getType();
       Type *Ty1 = Inst->getOperand(0)->getType();
       if (Ty0 == Ty1) {
@@ -1016,21 +1018,21 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
         CmpInst::Predicate SwappedCurrentPred =
             CmpInst::getSwappedPredicate(CurrentPred);
 
-        if ((VL.size() == 2 || SwappedPredsCompatible) &&
+        if ((E == 2 || SwappedPredsCompatible) &&
             (BasePred == CurrentPred || BasePred == SwappedCurrentPred))
           continue;
 
         if (isCmpSameOrSwapped(BaseInst, Inst, TLI))
           continue;
-        auto *AltInst = cast<CmpInst>(AltOp);
-        if (MainOp != AltOp) {
+        auto *AltInst = cast<CmpInst>(VL[AltIndex]);
+        if (AltIndex) {
           if (isCmpSameOrSwapped(AltInst, Inst, TLI))
             continue;
         } else if (BasePred != CurrentPred) {
           assert(
               isValidForAlternation(InstOpcode) &&
               "CmpInst isn't safe for alternation, logic needs to be updated!");
-          AltOp = I;
+          AltIndex = Cnt;
           continue;
         }
         CmpInst::Predicate AltPred = AltInst->getPredicate();
@@ -1044,17 +1046,17 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
              "CastInst.");
       if (auto *Gep = dyn_cast<GetElementPtrInst>(I)) {
         if (Gep->getNumOperands() != 2 ||
-            Gep->getOperand(0)->getType() != MainOp->getOperand(0)->getType())
+            Gep->getOperand(0)->getType() != IBase->getOperand(0)->getType())
           return InstructionsState::invalid();
       } else if (auto *EI = dyn_cast<ExtractElementInst>(I)) {
         if (!isVectorLikeInstWithConstOps(EI))
           return InstructionsState::invalid();
       } else if (auto *LI = dyn_cast<LoadInst>(I)) {
-        auto *BaseLI = cast<LoadInst>(MainOp);
+        auto *BaseLI = cast<LoadInst>(IBase);
         if (!LI->isSimple() || !BaseLI->isSimple())
           return InstructionsState::invalid();
       } else if (auto *Call = dyn_cast<CallInst>(I)) {
-        auto *CallBase = cast<CallInst>(MainOp);
+        auto *CallBase = cast<CallInst>(IBase);
         if (Call->getCalledFunction() != CallBase->getCalledFunction())
           return InstructionsState::invalid();
         if (Call->hasOperandBundles() &&
@@ -1084,7 +1086,8 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
     return InstructionsState::invalid();
   }
 
-  return InstructionsState(MainOp, AltOp);
+  return InstructionsState(cast<Instruction>(V),
+                           cast<Instruction>(VL[AltIndex]));
 }
 
 /// \returns true if all of the values in \p VL have the same type or false


        


More information about the llvm-branch-commits mailing list