[llvm] b267d3c - [InstCombine] avoid infinite loops from min/max canonicalization

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Tue Aug 10 11:42:43 PDT 2021


Author: Sanjay Patel
Date: 2021-08-10T14:42:37-04:00
New Revision: b267d3ce8defa092600bda717ff18440d002f316

URL: https://github.com/llvm/llvm-project/commit/b267d3ce8defa092600bda717ff18440d002f316
DIFF: https://github.com/llvm/llvm-project/commit/b267d3ce8defa092600bda717ff18440d002f316.diff

LOG: [InstCombine] avoid infinite loops from min/max canonicalization

The intrinsics have an extra chunk of known bits logic
compared to the normal cmp+select idiom. That allows
folding the icmp in each case to something better, but
that then opposes the canonical form of min/max that
we try to form for a select.

I'm carving out a narrow exception to preserve all
existing regression tests while avoiding the inf-loop.
It seems unlikely that this is the only bug like this
left, but this should fix:
https://llvm.org/PR51419

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
    llvm/test/Transforms/InstCombine/select-min-max.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 8514952e9241..2e20bca300d3 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -5184,6 +5184,83 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
   if (!isa<Constant>(Op1) && Op1Min == Op1Max)
     return new ICmpInst(Pred, Op0, ConstantExpr::getIntegerValue(Ty, Op1Min));
 
