[llvm] 16e9315 - [IR] allow undefined elements when checking for splat constants

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Tue Dec 10 14:17:15 PST 2019


Author: Sanjay Patel
Date: 2019-12-10T17:16:59-05:00
New Revision: 16e9315685bc057849eab072de6ec349b508ec1d

URL: https://github.com/llvm/llvm-project/commit/16e9315685bc057849eab072de6ec349b508ec1d
DIFF: https://github.com/llvm/llvm-project/commit/16e9315685bc057849eab072de6ec349b508ec1d.diff

LOG: [IR] allow undefined elements when checking for splat constants

This mimics the related call in SDAG. The caller is responsible
for ensuring that undef values are propagated safely.

Added: 
    

Modified: 
    llvm/include/llvm/IR/Constant.h
    llvm/include/llvm/IR/Constants.h
    llvm/lib/IR/Constants.cpp
    llvm/unittests/IR/InstructionsTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/IR/Constant.h b/llvm/include/llvm/IR/Constant.h
index 3f3fa4c272c5..174e7364c524 100644
--- a/llvm/include/llvm/IR/Constant.h
+++ b/llvm/include/llvm/IR/Constant.h
@@ -133,9 +133,10 @@ class Constant : public User {
   Constant *getAggregateElement(unsigned Elt) const;
   Constant *getAggregateElement(Constant *Elt) const;
 
-  /// If this is a splat vector constant, meaning that all of the elements have
-  /// the same value, return that value. Otherwise return 0.
-  Constant *getSplatValue() const;
+  /// If all elements of the vector constant have the same value, return that
+  /// value. Otherwise, return nullptr. Ignore undefined elements by setting
+  /// AllowUndefs to true.
+  Constant *getSplatValue(bool AllowUndefs = false) const;
 
   /// If C is a constant integer then return its value, otherwise C must be a
   /// vector of constant integers, all equal, and the common value is returned.

diff  --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h
index 7f0687d382f0..262ab439df65 100644
--- a/llvm/include/llvm/IR/Constants.h
+++ b/llvm/include/llvm/IR/Constants.h
@@ -522,9 +522,10 @@ class ConstantVector final : public ConstantAggregate {
     return cast<VectorType>(Value::getType());
   }
 
-  /// If this is a splat constant, meaning that all of the elements have the
-  /// same value, return that value. Otherwise return NULL.
-  Constant *getSplatValue() const;
+  /// If all elements of the vector constant have the same value, return that
+  /// value. Otherwise, return nullptr. Ignore undefined elements by setting
+  /// AllowUndefs to true.
+  Constant *getSplatValue(bool AllowUndefs = false) const;
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast:
   static bool classof(const Value *V) {

diff  --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index fc215d6bf958..cafb412b795b 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -1442,24 +1442,41 @@ void ConstantVector::destroyConstantImpl() {
   getType()->getContext().pImpl->VectorConstants.remove(this);
 }
 
-Constant *Constant::getSplatValue() const {
+Constant *Constant::getSplatValue(bool AllowUndefs) const {
   assert(this->getType()->isVectorTy() && "Only valid for vectors!");
   if (isa<ConstantAggregateZero>(this))
     return getNullValue(this->getType()->getVectorElementType());
   if (const ConstantDataVector *CV = dyn_cast<ConstantDataVector>(this))
     return CV->getSplatValue();
   if (const ConstantVector *CV = dyn_cast<ConstantVector>(this))
-    return CV->getSplatValue();
+    return CV->getSplatValue(AllowUndefs);
   return nullptr;
 }
 
-Constant *ConstantVector::getSplatValue() const {
+Constant *ConstantVector::getSplatValue(bool AllowUndefs) const {
   // Check out first element.
   Constant *Elt = getOperand(0);
   // Then make sure all remaining elements point to the same value.
-  for (unsigned I = 1, E = getNumOperands(); I < E; ++I)
-    if (getOperand(I) != Elt)
+  for (unsigned I = 1, E = getNumOperands(); I < E; ++I) {
+    Constant *OpC = getOperand(I);
+    if (OpC == Elt)
+      continue;
+
+    // Strict mode: any mismatch is not a splat.
+    if (!AllowUndefs)
       return nullptr;
+
+    // Allow undefs mode: ignore undefined elements.
+    if (isa<UndefValue>(OpC))
+      continue;
+
+    // If we do not have a defined element yet, use the current operand.
+    if (isa<UndefValue>(Elt))
+      Elt = OpC;
+
+    if (OpC != Elt)
+      return nullptr;
+  }
   return Elt;
 }
 

diff  --git a/llvm/unittests/IR/InstructionsTest.cpp b/llvm/unittests/IR/InstructionsTest.cpp
index 556c41058e7d..c2f70724337c 100644
--- a/llvm/unittests/IR/InstructionsTest.cpp
+++ b/llvm/unittests/IR/InstructionsTest.cpp
@@ -995,6 +995,46 @@ TEST(InstructionsTest, ShuffleMaskQueries) {
   delete Id12;
 }
 
+TEST(InstructionsTest, GetSplat) {
+  // Create the elements for various constant vectors.
+  LLVMContext Ctx;
+  Type *Int32Ty = Type::getInt32Ty(Ctx);
+  Constant *CU = UndefValue::get(Int32Ty);
+  Constant *C0 = ConstantInt::get(Int32Ty, 0);
+  Constant *C1 = ConstantInt::get(Int32Ty, 1);
+
+  Constant *Splat0 = ConstantVector::get({C0, C0, C0, C0});
+  Constant *Splat1 = ConstantVector::get({C1, C1, C1, C1 ,C1});
+  Constant *Splat0Undef = ConstantVector::get({C0, CU, C0, CU});
+  Constant *Splat1Undef = ConstantVector::get({CU, CU, C1, CU});
+  Constant *NotSplat = ConstantVector::get({C1, C1, C0, C1 ,C1});
+  Constant *NotSplatUndef = ConstantVector::get({CU, C1, CU, CU ,C0});
+
+  // Default - undefs are not allowed.
+  EXPECT_EQ(Splat0->getSplatValue(), C0);
+  EXPECT_EQ(Splat1->getSplatValue(), C1);
+  EXPECT_EQ(Splat0Undef->getSplatValue(), nullptr);
+  EXPECT_EQ(Splat1Undef->getSplatValue(), nullptr);
+  EXPECT_EQ(NotSplat->getSplatValue(), nullptr);
+  EXPECT_EQ(NotSplatUndef->getSplatValue(), nullptr);
+
+  // Disallow undefs explicitly.
+  EXPECT_EQ(Splat0->getSplatValue(false), C0);
+  EXPECT_EQ(Splat1->getSplatValue(false), C1);
+  EXPECT_EQ(Splat0Undef->getSplatValue(false), nullptr);
+  EXPECT_EQ(Splat1Undef->getSplatValue(false), nullptr);
+  EXPECT_EQ(NotSplat->getSplatValue(false), nullptr);
+  EXPECT_EQ(NotSplatUndef->getSplatValue(false), nullptr);
+
+  // Allow undefs.
+  EXPECT_EQ(Splat0->getSplatValue(true), C0);
+  EXPECT_EQ(Splat1->getSplatValue(true), C1);
+  EXPECT_EQ(Splat0Undef->getSplatValue(true), C0);
+  EXPECT_EQ(Splat1Undef->getSplatValue(true), C1);
+  EXPECT_EQ(NotSplat->getSplatValue(true), nullptr);
+  EXPECT_EQ(NotSplatUndef->getSplatValue(true), nullptr);
+}
+
 TEST(InstructionsTest, SkipDebug) {
   LLVMContext C;
   std::unique_ptr<Module> M = parseIR(C,


        


More information about the llvm-commits mailing list