[llvm] VectorCombine: refactor foldShuffleToIdentity (NFC) (PR #92766)

Ramkumar Ramachandra via llvm-commits llvm-commits at lists.llvm.org
Mon May 20 08:08:17 PDT 2024


https://github.com/artagnon created https://github.com/llvm/llvm-project/pull/92766

Lift out the long lambdas into static functions, use C++ destructing syntax, and fix other minor things to improve the readability of the function.

>From 1e82e2b754b21d5975abf4cbaac1b0032f48860a Mon Sep 17 00:00:00 2001
From: Ramkumar Ramachandra <r at artagnon.com>
Date: Mon, 20 May 2024 16:03:57 +0100
Subject: [PATCH] VectorCombine: refactor foldShuffleToIdentity (NFC)

Lift out the long lambdas into static functions, use C++ destructing
syntax, and fix other minor things to improve the readability of the
function.
---
 .../Transforms/Vectorize/VectorCombine.cpp    | 228 +++++++++---------
 1 file changed, 119 insertions(+), 109 deletions(-)

diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 15deaf908422d..5d45c012b4b87 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -1668,6 +1668,86 @@ bool VectorCombine::foldShuffleOfShuffles(Instruction &I) {
   return true;
 }
 
+using InstLane = std::pair<Value *, int>;
+
+static InstLane lookThroughShuffles(Value *V, int Lane) {
+  while (auto *SV = dyn_cast<ShuffleVectorInst>(V)) {
+    unsigned NumElts =
+        cast<FixedVectorType>(SV->getOperand(0)->getType())->getNumElements();
+    int M = SV->getMaskValue(Lane);
+    if (M < 0)
+      return {nullptr, PoisonMaskElem};
+    if (static_cast<unsigned>(M) < NumElts) {
+      V = SV->getOperand(0);
+      Lane = M;
+    } else {
+      V = SV->getOperand(1);
+      Lane = M - NumElts;
+    }
+  }
+  return InstLane{V, Lane};
+}
+
+static SmallVector<InstLane>
+generateInstLaneVectorFromOperand(ArrayRef<InstLane> Item, int Op) {
+  SmallVector<InstLane> NItem;
+  for (InstLane IL : Item) {
+    auto [V, Lane] = IL;
+    InstLane OpLane =
+        V ? lookThroughShuffles(cast<Instruction>(V)->getOperand(Op), Lane)
+          : InstLane{nullptr, PoisonMaskElem};
+    NItem.emplace_back(OpLane);
+  }
+  return NItem;
+}
+
+static Value *generateNewInstTree(ArrayRef<InstLane> Item, FixedVectorType *Ty,
+                                  const SmallPtrSet<Value *, 4> &IdentityLeafs,
+                                  const SmallPtrSet<Value *, 4> &SplatLeafs,
+                                  IRBuilder<> &Builder) {
+  auto [FrontV, FrontLane] = Item.front();
+
+  if (IdentityLeafs.contains(FrontV) &&
+      all_of(drop_begin(enumerate(Item)), [Item](const auto &E) {
+        Value *FrontV = Item.front().first;
+        auto [V, Lane] = E.value();
+        return !V || (V == FrontV && Lane == (int)E.index());
+      })) {
+    return FrontV;
+  }
+  if (SplatLeafs.contains(FrontV)) {
+    if (auto *ILI = dyn_cast<Instruction>(FrontV))
+      Builder.SetInsertPoint(*ILI->getInsertionPointAfterDef());
+    else if (auto *Arg = dyn_cast<Argument>(FrontV))
+      Builder.SetInsertPointPastAllocas(Arg->getParent());
+    SmallVector<int, 16> Mask(Ty->getNumElements(), FrontLane);
+    return Builder.CreateShuffleVector(FrontV, Mask);
+  }
+
+  auto *I = cast<Instruction>(FrontV);
+  auto *II = dyn_cast<IntrinsicInst>(I);
+  unsigned NumOps = I->getNumOperands() - (II ? 1 : 0);
+  SmallVector<Value *> Ops(NumOps);
+  for (unsigned Idx = 0; Idx < NumOps; Idx++) {
+    if (II && isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Idx)) {
+      Ops[Idx] = II->getOperand(Idx);
+      continue;
+    }
+    Ops[Idx] = generateNewInstTree(generateInstLaneVectorFromOperand(Item, Idx),
+                                   Ty, IdentityLeafs, SplatLeafs, Builder);
+  }
+  Builder.SetInsertPoint(I);
+  Type *DstTy =
+      FixedVectorType::get(I->getType()->getScalarType(), Ty->getNumElements());
+  if (auto *BI = dyn_cast<BinaryOperator>(I))
+    return Builder.CreateBinOp((Instruction::BinaryOps)BI->getOpcode(), Ops[0],
+                               Ops[1]);
+  if (II)
+    return Builder.CreateIntrinsic(DstTy, II->getIntrinsicID(), Ops);
+  assert(isa<UnaryInstruction>(I) && "Unexpected instruction type in Generate");
+  return Builder.CreateUnOp((Instruction::UnaryOps)I->getOpcode(), Ops[0]);
+}
+
 // Starting from a shuffle, look up through operands tracking the shuffled index
 // of each lane. If we can simplify away the shuffles to identities then
 // do so.
