[llvm] 26a9f3f - [SLP][NFC]Cleanup getSameOpcode, return InstructionsState::invalid() for non-valid inputs

Alexey Bataev via llvm-commits llvm-commits at lists.llvm.org
Fri Nov 8 14:03:25 PST 2024


Author: Alexey Bataev
Date: 2024-11-08T14:00:32-08:00
New Revision: 26a9f3f5906c62cff7f2245b98affa432b504a87

URL: https://github.com/llvm/llvm-project/commit/26a9f3f5906c62cff7f2245b98affa432b504a87
DIFF: https://github.com/llvm/llvm-project/commit/26a9f3f5906c62cff7f2245b98affa432b504a87.diff

LOG: [SLP][NFC]Cleanup getSameOpcode, return InstructionsState::invalid() for non-valid inputs

Just a cleanup and related changes

Added: 
    

Modified: 
    llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index a6accf0318a30f..4a73b9c2c4b34a 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -832,6 +832,7 @@ struct InstructionsState {
   InstructionsState() = delete;
   InstructionsState(Value *OpValue, Instruction *MainOp, Instruction *AltOp)
       : OpValue(OpValue), MainOp(MainOp), AltOp(AltOp) {}
+  static InstructionsState invalid() { return {nullptr, nullptr, nullptr}; }
 };
 
 } // end anonymous namespace
