[llvm] 29f8628 - [Constant] Add containsPoisonElement

Juneyoung Lee via llvm-commits llvm-commits at lists.llvm.org
Tue Jan 5 19:12:11 PST 2021


Author: Juneyoung Lee
Date: 2021-01-06T12:10:33+09:00
New Revision: 29f8628d1fc8d96670e13562c4d92fc916bd0ce1

URL: https://github.com/llvm/llvm-project/commit/29f8628d1fc8d96670e13562c4d92fc916bd0ce1
DIFF: https://github.com/llvm/llvm-project/commit/29f8628d1fc8d96670e13562c4d92fc916bd0ce1.diff

LOG: [Constant] Add containsPoisonElement

This patch

- Adds containsPoisonElement that checks existence of poison in constant vector elements,
- Renames containsUndefElement to containsUndefOrPoisonElement to clarify its behavior & updates its uses properly

With this patch, isGuaranteedNotToBeUndefOrPoison's tests w.r.t constant vectors are added because its analysis is improved.

Thanks!

Reviewed By: nikic

Differential Revision: https://reviews.llvm.org/D94053

Added: 
    

Modified: 
    llvm/include/llvm/IR/Constant.h
    llvm/lib/Analysis/ValueTracking.cpp
    llvm/lib/IR/ConstantFold.cpp
    llvm/lib/IR/Constants.cpp
    llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
    llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp
    llvm/unittests/Analysis/ValueTrackingTest.cpp
    llvm/unittests/IR/ConstantsTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/IR/Constant.h b/llvm/include/llvm/IR/Constant.h