@@ -1677,42 +1757,9 @@ bool VectorCombine::foldShuffleToIdentity(Instruction &I) {
       !isa<Instruction>(I.getOperand(1)))
     return false;
 
-  using InstLane = std::pair<Value *, int>;
-
-  auto LookThroughShuffles = [](Value *V, int Lane) -> InstLane {
-    while (auto *SV = dyn_cast<ShuffleVectorInst>(V)) {
-      unsigned NumElts =
-          cast<FixedVectorType>(SV->getOperand(0)->getType())->getNumElements();
-      int M = SV->getMaskValue(Lane);
-      if (M < 0)
-        return {nullptr, PoisonMaskElem};
-      else if (M < (int)NumElts) {
-        V = SV->getOperand(0);
-        Lane = M;
-      } else {
-        V = SV->getOperand(1);
-        Lane = M - NumElts;
-      }
-    }
-    return InstLane{V, Lane};
-  };
-
-  auto GenerateInstLaneVectorFromOperand =
-      [&LookThroughShuffles](ArrayRef<InstLane> Item, int Op) {
-        SmallVector<InstLane> NItem;
-        for (InstLane V : Item) {
-          NItem.emplace_back(
-              !V.first
-                  ? InstLane{nullptr, PoisonMaskElem}
-                  : LookThroughShuffles(
-                        cast<Instruction>(V.first)->getOperand(Op), V.second));
-        }
-        return NItem;
-      };
-
   SmallVector<InstLane> Start(Ty->getNumElements());
   for (unsigned M = 0, E = Ty->getNumElements(); M < E; ++M)
-    Start[M] = LookThroughShuffles(&I, M);
+    Start[M] = lookThroughShuffles(&I, M);
 
   SmallVector<SmallVector<InstLane>> Worklist;
   Worklist.push_back(Start);
