[llvm] [SCEV] Match both (-1)b + a and a + (-1)b as a - b (PR #84247)

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Wed Mar 6 14:24:37 PST 2024


https://github.com/preames created https://github.com/llvm/llvm-project/pull/84247

In our analysis of guarding conditions, we were converting a-b == 0 into a == b alternate form, but we were only checking for one of the two forms for the sub.  There's no requirement that the multiply only be on the LHS of the add.

>From 617316cd2bc7b112fba0b5f7fe9239efbaab761f Mon Sep 17 00:00:00 2001
From: Philip Reames <preames at rivosinc.com>
Date: Wed, 6 Mar 2024 14:21:17 -0800
Subject: [PATCH] [SCEV] Match both (-1)b + a and a + (-1)b as a - b

In our analysis of guarding conditions, we were converting a-b == 0
into a == b alternate form, but we were only checking for one of the
two forms for the sub.  There's no requirement that the multiply only
be on the LHS of the add.
---
 llvm/lib/Analysis/ScalarEvolution.cpp         | 32 ++++++++++++-------
 .../Analysis/ScalarEvolution/trip-count.ll    |  2 +-
 2 files changed, 22 insertions(+), 12 deletions(-)

diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 15c2965aede1a0..acc0aa23107bb5 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -10577,6 +10577,25 @@ static bool HasSameValue(const SCEV *A, const SCEV *B) {
   return false;
 }
 
+static bool MatchBinarySub(const SCEV *S, const SCEV *&LHS, const SCEV *&RHS) {
+  const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S);
+  if (!Add || Add->getNumOperands() != 2)
+    return false;
+  if (auto *ME = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
+      ME && ME->getNumOperands() == 2 && ME->getOperand(0)->isAllOnesValue()) {
+    LHS = Add->getOperand(1);
+    RHS = ME->getOperand(1);
+    return true;
+  }
+  if (auto *ME = dyn_cast<SCEVMulExpr>(Add->getOperand(1));
+      ME && ME->getNumOperands() == 2 && ME->getOperand(0)->isAllOnesValue()) {
+    LHS = Add->getOperand(0);
+    RHS = ME->getOperand(1);
+    return true;
+  }
+  return false;
+}
+
 bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred,
                                            const SCEV *&LHS, const SCEV *&RHS,
                                            unsigned Depth) {
@@ -10652,19 +10671,10 @@ bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred,
       case ICmpInst::ICMP_EQ:
       case ICmpInst::ICMP_NE:
         // Fold ((-1) * %a) + %b == 0 (equivalent to %b-%a == 0) into %a == %b.
-        if (!RA)
-          if (const SCEVAddExpr *AE = dyn_cast<SCEVAddExpr>(LHS))
-            if (const SCEVMulExpr *ME =
-                    dyn_cast<SCEVMulExpr>(AE->getOperand(0)))
-              if (AE->getNumOperands() == 2 && ME->getNumOperands() == 2 &&
-                  ME->getOperand(0)->isAllOnesValue()) {
-                RHS = AE->getOperand(1);
-                LHS = ME->getOperand(1);
-                Changed = true;
-              }
+        if (RA.isZero() && MatchBinarySub(LHS, LHS, RHS))
+          Changed = true;
         break;
 
-
         // The "Should have been caught earlier!" messages refer to the fact
         // that the ExactCR.isFullSet() or ExactCR.isEmptySet() check above
         // should have fired on the corresponding cases, and canonicalized the
diff --git a/llvm/test/Analysis/ScalarEvolution/trip-count.ll b/llvm/test/Analysis/ScalarEvolution/trip-count.ll
index cbe07effdeb265..8fc5b9b4096127 100644
--- a/llvm/test/Analysis/ScalarEvolution/trip-count.ll
+++ b/llvm/test/Analysis/ScalarEvolution/trip-count.ll
@@ -145,7 +145,7 @@ define void @dual_sext_ne_with_slt_guard(i8 %s, i8 %n) {
 ; CHECK-LABEL: 'dual_sext_ne_with_slt_guard'
 ; CHECK-NEXT:  Determining loop execution counts for: @dual_sext_ne_with_slt_guard
 ; CHECK-NEXT:  Loop %for.body: backedge-taken count is (-1 + (sext i8 %n to i64) + (-1 * (sext i8 %s to i64))<nsw>)
-; CHECK-NEXT:  Loop %for.body: constant max backedge-taken count is i64 -1
+; CHECK-NEXT:  Loop %for.body: constant max backedge-taken count is i64 -2
 ; CHECK-NEXT:  Loop %for.body: symbolic max backedge-taken count is (-1 + (sext i8 %n to i64) + (-1 * (sext i8 %s to i64))<nsw>)
 ; CHECK-NEXT:  Loop %for.body: Trip multiple is 1
 ;



More information about the llvm-commits mailing list