[llvm] 65715ac - [SCEV] Generalize umin_seq matching

Roman Lebedev via llvm-commits llvm-commits at lists.llvm.org
Fri Feb 11 10:58:41 PST 2022


Author: Roman Lebedev
Date: 2022-02-11T21:58:19+03:00
New Revision: 65715ac72aedd2219e06815983e20d60986c9c48

URL: https://github.com/llvm/llvm-project/commit/65715ac72aedd2219e06815983e20d60986c9c48
DIFF: https://github.com/llvm/llvm-project/commit/65715ac72aedd2219e06815983e20d60986c9c48.diff

LOG: [SCEV] Generalize umin_seq matching

Since we don't greedily flatten `umin_seq(a, umin(b, c))` into `umin_seq(a, b, c)`,
just looking at the operands of the outer-level `umin` is not sufficient,
and we need to recurse into all same-typed `umin`'s.

Added: 
    

Modified: 
    llvm/lib/Analysis/ScalarEvolution.cpp
    llvm/test/Analysis/ScalarEvolution/logical-operations.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index b0d5f0ccfbbf..b50e6f5697aa 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -5887,6 +5887,44 @@ const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
   return getUnknown(PN);
 }
 
+bool SCEVMinMaxExprContains(const SCEV *Root, const SCEV *OperandToFind,
+                            SCEVTypes RootKind) {
+  struct FindClosure {
+    const SCEV *OperandToFind;
+    const SCEVTypes RootKind; // Must be a sequential min/max expression.
+    const SCEVTypes NonSequentialRootKind; // Non-seq variant of RootKind.
+
+    bool Found = false;
+
+    bool canRecurseInto(SCEVTypes Kind) const {
+      // We can only recurse into the SCEV expression of the same effective type
+      // as the type of our root SCEV expression.
+      return RootKind == Kind || NonSequentialRootKind == Kind;
+    };
+
+    FindClosure(const SCEV *OperandToFind, SCEVTypes RootKind)
+        : OperandToFind(OperandToFind), RootKind(RootKind),
+          NonSequentialRootKind(
+              SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(
+                  RootKind)) {}
+
+    bool follow(const SCEV *S) {
+      if (isDone())
+        return false;
+
+      Found = S == OperandToFind;
+
+      return !isDone() && canRecurseInto(S->getSCEVType());
+    }
+
+    bool isDone() const { return Found; }
+  };
+
+  FindClosure FC(OperandToFind, RootKind);
+  visitAll(Root, FC);
+  return FC.Found;
+}
+
 const SCEV *ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(
     Instruction *I, ICmpInst *Cond, Value *TrueVal, Value *FalseVal) {
   // Try to match some simple smax or umax patterns.
@@ -5969,15 +6007,14 @@ const SCEV *ScalarEvolution::createNodeForSelectOrPHIInstWithICmpInstCond(
     }
     // x == 0 ? 0 : umin    (..., x, ...)  ->  umin_seq(x, umin    (...))
     // x == 0 ? 0 : umin_seq(..., x, ...)  ->  umin_seq(x, umin_seq(...))
+    // x == 0 ? 0 : umin    (..., umin_seq(..., x, ...), ...)
+    //                    ->  umin_seq(x, umin (..., umin_seq(...), ...))
     if (getTypeSizeInBits(LHS->getType()) == getTypeSizeInBits(I->getType()) &&
         isa<ConstantInt>(RHS) && cast<ConstantInt>(RHS)->isZero() &&
         isa<ConstantInt>(TrueVal) && cast<ConstantInt>(TrueVal)->isZero()) {
       const SCEV *X = getSCEV(LHS);
-      auto *FalseValExpr = dyn_cast<SCEVNAryExpr>(getSCEV(FalseVal));
-      if (FalseValExpr &&
-          (FalseValExpr->getSCEVType() == scUMinExpr ||
-           FalseValExpr->getSCEVType() == scSequentialUMinExpr) &&
-          is_contained(FalseValExpr->operands(), X))
+      const SCEV *FalseValExpr = getSCEV(FalseVal);
+      if (SCEVMinMaxExprContains(FalseValExpr, X, scSequentialUMinExpr))
         return getUMinExpr(X, FalseValExpr, /*Sequential=*/true);
     }
     break;

diff  --git a/llvm/test/Analysis/ScalarEvolution/logical-operations.ll b/llvm/test/Analysis/ScalarEvolution/logical-operations.ll
index 4456d2eb6a4b..682d99c2e349 100644
--- a/llvm/test/Analysis/ScalarEvolution/logical-operations.ll
+++ b/llvm/test/Analysis/ScalarEvolution/logical-operations.ll
@@ -608,7 +608,7 @@ define i32 @umin_seq_x_y_z(i32 %x, i32 %y, i32 %z) {
 ; CHECK-NEXT:    %r0 = select i1 %y.is.zero, i32 0, i32 %umin
 ; CHECK-NEXT:    --> (%y umin_seq (%x umin %z)) U: full-set S: full-set
 ; CHECK-NEXT:    %r = select i1 %x.is.zero, i32 0, i32 %r0
-; CHECK-NEXT:    --> %r U: full-set S: full-set
+; CHECK-NEXT:    --> (%x umin_seq %y umin_seq %z) U: full-set S: full-set
 ; CHECK-NEXT:  Determining loop execution counts for: @umin_seq_x_y_z
 ;
   %umin0 = call i32 @llvm.umin(i32 %z, i32 %x)
@@ -632,7 +632,7 @@ define i32 @umin_seq_a_b_c_d(i32 %a, i32 %b, i32 %c, i32 %d) {
 ; CHECK-NEXT:    %umin = call i32 @llvm.umin.i32(i32 %umin0, i32 %r1)
 ; CHECK-NEXT:    --> ((%c umin_seq %d) umin %a umin %b) U: full-set S: full-set
 ; CHECK-NEXT:    %r = select i1 %d.is.zero, i32 0, i32 %umin
-; CHECK-NEXT:    --> %r U: full-set S: full-set
+; CHECK-NEXT:    --> (%d umin_seq (%a umin %b umin %c)) U: full-set S: full-set
 ; CHECK-NEXT:  Determining loop execution counts for: @umin_seq_a_b_c_d
 ;
   %umin1 = call i32 @llvm.umin(i32 %c, i32 %d)


        


More information about the llvm-commits mailing list