[llvm] 8e5aa96 - [SCEV] Preserve divisibility and min/max information in applyLoopGuards
via llvm-commits
llvm-commits at lists.llvm.org
Mon Mar 20 03:06:58 PDT 2023
Author: Alon Kom
Date: 2023-03-20T12:04:05+02:00
New Revision: 8e5aa969d0e9960bfc3d4e14144899076895e1b4
URL: https://github.com/llvm/llvm-project/commit/8e5aa969d0e9960bfc3d4e14144899076895e1b4
DIFF: https://github.com/llvm/llvm-project/commit/8e5aa969d0e9960bfc3d4e14144899076895e1b4.diff
LOG: [SCEV] Preserve divisibility and min/max information in applyLoopGuards
applyLoopGuards doesn't always preserve information when there are multiple assumes.
This patch tries to deal with multiple assumes regarding a SCEV's divisibility and min/max values, and rewrite it into a SCEV that still preserves all of the information.
For example, let the trip count of the loop be TC. Consider the 3 following assumes:
1. __builtin_assume(TC % 8 == 0);
2. __builtin_assume(TC > 0);
3. __builtin_assume(TC < 100);
Before this patch, depending on the assume processing order applyLoopGuards could create the following SCEV:
max(min((8 * (TC / 8)) , 99), 1)
Looking at this SCEV, it doesn't preserve the divisibility by 8 information.
After this patch, depending on the assume processing order applyLoopGuards could create the following SCEV:
max(min((8 * (TC / 8)) , 96), 8)
By aligning up 1 to 8, and aligning down 99 to 96, the new SCEV still preserves all of the original assumes.
Differential Revision: https://reviews.llvm.org/D144947
Added:
Modified:
llvm/lib/Analysis/ScalarEvolution.cpp
llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll
llvm/unittests/Analysis/ScalarEvolutionTest.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index e3c4fc57c202..df525f4d6be7 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -15023,6 +15023,93 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
if (MatchRangeCheckIdiom())
return;
+ // Return true if \p Expr is a MinMax SCEV expression with a non-negative
+ // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
+ // the non-constant operand and in \p LHS the constant operand.
+ auto IsMinMaxSCEVWithNonNegativeConstant =
+ [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
+ const SCEV *&RHS) {
+ if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
+ if (MinMax->getNumOperands() != 2)
+ return false;
+ if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
+ if (C->getAPInt().isNegative())
+ return false;
+ SCTy = MinMax->getSCEVType();
+ LHS = MinMax->getOperand(0);
+ RHS = MinMax->getOperand(1);
+ return true;
+ }
+ }
+ return false;
+ };
+
+ // Checks whether Expr is a non-negative constant, and Divisor is a positive
+ // constant, and returns their APInt in ExprVal and in DivisorVal.
+ auto GetNonNegExprAndPosDivisor = [&](const SCEV *Expr, const SCEV *Divisor,
+ APInt &ExprVal, APInt &DivisorVal) {
+ auto *ConstExpr = dyn_cast<SCEVConstant>(Expr);
+ auto *ConstDivisor = dyn_cast<SCEVConstant>(Divisor);
+ if (!ConstExpr || !ConstDivisor)
+ return false;
+ ExprVal = ConstExpr->getAPInt();
+ DivisorVal = ConstDivisor->getAPInt();
+ return ExprVal.isNonNegative() && !DivisorVal.isNonPositive();
+ };
+
+ // Return a new SCEV that modifies \p Expr to the closest number divides by
+ // \p Divisor and greater or equal than Expr.
+ // For now, only handle constant Expr and Divisor.
+ auto GetNextSCEVDividesByDivisor = [&](const SCEV *Expr,
+ const SCEV *Divisor) {
+ APInt ExprVal;
+ APInt DivisorVal;
+ if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
+ return Expr;
+ APInt Rem = ExprVal.urem(DivisorVal);
+ if (!Rem.isZero())
+ // return the SCEV: Expr + Divisor - Expr % Divisor
+ return getConstant(ExprVal + DivisorVal - Rem);
+ return Expr;
+ };
+
+ // Return a new SCEV that modifies \p Expr to the closest number divides by
+ // \p Divisor and less or equal than Expr.
+ // For now, only handle constant Expr and Divisor.
+ auto GetPreviousSCEVDividesByDivisor = [&](const SCEV *Expr,
+ const SCEV *Divisor) {
+ APInt ExprVal;
+ APInt DivisorVal;
+ if (!GetNonNegExprAndPosDivisor(Expr, Divisor, ExprVal, DivisorVal))
+ return Expr;
+ APInt Rem = ExprVal.urem(DivisorVal);
+ // return the SCEV: Expr - Expr % Divisor
+ return getConstant(ExprVal - Rem);
+ };
+
+ // Apply divisibilty by \p Divisor on MinMaxExpr with constant values,
+ // recursively. This is done by aligning up/down the constant value to the
+ // Divisor.
+ std::function<const SCEV *(const SCEV *, const SCEV *)>
+ ApplyDivisibiltyOnMinMaxExpr = [&](const SCEV *MinMaxExpr,
+ const SCEV *Divisor) {
+ const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
+ SCEVTypes SCTy;
+ if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
+ MinMaxRHS))
+ return MinMaxExpr;
+ auto IsMin =
+ isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
+ assert(isKnownNonNegative(MinMaxLHS) &&
+ "Expected non-negative operand!");
+ auto *DivisibleExpr =
+ IsMin ? GetPreviousSCEVDividesByDivisor(MinMaxLHS, Divisor)
+ : GetNextSCEVDividesByDivisor(MinMaxLHS, Divisor);
+ SmallVector<const SCEV *> Ops = {
+ ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr};
+ return getMinMaxExpr(SCTy, Ops);
+ };
+
// If we have LHS == 0, check if LHS is computing a property of some unknown
// SCEV %v which we can rewrite %v to express explicitly.
const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS);
@@ -15034,7 +15121,12 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
const SCEV *URemRHS = nullptr;
if (matchURem(LHS, URemLHS, URemRHS)) {
if (const SCEVUnknown *LHSUnknown = dyn_cast<SCEVUnknown>(URemLHS)) {
- const auto *Multiple = getMulExpr(getUDivExpr(URemLHS, URemRHS), URemRHS);
+ auto I = RewriteMap.find(LHSUnknown);
+ const SCEV *RewrittenLHS =
+ I != RewriteMap.end() ? I->second : LHSUnknown;
+ RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
+ const auto *Multiple =
+ getMulExpr(getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
RewriteMap[LHSUnknown] = Multiple;
ExprsToRewrite.push_back(LHSUnknown);
return;
@@ -15071,6 +15163,52 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
return I != RewriteMap.end() ? I->second : S;
};
+ // Check for the SCEV expression (A /u B) * B while B is a constant, inside
+ // \p Expr. The check is done recuresively on \p Expr, which is assumed to
+ // be a composition of Min/Max SCEVs. Return whether the SCEV expression (A
+ // /u B) * B was found, and return the divisor B in \p DividesBy. For
+ // example, if Expr = umin (umax ((A /u 8) * 8, 16), 64), return true since
+ // (A /u 8) * 8 matched the pattern, and return the constant SCEV 8 in \p
+ // DividesBy.
+ std::function<bool(const SCEV *, const SCEV *&)> HasDivisibiltyInfo =
+ [&](const SCEV *Expr, const SCEV *&DividesBy) {
+ if (auto *Mul = dyn_cast<SCEVMulExpr>(Expr)) {
+ if (Mul->getNumOperands() != 2)
+ return false;
+ auto *MulLHS = Mul->getOperand(0);
+ auto *MulRHS = Mul->getOperand(1);
+ if (isa<SCEVConstant>(MulLHS))
+ std::swap(MulLHS, MulRHS);
+ if (auto *Div = dyn_cast<SCEVUDivExpr>(MulLHS))
+ if (Div->getOperand(1) == MulRHS) {
+ DividesBy = MulRHS;
+ return true;
+ }
+ }
+ if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr))
+ return HasDivisibiltyInfo(MinMax->getOperand(0), DividesBy) ||
+ HasDivisibiltyInfo(MinMax->getOperand(1), DividesBy);
+ return false;
+ };
+
+ // Return true if Expr known to divide by \p DividesBy.
+ std::function<bool(const SCEV *, const SCEV *&)> IsKnownToDivideBy =
+ [&](const SCEV *Expr, const SCEV *DividesBy) {
+ if (getURemExpr(Expr, DividesBy)->isZero())
+ return true;
+ if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr))
+ return IsKnownToDivideBy(MinMax->getOperand(0), DividesBy) &&
+ IsKnownToDivideBy(MinMax->getOperand(1), DividesBy);
+ return false;
+ };
+
+ const SCEV *RewrittenLHS = GetMaybeRewritten(LHS);
+ const SCEV *DividesBy = nullptr;
+ if (HasDivisibiltyInfo(RewrittenLHS, DividesBy))
+ // Check that the whole expression is divided by DividesBy
+ DividesBy =
+ IsKnownToDivideBy(RewrittenLHS, DividesBy) ? DividesBy : nullptr;
+
// Collect rewrites for LHS and its transitive operands based on the
// condition.
// For min/max expressions, also apply the guard to its operands:
@@ -15091,11 +15229,21 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
LLVM_FALLTHROUGH;
case CmpInst::ICMP_SLT: {
RHS = getMinusSCEV(RHS, One);
+ RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS;
break;
}
case CmpInst::ICMP_UGT:
case CmpInst::ICMP_SGT:
RHS = getAddExpr(RHS, One);
+ RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS;
+ break;
+ case CmpInst::ICMP_ULE:
+ case CmpInst::ICMP_SLE:
+ RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS;
+ break;
+ case CmpInst::ICMP_UGE:
+ case CmpInst::ICMP_SGE:
+ RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS;
break;
default:
break;
@@ -15148,8 +15296,11 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
break;
case CmpInst::ICMP_NE:
if (isa<SCEVConstant>(RHS) &&
- cast<SCEVConstant>(RHS)->getValue()->isNullValue())
- To = getUMaxExpr(FromRewritten, One);
+ cast<SCEVConstant>(RHS)->getValue()->isNullValue()) {
+ const SCEV *OneAlignedUp =
+ DividesBy ? GetNextSCEVDividesByDivisor(One, DividesBy) : One;
+ To = getUMaxExpr(FromRewritten, OneAlignedUp);
+ }
break;
default:
break;
diff --git a/llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll b/llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll
index cfa91e3cc747..492ed9c4d265 100644
--- a/llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll
+++ b/llvm/test/Analysis/ScalarEvolution/trip-multiple-guard-info.ll
@@ -125,7 +125,7 @@ define void @test_trip_multiple_4_ugt_5_order_swapped(i32 %num) {
; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %num)
; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num)
; CHECK-NEXT: Predicates:
-; CHECK: Loop %for.body: Trip multiple is 2
+; CHECK: Loop %for.body: Trip multiple is 4
;
entry:
%u = urem i32 %num, 4
@@ -196,7 +196,7 @@ define void @test_trip_multiple_4_sgt_5_order_swapped(i32 %num) {
; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %num)
; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num)
; CHECK-NEXT: Predicates:
-; CHECK: Loop %for.body: Trip multiple is 2
+; CHECK: Loop %for.body: Trip multiple is 4
;
entry:
%u = urem i32 %num, 4
@@ -267,7 +267,7 @@ define void @test_trip_multiple_4_uge_5_order_swapped(i32 %num) {
; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %num)
; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num)
; CHECK-NEXT: Predicates:
-; CHECK: Loop %for.body: Trip multiple is 1
+; CHECK: Loop %for.body: Trip multiple is 4
;
entry:
%u = urem i32 %num, 4
@@ -338,7 +338,7 @@ define void @test_trip_multiple_4_sge_5_order_swapped(i32 %num) {
; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %num)
; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num)
; CHECK-NEXT: Predicates:
-; CHECK: Loop %for.body: Trip multiple is 1
+; CHECK: Loop %for.body: Trip multiple is 4
;
entry:
%u = urem i32 %num, 4
@@ -409,7 +409,7 @@ define void @test_trip_multiple_4_upper_lower_bounds(i32 %num) {
; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %num)
; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num)
; CHECK-NEXT: Predicates:
-; CHECK: Loop %for.body: Trip multiple is 1
+; CHECK: Loop %for.body: Trip multiple is 4
;
entry:
%cmp.1 = icmp uge i32 %num, 5
@@ -446,7 +446,7 @@ define void @test_trip_multiple_4_upper_lower_bounds_swapped1(i32 %num) {
; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %num)
; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num)
; CHECK-NEXT: Predicates:
-; CHECK: Loop %for.body: Trip multiple is 1
+; CHECK: Loop %for.body: Trip multiple is 4
;
entry:
%cmp.1 = icmp uge i32 %num, 5
@@ -483,7 +483,7 @@ define void @test_trip_multiple_4_upper_lower_bounds_swapped2(i32 %num) {
; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is (-1 + %num)
; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is (-1 + %num)
; CHECK-NEXT: Predicates:
-; CHECK: Loop %for.body: Trip multiple is 1
+; CHECK: Loop %for.body: Trip multiple is 4
;
entry:
%cmp.1 = icmp uge i32 %num, 5
diff --git a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
index 985d1cbc642a..1834e8cad56f 100644
--- a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
+++ b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
@@ -1760,4 +1760,42 @@ TEST_F(ScalarEvolutionsTest, CheckGetPowerOfTwo) {
->equalsInt(1ULL << i));
}
+TEST_F(ScalarEvolutionsTest, ApplyLoopGuards) {
+ LLVMContext C;
+ SMDiagnostic Err;
+ std::unique_ptr<Module> M = parseAssemblyString(
+ "declare void @llvm.assume(i1)\n"
+ "define void @test(i32 %num) {\n"
+ "entry:\n"
+ " %u = urem i32 %num, 4\n"
+ " %cmp = icmp eq i32 %u, 0\n"
+ " tail call void @llvm.assume(i1 %cmp)\n"
+ " %cmp.1 = icmp ugt i32 %num, 0\n"
+ " tail call void @llvm.assume(i1 %cmp.1)\n"
+ " br label %for.body\n"
+ "for.body:\n"
+ " %i.010 = phi i32 [ 0, %entry ], [ %inc, %for.body ]\n"
+ " %inc = add nuw nsw i32 %i.010, 1\n"
+ " %cmp2 = icmp ult i32 %inc, %num\n"
+ " br i1 %cmp2, label %for.body, label %exit\n"
+ "exit:\n"
+ " ret void\n"
+ "}\n",
+ Err, C);
+
+ ASSERT_TRUE(M && "Could not parse module?");
+ ASSERT_TRUE(!verifyModule(*M) && "Must have been well formed!");
+
+ runWithSE(*M, "test", [](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
+ auto *TCScev = SE.getSCEV(getArgByName(F, "num"));
+ auto *ApplyLoopGuardsTC = SE.applyLoopGuards(TCScev, *LI.begin());
+ // Assert that the new TC is (4 * ((4 umax %num) /u 4))
+ APInt Four(32, 4);
+ auto *Constant4 = SE.getConstant(Four);
+ auto *Max = SE.getUMaxExpr(TCScev, Constant4);
+ auto *Mul = SE.getMulExpr(SE.getUDivExpr(Max, Constant4), Constant4);
+ ASSERT_TRUE(Mul == ApplyLoopGuardsTC);
+ });
+}
+
} // end namespace llvm
More information about the llvm-commits
mailing list