[llvm] [AArch64] Unroll some loops with early-continues on Apple Silicon. (PR #118499)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Sun Dec 22 02:08:11 PST 2024


================
@@ -4068,51 +4068,86 @@ getAppleRuntimeUnrollPreferences(Loop *L, ScalarEvolution &SE,
 
   // Try to unroll small, single block loops, if they have load/store
   // dependencies, to expose more parallel memory access streams.
-  if (L->getHeader() != L->getLoopLatch() || Size > 8)
-    return;
+  BasicBlock *Header = L->getHeader();
+  if (Header == L->getLoopLatch()) {
+    if (Size > 8)
+      return;
 
-  SmallPtrSet<Value *, 8> LoadedValues;
-  SmallVector<StoreInst *> Stores;
-  for (auto *BB : L->blocks()) {
-    for (auto &I : *BB) {
-      Value *Ptr = getLoadStorePointerOperand(&I);
-      if (!Ptr)
-        continue;
-      const SCEV *PtrSCEV = SE.getSCEV(Ptr);
-      if (SE.isLoopInvariant(PtrSCEV, L))
-        continue;
-      if (isa<LoadInst>(&I))
-        LoadedValues.insert(&I);
-      else
-        Stores.push_back(cast<StoreInst>(&I));
+    SmallPtrSet<Value *, 8> LoadedValues;
+    SmallVector<StoreInst *> Stores;
+    for (auto *BB : L->blocks()) {
+      for (auto &I : *BB) {
+        Value *Ptr = getLoadStorePointerOperand(&I);
+        if (!Ptr)
+          continue;
+        const SCEV *PtrSCEV = SE.getSCEV(Ptr);
+        if (SE.isLoopInvariant(PtrSCEV, L))
+          continue;
+        if (isa<LoadInst>(&I))
+          LoadedValues.insert(&I);
+        else
+          Stores.push_back(cast<StoreInst>(&I));
+      }
     }
-  }
 
-  // Try to find an unroll count that maximizes the use of the instruction
-  // window, i.e. trying to fetch as many instructions per cycle as possible.
-  unsigned MaxInstsPerLine = 16;
-  unsigned UC = 1;
-  unsigned BestUC = 1;
-  unsigned SizeWithBestUC = BestUC * Size;
-  while (UC <= 8) {
-    unsigned SizeWithUC = UC * Size;
-    if (SizeWithUC > 48)
-      break;
-    if ((SizeWithUC % MaxInstsPerLine) == 0 ||
-        (SizeWithBestUC % MaxInstsPerLine) < (SizeWithUC % MaxInstsPerLine)) {
-      BestUC = UC;
-      SizeWithBestUC = BestUC * Size;
+    // Try to find an unroll count that maximizes the use of the instruction
+    // window, i.e. trying to fetch as many instructions per cycle as possible.
+    unsigned MaxInstsPerLine = 16;
+    unsigned UC = 1;
+    unsigned BestUC = 1;
+    unsigned SizeWithBestUC = BestUC * Size;
+    while (UC <= 8) {
+      unsigned SizeWithUC = UC * Size;
+      if (SizeWithUC > 48)
+        break;
+      if ((SizeWithUC % MaxInstsPerLine) == 0 ||
+          (SizeWithBestUC % MaxInstsPerLine) < (SizeWithUC % MaxInstsPerLine)) {
+        BestUC = UC;
+        SizeWithBestUC = BestUC * Size;
+      }
+      UC++;
     }
-    UC++;
+
+    if (BestUC == 1 || none_of(Stores, [&LoadedValues](StoreInst *SI) {
+          return LoadedValues.contains(SI->getOperand(0));
+        }))
+      return;
+
+    UP.Runtime = true;
+    UP.DefaultUnrollRuntimeCount = BestUC;
+    return;
   }
 
-  if (BestUC == 1 || none_of(Stores, [&LoadedValues](StoreInst *SI) {
-        return LoadedValues.contains(SI->getOperand(0));
-      }))
+  // Try to runtime-unroll loops with early-continues depending on loop-varying
+  // loads; this helps with branch-prediction for the early-continues.
+  auto *Term = dyn_cast<BranchInst>(Header->getTerminator());
+  auto *Latch = L->getLoopLatch();
+  SmallVector<BasicBlock *> Preds(predecessors(Latch));
+  if (!Term || !Term->isConditional() || Preds.size() == 1 ||
+      none_of(Preds, [Header](BasicBlock *Pred) { return Header == Pred; }) ||
+      none_of(Preds, [L](BasicBlock *Pred) { return L->contains(Pred); }))
     return;
 
-  UP.Runtime = true;
-  UP.DefaultUnrollRuntimeCount = BestUC;
+  std::function<bool(Instruction *, unsigned)> DependsOnLoopLoad =
+      [&](Instruction *I, unsigned Depth) -> bool {
+    if (isa<PHINode>(I) || L->isLoopInvariant(I) || Depth > 8)
+      return false;
+
+    if (auto *LI = dyn_cast<LoadInst>(I))
+      return true;
+
+    return any_of(I->operands(), [&](Value *V) {
+      auto *I = dyn_cast<Instruction>(V);
+      return I && DependsOnLoopLoad(I, Depth + 1);
+    });
+  };
+  CmpInst::Predicate Pred;
+  Instruction *I;
+  if (match(Term, m_Br(m_ICmp(Pred, m_Instruction(I), m_Value()), m_Value(),
----------------
fhahn wrote:

For now I think most cases should be caught due to canonicalization. I'll look into extending it as follow up

https://github.com/llvm/llvm-project/pull/118499


More information about the llvm-commits mailing list