[llvm] 3ea0774 - [ConstantFold][NFC] Compile time optimization for large vectors
Thomas Raoux via llvm-commits
llvm-commits at lists.llvm.org
Mon Mar 30 11:27:49 PDT 2020
Author: Thomas Raoux
Date: 2020-03-30T11:27:09-07:00
New Revision: 3ea0774b13a538759aa1a68f30130d18ddb0d3f2
URL: https://github.com/llvm/llvm-project/commit/3ea0774b13a538759aa1a68f30130d18ddb0d3f2
DIFF: https://github.com/llvm/llvm-project/commit/3ea0774b13a538759aa1a68f30130d18ddb0d3f2.diff
LOG: [ConstantFold][NFC] Compile time optimization for large vectors
Optimize the common case of splat vector constant. For large vector
going through all elements is expensive. For splatr/broadcast cases we
can skip going through all elements.
Differential Revision: https://reviews.llvm.org/D76664
Added:
Modified:
llvm/include/llvm/IR/Constants.h
llvm/lib/Analysis/ValueTracking.cpp
llvm/lib/IR/ConstantFold.cpp
llvm/lib/IR/Constants.cpp
llvm/lib/IR/Instructions.cpp
llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
Removed:
################################################################################
diff --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h
index d0ea1adbbf18..a345795ff54a 100644
--- a/llvm/include/llvm/IR/Constants.h
+++ b/llvm/include/llvm/IR/Constants.h
@@ -766,7 +766,12 @@ class ConstantDataVector final : public ConstantDataSequential {
friend class ConstantDataSequential;
explicit ConstantDataVector(Type *ty, const char *Data)
- : ConstantDataSequential(ty, ConstantDataVectorVal, Data) {}
+ : ConstantDataSequential(ty, ConstantDataVectorVal, Data),
+ IsSplatSet(false) {}
+ // Cache whether or not the constant is a splat.
+ mutable bool IsSplatSet : 1;
+ mutable bool IsSplat : 1;
+ bool isSplatData() const;
public:
ConstantDataVector(const ConstantDataVector &) = delete;
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index ab6afa1a81dc..007513020769 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -174,7 +174,13 @@ static bool getShuffleDemandedElts(const ShuffleVectorInst *Shuf,
int NumElts = Shuf->getOperand(0)->getType()->getVectorNumElements();
int NumMaskElts = Shuf->getMask()->getType()->getVectorNumElements();
DemandedLHS = DemandedRHS = APInt::getNullValue(NumElts);
-
+ if (DemandedElts.isNullValue())
+ return true;
+ // Simple case of a shuffle with zeroinitializer.
+ if (isa<ConstantAggregateZero>(Shuf->getMask())) {
+ DemandedLHS.setBit(0);
+ return true;
+ }
for (int i = 0; i != NumMaskElts; ++i) {
if (!DemandedElts[i])
continue;
diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp
index 3e2e74c31fc0..07292e50fc8d 100644
--- a/llvm/lib/IR/ConstantFold.cpp
+++ b/llvm/lib/IR/ConstantFold.cpp
@@ -60,6 +60,11 @@ static Constant *BitCastConstantVector(Constant *CV, VectorType *DstTy) {
return nullptr;
Type *DstEltTy = DstTy->getElementType();
+ // Fast path for splatted constants.
+ if (Constant *Splat = CV->getSplatValue()) {
+ return ConstantVector::getSplat(DstTy->getVectorElementCount(),
+ ConstantExpr::getBitCast(Splat, DstEltTy));
+ }
SmallVector<Constant*, 16> Result;
Type *Ty = IntegerType::get(CV->getContext(), 32);
@@ -577,9 +582,15 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V,
if ((isa<ConstantVector>(V) || isa<ConstantDataVector>(V)) &&
DestTy->isVectorTy() &&
DestTy->getVectorNumElements() == V->getType()->getVectorNumElements()) {
- SmallVector<Constant*, 16> res;
VectorType *DestVecTy = cast<VectorType>(DestTy);
Type *DstEltTy = DestVecTy->getElementType();
+ // Fast path for splatted constants.
+ if (Constant *Splat = V->getSplatValue()) {
+ return ConstantVector::getSplat(
+ DestTy->getVectorElementCount(),
+ ConstantExpr::getCast(opc, Splat, DstEltTy));
+ }
+ SmallVector<Constant *, 16> res;
Type *Ty = IntegerType::get(V->getContext(), 32);
for (unsigned i = 0, e = V->getType()->getVectorNumElements(); i != e; ++i) {
Constant *C =
@@ -878,6 +889,14 @@ Constant *llvm::ConstantFoldShuffleVectorInstruction(Constant *V1,
// Don't break the bitcode reader hack.
if (isa<ConstantExpr>(Mask)) return nullptr;
+ // If the mask is all zeros this is a splat, no need to go through all
+ // elements.
+ if (isa<ConstantAggregateZero>(Mask) && !MaskEltCount.Scalable) {
+ Type *Ty = IntegerType::get(V1->getContext(), 32);
+ Constant *Elt =
+ ConstantExpr::getExtractElement(V1, ConstantInt::get(Ty, 0));
+ return ConstantVector::getSplat(MaskEltCount, Elt);
+ }
// Do not iterate on scalable vector. The num of elements is unknown at
// compile-time.
VectorType *ValTy = cast<VectorType>(V1->getType());
@@ -993,10 +1012,15 @@ Constant *llvm::ConstantFoldUnaryInstruction(unsigned Opcode, Constant *C) {
// compile-time.
if (IsScalableVector)
return nullptr;
+ Type *Ty = IntegerType::get(VTy->getContext(), 32);
+ // Fast path for splatted constants.
+ if (Constant *Splat = C->getSplatValue()) {
+ Constant *Elt = ConstantExpr::get(Opcode, Splat);
+ return ConstantVector::getSplat(VTy->getElementCount(), Elt);
+ }
// Fold each element and create a vector constant from those constants.
SmallVector<Constant*, 16> Result;
- Type *Ty = IntegerType::get(VTy->getContext(), 32);
for (unsigned i = 0, e = VTy->getNumElements(); i != e; ++i) {
Constant *ExtractIdx = ConstantInt::get(Ty, i);
Constant *Elt = ConstantExpr::getExtractElement(C, ExtractIdx);
@@ -1357,6 +1381,16 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode, Constant *C1,
// compile-time.
if (IsScalableVector)
return nullptr;
+ // Fast path for splatted constants.
+ if (Constant *C2Splat = C2->getSplatValue()) {
+ if (Instruction::isIntDivRem(Opcode) && C2Splat->isNullValue())
+ return UndefValue::get(VTy);
+ if (Constant *C1Splat = C1->getSplatValue()) {
+ return ConstantVector::getSplat(
+ VTy->getVectorElementCount(),
+ ConstantExpr::get(Opcode, C1Splat, C2Splat));
+ }
+ }
// Fold each element and create a vector constant from those constants.
SmallVector<Constant*, 16> Result;
@@ -1975,6 +2009,12 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred,
// compile-time.
if (C1->getType()->getVectorIsScalable())
return nullptr;
+ // Fast path for splatted constants.
+ if (Constant *C1Splat = C1->getSplatValue())
+ if (Constant *C2Splat = C2->getSplatValue())
+ return ConstantVector::getSplat(
+ C1->getType()->getVectorElementCount(),
+ ConstantExpr::getCompare(pred, C1Splat, C2Splat));
// If we can constant fold the comparison of each element, constant fold
// the whole vector comparison.
diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index bde4c07e15a3..e001b5230751 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -2891,7 +2891,7 @@ bool ConstantDataSequential::isCString() const {
return Str.drop_back().find(0) == StringRef::npos;
}
-bool ConstantDataVector::isSplat() const {
+bool ConstantDataVector::isSplatData() const {
const char *Base = getRawDataValues().data();
// Compare elements 1+ to the 0'th element.
@@ -2903,6 +2903,14 @@ bool ConstantDataVector::isSplat() const {
return true;
}
+bool ConstantDataVector::isSplat() const {
+ if (!IsSplatSet) {
+ IsSplatSet = true;
+ IsSplat = isSplatData();
+ }
+ return IsSplat;
+}
+
Constant *ConstantDataVector::getSplatValue() const {
// If they're all the same, return the 0th one as a representative.
return isSplat() ? getElementAsConstant(0) : nullptr;
diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index 0884a24a709e..f6748c8d864e 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -1958,7 +1958,11 @@ void ShuffleVectorInst::getShuffleMask(const Constant *Mask,
assert(!Mask->getType()->getVectorElementCount().Scalable &&
"Length of scalable vectors unknown at compile time");
unsigned NumElts = Mask->getType()->getVectorNumElements();
-
+ if (isa<ConstantAggregateZero>(Mask)) {
+ Result.resize(NumElts, 0);
+ return;
+ }
+ Result.reserve(NumElts);
if (auto *CDS = dyn_cast<ConstantDataSequential>(Mask)) {
for (unsigned i = 0; i != NumElts; ++i)
Result.push_back(CDS->getElementAsInteger(i));
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
index be5135e68ffd..90b00536b471 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
@@ -1387,6 +1387,24 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts,
"Expected shuffle operands to have same type");
unsigned OpWidth =
Shuffle->getOperand(0)->getType()->getVectorNumElements();
+ // Handle trivial case of a splat. Only check the first element of LHS
+ // operand.
+ if (isa<ConstantAggregateZero>(Shuffle->getMask()) &&
+ DemandedElts.isAllOnesValue()) {
+ if (!isa<UndefValue>(I->getOperand(1))) {
+ I->setOperand(1, UndefValue::get(I->getOperand(1)->getType()));
+ MadeChange = true;
+ }
+ APInt LeftDemanded(OpWidth, 1);
+ APInt LHSUndefElts(OpWidth, 0);
+ simplifyAndSetOp(I, 0, LeftDemanded, LHSUndefElts);
+ if (LHSUndefElts[0])
+ UndefElts = EltMask;
+ else
+ UndefElts.clearAllBits();
+ break;
+ }
+
APInt LeftDemanded(OpWidth, 0), RightDemanded(OpWidth, 0);
for (unsigned i = 0; i < VWidth; i++) {
if (DemandedElts[i]) {
More information about the llvm-commits
mailing list