[llvm] 41dde5d - [InstSimplify] Support vectors in simplifyWithOpReplaced()

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Thu Sep 22 01:46:02 PDT 2022


Author: Nikita Popov
Date: 2022-09-22T10:45:42+02:00
New Revision: 41dde5d858cdfcddbd83dd6e17dc30594d47f2bd

URL: https://github.com/llvm/llvm-project/commit/41dde5d858cdfcddbd83dd6e17dc30594d47f2bd
DIFF: https://github.com/llvm/llvm-project/commit/41dde5d858cdfcddbd83dd6e17dc30594d47f2bd.diff

LOG: [InstSimplify] Support vectors in simplifyWithOpReplaced()

We can handle vectors inside simplifyWithOpReplaced(), as long as
cross-lane operations are excluded. The equality can hold (or not
hold) for each vector lane independently, so we shouldn't use the
replacement value from other lanes.

I believe the only operations relevant here are shufflevector (where
all previous bugs were seen) and calls (which might use shuffle-like
intrinsics and would require more careful classification).

Differential Revision: https://reviews.llvm.org/D134348

Added: 
    

Modified: 
    llvm/lib/Analysis/InstructionSimplify.cpp
    llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
    llvm/test/Transforms/InstCombine/select-safe-bool-transforms.ll
    llvm/test/Transforms/InstCombine/select.ll
    llvm/test/Transforms/InstSimplify/select.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index eebb5f1f09154..fe0e6388d2178 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -4082,8 +4082,6 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
                                      const SimplifyQuery &Q,
                                      bool AllowRefinement,
                                      unsigned MaxRecurse) {
-  assert(!Op->getType()->isVectorTy() && "This is not safe for vectors");
-
   // Trivial replacement.
   if (V == Op)
     return RepOp;
@@ -4096,6 +4094,14 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
   if (!I || !is_contained(I->operands(), Op))
     return nullptr;
 
+  if (Op->getType()->isVectorTy()) {
+    // For vector types, the simplification must hold per-lane, so forbid
+    // potentially cross-lane operations like shufflevector.
+    assert(I->getType()->isVectorTy() && "Vector type mismatch");
+    if (isa<ShuffleVectorInst>(I) || isa<CallBase>(I))
+      return nullptr;
+  }
+
   // Replace Op with RepOp in instruction operands.
   SmallVector<Value *, 8> NewOps(I->getNumOperands());
   transform(I->operands(), NewOps.begin(),
@@ -4325,9 +4331,7 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal,
   // If we have a scalar equality comparison, then we know the value in one of
   // the arms of the select. See if substituting this value into the arm and
   // simplifying the result yields the same value as the other arm.
-  // Note that the equivalence/replacement opportunity does not hold for vectors
-  // because each element of a vector select is chosen independently.
-  if (Pred == ICmpInst::ICMP_EQ && !CondVal->getType()->isVectorTy()) {
+  if (Pred == ICmpInst::ICMP_EQ) {
     if (simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q,
                                /* AllowRefinement */ false,
                                MaxRecurse) == TrueVal ||

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index a0631381f12f8..b316be4f92633 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -1164,10 +1164,7 @@ static Instruction *canonicalizeSPF(SelectInst &Sel, ICmpInst &Cmp,
 /// TODO: Wrapping flags could be preserved in some cases with better analysis.
 Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel,
                                                           ICmpInst &Cmp) {
-  // Value equivalence substitution requires an all-or-nothing replacement.
-  // It does not make sense for a vector compare where each lane is chosen
-  // independently.
-  if (!Cmp.isEquality() || Cmp.getType()->isVectorTy())
+  if (!Cmp.isEquality())
     return nullptr;
 
   // Canonicalize the pattern to ICMP_EQ by swapping the select operands.
@@ -1197,7 +1194,9 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel,
     // undefined behavior). Only do this if CmpRHS is a constant, as
     // profitability is not clear for other cases.
     // FIXME: The replacement could be performed recursively.
-    if (match(CmpRHS, m_ImmConstant()) && !match(CmpLHS, m_ImmConstant()))
+    // FIXME: Support vectors.
+    if (match(CmpRHS, m_ImmConstant()) && !match(CmpLHS, m_ImmConstant()) &&
+        !Cmp.getType()->isVectorTy())
       if (auto *I = dyn_cast<Instruction>(TrueVal))
         if (I->hasOneUse() && isSafeToSpeculativelyExecute(I))
           for (Use &U : I->operands())
@@ -2795,14 +2794,12 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
       return BinaryOperator::CreateAnd(FalseVal, Builder.CreateOr(C, TrueVal));
     }
 
-    if (!SelType->isVectorTy()) {
-      if (Value *S = simplifyWithOpReplaced(TrueVal, CondVal, One, SQ,
-                                            /* AllowRefinement */ true))
-        return replaceOperand(SI, 1, S);
-      if (Value *S = simplifyWithOpReplaced(FalseVal, CondVal, Zero, SQ,
-                                            /* AllowRefinement */ true))
-        return replaceOperand(SI, 2, S);
-    }
+    if (Value *S = simplifyWithOpReplaced(TrueVal, CondVal, One, SQ,
+                                          /* AllowRefinement */ true))
+      return replaceOperand(SI, 1, S);
+    if (Value *S = simplifyWithOpReplaced(FalseVal, CondVal, Zero, SQ,
+                                          /* AllowRefinement */ true))
+      return replaceOperand(SI, 2, S);
 
     if (match(FalseVal, m_Zero()) || match(TrueVal, m_One())) {
       Use *Y = nullptr;

diff  --git a/llvm/test/Transforms/InstCombine/select-safe-bool-transforms.ll b/llvm/test/Transforms/InstCombine/select-safe-bool-transforms.ll
index e723ad9505f3c..3366775ee4ea6 100644
--- a/llvm/test/Transforms/InstCombine/select-safe-bool-transforms.ll
+++ b/llvm/test/Transforms/InstCombine/select-safe-bool-transforms.ll
@@ -308,9 +308,7 @@ define <2 x i1> @land_lor_right1_vec(<2 x i1> %A, <2 x i1> %B) {
 }
 define <2 x i1> @land_lor_right2_vec(<2 x i1> %A, <2 x i1> %B) {
 ; CHECK-LABEL: @land_lor_right2_vec(
-; CHECK-NEXT:    [[C:%.*]] = select <2 x i1> [[B:%.*]], <2 x i1> [[A:%.*]], <2 x i1> zeroinitializer
-; CHECK-NEXT:    [[RES:%.*]] = select <2 x i1> [[A]], <2 x i1> <i1 true, i1 true>, <2 x i1> [[C]]
-; CHECK-NEXT:    ret <2 x i1> [[RES]]
+; CHECK-NEXT:    ret <2 x i1> [[A:%.*]]
 ;
   %c = select <2 x i1> %B, <2 x i1> %A, <2 x i1> zeroinitializer
   %res = select <2 x i1> %A, <2 x i1> <i1 true, i1 true>, <2 x i1> %c

diff  --git a/llvm/test/Transforms/InstCombine/select.ll b/llvm/test/Transforms/InstCombine/select.ll
index 59a777df82447..57aada15b936b 100644
--- a/llvm/test/Transforms/InstCombine/select.ll
+++ b/llvm/test/Transforms/InstCombine/select.ll
@@ -1460,10 +1460,8 @@ define i32 @PR27817_nsw(i32 %x) {
 
 define <2 x i32> @PR27817_nsw_vec(<2 x i32> %x) {
 ; CHECK-LABEL: @PR27817_nsw_vec(
-; CHECK-NEXT:    [[CMP:%.*]] = icmp eq <2 x i32> [[X:%.*]], <i32 -2147483648, i32 -2147483648>
-; CHECK-NEXT:    [[SUB:%.*]] = sub nsw <2 x i32> zeroinitializer, [[X]]
-; CHECK-NEXT:    [[SEL:%.*]] = select <2 x i1> [[CMP]], <2 x i32> <i32 -2147483648, i32 -2147483648>, <2 x i32> [[SUB]]
-; CHECK-NEXT:    ret <2 x i32> [[SEL]]
+; CHECK-NEXT:    [[SUB:%.*]] = sub <2 x i32> zeroinitializer, [[X:%.*]]
+; CHECK-NEXT:    ret <2 x i32> [[SUB]]
 ;
   %cmp = icmp eq <2 x i32> %x, <i32 -2147483648, i32 -2147483648>
   %sub = sub nsw <2 x i32> zeroinitializer, %x
@@ -2785,8 +2783,7 @@ define i8 @select_replacement_add_eq(i8 %x, i8 %y) {
 define <2 x i8> @select_replacement_add_eq_vec(<2 x i8> %x, <2 x i8> %y) {
 ; CHECK-LABEL: @select_replacement_add_eq_vec(
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp eq <2 x i8> [[X:%.*]], <i8 1, i8 1>
-; CHECK-NEXT:    [[ADD:%.*]] = add <2 x i8> [[X]], <i8 1, i8 1>
-; CHECK-NEXT:    [[SEL:%.*]] = select <2 x i1> [[CMP]], <2 x i8> [[ADD]], <2 x i8> [[Y:%.*]]
+; CHECK-NEXT:    [[SEL:%.*]] = select <2 x i1> [[CMP]], <2 x i8> <i8 2, i8 2>, <2 x i8> [[Y:%.*]]
 ; CHECK-NEXT:    ret <2 x i8> [[SEL]]
 ;
   %cmp = icmp eq <2 x i8> %x, <i8 1, i8 1>
@@ -2798,8 +2795,7 @@ define <2 x i8> @select_replacement_add_eq_vec(<2 x i8> %x, <2 x i8> %y) {
 define <2 x i8> @select_replacement_add_eq_vec_nonuniform(<2 x i8> %x, <2 x i8> %y) {
 ; CHECK-LABEL: @select_replacement_add_eq_vec_nonuniform(
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp eq <2 x i8> [[X:%.*]], <i8 1, i8 2>
-; CHECK-NEXT:    [[ADD:%.*]] = add <2 x i8> [[X]], <i8 3, i8 4>
-; CHECK-NEXT:    [[SEL:%.*]] = select <2 x i1> [[CMP]], <2 x i8> [[ADD]], <2 x i8> [[Y:%.*]]
+; CHECK-NEXT:    [[SEL:%.*]] = select <2 x i1> [[CMP]], <2 x i8> <i8 4, i8 6>, <2 x i8> [[Y:%.*]]
 ; CHECK-NEXT:    ret <2 x i8> [[SEL]]
 ;
   %cmp = icmp eq <2 x i8> %x, <i8 1, i8 2>

diff  --git a/llvm/test/Transforms/InstSimplify/select.ll b/llvm/test/Transforms/InstSimplify/select.ll
index 756a019a9a0cd..b34d4840e72c0 100644
--- a/llvm/test/Transforms/InstSimplify/select.ll
+++ b/llvm/test/Transforms/InstSimplify/select.ll
@@ -988,10 +988,8 @@ define i32 @select_neutral_add_lhs(i32 %x, i32 %y) {
 
 define <2 x i32> @select_neutral_add_rhs_vec(<2 x i32> %x, <2 x i32> %y) {
 ; CHECK-LABEL: @select_neutral_add_rhs_vec(
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ne <2 x i32> [[Y:%.*]], zeroinitializer
-; CHECK-NEXT:    [[ADD:%.*]] = add <2 x i32> [[X:%.*]], [[Y]]
-; CHECK-NEXT:    [[SEL:%.*]] = select <2 x i1> [[CMP]], <2 x i32> [[ADD]], <2 x i32> [[X]]
-; CHECK-NEXT:    ret <2 x i32> [[SEL]]
+; CHECK-NEXT:    [[ADD:%.*]] = add <2 x i32> [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    ret <2 x i32> [[ADD]]
 ;
   %cmp = icmp ne <2 x i32> %y, zeroinitializer
   %add = add <2 x i32> %x, %y
@@ -1001,10 +999,8 @@ define <2 x i32> @select_neutral_add_rhs_vec(<2 x i32> %x, <2 x i32> %y) {
 
 define <2 x i32> @select_neutral_add_lhs_vec(<2 x i32> %x, <2 x i32> %y) {
 ; CHECK-LABEL: @select_neutral_add_lhs_vec(
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ne <2 x i32> [[Y:%.*]], zeroinitializer
-; CHECK-NEXT:    [[ADD:%.*]] = add <2 x i32> [[Y]], [[X:%.*]]
-; CHECK-NEXT:    [[SEL:%.*]] = select <2 x i1> [[CMP]], <2 x i32> [[ADD]], <2 x i32> [[X]]
-; CHECK-NEXT:    ret <2 x i32> [[SEL]]
+; CHECK-NEXT:    [[ADD:%.*]] = add <2 x i32> [[Y:%.*]], [[X:%.*]]
+; CHECK-NEXT:    ret <2 x i32> [[ADD]]
 ;
   %cmp = icmp ne <2 x i32> %y, zeroinitializer
   %add = add <2 x i32> %y, %x
@@ -1047,6 +1043,7 @@ define i32 @select_ctpop_zero(i32 %x) {
   ret i32 %sel
 }
 
+; FIXME: This is safe to fold.
 define <2 x i32> @select_ctpop_zero_vec(<2 x i32> %x) {
 ; CHECK-LABEL: @select_ctpop_zero_vec(
 ; CHECK-NEXT:    [[T0:%.*]] = icmp eq <2 x i32> [[X:%.*]], zeroinitializer
@@ -1060,6 +1057,7 @@ define <2 x i32> @select_ctpop_zero_vec(<2 x i32> %x) {
   ret <2 x i32> %sel
 }
 
+; Negative test: Cannot fold due to cross-lane intrinsic.
 define <2 x i32> @select_vector_reverse(<2 x i32> %x) {
 ; CHECK-LABEL: @select_vector_reverse(
 ; CHECK-NEXT:    [[CMP:%.*]] = icmp eq <2 x i32> [[X:%.*]], zeroinitializer


        


More information about the llvm-commits mailing list