[llvm] [WIP][SLP] SLP's copyable elements based upon Main/Alt operations. (PR #124242)

via llvm-commits llvm-commits at lists.llvm.org
Fri Jan 24 00:37:50 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-vectorizers

Author: Dinar Temirbulatov (dtemirbulatov)

<details>
<summary>Changes</summary>

Added testcase from https://github.com/llvm/llvm-project/issues/110740.
Still there are several issues with this change that can be reproduced with LNT by adding "-mllvm -slp-vectorize-copyable=true -mllvm -slp-threshold=-99999" and demote values support, float operations support with "fast-math".

---

Patch is 95.67 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/124242.diff


2 Files Affected:

- (modified) llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp (+648-115) 
- (modified) llvm/test/Transforms/SLPVectorizer/X86/vect_copyable_in_binops.ll (+436-98) 


``````````diff
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index c98d872fb6467f..47b61496b5e155 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -201,6 +201,10 @@ static cl::opt<bool> VectorizeNonPowerOf2(
     "slp-vectorize-non-power-of-2", cl::init(false), cl::Hidden,
     cl::desc("Try to vectorize with non-power-of-2 number of elements."));
 
+static cl::opt<bool>
+    VectorizeCopyable("slp-vectorize-copyable", cl::init(false), cl::Hidden,
+                      cl::desc("Try to vectorize with copyable elements."));
+
 // Limit the number of alias checks. The limit is chosen so that
 // it has no negative effect on the llvm benchmarks.
 static const unsigned AliasedCheckLimit = 10;
@@ -426,6 +430,8 @@ static bool isVectorLikeInstWithConstOps(Value *V) {
   if (isa<ExtractElementInst>(I))
     return isConstant(I->getOperand(1));
   assert(isa<InsertElementInst>(V) && "Expected only insertelement.");
+  if (I->getNumOperands() < 2)
+    return false;
   return isConstant(I->getOperand(2));
 }
 
@@ -594,6 +600,41 @@ static std::optional<unsigned> getElementIndex(const Value *Inst,
   return Index;
 }
 
+/// Checks if the \p Opcode can be considered as an operand of a (possibly)
+/// binary operation \p I.
+/// \returns The code of the binary operation of instruction \p I if the
+/// instruction with \p Opcode can be considered as an operand of \p I with the
+/// default value.
+static unsigned tryToRepresentAsInstArg(unsigned Opcode, Instruction *I) {
+  if (Opcode != Instruction::PHI && Opcode != Instruction::Invoke &&
+      !isa<FPMathOperator>(I) &&
+      ((I->getType()->isIntegerTy() &&
+        (I->getOpcode() == Instruction::Add ||
+         I->getOpcode() == Instruction::And ||
+         I->getOpcode() == Instruction::AShr ||
+         I->getOpcode() == Instruction::BitCast ||
+         I->getOpcode() == Instruction::Call ||
+         // Issue with scheduling with isVectorLikeInstWithConstOps
+         // operations.
+         // I->getOpcode() == Instruction::ExtractElement ||
+         // I->getOpcode() == Instruction::ExtractValue ||
+         I->getOpcode() == Instruction::ICmp ||
+         I->getOpcode() == Instruction::Load ||
+         I->getOpcode() == Instruction::LShr ||
+         I->getOpcode() == Instruction::Mul ||
+         I->getOpcode() == Instruction::Or ||
+         I->getOpcode() == Instruction::PtrToInt ||
+         I->getOpcode() == Instruction::Select ||
+         I->getOpcode() == Instruction::SExt ||
+         I->getOpcode() == Instruction::Shl ||
+         I->getOpcode() == Instruction::Sub ||
+         I->getOpcode() == Instruction::Trunc ||
+         I->getOpcode() == Instruction::Xor ||
+         I->getOpcode() == Instruction::ZExt))))
+    return I->getOpcode();
+  return 0;
+}
+
 namespace {
 /// Specifies the way the mask should be analyzed for undefs/poisonous elements
 /// in the shuffle mask.
@@ -853,6 +894,16 @@ class InstructionsState {
 
 } // end anonymous namespace
 
+/// Chooses the correct key for scheduling data. If \p Op has the same (or
+/// alternate) opcode as \p OpValue, the key is \p Op. Otherwise the key is \p
+/// OpValue.
+static Value *isOneOf(const InstructionsState &S, Value *Op) {
+  auto *I = dyn_cast<Instruction>(Op);
+  if (I && S.isOpcodeOrAlt(I))
+    return Op;
+  return S.getMainOp();
+}
+
 /// \returns true if \p Opcode is allowed as part of the main/alternate
 /// instruction for SLP vectorization.
 ///
@@ -865,6 +916,14 @@ static bool isValidForAlternation(unsigned Opcode) {
   return true;
 }
 
+// Check for inner dependencies, we could not support such depenedies if it
+// comes from a main operaion, only from an alternative.
+static bool checkCopyableInnerDep(ArrayRef<Value *> VL,
+                                  const InstructionsState &S);
+
+// Determine that the vector could be vectorized with copyable elements.
+static bool isCopyableOp(ArrayRef<Value *> VL, Value *Main, Value *Alt);
+
 static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
                                        const TargetLibraryInfo &TLI);
 
