[llvm] [VectorCombine] Add a cost model for shuffleToIdentity. (PR #93937)

via llvm-commits llvm-commits at lists.llvm.org
Fri May 31 01:25:06 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: David Green (davemgreen)

<details>
<summary>Changes</summary>

As with every combine that has ever touched identites, apparently converting to identities can also cause performance issues. This adds a simple cost model which helps when the cost of concat's might be high and will hopefully be useful if more types of shuffles are supported.

---
Full diff: https://github.com/llvm/llvm-project/pull/93937.diff


2 Files Affected:

- (modified) llvm/lib/Transforms/Vectorize/VectorCombine.cpp (+56-16) 
- (added) llvm/test/Transforms/VectorCombine/X86/shuffleToIdentityCost.ll (+66) 


``````````diff
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 7ecfe5218ef67..8b94e6653f559 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -1670,8 +1670,12 @@ bool VectorCombine::foldShuffleOfShuffles(Instruction &I) {
 
 using InstLane = std::pair<Value *, int>;
 
-static InstLane lookThroughShuffles(Value *V, int Lane) {
+static InstLane
+lookThroughShuffles(Value *V, int Lane,
+                    SmallPtrSetImpl<Instruction *> *VisitedShuffles) {
   while (auto *SV = dyn_cast<ShuffleVectorInst>(V)) {
+    if (VisitedShuffles)
+      VisitedShuffles->insert(SV);
     unsigned NumElts =
         cast<FixedVectorType>(SV->getOperand(0)->getType())->getNumElements();
     int M = SV->getMaskValue(Lane);
@@ -1688,13 +1692,15 @@ static InstLane lookThroughShuffles(Value *V, int Lane) {
   return InstLane{V, Lane};
 }
 
-static SmallVector<InstLane>
-generateInstLaneVectorFromOperand(ArrayRef<InstLane> Item, int Op) {
+static SmallVector<InstLane> generateInstLaneVectorFromOperand(
+    ArrayRef<InstLane> Item, int Op,
+    SmallPtrSetImpl<Instruction *> *VisitedShuffles) {
   SmallVector<InstLane> NItem;
   for (InstLane IL : Item) {
     auto [V, Lane] = IL;
     InstLane OpLane =
-        V ? lookThroughShuffles(cast<Instruction>(V)->getOperand(Op), Lane)
+        V ? lookThroughShuffles(cast<Instruction>(V)->getOperand(Op), Lane,
+                                VisitedShuffles)
           : InstLane{nullptr, PoisonMaskElem};
     NItem.emplace_back(OpLane);
   }
@@ -1733,8 +1739,9 @@ static Value *generateNewInstTree(ArrayRef<InstLane> Item, FixedVectorType *Ty,
       Ops[Idx] = II->getOperand(Idx);
       continue;
     }
-    Ops[Idx] = generateNewInstTree(generateInstLaneVectorFromOperand(Item, Idx),
-                                   Ty, IdentityLeafs, SplatLeafs, Builder);
+    Ops[Idx] =
+        generateNewInstTree(generateInstLaneVectorFromOperand(Item, Idx, nullptr),
+                            Ty, IdentityLeafs, SplatLeafs, Builder);
   }
   Builder.SetInsertPoint(I);
   Type *DstTy =
@@ -1763,13 +1770,14 @@ bool VectorCombine::foldShuffleToIdentity(Instruction &I) {
   if (!Ty)
     return false;
 
+  SmallPtrSet<Instruction *, 4> VisitedShuffles;
   SmallVector<InstLane> Start(Ty->getNumElements());
   for (unsigned M = 0, E = Ty->getNumElements(); M < E; ++M)
-    Start[M] = lookThroughShuffles(&I, M);
+    Start[M] = lookThroughShuffles(&I, M, &VisitedShuffles);
 
   SmallVector<SmallVector<InstLane>> Worklist;
   Worklist.push_back(Start);
-  SmallPtrSet<Value *, 4> IdentityLeafs, SplatLeafs;
+  SmallPtrSet<Value *, 4> IdentityLeafs, SplatLeafs, ConstLeafs;
   unsigned NumVisited = 0;
 
   while (!Worklist.empty()) {
@@ -1803,7 +1811,7 @@ bool VectorCombine::foldShuffleToIdentity(Instruction &I) {
           Value *V = IL.first;
           return !V || V == FrontV;
         })) {
-      SplatLeafs.insert(FrontV);
+      ConstLeafs.insert(FrontV);
       continue;
     }
     // Look for a splat value.
@@ -1847,14 +1855,20 @@ bool VectorCombine::foldShuffleToIdentity(Instruction &I) {
     if ((isa<BinaryOperator>(FrontV) &&
          !cast<BinaryOperator>(FrontV)->isIntDivRem()) ||
         isa<CmpInst>(FrontV)) {
-      Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0));
-      Worklist.push_back(generateInstLaneVectorFromOperand(Item, 1));
+      Worklist.push_back(
+          generateInstLaneVectorFromOperand(Item, 0, &VisitedShuffles));
+      Worklist.push_back(
+          generateInstLaneVectorFromOperand(Item, 1, &VisitedShuffles));
     } else if (isa<UnaryOperator, TruncInst, ZExtInst, SExtInst>(FrontV)) {
-      Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0));
+      Worklist.push_back(
+          generateInstLaneVectorFromOperand(Item, 0, &VisitedShuffles));
     } else if (isa<SelectInst>(FrontV)) {
-      Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0));
-      Worklist.push_back(generateInstLaneVectorFromOperand(Item, 1));
-      Worklist.push_back(generateInstLaneVectorFromOperand(Item, 2));
+      Worklist.push_back(
+          generateInstLaneVectorFromOperand(Item, 0, &VisitedShuffles));
+      Worklist.push_back(
+          generateInstLaneVectorFromOperand(Item, 1, &VisitedShuffles));
+      Worklist.push_back(
+          generateInstLaneVectorFromOperand(Item, 2, &VisitedShuffles));
     } else if (auto *II = dyn_cast<IntrinsicInst>(FrontV);
                II && isTriviallyVectorizable(II->getIntrinsicID())) {
       for (unsigned Op = 0, E = II->getNumOperands() - 1; Op < E; Op++) {
@@ -1868,7 +1882,8 @@ bool VectorCombine::foldShuffleToIdentity(Instruction &I) {
             return false;
           continue;
         }
-        Worklist.push_back(generateInstLaneVectorFromOperand(Item, Op));
+        Worklist.push_back(
+            generateInstLaneVectorFromOperand(Item, Op, &VisitedShuffles));
       }
     } else {
       return false;
@@ -1878,6 +1893,31 @@ bool VectorCombine::foldShuffleToIdentity(Instruction &I) {
   if (NumVisited <= 1)
     return false;
 
+  LLVM_DEBUG(dbgs() << "Found a set of shuffles that can be removed:\n");
+  InstructionCost OldShuffleCost;
+  for (auto *I : VisitedShuffles) {
+    InstructionCost C = TTI.getInstructionCost(I, TTI::TCK_RecipThroughput);
+    LLVM_DEBUG(dbgs() << C << *I << "\n");
+    OldShuffleCost += C;
+  }
+  LLVM_DEBUG(dbgs() << "  total cost " << OldShuffleCost << "\n");
+  SmallVector<int, 16> ExtractMask(Ty->getNumElements());
+  std::iota(ExtractMask.begin(), ExtractMask.end(), 0);
+  InstructionCost IdentityCost = TTI.getShuffleCost(
+      TTI::SK_PermuteSingleSrc, Ty, ExtractMask, TTI::TCK_RecipThroughput);
+  InstructionCost SplatCost = TTI.getShuffleCost(
+      TTI::SK_Broadcast, Ty, std::nullopt, TTI::TCK_RecipThroughput);
+  InstructionCost NewShuffleCost =
+      IdentityCost * IdentityLeafs.size() + SplatCost * SplatLeafs.size();
+  LLVM_DEBUG(dbgs() << "      vs     " << NewShuffleCost << " (" << IdentityCost
+                    << " * " << IdentityLeafs.size() << " + " << SplatCost
+                    << " * " << SplatLeafs.size() << ")\n");
+
+  if (OldShuffleCost < NewShuffleCost)
+    return false;
+
+  SplatLeafs.insert(ConstLeafs.begin(), ConstLeafs.end());
+
   // If we got this far, we know the shuffles are superfluous and can be
   // removed. Scan through again and generate the new tree of instructions.
   Value *V = generateNewInstTree(Start, Ty, IdentityLeafs, SplatLeafs, Builder);
diff --git a/llvm/test/Transforms/VectorCombine/X86/shuffleToIdentityCost.ll b/llvm/test/Transforms/VectorCombine/X86/shuffleToIdentityCost.ll
new file mode 100644
index 0000000000000..47cc97ddd32ff
--- /dev/null
+++ b/llvm/test/Transforms/VectorCombine/X86/shuffleToIdentityCost.ll
@@ -0,0 +1,66 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -passes=vector-combine -S %s | FileCheck %s
+
+target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"
+target triple = "x86_64--linux-gnu"
+
+define void @test_op_pblend_b_0_avx(ptr %l8, ptr %lop_pblend_b_0) "target-cpu"="corei7-avx" {
+; CHECK-LABEL: define void @test_op_pblend_b_0_avx(
+; CHECK-SAME: ptr [[L8:%.*]], ptr [[LOP_PBLEND_B_0:%.*]]) #[[ATTR0:[0-9]+]] {
+; CHECK-NEXT:    [[LT162:%.*]] = load <32 x i8>, ptr [[L8]], align 16
+; CHECK-NEXT:    [[L9:%.*]] = shufflevector <32 x i8> [[LT162]], <32 x i8> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
+; CHECK-NEXT:    [[L10:%.*]] = icmp ugt <16 x i8> [[L9]], <i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7>
+; CHECK-NEXT:    [[L11:%.*]] = shufflevector <32 x i8> [[LT162]], <32 x i8> poison, <16 x i32> <i32 16, i32 17, i32 18, i32 19, i32 20, i32 21, i32 22, i32 23, i32 24, i32 25, i32 26, i32 27, i32 28, i32 29, i32 30, i32 31>
+; CHECK-NEXT:    [[L12:%.*]] = icmp ugt <16 x i8> [[L11]], <i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7>
+; CHECK-NEXT:    [[L13:%.*]] = getelementptr inbounds i8, ptr [[L8]], i64 16
+; CHECK-NEXT:    [[L14:%.*]] = load <32 x i8>, ptr [[L13]], align 16
+; CHECK-NEXT:    [[L15:%.*]] = shufflevector <32 x i8> [[L14]], <32 x i8> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
+; CHECK-NEXT:    [[L16:%.*]] = select <16 x i1> [[L10]], <16 x i8> [[L9]], <16 x i8> [[L15]]
+; CHECK-NEXT:    [[L17:%.*]] = shufflevector <32 x i8> [[L14]], <32 x i8> poison, <16 x i32> <i32 16, i32 17, i32 18, i32 19, i32 20, i32 21, i32 22, i32 23, i32 24, i32 25, i32 26, i32 27, i32 28, i32 29, i32 30, i32 31>
+; CHECK-NEXT:    [[L18:%.*]] = select <16 x i1> [[L12]], <16 x i8> [[L11]], <16 x i8> [[L17]]
+; CHECK-NEXT:    [[L19:%.*]] = shufflevector <16 x i8> [[L16]], <16 x i8> [[L18]], <32 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15, i32 16, i32 17, i32 18, i32 19, i32 20, i32 21, i32 22, i32 23, i32 24, i32 25, i32 26, i32 27, i32 28, i32 29, i32 30, i32 31>
+; CHECK-NEXT:    store <32 x i8> [[L19]], ptr [[LOP_PBLEND_B_0]], align 32
+; CHECK-NEXT:    ret void
+;
+  %lt162 = load <32 x i8>, ptr %l8, align 16
+  %l9 = shufflevector <32 x i8> %lt162, <32 x i8> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
+  %l10 = icmp ugt <16 x i8> %l9, <i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7>
+  %l11 = shufflevector <32 x i8> %lt162, <32 x i8> poison, <16 x i32> <i32 16, i32 17, i32 18, i32 19, i32 20, i32 21, i32 22, i32 23, i32 24, i32 25, i32 26, i32 27, i32 28, i32 29, i32 30, i32 31>
+  %l12 = icmp ugt <16 x i8> %l11, <i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7>
+  %l13 = getelementptr inbounds i8, ptr %l8, i64 16
+  %l14 = load <32 x i8>, ptr %l13, align 16
+  %l15 = shufflevector <32 x i8> %l14, <32 x i8> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
+  %l16 = select <16 x i1> %l10, <16 x i8> %l9, <16 x i8> %l15
+  %l17 = shufflevector <32 x i8> %l14, <32 x i8> poison, <16 x i32> <i32 16, i32 17, i32 18, i32 19, i32 20, i32 21, i32 22, i32 23, i32 24, i32 25, i32 26, i32 27, i32 28, i32 29, i32 30, i32 31>
+  %l18 = select <16 x i1> %l12, <16 x i8> %l11, <16 x i8> %l17
+  %l19 = shufflevector <16 x i8> %l16, <16 x i8> %l18, <32 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15, i32 16, i32 17, i32 18, i32 19, i32 20, i32 21, i32 22, i32 23, i32 24, i32 25, i32 26, i32 27, i32 28, i32 29, i32 30, i32 31>
+  store <32 x i8> %l19, ptr %lop_pblend_b_0, align 32
+  ret void
+}
+
+define void @test_op_pblend_b_0_nocpu(ptr %l8, ptr %lop_pblend_b_0) {
+; CHECK-LABEL: define void @test_op_pblend_b_0_nocpu(
+; CHECK-SAME: ptr [[L8:%.*]], ptr [[LOP_PBLEND_B_0:%.*]]) {
+; CHECK-NEXT:    [[LT162:%.*]] = load <32 x i8>, ptr [[L8]], align 16
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp ugt <32 x i8> [[LT162]], <i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7>
+; CHECK-NEXT:    [[L13:%.*]] = getelementptr inbounds i8, ptr [[L8]], i64 16
+; CHECK-NEXT:    [[L14:%.*]] = load <32 x i8>, ptr [[L13]], align 16
+; CHECK-NEXT:    [[L19:%.*]] = select <32 x i1> [[TMP1]], <32 x i8> [[LT162]], <32 x i8> [[L14]]
+; CHECK-NEXT:    store <32 x i8> [[L19]], ptr [[LOP_PBLEND_B_0]], align 32
+; CHECK-NEXT:    ret void
+;
+  %lt162 = load <32 x i8>, ptr %l8, align 16
+  %l9 = shufflevector <32 x i8> %lt162, <32 x i8> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
+  %l10 = icmp ugt <16 x i8> %l9, <i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7>
+  %l11 = shufflevector <32 x i8> %lt162, <32 x i8> poison, <16 x i32> <i32 16, i32 17, i32 18, i32 19, i32 20, i32 21, i32 22, i32 23, i32 24, i32 25, i32 26, i32 27, i32 28, i32 29, i32 30, i32 31>
+  %l12 = icmp ugt <16 x i8> %l11, <i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7, i8 7>
+  %l13 = getelementptr inbounds i8, ptr %l8, i64 16
+  %l14 = load <32 x i8>, ptr %l13, align 16
+  %l15 = shufflevector <32 x i8> %l14, <32 x i8> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
+  %l16 = select <16 x i1> %l10, <16 x i8> %l9, <16 x i8> %l15
+  %l17 = shufflevector <32 x i8> %l14, <32 x i8> poison, <16 x i32> <i32 16, i32 17, i32 18, i32 19, i32 20, i32 21, i32 22, i32 23, i32 24, i32 25, i32 26, i32 27, i32 28, i32 29, i32 30, i32 31>
+  %l18 = select <16 x i1> %l12, <16 x i8> %l11, <16 x i8> %l17
+  %l19 = shufflevector <16 x i8> %l16, <16 x i8> %l18, <32 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15, i32 16, i32 17, i32 18, i32 19, i32 20, i32 21, i32 22, i32 23, i32 24, i32 25, i32 26, i32 27, i32 28, i32 29, i32 30, i32 31>
+  store <32 x i8> %l19, ptr %lop_pblend_b_0, align 32
+  ret void
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/93937


More information about the llvm-commits mailing list