[llvm] 7bb8bfa - [InstCombine] fix miscompile from vector select substitution

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Sun May 30 04:12:07 PDT 2021


Author: Sanjay Patel
Date: 2021-05-30T07:11:58-04:00
New Revision: 7bb8bfa0622b8ee55c3f748004dcf4d83d48cf97

URL: https://github.com/llvm/llvm-project/commit/7bb8bfa0622b8ee55c3f748004dcf4d83d48cf97
DIFF: https://github.com/llvm/llvm-project/commit/7bb8bfa0622b8ee55c3f748004dcf4d83d48cf97.diff

LOG: [InstCombine] fix miscompile from vector select substitution

This is similar to the fix in c590a9880d7a ( PR49832 ), but
we missed handling the pattern for select of bools (no compare
inst).

We can't substitute a vector value because the equality condition
replacement that we are attempting requires that the condition
is true/false for the entire value. Vector select can be partly
true/false.

I added an assert for vector types, so we shouldn't hit this again.
Fixed formatting while auditing the callers.

https://llvm.org/PR50500

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/InstructionSimplify.h
    llvm/lib/Analysis/InstructionSimplify.cpp
    llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
    llvm/test/Transforms/InstCombine/select-safe-bool-transforms.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/InstructionSimplify.h b/llvm/include/llvm/Analysis/InstructionSimplify.h
index e1e7da14376e6..75ce4e38df166 100644
--- a/llvm/include/llvm/Analysis/InstructionSimplify.h
+++ b/llvm/include/llvm/Analysis/InstructionSimplify.h
@@ -299,7 +299,7 @@ Value *SimplifyInstruction(Instruction *I, const SimplifyQuery &Q,
 /// return null.
 /// AllowRefinement specifies whether the simplification can be a refinement
 /// (e.g. 0 instead of poison), or whether it needs to be strictly identical.
-Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
+Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
                               const SimplifyQuery &Q, bool AllowRefinement);
 
 /// Replace all uses of 'I' with 'SimpleV' and simplify the uses recursively.

diff  --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index c72670b901fe5..e6baed1779cd1 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -3852,10 +3852,12 @@ Value *llvm::SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS,
   return ::SimplifyFCmpInst(Predicate, LHS, RHS, FMF, Q, RecursionLimit);
 }
 
-static Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
+static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
                                      const SimplifyQuery &Q,
                                      bool AllowRefinement,
                                      unsigned MaxRecurse) {
+  assert(!Op->getType()->isVectorTy() && "This is not safe for vectors");
+
   // Trivial replacement.
   if (V == Op)
     return RepOp;
@@ -3965,10 +3967,10 @@ static Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
   return ConstantFoldInstOperands(I, ConstOps, Q.DL, Q.TLI);
 }
 
-Value *llvm::SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
+Value *llvm::simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
                                     const SimplifyQuery &Q,
                                     bool AllowRefinement) {
-  return ::SimplifyWithOpReplaced(V, Op, RepOp, Q, AllowRefinement,
+  return ::simplifyWithOpReplaced(V, Op, RepOp, Q, AllowRefinement,
                                   RecursionLimit);
 }
 
@@ -4090,17 +4092,17 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal,
   // Note that the equivalence/replacement opportunity does not hold for vectors
   // because each element of a vector select is chosen independently.
   if (Pred == ICmpInst::ICMP_EQ && !CondVal->getType()->isVectorTy()) {
-    if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q,
+    if (simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q,
                                /* AllowRefinement */ false, MaxRecurse) ==
             TrueVal ||
-        SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q,
+        simplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q,
                                /* AllowRefinement */ false, MaxRecurse) ==
             TrueVal)
       return FalseVal;
-    if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q,
+    if (simplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q,
                                /* AllowRefinement */ true, MaxRecurse) ==
             FalseVal ||
-        SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q,
+        simplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q,
                                /* AllowRefinement */ true, MaxRecurse) ==
             FalseVal)
       return FalseVal;

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 68124a188f716..f48cc901b7428 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -1142,7 +1142,7 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel,
   Value *CmpLHS = Cmp.getOperand(0), *CmpRHS = Cmp.getOperand(1);
   if (TrueVal != CmpLHS &&
       isGuaranteedNotToBeUndefOrPoison(CmpRHS, SQ.AC, &Sel, &DT)) {
-    if (Value *V = SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, SQ,
+    if (Value *V = simplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, SQ,
                                           /* AllowRefinement */ true))
       return replaceOperand(Sel, Swapped ? 2 : 1, V);
 
@@ -1164,7 +1164,7 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel,
   }
   if (TrueVal != CmpRHS &&
       isGuaranteedNotToBeUndefOrPoison(CmpLHS, SQ.AC, &Sel, &DT))