+  // Don't break up a clamp pattern -- (min(max X, Y), Z) -- by replacing a
+  // min/max canonical compare with some other compare. That could lead to
+  // conflict with select canonicalization and infinite looping.
+  // FIXME: This constraint may go away if min/max intrinsics are canonical.
+  auto isMinMaxCmp = [&](Instruction &Cmp) {
+    if (!Cmp.hasOneUse())
+      return false;
+    Value *A, *B;
+    SelectPatternFlavor SPF = matchSelectPattern(Cmp.user_back(), A, B).Flavor;
+    if (!SelectPatternResult::isMinOrMax(SPF))
+      return false;
+    return match(Op0, m_MaxOrMin(m_Value(), m_Value())) ||
+           match(Op1, m_MaxOrMin(m_Value(), m_Value()));
+  };
+  if (!isMinMaxCmp(I)) {
+    switch (Pred) {
+    default:
+      break;
+    case ICmpInst::ICMP_ULT: {
+      if (Op1Min == Op0Max) // A <u B -> A != B if max(A) == min(B)
+        return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
+      const APInt *CmpC;
+      if (match(Op1, m_APInt(CmpC))) {
+        // A <u C -> A == C-1 if min(A)+1 == C
+        if (*CmpC == Op0Min + 1)
+          return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+                              ConstantInt::get(Op1->getType(), *CmpC - 1));
+        // X <u C --> X == 0, if the number of zero bits in the bottom of X
+        // exceeds the log2 of C.
+        if (Op0Known.countMinTrailingZeros() >= CmpC->ceilLogBase2())
+          return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+                              Constant::getNullValue(Op1->getType()));
+      }
+      break;
+    }
+    case ICmpInst::ICMP_UGT: {
+      if (Op1Max == Op0Min) // A >u B -> A != B if min(A) == max(B)
+        return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
+      const APInt *CmpC;
+      if (match(Op1, m_APInt(CmpC))) {
+        // A >u C -> A == C+1 if max(a)-1 == C
+        if (*CmpC == Op0Max - 1)
+          return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+                              ConstantInt::get(Op1->getType(), *CmpC + 1));
+        // X >u C --> X != 0, if the number of zero bits in the bottom of X
+        // exceeds the log2 of C.
+        if (Op0Known.countMinTrailingZeros() >= CmpC->getActiveBits())
+          return new ICmpInst(ICmpInst::ICMP_NE, Op0,
+                              Constant::getNullValue(Op1->getType()));
+      }
+      break;
+    }
+    case ICmpInst::ICMP_SLT: {
+      if (Op1Min == Op0Max) // A <s B -> A != B if max(A) == min(B)
+        return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
+      const APInt *CmpC;
+      if (match(Op1, m_APInt(CmpC))) {
+        if (*CmpC == Op0Min + 1) // A <s C -> A == C-1 if min(A)+1 == C
+          return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+                              ConstantInt::get(Op1->getType(), *CmpC - 1));
+      }
+      break;
+    }
+    case ICmpInst::ICMP_SGT: {
+      if (Op1Max == Op0Min) // A >s B -> A != B if min(A) == max(B)
+        return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
+      const APInt *CmpC;
+      if (match(Op1, m_APInt(CmpC))) {
+        if (*CmpC == Op0Max - 1) // A >s C -> A == C+1 if max(A)-1 == C
+          return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+                              ConstantInt::get(Op1->getType(), *CmpC + 1));
+      }
+      break;
+    }
+    }
+  }
+
   // Based on the range information we know about the LHS, see if we can
   // simplify this comparison.  For example, (x&4) < 8 is always true.
   switch (Pred) {
@@ -5245,21 +5322,6 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
       return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
     if (Op0Min.uge(Op1Max)) // A <u B -> false if min(A) >= max(B)
       return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
-    if (Op1Min == Op0Max) // A <u B -> A != B if max(A) == min(B)
-      return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
-
-    const APInt *CmpC;
-    if (match(Op1, m_APInt(CmpC))) {
-      // A <u C -> A == C-1 if min(A)+1 == C
-      if (*CmpC == Op0Min + 1)
-        return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
-                            ConstantInt::get(Op1->getType(), *CmpC - 1));
-      // X <u C --> X == 0, if the number of zero bits in the bottom of X
-      // exceeds the log2 of C.
-      if (Op0Known.countMinTrailingZeros() >= CmpC->ceilLogBase2())
-        return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
-                            Constant::getNullValue(Op1->getType()));
-    }
     break;
   }
   case ICmpInst::ICMP_UGT: {
@@ -5267,21 +5329,6 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
       return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
     if (Op0Max.ule(Op1Min)) // A >u B -> false if max(A) <= max(B)
       return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
-    if (Op1Max == Op0Min) // A >u B -> A != B if min(A) == max(B)
-      return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
-
-    const APInt *CmpC;
-    if (match(Op1, m_APInt(CmpC))) {
-      // A >u C -> A == C+1 if max(a)-1 == C
-      if (*CmpC == Op0Max - 1)
-        return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
-                            ConstantInt::get(Op1->getType(), *CmpC + 1));
-      // X >u C --> X != 0, if the number of zero bits in the bottom of X
-      // exceeds the log2 of C.
-      if (Op0Known.countMinTrailingZeros() >= CmpC->getActiveBits())
-        return new ICmpInst(ICmpInst::ICMP_NE, Op0,
-                            Constant::getNullValue(Op1->getType()));
-    }
     break;
   }
   case ICmpInst::ICMP_SLT: {
@@ -5289,14 +5336,6 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
       return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
     if (Op0Min.sge(Op1Max)) // A <s B -> false if min(A) >= max(C)
       return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
-    if (Op1Min == Op0Max) // A <s B -> A != B if max(A) == min(B)
-      return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
-    const APInt *CmpC;
-    if (match(Op1, m_APInt(CmpC))) {
-      if (*CmpC == Op0Min + 1) // A <s C -> A == C-1 if min(A)+1 == C
-        return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
-                            ConstantInt::get(Op1->getType(), *CmpC - 1));
-    }
     break;
   }
   case ICmpInst::ICMP_SGT: {
@@ -5304,14 +5343,6 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) {
       return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
     if (Op0Max.sle(Op1Min)) // A >s B -> false if max(A) <= min(B)
       return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
-    if (Op1Max == Op0Min) // A >s B -> A != B if min(A) == max(B)
-      return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
-    const APInt *CmpC;
-    if (match(Op1, m_APInt(CmpC))) {
-      if (*CmpC == Op0Max - 1) // A >s C -> A == C+1 if max(A)-1 == C
-        return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
-                            ConstantInt::get(Op1->getType(), *CmpC + 1));
-    }
     break;
   }
   case ICmpInst::ICMP_SGE:

