[llvm] [SLP][NFC]Introduce CombinedVectorize nodes, NFC. (PR #99309)

via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 17 04:50:28 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Alexey Bataev (alexey-bataev)

<details>
<summary>Changes</summary>

This adds combined vectorized node. It simplifies handling of the
combined nodes, like select/cmp, which can be reduced to min/max,
mul/add transformed to fma, etc. Improves cost mode handling and may end
up with better codegen in future (direct emission of the intrinsics).


---
Full diff: https://github.com/llvm/llvm-project/pull/99309.diff


1 Files Affected:

- (modified) llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp (+78-32) 


``````````diff
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 722590a840a54..1c2402bda551c 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -2936,13 +2936,24 @@ class BoUpSLP {
     /// (either with vector instruction or with scatter/gather
     /// intrinsics for store/load)?
     enum EntryState {
-      Vectorize,
-      ScatterVectorize,
-      StridedVectorize,
-      NeedToGather
+      Vectorize,         ///< The node is regularly vectorized.
+      ScatterVectorize,  ///< Masked scatter/gather node.
+      StridedVectorize,  ///< Strided loads (and stores)
+      NeedToGather,      ///< Gather/buildvector node.
+      CombinedVectorize, ///< Vectorized node, combined with its user into more
+                         ///< complex node like select/cmp to minmax, mul/add to
+                         ///< fma, etc. Must be used for the following nodes in
+                         ///< the pattern, not the very first one.
     };
     EntryState State;
 
+    /// List of combined opcodes supported by the vectorizer.
+    enum CombinedOpcode {
+      NotCombinedOp = -1,
+      MinMax = Instruction::OtherOpsEnd + 1,
+    };
+    CombinedOpcode CombinedOp = NotCombinedOp;
+
     /// Does this sequence require some shuffling?
     SmallVector<int, 4> ReuseShuffleIndices;
 
@@ -3130,6 +3141,9 @@ class BoUpSLP {
       case NeedToGather:
         dbgs() << "NeedToGather\n";
         break;
+      case CombinedVectorize:
+        dbgs() << "CombinedVectorize\n";
+        break;
       }
       dbgs() << "MainOp: ";
       if (MainOp)
@@ -7130,6 +7144,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
         buildTree_rec(PointerOps, Depth + 1, {TE, 0});
         LLVM_DEBUG(dbgs() << "SLP: added a vector of non-consecutive loads.\n");
         break;
+      case TreeEntry::CombinedVectorize:
       case TreeEntry::NeedToGather:
         llvm_unreachable("Unexpected loads state.");
       }
@@ -8188,6 +8203,22 @@ void BoUpSLP::transformNodes() {
       }
       break;
     }
+    case Instruction::Select: {
+      if (E.State != TreeEntry::Vectorize)
+        break;
+      auto [MinMaxID, SelectOnly] = canConvertToMinOrMaxIntrinsic(E.Scalars);
+      if (MinMaxID == Intrinsic::not_intrinsic)
+        break;
+      // This node is a minmax node.
+      E.CombinedOp = TreeEntry::MinMax;
+      TreeEntry *CondEntry = const_cast<TreeEntry *>(getOperandEntry(&E, 0));
+      if (SelectOnly && CondEntry->UserTreeIndices.size() == 1 &&
+          CondEntry->State == TreeEntry::Vectorize) {
+        // The condition node is part of the combined minmax node.
+        CondEntry->State = TreeEntry::CombinedVectorize;
+      }
+      break;
+    }
     default:
       break;
     }
@@ -9295,6 +9326,8 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
   Instruction *VL0 = E->getMainOp();
   unsigned ShuffleOrOp =
       E->isAltShuffle() ? (unsigned)Instruction::ShuffleVector : E->getOpcode();
+  if (E->CombinedOp != TreeEntry::NotCombinedOp)
+    ShuffleOrOp = E->CombinedOp;
   SetVector<Value *> UniqueValues(VL.begin(), VL.end());
   const unsigned Sz = UniqueValues.size();
   SmallBitVector UsedScalars(Sz, false);
@@ -9660,7 +9693,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
               CI->getOpcode(), OrigScalarTy, Builder.getInt1Ty(),
               CI->getPredicate(), CostKind, CI);
         }
-        ScalarCost = std::min(ScalarCost, IntrinsicCost);
+        ScalarCost = IntrinsicCost;
       }
 
       return ScalarCost;
@@ -9670,29 +9703,36 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
 
       InstructionCost VecCost = TTI->getCmpSelInstrCost(
           E->getOpcode(), VecTy, MaskTy, VecPred, CostKind, VL0);