@@ -917,19 +976,53 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
     return InstructionsState::invalid();
 
   Instruction *MainOp = cast<Instruction>(*It);
+  Instruction *AltOp = MainOp;
+  unsigned Opcode = MainOp->getOpcode();
+  unsigned AltOpcode = Opcode;
+  for (Value *V : iterator_range(It + 1, VL.end())) {
+    Instruction *Inst = dyn_cast<Instruction>(V);
+    if (!Inst)
+      continue;
+    unsigned VOpcode = Inst->getOpcode();
+    if (Inst && AltOpcode == Opcode && !isa<PHINode>(Inst) &&
+        VOpcode != Opcode && isValidForAlternation(VOpcode)) {
+      AltOpcode = VOpcode;
+      AltOp = Inst;
+      break;
+    }
+  }
   unsigned InstCnt = std::count_if(It, VL.end(), IsaPred<Instruction>);
   if ((VL.size() > 2 && !isa<PHINode>(MainOp) && InstCnt < VL.size() / 2) ||
       (VL.size() == 2 && InstCnt < 2))
     return InstructionsState::invalid();
+  bool IsBinOp = isa<BinaryOperator>(MainOp);
+  bool IsCopyable = false;
 
+  if (MainOp && AltOp && MainOp != AltOp) {
+    if (!IsBinOp && isa<BinaryOperator>(AltOp) && !isa<PHINode>(MainOp)) {
+      std::swap(MainOp, AltOp);
+      std::swap(AltOpcode, Opcode);
+      IsBinOp = true;
+    }
+    IsCopyable = VectorizeCopyable && isCopyableOp(VL, MainOp, AltOp);
+    if (IsCopyable && isa<CmpInst>(AltOp)) {
+      Type *Ty0 = MainOp->getOperand(0)->getType();
+      Type *Ty1 = AltOp->getOperand(0)->getType();
+      if (Ty0 != Ty1)
+        return InstructionsState::invalid();
+    }
+    if (!IsCopyable) {
+      MainOp = cast<Instruction>(*It);
+      AltOp = MainOp;
+      Opcode = MainOp->getOpcode();
+      AltOpcode = Opcode;
+      IsBinOp = isa<BinaryOperator>(MainOp);
+    }
+  }
   bool IsCastOp = isa<CastInst>(MainOp);
-  bool IsBinOp = isa<BinaryOperator>(MainOp);
   bool IsCmpOp = isa<CmpInst>(MainOp);
   CmpInst::Predicate BasePred = IsCmpOp ? cast<CmpInst>(MainOp)->getPredicate()
                                         : CmpInst::BAD_ICMP_PREDICATE;
-  Instruction *AltOp = MainOp;
-  unsigned Opcode = MainOp->getOpcode();
-  unsigned AltOpcode = Opcode;
 
   bool SwappedPredsCompatible = IsCmpOp && [&]() {
     SetVector<unsigned> UniquePreds, UniqueNonSwappedPreds;
@@ -984,7 +1077,7 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
         AltOp = I;
         continue;
       }