diff  --git a/llvm/test/Transforms/InstCombine/select-min-max.ll b/llvm/test/Transforms/InstCombine/select-min-max.ll
index 79dbe40f20c6..caf0f7320d94 100644
--- a/llvm/test/Transforms/InstCombine/select-min-max.ll
+++ b/llvm/test/Transforms/InstCombine/select-min-max.ll
@@ -192,3 +192,64 @@ define <3 x i5> @umax_select_const(<3 x i1> %b, <3 x i5> %x) {
   %c = call <3 x i5> @llvm.umax.v3i5(<3 x i5> <i5 5, i5 8, i5 1>, <3 x i5> %s)
   ret <3 x i5> %c
 }
+
+declare i32 @llvm.smax.i32(i32, i32);
+declare i32 @llvm.smin.i32(i32, i32);
+declare i8 @llvm.umax.i8(i8, i8);
+declare i8 @llvm.umin.i8(i8, i8);
+
+; Each of the following 4 tests would infinite loop because
+; we had conflicting transforms for icmp and select using
+; known bits.
+
+define i32 @smax_smin(i32 %x) {
+; CHECK-LABEL: @smax_smin(
+; CHECK-NEXT:    [[M:%.*]] = call i32 @llvm.smax.i32(i32 [[X:%.*]], i32 0)
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp ult i32 [[M]], 1
+; CHECK-NEXT:    [[S:%.*]] = select i1 [[TMP1]], i32 [[M]], i32 1
+; CHECK-NEXT:    ret i32 [[S]]
+;
+  %m = call i32 @llvm.smax.i32(i32 %x, i32 0)
+  %c = icmp slt i32 %x, 1
+  %s = select i1 %c, i32 %m, i32 1
+  ret i32 %s
+}
+
+define i32 @smin_smax(i32 %x) {
+; CHECK-LABEL: @smin_smax(
+; CHECK-NEXT:    [[M:%.*]] = call i32 @llvm.smin.i32(i32 [[X:%.*]], i32 -1)
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp ugt i32 [[M]], -2
+; CHECK-NEXT:    [[S:%.*]] = select i1 [[TMP1]], i32 [[M]], i32 -2
+; CHECK-NEXT:    ret i32 [[S]]
+;
+  %m = call i32 @llvm.smin.i32(i32 %x, i32 -1)
+  %c = icmp sgt i32 %x, -2
+  %s = select i1 %c, i32 %m, i32 -2
+  ret i32 %s
+}
+
+define i8 @umax_umin(i8 %x) {
+; CHECK-LABEL: @umax_umin(
+; CHECK-NEXT:    [[M:%.*]] = call i8 @llvm.umax.i8(i8 [[X:%.*]], i8 -128)
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp ult i8 [[M]], -127
+; CHECK-NEXT:    [[S:%.*]] = select i1 [[TMP1]], i8 [[M]], i8 -127
+; CHECK-NEXT:    ret i8 [[S]]
+;
+  %m = call i8 @llvm.umax.i8(i8 %x, i8 128)
+  %c = icmp ult i8 %x, 129
+  %s = select i1 %c, i8 %m, i8 129
+  ret i8 %s
+}
+
+define i8 @umin_umax(i8 %x) {
+; CHECK-LABEL: @umin_umax(
+; CHECK-NEXT:    [[M:%.*]] = call i8 @llvm.umin.i8(i8 [[X:%.*]], i8 127)
+; CHECK-NEXT:    [[TMP1:%.*]] = icmp ugt i8 [[M]], 126
+; CHECK-NEXT:    [[S:%.*]] = select i1 [[TMP1]], i8 [[M]], i8 126
+; CHECK-NEXT:    ret i8 [[S]]
+;
+  %m = call i8 @llvm.umin.i8(i8 %x, i8 127)
+  %c = icmp ugt i8 %x, 126
+  %s = select i1 %c, i8 %m, i8 126
+  ret i8 %s
+}


        


More information about the llvm-commits mailing list