[llvm] b267d3c - [InstCombine] avoid infinite loops from min/max canonicalization
Sanjay Patel via llvm-commits
llvm-commits at lists.llvm.org
Tue Aug 10 11:42:43 PDT 2021
Author: Sanjay Patel
Date: 2021-08-10T14:42:37-04:00
New Revision: b267d3ce8defa092600bda717ff18440d002f316
URL: https://github.com/llvm/llvm-project/commit/b267d3ce8defa092600bda717ff18440d002f316
DIFF: https://github.com/llvm/llvm-project/commit/b267d3ce8defa092600bda717ff18440d002f316.diff
LOG: [InstCombine] avoid infinite loops from min/max canonicalization
The intrinsics have an extra chunk of known bits logic
compared to the normal cmp+select idiom. That allows
folding the icmp in each case to something better, but
that then opposes the canonical form of min/max that
we try to form for a select.
I'm carving out a narrow exception to preserve all
existing regression tests while avoiding the inf-loop.
It seems unlikely that this is the only bug like this
left, but this should fix:
https://llvm.org/PR51419
Added:
Modified:
llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
llvm/test/Transforms/InstCombine/select-min-max.ll
Removed:
################################################################################
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 8514952e9241..2e20bca300d3 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -5184,6 +5184,83 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
if (!isa<Constant>(Op1) && Op1Min == Op1Max)
return new ICmpInst(Pred, Op0, ConstantExpr::getIntegerValue(Ty, Op1Min));
+ // Don't break up a clamp pattern -- (min(max X, Y), Z) -- by replacing a
+ // min/max canonical compare with some other compare. That could lead to
+ // conflict with select canonicalization and infinite looping.
+ // FIXME: This constraint may go away if min/max intrinsics are canonical.
+ auto isMinMaxCmp = [&](Instruction &Cmp) {
+ if (!Cmp.hasOneUse())
+ return false;
+ Value *A, *B;
+ SelectPatternFlavor SPF = matchSelectPattern(Cmp.user_back(), A, B).Flavor;
+ if (!SelectPatternResult::isMinOrMax(SPF))
+ return false;
+ return match(Op0, m_MaxOrMin(m_Value(), m_Value())) ||
+ match(Op1, m_MaxOrMin(m_Value(), m_Value()));
+ };
+ if (!isMinMaxCmp(I)) {
+ switch (Pred) {
+ default:
+ break;
+ case ICmpInst::ICMP_ULT: {
+ if (Op1Min == Op0Max) // A <u B -> A != B if max(A) == min(B)
+ return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
+ const APInt *CmpC;
+ if (match(Op1, m_APInt(CmpC))) {
+ // A <u C -> A == C-1 if min(A)+1 == C
+ if (*CmpC == Op0Min + 1)
+ return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+ ConstantInt::get(Op1->getType(), *CmpC - 1));
+ // X <u C --> X == 0, if the number of zero bits in the bottom of X
+ // exceeds the log2 of C.
+ if (Op0Known.countMinTrailingZeros() >= CmpC->ceilLogBase2())
+ return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+ Constant::getNullValue(Op1->getType()));
+ }
+ break;
+ }
+ case ICmpInst::ICMP_UGT: {
+ if (Op1Max == Op0Min) // A >u B -> A != B if min(A) == max(B)
+ return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
+ const APInt *CmpC;
+ if (match(Op1, m_APInt(CmpC))) {
+ // A >u C -> A == C+1 if max(a)-1 == C
+ if (*CmpC == Op0Max - 1)
+ return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+ ConstantInt::get(Op1->getType(), *CmpC + 1));
+ // X >u C --> X != 0, if the number of zero bits in the bottom of X
+ // exceeds the log2 of C.
+ if (Op0Known.countMinTrailingZeros() >= CmpC->getActiveBits())
+ return new ICmpInst(ICmpInst::ICMP_NE, Op0,
+ Constant::getNullValue(Op1->getType()));
+ }
+ break;
+ }
+ case ICmpInst::ICMP_SLT: {
+ if (Op1Min == Op0Max) // A <s B -> A != B if max(A) == min(B)
+ return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
+ const APInt *CmpC;
+ if (match(Op1, m_APInt(CmpC))) {
+ if (*CmpC == Op0Min + 1) // A <s C -> A == C-1 if min(A)+1 == C
+ return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+ ConstantInt::get(Op1->getType(), *CmpC - 1));
+ }
+ break;
+ }
+ case ICmpInst::ICMP_SGT: {
+ if (Op1Max == Op0Min) // A >s B -> A != B if min(A) == max(B)
+ return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
+ const APInt *CmpC;
+ if (match(Op1, m_APInt(CmpC))) {
+ if (*CmpC == Op0Max - 1) // A >s C -> A == C+1 if max(A)-1 == C
+ return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+ ConstantInt::get(Op1->getType(), *CmpC + 1));
+ }
+ break;
+ }
+ }
+ }
+
// Based on the range information we know about the LHS, see if we can
// simplify this comparison. For example, (x&4) < 8 is always true.
switch (Pred) {
@@ -5245,21 +5322,6 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
if (Op0Min.uge(Op1Max)) // A <u B -> false if min(A) >= max(B)
return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
- if (Op1Min == Op0Max) // A <u B -> A != B if max(A) == min(B)
- return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
-
- const APInt *CmpC;
- if (match(Op1, m_APInt(CmpC))) {
- // A <u C -> A == C-1 if min(A)+1 == C
- if (*CmpC == Op0Min + 1)
- return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
- ConstantInt::get(Op1->getType(), *CmpC - 1));
- // X <u C --> X == 0, if the number of zero bits in the bottom of X
- // exceeds the log2 of C.
- if (Op0Known.countMinTrailingZeros() >= CmpC->ceilLogBase2())
- return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
- Constant::getNullValue(Op1->getType()));
- }
break;
}
case ICmpInst::ICMP_UGT: {
@@ -5267,21 +5329,6 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
if (Op0Max.ule(Op1Min)) // A >u B -> false if max(A) <= max(B)
return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
- if (Op1Max == Op0Min) // A >u B -> A != B if min(A) == max(B)
- return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
-
- const APInt *CmpC;
- if (match(Op1, m_APInt(CmpC))) {
- // A >u C -> A == C+1 if max(a)-1 == C
- if (*CmpC == Op0Max - 1)
- return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
- ConstantInt::get(Op1->getType(), *CmpC + 1));
- // X >u C --> X != 0, if the number of zero bits in the bottom of X
- // exceeds the log2 of C.
- if (Op0Known.countMinTrailingZeros() >= CmpC->getActiveBits())
- return new ICmpInst(ICmpInst::ICMP_NE, Op0,
- Constant::getNullValue(Op1->getType()));
- }
break;
}
case ICmpInst::ICMP_SLT: {
@@ -5289,14 +5336,6 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
if (Op0Min.sge(Op1Max)) // A <s B -> false if min(A) >= max(C)
return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
- if (Op1Min == Op0Max) // A <s B -> A != B if max(A) == min(B)
- return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
- const APInt *CmpC;
- if (match(Op1, m_APInt(CmpC))) {
- if (*CmpC == Op0Min + 1) // A <s C -> A == C-1 if min(A)+1 == C
- return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
- ConstantInt::get(Op1->getType(), *CmpC - 1));
- }
break;
}
case ICmpInst::ICMP_SGT: {
@@ -5304,14 +5343,6 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
if (Op0Max.sle(Op1Min)) // A >s B -> false if max(A) <= min(B)
return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
- if (Op1Max == Op0Min) // A >s B -> A != B if min(A) == max(B)
- return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
- const APInt *CmpC;
- if (match(Op1, m_APInt(CmpC))) {
- if (*CmpC == Op0Max - 1) // A >s C -> A == C+1 if max(A)-1 == C
- return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
- ConstantInt::get(Op1->getType(), *CmpC + 1));
- }
break;
}
case ICmpInst::ICMP_SGE:
diff --git a/llvm/test/Transforms/InstCombine/select-min-max.ll b/llvm/test/Transforms/InstCombine/select-min-max.ll
index 79dbe40f20c6..caf0f7320d94 100644
--- a/llvm/test/Transforms/InstCombine/select-min-max.ll
+++ b/llvm/test/Transforms/InstCombine/select-min-max.ll
@@ -192,3 +192,64 @@ define <3 x i5> @umax_select_const(<3 x i1> %b, <3 x i5> %x) {
%c = call <3 x i5> @llvm.umax.v3i5(<3 x i5> <i5 5, i5 8, i5 1>, <3 x i5> %s)
ret <3 x i5> %c
}
+
+declare i32 @llvm.smax.i32(i32, i32);
+declare i32 @llvm.smin.i32(i32, i32);
+declare i8 @llvm.umax.i8(i8, i8);
+declare i8 @llvm.umin.i8(i8, i8);
+
+; Each of the following 4 tests would infinite loop because
+; we had conflicting transforms for icmp and select using
+; known bits.
+
+define i32 @smax_smin(i32 %x) {
+; CHECK-LABEL: @smax_smin(
+; CHECK-NEXT: [[M:%.*]] = call i32 @llvm.smax.i32(i32 [[X:%.*]], i32 0)
+; CHECK-NEXT: [[TMP1:%.*]] = icmp ult i32 [[M]], 1
+; CHECK-NEXT: [[S:%.*]] = select i1 [[TMP1]], i32 [[M]], i32 1
+; CHECK-NEXT: ret i32 [[S]]
+;
+ %m = call i32 @llvm.smax.i32(i32 %x, i32 0)
+ %c = icmp slt i32 %x, 1
+ %s = select i1 %c, i32 %m, i32 1
+ ret i32 %s
+}
+
+define i32 @smin_smax(i32 %x) {
+; CHECK-LABEL: @smin_smax(
+; CHECK-NEXT: [[M:%.*]] = call i32 @llvm.smin.i32(i32 [[X:%.*]], i32 -1)
+; CHECK-NEXT: [[TMP1:%.*]] = icmp ugt i32 [[M]], -2
+; CHECK-NEXT: [[S:%.*]] = select i1 [[TMP1]], i32 [[M]], i32 -2
+; CHECK-NEXT: ret i32 [[S]]
+;
+ %m = call i32 @llvm.smin.i32(i32 %x, i32 -1)
+ %c = icmp sgt i32 %x, -2
+ %s = select i1 %c, i32 %m, i32 -2
+ ret i32 %s
+}
+
+define i8 @umax_umin(i8 %x) {
+; CHECK-LABEL: @umax_umin(
+; CHECK-NEXT: [[M:%.*]] = call i8 @llvm.umax.i8(i8 [[X:%.*]], i8 -128)
+; CHECK-NEXT: [[TMP1:%.*]] = icmp ult i8 [[M]], -127
+; CHECK-NEXT: [[S:%.*]] = select i1 [[TMP1]], i8 [[M]], i8 -127
+; CHECK-NEXT: ret i8 [[S]]
+;
+ %m = call i8 @llvm.umax.i8(i8 %x, i8 128)
+ %c = icmp ult i8 %x, 129
+ %s = select i1 %c, i8 %m, i8 129
+ ret i8 %s
+}
+
+define i8 @umin_umax(i8 %x) {
+; CHECK-LABEL: @umin_umax(
+; CHECK-NEXT: [[M:%.*]] = call i8 @llvm.umin.i8(i8 [[X:%.*]], i8 127)
+; CHECK-NEXT: [[TMP1:%.*]] = icmp ugt i8 [[M]], 126
+; CHECK-NEXT: [[S:%.*]] = select i1 [[TMP1]], i8 [[M]], i8 126
+; CHECK-NEXT: ret i8 [[S]]
+;
+ %m = call i8 @llvm.umin.i8(i8 %x, i8 127)
+ %c = icmp ugt i8 %x, 126
+ %s = select i1 %c, i8 %m, i8 126
+ ret i8 %s
+}
More information about the llvm-commits
mailing list