[llvm] 9d1c8c0 - [InstCombine] Fix select operand simplification with undef (PR47696)

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Thu Oct 1 12:16:31 PDT 2020


Author: Nikita Popov
Date: 2020-10-01T21:15:48+02:00
New Revision: 9d1c8c0ba94a273c53829f0800335045e547db88

URL: https://github.com/llvm/llvm-project/commit/9d1c8c0ba94a273c53829f0800335045e547db88
DIFF: https://github.com/llvm/llvm-project/commit/9d1c8c0ba94a273c53829f0800335045e547db88.diff

LOG: [InstCombine] Fix select operand simplification with undef (PR47696)

When replacing X == Y ? f(X) : Z with X == Y ? f(Y) : Z, make sure
that Y cannot be undef. If it may be undef, we might end up picking
a different value for undef in the comparison and the select
operand.

Added: 
    

Modified: 
    llvm/lib/Transforms/InstCombine/InstCombineInternal.h
    llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
    llvm/test/Transforms/InstCombine/select-binop-cmp.ll
    llvm/test/Transforms/InstCombine/select.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 62ee7d00780e..eef56c8645f8 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -711,6 +711,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
                             Value *A, Value *B, Instruction &Outer,
                             SelectPatternFlavor SPF2, Value *C);
   Instruction *foldSelectInstWithICmp(SelectInst &SI, ICmpInst *ICI);
+  Instruction *foldSelectValueEquivalence(SelectInst &SI, ICmpInst &ICI);
 
   Instruction *OptAndOp(BinaryOperator *Op, ConstantInt *OpRHS,
                         ConstantInt *AndRHS, BinaryOperator &TheAnd);

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index ce473410f4ca..087586ede808 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -1165,9 +1165,8 @@ static Instruction *canonicalizeAbsNabs(SelectInst &Sel, ICmpInst &Cmp,
 ///
 /// We can't replace %sel with %add unless we strip away the flags.
 /// TODO: Wrapping flags could be preserved in some cases with better analysis.
-static Instruction *foldSelectValueEquivalence(SelectInst &Sel, ICmpInst &Cmp,
-                                               const SimplifyQuery &Q,
-                                               InstCombiner &IC) {
+Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel,
+                                                          ICmpInst &Cmp) {
   if (!Cmp.isEquality())
     return nullptr;
 
@@ -1179,18 +1178,20 @@ static Instruction *foldSelectValueEquivalence(SelectInst &Sel, ICmpInst &Cmp,
     Swapped = true;
   }
 
-  // In X == Y ? f(X) : Z, try to evaluate f(X) and replace the operand.
-  // Take care to avoid replacing X == Y ? X : Z with X == Y ? Y : Z, as that
-  // would lead to an infinite replacement cycle.
+  // In X == Y ? f(X) : Z, try to evaluate f(Y) and replace the operand.
+  // Make sure Y cannot be undef though, as we might pick 
diff erent values for
+  // undef in the icmp and in f(Y). Additionally, take care to avoid replacing
+  // X == Y ? X : Z with X == Y ? Y : Z, as that would lead to an infinite
+  // replacement cycle.
   Value *CmpLHS = Cmp.getOperand(0), *CmpRHS = Cmp.getOperand(1);
-  if (TrueVal != CmpLHS)
-    if (Value *V = SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q,
+  if (TrueVal != CmpLHS && isGuaranteedNotToBeUndefOrPoison(CmpRHS, &Sel, &DT))
+    if (Value *V = SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, SQ,
                                           /* AllowRefinement */ true))
-      return IC.replaceOperand(Sel, Swapped ? 2 : 1, V);
-  if (TrueVal != CmpRHS)
-    if (Value *V = SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q,
+      return replaceOperand(Sel, Swapped ? 2 : 1, V);
+  if (TrueVal != CmpRHS && isGuaranteedNotToBeUndefOrPoison(CmpLHS, &Sel, &DT))
+    if (Value *V = SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, SQ,
                                           /* AllowRefinement */ true))
-      return IC.replaceOperand(Sel, Swapped ? 2 : 1, V);
+      return replaceOperand(Sel, Swapped ? 2 : 1, V);
 
   auto *FalseInst = dyn_cast<Instruction>(FalseVal);
   if (!FalseInst)
@@ -1215,11 +1216,11 @@ static Instruction *foldSelectValueEquivalence(SelectInst &Sel, ICmpInst &Cmp,
   // We have an 'EQ' comparison, so the select's false value will propagate.
   // Example:
   // (X == 42) ? 43 : (X + 1) --> (X == 42) ? (X + 1) : (X + 1) --> X + 1
-  if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q,
+  if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, SQ,
                              /* AllowRefinement */ false) == TrueVal ||
-      SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q,
+      SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, SQ,
                              /* AllowRefinement */ false) == TrueVal) {
-    return IC.replaceInstUsesWith(Sel, FalseVal);
+    return replaceInstUsesWith(Sel, FalseVal);
   }
 
   // Restore poison-generating flags if the transform did not apply.
