[llvm] [SLP] NFC. Refactor getSameOpcode and reduce for loop iterations. (PR #122241)

via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 9 02:01:39 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-vectorizers

Author: Han-Kuan Chen (HanKuanChen)

<details>
<summary>Changes</summary>

Replace Cnt and AltIndex with MainOp and AltOp as InstructionsState.
Reduce the number of iterations in the for loop.

---
Full diff: https://github.com/llvm/llvm-project/pull/122241.diff


1 Files Affected:

- (modified) llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp (+28-32) 


``````````diff
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index c4582df89213d8..aa12af43dcf95c 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -915,24 +915,22 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
   if (It == VL.end())
     return InstructionsState::invalid();
 
-  Value *V = *It;
+  Instruction *MainOp = cast<Instruction>(*It);
   unsigned InstCnt = std::count_if(It, VL.end(), IsaPred<Instruction>);
-  if ((VL.size() > 2 && !isa<PHINode>(V) && InstCnt < VL.size() / 2) ||
+  if ((VL.size() > 2 && !isa<PHINode>(MainOp) && InstCnt < VL.size() / 2) ||
       (VL.size() == 2 && InstCnt < 2))
     return InstructionsState::invalid();
 
-  bool IsCastOp = isa<CastInst>(V);
-  bool IsBinOp = isa<BinaryOperator>(V);
-  bool IsCmpOp = isa<CmpInst>(V);
-  CmpInst::Predicate BasePred =
-      IsCmpOp ? cast<CmpInst>(V)->getPredicate() : CmpInst::BAD_ICMP_PREDICATE;
-  unsigned Opcode = cast<Instruction>(V)->getOpcode();
+  bool IsCastOp = isa<CastInst>(MainOp);
+  bool IsBinOp = isa<BinaryOperator>(MainOp);
+  bool IsCmpOp = isa<CmpInst>(MainOp);
+  CmpInst::Predicate BasePred = IsCmpOp ? cast<CmpInst>(MainOp)->getPredicate()
+                                        : CmpInst::BAD_ICMP_PREDICATE;
+  Instruction *AltOp = MainOp;
+  unsigned Opcode = MainOp->getOpcode();
   unsigned AltOpcode = Opcode;
-  unsigned AltIndex = std::distance(VL.begin(), It);
 
-  bool SwappedPredsCompatible = [&]() {
-    if (!IsCmpOp)
-      return false;
+  bool SwappedPredsCompatible = IsCmpOp && [&]() {
     SetVector<unsigned> UniquePreds, UniqueNonSwappedPreds;
     UniquePreds.insert(BasePred);
     UniqueNonSwappedPreds.insert(BasePred);
@@ -955,18 +953,18 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
   }();
   // Check for one alternate opcode from another BinaryOperator.
   // TODO - generalize to support all operators (types, calls etc.).
-  auto *IBase = cast<Instruction>(V);
   Intrinsic::ID BaseID = 0;
   SmallVector<VFInfo> BaseMappings;
-  if (auto *CallBase = dyn_cast<CallInst>(IBase)) {
+  if (auto *CallBase = dyn_cast<CallInst>(MainOp)) {
     BaseID = getVectorIntrinsicIDForCall(CallBase, &TLI);
     BaseMappings = VFDatabase(*CallBase).getMappings(*CallBase);
     if (!isTriviallyVectorizable(BaseID) && BaseMappings.empty())
       return InstructionsState::invalid();
   }
   bool AnyPoison = InstCnt != VL.size();
-  for (int Cnt = 0, E = VL.size(); Cnt < E; Cnt++) {
-    auto *I = dyn_cast<Instruction>(VL[Cnt]);
+  // Skip MainOp.
+  while (++It != VL.end()) {
+    auto *I = dyn_cast<Instruction>(*It);
     if (!I)
       continue;
 
@@ -982,11 +980,11 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
       if (Opcode == AltOpcode && isValidForAlternation(InstOpcode) &&
           isValidForAlternation(Opcode)) {
         AltOpcode = InstOpcode;
-        AltIndex = Cnt;
+        AltOp = I;
         continue;
       }
     } else if (IsCastOp && isa<CastInst>(I)) {
-      Value *Op0 = IBase->getOperand(0);
+      Value *Op0 = MainOp->getOperand(0);
       Type *Ty0 = Op0->getType();
       Value *Op1 = I->getOperand(0);
       Type *Ty1 = Op1->getType();
@@ -998,12 +996,12 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
                  isValidForAlternation(InstOpcode) &&
                  "Cast isn't safe for alternation, logic needs to be updated!");
           AltOpcode = InstOpcode;
-          AltIndex = Cnt;
+          AltOp = I;
           continue;
         }
       }
-    } else if (auto *Inst = dyn_cast<CmpInst>(VL[Cnt]); Inst && IsCmpOp) {
-      auto *BaseInst = cast<CmpInst>(V);
+    } else if (auto *Inst = dyn_cast<CmpInst>(I); Inst && IsCmpOp) {
+      auto *BaseInst = cast<CmpInst>(MainOp);
       Type *Ty0 = BaseInst->getOperand(0)->getType();
       Type *Ty1 = Inst->getOperand(0)->getType();
       if (Ty0 == Ty1) {
@@ -1017,24 +1015,23 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
         CmpInst::Predicate SwappedCurrentPred =
             CmpInst::getSwappedPredicate(CurrentPred);
 
-        if ((E == 2 || SwappedPredsCompatible) &&
+        if ((VL.size() == 2 || SwappedPredsCompatible) &&
             (BasePred == CurrentPred || BasePred == SwappedCurrentPred))
           continue;
 
         if (isCmpSameOrSwapped(BaseInst, Inst, TLI))
           continue;
-        auto *AltInst = cast<CmpInst>(VL[AltIndex]);
-        if (AltIndex) {
-          if (isCmpSameOrSwapped(AltInst, Inst, TLI))
+        if (MainOp != AltOp) {
+          if (isCmpSameOrSwapped(cast<CmpInst>(AltOp), Inst, TLI))
             continue;
         } else if (BasePred != CurrentPred) {
           assert(
               isValidForAlternation(InstOpcode) &&
               "CmpInst isn't safe for alternation, logic needs to be updated!");
-          AltIndex = Cnt;
+          AltOp = I;
           continue;
         }
-        CmpInst::Predicate AltPred = AltInst->getPredicate();
+        CmpInst::Predicate AltPred = cast<CmpInst>(AltOp)->getPredicate();
         if (BasePred == CurrentPred || BasePred == SwappedCurrentPred ||
             AltPred == CurrentPred || AltPred == SwappedCurrentPred)
           continue;
@@ -1045,17 +1042,17 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
              "CastInst.");
       if (auto *Gep = dyn_cast<GetElementPtrInst>(I)) {
         if (Gep->getNumOperands() != 2 ||
-            Gep->getOperand(0)->getType() != IBase->getOperand(0)->getType())
+            Gep->getOperand(0)->getType() != MainOp->getOperand(0)->getType())
           return InstructionsState::invalid();
       } else if (auto *EI = dyn_cast<ExtractElementInst>(I)) {
         if (!isVectorLikeInstWithConstOps(EI))
           return InstructionsState::invalid();
       } else if (auto *LI = dyn_cast<LoadInst>(I)) {
-        auto *BaseLI = cast<LoadInst>(IBase);
+        auto *BaseLI = cast<LoadInst>(MainOp);
         if (!LI->isSimple() || !BaseLI->isSimple())
           return InstructionsState::invalid();
       } else if (auto *Call = dyn_cast<CallInst>(I)) {
-        auto *CallBase = cast<CallInst>(IBase);
+        auto *CallBase = cast<CallInst>(MainOp);
         if (Call->getCalledFunction() != CallBase->getCalledFunction())
           return InstructionsState::invalid();
         if (Call->hasOperandBundles() &&
@@ -1085,8 +1082,7 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
     return InstructionsState::invalid();
   }
 
-  return InstructionsState(cast<Instruction>(V),
-                           cast<Instruction>(VL[AltIndex]));
+  return InstructionsState(MainOp, cast<Instruction>(AltOp));
 }
 
 /// \returns true if all of the values in \p VL have the same type or false

``````````

</details>


https://github.com/llvm/llvm-project/pull/122241


More information about the llvm-commits mailing list