[llvm] [IR][PatternMatch] Only accept poison in getSplatValue() (PR #89159)

via llvm-commits llvm-commits at lists.llvm.org
Wed Apr 17 17:22:18 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-x86

Author: Nikita Popov (nikic)

<details>
<summary>Changes</summary>

In #<!-- -->88217 a large set of matchers was changed to only accept poison values in splats, but not undef values. This is because we now use poison for non-demanded vector elements, and allowing undef can cause correctness issues.

This patch covers the remaining matchers by changing the AllowUndef parameter of getSplatValue() to AllowPoison instead. We also carry out corresponding renames in matchers.

As a followup, we may want to change the default for things like m_APInt to m_APIntAllowPoison (as this is much less risky when only allowing poison), but this change doesn't do that.

There is one caveat here: We have a single place (X86FixupVectorConstants) which does require handling of vector splats with undefs. This is because this works on backend constant pool entries, which currently still use undef instead of poison for non-demanded elements (because SDAG as a whole does not have an explicit poison representation). As it's just the single use, I've open-coded a getSplatValueAllowUndef() helper there, to discourage use in any other places.

---

Patch is 114.94 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/89159.diff


42 Files Affected:

- (modified) llvm/include/llvm/IR/Constant.h (+3-3) 
- (modified) llvm/include/llvm/IR/Constants.h (+3-3) 
- (modified) llvm/include/llvm/IR/PatternMatch.h (+33-33) 
- (modified) llvm/lib/Analysis/CmpInstAnalysis.cpp (+1-1) 
- (modified) llvm/lib/Analysis/InstructionSimplify.cpp (+7-7) 
- (modified) llvm/lib/Analysis/ValueTracking.cpp (+2-2) 
- (modified) llvm/lib/IR/Constants.cpp (+7-7) 
- (modified) llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp (+3-3) 
- (modified) llvm/lib/Target/X86/X86FixupVectorConstants.cpp (+18-1) 
- (modified) llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp (+2-2) 
- (modified) llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp (+11-11) 
- (modified) llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp (+2-6) 
- (modified) llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp (+1-1) 
- (modified) llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp (+19-18) 
- (modified) llvm/lib/Transforms/InstCombine/InstCombineInternal.h (+2-2) 
- (modified) llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp (+3-3) 
- (modified) llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp (+1-1) 
- (modified) llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp (+8-8) 
- (modified) llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp (+4-4) 
- (modified) llvm/lib/Transforms/InstCombine/InstructionCombining.cpp (+2-2) 
- (modified) llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h (+1-1) 
- (modified) llvm/test/Transforms/InstCombine/and-or-icmp-const-icmp.ll (+7-7) 
- (modified) llvm/test/Transforms/InstCombine/binop-itofp.ll (+10-11) 
- (modified) llvm/test/Transforms/InstCombine/bswap-fold.ll (+4-4) 
- (modified) llvm/test/Transforms/InstCombine/bswap.ll (+23-23) 
- (modified) llvm/test/Transforms/InstCombine/compare-signs.ll (+7-7) 
- (modified) llvm/test/Transforms/InstCombine/fcmp-range-check-idiom.ll (+4-4) 
- (modified) llvm/test/Transforms/InstCombine/icmp-fsh.ll (+16-16) 
- (modified) llvm/test/Transforms/InstCombine/icmp-power2-and-icmp-shifted-mask.ll (+8-2) 
- (modified) llvm/test/Transforms/InstCombine/icmp-vec-inseltpoison.ll (+10-10) 
- (modified) llvm/test/Transforms/InstCombine/icmp-vec.ll (+25-25) 
- (modified) llvm/test/Transforms/InstCombine/low-bit-splat.ll (+13-10) 
- (modified) llvm/test/Transforms/InstCombine/lshr-trunc-sext-to-ashr-sext.ll (+17-17) 
- (modified) llvm/test/Transforms/InstCombine/select.ll (+9-9) 
- (modified) llvm/test/Transforms/InstCombine/signed-truncation-check.ll (+37-37) 
- (modified) llvm/test/Transforms/InstCombine/unsigned-add-lack-of-overflow-check.ll (+1-1) 
- (modified) llvm/test/Transforms/InstCombine/xor-ashr.ll (+5-5) 
- (modified) llvm/test/Transforms/InstSimplify/cast-unsigned-icmp-cmp-0.ll (+18-18) 
- (modified) llvm/test/Transforms/InstSimplify/icmp-constant.ll (+30-30) 
- (modified) llvm/test/Transforms/InstSimplify/maxmin_intrinsics.ll (+26-26) 
- (modified) llvm/unittests/IR/InstructionsTest.cpp (+18-5) 
- (modified) llvm/unittests/IR/PatternMatch.cpp (+58-27) 


``````````diff
diff --git a/llvm/include/llvm/IR/Constant.h b/llvm/include/llvm/IR/Constant.h
index 778764062227cb..d3171acf7b9ac2 100644
--- a/llvm/include/llvm/IR/Constant.h
+++ b/llvm/include/llvm/IR/Constant.h
@@ -146,9 +146,9 @@ class Constant : public User {
   Constant *getAggregateElement(Constant *Elt) const;
 
   /// If all elements of the vector constant have the same value, return that
-  /// value. Otherwise, return nullptr. Ignore undefined elements by setting
-  /// AllowUndefs to true.
-  Constant *getSplatValue(bool AllowUndefs = false) const;
+  /// value. Otherwise, return nullptr. Ignore poison elements by setting
+  /// AllowPoison to true.
+  Constant *getSplatValue(bool AllowPoison = false) const;
 
   /// If C is a constant integer then return its value, otherwise C must be a
   /// vector of constant integers, all equal, and the common value is returned.
diff --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h
index 4290ef4486c6f4..9ec81903f09c96 100644
--- a/llvm/include/llvm/IR/Constants.h
+++ b/llvm/include/llvm/IR/Constants.h
@@ -532,9 +532,9 @@ class ConstantVector final : public ConstantAggregate {
   }
 
   /// If all elements of the vector constant have the same value, return that
-  /// value. Otherwise, return nullptr. Ignore undefined elements by setting
-  /// AllowUndefs to true.
-  Constant *getSplatValue(bool AllowUndefs = false) const;
+  /// value. Otherwise, return nullptr. Ignore poison elements by setting
+  /// AllowPoison to true.
+  Constant *getSplatValue(bool AllowPoison = false) const;
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast:
   static bool classof(const Value *V) {
diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h
index 98cc0e50376981..30e9c3900c8661 100644
--- a/llvm/include/llvm/IR/PatternMatch.h
+++ b/llvm/include/llvm/IR/PatternMatch.h
@@ -243,10 +243,10 @@ inline match_combine_and<LTy, RTy> m_CombineAnd(const LTy &L, const RTy &R) {
 
 struct apint_match {
   const APInt *&Res;
-  bool AllowUndef;
+  bool AllowPoison;
 
-  apint_match(const APInt *&Res, bool AllowUndef)
-      : Res(Res), AllowUndef(AllowUndef) {}
+  apint_match(const APInt *&Res, bool AllowPoison)
+      : Res(Res), AllowPoison(AllowPoison) {}
 
   template <typename ITy> bool match(ITy *V) {
     if (auto *CI = dyn_cast<ConstantInt>(V)) {
@@ -256,7 +256,7 @@ struct apint_match {
     if (V->getType()->isVectorTy())
       if (const auto *C = dyn_cast<Constant>(V))
         if (auto *CI =
-                dyn_cast_or_null<ConstantInt>(C->getSplatValue(AllowUndef))) {
+                dyn_cast_or_null<ConstantInt>(C->getSplatValue(AllowPoison))) {
           Res = &CI->getValue();
           return true;
         }
@@ -268,10 +268,10 @@ struct apint_match {
 // function for both apint/apfloat.
 struct apfloat_match {
   const APFloat *&Res;
-  bool AllowUndef;
+  bool AllowPoison;
 
-  apfloat_match(const APFloat *&Res, bool AllowUndef)
-      : Res(Res), AllowUndef(AllowUndef) {}
+  apfloat_match(const APFloat *&Res, bool AllowPoison)
+      : Res(Res), AllowPoison(AllowPoison) {}
 
   template <typename ITy> bool match(ITy *V) {
     if (auto *CI = dyn_cast<ConstantFP>(V)) {
@@ -281,7 +281,7 @@ struct apfloat_match {
     if (V->getType()->isVectorTy())
       if (const auto *C = dyn_cast<Constant>(V))
         if (auto *CI =
-                dyn_cast_or_null<ConstantFP>(C->getSplatValue(AllowUndef))) {
+                dyn_cast_or_null<ConstantFP>(C->getSplatValue(AllowPoison))) {
           Res = &CI->getValueAPF();
           return true;
         }
@@ -292,35 +292,35 @@ struct apfloat_match {
 /// Match a ConstantInt or splatted ConstantVector, binding the
 /// specified pointer to the contained APInt.
 inline apint_match m_APInt(const APInt *&Res) {
-  // Forbid undefs by default to maintain previous behavior.
-  return apint_match(Res, /* AllowUndef */ false);
+  // Forbid poison by default to maintain previous behavior.
+  return apint_match(Res, /* AllowPoison */ false);
 }
 
-/// Match APInt while allowing undefs in splat vector constants.
-inline apint_match m_APIntAllowUndef(const APInt *&Res) {
-  return apint_match(Res, /* AllowUndef */ true);
+/// Match APInt while allowing poison in splat vector constants.
+inline apint_match m_APIntAllowPoison(const APInt *&Res) {
+  return apint_match(Res, /* AllowPoison */ true);
 }
 
-/// Match APInt while forbidding undefs in splat vector constants.
-inline apint_match m_APIntForbidUndef(const APInt *&Res) {
-  return apint_match(Res, /* AllowUndef */ false);
+/// Match APInt while forbidding poison in splat vector constants.
+inline apint_match m_APIntForbidPoison(const APInt *&Res) {
+  return apint_match(Res, /* AllowPoison */ false);
 }
 
 /// Match a ConstantFP or splatted ConstantVector, binding the
 /// specified pointer to the contained APFloat.
 inline apfloat_match m_APFloat(const APFloat *&Res) {
   // Forbid undefs by default to maintain previous behavior.
-  return apfloat_match(Res, /* AllowUndef */ false);
+  return apfloat_match(Res, /* AllowPoison */ false);
 }
 
-/// Match APFloat while allowing undefs in splat vector constants.
-inline apfloat_match m_APFloatAllowUndef(const APFloat *&Res) {
-  return apfloat_match(Res, /* AllowUndef */ true);
+/// Match APFloat while allowing poison in splat vector constants.
+inline apfloat_match m_APFloatAllowPoison(const APFloat *&Res) {
+  return apfloat_match(Res, /* AllowPoison */ true);
 }
 
-/// Match APFloat while forbidding undefs in splat vector constants.
-inline apfloat_match m_APFloatForbidUndef(const APFloat *&Res) {
-  return apfloat_match(Res, /* AllowUndef */ false);
+/// Match APFloat while forbidding poison in splat vector constants.
+inline apfloat_match m_APFloatForbidPoison(const APFloat *&Res) {
+  return apfloat_match(Res, /* AllowPoison */ false);
 }
 
 template <int64_t Val> struct constantint_match {
@@ -418,7 +418,7 @@ template <typename Predicate> struct api_pred_ty : public Predicate {
 
 /// This helper class is used to match scalar and vector constants that
 /// satisfy a specified predicate, and bind them to an APFloat.
-/// Undefs are allowed in splat vector constants.
+/// Poison is allowed in splat vector constants.
 template <typename Predicate> struct apf_pred_ty : public Predicate {
   const APFloat *&Res;
 
@@ -433,7 +433,7 @@ template <typename Predicate> struct apf_pred_ty : public Predicate {
     if (V->getType()->isVectorTy())
       if (const auto *C = dyn_cast<Constant>(V))
         if (auto *CI = dyn_cast_or_null<ConstantFP>(
-                C->getSplatValue(/* AllowUndef */ true)))
+                C->getSplatValue(/* AllowPoison */ true)))
           if (this->isValue(CI->getValue())) {
             Res = &CI->getValue();
             return true;
@@ -883,7 +883,7 @@ struct bind_const_intval_ty {
 
 /// Match a specified integer value or vector of all elements of that
 /// value.
-template <bool AllowUndefs> struct specific_intval {
+template <bool AllowPoison> struct specific_intval {
   const APInt &Val;
 
   specific_intval(const APInt &V) : Val(V) {}
@@ -892,13 +892,13 @@ template <bool AllowUndefs> struct specific_intval {
     const auto *CI = dyn_cast<ConstantInt>(V);
     if (!CI && V->getType()->isVectorTy())
       if (const auto *C = dyn_cast<Constant>(V))
-        CI = dyn_cast_or_null<ConstantInt>(C->getSplatValue(AllowUndefs));
+        CI = dyn_cast_or_null<ConstantInt>(C->getSplatValue(AllowPoison));
 
     return CI && APInt::isSameValue(CI->getValue(), Val);
   }
 };
 
-template <bool AllowUndefs> struct specific_intval64 {
+template <bool AllowPoison> struct specific_intval64 {
   uint64_t Val;
 
   specific_intval64(uint64_t V) : Val(V) {}
@@ -907,7 +907,7 @@ template <bool AllowUndefs> struct specific_intval64 {
     const auto *CI = dyn_cast<ConstantInt>(V);
     if (!CI && V->getType()->isVectorTy())
       if (const auto *C = dyn_cast<Constant>(V))
-        CI = dyn_cast_or_null<ConstantInt>(C->getSplatValue(AllowUndefs));
+        CI = dyn_cast_or_null<ConstantInt>(C->getSplatValue(AllowPoison));
 
     return CI && CI->getValue() == Val;
   }
@@ -923,11 +923,11 @@ inline specific_intval64<false> m_SpecificInt(uint64_t V) {
   return specific_intval64<false>(V);
 }
 
-inline specific_intval<true> m_SpecificIntAllowUndef(const APInt &V) {
+inline specific_intval<true> m_SpecificIntAllowPoison(const APInt &V) {
   return specific_intval<true>(V);
 }
 
-inline specific_intval64<true> m_SpecificIntAllowUndef(uint64_t V) {
+inline specific_intval64<true> m_SpecificIntAllowPoison(uint64_t V) {
   return specific_intval64<true>(V);
 }
 
@@ -1699,9 +1699,9 @@ struct m_SpecificMask {
   bool match(ArrayRef<int> Mask) { return MaskRef == Mask; }
 };
 
-struct m_SplatOrUndefMask {
+struct m_SplatOrPoisonMask {
   int &SplatIndex;
-  m_SplatOrUndefMask(int &SplatIndex) : SplatIndex(SplatIndex) {}
+  m_SplatOrPoisonMask(int &SplatIndex) : SplatIndex(SplatIndex) {}
   bool match(ArrayRef<int> Mask) {
     const auto *First = find_if(Mask, [](int Elem) { return Elem != -1; });
     if (First == Mask.end())
diff --git a/llvm/lib/Analysis/CmpInstAnalysis.cpp b/llvm/lib/Analysis/CmpInstAnalysis.cpp
index d6407e8750737b..a1fa7857764d98 100644
--- a/llvm/lib/Analysis/CmpInstAnalysis.cpp
+++ b/llvm/lib/Analysis/CmpInstAnalysis.cpp
@@ -79,7 +79,7 @@ bool llvm::decomposeBitTestICmp(Value *LHS, Value *RHS,
   using namespace PatternMatch;
 
   const APInt *C;
-  if (!match(RHS, m_APIntAllowUndef(C)))
+  if (!match(RHS, m_APIntAllowPoison(C)))
     return false;
 
   switch (Pred) {
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 06ba5ca4c6b352..eb4e789b59b154 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -3023,7 +3023,7 @@ static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS,
 
   Value *X;
   const APInt *C;
-  if (!match(RHS, m_APIntAllowUndef(C)))
+  if (!match(RHS, m_APIntAllowPoison(C)))
     return nullptr;
 
   // Sign-bit checks can be optimized to true/false after unsigned
@@ -3056,9 +3056,9 @@ static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS,
   // (mul nuw/nsw X, MulC) == C --> false (if C is not a multiple of MulC)
   const APInt *MulC;
   if (IIQ.UseInstrInfo && ICmpInst::isEquality(Pred) &&
-      ((match(LHS, m_NUWMul(m_Value(), m_APIntAllowUndef(MulC))) &&
+      ((match(LHS, m_NUWMul(m_Value(), m_APIntAllowPoison(MulC))) &&
         *MulC != 0 && C->urem(*MulC) != 0) ||
-       (match(LHS, m_NSWMul(m_Value(), m_APIntAllowUndef(MulC))) &&
+       (match(LHS, m_NSWMul(m_Value(), m_APIntAllowPoison(MulC))) &&
         *MulC != 0 && C->srem(*MulC) != 0)))
     return ConstantInt::get(ITy, Pred == ICmpInst::ICMP_NE);
 
@@ -3203,7 +3203,7 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
 
   // (sub C, X) == X, C is odd  --> false
   // (sub C, X) != X, C is odd  --> true
-  if (match(LBO, m_Sub(m_APIntAllowUndef(C), m_Specific(RHS))) &&
+  if (match(LBO, m_Sub(m_APIntAllowPoison(C), m_Specific(RHS))) &&
       (*C & 1) == 1 && ICmpInst::isEquality(Pred))
     return (Pred == ICmpInst::ICMP_EQ) ? getFalse(ITy) : getTrue(ITy);
 
@@ -3354,7 +3354,7 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS,
   //   (C2 << X) != C --> true
   const APInt *C;
   if (match(LHS, m_Shl(m_Power2(), m_Value())) &&
-      match(RHS, m_APIntAllowUndef(C)) && !C->isPowerOf2()) {
+      match(RHS, m_APIntAllowPoison(C)) && !C->isPowerOf2()) {
     // C2 << X can equal zero in some circumstances.
     // This simplification might be unsafe if C is zero.
     //
@@ -4105,7 +4105,7 @@ static Value *simplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS,
   }
 
   const APFloat *C = nullptr;
-  match(RHS, m_APFloatAllowUndef(C));
+  match(RHS, m_APFloatAllowPoison(C));
   std::optional<KnownFPClass> FullKnownClassLHS;
 
   // Lazily compute the possible classes for LHS. Avoid computing it twice if
@@ -6459,7 +6459,7 @@ Value *llvm::simplifyBinaryIntrinsic(Intrinsic::ID IID, Type *ReturnType,
           ReturnType, MinMaxIntrinsic::getSaturationPoint(IID, BitWidth));
 
     const APInt *C;
-    if (match(Op1, m_APIntAllowUndef(C))) {
+    if (match(Op1, m_APIntAllowPoison(C))) {
       // Clamp to limit value. For example:
       // umax(i8 %x, i8 255) --> 255
       if (*C == MinMaxIntrinsic::getSaturationPoint(IID, BitWidth))
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index ab2f43e1033fa1..349077afef05c1 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -4116,7 +4116,7 @@ std::pair<Value *, FPClassTest> llvm::fcmpToClassTest(FCmpInst::Predicate Pred,
                                                       Value *LHS, Value *RHS,
                                                       bool LookThroughSrc) {
   const APFloat *ConstRHS;
-  if (!match(RHS, m_APFloatAllowUndef(ConstRHS)))
+  if (!match(RHS, m_APFloatAllowPoison(ConstRHS)))
     return {nullptr, fcAllFlags};
 
   return fcmpToClassTest(Pred, F, LHS, ConstRHS, LookThroughSrc);
@@ -4517,7 +4517,7 @@ std::tuple<Value *, FPClassTest, FPClassTest>
 llvm::fcmpImpliesClass(CmpInst::Predicate Pred, const Function &F, Value *LHS,
                        Value *RHS, bool LookThroughSrc) {
   const APFloat *ConstRHS;
-  if (!match(RHS, m_APFloatAllowUndef(ConstRHS)))
+  if (!match(RHS, m_APFloatAllowPoison(ConstRHS)))
     return {nullptr, fcAllFlags, fcAllFlags};
 
   // TODO: Just call computeKnownFPClass for RHS to handle non-constants.
diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index 45b359a94b3ab7..5268eccf701442 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -1696,14 +1696,14 @@ void ConstantVector::destroyConstantImpl() {
   getType()->getContext().pImpl->VectorConstants.remove(this);
 }
 
-Constant *Constant::getSplatValue(bool AllowUndefs) const {
+Constant *Constant::getSplatValue(bool AllowPoison) const {
   assert(this->getType()->isVectorTy() && "Only valid for vectors!");
   if (isa<ConstantAggregateZero>(this))
     return getNullValue(cast<VectorType>(getType())->getElementType());
   if (const ConstantDataVector *CV = dyn_cast<ConstantDataVector>(this))
     return CV->getSplatValue();
   if (const ConstantVector *CV = dyn_cast<ConstantVector>(this))
-    return CV->getSplatValue(AllowUndefs);
+    return CV->getSplatValue(AllowPoison);
 
   // Check if this is a constant expression splat of the form returned by
   // ConstantVector::getSplat()
@@ -1728,7 +1728,7 @@ Constant *Constant::getSplatValue(bool AllowUndefs) const {
   return nullptr;
 }
 
-Constant *ConstantVector::getSplatValue(bool AllowUndefs) const {
+Constant *ConstantVector::getSplatValue(bool AllowPoison) const {
   // Check out first element.
   Constant *Elt = getOperand(0);
   // Then make sure all remaining elements point to the same value.
@@ -1738,15 +1738,15 @@ Constant *ConstantVector::getSplatValue(bool AllowUndefs) const {
       continue;
 
     // Strict mode: any mismatch is not a splat.
-    if (!AllowUndefs)
+    if (!AllowPoison)
       return nullptr;
 
-    // Allow undefs mode: ignore undefined elements.
-    if (isa<UndefValue>(OpC))
+    // Allow poison mode: ignore poison elements.
+    if (isa<PoisonValue>(OpC))
       continue;
 
     // If we do not have a defined element yet, use the current operand.
-    if (isa<UndefValue>(Elt))
+    if (isa<PoisonValue>(Elt))
       Elt = OpC;
 
     if (OpC != Elt)
diff --git a/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp b/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp
index 5aa35becd842c3..978a2d49b08bc2 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPULibCalls.cpp
@@ -906,8 +906,8 @@ bool AMDGPULibCalls::fold_pow(FPMathOperator *FPOp, IRBuilder<> &B,
 
   const APFloat *CF = nullptr;
   const APInt *CINT = nullptr;
-  if (!match(opr1, m_APFloatAllowUndef(CF)))
-    match(opr1, m_APIntAllowUndef(CINT));
+  if (!match(opr1, m_APFloatAllowPoison(CF)))
+    match(opr1, m_APIntAllowPoison(CINT));
 
   // 0x1111111 means that we don't do anything for this call.
   int ci_opr1 = (CINT ? (int)CINT->getSExtValue() : 0x1111111);
@@ -1039,7 +1039,7 @@ bool AMDGPULibCalls::fold_pow(FPMathOperator *FPOp, IRBuilder<> &B,
   Constant *cnval = nullptr;
   if (getVecSize(FInfo) == 1) {
     CF = nullptr;
-    match(opr0, m_APFloatAllowUndef(CF));
+    match(opr0, m_APFloatAllowPoison(CF));
 
     if (CF) {
       double V = (getArgType(FInfo) == AMDGPULibFunc::F32)
diff --git a/llvm/lib/Target/X86/X86FixupVectorConstants.cpp b/llvm/lib/Target/X86/X86FixupVectorConstants.cpp
index da7dcbb25a9577..c9f79e1645f58b 100644
--- a/llvm/lib/Target/X86/X86FixupVectorConstants.cpp
+++ b/llvm/lib/Target/X86/X86FixupVectorConstants.cpp
@@ -64,6 +64,23 @@ FunctionPass *llvm::createX86FixupVectorConstants() {
   return new X86FixupVectorConstantsPass();
 }
 
+/// Normally, we only allow poison in vector splats. However, as this is part
+/// of the backend, and working with the DAG representation, which currently
+/// only natively represents undef values, we need to accept undefs here.
+static Constant *getSplatValueAllowUndef(const ConstantVector *C) {
+  Constant *Res = nullptr;
+  for (Value *Op : C->operands()) {
+    Constant *OpC = cast<Constant>(Op);
+    if (isa<UndefValue>(OpC))
+      continue;
+    if (!Res)
+      Res = OpC;
+    else if (Res != OpC)
+      return nullptr;
+  }
+  return Res;
+}
+
 // Attempt to extract the full width of bits data from the constant.
 static std::optional<APInt> extractConstantBits(const Constant *C) {
   unsigned NumBits = C->getType()->getPrimitiveSizeInBits();
@@ -78,7 +95,7 @@ static std::optional<APInt> extractConstantBits(const Constant *C) {
     return CFP->getValue().bitcastToAPInt();
 
   if (auto *CV = dyn_cast<ConstantVector>(C)) {
-    if (auto *CVSplat = CV->getSplatValue(/*AllowUndefs*/ true)) {
+    if (auto *CVSplat = getSplatValueAllowUndef(CV)) {
       if (std::optional<APInt> Bits = extractConstantBits(CVSplat)) {
         assert((NumBits % Bits->getBitWidth()) == 0 && "Illegal splat");
         return APInt::getSplat(NumBits, *Bits);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index c59b867b10e7d1..b853dc678877e0 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -896,7 +896,7 @@ Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) {
   const APInt *C;
   unsigned BitWidth = Ty->getScalarSizeInBits();
   if (match(Op0, m_OneUse(m_AShr(m_Value(X),
-                                 m_SpecificIntAllowUndef(BitWidth - 1)))) &&
+                                 m_SpecificIntAllowPoison(BitWidth - 1)))) &&
       match(Op1, m_One()))
     return new ZExtInst(Builder.CreateIsNotNeg(X, "isnotneg"), Ty);
 
@@ -1656,7 +1656,7 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
   // (A s>> (BW - 1)) + (zext (A s> 0)) --> (A s>> (BW - 1)) | (zext (A != 0))
   ICmpInst::Predicate Pred;
   uint64_t BitWidth = Ty->getScalarSizeInBits();
-  if (match(LHS, m_AShr(m_Value(A), m_SpecificIntAllowUndef(BitWidth - 1))) &&
+  if (match(LHS, m_AShr(m_Value(A), m_SpecificIntAllowPoison(BitWidth - 1))) &&
       match(RHS, m_OneUse(m_ZExt(
                      m_OneUse(m_ICmp(Pred, m_Specific(A), m_ZeroInt()))))) &&
       Pred == CmpInst::ICMP_SGT) {
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 0f4fbf5bbfbbdc..bf7c0074a38f05 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -947,9 +947,9 @@ static Value *foldNegativePower2AndShiftedMask(
   // bits (0).
   auto isReducible = [](const Value *B, const Value *D, const Value *E) {
     const APInt *BCst, *DCst, *ECst;
-    return match(B, m_APIntAllowUndef(BCst)) && match(D, m_APInt(DCst)) &&
+    return match(B, m_APIntAllowPoison(BCst)) && match(D, m_APInt(DCst)) &&
            match(E, m_APInt(ECst)) && *DCst == *ECst &&
-           (isa<UndefValue>(B) ||
+           (isa<PoisonValue>(B) ||
             (BCst->countLeadingOnes() == DCst->countLe...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/89159


More information about the llvm-commits mailing list