[llvm] [Vectorize] Add StridedLoopUnroll + Versioning for 2-D strided loop nests (PR #157749)

via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 9 14:04:47 PDT 2025


github-actions[bot] wrote:

<!--LLVM CODE FORMAT COMMENT: {clang-format}-->


:warning: C/C++ code formatter, clang-format found issues in your code. :warning:

<details>
<summary>
You can test this locally with the following command:
</summary>

``````````bash
git-clang-format --diff origin/main HEAD --extensions h,cpp -- llvm/include/llvm/Transforms/Vectorize/StridedLoopUnroll.h llvm/lib/Transforms/Vectorize/StridedLoopUnroll.cpp llvm/include/llvm/Transforms/Utils/LoopVersioning.h llvm/lib/Passes/PassBuilder.cpp llvm/lib/Target/RISCV/RISCVTargetMachine.cpp llvm/lib/Transforms/Utils/LoopVersioning.cpp
``````````

:warning:
The reproduction instructions above might return results for more than one PR
in a stack if you are using a stacked PR workflow. You can limit the results by
changing `origin/main` to the base branch/commit you want to compare against.
:warning:

</details>

<details>
<summary>
View the diff from clang-format here.
</summary>

``````````diff
diff --git a/llvm/include/llvm/Transforms/Utils/LoopVersioning.h b/llvm/include/llvm/Transforms/Utils/LoopVersioning.h
index d58995b76..f6ccc8803 100644
--- a/llvm/include/llvm/Transforms/Utils/LoopVersioning.h
+++ b/llvm/include/llvm/Transforms/Utils/LoopVersioning.h
@@ -46,7 +46,8 @@ public:
   /// object having no checks and we expect the user to add them.
   LoopVersioning(const LoopAccessInfo &LAI,
                  ArrayRef<RuntimePointerCheck> Checks, Loop *L, LoopInfo *LI,
-                 DominatorTree *DT, ScalarEvolution *SE, bool HoistRuntimeChecks = false);
+                 DominatorTree *DT, ScalarEvolution *SE,
+                 bool HoistRuntimeChecks = false);
 
   /// Performs the CFG manipulation part of versioning the loop including
   /// the DominatorTree and LoopInfo updates.
diff --git a/llvm/include/llvm/Transforms/Vectorize/StridedLoopUnroll.h b/llvm/include/llvm/Transforms/Vectorize/StridedLoopUnroll.h
index baf1220a6..67ecf9df7 100644
--- a/llvm/include/llvm/Transforms/Vectorize/StridedLoopUnroll.h
+++ b/llvm/include/llvm/Transforms/Vectorize/StridedLoopUnroll.h
@@ -28,6 +28,7 @@ public:
 class StridedLoopUnrollVersioningPass
     : public PassInfoMixin<StridedLoopUnrollVersioningPass> {
   int OptLevel;
+
 public:
   StridedLoopUnrollVersioningPass(int OptLevel = 2) : OptLevel(OptLevel) {}
 
diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp
index c3b310298..12d80bd5d 100644
--- a/llvm/lib/Passes/PassBuilder.cpp
+++ b/llvm/lib/Passes/PassBuilder.cpp
@@ -378,8 +378,8 @@
 #include "llvm/Transforms/Vectorize/LoopVectorize.h"
 #include "llvm/Transforms/Vectorize/SLPVectorizer.h"
 #include "llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.h"
-#include "llvm/Transforms/Vectorize/VectorCombine.h"
 #include "llvm/Transforms/Vectorize/StridedLoopUnroll.h"
+#include "llvm/Transforms/Vectorize/VectorCombine.h"
 #include <optional>
 
 using namespace llvm;
diff --git a/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp b/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp
index a2dd70459..6a8047ebd 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp
@@ -38,9 +38,9 @@
 #include "llvm/Target/TargetOptions.h"
 #include "llvm/Transforms/IPO.h"
 #include "llvm/Transforms/Scalar.h"
+#include "llvm/Transforms/Scalar/LoopUnrollPass.h"
 #include "llvm/Transforms/Vectorize/LoopIdiomVectorize.h"
 #include "llvm/Transforms/Vectorize/StridedLoopUnroll.h"
-#include "llvm/Transforms/Scalar/LoopUnrollPass.h"
 #include <optional>
 using namespace llvm;
 
@@ -658,15 +658,16 @@ void RISCVTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) {
       LPM.addPass(LoopIdiomVectorizePass(LoopIdiomVectorizeStyle::Predicated));
   });
 
-  PB.registerScalarOptimizerLateEPCallback([=](FunctionPassManager &LPM,
-                                               OptimizationLevel) {
-    LPM.addPass(StridedLoopUnrollVersioningPass());
-  });
-  PB.registerOptimizerLastEPCallback(
-      [=](ModulePassManager &MPM, OptimizationLevel Level, llvm::ThinOrFullLTOPhase) {
-        MPM.addPass(createModuleToFunctionPassAdaptor(
-            createFunctionToLoopPassAdaptor(StridedLoopUnrollPass())));
+  PB.registerScalarOptimizerLateEPCallback(
+      [=](FunctionPassManager &LPM, OptimizationLevel) {
+        LPM.addPass(StridedLoopUnrollVersioningPass());
       });
+  PB.registerOptimizerLastEPCallback([=](ModulePassManager &MPM,
+                                         OptimizationLevel Level,
+                                         llvm::ThinOrFullLTOPhase) {
+    MPM.addPass(createModuleToFunctionPassAdaptor(
+        createFunctionToLoopPassAdaptor(StridedLoopUnrollPass())));
+  });
 }
 
 yaml::MachineFunctionInfo *
diff --git a/llvm/lib/Transforms/Utils/LoopVersioning.cpp b/llvm/lib/Transforms/Utils/LoopVersioning.cpp
index 002717946..684f58cc5 100644
--- a/llvm/lib/Transforms/Utils/LoopVersioning.cpp
+++ b/llvm/lib/Transforms/Utils/LoopVersioning.cpp
@@ -42,10 +42,10 @@ static cl::opt<bool>
 LoopVersioning::LoopVersioning(const LoopAccessInfo &LAI,
                                ArrayRef<RuntimePointerCheck> Checks, Loop *L,
                                LoopInfo *LI, DominatorTree *DT,
-                               ScalarEvolution *SE,
-                               bool HoistRuntimeChecks)
+                               ScalarEvolution *SE, bool HoistRuntimeChecks)
     : VersionedLoop(L), AliasChecks(Checks), Preds(LAI.getPSE().getPredicate()),
