[llvm] [LoopIdiom] Improve code; use SCEVPatternMatch (NFC) (PR #139540)
via llvm-commits
llvm-commits at lists.llvm.org
Mon May 12 05:04:42 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-transforms
Author: Ramkumar Ramachandra (artagnon)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/139540.diff
1 Files Affected:
- (modified) llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp (+25-36)
``````````diff
diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
index 8f5d1ecba982d..5f59ec6daaba5 100644
--- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
@@ -32,7 +32,6 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/MapVector.h"
-#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
@@ -49,6 +48,7 @@
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/Analysis/ScalarEvolution.h"
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
+#include "llvm/Analysis/ScalarEvolutionPatternMatch.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/Analysis/ValueTracking.h"
@@ -91,6 +91,7 @@
#include <vector>
using namespace llvm;
+using namespace SCEVPatternMatch;
#define DEBUG_TYPE "loop-idiom"
@@ -340,9 +341,8 @@ bool LoopIdiomRecognize::runOnCountableLoop() {
// If this loop executes exactly one time, then it should be peeled, not
// optimized by this pass.
- if (const SCEVConstant *BECst = dyn_cast<SCEVConstant>(BECount))
- if (BECst->getAPInt() == 0)
- return false;
+ if (match(BECount, m_scev_SpecificInt(0)))
+ return false;
SmallVector<BasicBlock *, 8> ExitBlocks;
CurLoop->getUniqueExitBlocks(ExitBlocks);
@@ -805,20 +805,17 @@ bool LoopIdiomRecognize::processLoopMemCpy(MemCpyInst *MCI,
// Check if the stride matches the size of the memcpy. If so, then we know
// that every byte is touched in the loop.
- const SCEVConstant *ConstStoreStride =
- dyn_cast<SCEVConstant>(StoreEv->getOperand(1));
- const SCEVConstant *ConstLoadStride =
- dyn_cast<SCEVConstant>(LoadEv->getOperand(1));
- if (!ConstStoreStride || !ConstLoadStride)
+ const APInt *StoreStrideValue, *LoadStrideValue;
+ if (!match(StoreEv->getOperand(1), m_scev_APInt(StoreStrideValue)) ||
+ !match(LoadEv->getOperand(1), m_scev_APInt(LoadStrideValue)))
return false;
- APInt StoreStrideValue = ConstStoreStride->getAPInt();
- APInt LoadStrideValue = ConstLoadStride->getAPInt();
// Huge stride value - give up
- if (StoreStrideValue.getBitWidth() > 64 || LoadStrideValue.getBitWidth() > 64)
+ if (StoreStrideValue->getBitWidth() > 64 ||
+ LoadStrideValue->getBitWidth() > 64)
return false;
- if (SizeInBytes != StoreStrideValue && SizeInBytes != -StoreStrideValue) {
+ if (SizeInBytes != *StoreStrideValue && SizeInBytes != -*StoreStrideValue) {
ORE.emit([&]() {
return OptimizationRemarkMissed(DEBUG_TYPE, "SizeStrideUnequal", MCI)
<< ore::NV("Inst", "memcpy") << " in "
@@ -829,8 +826,8 @@ bool LoopIdiomRecognize::processLoopMemCpy(MemCpyInst *MCI,
return false;
}
- int64_t StoreStrideInt = StoreStrideValue.getSExtValue();
- int64_t LoadStrideInt = LoadStrideValue.getSExtValue();
+ int64_t StoreStrideInt = StoreStrideValue->getSExtValue();
+ int64_t LoadStrideInt = LoadStrideValue->getSExtValue();
// Check if the load stride matches the store stride.
if (StoreStrideInt != LoadStrideInt)
return false;
@@ -879,15 +876,14 @@ bool LoopIdiomRecognize::processLoopMemSet(MemSetInst *MSI,
// we know that every byte is touched in the loop.
LLVM_DEBUG(dbgs() << " memset size is constant\n");
uint64_t SizeInBytes = cast<ConstantInt>(MSI->getLength())->getZExtValue();
- const SCEVConstant *ConstStride = dyn_cast<SCEVConstant>(Ev->getOperand(1));
- if (!ConstStride)
+ const APInt *Stride;
+ if (!match(Ev->getOperand(1), m_scev_APInt(Stride)))
return false;
- APInt Stride = ConstStride->getAPInt();
- if (SizeInBytes != Stride && SizeInBytes != -Stride)
+ if (SizeInBytes != *Stride && SizeInBytes != -*Stride)
return false;
- IsNegStride = SizeInBytes == -Stride;
+ IsNegStride = SizeInBytes == -*Stride;
} else {
// Memset size is non-constant.
// Check if the pointer stride matches the memset size.
@@ -963,11 +959,11 @@ mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L,
// If the loop iterates a fixed number of times, we can refine the access size
// to be exactly the size of the memset, which is (BECount+1)*StoreSize
- const SCEVConstant *BECst = dyn_cast<SCEVConstant>(BECount);
- const SCEVConstant *ConstSize = dyn_cast<SCEVConstant>(StoreSizeSCEV);
- if (BECst && ConstSize) {
- std::optional<uint64_t> BEInt = BECst->getAPInt().tryZExtValue();
- std::optional<uint64_t> SizeInt = ConstSize->getAPInt().tryZExtValue();
+ const APInt *BECst, *ConstSize;
+ if (match(BECount, m_scev_APInt(BECst)) &&
+ match(StoreSizeSCEV, m_scev_APInt(ConstSize))) {
+ std::optional<uint64_t> BEInt = BECst->tryZExtValue();
+ std::optional<uint64_t> SizeInt = ConstSize->tryZExtValue();
// FIXME: Should this check for overflow?
if (BEInt && SizeInt)
AccessSize = LocationSize::precise((*BEInt + 1) * *SizeInt);
@@ -1605,16 +1601,11 @@ class StrlenVerifier {
LLVM_DEBUG(dbgs() << "pointer load scev: " << *LoadEv << "\n");
- const SCEVConstant *Step =
- dyn_cast<SCEVConstant>(LoadEv->getStepRecurrence(*SE));
- if (!Step)
+ const APInt *Step;
+ if (!match(LoadEv->getStepRecurrence(*SE), m_scev_APInt(Step)))
return false;
- unsigned StepSize = 0;
- StepSizeCI = dyn_cast<ConstantInt>(Step->getValue());
- if (!StepSizeCI)
- return false;
- StepSize = StepSizeCI->getZExtValue();
+ unsigned StepSize = Step->getZExtValue();
// Verify that StepSize is consistent with platform char width.
OpWidth = OperandType->getIntegerBitWidth();
@@ -3294,9 +3285,7 @@ bool LoopIdiomRecognize::recognizeShiftUntilZero() {
// Ok, transform appears worthwhile.
MadeChange = true;
- bool OffsetIsZero = false;
- if (auto *ExtraOffsetExprC = dyn_cast<SCEVConstant>(ExtraOffsetExpr))
- OffsetIsZero = ExtraOffsetExprC->isZero();
+ bool OffsetIsZero = match(ExtraOffsetExpr, m_scev_SpecificInt(0));
// Step 1: Compute the loop's final IV value / trip count.
``````````
</details>
https://github.com/llvm/llvm-project/pull/139540
More information about the llvm-commits
mailing list