[llvm] 16e9315 - [IR] allow undefined elements when checking for splat constants
Sanjay Patel via llvm-commits
llvm-commits at lists.llvm.org
Tue Dec 10 14:17:15 PST 2019
Author: Sanjay Patel
Date: 2019-12-10T17:16:59-05:00
New Revision: 16e9315685bc057849eab072de6ec349b508ec1d
URL: https://github.com/llvm/llvm-project/commit/16e9315685bc057849eab072de6ec349b508ec1d
DIFF: https://github.com/llvm/llvm-project/commit/16e9315685bc057849eab072de6ec349b508ec1d.diff
LOG: [IR] allow undefined elements when checking for splat constants
This mimics the related call in SDAG. The caller is responsible
for ensuring that undef values are propagated safely.
Added:
Modified:
llvm/include/llvm/IR/Constant.h
llvm/include/llvm/IR/Constants.h
llvm/lib/IR/Constants.cpp
llvm/unittests/IR/InstructionsTest.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/IR/Constant.h b/llvm/include/llvm/IR/Constant.h
index 3f3fa4c272c5..174e7364c524 100644
--- a/llvm/include/llvm/IR/Constant.h
+++ b/llvm/include/llvm/IR/Constant.h
@@ -133,9 +133,10 @@ class Constant : public User {
Constant *getAggregateElement(unsigned Elt) const;
Constant *getAggregateElement(Constant *Elt) const;
- /// If this is a splat vector constant, meaning that all of the elements have
- /// the same value, return that value. Otherwise return 0.
- Constant *getSplatValue() const;
+ /// If all elements of the vector constant have the same value, return that
+ /// value. Otherwise, return nullptr. Ignore undefined elements by setting
+ /// AllowUndefs to true.
+ Constant *getSplatValue(bool AllowUndefs = false) const;
/// If C is a constant integer then return its value, otherwise C must be a
/// vector of constant integers, all equal, and the common value is returned.
diff --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h
index 7f0687d382f0..262ab439df65 100644
--- a/llvm/include/llvm/IR/Constants.h
+++ b/llvm/include/llvm/IR/Constants.h
@@ -522,9 +522,10 @@ class ConstantVector final : public ConstantAggregate {
return cast<VectorType>(Value::getType());
}
- /// If this is a splat constant, meaning that all of the elements have the
- /// same value, return that value. Otherwise return NULL.
- Constant *getSplatValue() const;
+ /// If all elements of the vector constant have the same value, return that
+ /// value. Otherwise, return nullptr. Ignore undefined elements by setting
+ /// AllowUndefs to true.
+ Constant *getSplatValue(bool AllowUndefs = false) const;
/// Methods for support type inquiry through isa, cast, and dyn_cast:
static bool classof(const Value *V) {
diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index fc215d6bf958..cafb412b795b 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -1442,24 +1442,41 @@ void ConstantVector::destroyConstantImpl() {
getType()->getContext().pImpl->VectorConstants.remove(this);
}
-Constant *Constant::getSplatValue() const {
+Constant *Constant::getSplatValue(bool AllowUndefs) const {
assert(this->getType()->isVectorTy() && "Only valid for vectors!");
if (isa<ConstantAggregateZero>(this))
return getNullValue(this->getType()->getVectorElementType());
if (const ConstantDataVector *CV = dyn_cast<ConstantDataVector>(this))
return CV->getSplatValue();
if (const ConstantVector *CV = dyn_cast<ConstantVector>(this))
- return CV->getSplatValue();
+ return CV->getSplatValue(AllowUndefs);
return nullptr;
}
-Constant *ConstantVector::getSplatValue() const {
+Constant *ConstantVector::getSplatValue(bool AllowUndefs) const {
// Check out first element.
Constant *Elt = getOperand(0);
// Then make sure all remaining elements point to the same value.
- for (unsigned I = 1, E = getNumOperands(); I < E; ++I)
- if (getOperand(I) != Elt)
+ for (unsigned I = 1, E = getNumOperands(); I < E; ++I) {
+ Constant *OpC = getOperand(I);
+ if (OpC == Elt)
+ continue;
+
+ // Strict mode: any mismatch is not a splat.
+ if (!AllowUndefs)
return nullptr;
+
+ // Allow undefs mode: ignore undefined elements.
+ if (isa<UndefValue>(OpC))
+ continue;
+
+ // If we do not have a defined element yet, use the current operand.
+ if (isa<UndefValue>(Elt))
+ Elt = OpC;
+
+ if (OpC != Elt)
+ return nullptr;
+ }
return Elt;
}
diff --git a/llvm/unittests/IR/InstructionsTest.cpp b/llvm/unittests/IR/InstructionsTest.cpp
index 556c41058e7d..c2f70724337c 100644
--- a/llvm/unittests/IR/InstructionsTest.cpp
+++ b/llvm/unittests/IR/InstructionsTest.cpp
@@ -995,6 +995,46 @@ TEST(InstructionsTest, ShuffleMaskQueries) {
delete Id12;
}
+TEST(InstructionsTest, GetSplat) {
+ // Create the elements for various constant vectors.
+ LLVMContext Ctx;
+ Type *Int32Ty = Type::getInt32Ty(Ctx);
+ Constant *CU = UndefValue::get(Int32Ty);
+ Constant *C0 = ConstantInt::get(Int32Ty, 0);
+ Constant *C1 = ConstantInt::get(Int32Ty, 1);
+
+ Constant *Splat0 = ConstantVector::get({C0, C0, C0, C0});
+ Constant *Splat1 = ConstantVector::get({C1, C1, C1, C1 ,C1});
+ Constant *Splat0Undef = ConstantVector::get({C0, CU, C0, CU});
+ Constant *Splat1Undef = ConstantVector::get({CU, CU, C1, CU});
+ Constant *NotSplat = ConstantVector::get({C1, C1, C0, C1 ,C1});
+ Constant *NotSplatUndef = ConstantVector::get({CU, C1, CU, CU ,C0});
+
+ // Default - undefs are not allowed.
+ EXPECT_EQ(Splat0->getSplatValue(), C0);
+ EXPECT_EQ(Splat1->getSplatValue(), C1);
+ EXPECT_EQ(Splat0Undef->getSplatValue(), nullptr);
+ EXPECT_EQ(Splat1Undef->getSplatValue(), nullptr);
+ EXPECT_EQ(NotSplat->getSplatValue(), nullptr);
+ EXPECT_EQ(NotSplatUndef->getSplatValue(), nullptr);
+
+ // Disallow undefs explicitly.
+ EXPECT_EQ(Splat0->getSplatValue(false), C0);
+ EXPECT_EQ(Splat1->getSplatValue(false), C1);
+ EXPECT_EQ(Splat0Undef->getSplatValue(false), nullptr);
+ EXPECT_EQ(Splat1Undef->getSplatValue(false), nullptr);
+ EXPECT_EQ(NotSplat->getSplatValue(false), nullptr);
+ EXPECT_EQ(NotSplatUndef->getSplatValue(false), nullptr);
+
+ // Allow undefs.
+ EXPECT_EQ(Splat0->getSplatValue(true), C0);
+ EXPECT_EQ(Splat1->getSplatValue(true), C1);
+ EXPECT_EQ(Splat0Undef->getSplatValue(true), C0);
+ EXPECT_EQ(Splat1Undef->getSplatValue(true), C1);
+ EXPECT_EQ(NotSplat->getSplatValue(true), nullptr);
+ EXPECT_EQ(NotSplatUndef->getSplatValue(true), nullptr);
+}
+
TEST(InstructionsTest, SkipDebug) {
LLVMContext C;
std::unique_ptr<Module> M = parseIR(C,
More information about the llvm-commits
mailing list