[llvm] [LoopIdiom] Improve code; use SCEVPatternMatch (NFC) (PR #139540)

via llvm-commits llvm-commits at lists.llvm.org
Mon May 12 05:04:42 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Ramkumar Ramachandra (artagnon)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/139540.diff


1 Files Affected:

- (modified) llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp (+25-36) 


``````````diff
diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
index 8f5d1ecba982d..5f59ec6daaba5 100644
--- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp
@@ -32,7 +32,6 @@
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/MapVector.h"
-#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/SmallVector.h"
@@ -49,6 +48,7 @@
 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
 #include "llvm/Analysis/ScalarEvolution.h"
 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
+#include "llvm/Analysis/ScalarEvolutionPatternMatch.h"
 #include "llvm/Analysis/TargetLibraryInfo.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/Analysis/ValueTracking.h"
@@ -91,6 +91,7 @@
 #include <vector>
 
 using namespace llvm;
+using namespace SCEVPatternMatch;
 
 #define DEBUG_TYPE "loop-idiom"
 
@@ -340,9 +341,8 @@ bool LoopIdiomRecognize::runOnCountableLoop() {
 
   // If this loop executes exactly one time, then it should be peeled, not
   // optimized by this pass.
-  if (const SCEVConstant *BECst = dyn_cast<SCEVConstant>(BECount))
-    if (BECst->getAPInt() == 0)
-      return false;
+  if (match(BECount, m_scev_SpecificInt(0)))
+    return false;
 
   SmallVector<BasicBlock *, 8> ExitBlocks;
   CurLoop->getUniqueExitBlocks(ExitBlocks);
@@ -805,20 +805,17 @@ bool LoopIdiomRecognize::processLoopMemCpy(MemCpyInst *MCI,
 
   // Check if the stride matches the size of the memcpy. If so, then we know
   // that every byte is touched in the loop.
-  const SCEVConstant *ConstStoreStride =
-      dyn_cast<SCEVConstant>(StoreEv->getOperand(1));
-  const SCEVConstant *ConstLoadStride =
-      dyn_cast<SCEVConstant>(LoadEv->getOperand(1));
-  if (!ConstStoreStride || !ConstLoadStride)
+  const APInt *StoreStrideValue, *LoadStrideValue;
+  if (!match(StoreEv->getOperand(1), m_scev_APInt(StoreStrideValue)) ||
+      !match(LoadEv->getOperand(1), m_scev_APInt(LoadStrideValue)))
     return false;
 
-  APInt StoreStrideValue = ConstStoreStride->getAPInt();
-  APInt LoadStrideValue = ConstLoadStride->getAPInt();
   // Huge stride value - give up
-  if (StoreStrideValue.getBitWidth() > 64 || LoadStrideValue.getBitWidth() > 64)
+  if (StoreStrideValue->getBitWidth() > 64 ||
+      LoadStrideValue->getBitWidth() > 64)
     return false;
 
-  if (SizeInBytes != StoreStrideValue && SizeInBytes != -StoreStrideValue) {
+  if (SizeInBytes != *StoreStrideValue && SizeInBytes != -*StoreStrideValue) {
     ORE.emit([&]() {
       return OptimizationRemarkMissed(DEBUG_TYPE, "SizeStrideUnequal", MCI)
              << ore::NV("Inst", "memcpy") << " in "
@@ -829,8 +826,8 @@ bool LoopIdiomRecognize::processLoopMemCpy(MemCpyInst *MCI,
     return false;
   }
 
-  int64_t StoreStrideInt = StoreStrideValue.getSExtValue();
-  int64_t LoadStrideInt = LoadStrideValue.getSExtValue();
+  int64_t StoreStrideInt = StoreStrideValue->getSExtValue();
+  int64_t LoadStrideInt = LoadStrideValue->getSExtValue();
   // Check if the load stride matches the store stride.
   if (StoreStrideInt != LoadStrideInt)
     return false;
@@ -879,15 +876,14 @@ bool LoopIdiomRecognize::processLoopMemSet(MemSetInst *MSI,
     // we know that every byte is touched in the loop.
     LLVM_DEBUG(dbgs() << "  memset size is constant\n");
     uint64_t SizeInBytes = cast<ConstantInt>(MSI->getLength())->getZExtValue();
-    const SCEVConstant *ConstStride = dyn_cast<SCEVConstant>(Ev->getOperand(1));
-    if (!ConstStride)
+    const APInt *Stride;
+    if (!match(Ev->getOperand(1), m_scev_APInt(Stride)))
       return false;
 
-    APInt Stride = ConstStride->getAPInt();
-    if (SizeInBytes != Stride && SizeInBytes != -Stride)
+    if (SizeInBytes != *Stride && SizeInBytes != -*Stride)
       return false;
 
-    IsNegStride = SizeInBytes == -Stride;
+    IsNegStride = SizeInBytes == -*Stride;
   } else {
     // Memset size is non-constant.
     // Check if the pointer stride matches the memset size.
@@ -963,11 +959,11 @@ mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L,
 
   // If the loop iterates a fixed number of times, we can refine the access size
   // to be exactly the size of the memset, which is (BECount+1)*StoreSize
-  const SCEVConstant *BECst = dyn_cast<SCEVConstant>(BECount);
-  const SCEVConstant *ConstSize = dyn_cast<SCEVConstant>(StoreSizeSCEV);
-  if (BECst && ConstSize) {
-    std::optional<uint64_t> BEInt = BECst->getAPInt().tryZExtValue();
-    std::optional<uint64_t> SizeInt = ConstSize->getAPInt().tryZExtValue();
+  const APInt *BECst, *ConstSize;
+  if (match(BECount, m_scev_APInt(BECst)) &&
+      match(StoreSizeSCEV, m_scev_APInt(ConstSize))) {
+    std::optional<uint64_t> BEInt = BECst->tryZExtValue();
+    std::optional<uint64_t> SizeInt = ConstSize->tryZExtValue();
     // FIXME: Should this check for overflow?
     if (BEInt && SizeInt)
       AccessSize = LocationSize::precise((*BEInt + 1) * *SizeInt);
@@ -1605,16 +1601,11 @@ class StrlenVerifier {
 
     LLVM_DEBUG(dbgs() << "pointer load scev: " << *LoadEv << "\n");
 
-    const SCEVConstant *Step =
-        dyn_cast<SCEVConstant>(LoadEv->getStepRecurrence(*SE));
-    if (!Step)
+    const APInt *Step;
+    if (!match(LoadEv->getStepRecurrence(*SE), m_scev_APInt(Step)))
       return false;
 
-    unsigned StepSize = 0;
-    StepSizeCI = dyn_cast<ConstantInt>(Step->getValue());
-    if (!StepSizeCI)
-      return false;
-    StepSize = StepSizeCI->getZExtValue();
+    unsigned StepSize = Step->getZExtValue();
 
     // Verify that StepSize is consistent with platform char width.
     OpWidth = OperandType->getIntegerBitWidth();
@@ -3294,9 +3285,7 @@ bool LoopIdiomRecognize::recognizeShiftUntilZero() {
   // Ok, transform appears worthwhile.
   MadeChange = true;
 
-  bool OffsetIsZero = false;
-  if (auto *ExtraOffsetExprC = dyn_cast<SCEVConstant>(ExtraOffsetExpr))
-    OffsetIsZero = ExtraOffsetExprC->isZero();
+  bool OffsetIsZero = match(ExtraOffsetExpr, m_scev_SpecificInt(0));
 
   // Step 1: Compute the loop's final IV value / trip count.
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/139540


More information about the llvm-commits mailing list