[llvm] [LV] Rewrite UDiv A, B -> UDiv A, UMax(B, 1) in trip counts if needed. (PR #92177)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Fri Sep 20 05:03:08 PDT 2024


https://github.com/fhahn updated https://github.com/llvm/llvm-project/pull/92177

>From bff4e5adb6a438947898e5852af3367b911b6635 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Tue, 14 May 2024 20:36:16 +0100
Subject: [PATCH] [LV] Don't vectorize if trip count expansion may introduce
 UB.

Introduce a utility to check if a SCEV expansion may introduce UB
(couldn't find a similar utility after a quick glance) and use to the
avoid vectorizing when expanding the trip count introduces UB.

Fixes https://github.com/llvm/llvm-project/issues/89958.

!fixup introduce SafeUDivMode to SCEVExpander.

Step
---
 .../Utils/ScalarEvolutionExpander.h           | 13 +++++++
 llvm/lib/Analysis/ScalarEvolution.cpp         | 11 ++++--
 .../Utils/ScalarEvolutionExpander.cpp         | 15 +++++++-
 .../trip-count-expansion-may-introduce-ub.ll  | 36 +++++++++++++------
 4 files changed, 61 insertions(+), 14 deletions(-)

diff --git a/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h b/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h
index 62c1e15a9a60e1..cb3b7372f7e66e 100644
--- a/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h
+++ b/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h
@@ -124,6 +124,11 @@ class SCEVExpander : public SCEVVisitor<SCEVExpander, Value *> {
   /// "expanded" form.
   bool LSRMode;
 
+  /// When true, rewrite any divisors of UDiv expressions that may be 0 to
+  /// umax(Divisor, 1) to avoid introducing UB. If the divisor may be poison,
+  /// freeze it first.
+  bool SafeUDivMode = false;
+
   typedef IRBuilder<InstSimplifyFolder, IRBuilderCallbackInserter> BuilderType;
   BuilderType Builder;
 
@@ -300,6 +305,9 @@ class SCEVExpander : public SCEVVisitor<SCEVExpander, Value *> {
   /// location and their operands are defined at this location.
   bool isSafeToExpandAt(const SCEV *S, const Instruction *InsertionPoint) const;
 
+  static bool isSafeToExpand(const SCEV *S, bool CanonicalMode,
+                             ScalarEvolution &SE);
+
   /// Insert code to directly compute the specified SCEV expression into the
   /// program.  The code is inserted into the specified block.
   Value *expandCodeFor(const SCEV *SH, Type *Ty, BasicBlock::iterator I);
@@ -418,6 +426,11 @@ class SCEVExpander : public SCEVVisitor<SCEVExpander, Value *> {
   BasicBlock::iterator findInsertPointAfter(Instruction *I,
                                             Instruction *MustDominate) const;
 
+  static const SCEV *rewriteExpressionToRemoveUB(const SCEV *BTC, Loop *L,
+                                                 ScalarEvolution &SE);
+
+  void setSafeUDivMode() { SafeUDivMode = true; }
+
 private:
   LLVMContext &getContext() const { return SE.getContext(); }
 
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 1d3443588ce60d..b4d02df02694ca 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -4304,12 +4304,19 @@ ScalarEvolution::getSequentialMinMaxExpr(SCEVTypes Kind,
   }
 
   for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
+    bool MayBeUB = SCEVExprContains(Ops[i], [](const SCEV *S) {
+      auto *UDiv = dyn_cast<SCEVUDivExpr>(S);
+      return UDiv && !isa<SCEVConstant>(UDiv->getOperand(1));
+    });
+
+    if (MayBeUB)
+      continue;
     // We can replace %x umin_seq %y with %x umin %y if either:
     //  * %y being poison implies %x is also poison.
     //  * %x cannot be the saturating value (e.g. zero for umin).
     if (::impliesPoison(Ops[i], Ops[i - 1]) ||
-        isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, Ops[i - 1],
-                                        SaturationPoint)) {
+        (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, Ops[i - 1],
+                                         SaturationPoint))) {
       SmallVector<const SCEV *> SeqOps = {Ops[i - 1], Ops[i]};
       Ops[i - 1] = getMinMaxExpr(
           SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(Kind),
diff --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
index 0927a3015818fd..ee7b2f74db6fa9 100644
--- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
+++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
@@ -665,6 +665,12 @@ Value *SCEVExpander::visitUDivExpr(const SCEVUDivExpr *S) {
   }
 
   Value *RHS = expand(S->getRHS());
+  if (SafeUDivMode && !SE.isKnownNonZero(S->getRHS())) {
+    if (!isa<SCEVConstant>(S->getRHS()))
+      RHS = Builder.CreateFreeze(RHS);
+    RHS = Builder.CreateIntrinsic(RHS->getType(), Intrinsic::umax,
+                                  {RHS, ConstantInt::get(RHS->getType(), 1)});
+  }
   return InsertBinop(Instruction::UDiv, LHS, RHS, SCEV::FlagAnyWrap,
                      /*IsSafeToHoist*/ SE.isKnownNonZero(S->getRHS()));
 }
@@ -1358,11 +1364,13 @@ Value *SCEVExpander::visitSignExtendExpr(const SCEVSignExtendExpr *S) {
 Value *SCEVExpander::expandMinMaxExpr(const SCEVNAryExpr *S,
                                       Intrinsic::ID IntrinID, Twine Name,
                                       bool IsSequential) {
+  SafeUDivMode = true;
   Value *LHS = expand(S->getOperand(S->getNumOperands() - 1));
   Type *Ty = LHS->getType();
   if (IsSequential)
     LHS = Builder.CreateFreeze(LHS);
   for (int i = S->getNumOperands() - 2; i >= 0; --i) {
+    SafeUDivMode = i != 0;
     Value *RHS = expand(S->getOperand(i));
     if (IsSequential && i != 0)
       RHS = Builder.CreateFreeze(RHS);
@@ -2315,12 +2323,17 @@ struct SCEVFindUnsafe {
 };
 } // namespace
 
-bool SCEVExpander::isSafeToExpand(const SCEV *S) const {
+bool SCEVExpander::isSafeToExpand(const SCEV *S, bool CanonicalMode,
+                                  ScalarEvolution &SE) {
   SCEVFindUnsafe Search(SE, CanonicalMode);
   visitAll(S, Search);
   return !Search.IsUnsafe;
 }
 
+bool SCEVExpander::isSafeToExpand(const SCEV *S) const {
+  return isSafeToExpand(S, CanonicalMode, SE);
+}
+
 bool SCEVExpander::isSafeToExpandAt(const SCEV *S,
                                     const Instruction *InsertionPoint) const {
   if (!isSafeToExpand(S))
diff --git a/llvm/test/Transforms/LoopVectorize/trip-count-expansion-may-introduce-ub.ll b/llvm/test/Transforms/LoopVectorize/trip-count-expansion-may-introduce-ub.ll
index 6fba8ccd590c62..69806880ce46ee 100644
--- a/llvm/test/Transforms/LoopVectorize/trip-count-expansion-may-introduce-ub.ll
+++ b/llvm/test/Transforms/LoopVectorize/trip-count-expansion-may-introduce-ub.ll
@@ -463,9 +463,12 @@ define i64 @multi_exit_4_exit_count_with_udiv_by_value_in_latch(ptr %dst, i64 %N
 ; CHECK-LABEL: define i64 @multi_exit_4_exit_count_with_udiv_by_value_in_latch(
 ; CHECK-SAME: ptr [[DST:%.*]], i64 [[N:%.*]]) {
 ; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP9:%.*]] = freeze i64 [[N]]
+; CHECK-NEXT:    [[TMP10:%.*]] = call i64 @llvm.umax.i64(i64 [[TMP9]], i64 1)
+; CHECK-NEXT:    [[TMP0:%.*]] = udiv i64 42, [[TMP10]]
+; CHECK-NEXT:    [[TMP8:%.*]] = freeze i64 [[TMP0]]
 ; CHECK-NEXT:    [[SMAX:%.*]] = call i64 @llvm.smax.i64(i64 [[N]], i64 0)
-; CHECK-NEXT:    [[TMP0:%.*]] = udiv i64 42, [[N]]
-; CHECK-NEXT:    [[UMIN:%.*]] = call i64 @llvm.umin.i64(i64 [[SMAX]], i64 [[TMP0]])
+; CHECK-NEXT:    [[UMIN:%.*]] = call i64 @llvm.umin.i64(i64 [[TMP8]], i64 [[SMAX]])
 ; CHECK-NEXT:    [[TMP1:%.*]] = add nuw nsw i64 [[UMIN]], 1
 ; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ule i64 [[TMP1]], 4
 ; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
@@ -529,7 +532,9 @@ define i64 @multi_exit_4_exit_count_with_udiv_by_value_in_latch_different_bounds
 ; CHECK-LABEL: define i64 @multi_exit_4_exit_count_with_udiv_by_value_in_latch_different_bounds(
 ; CHECK-SAME: ptr [[DST:%.*]], i64 [[N:%.*]], i64 [[M:%.*]]) {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[TMP2:%.*]] = udiv i64 42, [[M]]
+; CHECK-NEXT:    [[TMP0:%.*]] = freeze i64 [[M]]
+; CHECK-NEXT:    [[TMP1:%.*]] = call i64 @llvm.umax.i64(i64 [[TMP0]], i64 1)
+; CHECK-NEXT:    [[TMP2:%.*]] = udiv i64 42, [[TMP1]]
 ; CHECK-NEXT:    [[TMP3:%.*]] = freeze i64 [[TMP2]]
 ; CHECK-NEXT:    [[SMAX:%.*]] = call i64 @llvm.smax.i64(i64 [[N]], i64 0)
 ; CHECK-NEXT:    [[UMIN:%.*]] = call i64 @llvm.umin.i64(i64 [[TMP3]], i64 [[SMAX]])
@@ -598,9 +603,12 @@ define i64 @multi_exit_4_exit_count_with_udiv_by_frozen_value_in_latch(ptr %dst,
 ; CHECK-SAME: ptr [[DST:%.*]], i64 [[N:%.*]]) {
 ; CHECK-NEXT:  entry:
 ; CHECK-NEXT:    [[FR_N:%.*]] = freeze i64 [[N]]
-; CHECK-NEXT:    [[TMP2:%.*]] = call i64 @llvm.smax.i64(i64 [[N]], i64 0)
-; CHECK-NEXT:    [[TMP0:%.*]] = udiv i64 42, [[FR_N]]
-; CHECK-NEXT:    [[UMIN:%.*]] = call i64 @llvm.umin.i64(i64 [[TMP2]], i64 [[TMP0]])
+; CHECK-NEXT:    [[TMP0:%.*]] = freeze i64 [[FR_N]]
+; CHECK-NEXT:    [[TMP1:%.*]] = call i64 @llvm.umax.i64(i64 [[TMP0]], i64 1)
+; CHECK-NEXT:    [[TMP2:%.*]] = udiv i64 42, [[TMP1]]
+; CHECK-NEXT:    [[TMP10:%.*]] = freeze i64 [[TMP2]]
+; CHECK-NEXT:    [[SMAX:%.*]] = call i64 @llvm.smax.i64(i64 [[N]], i64 0)
+; CHECK-NEXT:    [[UMIN:%.*]] = call i64 @llvm.umin.i64(i64 [[TMP10]], i64 [[SMAX]])
 ; CHECK-NEXT:    [[TMP3:%.*]] = add nuw nsw i64 [[UMIN]], 1
 ; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ule i64 [[TMP3]], 4
 ; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
@@ -786,12 +794,15 @@ define i64 @multi_exit_4_exit_count_with_urem_by_value_in_latch(ptr %dst, i64 %N
 ; CHECK-LABEL: define i64 @multi_exit_4_exit_count_with_urem_by_value_in_latch(
 ; CHECK-SAME: ptr [[DST:%.*]], i64 [[N:%.*]]) {
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[SMAX:%.*]] = call i64 @llvm.smax.i64(i64 [[N]], i64 0)
-; CHECK-NEXT:    [[TMP0:%.*]] = udiv i64 42, [[N]]
+; CHECK-NEXT:    [[TMP11:%.*]] = freeze i64 [[N]]
+; CHECK-NEXT:    [[TMP12:%.*]] = call i64 @llvm.umax.i64(i64 [[TMP11]], i64 1)
+; CHECK-NEXT:    [[TMP0:%.*]] = udiv i64 42, [[TMP12]]
 ; CHECK-NEXT:    [[TMP1:%.*]] = mul nuw i64 [[N]], [[TMP0]]
 ; CHECK-NEXT:    [[TMP2:%.*]] = sub i64 42, [[TMP1]]
 ; CHECK-NEXT:    [[SMAX1:%.*]] = call i64 @llvm.smax.i64(i64 [[TMP2]], i64 0)
-; CHECK-NEXT:    [[UMIN:%.*]] = call i64 @llvm.umin.i64(i64 [[SMAX]], i64 [[SMAX1]])
+; CHECK-NEXT:    [[TMP10:%.*]] = freeze i64 [[SMAX1]]
+; CHECK-NEXT:    [[SMAX2:%.*]] = call i64 @llvm.smax.i64(i64 [[N]], i64 0)
+; CHECK-NEXT:    [[UMIN:%.*]] = call i64 @llvm.umin.i64(i64 [[TMP10]], i64 [[SMAX2]])
 ; CHECK-NEXT:    [[TMP3:%.*]] = add nuw i64 [[UMIN]], 1
 ; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ule i64 [[TMP3]], 4
 ; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
@@ -1004,9 +1015,12 @@ define i64 @multi_exit_4_exit_count_with_udiv_by_value_in_latch1(ptr %dst, i64 %
 ; CHECK-LABEL: define i64 @multi_exit_4_exit_count_with_udiv_by_value_in_latch1(
 ; CHECK-SAME: ptr [[DST:%.*]], i64 [[N:%.*]]) {
 ; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = freeze i64 [[N]]
+; CHECK-NEXT:    [[TMP8:%.*]] = call i64 @llvm.umax.i64(i64 [[TMP0]], i64 1)
+; CHECK-NEXT:    [[TMP9:%.*]] = udiv i64 42, [[TMP8]]
+; CHECK-NEXT:    [[TMP10:%.*]] = freeze i64 [[TMP9]]
 ; CHECK-NEXT:    [[SMAX:%.*]] = call i64 @llvm.smax.i64(i64 [[N]], i64 0)
-; CHECK-NEXT:    [[TMP0:%.*]] = udiv i64 42, [[N]]
-; CHECK-NEXT:    [[UMIN:%.*]] = call i64 @llvm.umin.i64(i64 [[SMAX]], i64 [[TMP0]])
+; CHECK-NEXT:    [[UMIN:%.*]] = call i64 @llvm.umin.i64(i64 [[TMP10]], i64 [[SMAX]])
 ; CHECK-NEXT:    [[TMP1:%.*]] = add nuw nsw i64 [[UMIN]], 1
 ; CHECK-NEXT:    [[MIN_ITERS_CHECK:%.*]] = icmp ule i64 [[TMP1]], 4
 ; CHECK-NEXT:    br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]



More information about the llvm-commits mailing list