[llvm] [InstCombine] Fix poison propagation in select of bitwise fold (PR #89701)

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Tue Apr 23 17:42:30 PDT 2024


https://github.com/nikic updated https://github.com/llvm/llvm-project/pull/89701

>From 2bf2afa70533d4ef8cc934cbf9c82d64e709b0ca Mon Sep 17 00:00:00 2001
From: Nikita Popov <npopov at redhat.com>
Date: Tue, 23 Apr 2024 11:29:20 +0900
Subject: [PATCH 1/2] [InstCombine] Fix poison propagation in select of bitwise
 fold

We're replacing the select with the false value here, but it may
be more poisonous if m_Not contains poison elements. Fix this
by introducing a m_NotForbidPoison matcher and using it here.

Fixes https://github.com/llvm/llvm-project/issues/89500.
---
 llvm/include/llvm/IR/PatternMatch.h           | 25 ++++++++++++++-----
 .../InstCombine/InstCombineSelect.cpp         |  8 +++---
 llvm/test/Transforms/InstCombine/select.ll    | 11 +++++---
 3 files changed, 30 insertions(+), 14 deletions(-)

diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 1fee1901fabb65d..0b13b4aad9c326a 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -350,8 +350,9 @@ template <int64_t Val> inline constantint_match<Val> m_ConstantInt() {
 
 /// This helper class is used to match constant scalars, vector splats,
 /// and fixed width vectors that satisfy a specified predicate.
-/// For fixed width vector constants, poison elements are ignored.
-template <typename Predicate, typename ConstantVal>
+/// For fixed width vector constants, poison elements are ignored if AllowPoison
+/// is true.
+template <typename Predicate, typename ConstantVal, bool AllowPoison>
 struct cstval_pred_ty : public Predicate {
   template <typename ITy> bool match(ITy *V) {
     if (const auto *CV = dyn_cast<ConstantVal>(V))
@@ -374,7 +375,7 @@ struct cstval_pred_ty : public Predicate {
           Constant *Elt = C->getAggregateElement(i);
           if (!Elt)
             return false;
-          if (isa<PoisonValue>(Elt))
+          if (AllowPoison && isa<PoisonValue>(Elt))
             continue;
           auto *CV = dyn_cast<ConstantVal>(Elt);
           if (!CV || !this->isValue(CV->getValue()))
@@ -389,12 +390,13 @@ struct cstval_pred_ty : public Predicate {
 };
 
 /// specialization of cstval_pred_ty for ConstantInt
-template <typename Predicate>
-using cst_pred_ty = cstval_pred_ty<Predicate, ConstantInt>;
+template <typename Predicate, bool AllowPoison = true>
+using cst_pred_ty = cstval_pred_ty<Predicate, ConstantInt, AllowPoison>;
 
 /// specialization of cstval_pred_ty for ConstantFP
 template <typename Predicate>
-using cstfp_pred_ty = cstval_pred_ty<Predicate, ConstantFP>;
+using cstfp_pred_ty = cstval_pred_ty<Predicate, ConstantFP,
+                                     /*AllowPoison=*/true>;
 
 /// This helper class is used to match scalar and vector constants that
 /// satisfy a specified predicate, and bind them to an APInt.
@@ -484,6 +486,10 @@ inline cst_pred_ty<is_all_ones> m_AllOnes() {
   return cst_pred_ty<is_all_ones>();
 }
 
+inline cst_pred_ty<is_all_ones, false> m_AllOnesForbidPoison() {
+  return cst_pred_ty<is_all_ones, false>();
+}
+
 struct is_maxsignedvalue {
   bool isValue(const APInt &C) { return C.isMaxSignedValue(); }
 };
@@ -2596,6 +2602,13 @@ m_Not(const ValTy &V) {
   return m_c_Xor(m_AllOnes(), V);
 }
 
+template <typename ValTy>
+inline BinaryOp_match<cst_pred_ty<is_all_ones, false>, ValTy, Instruction::Xor,
+                      true>
+m_NotForbidPoison(const ValTy &V) {
+  return m_c_Xor(m_AllOnesForbidPoison(), V);
+}
+
 /// Matches an SMin with LHS and RHS in either order.
 template <typename LHS, typename RHS>
 inline MaxMin_match<ICmpInst, LHS, RHS, smin_pred_ty, true>
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 73600206a55c145..117eb7a1dcc933d 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -1722,11 +1722,11 @@ static Instruction *foldSelectICmpEq(SelectInst &SI, ICmpInst *ICI,
       return match(CmpRHS, m_Zero()) && match(FalseVal, matchInner);
 
     if (NotMask == NotInner) {
-      return match(FalseVal,
-                   m_c_BinOp(OuterOpc, m_Not(matchInner), m_Specific(CmpRHS)));
+      return match(FalseVal, m_c_BinOp(OuterOpc, m_NotForbidPoison(matchInner),
+                                       m_Specific(CmpRHS)));
     } else if (NotMask == NotRHS) {
-      return match(FalseVal,
-                   m_c_BinOp(OuterOpc, matchInner, m_Not(m_Specific(CmpRHS))));
+      return match(FalseVal, m_c_BinOp(OuterOpc, matchInner,
+                                       m_NotForbidPoison(m_Specific(CmpRHS))));
     } else {
       return match(FalseVal,
                    m_c_BinOp(OuterOpc, matchInner, m_Specific(CmpRHS)));
diff --git a/llvm/test/Transforms/InstCombine/select.ll b/llvm/test/Transforms/InstCombine/select.ll
index 2ec092a745c52c7..87e9d1779e30ba0 100644
--- a/llvm/test/Transforms/InstCombine/select.ll
+++ b/llvm/test/Transforms/InstCombine/select.ll
@@ -3830,14 +3830,17 @@ entry:
   ret i32 %cond
 }
 
-; FIXME: This is a miscompile.
 define <2 x i32> @src_and_eq_C_xor_OrAndNotC_vec_poison(<2 x i32> %0, <2 x i32> %1, <2 x i32> %2) {
 ; CHECK-LABEL: @src_and_eq_C_xor_OrAndNotC_vec_poison(
 ; CHECK-NEXT:  entry:
-; CHECK-NEXT:    [[OR:%.*]] = or <2 x i32> [[TMP1:%.*]], [[TMP0:%.*]]
-; CHECK-NEXT:    [[NOT:%.*]] = xor <2 x i32> [[TMP2:%.*]], <i32 -1, i32 poison>
+; CHECK-NEXT:    [[AND:%.*]] = and <2 x i32> [[TMP1:%.*]], [[TMP0:%.*]]
+; CHECK-NEXT:    [[CMP:%.*]] = icmp eq <2 x i32> [[AND]], [[TMP2:%.*]]
+; CHECK-NEXT:    [[XOR:%.*]] = xor <2 x i32> [[TMP1]], [[TMP0]]
+; CHECK-NEXT:    [[OR:%.*]] = or <2 x i32> [[TMP1]], [[TMP0]]
+; CHECK-NEXT:    [[NOT:%.*]] = xor <2 x i32> [[TMP2]], <i32 -1, i32 poison>
 ; CHECK-NEXT:    [[AND1:%.*]] = and <2 x i32> [[OR]], [[NOT]]
-; CHECK-NEXT:    ret <2 x i32> [[AND1]]
+; CHECK-NEXT:    [[COND:%.*]] = select <2 x i1> [[CMP]], <2 x i32> [[XOR]], <2 x i32> [[AND1]]
+; CHECK-NEXT:    ret <2 x i32> [[COND]]
 ;
 entry:
   %and = and <2 x i32> %1, %0

>From a3871265cc8322bdd1ac506315b4a8aa42f21747 Mon Sep 17 00:00:00 2001
From: Nikita Popov <npopov at redhat.com>
Date: Wed, 24 Apr 2024 09:41:41 +0900
Subject: [PATCH 2/2] Add unit tests for m_NotForbidPoison

This is kind of the replacement for the m_NotForbidUndef matcher
we used to have.
---
 llvm/unittests/IR/PatternMatch.cpp | 12 +++++++++++-
 1 file changed, 11 insertions(+), 1 deletion(-)

diff --git a/llvm/unittests/IR/PatternMatch.cpp b/llvm/unittests/IR/PatternMatch.cpp
index f0377eae9989fd6..a25885faa3a4422 100644
--- a/llvm/unittests/IR/PatternMatch.cpp
+++ b/llvm/unittests/IR/PatternMatch.cpp
@@ -1995,7 +1995,7 @@ TEST_F(PatternMatchTest, VScale) {
   EXPECT_TRUE(match(PtrToInt2, m_VScale()));
 }
 
-TEST_F(PatternMatchTest, NotForbidUndef) {
+TEST_F(PatternMatchTest, NotForbidPoison) {
   Type *ScalarTy = IRB.getInt8Ty();
   Type *VectorTy = FixedVectorType::get(ScalarTy, 3);
   Constant *ScalarUndef = UndefValue::get(ScalarTy);
@@ -2020,23 +2020,33 @@ TEST_F(PatternMatchTest, NotForbidUndef) {
   Value *X;
   EXPECT_TRUE(match(Not, m_Not(m_Value(X))));
   EXPECT_TRUE(match(X, m_Zero()));
+  X = nullptr;
+  EXPECT_TRUE(match(Not, m_NotForbidPoison(m_Value(X))));
+  EXPECT_TRUE(match(X, m_Zero()));
 
   Value *NotCommute = IRB.CreateXor(VectorOnes, VectorZero);
   Value *Y;
   EXPECT_TRUE(match(NotCommute, m_Not(m_Value(Y))));
   EXPECT_TRUE(match(Y, m_Zero()));
+  Y = nullptr;
+  EXPECT_TRUE(match(NotCommute, m_NotForbidPoison(m_Value(Y))));
+  EXPECT_TRUE(match(Y, m_Zero()));
 
   Value *NotWithUndefs = IRB.CreateXor(VectorZero, VectorMixedUndef);
   EXPECT_FALSE(match(NotWithUndefs, m_Not(m_Value())));
+  EXPECT_FALSE(match(NotWithUndefs, m_NotForbidPoison(m_Value())));
 
   Value *NotWithPoisons = IRB.CreateXor(VectorZero, VectorMixedPoison);
   EXPECT_TRUE(match(NotWithPoisons, m_Not(m_Value())));
+  EXPECT_FALSE(match(NotWithPoisons, m_NotForbidPoison(m_Value())));
 
   Value *NotWithUndefsCommute = IRB.CreateXor(VectorMixedUndef, VectorZero);
   EXPECT_FALSE(match(NotWithUndefsCommute, m_Not(m_Value())));
+  EXPECT_FALSE(match(NotWithUndefsCommute, m_NotForbidPoison(m_Value())));
 
   Value *NotWithPoisonsCommute = IRB.CreateXor(VectorMixedPoison, VectorZero);
   EXPECT_TRUE(match(NotWithPoisonsCommute, m_Not(m_Value())));
+  EXPECT_FALSE(match(NotWithPoisonsCommute, m_NotForbidPoison(m_Value())));
 }
 
 template <typename T> struct MutableConstTest : PatternMatchTest { };



More information about the llvm-commits mailing list