@@ -1455,7 +1456,7 @@ tryToReuseConstantFromSelectInComparison(SelectInst &Sel, ICmpInst &Cmp,
 /// Visit a SelectInst that has an ICmpInst as its first operand.
 Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
                                                       ICmpInst *ICI) {
-  if (Instruction *NewSel = foldSelectValueEquivalence(SI, *ICI, SQ, *this))
+  if (Instruction *NewSel = foldSelectValueEquivalence(SI, *ICI))
     return NewSel;
 
   if (Instruction *NewSel = canonicalizeMinMaxWithConstant(SI, *ICI, *this))

diff  --git a/llvm/test/Transforms/InstCombine/select-binop-cmp.ll b/llvm/test/Transforms/InstCombine/select-binop-cmp.ll
index aa450f8af8b7..c4a9d0941b96 100644
--- a/llvm/test/Transforms/InstCombine/select-binop-cmp.ll
+++ b/llvm/test/Transforms/InstCombine/select-binop-cmp.ll
@@ -564,10 +564,12 @@ define <2 x i8> @select_xor_icmp_vec_bad(<2 x i8> %x, <2 x i8> %y, <2 x i8> %z)
   ret <2 x i8>  %C
 }
 
+; Folding this would only be legal if we sanitized undef to 0.
 define <2 x i8> @select_xor_icmp_vec_undef(<2 x i8> %x, <2 x i8> %y, <2 x i8> %z) {
 ; CHECK-LABEL: @select_xor_icmp_vec_undef(
 ; CHECK-NEXT:    [[A:%.*]] = icmp eq <2 x i8> [[X:%.*]], <i8 0, i8 undef>
-; CHECK-NEXT:    [[C:%.*]] = select <2 x i1> [[A]], <2 x i8> [[Z:%.*]], <2 x i8> [[Y:%.*]]
+; CHECK-NEXT:    [[B:%.*]] = xor <2 x i8> [[X]], [[Z:%.*]]
+; CHECK-NEXT:    [[C:%.*]] = select <2 x i1> [[A]], <2 x i8> [[B]], <2 x i8> [[Y:%.*]]
 ; CHECK-NEXT:    ret <2 x i8> [[C]]
 ;
   %A = icmp eq <2 x i8>  %x, <i8 0, i8 undef>

diff  --git a/llvm/test/Transforms/InstCombine/select.ll b/llvm/test/Transforms/InstCombine/select.ll
index b7c4cb5c6420..df506477eed1 100644
--- a/llvm/test/Transforms/InstCombine/select.ll
+++ b/llvm/test/Transforms/InstCombine/select.ll
@@ -2641,10 +2641,24 @@ define i8 @select_replacement_add_nuw(i8 %x, i8 %y) {
   ret i8 %sel
 }
 
+define i8 @select_replacement_sub_noundef(i8 %x, i8 noundef %y, i8 %z) {
+; CHECK-LABEL: @select_replacement_sub_noundef(
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8 [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[CMP]], i8 0, i8 [[Z:%.*]]
+; CHECK-NEXT:    ret i8 [[SEL]]
+;
+  %cmp = icmp eq i8 %x, %y
+  %sub = sub i8 %x, %y
+  %sel = select i1 %cmp, i8 %sub, i8 %z
+  ret i8 %sel
+}
+
+; TODO: The transform is also safe without noundef.
 define i8 @select_replacement_sub(i8 %x, i8 %y, i8 %z) {
 ; CHECK-LABEL: @select_replacement_sub(
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8 [[X:%.*]], [[Y:%.*]]
-; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[CMP]], i8 0, i8 [[Z:%.*]]
+; CHECK-NEXT:    [[SUB:%.*]] = sub i8 [[X]], [[Y]]
+; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[CMP]], i8 [[SUB]], i8 [[Z:%.*]]
 ; CHECK-NEXT:    ret i8 [[SEL]]
 ;
   %cmp = icmp eq i8 %x, %y
@@ -2653,11 +2667,29 @@ define i8 @select_replacement_sub(i8 %x, i8 %y, i8 %z) {
   ret i8 %sel
 }
 
+define i8 @select_replacement_shift_noundef(i8 %x, i8 %y, i8 %z) {
+; CHECK-LABEL: @select_replacement_shift_noundef(
+; CHECK-NEXT:    [[SHR:%.*]] = lshr exact i8 [[X:%.*]], 1
+; CHECK-NEXT:    call void @use_i8(i8 noundef [[SHR]])
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8 [[SHR]], [[Y:%.*]]
+; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[CMP]], i8 [[X]], i8 [[Z:%.*]]
+; CHECK-NEXT:    ret i8 [[SEL]]
+;
+  %shr = lshr exact i8 %x, 1
+  call void @use_i8(i8 noundef %shr)
+  %cmp = icmp eq i8 %shr, %y
+  %shl = shl i8 %y, 1
+  %sel = select i1 %cmp, i8 %shl, i8 %z
+  ret i8 %sel
+}
+
+; TODO: The transform is also safe without noundef.
 define i8 @select_replacement_shift(i8 %x, i8 %y, i8 %z) {
 ; CHECK-LABEL: @select_replacement_shift(
 ; CHECK-NEXT:    [[SHR:%.*]] = lshr exact i8 [[X:%.*]], 1
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8 [[SHR]], [[Y:%.*]]
-; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[CMP]], i8 [[X]], i8 [[Z:%.*]]
+; CHECK-NEXT:    [[SHL:%.*]] = shl i8 [[Y]], 1
+; CHECK-NEXT:    [[SEL:%.*]] = select i1 [[CMP]], i8 [[SHL]], i8 [[Z:%.*]]
 ; CHECK-NEXT:    ret i8 [[SEL]]
 ;
   %shr = lshr exact i8 %x, 1
@@ -2694,4 +2726,5 @@ define i32 @select_replacement_loop2(i32 %arg, i32 %arg2) {
 }
 
 declare void @use(i1)
+declare void @use_i8(i8)
 declare i32 @llvm.cttz.i32(i32, i1 immarg)


        


More information about the llvm-commits mailing list