-      LAI(LAI), LI(LI), DT(DT), SE(SE), HoistRuntimeChecks(HoistRuntimeChecks) {}
+      LAI(LAI), LI(LI), DT(DT), SE(SE), HoistRuntimeChecks(HoistRuntimeChecks) {
+}
 
 void LoopVersioning::versionLoop(
     const SmallVectorImpl<Instruction *> &DefsUsedOutside) {
@@ -64,8 +64,9 @@ void LoopVersioning::versionLoop(
   SCEVExpander Exp2(*RtPtrChecking.getSE(),
                     VersionedLoop->getHeader()->getDataLayout(),
                     "induction");
-  MemRuntimeCheck = addRuntimeChecks(RuntimeCheckBB->getTerminator(),
-                                     VersionedLoop, AliasChecks, Exp2, HoistRuntimeChecks);
+  MemRuntimeCheck =
+      addRuntimeChecks(RuntimeCheckBB->getTerminator(), VersionedLoop,
+                       AliasChecks, Exp2, HoistRuntimeChecks);
 
   SCEVExpander Exp(*SE, RuntimeCheckBB->getDataLayout(),
                    "scev.check");
diff --git a/llvm/lib/Transforms/Vectorize/StridedLoopUnroll.cpp b/llvm/lib/Transforms/Vectorize/StridedLoopUnroll.cpp
index 8f1f6cf19..35297911e 100644
--- a/llvm/lib/Transforms/Vectorize/StridedLoopUnroll.cpp
+++ b/llvm/lib/Transforms/Vectorize/StridedLoopUnroll.cpp
@@ -34,9 +34,12 @@
 
 #include "llvm/Transforms/Vectorize/StridedLoopUnroll.h"
 #include "llvm/Analysis/AliasAnalysis.h"
+#include "llvm/Analysis/AssumptionCache.h"
 #include "llvm/Analysis/DomTreeUpdater.h"
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/Analysis/LoopPass.h"
+#include "llvm/Analysis/MemorySSA.h"
+#include "llvm/Analysis/OptimizationRemarkEmitter.h"
 #include "llvm/Analysis/ScalarEvolution.h"
 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
@@ -45,32 +48,13 @@
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/PatternMatch.h"
+#include "llvm/Target/TargetMachine.h"
+#include "llvm/Transforms/Scalar/EarlyCSE.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
+#include "llvm/Transforms/Utils/Local.h"
 #include "llvm/Transforms/Utils/LoopVersioning.h"
 #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
-#include "llvm/Analysis/AliasAnalysis.h"
-#include "llvm/Analysis/DomTreeUpdater.h"
-#include "llvm/Analysis/LoopInfo.h"
-#include "llvm/Analysis/LoopPass.h"
-#include "llvm/Analysis/ScalarEvolution.h"
-#include "llvm/Analysis/ScalarEvolutionExpressions.h"
-#include "llvm/Analysis/TargetTransformInfo.h"
-#include "llvm/IR/Dominators.h"
-#include "llvm/IR/IRBuilder.h"
-#include "llvm/IR/Intrinsics.h"
-#include "llvm/IR/MDBuilder.h"
-#include "llvm/IR/PatternMatch.h"
-#include "llvm/Transforms/Utils/BasicBlockUtils.h"
-#include "llvm/Transforms/Utils/LoopVersioning.h"
 #include "llvm/Transforms/Utils/UnrollLoop.h"
-#include "llvm/Analysis/TargetTransformInfo.h"
-#include "llvm/Analysis/OptimizationRemarkEmitter.h"
-#include "llvm/Analysis/AssumptionCache.h"
-#include "llvm/Transforms/Utils/Local.h"
-#include "llvm/Transforms/Scalar/EarlyCSE.h"
-#include "llvm/Analysis/MemorySSA.h"
-#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
-#include "llvm/Target/TargetMachine.h"
 
 using namespace llvm;
 using namespace PatternMatch;
@@ -101,9 +85,9 @@ class StridedLoopUnroll {
   const LoopAccessInfo *LAI = nullptr;
 
 public:
-  StridedLoopUnroll(DominatorTree *DT, LoopInfo *LI,
-                    TargetTransformInfo *TTI, const DataLayout *DL,
-                    ScalarEvolution *SE, AliasAnalysis *AA, AssumptionCache* AC)
+  StridedLoopUnroll(DominatorTree *DT, LoopInfo *LI, TargetTransformInfo *TTI,
+                    const DataLayout *DL, ScalarEvolution *SE,
+                    AliasAnalysis *AA, AssumptionCache *AC)
       : DT(DT), LI(LI), TTI(TTI), DL(DL), SE(SE), AA(AA),
         LAIs(*SE, *AA, *DT, *LI, TTI, nullptr, AC) {}
 
@@ -119,21 +103,22 @@ private:
 
   void transformStridedSpecialCases(BasicBlock *Header, BasicBlock *Latch,
                                     BasicBlock *Preheader, Loop *SubLoop,
-                                    SmallVectorImpl<LoadInst*> &Loads,
-                                    StoreInst* Store,
-                                    SmallVectorImpl<Value *>& PostOrder,
-                                    SmallVectorImpl<Value *>& PreOrder);
+                                    SmallVectorImpl<LoadInst *> &Loads,
+                                    StoreInst *Store,
+                                    SmallVectorImpl<Value *> &PostOrder,
+                                    SmallVectorImpl<Value *> &PreOrder);
   void changeInductionVarIncrement(Value *IncomingValue, unsigned VF);
   std::optional<Value *> getDynamicStrideFromMemOp(Value *Value,
                                                    Instruction *InsertionPt);
-  std::optional<Value*> getStrideFromAddRecExpr(const SCEVAddRecExpr* AR, Instruction *InsertionPt);
+  std::optional<Value *> getStrideFromAddRecExpr(const SCEVAddRecExpr *AR,
+                                                 Instruction *InsertionPt);
 
   /// @}
 };
 
 static cl::opt<bool>
     SkipPass("strided-loop-unroll-disable", cl::init(false), cl::Hidden,
-                 cl::desc("Skip running strided loop unroll optimization."));
+             cl::desc("Skip running strided loop unroll optimization."));
 
 class StridedLoopUnrollVersioning {
   Loop *CurLoop = nullptr;
@@ -144,9 +129,9 @@ class StridedLoopUnrollVersioning {
   ScalarEvolution *SE;
   AliasAnalysis *AA;
   LoopAccessInfoManager LAIs;
-  AssumptionCache* AC;
-  OptimizationRemarkEmitter* ORE;
-  Function* F;
+  AssumptionCache *AC;
+  OptimizationRemarkEmitter *ORE;
+  Function *F;
 
   // Blocks that will be used for inserting vectorized code.
   BasicBlock *EndBlock = nullptr;
@@ -158,10 +143,10 @@ class StridedLoopUnrollVersioning {
 
 public:
   StridedLoopUnrollVersioning(DominatorTree *DT, LoopInfo *LI,
-                          TargetTransformInfo *TTI, const DataLayout *DL,
-                          ScalarEvolution *SE, AliasAnalysis *AA, AssumptionCache *AC,
-                              OptimizationRemarkEmitter* ORE,
-                              Function* F)
+                              TargetTransformInfo *TTI, const DataLayout *DL,
+                              ScalarEvolution *SE, AliasAnalysis *AA,
+                              AssumptionCache *AC,
+                              OptimizationRemarkEmitter *ORE, Function *F)
       : DT(DT), LI(LI), TTI(TTI), DL(DL), SE(SE), AA(AA),
         LAIs(*SE, *AA, *DT, *LI, TTI, nullptr, AC), AC(AC), ORE(ORE), F(F) {}
 
@@ -173,23 +158,21 @@ private:
 
   void setNoAliasToLoop(Loop *VerLoop);
   bool recognizeStridedSpecialCases();
-  void transformStridedSpecialCases(PHINode *OuterInductionVar,
-                                    PHINode *InnerInductionVar,
-                                    StoreInst *Stores, BasicBlock *PreheaderBB,
-                                    BasicBlock *BodyBB, BasicBlock *HeaderBB,
-                                    BasicBlock *LatchBB,
-                                    SmallVectorImpl<const SCEV*>& AlignmentInfo,
-                                    unsigned UnrollSize);
+  void transformStridedSpecialCases(
+      PHINode *OuterInductionVar, PHINode *InnerInductionVar, StoreInst *Stores,
+      BasicBlock *PreheaderBB, BasicBlock *BodyBB, BasicBlock *HeaderBB,
+      BasicBlock *LatchBB, SmallVectorImpl<const SCEV *> &AlignmentInfo,
+      unsigned UnrollSize);
   void eliminateRedundantLoads(BasicBlock *BB) {
     // Map from load pointer to the first load instruction
-    DenseMap<Value*, LoadInst*> LoadMap;
-    SmallVector<LoadInst*, 16> ToDelete;
-    
+    DenseMap<Value *, LoadInst *> LoadMap;
+    SmallVector<LoadInst *, 16> ToDelete;
+
     // First pass: collect all loads and find duplicates
     for (Instruction &I : *BB) {
       if (auto *LI = dyn_cast<LoadInst>(&I)) {
         Value *Ptr = LI->getPointerOperand();
-        
+
         // Check if we've seen a load from this address
         auto It = LoadMap.find(Ptr);
         if (It != LoadMap.end() && !LI->isVolatile()) {
@@ -208,7 +191,7 @@ private:
         }
       }
     }
-    
+
     // Delete redundant loads
     for (LoadInst *LI : ToDelete) {
       LI->eraseFromParent();
@@ -219,7 +202,6 @@ private:
   /// @}
 };
 
-
 } // anonymous namespace
 
 PreservedAnalyses StridedLoopUnrollPass::run(Loop &L, LoopAnalysisManager &AM,
@@ -253,20 +235,20 @@ bool StridedLoopUnroll::run(Loop *L) {
                     << CurLoop->getHeader()->getName() << "\n");
 
   if (recognizeStridedSpecialCases()) {
-    LLVM_DEBUG(dbgs() << DEBUG_TYPE " Will transform: F[" << F.getName() << "] Loop %"
-               << CurLoop->getHeader()->getName() << "\n");
+    LLVM_DEBUG(dbgs() << DEBUG_TYPE " Will transform: F[" << F.getName()
+                      << "] Loop %" << CurLoop->getHeader()->getName() << "\n");
     return true;
   }
 
-  LLVM_DEBUG(dbgs() << DEBUG_TYPE " Will not transform: F[" << F.getName() << "] Loop %"
-             << CurLoop->getHeader()->getName() << "\n");
+  LLVM_DEBUG(dbgs() << DEBUG_TYPE " Will not transform: F[" << F.getName()
+                    << "] Loop %" << CurLoop->getHeader()->getName() << "\n");
   return false;
 }
 
 bool StridedLoopUnrollVersioning::run(Loop *L) {
   CurLoop = L;
 
-  if(!TTI->getVScaleForTuning())
+  if (!TTI->getVScaleForTuning())
     return false;
 
   Function &F = *L->getHeader()->getParent();
@@ -282,13 +264,13 @@ bool StridedLoopUnrollVersioning::run(Loop *L) {
                     << CurLoop->getHeader()->getName() << "\n");
 
   if (recognizeStridedSpecialCases()) {
-    LLVM_DEBUG(dbgs() << DEBUG_TYPE " Will transform: F[" << F.getName() << "] Loop %"
-               << CurLoop->getHeader()->getName() << "\n");
+    LLVM_DEBUG(dbgs() << DEBUG_TYPE " Will transform: F[" << F.getName()
+                      << "] Loop %" << CurLoop->getHeader()->getName() << "\n");
 
     return true;
   }
-  LLVM_DEBUG(dbgs() << DEBUG_TYPE " Will not transform: F[" << F.getName() << "] Loop %"
-             << CurLoop->getHeader()->getName() << "\n");
+  LLVM_DEBUG(dbgs() << DEBUG_TYPE " Will not transform: F[" << F.getName()
+                    << "] Loop %" << CurLoop->getHeader()->getName() << "\n");
   return false;
 }
 
@@ -360,7 +342,8 @@ static void findUnconnectedToLoad(Instruction *start,
 
       bool shouldBreak = isa<LoadInst>(inst);
       // If this is a load, do not proceed from here!
-      connected = isa<LoadInst>(inst) && start->getParent()->getName() == inst->getParent()->getName();
+      connected = isa<LoadInst>(inst) &&
+                  start->getParent()->getName() == inst->getParent()->getName();
       if (shouldBreak)
         break;
 
@@ -369,8 +352,7 @@ static void findUnconnectedToLoad(Instruction *start,
         if (auto I = dyn_cast<Instruction>(op.get())) {
           if (I->getParent() == start->getParent())
             innerWorklist.push_back(op.get());
-        }
-        else
+        } else
           innerWorklist.push_back(op.get());
       }
     }
@@ -434,12 +416,12 @@ Value *StridedLoopUnroll::widenVectorizedInstruction(
     return V;
   }
   case Instruction::Select: {
-    Value* V = Builder.CreateSelect(Ops[0], Ops[1], Ops[2]);
+    Value *V = Builder.CreateSelect(Ops[0], Ops[1], Ops[2]);
     return V;
   }
   case Instruction::ICmp: {
     auto Cmp = dyn_cast<ICmpInst>(I);
-    Value* V = Builder.CreateICmp(Cmp->getPredicate(), Ops[0], Ops[1]);
+    Value *V = Builder.CreateICmp(Cmp->getPredicate(), Ops[0], Ops[1]);
     return V;
   }
   case Instruction::ShuffleVector: {
@@ -453,11 +435,13 @@ Value *StridedLoopUnroll::widenVectorizedInstruction(
     }
 
     ArrayRef<int> NewMask(RepeatedMask);
-    Value *Shuffle = Builder.CreateShuffleVector(Ops[0], Ops[1], NewMask, I->getName());
+    Value *Shuffle =
+        Builder.CreateShuffleVector(Ops[0], Ops[1], NewMask, I->getName());
     return Shuffle;
   }
   case Instruction::InsertElement: {
-    Value *A = Builder.CreateInsertElement(Ops[0], Ops[1], Ops[2], I->getName());
+    Value *A =
+        Builder.CreateInsertElement(Ops[0], Ops[1], Ops[2], I->getName());
     return A;
   }
   case Instruction::Load: {
@@ -513,89 +497,90 @@ bool isInstructionDepends(Instruction *Dependant, Instruction *Target) {
 
 // InnerInductionVar will be transformed to static
 void StridedLoopUnroll::transformStridedSpecialCases(
-    BasicBlock *Header, BasicBlock *Latch, BasicBlock *Preheader,
-    Loop *SubLoop,
-    SmallVectorImpl<LoadInst*> &Loads, StoreInst* Store,
-    SmallVectorImpl<Value *>& PostOrder,
-    SmallVectorImpl<Value *>& PreOrder) {
+    BasicBlock *Header, BasicBlock *Latch, BasicBlock *Preheader, Loop *SubLoop,
+    SmallVectorImpl<LoadInst *> &Loads, StoreInst *Store,
+    SmallVectorImpl<Value *> &PostOrder, SmallVectorImpl<Value *> &PreOrder) {
 
-  //auto InnerPreheader = SubLoop->getLoopPreheader();
+  // auto InnerPreheader = SubLoop->getLoopPreheader();
 
-  auto Stride = getDynamicStrideFromMemOp(Store->getPointerOperand(), Preheader->getTerminator());
+  auto Stride = getDynamicStrideFromMemOp(Store->getPointerOperand(),
+                                          Preheader->getTerminator());
 
   SmallPtrSet<llvm::Value *, 32> Connected;
   SmallPtrSet<llvm::Value *, 32> NotConnected;
   SmallDenseMap<llvm::Value *, llvm::Value *, 32> Replacements;
 
-  auto StoredInstruction =
-      dyn_cast<Instruction>(Store->getValueOperand());
+  auto StoredInstruction = dyn_cast<Instruction>(Store->getValueOperand());
   findUnconnectedToLoad(StoredInstruction, NotConnected, Connected);
 
-  auto convertConstant = [&] (auto val) {
-        auto constVal = cast<Constant>(val);
-        unsigned numElements =
-            cast<FixedVectorType>(val->getType())->getNumElements();
-        SmallVector<Constant *, 16> elements;
-
-        // Extract original elements
-        for (unsigned i = 0; i < numElements; ++i)
-          elements.push_back(constVal->getAggregateElement(i));
-
-        auto originalElements = elements;
-        for (unsigned int copy = 0; copy != (*TTI->getVScaleForTuning())-1; ++copy)
-          elements.append(originalElements);
-        Constant *newConst = ConstantVector::get(elements);
-        return newConst;
+  auto convertConstant = [&](auto val) {
+    auto constVal = cast<Constant>(val);
+    unsigned numElements =
+        cast<FixedVectorType>(val->getType())->getNumElements();
+    SmallVector<Constant *, 16> elements;
+
+    // Extract original elements
+    for (unsigned i = 0; i < numElements; ++i)
+      elements.push_back(constVal->getAggregateElement(i));
+
+    auto originalElements = elements;
+    for (unsigned int copy = 0; copy != (*TTI->getVScaleForTuning()) - 1;
+         ++copy)
+      elements.append(originalElements);
+    Constant *newConst = ConstantVector::get(elements);
+    return newConst;
   };
 
   // Process in post-order (leafs to root)
   for (Value *val : PostOrder) {
     if (Connected.contains(val)) {
       if (auto *I = dyn_cast<Instruction>(val)) {
-          SmallVector<Value *, 4> Operands(I->operands());
-          for (auto op_it = Operands.begin(); op_it != Operands.end(); ++op_it) {
-            if (Replacements.contains(*op_it))
-              *op_it = Replacements[*op_it];
-            else if (auto OrigVecTy = llvm::dyn_cast<llvm::VectorType>((*op_it)->getType())) {
-              if (auto Iop = dyn_cast<Instruction>(*op_it)) {
-                if (Iop->getParent() != Store->getParent()) {
-                  assert(!Connected.contains(*op_it));
-
-                  IRBuilder<> Builder(I);
-
-                  std::vector<llvm::Constant*> Consts;
-                  for (unsigned int i = 0; i != *TTI->getVScaleForTuning(); i++) {
-                    for (size_t j = 0; j != OrigVecTy->getElementCount().getFixedValue(); j++) {
-                      Consts.push_back(llvm::ConstantInt::get(llvm::Type::getInt32Ty(Builder.getContext()), j));
-                    }
+        SmallVector<Value *, 4> Operands(I->operands());
+        for (auto op_it = Operands.begin(); op_it != Operands.end(); ++op_it) {
+          if (Replacements.contains(*op_it))
+            *op_it = Replacements[*op_it];
+          else if (auto OrigVecTy =
+                       llvm::dyn_cast<llvm::VectorType>((*op_it)->getType())) {
+            if (auto Iop = dyn_cast<Instruction>(*op_it)) {
+              if (Iop->getParent() != Store->getParent()) {
+                assert(!Connected.contains(*op_it));
+
+                IRBuilder<> Builder(I);
+
+                std::vector<llvm::Constant *> Consts;
+                for (unsigned int i = 0; i != *TTI->getVScaleForTuning(); i++) {
+                  for (size_t j = 0;
+                       j != OrigVecTy->getElementCount().getFixedValue(); j++) {
+                    Consts.push_back(llvm::ConstantInt::get(
+                        llvm::Type::getInt32Ty(Builder.getContext()), j));
                   }
-
-                  llvm::Constant* maskConst =
-                    llvm::ConstantVector::get(Consts);
-                  assert(maskConst != nullptr);
-
-                  llvm::Value* splat =
-                    Builder.CreateShuffleVector(Iop,
-                                                llvm::UndefValue::get(Iop->getType()),
-                                                maskConst);
-                  assert(splat != nullptr);
-                  Replacements.insert({*op_it, splat});
-                  *op_it = splat;
                 }
-              } else if (isa<Constant>(*op_it)) { // not instruction
-                  auto replacement = convertConstant(*op_it);
-                  assert(!!replacement);
-                  Replacements.insert({*op_it, replacement});
-                  *op_it = replacement;
+
+                llvm::Constant *maskConst = llvm::ConstantVector::get(Consts);
+                assert(maskConst != nullptr);
+
+                llvm::Value *splat = Builder.CreateShuffleVector(
+                    Iop, llvm::UndefValue::get(Iop->getType()), maskConst);
+                assert(splat != nullptr);
+                Replacements.insert({*op_it, splat});
+                *op_it = splat;
               }
+            } else if (isa<Constant>(*op_it)) { // not instruction
+              auto replacement = convertConstant(*op_it);
+              assert(!!replacement);
+              Replacements.insert({*op_it, replacement});
+              *op_it = replacement;
             }
           }
+        }
 
-          auto NewVecTy = getWidenedType(I->getType(), *TTI->getVScaleForTuning());
-          Value *NI = widenVectorizedInstruction(I, Operands, NewVecTy, *TTI->getVScaleForTuning());
+        auto NewVecTy =
+            getWidenedType(I->getType(), *TTI->getVScaleForTuning());
+        Value *NI = widenVectorizedInstruction(I, Operands, NewVecTy,
+                                               *TTI->getVScaleForTuning());
 
-          assert(NI != nullptr);
-          Replacements.insert({I, NI});
+        assert(NI != nullptr);
+        Replacements.insert({I, NI});
       }
     } else if (NotConnected.contains(val)) {
       if (val->getType()->isVectorTy() && isa<Constant>(val)) {
@@ -603,44 +588,45 @@ void StridedLoopUnroll::transformStridedSpecialCases(
         Replacements.insert({val, replacement});
       }
     } else if (auto Load = dyn_cast<LoadInst>(val)) {
-        auto It =
-          std::find_if(Loads.begin(), Loads.end(), [Load](auto &&LoadInstr) {
-            return LoadInstr == Load;
-          });
-        if (It != Loads.end()) {
-          auto Stride = getDynamicStrideFromMemOp((*It)->getPointerOperand(), Preheader->getTerminator());
-
-          auto GroupedVecTy = getGroupedWidenedType(Load->getType(), *TTI->getVScaleForTuning(), *DL);
-          auto VecTy = getWidenedType(Load->getType(), *TTI->getVScaleForTuning());
-          ElementCount NewElementCount = GroupedVecTy->getElementCount();
-
-          IRBuilder<> Builder(Load);
-          auto *NewInst = Builder.CreateIntrinsic(
-          Intrinsic::experimental_vp_strided_load,
-          {GroupedVecTy, Load->getPointerOperand()->getType(),
-           (*Stride)->getType()},
-          {Load->getPointerOperand(), *Stride,
-           Builder.getAllOnesMask(NewElementCount),
-           Builder.getInt32(NewElementCount.getKnownMinValue())});
-          auto Cast = Builder.CreateBitCast(NewInst, VecTy);
-          Replacements.insert({Load, Cast});
+      auto It =
+          std::find_if(Loads.begin(), Loads.end(),
+                       [Load](auto &&LoadInstr) { return LoadInstr == Load; });
+      if (It != Loads.end()) {
+        auto Stride = getDynamicStrideFromMemOp((*It)->getPointerOperand(),
+                                                Preheader->getTerminator());
+
+        auto GroupedVecTy = getGroupedWidenedType(
+            Load->getType(), *TTI->getVScaleForTuning(), *DL);
+        auto VecTy =
+            getWidenedType(Load->getType(), *TTI->getVScaleForTuning());
+        ElementCount NewElementCount = GroupedVecTy->getElementCount();
+
+        IRBuilder<> Builder(Load);
+        auto *NewInst = Builder.CreateIntrinsic(
+            Intrinsic::experimental_vp_strided_load,
+            {GroupedVecTy, Load->getPointerOperand()->getType(),
+             (*Stride)->getType()},
+            {Load->getPointerOperand(), *Stride,
+             Builder.getAllOnesMask(NewElementCount),
+             Builder.getInt32(NewElementCount.getKnownMinValue())});
+        auto Cast = Builder.CreateBitCast(NewInst, VecTy);
+        Replacements.insert({Load, Cast});
       }
     }
   }
 
   IRBuilder<> Builder(Store);
-  auto VecTy =
-      getGroupedWidenedType(Store->getValueOperand()->getType(), *TTI->getVScaleForTuning(), *DL);
+  auto VecTy = getGroupedWidenedType(Store->getValueOperand()->getType(),
+                                     *TTI->getVScaleForTuning(), *DL);
   ElementCount NewElementCount = VecTy->getElementCount();
 
   assert(Replacements.find(Store->getValueOperand()) != Replacements.end());
-  auto Cast = Builder.CreateBitCast(
-      Replacements[Store->getValueOperand()], VecTy);
+  auto Cast =
+      Builder.CreateBitCast(Replacements[Store->getValueOperand()], VecTy);
 
   Builder.CreateIntrinsic(
       Intrinsic::experimental_vp_strided_store,
-      {VecTy, Store->getPointerOperand()->getType(),
-       (*Stride)->getType()},
+      {VecTy, Store->getPointerOperand()->getType(), (*Stride)->getType()},
       {Cast, Store->getPointerOperand(), *Stride,
        Builder.getAllOnesMask(NewElementCount),
        Builder.getInt32(NewElementCount.getKnownMinValue())});
@@ -651,11 +637,13 @@ void StridedLoopUnroll::transformStridedSpecialCases(
     if (InductionDescriptor::isInductionPHI(&PN, CurLoop, SE, IndDesc)) {
       if (IndDesc.getKind() == InductionDescriptor::IK_PtrInduction)
         changeInductionVarIncrement(
-          PN.getIncomingValueForBlock(CurLoop->getLoopLatch()), *TTI->getVScaleForTuning());
+            PN.getIncomingValueForBlock(CurLoop->getLoopLatch()),
+            *TTI->getVScaleForTuning());
       else if (IndDesc.getKind() == InductionDescriptor::IK_IntInduction)
-        changeInductionVarIncrement(IndDesc.getInductionBinOp(), *TTI->getVScaleForTuning());
+        changeInductionVarIncrement(IndDesc.getInductionBinOp(),
+                                    *TTI->getVScaleForTuning());
     }
- }
+  }
 
   if (Store->use_empty())
     Store->eraseFromParent();
@@ -666,14 +654,15 @@ void StridedLoopUnroll::transformStridedSpecialCases(
         I->eraseFromParent();
 }
 
-std::optional<Value*> StridedLoopUnroll::getStrideFromAddRecExpr(const SCEVAddRecExpr* AR,
-                                                                 Instruction *InsertionPt) {
+std::optional<Value *>
+StridedLoopUnroll::getStrideFromAddRecExpr(const SCEVAddRecExpr *AR,
+                                           Instruction *InsertionPt) {
   auto Step = AR->getStepRecurrence(*SE);
   if (isa<SCEVConstant>(Step))
     return std::nullopt;
   SCEVExpander Expander(*SE, *DL, "stride");
   Value *StrideValue =
-    Expander.expandCodeFor(Step, Step->getType(), InsertionPt);
+      Expander.expandCodeFor(Step, Step->getType(), InsertionPt);
   return StrideValue;
 }
 
@@ -683,7 +672,7 @@ StridedLoopUnroll::getDynamicStrideFromMemOp(Value *V,
   const SCEV *S = SE->getSCEV(V);
   if (const SCEVAddRecExpr *InnerLoopAR = dyn_cast<SCEVAddRecExpr>(S)) {
     if (auto *constant =
-        dyn_cast<SCEVConstant>(InnerLoopAR->getStepRecurrence(*SE))) {
+            dyn_cast<SCEVConstant>(InnerLoopAR->getStepRecurrence(*SE))) {
       // We need to form 64-bit groups
       if (constant->getAPInt() != 8) {
         return std::nullopt;
@@ -692,16 +681,16 @@ StridedLoopUnroll::getDynamicStrideFromMemOp(Value *V,
       const auto *Add = dyn_cast<SCEVAddExpr>(InnerLoopAR->getStart());
       if (Add) {
         for (const SCEV *Op : Add->operands()) {
-          // Look for the outer recurrence: { %dst, +, sext(%i_dst_stride) } <outer loop>
+          // Look for the outer recurrence: { %dst, +, sext(%i_dst_stride) }
+          // <outer loop>
           const auto *AR = dyn_cast<SCEVAddRecExpr>(Op);
           if (!AR)
             continue;
 
           return getStrideFromAddRecExpr(AR, InsertionPt);
         }
-      }
-      else if (const SCEVAddRecExpr *AR =
-              dyn_cast<SCEVAddRecExpr>(InnerLoopAR->getStart())) {
+      } else if (const SCEVAddRecExpr *AR =
+                     dyn_cast<SCEVAddRecExpr>(InnerLoopAR->getStart())) {
         return getStrideFromAddRecExpr(AR, InsertionPt);
       }
     }
@@ -715,12 +704,11 @@ bool StridedLoopUnroll::recognizeStridedSpecialCases() {
     return false;
 
   auto SubLoops = CurLoop->getSubLoops();
-  
+
   if (SubLoops.size() > 2 || SubLoops.empty())
     return false;
 
-  auto SubLoop = SubLoops.size() == 2 ? SubLoops[1]
-    : SubLoops[0];
+  auto SubLoop = SubLoops.size() == 2 ? SubLoops[1] : SubLoops[0];
 
   auto Preheader = SubLoop->getLoopPreheader();
   auto Header = SubLoop->getHeader();
@@ -729,8 +717,8 @@ bool StridedLoopUnroll::recognizeStridedSpecialCases() {
   if (Header != Latch)
     return false;
 
-  SmallVector<LoadInst*> Loads;
-  SmallVector<StoreInst*> Stores;
+  SmallVector<LoadInst *> Loads;
+  SmallVector<StoreInst *> Stores;
 
   llvm::SmallPtrSet<llvm::Instruction *, 32> NotVisited;
   llvm::SmallVector<llvm::Instruction *, 8> WorkList;
@@ -775,8 +763,7 @@ bool StridedLoopUnroll::recognizeStridedSpecialCases() {
   SmallPtrSet<llvm::Value *, 32> Connected;
   SmallPtrSet<llvm::Value *, 32> NotConnected;
 
-  auto StoredInstruction =
-      dyn_cast<Instruction>(Stores[0]->getValueOperand());
+  auto StoredInstruction = dyn_cast<Instruction>(Stores[0]->getValueOperand());
   findUnconnectedToLoad(StoredInstruction, NotConnected, Connected);
 
   llvm::SmallVector<Value *, 16> PostOrder;
@@ -787,12 +774,10 @@ bool StridedLoopUnroll::recognizeStridedSpecialCases() {
 
   Worklist.push_back(StoredInstruction);
 
-  auto shouldVisit = [Header, Preheader](auto *Val)
-  {
+  auto shouldVisit = [Header, Preheader](auto *Val) {
     return !isa<PHINode>(Val) &&
-      (
-       !isa<Instruction>(Val) ||
-       dyn_cast<Instruction>(Val)->getParent() == Header);
+           (!isa<Instruction>(Val) ||
+            dyn_cast<Instruction>(Val)->getParent() == Header);
   };
   auto shouldVisitOperands = [Header, Preheader](auto *Val) {
     return !isa<PHINode>(Val) && !isa<LoadInst>(Val);
@@ -800,8 +785,9 @@ bool StridedLoopUnroll::recognizeStridedSpecialCases() {
 
   while (!Worklist.empty()) {
     Value *val = Worklist.back();
-    assert (!isa<Instruction>(val) || dyn_cast<Instruction>(val)->getParent() == Header
-            || dyn_cast<Instruction>(val)->getParent() == Preheader);
+    assert(!isa<Instruction>(val) ||
+           dyn_cast<Instruction>(val)->getParent() == Header ||
+           dyn_cast<Instruction>(val)->getParent() == Preheader);
 
     if (InStack.contains(val)) {
       // We've finished processing children, add to post-order
@@ -839,12 +825,11 @@ bool StridedLoopUnroll::recognizeStridedSpecialCases() {
             }
           }
         }
-      }
-      else { // We don't handle Non-instructions connected to Load
+      } else { // We don't handle Non-instructions connected to Load
         return false;
       }
-    } else if (NotConnected.contains(val)
-               && (!val->getType()->isVectorTy() || !isa<Constant>(val))) {
+    } else if (NotConnected.contains(val) &&
+               (!val->getType()->isVectorTy() || !isa<Constant>(val))) {
       return false;
     } else if (auto Load = dyn_cast<LoadInst>(val)) {
       if (std::find(Loads.begin(), Loads.end(), Load) == Loads.end())
@@ -852,8 +837,8 @@ bool StridedLoopUnroll::recognizeStridedSpecialCases() {
     }
   }
 
-  transformStridedSpecialCases(Header, Latch, Preheader, SubLoop, Loads, Stores[0],
-                               PostOrder, PreOrder);
+  transformStridedSpecialCases(Header, Latch, Preheader, SubLoop, Loads,
+                               Stores[0], PostOrder, PreOrder);
 
   return true;
 }
@@ -901,8 +886,8 @@ bool canHandleInstruction(Instruction *I) {
 } // anonymous namespace
 
 PreservedAnalyses
-StridedLoopUnrollVersioningPass::run(Function &F, FunctionAnalysisManager &FAM)
-{
+StridedLoopUnrollVersioningPass::run(Function &F,
+                                     FunctionAnalysisManager &FAM) {
   bool Changed = false;
 
   if (SkipPass)
@@ -927,7 +912,8 @@ StridedLoopUnrollVersioningPass::run(Function &F, FunctionAnalysisManager &FAM)
 
     const auto *DL = &L->getHeader()->getDataLayout();
 
-    StridedLoopUnrollVersioning LIV(&DT, &LI, &TTI, DL, &SE, &AA, &AC, &ORE, &F);
+    StridedLoopUnrollVersioning LIV(&DT, &LI, &TTI, DL, &SE, &AA, &AC, &ORE,
+                                    &F);
     bool ThisChanged = LIV.run(L);
     Changed |= ThisChanged;
   }
@@ -968,58 +954,56 @@ void StridedLoopUnrollVersioning::setNoAliasToLoop(Loop *VerLoop) {
 }
 
 void StridedLoopUnrollVersioning::hoistInvariantLoadsToPreheader(Loop *L) {
-    BasicBlock *Preheader = L->getLoopPreheader();
-    if (!Preheader) {
-      // If no preheader, try the header
-      Preheader = L->getHeader();
-    }
-    
-    // Find all invariant loads in the loop
-    SmallVector<LoadInst*, 8> InvariantLoads;
-    
-    for (BasicBlock *BB : L->blocks()) {
-      for (Instruction &I : *BB) {
-        if (auto *LI = dyn_cast<LoadInst>(&I)) {
-          Value *Ptr = LI->getPointerOperand();
-          
-          if (L->isLoopInvariant(Ptr)) {
-            InvariantLoads.push_back(LI);
-          }
+  BasicBlock *Preheader = L->getLoopPreheader();
+  if (!Preheader) {
+    // If no preheader, try the header
+    Preheader = L->getHeader();
+  }
+
+  // Find all invariant loads in the loop
+  SmallVector<LoadInst *, 8> InvariantLoads;
+
+  for (BasicBlock *BB : L->blocks()) {
+    for (Instruction &I : *BB) {
+      if (auto *LI = dyn_cast<LoadInst>(&I)) {
+        Value *Ptr = LI->getPointerOperand();
+
+        if (L->isLoopInvariant(Ptr)) {
+          InvariantLoads.push_back(LI);
         }
       }
     }
-    
-    // Move loads to preheader and eliminate duplicates
-    DenseMap<Value*, LoadInst*> HoistedLoads;
-    
-    for (LoadInst *LI : InvariantLoads) {
-      Value *Ptr = LI->getPointerOperand();
-      
-      if (HoistedLoads.count(Ptr)) {
-        // Already hoisted this load, replace uses
-        LI->replaceAllUsesWith(HoistedLoads[Ptr]);
-        LI->eraseFromParent();
-      } else {
-        // Move to preheader
-        LI->moveBefore(*Preheader, Preheader->getTerminator()->getIterator());
-        HoistedLoads[Ptr] = LI;
-      }
+  }
+
+  // Move loads to preheader and eliminate duplicates
+  DenseMap<Value *, LoadInst *> HoistedLoads;
+
+  for (LoadInst *LI : InvariantLoads) {
+    Value *Ptr = LI->getPointerOperand();
+
+    if (HoistedLoads.count(Ptr)) {
+      // Already hoisted this load, replace uses
+      LI->replaceAllUsesWith(HoistedLoads[Ptr]);
+      LI->eraseFromParent();
+    } else {
+      // Move to preheader
+      LI->moveBefore(*Preheader, Preheader->getTerminator()->getIterator());
+      HoistedLoads[Ptr] = LI;
     }
+  }
 }
 
 // InnerInductionVar will be transformed to static
 void StridedLoopUnrollVersioning::transformStridedSpecialCases(
-    PHINode *OuterInductionVar, PHINode *InnerInductionVar,
-    StoreInst *Store, BasicBlock *PreheaderBB,
-    BasicBlock *BodyBB, BasicBlock *HeaderBB, BasicBlock *LatchBB,
-    SmallVectorImpl<const SCEV*>& AlignmentInfo,
+    PHINode *OuterInductionVar, PHINode *InnerInductionVar, StoreInst *Store,
+    BasicBlock *PreheaderBB, BasicBlock *BodyBB, BasicBlock *HeaderBB,
+    BasicBlock *LatchBB, SmallVectorImpl<const SCEV *> &AlignmentInfo,
     unsigned UnrollSize) {
 
   PredicatedScalarEvolution PSE(*SE, *CurLoop);
 
   auto VLAI = &LAIs.getInfo(*CurLoop);
-  LoopVersioning LVer2(*VLAI,
-                       VLAI->getRuntimePointerChecking()->getChecks(),
+  LoopVersioning LVer2(*VLAI, VLAI->getRuntimePointerChecking()->getChecks(),
                        CurLoop, LI, DT, SE, true);
   LVer2.versionLoop();
 
@@ -1048,7 +1032,7 @@ void StridedLoopUnrollVersioning::transformStridedSpecialCases(
   ULO.UnrollRemainder = false;
   ULO.SCEVExpansionBudget = -1;
 
-  /*auto r =*/ UnrollLoop(NewInnerLoop, ULO, LI, SE, DT, AC, TTI, ORE, false);
+  /*auto r =*/UnrollLoop(NewInnerLoop, ULO, LI, SE, DT, AC, TTI, ORE, false);
 
   hoistInvariantLoadsToPreheader(VersionedLoop);
 
@@ -1057,9 +1041,9 @@ void StridedLoopUnrollVersioning::transformStridedSpecialCases(
   }
 
   for (BasicBlock *BB : VersionedLoop->blocks()) {
-    DenseMap<Value*, Value*> LoadCSE;
-    SmallVector<Instruction*, 16> DeadInsts;
-    
+    DenseMap<Value *, Value *> LoadCSE;
+    SmallVector<Instruction *, 16> DeadInsts;
+
     for (Instruction &I : *BB) {
       if (auto *LI = dyn_cast<LoadInst>(&I)) {
         if (!LI->isVolatile()) {
@@ -1074,7 +1058,7 @@ void StridedLoopUnrollVersioning::transformStridedSpecialCases(
         }
       }
     }
-    
+
     for (auto *I : DeadInsts) {
       I->eraseFromParent();
     }
@@ -1103,7 +1087,7 @@ void StridedLoopUnrollVersioning::transformStridedSpecialCases(
 
       Value *innerMask = Builder.getIntN(
           InnerLoopBounds->getFinalIVValue().getType()->getIntegerBitWidth(),
-          UnrollSize-1);
+          UnrollSize - 1);
       Value *innerAndResult = Builder.CreateAnd(
           &InnerLoopBounds->getFinalIVValue(), innerMask, "inner_mod_unroll");
       Value *innerIsNotDivisible =
@@ -1121,8 +1105,8 @@ void StridedLoopUnrollVersioning::transformStridedSpecialCases(
       Value *mask = Builder.getIntN(
           OuterLoopBounds->getFinalIVValue().getType()->getIntegerBitWidth(),
           *o - 1);
-      Value *andResult =
-          Builder.CreateAnd(&OuterLoopBounds->getFinalIVValue(), mask, "div_unroll");
+      Value *andResult = Builder.CreateAnd(&OuterLoopBounds->getFinalIVValue(),
+                                           mask, "div_unroll");
       Value *isNotDivisible =
           Builder.CreateICmpNE(andResult, outerZero, "is_div_unroll");
       Value *Check1 = Builder.CreateOr(innerIsNotDivisible, isNotDivisible);
@@ -1130,22 +1114,25 @@ void StridedLoopUnrollVersioning::transformStridedSpecialCases(
 
       Value *AlignmentCheck = Builder.getFalse();
 
-      for (auto && PtrSCEV : AlignmentInfo) {
+      for (auto &&PtrSCEV : AlignmentInfo) {
         const unsigned Alignment = 8;
         // Expand SCEV to get runtime value
         SCEVExpander Expander(*SE, *DL, "align.check");
-        Value *PtrValue = Expander.expandCodeFor(PtrSCEV, Builder.getPtrTy(), PHBranch);
+        Value *PtrValue =
+            Expander.expandCodeFor(PtrSCEV, Builder.getPtrTy(), PHBranch);
 
         Type *I64 = Type::getInt64Ty(PtrValue->getContext());
-        bool AllowsMisaligned = TTI->isLegalStridedLoadStore(VectorType::get(I64, ElementCount::getFixed(8)), Align(1));
+        bool AllowsMisaligned = TTI->isLegalStridedLoadStore(
+            VectorType::get(I64, ElementCount::getFixed(8)), Align(1));
 
-        if(!AllowsMisaligned) {
+        if (!AllowsMisaligned) {
           // Create alignment check: (ptr & (alignment-1)) == 0
-          Value *PtrInt = Builder.CreatePtrToInt(PtrValue, Builder.getInt64Ty());
+          Value *PtrInt =
+              Builder.CreatePtrToInt(PtrValue, Builder.getInt64Ty());
           Value *Mask = Builder.getInt64(Alignment - 1);
           Value *Masked = Builder.CreateAnd(PtrInt, Mask);
           Value *IsAligned = Builder.CreateICmpNE(Masked, Builder.getInt64(0));
-  
+
           AlignmentCheck = Builder.CreateOr(AlignmentCheck, IsAligned);
         }
       }
@@ -1272,20 +1259,20 @@ bool StridedLoopUnrollVersioning::recognizeStridedSpecialCases() {
   llvm::SmallPtrSet<llvm::Instruction *, 32> NotVisited;
   llvm::SmallVector<llvm::Instruction *, 8> WorkList;
 
-  for (auto&& BB : CurLoop->getBlocks())
-    for (auto&& V : *BB)
+  for (auto &&BB : CurLoop->getBlocks())
+    for (auto &&V : *BB)
       if (BB != ForLoop)
         if (!canHandleInstruction(&V))
           return false;
-  for (auto&& Loop : CurLoop->getSubLoops())
-    for (auto&& BB : Loop->getBlocks())
-      for (auto&& V : *BB)
+  for (auto &&Loop : CurLoop->getSubLoops())
+    for (auto &&BB : Loop->getBlocks())
+      for (auto &&V : *BB)
         if (BB != ForLoop)
           if (!canHandleInstruction(&V))
             return false;
-  
+
   // Collect pointers needing alignment
-  SmallVector<const SCEV*, 8> AlignmentInfo;
+  SmallVector<const SCEV *, 8> AlignmentInfo;
   unsigned UnrollSize = 8;
 
   for (BasicBlock *BB : CurLoop->blocks()) {
@@ -1295,57 +1282,55 @@ bool StridedLoopUnrollVersioning::recognizeStridedSpecialCases() {
 
       if (LoadInst *LI = dyn_cast<LoadInst>(&I)) {
         Ptr = LI->getPointerOperand();
-        TypeSize typeSize  = DL->getTypeAllocSize(I.getType());
+        TypeSize typeSize = DL->getTypeAllocSize(I.getType());
         if (size == 0)
           size = typeSize;
         else if (size != typeSize)
           return false;
-      }
-      else if (StoreInst *SI = dyn_cast<StoreInst>(&I)) {
+      } else if (StoreInst *SI = dyn_cast<StoreInst>(&I)) {
         Ptr = SI->getPointerOperand();
-        TypeSize typeSize  = DL->getTypeAllocSize(SI->getValueOperand()->getType());
+        TypeSize typeSize =
+            DL->getTypeAllocSize(SI->getValueOperand()->getType());
         if (size == 0)
           size = typeSize;
         else if (size != typeSize)
           return false;
-        UnrollSize = 8/size;
-      }
-      else
+        UnrollSize = 8 / size;
+      } else
         continue;
 
       const SCEV *S = SE->getSCEV(Ptr);
 
       if (const SCEVAddRecExpr *InnerLoopAR = dyn_cast<SCEVAddRecExpr>(S)) {
         if (auto *constant =
-            dyn_cast<SCEVConstant>(InnerLoopAR->getStepRecurrence(*SE))) {
+                dyn_cast<SCEVConstant>(InnerLoopAR->getStepRecurrence(*SE))) {
           if (constant->getAPInt() != size)
             return false; // must be contiguous
 
           if (const SCEVAddRecExpr *AR =
-              dyn_cast<SCEVAddRecExpr>(InnerLoopAR->getStart())) {
+                  dyn_cast<SCEVAddRecExpr>(InnerLoopAR->getStart())) {
             auto Step = AR->getStepRecurrence(*SE);
             if (isa<SCEVConstant>(Step))
               return false;
             else {
-              const SCEVUnknown* Unknown = nullptr;
+              const SCEVUnknown *Unknown = nullptr;
 
               if (size > 1) {
                 if (auto mul = dyn_cast<SCEVMulExpr>(Step)) {
                   if (mul->getNumOperands() == 2) {
-                    if (auto constant = dyn_cast<SCEVConstant>(mul->getOperand(0))) {
+                    if (auto constant =
+                            dyn_cast<SCEVConstant>(mul->getOperand(0))) {
                       if (constant->getAPInt() != size)
                         return false;
-                    }
-                    else
+                    } else
                       return false;
                     Unknown = dyn_cast<SCEVUnknown>(mul->getOperand(1));
-                    if (auto CastExtend = dyn_cast<SCEVCastExpr>(mul->getOperand(1)))
+                    if (auto CastExtend =
+                            dyn_cast<SCEVCastExpr>(mul->getOperand(1)))
                       Unknown = dyn_cast<SCEVUnknown>(CastExtend->getOperand());
-                  }
-                  else
+                  } else
                     return false;
-                }
-                else
+                } else
                   return false;
               }
               if (!Unknown) {
@@ -1356,20 +1341,16 @@ bool StridedLoopUnrollVersioning::recognizeStridedSpecialCases() {
               if (Unknown) { // stride should be fixed but not constant
                 if (isa<Instruction>(Unknown->getValue()))
                   return false;
-              }
-              else
+              } else
                 return false;
             }
 
             AlignmentInfo.push_back({AR->getStart()});
-          }
-          else
+          } else
             return false;
-        }
-        else
+        } else
           return false;
-      }
-      else if (!CurLoop->isLoopInvariant(Ptr))
+      } else if (!CurLoop->isLoopInvariant(Ptr))
         return false;
     }
   }
@@ -1440,9 +1421,8 @@ bool StridedLoopUnrollVersioning::recognizeStridedSpecialCases() {
     return false;
 
   transformStridedSpecialCases(OuterInductionVariable, InnerInductionVariable,
-                               Stores[0], Preheader, ForLoop,
-                               OuterLoopHeader, OuterLoopLatch, AlignmentInfo,
-                               UnrollSize);
+                               Stores[0], Preheader, ForLoop, OuterLoopHeader,
+                               OuterLoopLatch, AlignmentInfo, UnrollSize);
 
   return true;
 }

``````````

</details>


https://github.com/llvm/llvm-project/pull/157749


More information about the llvm-commits mailing list