-    if (Value *V = SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, SQ,
+    if (Value *V = simplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, SQ,
                                           /* AllowRefinement */ true))
       return replaceOperand(Sel, Swapped ? 2 : 1, V);
 
@@ -1195,9 +1195,9 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel,
   // We have an 'EQ' comparison, so the select's false value will propagate.
   // Example:
   // (X == 42) ? 43 : (X + 1) --> (X == 42) ? (X + 1) : (X + 1) --> X + 1
-  if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, SQ,
+  if (simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, SQ,
                              /* AllowRefinement */ false) == TrueVal ||
-      SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, SQ,
+      simplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, SQ,
                              /* AllowRefinement */ false) == TrueVal) {
     return replaceInstUsesWith(Sel, FalseVal);
   }
@@ -2714,12 +2714,14 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
         match(TrueVal, m_Specific(B)) && match(FalseVal, m_Zero()))
       return replaceOperand(SI, 0, A);
 
-    if (Value *S = SimplifyWithOpReplaced(TrueVal, CondVal, One, SQ,
-                                          /* AllowRefinement */ true))
-      return replaceOperand(SI, 1, S);
-    if (Value *S = SimplifyWithOpReplaced(FalseVal, CondVal, Zero, SQ,
-                                          /* AllowRefinement */ true))
-      return replaceOperand(SI, 2, S);
+    if (!SelType->isVectorTy()) {
+      if (Value *S = simplifyWithOpReplaced(TrueVal, CondVal, One, SQ,
+                                            /* AllowRefinement */ true))
+        return replaceOperand(SI, 1, S);
+      if (Value *S = simplifyWithOpReplaced(FalseVal, CondVal, Zero, SQ,
+                                            /* AllowRefinement */ true))
+        return replaceOperand(SI, 2, S);
+    }
 
     if (match(FalseVal, m_Zero()) || match(TrueVal, m_One())) {
       Use *Y = nullptr;

diff  --git a/llvm/test/Transforms/InstCombine/select-safe-bool-transforms.ll b/llvm/test/Transforms/InstCombine/select-safe-bool-transforms.ll
index c15a64ee73152..fef4081c0bb6e 100644
--- a/llvm/test/Transforms/InstCombine/select-safe-bool-transforms.ll
+++ b/llvm/test/Transforms/InstCombine/select-safe-bool-transforms.ll
@@ -468,3 +468,28 @@ define i1 @bor_lor_right2(i1 %A, i1 %B) {
   ret i1 %res
 }
 
+; Value equivalence substitution does not account for vector
+; transforms, so it needs a scalar condition operand.
+; For example, this would miscompile if %a = {1, 0}.
+
+define <2 x i1> @PR50500_trueval(<2 x i1> %a, <2 x i1> %b) {
+; CHECK-LABEL: @PR50500_trueval(
+; CHECK-NEXT:    [[S:%.*]] = shufflevector <2 x i1> [[A:%.*]], <2 x i1> poison, <2 x i32> <i32 1, i32 0>
+; CHECK-NEXT:    [[R:%.*]] = select <2 x i1> [[A]], <2 x i1> [[S]], <2 x i1> [[B:%.*]]
+; CHECK-NEXT:    ret <2 x i1> [[R]]
+;
+  %s = shufflevector <2 x i1> %a, <2 x i1> poison, <2 x i32> <i32 1, i32 0>
+  %r = select <2 x i1> %a, <2 x i1> %s, <2 x i1> %b
+  ret <2 x i1> %r
+}
+
+define <2 x i1> @PR50500_falseval(<2 x i1> %a, <2 x i1> %b) {
+; CHECK-LABEL: @PR50500_falseval(
+; CHECK-NEXT:    [[S:%.*]] = shufflevector <2 x i1> [[A:%.*]], <2 x i1> poison, <2 x i32> <i32 1, i32 0>
+; CHECK-NEXT:    [[R:%.*]] = select <2 x i1> [[A]], <2 x i1> [[B:%.*]], <2 x i1> [[S]]
+; CHECK-NEXT:    ret <2 x i1> [[R]]
+;
+  %s = shufflevector <2 x i1> %a, <2 x i1> poison, <2 x i32> <i32 1, i32 0>
+  %r = select <2 x i1> %a, <2 x i1> %b, <2 x i1> %s
+  ret <2 x i1> %r
+}


        


More information about the llvm-commits mailing list