[llvm] 36b423e - [SLP] NFC. Refactor getSameOpcode and reduce for loop iterations. (#122241)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 9 17:06:10 PST 2025


Author: Han-Kuan Chen
Date: 2025-01-10T09:06:07+08:00
New Revision: 36b423e0f85cb35eb8b211662a0fab70d476f501

URL: https://github.com/llvm/llvm-project/commit/36b423e0f85cb35eb8b211662a0fab70d476f501
DIFF: https://github.com/llvm/llvm-project/commit/36b423e0f85cb35eb8b211662a0fab70d476f501.diff

LOG: [SLP] NFC. Refactor getSameOpcode and reduce for loop iterations. (#122241)

Replace Cnt and AltIndex with MainOp and AltOp.
Reduce the number of iterations in the for loop.

Added: 
    

Modified: 
    llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 6360ddb57007d6..8ff70fdb1180b0 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -916,24 +916,22 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
   if (It == VL.end())
     return InstructionsState::invalid();
 
-  Value *V = *It;
+  Instruction *MainOp = cast<Instruction>(*It);
   unsigned InstCnt = std::count_if(It, VL.end(), IsaPred<Instruction>);
-  if ((VL.size() > 2 && !isa<PHINode>(V) && InstCnt < VL.size() / 2) ||
+  if ((VL.size() > 2 && !isa<PHINode>(MainOp) && InstCnt < VL.size() / 2) ||
       (VL.size() == 2 && InstCnt < 2))
     return InstructionsState::invalid();
 
-  bool IsCastOp = isa<CastInst>(V);
-  bool IsBinOp = isa<BinaryOperator>(V);
-  bool IsCmpOp = isa<CmpInst>(V);
-  CmpInst::Predicate BasePred =
-      IsCmpOp ? cast<CmpInst>(V)->getPredicate() : CmpInst::BAD_ICMP_PREDICATE;
-  unsigned Opcode = cast<Instruction>(V)->getOpcode();
+  bool IsCastOp = isa<CastInst>(MainOp);
+  bool IsBinOp = isa<BinaryOperator>(MainOp);
+  bool IsCmpOp = isa<CmpInst>(MainOp);
+  CmpInst::Predicate BasePred = IsCmpOp ? cast<CmpInst>(MainOp)->getPredicate()
+                                        : CmpInst::BAD_ICMP_PREDICATE;
+  Instruction *AltOp = MainOp;
+  unsigned Opcode = MainOp->getOpcode();
   unsigned AltOpcode = Opcode;
-  unsigned AltIndex = std::distance(VL.begin(), It);
 
-  bool SwappedPredsCompatible = [&]() {
-    if (!IsCmpOp)
-      return false;
+  bool SwappedPredsCompatible = IsCmpOp && [&]() {
     SetVector<unsigned> UniquePreds, UniqueNonSwappedPreds;
     UniquePreds.insert(BasePred);
     UniqueNonSwappedPreds.insert(BasePred);
@@ -956,18 +954,18 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
   }();
   // Check for one alternate opcode from another BinaryOperator.
   // TODO - generalize to support all operators (types, calls etc.).
-  auto *IBase = cast<Instruction>(V);
   Intrinsic::ID BaseID = 0;
   SmallVector<VFInfo> BaseMappings;
-  if (auto *CallBase = dyn_cast<CallInst>(IBase)) {
+  if (auto *CallBase = dyn_cast<CallInst>(MainOp)) {
     BaseID = getVectorIntrinsicIDForCall(CallBase, &TLI);
     BaseMappings = VFDatabase(*CallBase).getMappings(*CallBase);
     if (!isTriviallyVectorizable(BaseID) && BaseMappings.empty())
       return InstructionsState::invalid();
   }
   bool AnyPoison = InstCnt != VL.size();
-  for (int Cnt = 0, E = VL.size(); Cnt < E; Cnt++) {
-    auto *I = dyn_cast<Instruction>(VL[Cnt]);
+  // Skip MainOp.
+  for (Value *V : iterator_range(It + 1, VL.end())) {
+    auto *I = dyn_cast<Instruction>(V);
     if (!I)
       continue;
 
@@ -983,11 +981,11 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
       if (Opcode == AltOpcode && isValidForAlternation(InstOpcode) &&
           isValidForAlternation(Opcode)) {
         AltOpcode = InstOpcode;
-        AltIndex = Cnt;
+        AltOp = I;
         continue;
       }
     } else if (IsCastOp && isa<CastInst>(I)) {
-      Value *Op0 = IBase->getOperand(0);
+      Value *Op0 = MainOp->getOperand(0);
       Type *Ty0 = Op0->getType();
       Value *Op1 = I->getOperand(0);
       Type *Ty1 = Op1->getType();
@@ -999,12 +997,12 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
                  isValidForAlternation(InstOpcode) &&
                  "Cast isn't safe for alternation, logic needs to be updated!");
           AltOpcode = InstOpcode;
-          AltIndex = Cnt;
+          AltOp = I;
           continue;
         }
       }
-    } else if (auto *Inst = dyn_cast<CmpInst>(VL[Cnt]); Inst && IsCmpOp) {
-      auto *BaseInst = cast<CmpInst>(V);
+    } else if (auto *Inst = dyn_cast<CmpInst>(I); Inst && IsCmpOp) {
+      auto *BaseInst = cast<CmpInst>(MainOp);
       Type *Ty0 = BaseInst->getOperand(0)->getType();
       Type *Ty1 = Inst->getOperand(0)->getType();
       if (Ty0 == Ty1) {
@@ -1018,21 +1016,21 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
         CmpInst::Predicate SwappedCurrentPred =
             CmpInst::getSwappedPredicate(CurrentPred);
 
-        if ((E == 2 || SwappedPredsCompatible) &&
+        if ((VL.size() == 2 || SwappedPredsCompatible) &&
             (BasePred == CurrentPred || BasePred == SwappedCurrentPred))
           continue;
 
         if (isCmpSameOrSwapped(BaseInst, Inst, TLI))
           continue;
-        auto *AltInst = cast<CmpInst>(VL[AltIndex]);
-        if (AltIndex) {
+        auto *AltInst = cast<CmpInst>(AltOp);
+        if (MainOp != AltOp) {
           if (isCmpSameOrSwapped(AltInst, Inst, TLI))
             continue;
         } else if (BasePred != CurrentPred) {
           assert(
               isValidForAlternation(InstOpcode) &&
               "CmpInst isn't safe for alternation, logic needs to be updated!");
-          AltIndex = Cnt;
+          AltOp = I;
           continue;
         }
         CmpInst::Predicate AltPred = AltInst->getPredicate();
@@ -1046,17 +1044,17 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
              "CastInst.");
       if (auto *Gep = dyn_cast<GetElementPtrInst>(I)) {
         if (Gep->getNumOperands() != 2 ||
-            Gep->getOperand(0)->getType() != IBase->getOperand(0)->getType())
+            Gep->getOperand(0)->getType() != MainOp->getOperand(0)->getType())
           return InstructionsState::invalid();
       } else if (auto *EI = dyn_cast<ExtractElementInst>(I)) {
         if (!isVectorLikeInstWithConstOps(EI))
           return InstructionsState::invalid();
       } else if (auto *LI = dyn_cast<LoadInst>(I)) {
-        auto *BaseLI = cast<LoadInst>(IBase);
+        auto *BaseLI = cast<LoadInst>(MainOp);
         if (!LI->isSimple() || !BaseLI->isSimple())
           return InstructionsState::invalid();
       } else if (auto *Call = dyn_cast<CallInst>(I)) {
-        auto *CallBase = cast<CallInst>(IBase);
+        auto *CallBase = cast<CallInst>(MainOp);
         if (Call->getCalledFunction() != CallBase->getCalledFunction())
           return InstructionsState::invalid();
         if (Call->hasOperandBundles() &&
@@ -1086,8 +1084,7 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
     return InstructionsState::invalid();
   }
 
-  return InstructionsState(cast<Instruction>(V),
-                           cast<Instruction>(VL[AltIndex]));
+  return InstructionsState(MainOp, AltOp);
 }
 
 /// \returns true if all of the values in \p VL have the same type or false


        


More information about the llvm-commits mailing list