[llvm] [Vectorize] Add StridedLoopUnroll + Versioning for 2-D strided loop nests (PR #157749)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Sep 9 14:04:47 PDT 2025
github-actions[bot] wrote:
<!--LLVM CODE FORMAT COMMENT: {clang-format}-->
:warning: C/C++ code formatter, clang-format found issues in your code. :warning:
<details>
<summary>
You can test this locally with the following command:
</summary>
``````````bash
git-clang-format --diff origin/main HEAD --extensions h,cpp -- llvm/include/llvm/Transforms/Vectorize/StridedLoopUnroll.h llvm/lib/Transforms/Vectorize/StridedLoopUnroll.cpp llvm/include/llvm/Transforms/Utils/LoopVersioning.h llvm/lib/Passes/PassBuilder.cpp llvm/lib/Target/RISCV/RISCVTargetMachine.cpp llvm/lib/Transforms/Utils/LoopVersioning.cpp
``````````
:warning:
The reproduction instructions above might return results for more than one PR
in a stack if you are using a stacked PR workflow. You can limit the results by
changing `origin/main` to the base branch/commit you want to compare against.
:warning:
</details>
<details>
<summary>
View the diff from clang-format here.
</summary>
``````````diff
diff --git a/llvm/include/llvm/Transforms/Utils/LoopVersioning.h b/llvm/include/llvm/Transforms/Utils/LoopVersioning.h
index d58995b76..f6ccc8803 100644
--- a/llvm/include/llvm/Transforms/Utils/LoopVersioning.h
+++ b/llvm/include/llvm/Transforms/Utils/LoopVersioning.h
@@ -46,7 +46,8 @@ public:
/// object having no checks and we expect the user to add them.
LoopVersioning(const LoopAccessInfo &LAI,
ArrayRef<RuntimePointerCheck> Checks, Loop *L, LoopInfo *LI,
- DominatorTree *DT, ScalarEvolution *SE, bool HoistRuntimeChecks = false);
+ DominatorTree *DT, ScalarEvolution *SE,
+ bool HoistRuntimeChecks = false);
/// Performs the CFG manipulation part of versioning the loop including
/// the DominatorTree and LoopInfo updates.
diff --git a/llvm/include/llvm/Transforms/Vectorize/StridedLoopUnroll.h b/llvm/include/llvm/Transforms/Vectorize/StridedLoopUnroll.h
index baf1220a6..67ecf9df7 100644
--- a/llvm/include/llvm/Transforms/Vectorize/StridedLoopUnroll.h
+++ b/llvm/include/llvm/Transforms/Vectorize/StridedLoopUnroll.h
@@ -28,6 +28,7 @@ public:
class StridedLoopUnrollVersioningPass
: public PassInfoMixin<StridedLoopUnrollVersioningPass> {
int OptLevel;
+
public:
StridedLoopUnrollVersioningPass(int OptLevel = 2) : OptLevel(OptLevel) {}
diff --git a/llvm/lib/Passes/PassBuilder.cpp b/llvm/lib/Passes/PassBuilder.cpp
index c3b310298..12d80bd5d 100644
--- a/llvm/lib/Passes/PassBuilder.cpp
+++ b/llvm/lib/Passes/PassBuilder.cpp
@@ -378,8 +378,8 @@
#include "llvm/Transforms/Vectorize/LoopVectorize.h"
#include "llvm/Transforms/Vectorize/SLPVectorizer.h"
#include "llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizer.h"
-#include "llvm/Transforms/Vectorize/VectorCombine.h"
#include "llvm/Transforms/Vectorize/StridedLoopUnroll.h"
+#include "llvm/Transforms/Vectorize/VectorCombine.h"
#include <optional>
using namespace llvm;
diff --git a/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp b/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp
index a2dd70459..6a8047ebd 100644
--- a/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp
+++ b/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp
@@ -38,9 +38,9 @@
#include "llvm/Target/TargetOptions.h"
#include "llvm/Transforms/IPO.h"
#include "llvm/Transforms/Scalar.h"
+#include "llvm/Transforms/Scalar/LoopUnrollPass.h"
#include "llvm/Transforms/Vectorize/LoopIdiomVectorize.h"
#include "llvm/Transforms/Vectorize/StridedLoopUnroll.h"
-#include "llvm/Transforms/Scalar/LoopUnrollPass.h"
#include <optional>
using namespace llvm;
@@ -658,15 +658,16 @@ void RISCVTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) {
LPM.addPass(LoopIdiomVectorizePass(LoopIdiomVectorizeStyle::Predicated));
});
- PB.registerScalarOptimizerLateEPCallback([=](FunctionPassManager &LPM,
- OptimizationLevel) {
- LPM.addPass(StridedLoopUnrollVersioningPass());
- });
- PB.registerOptimizerLastEPCallback(
- [=](ModulePassManager &MPM, OptimizationLevel Level, llvm::ThinOrFullLTOPhase) {
- MPM.addPass(createModuleToFunctionPassAdaptor(
- createFunctionToLoopPassAdaptor(StridedLoopUnrollPass())));
+ PB.registerScalarOptimizerLateEPCallback(
+ [=](FunctionPassManager &LPM, OptimizationLevel) {
+ LPM.addPass(StridedLoopUnrollVersioningPass());
});
+ PB.registerOptimizerLastEPCallback([=](ModulePassManager &MPM,
+ OptimizationLevel Level,
+ llvm::ThinOrFullLTOPhase) {
+ MPM.addPass(createModuleToFunctionPassAdaptor(
+ createFunctionToLoopPassAdaptor(StridedLoopUnrollPass())));
+ });
}
yaml::MachineFunctionInfo *
diff --git a/llvm/lib/Transforms/Utils/LoopVersioning.cpp b/llvm/lib/Transforms/Utils/LoopVersioning.cpp
index 002717946..684f58cc5 100644
--- a/llvm/lib/Transforms/Utils/LoopVersioning.cpp
+++ b/llvm/lib/Transforms/Utils/LoopVersioning.cpp
@@ -42,10 +42,10 @@ static cl::opt<bool>
LoopVersioning::LoopVersioning(const LoopAccessInfo &LAI,
ArrayRef<RuntimePointerCheck> Checks, Loop *L,
LoopInfo *LI, DominatorTree *DT,
- ScalarEvolution *SE,
- bool HoistRuntimeChecks)
+ ScalarEvolution *SE, bool HoistRuntimeChecks)
: VersionedLoop(L), AliasChecks(Checks), Preds(LAI.getPSE().getPredicate()),
- LAI(LAI), LI(LI), DT(DT), SE(SE), HoistRuntimeChecks(HoistRuntimeChecks) {}
+ LAI(LAI), LI(LI), DT(DT), SE(SE), HoistRuntimeChecks(HoistRuntimeChecks) {
+}
void LoopVersioning::versionLoop(
const SmallVectorImpl<Instruction *> &DefsUsedOutside) {
@@ -64,8 +64,9 @@ void LoopVersioning::versionLoop(
SCEVExpander Exp2(*RtPtrChecking.getSE(),
VersionedLoop->getHeader()->getDataLayout(),
"induction");
- MemRuntimeCheck = addRuntimeChecks(RuntimeCheckBB->getTerminator(),
- VersionedLoop, AliasChecks, Exp2, HoistRuntimeChecks);
+ MemRuntimeCheck =
+ addRuntimeChecks(RuntimeCheckBB->getTerminator(), VersionedLoop,
+ AliasChecks, Exp2, HoistRuntimeChecks);
SCEVExpander Exp(*SE, RuntimeCheckBB->getDataLayout(),
"scev.check");
diff --git a/llvm/lib/Transforms/Vectorize/StridedLoopUnroll.cpp b/llvm/lib/Transforms/Vectorize/StridedLoopUnroll.cpp
index 8f1f6cf19..35297911e 100644
--- a/llvm/lib/Transforms/Vectorize/StridedLoopUnroll.cpp
+++ b/llvm/lib/Transforms/Vectorize/StridedLoopUnroll.cpp
@@ -34,9 +34,12 @@
#include "llvm/Transforms/Vectorize/StridedLoopUnroll.h"
#include "llvm/Analysis/AliasAnalysis.h"
+#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/DomTreeUpdater.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/LoopPass.h"
+#include "llvm/Analysis/MemorySSA.h"
+#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/Analysis/ScalarEvolution.h"
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
#include "llvm/Analysis/TargetTransformInfo.h"
@@ -45,32 +48,13 @@
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/PatternMatch.h"
+#include "llvm/Target/TargetMachine.h"
+#include "llvm/Transforms/Scalar/EarlyCSE.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+#include "llvm/Transforms/Utils/Local.h"
#include "llvm/Transforms/Utils/LoopVersioning.h"
#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
-#include "llvm/Analysis/AliasAnalysis.h"
-#include "llvm/Analysis/DomTreeUpdater.h"
-#include "llvm/Analysis/LoopInfo.h"
-#include "llvm/Analysis/LoopPass.h"
-#include "llvm/Analysis/ScalarEvolution.h"
-#include "llvm/Analysis/ScalarEvolutionExpressions.h"
-#include "llvm/Analysis/TargetTransformInfo.h"
-#include "llvm/IR/Dominators.h"
-#include "llvm/IR/IRBuilder.h"
-#include "llvm/IR/Intrinsics.h"
-#include "llvm/IR/MDBuilder.h"
-#include "llvm/IR/PatternMatch.h"
-#include "llvm/Transforms/Utils/BasicBlockUtils.h"
-#include "llvm/Transforms/Utils/LoopVersioning.h"
#include "llvm/Transforms/Utils/UnrollLoop.h"
-#include "llvm/Analysis/TargetTransformInfo.h"
-#include "llvm/Analysis/OptimizationRemarkEmitter.h"
-#include "llvm/Analysis/AssumptionCache.h"
-#include "llvm/Transforms/Utils/Local.h"
-#include "llvm/Transforms/Scalar/EarlyCSE.h"
-#include "llvm/Analysis/MemorySSA.h"
-#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
-#include "llvm/Target/TargetMachine.h"
using namespace llvm;
using namespace PatternMatch;
@@ -101,9 +85,9 @@ class StridedLoopUnroll {
const LoopAccessInfo *LAI = nullptr;
public:
- StridedLoopUnroll(DominatorTree *DT, LoopInfo *LI,
- TargetTransformInfo *TTI, const DataLayout *DL,
- ScalarEvolution *SE, AliasAnalysis *AA, AssumptionCache* AC)
+ StridedLoopUnroll(DominatorTree *DT, LoopInfo *LI, TargetTransformInfo *TTI,
+ const DataLayout *DL, ScalarEvolution *SE,
+ AliasAnalysis *AA, AssumptionCache *AC)
: DT(DT), LI(LI), TTI(TTI), DL(DL), SE(SE), AA(AA),
LAIs(*SE, *AA, *DT, *LI, TTI, nullptr, AC) {}
@@ -119,21 +103,22 @@ private:
void transformStridedSpecialCases(BasicBlock *Header, BasicBlock *Latch,
BasicBlock *Preheader, Loop *SubLoop,
- SmallVectorImpl<LoadInst*> &Loads,
- StoreInst* Store,
- SmallVectorImpl<Value *>& PostOrder,
- SmallVectorImpl<Value *>& PreOrder);
+ SmallVectorImpl<LoadInst *> &Loads,
+ StoreInst *Store,
+ SmallVectorImpl<Value *> &PostOrder,
+ SmallVectorImpl<Value *> &PreOrder);
void changeInductionVarIncrement(Value *IncomingValue, unsigned VF);
std::optional<Value *> getDynamicStrideFromMemOp(Value *Value,
Instruction *InsertionPt);
- std::optional<Value*> getStrideFromAddRecExpr(const SCEVAddRecExpr* AR, Instruction *InsertionPt);
+ std::optional<Value *> getStrideFromAddRecExpr(const SCEVAddRecExpr *AR,
+ Instruction *InsertionPt);
/// @}
};
static cl::opt<bool>
SkipPass("strided-loop-unroll-disable", cl::init(false), cl::Hidden,
- cl::desc("Skip running strided loop unroll optimization."));
+ cl::desc("Skip running strided loop unroll optimization."));
class StridedLoopUnrollVersioning {
Loop *CurLoop = nullptr;
@@ -144,9 +129,9 @@ class StridedLoopUnrollVersioning {
ScalarEvolution *SE;
AliasAnalysis *AA;
LoopAccessInfoManager LAIs;
- AssumptionCache* AC;
- OptimizationRemarkEmitter* ORE;
- Function* F;
+ AssumptionCache *AC;
+ OptimizationRemarkEmitter *ORE;
+ Function *F;
// Blocks that will be used for inserting vectorized code.
BasicBlock *EndBlock = nullptr;
@@ -158,10 +143,10 @@ class StridedLoopUnrollVersioning {
public:
StridedLoopUnrollVersioning(DominatorTree *DT, LoopInfo *LI,
- TargetTransformInfo *TTI, const DataLayout *DL,
- ScalarEvolution *SE, AliasAnalysis *AA, AssumptionCache *AC,
- OptimizationRemarkEmitter* ORE,
- Function* F)
+ TargetTransformInfo *TTI, const DataLayout *DL,
+ ScalarEvolution *SE, AliasAnalysis *AA,
+ AssumptionCache *AC,
+ OptimizationRemarkEmitter *ORE, Function *F)
: DT(DT), LI(LI), TTI(TTI), DL(DL), SE(SE), AA(AA),
LAIs(*SE, *AA, *DT, *LI, TTI, nullptr, AC), AC(AC), ORE(ORE), F(F) {}
@@ -173,23 +158,21 @@ private:
void setNoAliasToLoop(Loop *VerLoop);
bool recognizeStridedSpecialCases();
- void transformStridedSpecialCases(PHINode *OuterInductionVar,
- PHINode *InnerInductionVar,
- StoreInst *Stores, BasicBlock *PreheaderBB,
- BasicBlock *BodyBB, BasicBlock *HeaderBB,
- BasicBlock *LatchBB,
- SmallVectorImpl<const SCEV*>& AlignmentInfo,
- unsigned UnrollSize);
+ void transformStridedSpecialCases(
+ PHINode *OuterInductionVar, PHINode *InnerInductionVar, StoreInst *Stores,
+ BasicBlock *PreheaderBB, BasicBlock *BodyBB, BasicBlock *HeaderBB,
+ BasicBlock *LatchBB, SmallVectorImpl<const SCEV *> &AlignmentInfo,
+ unsigned UnrollSize);
void eliminateRedundantLoads(BasicBlock *BB) {
// Map from load pointer to the first load instruction
- DenseMap<Value*, LoadInst*> LoadMap;
- SmallVector<LoadInst*, 16> ToDelete;
-
+ DenseMap<Value *, LoadInst *> LoadMap;
+ SmallVector<LoadInst *, 16> ToDelete;
+
// First pass: collect all loads and find duplicates
for (Instruction &I : *BB) {
if (auto *LI = dyn_cast<LoadInst>(&I)) {
Value *Ptr = LI->getPointerOperand();
-
+
// Check if we've seen a load from this address
auto It = LoadMap.find(Ptr);
if (It != LoadMap.end() && !LI->isVolatile()) {
@@ -208,7 +191,7 @@ private:
}
}
}
-
+
// Delete redundant loads
for (LoadInst *LI : ToDelete) {
LI->eraseFromParent();
@@ -219,7 +202,6 @@ private:
/// @}
};
-
} // anonymous namespace
PreservedAnalyses StridedLoopUnrollPass::run(Loop &L, LoopAnalysisManager &AM,
@@ -253,20 +235,20 @@ bool StridedLoopUnroll::run(Loop *L) {
<< CurLoop->getHeader()->getName() << "\n");
if (recognizeStridedSpecialCases()) {
- LLVM_DEBUG(dbgs() << DEBUG_TYPE " Will transform: F[" << F.getName() << "] Loop %"
- << CurLoop->getHeader()->getName() << "\n");
+ LLVM_DEBUG(dbgs() << DEBUG_TYPE " Will transform: F[" << F.getName()
+ << "] Loop %" << CurLoop->getHeader()->getName() << "\n");
return true;
}
- LLVM_DEBUG(dbgs() << DEBUG_TYPE " Will not transform: F[" << F.getName() << "] Loop %"
- << CurLoop->getHeader()->getName() << "\n");
+ LLVM_DEBUG(dbgs() << DEBUG_TYPE " Will not transform: F[" << F.getName()
+ << "] Loop %" << CurLoop->getHeader()->getName() << "\n");
return false;
}
bool StridedLoopUnrollVersioning::run(Loop *L) {
CurLoop = L;
- if(!TTI->getVScaleForTuning())
+ if (!TTI->getVScaleForTuning())
return false;
Function &F = *L->getHeader()->getParent();
@@ -282,13 +264,13 @@ bool StridedLoopUnrollVersioning::run(Loop *L) {
<< CurLoop->getHeader()->getName() << "\n");
if (recognizeStridedSpecialCases()) {
- LLVM_DEBUG(dbgs() << DEBUG_TYPE " Will transform: F[" << F.getName() << "] Loop %"
- << CurLoop->getHeader()->getName() << "\n");
+ LLVM_DEBUG(dbgs() << DEBUG_TYPE " Will transform: F[" << F.getName()
+ << "] Loop %" << CurLoop->getHeader()->getName() << "\n");
return true;
}
- LLVM_DEBUG(dbgs() << DEBUG_TYPE " Will not transform: F[" << F.getName() << "] Loop %"
- << CurLoop->getHeader()->getName() << "\n");
+ LLVM_DEBUG(dbgs() << DEBUG_TYPE " Will not transform: F[" << F.getName()
+ << "] Loop %" << CurLoop->getHeader()->getName() << "\n");
return false;
}
@@ -360,7 +342,8 @@ static void findUnconnectedToLoad(Instruction *start,
bool shouldBreak = isa<LoadInst>(inst);
// If this is a load, do not proceed from here!
- connected = isa<LoadInst>(inst) && start->getParent()->getName() == inst->getParent()->getName();
+ connected = isa<LoadInst>(inst) &&
+ start->getParent()->getName() == inst->getParent()->getName();
if (shouldBreak)
break;
@@ -369,8 +352,7 @@ static void findUnconnectedToLoad(Instruction *start,
if (auto I = dyn_cast<Instruction>(op.get())) {
if (I->getParent() == start->getParent())
innerWorklist.push_back(op.get());
- }
- else
+ } else
innerWorklist.push_back(op.get());
}
}
@@ -434,12 +416,12 @@ Value *StridedLoopUnroll::widenVectorizedInstruction(
return V;
}
case Instruction::Select: {
- Value* V = Builder.CreateSelect(Ops[0], Ops[1], Ops[2]);
+ Value *V = Builder.CreateSelect(Ops[0], Ops[1], Ops[2]);
return V;
}
case Instruction::ICmp: {
auto Cmp = dyn_cast<ICmpInst>(I);
- Value* V = Builder.CreateICmp(Cmp->getPredicate(), Ops[0], Ops[1]);
+ Value *V = Builder.CreateICmp(Cmp->getPredicate(), Ops[0], Ops[1]);
return V;
}
case Instruction::ShuffleVector: {
@@ -453,11 +435,13 @@ Value *StridedLoopUnroll::widenVectorizedInstruction(
}
ArrayRef<int> NewMask(RepeatedMask);
- Value *Shuffle = Builder.CreateShuffleVector(Ops[0], Ops[1], NewMask, I->getName());
+ Value *Shuffle =
+ Builder.CreateShuffleVector(Ops[0], Ops[1], NewMask, I->getName());
return Shuffle;
}
case Instruction::InsertElement: {
- Value *A = Builder.CreateInsertElement(Ops[0], Ops[1], Ops[2], I->getName());
+ Value *A =
+ Builder.CreateInsertElement(Ops[0], Ops[1], Ops[2], I->getName());
return A;
}
case Instruction::Load: {
@@ -513,89 +497,90 @@ bool isInstructionDepends(Instruction *Dependant, Instruction *Target) {
// InnerInductionVar will be transformed to static
void StridedLoopUnroll::transformStridedSpecialCases(
- BasicBlock *Header, BasicBlock *Latch, BasicBlock *Preheader,
- Loop *SubLoop,
- SmallVectorImpl<LoadInst*> &Loads, StoreInst* Store,
- SmallVectorImpl<Value *>& PostOrder,
- SmallVectorImpl<Value *>& PreOrder) {
+ BasicBlock *Header, BasicBlock *Latch, BasicBlock *Preheader, Loop *SubLoop,
+ SmallVectorImpl<LoadInst *> &Loads, StoreInst *Store,
+ SmallVectorImpl<Value *> &PostOrder, SmallVectorImpl<Value *> &PreOrder) {
- //auto InnerPreheader = SubLoop->getLoopPreheader();
+ // auto InnerPreheader = SubLoop->getLoopPreheader();
- auto Stride = getDynamicStrideFromMemOp(Store->getPointerOperand(), Preheader->getTerminator());
+ auto Stride = getDynamicStrideFromMemOp(Store->getPointerOperand(),
+ Preheader->getTerminator());
SmallPtrSet<llvm::Value *, 32> Connected;
SmallPtrSet<llvm::Value *, 32> NotConnected;
SmallDenseMap<llvm::Value *, llvm::Value *, 32> Replacements;
- auto StoredInstruction =
- dyn_cast<Instruction>(Store->getValueOperand());
+ auto StoredInstruction = dyn_cast<Instruction>(Store->getValueOperand());
findUnconnectedToLoad(StoredInstruction, NotConnected, Connected);
- auto convertConstant = [&] (auto val) {
- auto constVal = cast<Constant>(val);
- unsigned numElements =
- cast<FixedVectorType>(val->getType())->getNumElements();
- SmallVector<Constant *, 16> elements;
-
- // Extract original elements
- for (unsigned i = 0; i < numElements; ++i)
- elements.push_back(constVal->getAggregateElement(i));
-
- auto originalElements = elements;
- for (unsigned int copy = 0; copy != (*TTI->getVScaleForTuning())-1; ++copy)
- elements.append(originalElements);
- Constant *newConst = ConstantVector::get(elements);
- return newConst;
+ auto convertConstant = [&](auto val) {
+ auto constVal = cast<Constant>(val);
+ unsigned numElements =
+ cast<FixedVectorType>(val->getType())->getNumElements();
+ SmallVector<Constant *, 16> elements;
+
+ // Extract original elements
+ for (unsigned i = 0; i < numElements; ++i)
+ elements.push_back(constVal->getAggregateElement(i));
+
+ auto originalElements = elements;
+ for (unsigned int copy = 0; copy != (*TTI->getVScaleForTuning()) - 1;
+ ++copy)
+ elements.append(originalElements);
+ Constant *newConst = ConstantVector::get(elements);
+ return newConst;
};
// Process in post-order (leafs to root)
for (Value *val : PostOrder) {
if (Connected.contains(val)) {
if (auto *I = dyn_cast<Instruction>(val)) {
- SmallVector<Value *, 4> Operands(I->operands());
- for (auto op_it = Operands.begin(); op_it != Operands.end(); ++op_it) {
- if (Replacements.contains(*op_it))
- *op_it = Replacements[*op_it];
- else if (auto OrigVecTy = llvm::dyn_cast<llvm::VectorType>((*op_it)->getType())) {
- if (auto Iop = dyn_cast<Instruction>(*op_it)) {
- if (Iop->getParent() != Store->getParent()) {
- assert(!Connected.contains(*op_it));
-
- IRBuilder<> Builder(I);
-
- std::vector<llvm::Constant*> Consts;
- for (unsigned int i = 0; i != *TTI->getVScaleForTuning(); i++) {
- for (size_t j = 0; j != OrigVecTy->getElementCount().getFixedValue(); j++) {
- Consts.push_back(llvm::ConstantInt::get(llvm::Type::getInt32Ty(Builder.getContext()), j));
- }
+ SmallVector<Value *, 4> Operands(I->operands());
+ for (auto op_it = Operands.begin(); op_it != Operands.end(); ++op_it) {
+ if (Replacements.contains(*op_it))
+ *op_it = Replacements[*op_it];
+ else if (auto OrigVecTy =
+ llvm::dyn_cast<llvm::VectorType>((*op_it)->getType())) {
+ if (auto Iop = dyn_cast<Instruction>(*op_it)) {
+ if (Iop->getParent() != Store->getParent()) {
+ assert(!Connected.contains(*op_it));
+
+ IRBuilder<> Builder(I);
+
+ std::vector<llvm::Constant *> Consts;
+ for (unsigned int i = 0; i != *TTI->getVScaleForTuning(); i++) {
+ for (size_t j = 0;
+ j != OrigVecTy->getElementCount().getFixedValue(); j++) {
+ Consts.push_back(llvm::ConstantInt::get(
+ llvm::Type::getInt32Ty(Builder.getContext()), j));
}
-
- llvm::Constant* maskConst =
- llvm::ConstantVector::get(Consts);
- assert(maskConst != nullptr);
-
- llvm::Value* splat =
- Builder.CreateShuffleVector(Iop,
- llvm::UndefValue::get(Iop->getType()),
- maskConst);
- assert(splat != nullptr);
- Replacements.insert({*op_it, splat});
- *op_it = splat;
}
- } else if (isa<Constant>(*op_it)) { // not instruction
- auto replacement = convertConstant(*op_it);
- assert(!!replacement);
- Replacements.insert({*op_it, replacement});
- *op_it = replacement;
+
+ llvm::Constant *maskConst = llvm::ConstantVector::get(Consts);
+ assert(maskConst != nullptr);
+
+ llvm::Value *splat = Builder.CreateShuffleVector(
+ Iop, llvm::UndefValue::get(Iop->getType()), maskConst);
+ assert(splat != nullptr);
+ Replacements.insert({*op_it, splat});
+ *op_it = splat;
}
+ } else if (isa<Constant>(*op_it)) { // not instruction
+ auto replacement = convertConstant(*op_it);
+ assert(!!replacement);
+ Replacements.insert({*op_it, replacement});
+ *op_it = replacement;
}
}
+ }
- auto NewVecTy = getWidenedType(I->getType(), *TTI->getVScaleForTuning());
- Value *NI = widenVectorizedInstruction(I, Operands, NewVecTy, *TTI->getVScaleForTuning());
+ auto NewVecTy =
+ getWidenedType(I->getType(), *TTI->getVScaleForTuning());
+ Value *NI = widenVectorizedInstruction(I, Operands, NewVecTy,
+ *TTI->getVScaleForTuning());
- assert(NI != nullptr);
- Replacements.insert({I, NI});
+ assert(NI != nullptr);
+ Replacements.insert({I, NI});
}
} else if (NotConnected.contains(val)) {
if (val->getType()->isVectorTy() && isa<Constant>(val)) {
@@ -603,44 +588,45 @@ void StridedLoopUnroll::transformStridedSpecialCases(
Replacements.insert({val, replacement});
}
} else if (auto Load = dyn_cast<LoadInst>(val)) {
- auto It =
- std::find_if(Loads.begin(), Loads.end(), [Load](auto &&LoadInstr) {
- return LoadInstr == Load;
- });
- if (It != Loads.end()) {
- auto Stride = getDynamicStrideFromMemOp((*It)->getPointerOperand(), Preheader->getTerminator());
-
- auto GroupedVecTy = getGroupedWidenedType(Load->getType(), *TTI->getVScaleForTuning(), *DL);
- auto VecTy = getWidenedType(Load->getType(), *TTI->getVScaleForTuning());
- ElementCount NewElementCount = GroupedVecTy->getElementCount();
-
- IRBuilder<> Builder(Load);
- auto *NewInst = Builder.CreateIntrinsic(
- Intrinsic::experimental_vp_strided_load,
- {GroupedVecTy, Load->getPointerOperand()->getType(),
- (*Stride)->getType()},
- {Load->getPointerOperand(), *Stride,
- Builder.getAllOnesMask(NewElementCount),
- Builder.getInt32(NewElementCount.getKnownMinValue())});
- auto Cast = Builder.CreateBitCast(NewInst, VecTy);
- Replacements.insert({Load, Cast});
+ auto It =
+ std::find_if(Loads.begin(), Loads.end(),
+ [Load](auto &&LoadInstr) { return LoadInstr == Load; });
+ if (It != Loads.end()) {
+ auto Stride = getDynamicStrideFromMemOp((*It)->getPointerOperand(),
+ Preheader->getTerminator());
+
+ auto GroupedVecTy = getGroupedWidenedType(
+ Load->getType(), *TTI->getVScaleForTuning(), *DL);
+ auto VecTy =
+ getWidenedType(Load->getType(), *TTI->getVScaleForTuning());
+ ElementCount NewElementCount = GroupedVecTy->getElementCount();
+
+ IRBuilder<> Builder(Load);
+ auto *NewInst = Builder.CreateIntrinsic(
+ Intrinsic::experimental_vp_strided_load,
+ {GroupedVecTy, Load->getPointerOperand()->getType(),
+ (*Stride)->getType()},
+ {Load->getPointerOperand(), *Stride,
+ Builder.getAllOnesMask(NewElementCount),
+ Builder.getInt32(NewElementCount.getKnownMinValue())});
+ auto Cast = Builder.CreateBitCast(NewInst, VecTy);
+ Replacements.insert({Load, Cast});
}
}
}
IRBuilder<> Builder(Store);
- auto VecTy =
- getGroupedWidenedType(Store->getValueOperand()->getType(), *TTI->getVScaleForTuning(), *DL);
+ auto VecTy = getGroupedWidenedType(Store->getValueOperand()->getType(),
+ *TTI->getVScaleForTuning(), *DL);
ElementCount NewElementCount = VecTy->getElementCount();
assert(Replacements.find(Store->getValueOperand()) != Replacements.end());
- auto Cast = Builder.CreateBitCast(
- Replacements[Store->getValueOperand()], VecTy);
+ auto Cast =
+ Builder.CreateBitCast(Replacements[Store->getValueOperand()], VecTy);
Builder.CreateIntrinsic(
Intrinsic::experimental_vp_strided_store,
- {VecTy, Store->getPointerOperand()->getType(),
- (*Stride)->getType()},
+ {VecTy, Store->getPointerOperand()->getType(), (*Stride)->getType()},
{Cast, Store->getPointerOperand(), *Stride,
Builder.getAllOnesMask(NewElementCount),
Builder.getInt32(NewElementCount.getKnownMinValue())});
@@ -651,11 +637,13 @@ void StridedLoopUnroll::transformStridedSpecialCases(
if (InductionDescriptor::isInductionPHI(&PN, CurLoop, SE, IndDesc)) {
if (IndDesc.getKind() == InductionDescriptor::IK_PtrInduction)
changeInductionVarIncrement(
- PN.getIncomingValueForBlock(CurLoop->getLoopLatch()), *TTI->getVScaleForTuning());
+ PN.getIncomingValueForBlock(CurLoop->getLoopLatch()),
+ *TTI->getVScaleForTuning());
else if (IndDesc.getKind() == InductionDescriptor::IK_IntInduction)
- changeInductionVarIncrement(IndDesc.getInductionBinOp(), *TTI->getVScaleForTuning());
+ changeInductionVarIncrement(IndDesc.getInductionBinOp(),
+ *TTI->getVScaleForTuning());
}
- }
+ }
if (Store->use_empty())
Store->eraseFromParent();
@@ -666,14 +654,15 @@ void StridedLoopUnroll::transformStridedSpecialCases(
I->eraseFromParent();
}
-std::optional<Value*> StridedLoopUnroll::getStrideFromAddRecExpr(const SCEVAddRecExpr* AR,
- Instruction *InsertionPt) {
+std::optional<Value *>
+StridedLoopUnroll::getStrideFromAddRecExpr(const SCEVAddRecExpr *AR,
+ Instruction *InsertionPt) {
auto Step = AR->getStepRecurrence(*SE);
if (isa<SCEVConstant>(Step))
return std::nullopt;
SCEVExpander Expander(*SE, *DL, "stride");
Value *StrideValue =
- Expander.expandCodeFor(Step, Step->getType(), InsertionPt);
+ Expander.expandCodeFor(Step, Step->getType(), InsertionPt);
return StrideValue;
}
@@ -683,7 +672,7 @@ StridedLoopUnroll::getDynamicStrideFromMemOp(Value *V,
const SCEV *S = SE->getSCEV(V);
if (const SCEVAddRecExpr *InnerLoopAR = dyn_cast<SCEVAddRecExpr>(S)) {
if (auto *constant =
- dyn_cast<SCEVConstant>(InnerLoopAR->getStepRecurrence(*SE))) {
+ dyn_cast<SCEVConstant>(InnerLoopAR->getStepRecurrence(*SE))) {
// We need to form 64-bit groups
if (constant->getAPInt() != 8) {
return std::nullopt;
@@ -692,16 +681,16 @@ StridedLoopUnroll::getDynamicStrideFromMemOp(Value *V,
const auto *Add = dyn_cast<SCEVAddExpr>(InnerLoopAR->getStart());
if (Add) {
for (const SCEV *Op : Add->operands()) {
- // Look for the outer recurrence: { %dst, +, sext(%i_dst_stride) } <outer loop>
+ // Look for the outer recurrence: { %dst, +, sext(%i_dst_stride) }
+ // <outer loop>
const auto *AR = dyn_cast<SCEVAddRecExpr>(Op);
if (!AR)
continue;
return getStrideFromAddRecExpr(AR, InsertionPt);
}
- }
- else if (const SCEVAddRecExpr *AR =
- dyn_cast<SCEVAddRecExpr>(InnerLoopAR->getStart())) {
+ } else if (const SCEVAddRecExpr *AR =
+ dyn_cast<SCEVAddRecExpr>(InnerLoopAR->getStart())) {
return getStrideFromAddRecExpr(AR, InsertionPt);
}
}
@@ -715,12 +704,11 @@ bool StridedLoopUnroll::recognizeStridedSpecialCases() {
return false;
auto SubLoops = CurLoop->getSubLoops();
-
+
if (SubLoops.size() > 2 || SubLoops.empty())
return false;
- auto SubLoop = SubLoops.size() == 2 ? SubLoops[1]
- : SubLoops[0];
+ auto SubLoop = SubLoops.size() == 2 ? SubLoops[1] : SubLoops[0];
auto Preheader = SubLoop->getLoopPreheader();
auto Header = SubLoop->getHeader();
@@ -729,8 +717,8 @@ bool StridedLoopUnroll::recognizeStridedSpecialCases() {
if (Header != Latch)
return false;
- SmallVector<LoadInst*> Loads;
- SmallVector<StoreInst*> Stores;
+ SmallVector<LoadInst *> Loads;
+ SmallVector<StoreInst *> Stores;
llvm::SmallPtrSet<llvm::Instruction *, 32> NotVisited;
llvm::SmallVector<llvm::Instruction *, 8> WorkList;
@@ -775,8 +763,7 @@ bool StridedLoopUnroll::recognizeStridedSpecialCases() {
SmallPtrSet<llvm::Value *, 32> Connected;
SmallPtrSet<llvm::Value *, 32> NotConnected;
- auto StoredInstruction =
- dyn_cast<Instruction>(Stores[0]->getValueOperand());
+ auto StoredInstruction = dyn_cast<Instruction>(Stores[0]->getValueOperand());
findUnconnectedToLoad(StoredInstruction, NotConnected, Connected);
llvm::SmallVector<Value *, 16> PostOrder;
@@ -787,12 +774,10 @@ bool StridedLoopUnroll::recognizeStridedSpecialCases() {
Worklist.push_back(StoredInstruction);
- auto shouldVisit = [Header, Preheader](auto *Val)
- {
+ auto shouldVisit = [Header, Preheader](auto *Val) {
return !isa<PHINode>(Val) &&
- (
- !isa<Instruction>(Val) ||
- dyn_cast<Instruction>(Val)->getParent() == Header);
+ (!isa<Instruction>(Val) ||
+ dyn_cast<Instruction>(Val)->getParent() == Header);
};
auto shouldVisitOperands = [Header, Preheader](auto *Val) {
return !isa<PHINode>(Val) && !isa<LoadInst>(Val);
@@ -800,8 +785,9 @@ bool StridedLoopUnroll::recognizeStridedSpecialCases() {
while (!Worklist.empty()) {
Value *val = Worklist.back();
- assert (!isa<Instruction>(val) || dyn_cast<Instruction>(val)->getParent() == Header
- || dyn_cast<Instruction>(val)->getParent() == Preheader);
+ assert(!isa<Instruction>(val) ||
+ dyn_cast<Instruction>(val)->getParent() == Header ||
+ dyn_cast<Instruction>(val)->getParent() == Preheader);
if (InStack.contains(val)) {
// We've finished processing children, add to post-order
@@ -839,12 +825,11 @@ bool StridedLoopUnroll::recognizeStridedSpecialCases() {
}
}
}
- }
- else { // We don't handle Non-instructions connected to Load
+ } else { // We don't handle Non-instructions connected to Load
return false;
}
- } else if (NotConnected.contains(val)
- && (!val->getType()->isVectorTy() || !isa<Constant>(val))) {
+ } else if (NotConnected.contains(val) &&
+ (!val->getType()->isVectorTy() || !isa<Constant>(val))) {
return false;
} else if (auto Load = dyn_cast<LoadInst>(val)) {
if (std::find(Loads.begin(), Loads.end(), Load) == Loads.end())
@@ -852,8 +837,8 @@ bool StridedLoopUnroll::recognizeStridedSpecialCases() {
}
}
- transformStridedSpecialCases(Header, Latch, Preheader, SubLoop, Loads, Stores[0],
- PostOrder, PreOrder);
+ transformStridedSpecialCases(Header, Latch, Preheader, SubLoop, Loads,
+ Stores[0], PostOrder, PreOrder);
return true;
}
@@ -901,8 +886,8 @@ bool canHandleInstruction(Instruction *I) {
} // anonymous namespace
PreservedAnalyses
-StridedLoopUnrollVersioningPass::run(Function &F, FunctionAnalysisManager &FAM)
-{
+StridedLoopUnrollVersioningPass::run(Function &F,
+ FunctionAnalysisManager &FAM) {
bool Changed = false;
if (SkipPass)
@@ -927,7 +912,8 @@ StridedLoopUnrollVersioningPass::run(Function &F, FunctionAnalysisManager &FAM)
const auto *DL = &L->getHeader()->getDataLayout();
- StridedLoopUnrollVersioning LIV(&DT, &LI, &TTI, DL, &SE, &AA, &AC, &ORE, &F);
+ StridedLoopUnrollVersioning LIV(&DT, &LI, &TTI, DL, &SE, &AA, &AC, &ORE,
+ &F);
bool ThisChanged = LIV.run(L);
Changed |= ThisChanged;
}
@@ -968,58 +954,56 @@ void StridedLoopUnrollVersioning::setNoAliasToLoop(Loop *VerLoop) {
}
void StridedLoopUnrollVersioning::hoistInvariantLoadsToPreheader(Loop *L) {
- BasicBlock *Preheader = L->getLoopPreheader();
- if (!Preheader) {
- // If no preheader, try the header
- Preheader = L->getHeader();
- }
-
- // Find all invariant loads in the loop
- SmallVector<LoadInst*, 8> InvariantLoads;
-
- for (BasicBlock *BB : L->blocks()) {
- for (Instruction &I : *BB) {
- if (auto *LI = dyn_cast<LoadInst>(&I)) {
- Value *Ptr = LI->getPointerOperand();
-
- if (L->isLoopInvariant(Ptr)) {
- InvariantLoads.push_back(LI);
- }
+ BasicBlock *Preheader = L->getLoopPreheader();
+ if (!Preheader) {
+ // If no preheader, try the header
+ Preheader = L->getHeader();
+ }
+
+ // Find all invariant loads in the loop
+ SmallVector<LoadInst *, 8> InvariantLoads;
+
+ for (BasicBlock *BB : L->blocks()) {
+ for (Instruction &I : *BB) {
+ if (auto *LI = dyn_cast<LoadInst>(&I)) {
+ Value *Ptr = LI->getPointerOperand();
+
+ if (L->isLoopInvariant(Ptr)) {
+ InvariantLoads.push_back(LI);
}
}
}
-
- // Move loads to preheader and eliminate duplicates
- DenseMap<Value*, LoadInst*> HoistedLoads;
-
- for (LoadInst *LI : InvariantLoads) {
- Value *Ptr = LI->getPointerOperand();
-
- if (HoistedLoads.count(Ptr)) {
- // Already hoisted this load, replace uses
- LI->replaceAllUsesWith(HoistedLoads[Ptr]);
- LI->eraseFromParent();
- } else {
- // Move to preheader
- LI->moveBefore(*Preheader, Preheader->getTerminator()->getIterator());
- HoistedLoads[Ptr] = LI;
- }
+ }
+
+ // Move loads to preheader and eliminate duplicates
+ DenseMap<Value *, LoadInst *> HoistedLoads;
+
+ for (LoadInst *LI : InvariantLoads) {
+ Value *Ptr = LI->getPointerOperand();
+
+ if (HoistedLoads.count(Ptr)) {
+ // Already hoisted this load, replace uses
+ LI->replaceAllUsesWith(HoistedLoads[Ptr]);
+ LI->eraseFromParent();
+ } else {
+ // Move to preheader
+ LI->moveBefore(*Preheader, Preheader->getTerminator()->getIterator());
+ HoistedLoads[Ptr] = LI;
}
+ }
}
// InnerInductionVar will be transformed to static
void StridedLoopUnrollVersioning::transformStridedSpecialCases(
- PHINode *OuterInductionVar, PHINode *InnerInductionVar,
- StoreInst *Store, BasicBlock *PreheaderBB,
- BasicBlock *BodyBB, BasicBlock *HeaderBB, BasicBlock *LatchBB,
- SmallVectorImpl<const SCEV*>& AlignmentInfo,
+ PHINode *OuterInductionVar, PHINode *InnerInductionVar, StoreInst *Store,
+ BasicBlock *PreheaderBB, BasicBlock *BodyBB, BasicBlock *HeaderBB,
+ BasicBlock *LatchBB, SmallVectorImpl<const SCEV *> &AlignmentInfo,
unsigned UnrollSize) {
PredicatedScalarEvolution PSE(*SE, *CurLoop);
auto VLAI = &LAIs.getInfo(*CurLoop);
- LoopVersioning LVer2(*VLAI,
- VLAI->getRuntimePointerChecking()->getChecks(),
+ LoopVersioning LVer2(*VLAI, VLAI->getRuntimePointerChecking()->getChecks(),
CurLoop, LI, DT, SE, true);
LVer2.versionLoop();
@@ -1048,7 +1032,7 @@ void StridedLoopUnrollVersioning::transformStridedSpecialCases(
ULO.UnrollRemainder = false;
ULO.SCEVExpansionBudget = -1;
- /*auto r =*/ UnrollLoop(NewInnerLoop, ULO, LI, SE, DT, AC, TTI, ORE, false);
+ /*auto r =*/UnrollLoop(NewInnerLoop, ULO, LI, SE, DT, AC, TTI, ORE, false);
hoistInvariantLoadsToPreheader(VersionedLoop);
@@ -1057,9 +1041,9 @@ void StridedLoopUnrollVersioning::transformStridedSpecialCases(
}
for (BasicBlock *BB : VersionedLoop->blocks()) {
- DenseMap<Value*, Value*> LoadCSE;
- SmallVector<Instruction*, 16> DeadInsts;
-
+ DenseMap<Value *, Value *> LoadCSE;
+ SmallVector<Instruction *, 16> DeadInsts;
+
for (Instruction &I : *BB) {
if (auto *LI = dyn_cast<LoadInst>(&I)) {
if (!LI->isVolatile()) {
@@ -1074,7 +1058,7 @@ void StridedLoopUnrollVersioning::transformStridedSpecialCases(
}
}
}
-
+
for (auto *I : DeadInsts) {
I->eraseFromParent();
}
@@ -1103,7 +1087,7 @@ void StridedLoopUnrollVersioning::transformStridedSpecialCases(
Value *innerMask = Builder.getIntN(
InnerLoopBounds->getFinalIVValue().getType()->getIntegerBitWidth(),
- UnrollSize-1);
+ UnrollSize - 1);
Value *innerAndResult = Builder.CreateAnd(
&InnerLoopBounds->getFinalIVValue(), innerMask, "inner_mod_unroll");
Value *innerIsNotDivisible =
@@ -1121,8 +1105,8 @@ void StridedLoopUnrollVersioning::transformStridedSpecialCases(
Value *mask = Builder.getIntN(
OuterLoopBounds->getFinalIVValue().getType()->getIntegerBitWidth(),
*o - 1);
- Value *andResult =
- Builder.CreateAnd(&OuterLoopBounds->getFinalIVValue(), mask, "div_unroll");
+ Value *andResult = Builder.CreateAnd(&OuterLoopBounds->getFinalIVValue(),
+ mask, "div_unroll");
Value *isNotDivisible =
Builder.CreateICmpNE(andResult, outerZero, "is_div_unroll");
Value *Check1 = Builder.CreateOr(innerIsNotDivisible, isNotDivisible);
@@ -1130,22 +1114,25 @@ void StridedLoopUnrollVersioning::transformStridedSpecialCases(
Value *AlignmentCheck = Builder.getFalse();
- for (auto && PtrSCEV : AlignmentInfo) {
+ for (auto &&PtrSCEV : AlignmentInfo) {
const unsigned Alignment = 8;
// Expand SCEV to get runtime value
SCEVExpander Expander(*SE, *DL, "align.check");
- Value *PtrValue = Expander.expandCodeFor(PtrSCEV, Builder.getPtrTy(), PHBranch);
+ Value *PtrValue =
+ Expander.expandCodeFor(PtrSCEV, Builder.getPtrTy(), PHBranch);
Type *I64 = Type::getInt64Ty(PtrValue->getContext());
- bool AllowsMisaligned = TTI->isLegalStridedLoadStore(VectorType::get(I64, ElementCount::getFixed(8)), Align(1));
+ bool AllowsMisaligned = TTI->isLegalStridedLoadStore(
+ VectorType::get(I64, ElementCount::getFixed(8)), Align(1));
- if(!AllowsMisaligned) {
+ if (!AllowsMisaligned) {
// Create alignment check: (ptr & (alignment-1)) == 0
- Value *PtrInt = Builder.CreatePtrToInt(PtrValue, Builder.getInt64Ty());
+ Value *PtrInt =
+ Builder.CreatePtrToInt(PtrValue, Builder.getInt64Ty());
Value *Mask = Builder.getInt64(Alignment - 1);
Value *Masked = Builder.CreateAnd(PtrInt, Mask);
Value *IsAligned = Builder.CreateICmpNE(Masked, Builder.getInt64(0));
-
+
AlignmentCheck = Builder.CreateOr(AlignmentCheck, IsAligned);
}
}
@@ -1272,20 +1259,20 @@ bool StridedLoopUnrollVersioning::recognizeStridedSpecialCases() {
llvm::SmallPtrSet<llvm::Instruction *, 32> NotVisited;
llvm::SmallVector<llvm::Instruction *, 8> WorkList;
- for (auto&& BB : CurLoop->getBlocks())
- for (auto&& V : *BB)
+ for (auto &&BB : CurLoop->getBlocks())
+ for (auto &&V : *BB)
if (BB != ForLoop)
if (!canHandleInstruction(&V))
return false;
- for (auto&& Loop : CurLoop->getSubLoops())
- for (auto&& BB : Loop->getBlocks())
- for (auto&& V : *BB)
+ for (auto &&Loop : CurLoop->getSubLoops())
+ for (auto &&BB : Loop->getBlocks())
+ for (auto &&V : *BB)
if (BB != ForLoop)
if (!canHandleInstruction(&V))
return false;
-
+
// Collect pointers needing alignment
- SmallVector<const SCEV*, 8> AlignmentInfo;
+ SmallVector<const SCEV *, 8> AlignmentInfo;
unsigned UnrollSize = 8;
for (BasicBlock *BB : CurLoop->blocks()) {
@@ -1295,57 +1282,55 @@ bool StridedLoopUnrollVersioning::recognizeStridedSpecialCases() {
if (LoadInst *LI = dyn_cast<LoadInst>(&I)) {
Ptr = LI->getPointerOperand();
- TypeSize typeSize = DL->getTypeAllocSize(I.getType());
+ TypeSize typeSize = DL->getTypeAllocSize(I.getType());
if (size == 0)
size = typeSize;
else if (size != typeSize)
return false;
- }
- else if (StoreInst *SI = dyn_cast<StoreInst>(&I)) {
+ } else if (StoreInst *SI = dyn_cast<StoreInst>(&I)) {
Ptr = SI->getPointerOperand();
- TypeSize typeSize = DL->getTypeAllocSize(SI->getValueOperand()->getType());
+ TypeSize typeSize =
+ DL->getTypeAllocSize(SI->getValueOperand()->getType());
if (size == 0)
size = typeSize;
else if (size != typeSize)
return false;
- UnrollSize = 8/size;
- }
- else
+ UnrollSize = 8 / size;
+ } else
continue;
const SCEV *S = SE->getSCEV(Ptr);
if (const SCEVAddRecExpr *InnerLoopAR = dyn_cast<SCEVAddRecExpr>(S)) {
if (auto *constant =
- dyn_cast<SCEVConstant>(InnerLoopAR->getStepRecurrence(*SE))) {
+ dyn_cast<SCEVConstant>(InnerLoopAR->getStepRecurrence(*SE))) {
if (constant->getAPInt() != size)
return false; // must be contiguous
if (const SCEVAddRecExpr *AR =
- dyn_cast<SCEVAddRecExpr>(InnerLoopAR->getStart())) {
+ dyn_cast<SCEVAddRecExpr>(InnerLoopAR->getStart())) {
auto Step = AR->getStepRecurrence(*SE);
if (isa<SCEVConstant>(Step))
return false;
else {
- const SCEVUnknown* Unknown = nullptr;
+ const SCEVUnknown *Unknown = nullptr;
if (size > 1) {
if (auto mul = dyn_cast<SCEVMulExpr>(Step)) {
if (mul->getNumOperands() == 2) {
- if (auto constant = dyn_cast<SCEVConstant>(mul->getOperand(0))) {
+ if (auto constant =
+ dyn_cast<SCEVConstant>(mul->getOperand(0))) {
if (constant->getAPInt() != size)
return false;
- }
- else
+ } else
return false;
Unknown = dyn_cast<SCEVUnknown>(mul->getOperand(1));
- if (auto CastExtend = dyn_cast<SCEVCastExpr>(mul->getOperand(1)))
+ if (auto CastExtend =
+ dyn_cast<SCEVCastExpr>(mul->getOperand(1)))
Unknown = dyn_cast<SCEVUnknown>(CastExtend->getOperand());
- }
- else
+ } else
return false;
- }
- else
+ } else
return false;
}
if (!Unknown) {
@@ -1356,20 +1341,16 @@ bool StridedLoopUnrollVersioning::recognizeStridedSpecialCases() {
if (Unknown) { // stride should be fixed but not constant
if (isa<Instruction>(Unknown->getValue()))
return false;
- }
- else
+ } else
return false;
}
AlignmentInfo.push_back({AR->getStart()});
- }
- else
+ } else
return false;
- }
- else
+ } else
return false;
- }
- else if (!CurLoop->isLoopInvariant(Ptr))
+ } else if (!CurLoop->isLoopInvariant(Ptr))
return false;
}
}
@@ -1440,9 +1421,8 @@ bool StridedLoopUnrollVersioning::recognizeStridedSpecialCases() {
return false;
transformStridedSpecialCases(OuterInductionVariable, InnerInductionVariable,
- Stores[0], Preheader, ForLoop,
- OuterLoopHeader, OuterLoopLatch, AlignmentInfo,
- UnrollSize);
+ Stores[0], Preheader, ForLoop, OuterLoopHeader,
+ OuterLoopLatch, AlignmentInfo, UnrollSize);
return true;
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/157749
More information about the llvm-commits
mailing list