index 97650c2051ca..0190aca27b72 100644
--- a/llvm/include/llvm/IR/Constant.h
+++ b/llvm/include/llvm/IR/Constant.h
@@ -101,11 +101,15 @@ class Constant : public User {
   /// lane, the constants still match.
   bool isElementWiseEqual(Value *Y) const;
 
-  /// Return true if this is a vector constant that includes any undefined
-  /// elements. Since it is impossible to inspect a scalable vector element-
-  /// wise at compile time, this function returns true only if the entire
-  /// vector is undef
-  bool containsUndefElement() const;
+  /// Return true if this is a vector constant that includes any undef or
+  /// poison elements. Since it is impossible to inspect a scalable vector
+  /// element- wise at compile time, this function returns true only if the
+  /// entire vector is undef or poison.
+  bool containsUndefOrPoisonElement() const;
+
+  /// Return true if this is a vector constant that includes any poison
+  /// elements.
+  bool containsPoisonElement() const;
 
   /// Return true if this is a fixed width vector constant that includes
   /// any constant expressions.

diff  --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index e15d4f0e4b07..1c75c5fbd0db 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -4895,7 +4895,8 @@ static bool isGuaranteedNotToBeUndefOrPoison(const Value *V,
       return true;
 
     if (C->getType()->isVectorTy() && !isa<ConstantExpr>(C))
-      return (PoisonOnly || !C->containsUndefElement()) &&
+      return (PoisonOnly ? !C->containsPoisonElement()
+                         : !C->containsUndefOrPoisonElement()) &&
              !C->containsConstantExpression();
   }
 
@@ -5636,10 +5637,10 @@ static SelectPatternResult matchSelectPattern(CmpInst::Predicate Pred,
     // elements because those can not be back-propagated for analysis.
     Value *OutputZeroVal = nullptr;
     if (match(TrueVal, m_AnyZeroFP()) && !match(FalseVal, m_AnyZeroFP()) &&
-        !cast<Constant>(TrueVal)->containsUndefElement())
+        !cast<Constant>(TrueVal)->containsUndefOrPoisonElement())
       OutputZeroVal = TrueVal;
     else if (match(FalseVal, m_AnyZeroFP()) && !match(TrueVal, m_AnyZeroFP()) &&
-             !cast<Constant>(FalseVal)->containsUndefElement())
+             !cast<Constant>(FalseVal)->containsUndefOrPoisonElement())
       OutputZeroVal = FalseVal;
 
     if (OutputZeroVal) {

diff  --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp
index 47745689ba2d..03cb108cc485 100644
--- a/llvm/lib/IR/ConstantFold.cpp
+++ b/llvm/lib/IR/ConstantFold.cpp
@@ -811,7 +811,7 @@ Constant *llvm::ConstantFoldSelectInstruction(Constant *Cond,
       return true;
 
     if (C->getType()->isVectorTy())
-      return !C->containsUndefElement() && !C->containsConstantExpression();
+      return !C->containsPoisonElement() && !C->containsConstantExpression();
 
     // TODO: Recursively analyze aggregates or other constants.
     return false;

diff  --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index a38302d17937..5aa819dda2b3 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -304,31 +304,42 @@ bool Constant::isElementWiseEqual(Value *Y) const {
   return isa<UndefValue>(CmpEq) || match(CmpEq, m_One());
 }
 
-bool Constant::containsUndefElement() const {
-  if (auto *VTy = dyn_cast<VectorType>(getType())) {
-    if (isa<UndefValue>(this))
+static bool
+containsUndefinedElement(const Constant *C,
+                         function_ref<bool(const Constant *)> HasFn) {
+  if (auto *VTy = dyn_cast<VectorType>(C->getType())) {
+    if (HasFn(C))
       return true;
-    if (isa<ConstantAggregateZero>(this))
+    if (isa<ConstantAggregateZero>(C))
       return false;
-    if (isa<ScalableVectorType>(getType()))
+    if (isa<ScalableVectorType>(C->getType()))
       return false;
 
     for (unsigned i = 0, e = cast<FixedVectorType>(VTy)->getNumElements();
          i != e; ++i)
-      if (isa<UndefValue>(getAggregateElement(i)))
+      if (HasFn(C->getAggregateElement(i)))
         return true;
   }
 
   return false;
 }
 
+bool Constant::containsUndefOrPoisonElement() const {
+  return containsUndefinedElement(
+      this, [&](const auto *C) { return isa<UndefValue>(C); });
+}
+
+bool Constant::containsPoisonElement() const {
+  return containsUndefinedElement(
+      this, [&](const auto *C) { return isa<PoisonValue>(C); });
+}
+
 bool Constant::containsConstantExpression() const {
   if (auto *VTy = dyn_cast<FixedVectorType>(getType())) {
     for (unsigned i = 0, e = VTy->getNumElements(); i != e; ++i)
       if (isa<ConstantExpr>(getAggregateElement(i)))
         return true;
   }
-
   return false;
 }
 

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 87d4b40a9a64..08877797c53a 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -3370,7 +3370,7 @@ static Value *foldICmpWithLowBitMaskedVal(ICmpInst &I,
   Type *OpTy = M->getType();
   auto *VecC = dyn_cast<Constant>(M);
   auto *OpVTy = dyn_cast<FixedVectorType>(OpTy);
-  if (OpVTy && VecC && VecC->containsUndefElement()) {
+  if (OpVTy && VecC && VecC->containsUndefOrPoisonElement()) {
     Constant *SafeReplacementConstant = nullptr;
     for (unsigned i = 0, e = OpVTy->getNumElements(); i != e; ++i) {
       if (!isa<UndefValue>(VecC->getAggregateElement(i))) {
@@ -5259,7 +5259,8 @@ InstCombiner::getFlippedStrictnessPredicateAndConstant(CmpInst::Predicate Pred,
   // It may not be safe to change a compare predicate in the presence of
   // undefined elements, so replace those elements with the first safe constant
   // that we found.
-  if (C->containsUndefElement()) {
+  // TODO: in case of poison, it is safe; let's replace undefs only.
+  if (C->containsUndefOrPoisonElement()) {
     assert(SafeReplacementConstant && "Replacement constant not set");
     C = Constant::replaceUndefsWith(C, SafeReplacementConstant);
   }

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp b/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp
index 494c58e706e1..7718c8b0eedd 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp
@@ -239,8 +239,8 @@ LLVM_NODISCARD Value *Negator::visitImpl(Value *V, unsigned Depth) {
     // While this is normally not behind a use-check,
     // let's consider division to be special since it's costly.
     if (auto *Op1C = dyn_cast<Constant>(I->getOperand(1))) {
-      if (!Op1C->containsUndefElement() && Op1C->isNotMinSignedValue() &&
-          Op1C->isNotOneValue()) {
+      if (!Op1C->containsUndefOrPoisonElement() &&
+          Op1C->isNotMinSignedValue() && Op1C->isNotOneValue()) {
         Value *BO =
             Builder.CreateSDiv(I->getOperand(0), ConstantExpr::getNeg(Op1C),
                                I->getName() + ".neg");

diff  --git a/llvm/unittests/Analysis/ValueTrackingTest.cpp b/llvm/unittests/Analysis/ValueTrackingTest.cpp
index 0d6577452560..d70fd6eb0ba2 100644
--- a/llvm/unittests/Analysis/ValueTrackingTest.cpp
+++ b/llvm/unittests/Analysis/ValueTrackingTest.cpp
@@ -888,6 +888,30 @@ TEST_F(ValueTrackingTest, isGuaranteedNotToBeUndefOrPoison) {
   EXPECT_EQ(isGuaranteedNotToBeUndefOrPoison(PoisonValue::get(IntegerType::get(Context, 8))), false);
   EXPECT_EQ(isGuaranteedNotToBePoison(UndefValue::get(IntegerType::get(Context, 8))), true);
   EXPECT_EQ(isGuaranteedNotToBePoison(PoisonValue::get(IntegerType::get(Context, 8))), false);
+
+  Type *Int32Ty = Type::getInt32Ty(Context);
+  Constant *CU = UndefValue::get(Int32Ty);
+  Constant *CP = PoisonValue::get(Int32Ty);
+  Constant *C1 = ConstantInt::get(Int32Ty, 1);
+  Constant *C2 = ConstantInt::get(Int32Ty, 2);
+
+  {
+    Constant *V1 = ConstantVector::get({C1, C2});
+    EXPECT_TRUE(isGuaranteedNotToBeUndefOrPoison(V1));
+    EXPECT_TRUE(isGuaranteedNotToBePoison(V1));
+  }
+
+  {
+    Constant *V2 = ConstantVector::get({C1, CU});
+    EXPECT_FALSE(isGuaranteedNotToBeUndefOrPoison(V2));
+    EXPECT_TRUE(isGuaranteedNotToBePoison(V2));
+  }
+
+  {
+    Constant *V3 = ConstantVector::get({C1, CP});
+    EXPECT_FALSE(isGuaranteedNotToBeUndefOrPoison(V3));
+    EXPECT_FALSE(isGuaranteedNotToBePoison(V3));
+  }
 }
 
 TEST_F(ValueTrackingTest, isGuaranteedNotToBeUndefOrPoison_assume) {

diff  --git a/llvm/unittests/IR/ConstantsTest.cpp b/llvm/unittests/IR/ConstantsTest.cpp
index 9dd1ba84cc12..afae154cca90 100644
--- a/llvm/unittests/IR/ConstantsTest.cpp
+++ b/llvm/unittests/IR/ConstantsTest.cpp
@@ -585,6 +585,43 @@ TEST(ConstantsTest, FoldGlobalVariablePtr) {
       Instruction::And, TheConstantExpr, TheConstant)->isNullValue());
 }
 
+// Check that containsUndefOrPoisonElement and containsPoisonElement is working
+// great
+
+TEST(ConstantsTest, containsUndefElemTest) {
+  LLVMContext Context;
+
+  Type *Int32Ty = Type::getInt32Ty(Context);
+  Constant *CU = UndefValue::get(Int32Ty);
+  Constant *CP = PoisonValue::get(Int32Ty);
+  Constant *C1 = ConstantInt::get(Int32Ty, 1);
+  Constant *C2 = ConstantInt::get(Int32Ty, 2);
+
+  {
+    Constant *V1 = ConstantVector::get({C1, C2});
+    EXPECT_FALSE(V1->containsUndefOrPoisonElement());
+    EXPECT_FALSE(V1->containsPoisonElement());
+  }
+
+  {
+    Constant *V2 = ConstantVector::get({C1, CU});
+    EXPECT_TRUE(V2->containsUndefOrPoisonElement());
+    EXPECT_FALSE(V2->containsPoisonElement());
+  }
+
+  {
+    Constant *V3 = ConstantVector::get({C1, CP});
+    EXPECT_TRUE(V3->containsUndefOrPoisonElement());
+    EXPECT_TRUE(V3->containsPoisonElement());
+  }
+
+  {
+    Constant *V4 = ConstantVector::get({CU, CP});
+    EXPECT_TRUE(V4->containsUndefOrPoisonElement());
+    EXPECT_TRUE(V4->containsPoisonElement());
+  }
+}
+
 // Check that undefined elements in vector constants are matched
 // correctly for both integer and floating-point types. Just don't
 // crash on vectors of pointers (could be handled?).


        


More information about the llvm-commits mailing list