-      // Check if it is possible and profitable to use min/max for selects
-      // in VL.
-      //
-      auto [MinMaxID, SelectOnly] = canConvertToMinOrMaxIntrinsic(VL);
-      if (MinMaxID != Intrinsic::not_intrinsic) {
-        Type *CanonicalType = VecTy;
-        if (CanonicalType->isPtrOrPtrVectorTy())
-          CanonicalType = CanonicalType->getWithNewType(IntegerType::get(
-              CanonicalType->getContext(),
-              DL->getTypeSizeInBits(CanonicalType->getScalarType())));
-        IntrinsicCostAttributes CostAttrs(MinMaxID, VecTy, {VecTy, VecTy});
-        InstructionCost IntrinsicCost =
-            TTI->getIntrinsicInstrCost(CostAttrs, CostKind);
-        // If the selects are the only uses of the compares, they will be
-        // dead and we can adjust the cost by removing their cost.
-        if (SelectOnly) {
-          auto *CI =
-              cast<CmpInst>(cast<Instruction>(VL.front())->getOperand(0));
-          IntrinsicCost -= TTI->getCmpSelInstrCost(CI->getOpcode(), VecTy,
-                                                   MaskTy, VecPred, CostKind);
-        }
-        VecCost = std::min(VecCost, IntrinsicCost);
-      }
+      return VecCost + CommonCost;
+    };
+    return GetCostDiff(GetScalarCost, GetVectorCost);
+  }
+  case TreeEntry::MinMax: {
+    auto [MinMaxID, SelectOnly] = canConvertToMinOrMaxIntrinsic(VL);
+    assert(MinMaxID != Intrinsic::not_intrinsic &&
+           "Expected min/max intrinsic");
+    auto GetScalarCost = [&, MinMaxID = MinMaxID](unsigned Idx) {
+      Type *CanonicalType = OrigScalarTy;
+      if (CanonicalType->isPtrOrPtrVectorTy())
+        CanonicalType = CanonicalType->getWithNewType(IntegerType::get(
+            CanonicalType->getContext(),
+            DL->getTypeSizeInBits(CanonicalType->getScalarType())));
+
+      IntrinsicCostAttributes CostAttrs(MinMaxID, CanonicalType,
+                                        {CanonicalType, CanonicalType});
+      InstructionCost ScalarCost =
+          TTI->getIntrinsicInstrCost(CostAttrs, CostKind);
+
+      return ScalarCost;
+    };
+    auto GetVectorCost = [&, MinMaxID = MinMaxID](InstructionCost CommonCost) {
+      Type *CanonicalType = VecTy;
+      if (CanonicalType->isPtrOrPtrVectorTy())
+        CanonicalType = CanonicalType->getWithNewType(IntegerType::get(
+            CanonicalType->getContext(),
+            DL->getTypeSizeInBits(CanonicalType->getScalarType())));
+      IntrinsicCostAttributes CostAttrs(MinMaxID, VecTy, {VecTy, VecTy});
+      InstructionCost VecCost = TTI->getIntrinsicInstrCost(CostAttrs, CostKind);
       return VecCost + CommonCost;
     };
     return GetCostDiff(GetScalarCost, GetVectorCost);
@@ -10432,6 +10472,15 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
   SmallPtrSet<Value *, 4> CheckedExtracts;
   for (unsigned I = 0, E = VectorizableTree.size(); I < E; ++I) {
     TreeEntry &TE = *VectorizableTree[I];
+    // No need to count the cost for combined entries, they are combined and
+    // just skip their cost.
+    if (TE.State == TreeEntry::CombinedVectorize) {
+      LLVM_DEBUG(
+          dbgs() << "SLP: Skipping cost for combined node that starts with "
+                 << *TE.Scalars[0] << ".\n";
+          TE.dump(); dbgs() << "SLP: Current total cost = " << Cost << "\n");
+      continue;
+    }
     if (TE.isGather()) {
       if (const TreeEntry *E = getTreeEntry(TE.getMainOp());
           E && E->getVectorFactor() == TE.getVectorFactor() &&
@@ -12779,10 +12828,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
     return ShuffleBuilder.finalize(E->ReuseShuffleIndices);
   };
 
-  assert((E->State == TreeEntry::Vectorize ||
-          E->State == TreeEntry::ScatterVectorize ||
-          E->State == TreeEntry::StridedVectorize) &&
-         "Unhandled state");
+  assert(!E->isGather() && "Unhandled state");
   unsigned ShuffleOrOp =
       E->isAltShuffle() ? (unsigned)Instruction::ShuffleVector : E->getOpcode();
   Instruction *VL0 = E->getMainOp();

``````````

</details>


https://github.com/llvm/llvm-project/pull/99309


More information about the llvm-commits mailing list