[llvm] [LV] Rewrite UDiv A, B -> UDiv A, UMax(B, 1) in trip counts if needed. (PR #92177)
Florian Hahn via llvm-commits
llvm-commits at lists.llvm.org
Fri Sep 20 05:03:08 PDT 2024
https://github.com/fhahn updated https://github.com/llvm/llvm-project/pull/92177
>From bff4e5adb6a438947898e5852af3367b911b6635 Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Tue, 14 May 2024 20:36:16 +0100
Subject: [PATCH] [LV] Don't vectorize if trip count expansion may introduce
UB.
Introduce a utility to check if a SCEV expansion may introduce UB
(couldn't find a similar utility after a quick glance) and use to the
avoid vectorizing when expanding the trip count introduces UB.
Fixes https://github.com/llvm/llvm-project/issues/89958.
!fixup introduce SafeUDivMode to SCEVExpander.
Step
---
.../Utils/ScalarEvolutionExpander.h | 13 +++++++
llvm/lib/Analysis/ScalarEvolution.cpp | 11 ++++--
.../Utils/ScalarEvolutionExpander.cpp | 15 +++++++-
.../trip-count-expansion-may-introduce-ub.ll | 36 +++++++++++++------
4 files changed, 61 insertions(+), 14 deletions(-)
diff --git a/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h b/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h
index 62c1e15a9a60e1..cb3b7372f7e66e 100644
--- a/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h
+++ b/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h
@@ -124,6 +124,11 @@ class SCEVExpander : public SCEVVisitor<SCEVExpander, Value *> {
/// "expanded" form.
bool LSRMode;
+ /// When true, rewrite any divisors of UDiv expressions that may be 0 to
+ /// umax(Divisor, 1) to avoid introducing UB. If the divisor may be poison,
+ /// freeze it first.
+ bool SafeUDivMode = false;
+
typedef IRBuilder<InstSimplifyFolder, IRBuilderCallbackInserter> BuilderType;
BuilderType Builder;
@@ -300,6 +305,9 @@ class SCEVExpander : public SCEVVisitor<SCEVExpander, Value *> {
/// location and their operands are defined at this location.
bool isSafeToExpandAt(const SCEV *S, const Instruction *InsertionPoint) const;
+ static bool isSafeToExpand(const SCEV *S, bool CanonicalMode,
+ ScalarEvolution &SE);
+
/// Insert code to directly compute the specified SCEV expression into the
/// program. The code is inserted into the specified block.
Value *expandCodeFor(const SCEV *SH, Type *Ty, BasicBlock::iterator I);
@@ -418,6 +426,11 @@ class SCEVExpander : public SCEVVisitor<SCEVExpander, Value *> {
BasicBlock::iterator findInsertPointAfter(Instruction *I,
Instruction *MustDominate) const;
+ static const SCEV *rewriteExpressionToRemoveUB(const SCEV *BTC, Loop *L,
+ ScalarEvolution &SE);
+
+ void setSafeUDivMode() { SafeUDivMode = true; }
+
private:
LLVMContext &getContext() const { return SE.getContext(); }
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 1d3443588ce60d..b4d02df02694ca 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -4304,12 +4304,19 @@ ScalarEvolution::getSequentialMinMaxExpr(SCEVTypes Kind,
}
for (unsigned i = 1, e = Ops.size(); i != e; ++i) {
+ bool MayBeUB = SCEVExprContains(Ops[i], [](const SCEV *S) {
+ auto *UDiv = dyn_cast<SCEVUDivExpr>(S);
+ return UDiv && !isa<SCEVConstant>(UDiv->getOperand(1));
+ });
+
+ if (MayBeUB)
+ continue;
// We can replace %x umin_seq %y with %x umin %y if either:
// * %y being poison implies %x is also poison.
// * %x cannot be the saturating value (e.g. zero for umin).
if (::impliesPoison(Ops[i], Ops[i - 1]) ||
- isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, Ops[i - 1],
- SaturationPoint)) {
+ (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, Ops[i - 1],
+ SaturationPoint))) {
SmallVector<const SCEV *> SeqOps = {Ops[i - 1], Ops[i]};
Ops[i - 1] = getMinMaxExpr(
SCEVSequentialMinMaxExpr::getEquivalentNonSequentialSCEVType(Kind),
diff --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
index 0927a3015818fd..ee7b2f74db6fa9 100644
--- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
+++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp
@@ -665,6 +665,12 @@ Value *SCEVExpander::visitUDivExpr(const SCEVUDivExpr *S) {
}
Value *RHS = expand(S->getRHS());
+ if (SafeUDivMode && !SE.isKnownNonZero(S->getRHS())) {
+ if (!isa<SCEVConstant>(S->getRHS()))
+ RHS = Builder.CreateFreeze(RHS);
+ RHS = Builder.CreateIntrinsic(RHS->getType(), Intrinsic::umax,
+ {RHS, ConstantInt::get(RHS->getType(), 1)});
+ }
return InsertBinop(Instruction::UDiv, LHS, RHS, SCEV::FlagAnyWrap,
/*IsSafeToHoist*/ SE.isKnownNonZero(S->getRHS()));
}
@@ -1358,11 +1364,13 @@ Value *SCEVExpander::visitSignExtendExpr(const SCEVSignExtendExpr *S) {
Value *SCEVExpander::expandMinMaxExpr(const SCEVNAryExpr *S,
Intrinsic::ID IntrinID, Twine Name,
bool IsSequential) {
+ SafeUDivMode = true;
Value *LHS = expand(S->getOperand(S->getNumOperands() - 1));
Type *Ty = LHS->getType();
if (IsSequential)
LHS = Builder.CreateFreeze(LHS);
for (int i = S->getNumOperands() - 2; i >= 0; --i) {
+ SafeUDivMode = i != 0;
Value *RHS = expand(S->getOperand(i));
if (IsSequential && i != 0)
RHS = Builder.CreateFreeze(RHS);
@@ -2315,12 +2323,17 @@ struct SCEVFindUnsafe {
};
} // namespace
-bool SCEVExpander::isSafeToExpand(const SCEV *S) const {
+bool SCEVExpander::isSafeToExpand(const SCEV *S, bool CanonicalMode,
+ ScalarEvolution &SE) {
SCEVFindUnsafe Search(SE, CanonicalMode);
visitAll(S, Search);
return !Search.IsUnsafe;
}
+bool SCEVExpander::isSafeToExpand(const SCEV *S) const {
+ return isSafeToExpand(S, CanonicalMode, SE);
+}
+
bool SCEVExpander::isSafeToExpandAt(const SCEV *S,
const Instruction *InsertionPoint) const {
if (!isSafeToExpand(S))
diff --git a/llvm/test/Transforms/LoopVectorize/trip-count-expansion-may-introduce-ub.ll b/llvm/test/Transforms/LoopVectorize/trip-count-expansion-may-introduce-ub.ll
index 6fba8ccd590c62..69806880ce46ee 100644
--- a/llvm/test/Transforms/LoopVectorize/trip-count-expansion-may-introduce-ub.ll
+++ b/llvm/test/Transforms/LoopVectorize/trip-count-expansion-may-introduce-ub.ll
@@ -463,9 +463,12 @@ define i64 @multi_exit_4_exit_count_with_udiv_by_value_in_latch(ptr %dst, i64 %N
; CHECK-LABEL: define i64 @multi_exit_4_exit_count_with_udiv_by_value_in_latch(
; CHECK-SAME: ptr [[DST:%.*]], i64 [[N:%.*]]) {
; CHECK-NEXT: entry:
+; CHECK-NEXT: [[TMP9:%.*]] = freeze i64 [[N]]
+; CHECK-NEXT: [[TMP10:%.*]] = call i64 @llvm.umax.i64(i64 [[TMP9]], i64 1)
+; CHECK-NEXT: [[TMP0:%.*]] = udiv i64 42, [[TMP10]]
+; CHECK-NEXT: [[TMP8:%.*]] = freeze i64 [[TMP0]]
; CHECK-NEXT: [[SMAX:%.*]] = call i64 @llvm.smax.i64(i64 [[N]], i64 0)
-; CHECK-NEXT: [[TMP0:%.*]] = udiv i64 42, [[N]]
-; CHECK-NEXT: [[UMIN:%.*]] = call i64 @llvm.umin.i64(i64 [[SMAX]], i64 [[TMP0]])
+; CHECK-NEXT: [[UMIN:%.*]] = call i64 @llvm.umin.i64(i64 [[TMP8]], i64 [[SMAX]])
; CHECK-NEXT: [[TMP1:%.*]] = add nuw nsw i64 [[UMIN]], 1
; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ule i64 [[TMP1]], 4
; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
@@ -529,7 +532,9 @@ define i64 @multi_exit_4_exit_count_with_udiv_by_value_in_latch_different_bounds
; CHECK-LABEL: define i64 @multi_exit_4_exit_count_with_udiv_by_value_in_latch_different_bounds(
; CHECK-SAME: ptr [[DST:%.*]], i64 [[N:%.*]], i64 [[M:%.*]]) {
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[TMP2:%.*]] = udiv i64 42, [[M]]
+; CHECK-NEXT: [[TMP0:%.*]] = freeze i64 [[M]]
+; CHECK-NEXT: [[TMP1:%.*]] = call i64 @llvm.umax.i64(i64 [[TMP0]], i64 1)
+; CHECK-NEXT: [[TMP2:%.*]] = udiv i64 42, [[TMP1]]
; CHECK-NEXT: [[TMP3:%.*]] = freeze i64 [[TMP2]]
; CHECK-NEXT: [[SMAX:%.*]] = call i64 @llvm.smax.i64(i64 [[N]], i64 0)
; CHECK-NEXT: [[UMIN:%.*]] = call i64 @llvm.umin.i64(i64 [[TMP3]], i64 [[SMAX]])
@@ -598,9 +603,12 @@ define i64 @multi_exit_4_exit_count_with_udiv_by_frozen_value_in_latch(ptr %dst,
; CHECK-SAME: ptr [[DST:%.*]], i64 [[N:%.*]]) {
; CHECK-NEXT: entry:
; CHECK-NEXT: [[FR_N:%.*]] = freeze i64 [[N]]
-; CHECK-NEXT: [[TMP2:%.*]] = call i64 @llvm.smax.i64(i64 [[N]], i64 0)
-; CHECK-NEXT: [[TMP0:%.*]] = udiv i64 42, [[FR_N]]
-; CHECK-NEXT: [[UMIN:%.*]] = call i64 @llvm.umin.i64(i64 [[TMP2]], i64 [[TMP0]])
+; CHECK-NEXT: [[TMP0:%.*]] = freeze i64 [[FR_N]]
+; CHECK-NEXT: [[TMP1:%.*]] = call i64 @llvm.umax.i64(i64 [[TMP0]], i64 1)
+; CHECK-NEXT: [[TMP2:%.*]] = udiv i64 42, [[TMP1]]
+; CHECK-NEXT: [[TMP10:%.*]] = freeze i64 [[TMP2]]
+; CHECK-NEXT: [[SMAX:%.*]] = call i64 @llvm.smax.i64(i64 [[N]], i64 0)
+; CHECK-NEXT: [[UMIN:%.*]] = call i64 @llvm.umin.i64(i64 [[TMP10]], i64 [[SMAX]])
; CHECK-NEXT: [[TMP3:%.*]] = add nuw nsw i64 [[UMIN]], 1
; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ule i64 [[TMP3]], 4
; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
@@ -786,12 +794,15 @@ define i64 @multi_exit_4_exit_count_with_urem_by_value_in_latch(ptr %dst, i64 %N
; CHECK-LABEL: define i64 @multi_exit_4_exit_count_with_urem_by_value_in_latch(
; CHECK-SAME: ptr [[DST:%.*]], i64 [[N:%.*]]) {
; CHECK-NEXT: entry:
-; CHECK-NEXT: [[SMAX:%.*]] = call i64 @llvm.smax.i64(i64 [[N]], i64 0)
-; CHECK-NEXT: [[TMP0:%.*]] = udiv i64 42, [[N]]
+; CHECK-NEXT: [[TMP11:%.*]] = freeze i64 [[N]]
+; CHECK-NEXT: [[TMP12:%.*]] = call i64 @llvm.umax.i64(i64 [[TMP11]], i64 1)
+; CHECK-NEXT: [[TMP0:%.*]] = udiv i64 42, [[TMP12]]
; CHECK-NEXT: [[TMP1:%.*]] = mul nuw i64 [[N]], [[TMP0]]
; CHECK-NEXT: [[TMP2:%.*]] = sub i64 42, [[TMP1]]
; CHECK-NEXT: [[SMAX1:%.*]] = call i64 @llvm.smax.i64(i64 [[TMP2]], i64 0)
-; CHECK-NEXT: [[UMIN:%.*]] = call i64 @llvm.umin.i64(i64 [[SMAX]], i64 [[SMAX1]])
+; CHECK-NEXT: [[TMP10:%.*]] = freeze i64 [[SMAX1]]
+; CHECK-NEXT: [[SMAX2:%.*]] = call i64 @llvm.smax.i64(i64 [[N]], i64 0)
+; CHECK-NEXT: [[UMIN:%.*]] = call i64 @llvm.umin.i64(i64 [[TMP10]], i64 [[SMAX2]])
; CHECK-NEXT: [[TMP3:%.*]] = add nuw i64 [[UMIN]], 1
; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ule i64 [[TMP3]], 4
; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
@@ -1004,9 +1015,12 @@ define i64 @multi_exit_4_exit_count_with_udiv_by_value_in_latch1(ptr %dst, i64 %
; CHECK-LABEL: define i64 @multi_exit_4_exit_count_with_udiv_by_value_in_latch1(
; CHECK-SAME: ptr [[DST:%.*]], i64 [[N:%.*]]) {
; CHECK-NEXT: entry:
+; CHECK-NEXT: [[TMP0:%.*]] = freeze i64 [[N]]
+; CHECK-NEXT: [[TMP8:%.*]] = call i64 @llvm.umax.i64(i64 [[TMP0]], i64 1)
+; CHECK-NEXT: [[TMP9:%.*]] = udiv i64 42, [[TMP8]]
+; CHECK-NEXT: [[TMP10:%.*]] = freeze i64 [[TMP9]]
; CHECK-NEXT: [[SMAX:%.*]] = call i64 @llvm.smax.i64(i64 [[N]], i64 0)
-; CHECK-NEXT: [[TMP0:%.*]] = udiv i64 42, [[N]]
-; CHECK-NEXT: [[UMIN:%.*]] = call i64 @llvm.umin.i64(i64 [[SMAX]], i64 [[TMP0]])
+; CHECK-NEXT: [[UMIN:%.*]] = call i64 @llvm.umin.i64(i64 [[TMP10]], i64 [[SMAX]])
; CHECK-NEXT: [[TMP1:%.*]] = add nuw nsw i64 [[UMIN]], 1
; CHECK-NEXT: [[MIN_ITERS_CHECK:%.*]] = icmp ule i64 [[TMP1]], 4
; CHECK-NEXT: br i1 [[MIN_ITERS_CHECK]], label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
More information about the llvm-commits
mailing list