-    } else if (IsCastOp && isa<CastInst>(I)) {
+    } else if ((IsCastOp || IsCopyable) && isa<CastInst>(I)) {
       Value *Op0 = MainOp->getOperand(0);
       Type *Ty0 = Op0->getType();
       Value *Op1 = I->getOperand(0);
@@ -1001,13 +1094,15 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
           continue;
         }
       }
-    } else if (auto *Inst = dyn_cast<CmpInst>(I); Inst && IsCmpOp) {
+    } else if (auto *Inst = dyn_cast<CmpInst>(I);
+               Inst && (IsCmpOp || IsCopyable)) {
       auto *BaseInst = cast<CmpInst>(MainOp);
       Type *Ty0 = BaseInst->getOperand(0)->getType();
       Type *Ty1 = Inst->getOperand(0)->getType();
       if (Ty0 == Ty1) {
-        assert(InstOpcode == Opcode && "Expected same CmpInst opcode.");
-        assert(InstOpcode == AltOpcode &&
+        assert((IsCopyable || InstOpcode == Opcode) &&
+               "Expected same CmpInst opcode.");
+        assert((IsCopyable || InstOpcode == AltOpcode) &&
                "Alternate instructions are only supported by BinaryOperator "
                "and CastInst.");
         // Check for compatible operands. If the corresponding operands are not
@@ -1038,23 +1133,32 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
             AltPred == CurrentPred || AltPred == SwappedCurrentPred)
           continue;
       }
-    } else if (InstOpcode == Opcode) {
-      assert(InstOpcode == AltOpcode &&
+    } else if (InstOpcode == Opcode ||
+               (IsCopyable && InstOpcode == AltOpcode)) {
+      assert((IsCopyable || InstOpcode == AltOpcode) &&
              "Alternate instructions are only supported by BinaryOperator and "
              "CastInst.");
+      Instruction *Op = MainOp;
+      if (IsCopyable) {
+        if (InstOpcode != Opcode && InstOpcode != AltOpcode) {
+          Op = I;
+        } else if (Opcode != AltOpcode && InstOpcode == AltOpcode) {
+          Op = AltOp;
+        }
+      }
       if (auto *Gep = dyn_cast<GetElementPtrInst>(I)) {
         if (Gep->getNumOperands() != 2 ||
-            Gep->getOperand(0)->getType() != MainOp->getOperand(0)->getType())
+            Gep->getOperand(0)->getType() != Op->getOperand(0)->getType())
           return InstructionsState::invalid();
       } else if (auto *EI = dyn_cast<ExtractElementInst>(I)) {
         if (!isVectorLikeInstWithConstOps(EI))
           return InstructionsState::invalid();
       } else if (auto *LI = dyn_cast<LoadInst>(I)) {
-        auto *BaseLI = cast<LoadInst>(MainOp);
+        auto *BaseLI = cast<LoadInst>(Op);
         if (!LI->isSimple() || !BaseLI->isSimple())
           return InstructionsState::invalid();
       } else if (auto *Call = dyn_cast<CallInst>(I)) {
-        auto *CallBase = cast<CallInst>(MainOp);
+        auto *CallBase = cast<CallInst>(Op);
         if (Call->getCalledFunction() != CallBase->getCalledFunction())
           return InstructionsState::invalid();
         if (Call->hasOperandBundles() &&
@@ -1069,13 +1173,14 @@ static InstructionsState getSameOpcode(ArrayRef<Value *> VL,
           return InstructionsState::invalid();
         if (!ID) {
           SmallVector<VFInfo> Mappings = VFDatabase(*Call).getMappings(*Call);
-          if (Mappings.size() != BaseMappings.size() ||
-              Mappings.front().ISA != BaseMappings.front().ISA ||
-              Mappings.front().ScalarName != BaseMappings.front().ScalarName ||
-              Mappings.front().VectorName != BaseMappings.front().VectorName ||
-              Mappings.front().Shape.VF != BaseMappings.front().Shape.VF ||
-              Mappings.front().Shape.Parameters !=
-                  BaseMappings.front().Shape.Parameters)
+          if (Mappings.size() &&
+              (Mappings.size() != BaseMappings.size() ||
+               Mappings.front().ISA != BaseMappings.front().ISA ||
+               Mappings.front().ScalarName != BaseMappings.front().ScalarName ||
+               Mappings.front().VectorName != BaseMappings.front().VectorName ||
+               Mappings.front().Shape.VF != BaseMappings.front().Shape.VF ||
+               Mappings.front().Shape.Parameters !=
+                   BaseMappings.front().Shape.Parameters))
             return InstructionsState::invalid();
         }
       }
@@ -1124,6 +1229,46 @@ static bool doesInTreeUserNeedToExtract(Value *Scalar, Instruction *UserInst,
   }
 }
 
+static bool checkCopyableInnerDep(ArrayRef<Value *> VL,
+                                  const InstructionsState &S) {
+  SmallSet<Value *, 4> Ops;
+  unsigned Opcode = S.getOpcode();
+  for (Value *V : VL) {
+    auto *I = dyn_cast<Instruction>(V);
+    if (!I)
+      continue;
+    if (I->getOpcode() == Opcode)
+      Ops.insert(V);
+  }
+  for (Value *V : VL) {
+    auto *I = dyn_cast<Instruction>(V);
+    if (!I)
+      continue;
+    for (Use &U : I->operands()) {
+      if (auto *Op = dyn_cast<Instruction>(U.get()); Op && Ops.contains(Op))
+        return false;
+    }
+  }
+  return true;
+}
+
+static bool isCopyableOp(ArrayRef<Value *> VL, Value *Main, Value *Alt) {
+  if (any_of(VL, IsaPred<PoisonValue>) || Main == Alt ||
+      !isa<BinaryOperator>(Main) || !isa<Instruction>(Alt) ||
+      find_if(VL, IsaPred<PHINode>) != VL.end())
+    return false;
+
+  Instruction *MainOp = cast<Instruction>(Main);
+  Instruction *AltOp = cast<Instruction>(Alt);
+
+  if (isa<BinaryOperator>(MainOp) && !isa<BinaryOperator>(AltOp) &&
+      isValidForAlternation(MainOp->getOpcode()) &&
+      isValidForAlternation(AltOp->getOpcode()) &&
+      tryToRepresentAsInstArg(MainOp->getOpcode(), AltOp) &&
+      tryToRepresentAsInstArg(AltOp->getOpcode(), MainOp))
+    return true;
+  return false;
+}
 /// \returns the AA location that is being access by the instruction.
 static MemoryLocation getLocation(Instruction *I) {
   if (StoreInst *SI = dyn_cast<StoreInst>(I))
@@ -1463,6 +1608,7 @@ class BoUpSLP {
     MultiNodeScalars.clear();
     MustGather.clear();
     NonScheduledFirst.clear();
+    CopyableAltOp.clear();
     EntryToLastInstruction.clear();
     LoadEntriesToVectorize.clear();
     IsGraphTransformMode = false;
@@ -2461,8 +2607,16 @@ class BoUpSLP {
           }
           bool IsInverseOperation = !isCommutative(cast<Instruction>(VL[Lane]));
           bool APO = (OpIdx == 0) ? false : IsInverseOperation;
-          OpsVec[OpIdx][Lane] = {cast<Instruction>(VL[Lane])->getOperand(OpIdx),
-                                 APO, false};
+          Instruction *Inst = cast<Instruction>(VL[Lane]);
+          if (Inst->getOpcode() != MainOp->getOpcode() &&
+              OpIdx > (Inst->getNumOperands() - 1)) {
+            OpsVec[OpIdx][Lane] = {
+                PoisonValue::get(MainOp->getOperand(OpIdx)->getType()), true,
+                false};
+          } else {
+            OpsVec[OpIdx][Lane] = {
+                cast<Instruction>(VL[Lane])->getOperand(OpIdx), APO, false};
+          }
         }
       }
     }
@@ -3298,6 +3452,7 @@ class BoUpSLP {
                          ///< complex node like select/cmp to minmax, mul/add to
                          ///< fma, etc. Must be used for the following nodes in
                          ///< the pattern, not the very first one.
+      CopyableVectorize, ///< The node for copyable elements.
     };
     EntryState State;
 
@@ -3357,7 +3512,8 @@ class BoUpSLP {
       if (Operands.size() < OpIdx + 1)
         Operands.resize(OpIdx + 1);
       assert(Operands[OpIdx].empty() && "Already resized?");
-      assert(OpVL.size() <= Scalars.size() &&
+      assert((State == TreeEntry::CopyableVectorize ||
+              OpVL.size() <= Scalars.size()) &&
              "Number of operands is greater than the number of scalars.");
       Operands[OpIdx].resize(OpVL.size());
       copy(OpVL, Operands[OpIdx].begin());
@@ -3401,7 +3557,9 @@ class BoUpSLP {
     }
 
     /// Some of the instructions in the list have alternate opcodes.
-    bool isAltShuffle() const { return S.isAltShuffle(); }
+    bool isAltShuffle() const {
+      return S.isAltShuffle() && State != TreeEntry::CopyableVectorize;
+    }
 
     bool isOpcodeOrAlt(Instruction *I) const { return S.isOpcodeOrAlt(I); }
 
@@ -3524,6 +3682,9 @@ class BoUpSLP {
       case CombinedVectorize:
         dbgs() << "CombinedVectorize\n";
         break;
+      case CopyableVectorize:
+        dbgs() << "CopyableVectorize\n";
+        break;
       }
       if (S) {
         dbgs() << "MainOp: " << *S.getMainOp() << "\n";
@@ -3619,6 +3780,7 @@ class BoUpSLP {
     // for non-power-of-two vectors.
     assert(
         (hasFullVectorsOrPowerOf2(*TTI, getValueType(VL.front()), VL.size()) ||
+         EntryState == TreeEntry::CopyableVectorize ||
          ReuseShuffleIndices.empty()) &&
         "Reshuffling scalars not yet supported for nodes with padding");
     Last->ReuseShuffleIndices.append(ReuseShuffleIndices.begin(),
@@ -3642,8 +3804,13 @@ class BoUpSLP {
       Last->ReorderIndices.append(ReorderIndices.begin(), ReorderIndices.end());
     }
     if (!Last->isGather()) {
-      for (Value *V : VL) {
+      unsigned Opcode = S.getOpcode();
+      for (unsigned i = 0; i < VL.size(); ++i) {
+        Value *V = VL[i];
         const TreeEntry *TE = getTreeEntry(V);
+        Instruction *I = dyn_cast<Instruction>(V);
+        bool IsAltInst = (I) ? I->getOpcode() != Opcode : false;
+
         assert((!TE || TE == Last || doesNotNeedToBeScheduled(V)) &&
                "Scalar already in tree!");
         if (TE) {
@@ -3651,6 +3818,10 @@ class BoUpSLP {
             MultiNodeScalars.try_emplace(V).first->getSecond().push_back(Last);
           continue;
         }
+        if (EntryState == TreeEntry::CopyableVectorize && IsAltInst) {
+          CopyableAltOp.insert(V);
+          continue;
+        }
         ScalarToTreeEntry[V] = Last;
       }
       // Update the scheduler bundle to point to this TreeEntry.
@@ -3725,6 +3896,10 @@ class BoUpSLP {
   bool areAltOperandsProfitable(const InstructionsState &S,
                                 ArrayRef<Value *> VL) const;
 
+  /// Check that we can represent operations as copyable with looking to
+  /// operations operands.
+  bool canRepresentAsCopyable(const InstructionsState &S, ArrayRef<Value *> VL);
+
   /// Checks if the specified list of the instructions/values can be vectorized
   /// and fills required data before actual scheduling of the instructions.
   TreeEntry::EntryState
@@ -3746,6 +3921,9 @@ class BoUpSLP {
   /// A list of scalars that we found that we need to keep as scalars.
   ValueSet MustGather;
 
+  /// A set op scalars that we are considoring as copyable operations.
+  ValueSet CopyableAltOp;
+
   /// A set of first non-schedulable values.
   ValueSet NonScheduledFirst;
 
@@ -3875,15 +4053,16 @@ class BoUpSLP {
 
     ScheduleData() = default;
 
-    void init(int BlockSchedulingRegionID, Instruction *I) {
+    void init(int BlockSchedulingRegionID, Value *OpVal) {
       FirstInBundle = this;
       NextInBundle = nullptr;
       NextLoadStore = nullptr;
       IsScheduled = false;
       SchedulingRegionID = BlockSchedulingRegionID;
       clearDependencies();
-      Inst = I;
+      OpValue = OpVal;
       TE = nullptr;
+      IsCopy = false;
     }
 
     /// Verify basic self consistency properties
@@ -3990,6 +4169,9 @@ class BoUpSLP {
 
     Instruction *Inst = nullptr;
 
+    /// Opcode of the current instruction in the schedule data.
+    Value *OpValue = nullptr;
+
     /// The TreeEntry that this instruction corresponds to.
     TreeEntry *TE = nullptr;
 
@@ -4037,6 +4219,9 @@ class BoUpSLP {
     /// True if this instruction is scheduled (or considered as scheduled in the
     /// dry-run).
     bool IsScheduled = false;
+
+    /// True if this instruction is copy.
+    bool IsCopy = false;
   };
 
 #ifndef NDEBUG
@@ -4106,6 +4291,31 @@ class BoUpSLP {
       return nullptr;
     }
 
+    ScheduleData *getScheduleData(Value *V, Value *Key) {
+      auto I = ExtraScheduleDataMap.find(V);
+      if (I != ExtraScheduleDataMap.end()) {
+        ScheduleData *SD = I->second.lookup(Key);
+        if (SD && isInSchedulingRegion(SD))
+          return SD;
+      }
+      if (V == Key)
+        return getScheduleData(V);
+      return nullptr;
+    }
+
+    ScheduleData *getScheduleData(Value *V, const TreeEntry *E) {
+      ScheduleData *SD = getScheduleData(V);
+      if (SD && isInSchedulingRegion(SD) && SD->TE == E)
+        return SD;
+      auto I = ExtraScheduleDataMap.find(V);
+      if (I == ExtraScheduleDataMap.end())
+        return nullptr;
+      for (auto &P : I->second)
+        if (isInSchedulingRegion(P.second) && P.second->TE == E)
+          return P.second;
+      return nullptr;
+    }
+
     bool isInSchedulingRegion(ScheduleData *SD) const {
       return SD->SchedulingRegionID == SchedulingRegionID;
     }
@@ -4119,30 +4329,33 @@ class BoUpSLP {
 
       for (ScheduleData *BundleMember = SD; BundleMember;
            BundleMember = BundleMember->NextInBundle) {
-
         // Handle the def-use chain dependencies.
 
         // Decrement the unscheduled counter and insert to ready list if ready.
-        auto &&DecrUnsched = [this, &ReadyList](Instruction *I) {
-          ScheduleData *OpDef = getScheduleData(I);
-          if (OpDef && OpDef->hasValidDependencies() &&
-              OpDef->incrementUnscheduledDeps(-1) == 0) {
-            // There are no more unscheduled dependencies after
-            // decrementing, so we can put the dependent instruction
-            // into the ready list.
-            ScheduleData *DepBundle = OpDef->FirstInBundle;
-            assert(!DepBundle->IsScheduled &&
-                   "already scheduled bundle gets ready");
-            ReadyList.insert(DepBundle);
-            LLVM_DEBUG(dbgs()
-                       << "SLP:    gets ready (def): " << *DepBundle << "\n");
-          }
+        auto &&DecrUnsched = [this, &ReadyList, &Bu...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/124242


More information about the llvm-commits mailing list