@@ -1721,73 +1768,78 @@ bool VectorCombine::foldShuffleToIdentity(Instruction &I) {
 
   while (!Worklist.empty()) {
     SmallVector<InstLane> Item = Worklist.pop_back_val();
+    auto [FrontV, FrontLane] = Item.front();
     if (++NumVisited > MaxInstrsToScan)
       return false;
 
     // If we found an undef first lane then bail out to keep things simple.
-    if (!Item[0].first)
+    if (!FrontV)
       return false;
 
     // Look for an identity value.
-    if (Item[0].second == 0 &&
-        cast<FixedVectorType>(Item[0].first->getType())->getNumElements() ==
+    if (!FrontLane &&
+        cast<FixedVectorType>(FrontV->getType())->getNumElements() ==
             Ty->getNumElements() &&
-        all_of(drop_begin(enumerate(Item)), [&](const auto &E) {
-          return !E.value().first || (E.value().first == Item[0].first &&
+        all_of(drop_begin(enumerate(Item)), [Item](const auto &E) {
+          Value *FrontV = Item.front().first;
+          return !E.value().first || (E.value().first == FrontV &&
                                       E.value().second == (int)E.index());
         })) {
-      IdentityLeafs.insert(Item[0].first);
+      IdentityLeafs.insert(FrontV);
       continue;
     }
     // Look for a splat value.
-    if (all_of(drop_begin(Item), [&](InstLane &IL) {
-          return !IL.first ||
-                 (IL.first == Item[0].first && IL.second == Item[0].second);
+    if (all_of(drop_begin(Item), [Item](InstLane &IL) {
+          auto [FrontV, FrontLane] = Item.front();
+          auto [V, Lane] = IL;
+          return !V || (V == FrontV && Lane == FrontLane);
         })) {
-      SplatLeafs.insert(Item[0].first);
+      SplatLeafs.insert(FrontV);
       continue;
     }
 
     // We need each element to be the same type of value, and check that each
     // element has a single use.
-    if (!all_of(drop_begin(Item), [&](InstLane IL) {
-          if (!IL.first)
+    if (!all_of(drop_begin(Item), [Item](InstLane IL) {
+          Value *FrontV = Item.front().first;
+          Value *V = IL.first;
+          if (!V)
             return true;
-          if (auto *I = dyn_cast<Instruction>(IL.first); I && !I->hasOneUse())
+          if (auto *I = dyn_cast<Instruction>(V); I && !I->hasOneUse())
             return false;
-          if (IL.first->getValueID() != Item[0].first->getValueID())
+          if (V->getValueID() != FrontV->getValueID())
             return false;
-          if (isa<CallInst>(IL.first) && !isa<IntrinsicInst>(IL.first))
+          if (isa<CallInst>(V) && !isa<IntrinsicInst>(V))
             return false;
-          auto *II = dyn_cast<IntrinsicInst>(IL.first);
-          return !II ||
-                 (isa<IntrinsicInst>(Item[0].first) &&
-                  II->getIntrinsicID() ==
-                      cast<IntrinsicInst>(Item[0].first)->getIntrinsicID());
+          auto *II = dyn_cast<IntrinsicInst>(V);
+          return !II || (isa<IntrinsicInst>(FrontV) &&
+                         II->getIntrinsicID() ==
+                             cast<IntrinsicInst>(FrontV)->getIntrinsicID());
         }))
       return false;
 
     // Check the operator is one that we support. We exclude div/rem in case
     // they hit UB from poison lanes.
-    if (isa<BinaryOperator>(Item[0].first) &&
-        !cast<BinaryOperator>(Item[0].first)->isIntDivRem()) {
-      Worklist.push_back(GenerateInstLaneVectorFromOperand(Item, 0));
-      Worklist.push_back(GenerateInstLaneVectorFromOperand(Item, 1));
-    } else if (isa<UnaryOperator>(Item[0].first)) {
-      Worklist.push_back(GenerateInstLaneVectorFromOperand(Item, 0));
-    } else if (auto *II = dyn_cast<IntrinsicInst>(Item[0].first);
+    if (isa<BinaryOperator>(FrontV) &&
+        !cast<BinaryOperator>(FrontV)->isIntDivRem()) {
+      Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0));
+      Worklist.push_back(generateInstLaneVectorFromOperand(Item, 1));
+    } else if (isa<UnaryOperator>(FrontV)) {
+      Worklist.push_back(generateInstLaneVectorFromOperand(Item, 0));
+    } else if (auto *II = dyn_cast<IntrinsicInst>(FrontV);
                II && isTriviallyVectorizable(II->getIntrinsicID())) {
       for (unsigned Op = 0, E = II->getNumOperands() - 1; Op < E; Op++) {
         if (isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Op)) {
-          if (!all_of(drop_begin(Item), [&](InstLane &IL) {
-                return !IL.first ||
-                       (cast<Instruction>(IL.first)->getOperand(Op) ==
-                        cast<Instruction>(Item[0].first)->getOperand(Op));
+          if (!all_of(drop_begin(Item), [Item, Op](InstLane &IL) {
+                Value *FrontV = Item.front().first;
+                Value *V = IL.first;
+                return !V || (cast<Instruction>(V)->getOperand(Op) ==
+                              cast<Instruction>(FrontV)->getOperand(Op));
               }))
             return false;
           continue;
         }
-        Worklist.push_back(GenerateInstLaneVectorFromOperand(Item, Op));
+        Worklist.push_back(generateInstLaneVectorFromOperand(Item, Op));
       }
     } else {
       return false;
@@ -1799,49 +1851,7 @@ bool VectorCombine::foldShuffleToIdentity(Instruction &I) {
 
   // If we got this far, we know the shuffles are superfluous and can be
   // removed. Scan through again and generate the new tree of instructions.
-  std::function<Value *(ArrayRef<InstLane>)> Generate =
-      [&](ArrayRef<InstLane> Item) -> Value * {
-    if (IdentityLeafs.contains(Item[0].first) &&
-        all_of(drop_begin(enumerate(Item)), [&](const auto &E) {
-          return !E.value().first || (E.value().first == Item[0].first &&
-                                      E.value().second == (int)E.index());
-        })) {
-      return Item[0].first;
-    }
-    if (SplatLeafs.contains(Item[0].first)) {
-      if (auto ILI = dyn_cast<Instruction>(Item[0].first))
-        Builder.SetInsertPoint(*ILI->getInsertionPointAfterDef());
-      else if (isa<Argument>(Item[0].first))
-        Builder.SetInsertPointPastAllocas(I.getParent()->getParent());
-      SmallVector<int, 16> Mask(Ty->getNumElements(), Item[0].second);
-      return Builder.CreateShuffleVector(Item[0].first, Mask);
-    }
-
-    auto *I = cast<Instruction>(Item[0].first);
-    auto *II = dyn_cast<IntrinsicInst>(I);
-    unsigned NumOps = I->getNumOperands() - (II ? 1 : 0);
-    SmallVector<Value *> Ops(NumOps);
-    for (unsigned Idx = 0; Idx < NumOps; Idx++) {
-      if (II && isVectorIntrinsicWithScalarOpAtArg(II->getIntrinsicID(), Idx)) {
-        Ops[Idx] = II->getOperand(Idx);
-        continue;
-      }
-      Ops[Idx] = Generate(GenerateInstLaneVectorFromOperand(Item, Idx));
-    }
-    Builder.SetInsertPoint(I);
-    Type *DstTy = FixedVectorType::get(I->getType()->getScalarType(),
-                                       Ty->getNumElements());
-    if (auto BI = dyn_cast<BinaryOperator>(I))
-      return Builder.CreateBinOp((Instruction::BinaryOps)BI->getOpcode(),
-                                 Ops[0], Ops[1]);
-    if (II)
-      return Builder.CreateIntrinsic(DstTy, II->getIntrinsicID(), Ops);
-    assert(isa<UnaryInstruction>(I) &&
-           "Unexpected instruction type in Generate");
-    return Builder.CreateUnOp((Instruction::UnaryOps)I->getOpcode(), Ops[0]);
-  };
-
-  Value *V = Generate(Start);
+  Value *V = generateNewInstTree(Start, Ty, IdentityLeafs, SplatLeafs, Builder);
   replaceValue(I, *V);
   return true;
 }



More information about the llvm-commits mailing list