[llvm] Remove redundant assumes (take 2) (PR #123480)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Sun Jan 19 06:35:45 PST 2025


https://github.com/fhahn updated https://github.com/llvm/llvm-project/pull/123480

>From 21d06ec3d7c8df73d2296dbb9e50a30609d78f51 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Sat, 18 Jan 2025 20:29:00 +0000
Subject: [PATCH 1/2] Remove redundant assumes

---
 .../InstCombine/InstCombineCalls.cpp          | 98 ++++++++++++++++++-
 .../Transforms/InstCombine/assume-align.ll    | 56 ++++++++++-
 llvm/test/Transforms/InstCombine/assume.ll    | 18 ++--
 .../AArch64/infer-align-from-assumption.ll    | 32 +++---
 4 files changed, 173 insertions(+), 31 deletions(-)

diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 842881156dc67f..6d2f2dff523547 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -3199,12 +3199,13 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
       // TODO: apply range metadata for range check patterns?
     }
 
-    // Separate storage assumptions apply to the underlying allocations, not any
-    // particular pointer within them. When evaluating the hints for AA purposes
-    // we getUnderlyingObject them; by precomputing the answers here we can
-    // avoid having to do so repeatedly there.
     for (unsigned Idx = 0; Idx < II->getNumOperandBundles(); Idx++) {
       OperandBundleUse OBU = II->getOperandBundleAt(Idx);
+
+      // Separate storage assumptions apply to the underlying allocations, not
+      // any particular pointer within them. When evaluating the hints for AA
+      // purposes we getUnderlyingObject them; by precomputing the answers here
+      // we can avoid having to do so repeatedly there.
       if (OBU.getTagName() == "separate_storage") {
         assert(OBU.Inputs.size() == 2);
         auto MaybeSimplifyHint = [&](const Use &U) {
@@ -3218,6 +3219,95 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
         MaybeSimplifyHint(OBU.Inputs[0]);
         MaybeSimplifyHint(OBU.Inputs[1]);
       }
+
+      // Try to fold alignment assumption into a load's !align metadata, if the
+      // assumption is valid in the load's context and remove redundant ones.
+      if (OBU.getTagName() == "align" && OBU.Inputs.size() == 2) {
+        RetainedKnowledge RK = getKnowledgeFromBundle(
+            *cast<AssumeInst>(II), II->bundle_op_info_begin()[Idx]);
+        if (!RK || RK.AttrKind != Attribute::Alignment ||
+            !isPowerOf2_64(RK.ArgValue))
+          continue;
+        auto *C = dyn_cast<Constant>(RK.WasOn);
+        if (C && C->isNullValue()) {
+        } else {
+          Value *UO = getUnderlyingObject(RK.WasOn);
+
+          bool CanUseAlign = false;
+          SetVector<const Instruction *> WorkList;
+          for (const User *U : RK.WasOn->users())
+            if (auto *I = dyn_cast<Instruction>(U))
+              WorkList.insert(I);
+
+          for (unsigned I = 0; I != WorkList.size(); ++I) {
+            auto *Curr = WorkList[I];
+            if (!DT.dominates(II, Curr))
+              continue;
+            if (auto *LI = dyn_cast<LoadInst>(Curr)) {
+              if (LI->getAlign().value() < RK.ArgValue) {
+                CanUseAlign = true;
+                break;
+              }
+              continue;
+            }
+            if (auto *SI = dyn_cast<StoreInst>(Curr)) {
+              auto *PtrOpI = dyn_cast<Instruction>(SI->getPointerOperand());
+              if ((SI->getPointerOperand() == RK.WasOn || (PtrOpI && WorkList.contains(PtrOpI))) &&
+                  SI->getAlign().value() < RK.ArgValue) {
+                CanUseAlign = true;
+                break;
+              }
+              continue;
+            }
+            if (auto *II = dyn_cast<IntrinsicInst>(Curr)) {
+              for (const auto &[Idx, Arg] : enumerate(II->args())) {
+                if (Arg != RK.WasOn)
+                  continue;
+                if (II->getParamAlign(Idx) >= RK.ArgValue)
+                  continue;
+                CanUseAlign = true;
+                break;
+              }
+              if (CanUseAlign)
+                break;
+              continue;
+            }
+            if (isa<ReturnInst, CallBase>(Curr)) {
+              CanUseAlign = true;
+              break;
+            }
+            if (isa<ICmpInst>(Curr) &&
+                !isa<Constant>(cast<Instruction>(Curr)->getOperand(0)) &&
+                !isa<Constant>(cast<Instruction>(Curr)->getOperand(1))) {
+              CanUseAlign = true;
+              break;
+            }
+            if (!Curr->getType()->isPointerTy())
+              continue;
+
+            if (WorkList.size() > 16) {
+              CanUseAlign = true;
+              break;
+            }
+            for (const User *U : Curr->users())
+              WorkList.insert(cast<Instruction>(U));
+          }
+          if (CanUseAlign && (!UO || isa<Argument>(UO)))
+            continue;
+          // Try to get the instruction before the assumption to use as
+          // context.
+          Instruction *CtxI = nullptr;
+          if (CtxI && II->getParent()->begin() != II->getIterator())
+            CtxI = II->getPrevNode();
+
+          auto Known = computeKnownBits(RK.WasOn, 1, CtxI);
+          unsigned KnownAlign = 1 << Known.countMinTrailingZeros();
+          if (CanUseAlign && KnownAlign < RK.ArgValue)
+            continue;
+        }
+        auto *New = CallBase::removeOperandBundle(II, OBU.getTagID());
+        return New;
+      }
     }
 
     // Convert nonnull assume like:
diff --git a/llvm/test/Transforms/InstCombine/assume-align.ll b/llvm/test/Transforms/InstCombine/assume-align.ll
index f0e02574330861..1d117369db9b50 100644
--- a/llvm/test/Transforms/InstCombine/assume-align.ll
+++ b/llvm/test/Transforms/InstCombine/assume-align.ll
@@ -1,5 +1,5 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --check-globals smart
-; RUN: opt -S -passes=instcombine,simplifycfg < %s 2>&1 | FileCheck %s
+; RUN: opt -S -passes='instcombine<no-verify-fixpoint>,simplifycfg' < %s 2>&1 | FileCheck %s
 
 declare void @llvm.assume(i1 noundef)
 
@@ -87,7 +87,6 @@ if.end:                                           ; preds = %if.else, %if.then
 define void @f3(i64 %a, ptr %b) {
 ; CHECK-LABEL: @f3(
 ; CHECK-NEXT:    [[C:%.*]] = ptrtoint ptr [[B:%.*]] to i64
-; CHECK-NEXT:    call void @llvm.assume(i1 true) [ "align"(ptr [[B]], i64 4294967296) ]
 ; CHECK-NEXT:    [[D:%.*]] = add i64 [[A:%.*]], [[C]]
 ; CHECK-NEXT:    call void @g(i64 [[D]])
 ; CHECK-NEXT:    ret void
@@ -135,6 +134,17 @@ define ptr @fold_assume_align_pow2_of_loaded_pointer_into_align_metadata(ptr %p)
   ret ptr %p2
 }
 
+define ptr @fold_assume_align_i32_pow2_of_loaded_pointer_into_align_metadata(ptr %p) {
+; CHECK-LABEL: @fold_assume_align_i32_pow2_of_loaded_pointer_into_align_metadata(
+; CHECK-NEXT:    [[P2:%.*]] = load ptr, ptr [[P:%.*]], align 8
+; CHECK-NEXT:    call void @llvm.assume(i1 true) [ "align"(ptr [[P2]], i32 8) ]
+; CHECK-NEXT:    ret ptr [[P2]]
+;
+  %p2 = load ptr, ptr %p
+  call void @llvm.assume(i1 true) [ "align"(ptr %p2, i32 8) ]
+  ret ptr %p2
+}
+
 define ptr @dont_fold_assume_align_pow2_of_loaded_pointer_into_align_metadata_due_to_call(ptr %p) {
 ; CHECK-LABEL: @dont_fold_assume_align_pow2_of_loaded_pointer_into_align_metadata_due_to_call(
 ; CHECK-NEXT:    [[P2:%.*]] = load ptr, ptr [[P:%.*]], align 8
@@ -175,7 +185,6 @@ define ptr @dont_fold_assume_align_zero_of_loaded_pointer_into_align_metadata(pt
 define ptr @redundant_assume_align_1(ptr %p) {
 ; CHECK-LABEL: @redundant_assume_align_1(
 ; CHECK-NEXT:    [[P2:%.*]] = load ptr, ptr [[P:%.*]], align 8
-; CHECK-NEXT:    call void @llvm.assume(i1 true) [ "align"(ptr [[P2]], i32 1) ]
 ; CHECK-NEXT:    call void @foo(ptr [[P2]])
 ; CHECK-NEXT:    ret ptr [[P2]]
 ;
@@ -189,7 +198,6 @@ define ptr @redundant_assume_align_1(ptr %p) {
 define ptr @redundant_assume_align_8_via_align_metadata(ptr %p) {
 ; CHECK-LABEL: @redundant_assume_align_8_via_align_metadata(
 ; CHECK-NEXT:    [[P2:%.*]] = load ptr, ptr [[P:%.*]], align 8, !align [[META0:![0-9]+]]
-; CHECK-NEXT:    call void @llvm.assume(i1 true) [ "align"(ptr [[P2]], i32 8) ]
 ; CHECK-NEXT:    call void @foo(ptr [[P2]])
 ; CHECK-NEXT:    ret ptr [[P2]]
 ;
@@ -249,7 +257,47 @@ define ptr @redundant_assume_align_8_via_asume(ptr %p) {
   ret ptr %p
 }
 
+define void @redundant_arg_passed_to_intrinsic(ptr %dst, ptr %src) {
+; CHECK-LABEL: @redundant_arg_passed_to_intrinsic(
+; CHECK-NEXT:    call void @bar()
+; CHECK-NEXT:    call void @llvm.assume(i1 true) [ "align"(ptr [[SRC:%.*]], i32 8) ]
+; CHECK-NEXT:    call void @llvm.memmove.p0.p0.i64(ptr noundef nonnull align 8 dereferenceable(16) [[DST:%.*]], ptr noundef nonnull align 8 dereferenceable(16) [[SRC]], i64 16, i1 false)
+; CHECK-NEXT:    ret void
+;
+  call void @llvm.assume(i1 true) [ "align"(ptr %dst, i32 8) ]
+  call void @bar()
+  call void @llvm.assume(i1 true) [ "align"(ptr %src, i32 8) ]
+  call void @llvm.memmove.p0.p0.i64(ptr align 8 %dst, ptr %src, i64 16, i1 false)
+  ret void
+}
+
+define void @test_store(ptr %ptr) {
+; CHECK-LABEL: @test_store(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    call void @llvm.assume(i1 true) [ "align"(ptr [[PTR:%.*]], i64 2) ]
+; CHECK-NEXT:    store i16 0, ptr [[PTR]], align 1
+; CHECK-NEXT:    ret void
+;
+entry:
+  call void @llvm.assume(i1 true) [ "align"(ptr %ptr, i64 2) ]
+  store i16 0, ptr %ptr, align 1
+  ret void
+}
+
 declare void @foo(ptr)
+declare void @bar()
+
+; !align must have a constant integer alignment.
+define ptr @dont_fold_assume_align_not_constant_of_loaded_pointer_into_align_metadata(ptr %p, i64 %align) {
+; CHECK-LABEL: @dont_fold_assume_align_not_constant_of_loaded_pointer_into_align_metadata(
+; CHECK-NEXT:    [[P2:%.*]] = load ptr, ptr [[P:%.*]], align 8
+; CHECK-NEXT:    ret ptr [[P2]]
+;
+  %p2 = load ptr, ptr %p
+  call void @llvm.assume(i1 true) [ "align"(ptr %p2, i64 %align) ]
+  ret ptr %p2
+}
+
 ;.
 ; CHECK: [[META0]] = !{i64 8}
 ;.
diff --git a/llvm/test/Transforms/InstCombine/assume.ll b/llvm/test/Transforms/InstCombine/assume.ll
index c21f8457e82d14..2f07a68c25526a 100644
--- a/llvm/test/Transforms/InstCombine/assume.ll
+++ b/llvm/test/Transforms/InstCombine/assume.ll
@@ -1,9 +1,9 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
-; RUN: opt < %s -passes=instcombine -S | FileCheck --check-prefixes=CHECK,DEFAULT %s
-; RUN: opt < %s -passes=instcombine --enable-knowledge-retention -S | FileCheck --check-prefixes=CHECK,BUNDLES %s
+; RUN: opt < %s -passes='instcombine<no-verify-fixpoint>' -S | FileCheck --check-prefixes=CHECK,DEFAULT %s
+; RUN: opt < %s -passes='instcombine<no-verify-fixpoint>' --enable-knowledge-retention -S | FileCheck --check-prefixes=CHECK,BUNDLES %s
 
-; RUN: opt < %s -passes=instcombine -S --try-experimental-debuginfo-iterators | FileCheck --check-prefixes=CHECK,DEFAULT %s
-; RUN: opt < %s -passes=instcombine --enable-knowledge-retention -S --try-experimental-debuginfo-iterators | FileCheck --check-prefixes=CHECK,BUNDLES %s
+; RUN: opt < %s -passes='instcombine<no-verify-fixpoint>' -S --try-experimental-debuginfo-iterators | FileCheck --check-prefixes=CHECK,DEFAULT %s
+; RUN: opt < %s -passes='instcombine<no-verify-fixpoint>' --enable-knowledge-retention -S --try-experimental-debuginfo-iterators | FileCheck --check-prefixes=CHECK,BUNDLES %s
 
 target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128"
 target triple = "x86_64-unknown-linux-gnu"
@@ -795,14 +795,8 @@ exit:
 }
 
 define void @canonicalize_assume(ptr %0) {
-; DEFAULT-LABEL: @canonicalize_assume(
-; DEFAULT-NEXT:    [[TMP2:%.*]] = getelementptr inbounds nuw i8, ptr [[TMP0:%.*]], i64 8
-; DEFAULT-NEXT:    call void @llvm.assume(i1 true) [ "align"(ptr [[TMP2]], i64 16) ]
-; DEFAULT-NEXT:    ret void
-;
-; BUNDLES-LABEL: @canonicalize_assume(
-; BUNDLES-NEXT:    call void @llvm.assume(i1 true) [ "align"(ptr [[TMP0:%.*]], i64 8) ]
-; BUNDLES-NEXT:    ret void
+; CHECK-LABEL: @canonicalize_assume(
+; CHECK-NEXT:    ret void
 ;
   %2 = getelementptr inbounds i32, ptr %0, i64 2
   call void @llvm.assume(i1 true) [ "align"(ptr %2, i64 16) ]
diff --git a/llvm/test/Transforms/PhaseOrdering/AArch64/infer-align-from-assumption.ll b/llvm/test/Transforms/PhaseOrdering/AArch64/infer-align-from-assumption.ll
index 632e3a56aacac7..51d149e9798e97 100644
--- a/llvm/test/Transforms/PhaseOrdering/AArch64/infer-align-from-assumption.ll
+++ b/llvm/test/Transforms/PhaseOrdering/AArch64/infer-align-from-assumption.ll
@@ -15,7 +15,6 @@ define i32 @earlycse_entry(ptr %p) {
 ; CHECK-NEXT:    [[L_2_I:%.*]] = load ptr, ptr [[P]], align 8
 ; CHECK-NEXT:    [[GEP_I:%.*]] = getelementptr i8, ptr [[L_2_I]], i64 4
 ; CHECK-NEXT:    store ptr [[GEP_I]], ptr [[P]], align 8
-; CHECK-NEXT:    call void @llvm.assume(i1 true) [ "align"(ptr [[GEP_I]], i64 4) ]
 ; CHECK-NEXT:    [[L_ASSUME_ALIGNED_I_I2:%.*]] = load i32, ptr [[GEP_I]], align 4
 ; CHECK-NEXT:    [[R_I_I3:%.*]] = tail call i32 @swap(i32 [[L_ASSUME_ALIGNED_I_I2]])
 ; CHECK-NEXT:    [[L_2_I4:%.*]] = load ptr, ptr [[P]], align 8
@@ -51,7 +50,6 @@ define i32 @earlycse_fn1(ptr %p) {
 define i32 @load_assume_aligned(ptr %p) {
 ; CHECK-LABEL: define i32 @load_assume_aligned(
 ; CHECK-SAME: ptr [[P:%.*]]) local_unnamed_addr {
-; CHECK-NEXT:    call void @llvm.assume(i1 true) [ "align"(ptr [[P]], i64 4) ]
 ; CHECK-NEXT:    [[DOT0_COPYLOAD:%.*]] = load i32, ptr [[P]], align 4
 ; CHECK-NEXT:    [[TMP2:%.*]] = tail call i32 @swap(i32 [[DOT0_COPYLOAD]])
 ; CHECK-NEXT:    ret i32 [[TMP2]]
@@ -66,8 +64,7 @@ declare i32 @swap(i32)
 
 define void @sroa_align_entry(ptr %p) {
 ; CHECK-LABEL: define void @sroa_align_entry(
-; CHECK-SAME: ptr [[P:%.*]]) local_unnamed_addr #[[ATTR1:[0-9]+]] {
-; CHECK-NEXT:    call void @llvm.assume(i1 true) [ "align"(ptr [[P]], i64 8) ]
+; CHECK-SAME: ptr [[P:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
 ; CHECK-NEXT:    [[DOT0_COPYLOAD_I_I_I:%.*]] = load i64, ptr [[P]], align 8
 ; CHECK-NEXT:    [[TMP2:%.*]] = inttoptr i64 [[DOT0_COPYLOAD_I_I_I]] to ptr
 ; CHECK-NEXT:    store i32 0, ptr [[TMP2]], align 4
@@ -82,9 +79,6 @@ define void @sroa_align_entry(ptr %p) {
 
 define ptr @sroa_fn1(ptr %p) {
 ; CHECK-LABEL: define ptr @sroa_fn1(
-; CHECK-SAME: ptr nocapture readonly [[P:%.*]]) local_unnamed_addr #[[ATTR2:[0-9]+]] {
-; CHECK-NEXT:    [[L:%.*]] = load ptr, ptr [[P]], align 8
-; CHECK-NEXT:    call void @llvm.assume(i1 true) [ "align"(ptr [[L]], i64 8) ]
 ; CHECK-NEXT:    [[L_FN3_I_I:%.*]] = load i64, ptr [[L]], align 8
 ; CHECK-NEXT:    [[I_I:%.*]] = inttoptr i64 [[L_FN3_I_I]] to ptr
 ; CHECK-NEXT:    ret ptr [[I_I]]
@@ -96,8 +90,7 @@ define ptr @sroa_fn1(ptr %p) {
 
 define ptr @sroa_fn2(ptr %p) {
 ; CHECK-LABEL: define ptr @sroa_fn2(
-; CHECK-SAME: ptr [[P:%.*]]) local_unnamed_addr #[[ATTR3:[0-9]+]] {
-; CHECK-NEXT:    call void @llvm.assume(i1 true) [ "align"(ptr [[P]], i64 8) ]
+; CHECK-SAME: ptr [[P:%.*]]) local_unnamed_addr #[[ATTR2:[0-9]+]] {
 ; CHECK-NEXT:    [[DOT0_COPYLOAD_I_I:%.*]] = load i64, ptr [[P]], align 8
 ; CHECK-NEXT:    [[TMP3:%.*]] = inttoptr i64 [[DOT0_COPYLOAD_I_I]] to ptr
 ; CHECK-NEXT:    ret ptr [[TMP3]]
@@ -109,8 +102,7 @@ define ptr @sroa_fn2(ptr %p) {
 
 define i64 @sroa_fn3(ptr %0) {
 ; CHECK-LABEL: define i64 @sroa_fn3(
-; CHECK-SAME: ptr [[TMP0:%.*]]) local_unnamed_addr #[[ATTR3]] {
-; CHECK-NEXT:    call void @llvm.assume(i1 true) [ "align"(ptr [[TMP0]], i64 8) ]
+; CHECK-SAME: ptr [[TMP0:%.*]]) local_unnamed_addr #[[ATTR2]] {
 ; CHECK-NEXT:    [[DOT0_COPYLOAD_I:%.*]] = load i64, ptr [[TMP0]], align 8
 ; CHECK-NEXT:    ret i64 [[DOT0_COPYLOAD_I]]
 ;
@@ -118,3 +110,21 @@ define i64 @sroa_fn3(ptr %0) {
   %l.fn3 = load i64, ptr %0, align 1
   ret i64 %l.fn3
 }
+
+define void @test_store(ptr %ptr) {
+; CHECK-LABEL: define void @test_store(
+; CHECK-SAME: ptr [[PTR:%.*]]) local_unnamed_addr #[[ATTR3:[0-9]+]] {
+; CHECK-NEXT:  [[ENTRY:.*:]]
+; CHECK-NEXT:    store i16 0, ptr [[PTR]], align 2
+; CHECK-NEXT:    ret void
+;
+entry:
+  call void @llvm.assume(i1 true) [ "align"(ptr %ptr, i64 2) ]
+  store i16 0, ptr %ptr, align 1
+  ret void
+}
+;.
+; CHECK: [[META0]] = !{i64 4}
+; CHECK: [[META1]] = !{}
+; CHECK: [[META2]] = !{i64 8}
+;.

>From 9240176da136a5e107e1b3e3c9d8d6b584981a98 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Sun, 19 Jan 2025 14:35:03 +0000
Subject: [PATCH 2/2] ValueTracking.

---
 llvm/include/llvm/Analysis/SimplifyQuery.h    |  7 +++
 llvm/include/llvm/Analysis/ValueTracking.h    |  4 +-
 llvm/lib/Analysis/ScalarEvolution.cpp         |  2 +-
 llvm/lib/Analysis/ValueTracking.cpp           | 53 ++++++++++++++-----
 .../Target/ARM/MVEGatherScatterLowering.cpp   |  4 +-
 .../RISCV/RISCVGatherScatterLowering.cpp      |  4 +-
 .../Transforms/Scalar/LoopStrengthReduce.cpp  |  2 +-
 7 files changed, 57 insertions(+), 19 deletions(-)

diff --git a/llvm/include/llvm/Analysis/SimplifyQuery.h b/llvm/include/llvm/Analysis/SimplifyQuery.h
index e8f43c8c2e91f8..902e2ab7bb4312 100644
--- a/llvm/include/llvm/Analysis/SimplifyQuery.h
+++ b/llvm/include/llvm/Analysis/SimplifyQuery.h
@@ -18,6 +18,7 @@ class AssumptionCache;
 class DomConditionCache;
 class DominatorTree;
 class TargetLibraryInfo;
+class GetElementPtrInst;
 
 /// InstrInfoQuery provides an interface to query additional information for
 /// instructions like metadata or keywords like nsw, which provides conservative
@@ -45,6 +46,12 @@ struct InstrInfoQuery {
     return false;
   }
 
+  template <class InstT> bool isInBounds(const InstT *Op) const {
+    if (UseInstrInfo)
+      return Op->isInBounds();
+    return false;
+  }
+
   bool isExact(const BinaryOperator *Op) const {
     if (UseInstrInfo && isa<PossiblyExactOperator>(Op))
       return cast<PossiblyExactOperator>(Op)->isExact();
diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h
index b4918c2d1e8a18..e4978a4ef476a8 100644
--- a/llvm/include/llvm/Analysis/ValueTracking.h
+++ b/llvm/include/llvm/Analysis/ValueTracking.h
@@ -1241,11 +1241,11 @@ canConvertToMinOrMaxIntrinsic(ArrayRef<Value *> VL);
 ///
 /// NOTE: This is intentional simple.  If you want the ability to analyze
 /// non-trivial loop conditons, see ScalarEvolution instead.
-bool matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO, Value *&Start,
+bool matchSimpleRecurrence(const PHINode *P, Instruction *&BO, Value *&Start,
                            Value *&Step);
 
 /// Analogous to the above, but starting from the binary operator
-bool matchSimpleRecurrence(const BinaryOperator *I, PHINode *&P, Value *&Start,
+bool matchSimpleRecurrence(const Instruction *I, PHINode *&P, Value *&Start,
                            Value *&Step);
 
 /// Return true if RHS is known to be implied true by LHS.  Return false if
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 7673c354817579..f3282424d9d5e7 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -6465,7 +6465,7 @@ getRangeForUnknownRecurrence(const SCEVUnknown *U) {
     if (!DT.isReachableFromEntry(Pred))
       return FullSet;
 
-  BinaryOperator *BO;
+  Instruction*BO;
   Value *Start, *Step;
   if (!matchSimpleRecurrence(P, BO, Start, Step))
     return FullSet;
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 6e2f0ebde9bb6c..e46dec42738211 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -1489,7 +1489,7 @@ static void computeKnownBitsFromOperator(const Operator *I,
   }
   case Instruction::PHI: {
     const PHINode *P = cast<PHINode>(I);
-    BinaryOperator *BO = nullptr;
+    Instruction *BO = nullptr;
     Value *R = nullptr, *L = nullptr;
     if (matchSimpleRecurrence(P, BO, R, L)) {
       // Handle the case of a simple two-predecessor recurrence PHI.
@@ -1553,6 +1553,7 @@ static void computeKnownBitsFromOperator(const Operator *I,
       case Instruction::Sub:
       case Instruction::And:
       case Instruction::Or:
+      case Instruction::GetElementPtr:
       case Instruction::Mul: {
         // Change the context instruction to the "edge" that flows into the
         // phi. This is important because that is where the value is actually
@@ -1571,6 +1572,10 @@ static void computeKnownBitsFromOperator(const Operator *I,
 
         // We need to take the minimum number of known bits
         KnownBits Known3(BitWidth);
+        if (BitWidth != getBitWidth(L->getType(), Q.DL)) {
+          assert(isa<GetElementPtrInst>(BO));
+          break;
+        }
         RecQ.CxtI = LInst;
         computeKnownBits(L, DemandedElts, Known3, Depth + 1, RecQ);
 
@@ -1578,7 +1583,9 @@ static void computeKnownBitsFromOperator(const Operator *I,
                                        Known3.countMinTrailingZeros()));
 
         auto *OverflowOp = dyn_cast<OverflowingBinaryOperator>(BO);
-        if (!OverflowOp || !Q.IIQ.hasNoSignedWrap(OverflowOp))
+        if (!isa<GetElementPtrInst, OverflowingBinaryOperator>(BO))
+          break;
+        if ((isa<GetElementPtrInst>(BO) && !Q.IIQ.isInBounds(cast<GetElementPtrInst>(BO))) || (OverflowOp && !Q.IIQ.hasNoSignedWrap(OverflowOp)))
           break;
 
         switch (Opcode) {
@@ -1737,6 +1744,14 @@ static void computeKnownBitsFromOperator(const Operator *I,
           Known.resetAll();
       }
     }
+
+  // Aligned pointers have trailing zeros - refine Known.Zero set
+  if (isa<PointerType>(CB->getType())) {
+    Align Alignment = CB->getPointerAlignment(Q.DL);
+    Known.Zero.setLowBits(Log2(Alignment));
+  }
+
+
     if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
       switch (II->getIntrinsicID()) {
       default:
@@ -2270,7 +2285,7 @@ void computeKnownBits(const Value *V, const APInt &DemandedElts,
 /// always a power of two (or zero).
 static bool isPowerOfTwoRecurrence(const PHINode *PN, bool OrZero,
                                    unsigned Depth, SimplifyQuery &Q) {
-  BinaryOperator *BO = nullptr;
+  Instruction *BO = nullptr;
   Value *Start = nullptr, *Step = nullptr;
   if (!matchSimpleRecurrence(PN, BO, Start, Step))
     return false;
@@ -2308,7 +2323,7 @@ static bool isPowerOfTwoRecurrence(const PHINode *PN, bool OrZero,
     // Divisor must be a power of two.
     // If OrZero is false, cannot guarantee induction variable is non-zero after
     // division, same for Shr, unless it is exact division.
-    return (OrZero || Q.IIQ.isExact(BO)) &&
+    return (OrZero || Q.IIQ.isExact(cast<BinaryOperator>(BO))) &&
            isKnownToBeAPowerOfTwo(Step, false, Depth, Q);
   case Instruction::Shl:
     return OrZero || Q.IIQ.hasNoUnsignedWrap(BO) || Q.IIQ.hasNoSignedWrap(BO);
@@ -2317,7 +2332,7 @@ static bool isPowerOfTwoRecurrence(const PHINode *PN, bool OrZero,
       return false;
     [[fallthrough]];
   case Instruction::LShr:
-    return OrZero || Q.IIQ.isExact(BO);
+    return OrZero || Q.IIQ.isExact(cast<BinaryOperator>(BO));
   default:
     return false;
   }
@@ -2727,7 +2742,7 @@ static bool rangeMetadataExcludesValue(const MDNode* Ranges, const APInt& Value)
 /// Try to detect a recurrence that monotonically increases/decreases from a
 /// non-zero starting value. These are common as induction variables.
 static bool isNonZeroRecurrence(const PHINode *PN) {
-  BinaryOperator *BO = nullptr;
+  Instruction *BO = nullptr;
   Value *Start = nullptr, *Step = nullptr;
   const APInt *StartC, *StepC;
   if (!matchSimpleRecurrence(PN, BO, Start, Step) ||
@@ -3560,9 +3575,9 @@ getInvertibleOperands(const Operator *Op1,
     // If PN1 and PN2 are both recurrences, can we prove the entire recurrences
     // are a single invertible function of the start values? Note that repeated
     // application of an invertible function is also invertible
-    BinaryOperator *BO1 = nullptr;
+    Instruction *BO1 = nullptr;
     Value *Start1 = nullptr, *Step1 = nullptr;
-    BinaryOperator *BO2 = nullptr;
+    Instruction *BO2 = nullptr;
     Value *Start2 = nullptr, *Step2 = nullptr;
     if (PN1->getParent() != PN2->getParent() ||
         !matchSimpleRecurrence(PN1, BO1, Start1, Step1) ||
@@ -9197,7 +9212,7 @@ llvm::canConvertToMinOrMaxIntrinsic(ArrayRef<Value *> VL) {
   return {Intrinsic::not_intrinsic, false};
 }
 
-bool llvm::matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO,
+bool llvm::matchSimpleRecurrence(const PHINode *P, Instruction *&BO,
                                  Value *&Start, Value *&Step) {
   // Handle the case of a simple two-predecessor recurrence PHI.
   // There's a lot more that could theoretically be done here, but
@@ -9208,7 +9223,7 @@ bool llvm::matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO,
   for (unsigned i = 0; i != 2; ++i) {
     Value *L = P->getIncomingValue(i);
     Value *R = P->getIncomingValue(!i);
-    auto *LU = dyn_cast<BinaryOperator>(L);
+    auto *LU = dyn_cast<Instruction>(L);
     if (!LU)
       continue;
     unsigned Opcode = LU->getOpcode();
@@ -9240,6 +9255,20 @@ bool llvm::matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO,
 
       break; // Match!
     }
+    case Instruction::GetElementPtr: {
+      if (LU->getNumOperands() != 2 || !cast<GetElementPtrInst>(L)->getSourceElementType()->isIntegerTy(8))
+        continue;
+
+      Value *LL = LU->getOperand(0);
+      Value *LR = LU->getOperand(1);
+      // Find a recurrence.
+      if (LL == P) {
+        // Found a match
+        L = LR;
+        break;
+      }
+      continue;
+    }
     };
 
     // We have matched a recurrence of the form:
@@ -9256,9 +9285,9 @@ bool llvm::matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO,
   return false;
 }
 
-bool llvm::matchSimpleRecurrence(const BinaryOperator *I, PHINode *&P,
+bool llvm::matchSimpleRecurrence(const Instruction *I, PHINode *&P,
                                  Value *&Start, Value *&Step) {
-  BinaryOperator *BO = nullptr;
+  Instruction *BO = nullptr;
   P = dyn_cast<PHINode>(I->getOperand(0));
   if (!P)
     P = dyn_cast<PHINode>(I->getOperand(1));
diff --git a/llvm/lib/Target/ARM/MVEGatherScatterLowering.cpp b/llvm/lib/Target/ARM/MVEGatherScatterLowering.cpp
index 7efd2989aa7fa4..b4b3536ddcf21f 100644
--- a/llvm/lib/Target/ARM/MVEGatherScatterLowering.cpp
+++ b/llvm/lib/Target/ARM/MVEGatherScatterLowering.cpp
@@ -1025,10 +1025,10 @@ bool MVEGatherScatterLowering::optimiseOffsets(Value *Offsets, BasicBlock *BB,
     return false;
 
   // We're looking for a simple add recurrence.
-  BinaryOperator *IncInstruction;
+  Instruction *IncInstruction;
   Value *Start, *IncrementPerRound;
   if (!matchSimpleRecurrence(Phi, IncInstruction, Start, IncrementPerRound) ||
-      IncInstruction->getOpcode() != Instruction::Add)
+      IncInstruction->getOpcode() != Instruction::Add || !isa<BinaryOperator>(IncInstruction))
     return false;
 
   int IncrementingBlock = Phi->getIncomingValue(0) == IncInstruction ? 0 : 1;
diff --git a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
index 39c0af79859719..59f1d8292908b6 100644
--- a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
@@ -203,9 +203,11 @@ bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L,
       return false;
 
     Value *Step, *Start;
-    if (!matchSimpleRecurrence(Phi, Inc, Start, Step) ||
+    Instruction *Inc2;
+    if (!matchSimpleRecurrence(Phi, Inc2, Start, Step) ||
         Inc->getOpcode() != Instruction::Add)
       return false;
+    Inc = cast<BinaryOperator>(Inc2);
     assert(Phi->getNumIncomingValues() == 2 && "Expected 2 operand phi.");
     unsigned IncrementingBlock = Phi->getIncomingValue(0) == Inc ? 0 : 1;
     assert(Phi->getIncomingValue(IncrementingBlock) == Inc &&
diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
index d51d043f9f0d9b..cb32a8ff07cd14 100644
--- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
@@ -6108,7 +6108,7 @@ void LSRInstance::ImplementSolution(
   // chosen a non-optimal result for the actual schedule.  (And yes, this
   // scheduling decision does impact later codegen.)
   for (PHINode &PN : L->getHeader()->phis()) {
-    BinaryOperator *BO = nullptr;
+    Instruction *BO = nullptr;
     Value *Start = nullptr, *Step = nullptr;
     if (!matchSimpleRecurrence(&PN, BO, Start, Step))
       continue;



More information about the llvm-commits mailing list