@@ -891,20 +892,19 @@ static bool isCmpSameOrSwapped(const CmpInst *BaseCI, const CmpInst *CI,
 /// could be vectorized even if its structure is diverse.
 static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
                                        const TargetLibraryInfo &TLI) {
-  constexpr unsigned BaseIndex = 0;
   // Make sure these are all Instructions.
-  if (llvm::any_of(VL, [](Value *V) { return !isa<Instruction>(V); }))
-    return InstructionsState(VL[BaseIndex], nullptr, nullptr);
+  if (!all_of(VL, IsaPred<Instruction>))
+    return InstructionsState::invalid();
 
-  bool IsCastOp = isa<CastInst>(VL[BaseIndex]);
-  bool IsBinOp = isa<BinaryOperator>(VL[BaseIndex]);
-  bool IsCmpOp = isa<CmpInst>(VL[BaseIndex]);
+  Value *V = VL.front();
+  bool IsCastOp = isa<CastInst>(V);
+  bool IsBinOp = isa<BinaryOperator>(V);
+  bool IsCmpOp = isa<CmpInst>(V);
   CmpInst::Predicate BasePred =
-      IsCmpOp ? cast<CmpInst>(VL[BaseIndex])->getPredicate()
-              : CmpInst::BAD_ICMP_PREDICATE;
-  unsigned Opcode = cast<Instruction>(VL[BaseIndex])->getOpcode();
+      IsCmpOp ? cast<CmpInst>(V)->getPredicate() : CmpInst::BAD_ICMP_PREDICATE;
+  unsigned Opcode = cast<Instruction>(V)->getOpcode();
   unsigned AltOpcode = Opcode;
-  unsigned AltIndex = BaseIndex;
+  unsigned AltIndex = 0;
 
   bool SwappedPredsCompatible = [&]() {
     if (!IsCmpOp)
@@ -931,14 +931,14 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
   }();
   // Check for one alternate opcode from another BinaryOperator.
   // TODO - generalize to support all operators (types, calls etc.).
-  auto *IBase = cast<Instruction>(VL[BaseIndex]);
+  auto *IBase = cast<Instruction>(V);
   Intrinsic::ID BaseID = 0;
   SmallVector<VFInfo> BaseMappings;
   if (auto *CallBase = dyn_cast<CallInst>(IBase)) {
     BaseID = getVectorIntrinsicIDForCall(CallBase, &TLI);
     BaseMappings = VFDatabase(*CallBase).getMappings(*CallBase);
     if (!isTriviallyVectorizable(BaseID) && BaseMappings.empty())
-      return InstructionsState(VL[BaseIndex], nullptr, nullptr);
+      return InstructionsState::invalid();
   }
   for (int Cnt = 0, E = VL.size(); Cnt < E; Cnt++) {
     auto *I = cast<Instruction>(VL[Cnt]);
@@ -970,7 +970,7 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
         }
       }
     } else if (auto *Inst = dyn_cast<CmpInst>(VL[Cnt]); Inst && IsCmpOp) {
-      auto *BaseInst = cast<CmpInst>(VL[BaseIndex]);
+      auto *BaseInst = cast<CmpInst>(V);
       Type *Ty0 = BaseInst->getOperand(0)->getType();
       Type *Ty1 = Inst->getOperand(0)->getType();
       if (Ty0 == Ty1) {
@@ -988,7 +988,7 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
         if (isCmpSameOrSwapped(BaseInst, Inst, TLI))
           continue;
         auto *AltInst = cast<CmpInst>(VL[AltIndex]);
-        if (AltIndex != BaseIndex) {
+        if (AltIndex) {
           if (isCmpSameOrSwapped(AltInst, Inst, TLI))
             continue;
         } else if (BasePred != CurrentPred) {
@@ -1007,27 +1007,28 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
       if (auto *Gep = dyn_cast<GetElementPtrInst>(I)) {
         if (Gep->getNumOperands() != 2 ||
             Gep->getOperand(0)->getType() != IBase->getOperand(0)->getType())
-          return InstructionsState(VL[BaseIndex], nullptr, nullptr);
+          return InstructionsState::invalid();
       } else if (auto *EI = dyn_cast<ExtractElementInst>(I)) {
         if (!isVectorLikeInstWithConstOps(EI))
-          return InstructionsState(VL[BaseIndex], nullptr, nullptr);
+          return InstructionsState::invalid();
       } else if (auto *LI = dyn_cast<LoadInst>(I)) {
         auto *BaseLI = cast<LoadInst>(IBase);
         if (!LI->isSimple() || !BaseLI->isSimple())
-          return InstructionsState(VL[BaseIndex], nullptr, nullptr);
+          return InstructionsState::invalid();
       } else if (auto *Call = dyn_cast<CallInst>(I)) {
         auto *CallBase = cast<CallInst>(IBase);
         if (Call->getCalledFunction() != CallBase->getCalledFunction())
-          return InstructionsState(VL[BaseIndex], nullptr, nullptr);
-        if (Call->hasOperandBundles() && (!CallBase->hasOperandBundles() ||
-            !std::equal(Call->op_begin() + Call->getBundleOperandsStartIndex(),
-                        Call->op_begin() + Call->getBundleOperandsEndIndex(),
-                        CallBase->op_begin() +
-                            CallBase->getBundleOperandsStartIndex())))
-          return InstructionsState(VL[BaseIndex], nullptr, nullptr);
+          return InstructionsState::invalid();
+        if (Call->hasOperandBundles() &&
+            (!CallBase->hasOperandBundles() ||
+             !std::equal(Call->op_begin() + Call->getBundleOperandsStartIndex(),
+                         Call->op_begin() + Call->getBundleOperandsEndIndex(),
+                         CallBase->op_begin() +
+                             CallBase->getBundleOperandsStartIndex())))
+          return InstructionsState::invalid();
         Intrinsic::ID ID = getVectorIntrinsicIDForCall(Call, &TLI);
         if (ID != BaseID)
-          return InstructionsState(VL[BaseIndex], nullptr, nullptr);
+          return InstructionsState::invalid();
         if (!ID) {
           SmallVector<VFInfo> Mappings = VFDatabase(*Call).getMappings(*Call);
           if (Mappings.size() != BaseMappings.size() ||
@@ -1037,15 +1038,15 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
               Mappings.front().Shape.VF != BaseMappings.front().Shape.VF ||
               Mappings.front().Shape.Parameters !=
                   BaseMappings.front().Shape.Parameters)
-            return InstructionsState(VL[BaseIndex], nullptr, nullptr);
+            return InstructionsState::invalid();
         }
       }
       continue;
     }
-    return InstructionsState(VL[BaseIndex], nullptr, nullptr);
+    return InstructionsState::invalid();
   }
 
-  return InstructionsState(VL[BaseIndex], cast<Instruction>(VL[BaseIndex]),
+  return InstructionsState(V, cast<Instruction>(V),
                            cast<Instruction>(VL[AltIndex]));
 }
 
@@ -8019,7 +8020,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
   }
 
   // Don't handle vectors.
-  if (!SLPReVec && getValueType(S.OpValue)->isVectorTy()) {
+  if (!SLPReVec && getValueType(VL.front())->isVectorTy()) {
     LLVM_DEBUG(dbgs() << "SLP: Gathering due to vector type.\n");
     newTreeEntry(VL, std::nullopt /*not vectorized*/, S, UserTreeIdx);
     return;
@@ -8088,7 +8089,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
       UserTreeIdx.UserTE->State == TreeEntry::ScatterVectorize;
   bool AreAllSameBlock = S.getOpcode() && allSameBlock(VL);
   bool AreScatterAllGEPSameBlock =
-      (IsScatterVectorizeUserTE && S.OpValue->getType()->isPointerTy() &&
+      (IsScatterVectorizeUserTE && VL.front()->getType()->isPointerTy() &&
        VL.size() > 2 &&
        all_of(VL,
               [&BB](Value *V) {
@@ -8104,7 +8105,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
                        SortedIndices));
   bool AreAllSameInsts = AreAllSameBlock || AreScatterAllGEPSameBlock;
   if (!AreAllSameInsts || (!S.getOpcode() && allConstant(VL)) || isSplat(VL) ||
-      (isa<InsertElementInst, ExtractValueInst, ExtractElementInst>(
+      (isa_and_present<InsertElementInst, ExtractValueInst, ExtractElementInst>(
            S.OpValue) &&
        !all_of(VL, isVectorLikeInstWithConstOps)) ||
       NotProfitableForVectorization(VL)) {
@@ -8161,7 +8162,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
   // Special processing for sorted pointers for ScatterVectorize node with
   // constant indeces only.
   if (!AreAllSameBlock && AreScatterAllGEPSameBlock) {
-    assert(S.OpValue->getType()->isPointerTy() &&
+    assert(VL.front()->getType()->isPointerTy() &&
            count_if(VL, IsaPred<GetElementPtrInst>) >= 2 &&
            "Expected pointers only.");
     // Reset S to make it GetElementPtr kind of node.


        


More information about the llvm-commits mailing list