[llvm] 3ea0774 - [ConstantFold][NFC] Compile time optimization for large vectors

Thomas Raoux via llvm-commits llvm-commits at lists.llvm.org
Mon Mar 30 11:27:49 PDT 2020


Author: Thomas Raoux
Date: 2020-03-30T11:27:09-07:00
New Revision: 3ea0774b13a538759aa1a68f30130d18ddb0d3f2

URL: https://github.com/llvm/llvm-project/commit/3ea0774b13a538759aa1a68f30130d18ddb0d3f2
DIFF: https://github.com/llvm/llvm-project/commit/3ea0774b13a538759aa1a68f30130d18ddb0d3f2.diff

LOG: [ConstantFold][NFC] Compile time optimization for large vectors

Optimize the common case of splat vector constant. For large vector
going through all elements is expensive. For splatr/broadcast cases we
can skip going through all elements.

Differential Revision: https://reviews.llvm.org/D76664

Added: 
    

Modified: 
    llvm/include/llvm/IR/Constants.h
    llvm/lib/Analysis/ValueTracking.cpp
    llvm/lib/IR/ConstantFold.cpp
    llvm/lib/IR/Constants.cpp
    llvm/lib/IR/Instructions.cpp
    llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/IR/Constants.h b/llvm/include/llvm/IR/Constants.h
index d0ea1adbbf18..a345795ff54a 100644
--- a/llvm/include/llvm/IR/Constants.h
+++ b/llvm/include/llvm/IR/Constants.h
@@ -766,7 +766,12 @@ class ConstantDataVector final : public ConstantDataSequential {
   friend class ConstantDataSequential;
 
   explicit ConstantDataVector(Type *ty, const char *Data)
-      : ConstantDataSequential(ty, ConstantDataVectorVal, Data) {}
+      : ConstantDataSequential(ty, ConstantDataVectorVal, Data),
+        IsSplatSet(false) {}
+  // Cache whether or not the constant is a splat.
+  mutable bool IsSplatSet : 1;
+  mutable bool IsSplat : 1;
+  bool isSplatData() const;
 
 public:
   ConstantDataVector(const ConstantDataVector &) = delete;

diff  --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index ab6afa1a81dc..007513020769 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -174,7 +174,13 @@ static bool getShuffleDemandedElts(const ShuffleVectorInst *Shuf,
   int NumElts = Shuf->getOperand(0)->getType()->getVectorNumElements();
   int NumMaskElts = Shuf->getMask()->getType()->getVectorNumElements();
   DemandedLHS = DemandedRHS = APInt::getNullValue(NumElts);
-
+  if (DemandedElts.isNullValue())
+    return true;
+  // Simple case of a shuffle with zeroinitializer.
+  if (isa<ConstantAggregateZero>(Shuf->getMask())) {
+    DemandedLHS.setBit(0);
+    return true;
+  }
   for (int i = 0; i != NumMaskElts; ++i) {
     if (!DemandedElts[i])
       continue;

diff  --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp
index 3e2e74c31fc0..07292e50fc8d 100644
--- a/llvm/lib/IR/ConstantFold.cpp
+++ b/llvm/lib/IR/ConstantFold.cpp
@@ -60,6 +60,11 @@ static Constant *BitCastConstantVector(Constant *CV, VectorType *DstTy) {
     return nullptr;
 
   Type *DstEltTy = DstTy->getElementType();
+  // Fast path for splatted constants.
+  if (Constant *Splat = CV->getSplatValue()) {
+    return ConstantVector::getSplat(DstTy->getVectorElementCount(),
+                                    ConstantExpr::getBitCast(Splat, DstEltTy));
+  }
 
   SmallVector<Constant*, 16> Result;
   Type *Ty = IntegerType::get(CV->getContext(), 32);
@@ -577,9 +582,15 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V,
   if ((isa<ConstantVector>(V) || isa<ConstantDataVector>(V)) &&
       DestTy->isVectorTy() &&
       DestTy->getVectorNumElements() == V->getType()->getVectorNumElements()) {
-    SmallVector<Constant*, 16> res;
     VectorType *DestVecTy = cast<VectorType>(DestTy);
     Type *DstEltTy = DestVecTy->getElementType();
+    // Fast path for splatted constants.
+    if (Constant *Splat = V->getSplatValue()) {
+      return ConstantVector::getSplat(
+          DestTy->getVectorElementCount(),
+          ConstantExpr::getCast(opc, Splat, DstEltTy));
+    }
+    SmallVector<Constant *, 16> res;
     Type *Ty = IntegerType::get(V->getContext(), 32);
     for (unsigned i = 0, e = V->getType()->getVectorNumElements(); i != e; ++i) {
       Constant *C =
@@ -878,6 +889,14 @@ Constant *llvm::ConstantFoldShuffleVectorInstruction(Constant *V1,
   // Don't break the bitcode reader hack.
   if (isa<ConstantExpr>(Mask)) return nullptr;
 
+  // If the mask is all zeros this is a splat, no need to go through all
+  // elements.
+  if (isa<ConstantAggregateZero>(Mask) && !MaskEltCount.Scalable) {
+    Type *Ty = IntegerType::get(V1->getContext(), 32);
+    Constant *Elt =
+        ConstantExpr::getExtractElement(V1, ConstantInt::get(Ty, 0));
+    return ConstantVector::getSplat(MaskEltCount, Elt);
+  }
   // Do not iterate on scalable vector. The num of elements is unknown at
   // compile-time.
   VectorType *ValTy = cast<VectorType>(V1->getType());
@@ -993,10 +1012,15 @@ Constant *llvm::ConstantFoldUnaryInstruction(unsigned Opcode, Constant *C) {
     // compile-time.
     if (IsScalableVector)
       return nullptr;
+    Type *Ty = IntegerType::get(VTy->getContext(), 32);
+    // Fast path for splatted constants.
+    if (Constant *Splat = C->getSplatValue()) {
+      Constant *Elt = ConstantExpr::get(Opcode, Splat);
+      return ConstantVector::getSplat(VTy->getElementCount(), Elt);
+    }
 
     // Fold each element and create a vector constant from those constants.
     SmallVector<Constant*, 16> Result;
-    Type *Ty = IntegerType::get(VTy->getContext(), 32);
     for (unsigned i = 0, e = VTy->getNumElements(); i != e; ++i) {
       Constant *ExtractIdx = ConstantInt::get(Ty, i);
       Constant *Elt = ConstantExpr::getExtractElement(C, ExtractIdx);
@@ -1357,6 +1381,16 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode, Constant *C1,
     // compile-time.
     if (IsScalableVector)
       return nullptr;
+    // Fast path for splatted constants.
+    if (Constant *C2Splat = C2->getSplatValue()) {
+      if (Instruction::isIntDivRem(Opcode) && C2Splat->isNullValue())
+        return UndefValue::get(VTy);
+      if (Constant *C1Splat = C1->getSplatValue()) {
+        return ConstantVector::getSplat(
+            VTy->getVectorElementCount(),
+            ConstantExpr::get(Opcode, C1Splat, C2Splat));
+      }
+    }
 
     // Fold each element and create a vector constant from those constants.
     SmallVector<Constant*, 16> Result;
@@ -1975,6 +2009,12 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred,
     // compile-time.
     if (C1->getType()->getVectorIsScalable())
       return nullptr;
+    // Fast path for splatted constants.
+    if (Constant *C1Splat = C1->getSplatValue())
+      if (Constant *C2Splat = C2->getSplatValue())
+        return ConstantVector::getSplat(
+            C1->getType()->getVectorElementCount(),
+            ConstantExpr::getCompare(pred, C1Splat, C2Splat));
 
     // If we can constant fold the comparison of each element, constant fold
     // the whole vector comparison.

diff  --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp
index bde4c07e15a3..e001b5230751 100644
--- a/llvm/lib/IR/Constants.cpp
+++ b/llvm/lib/IR/Constants.cpp
@@ -2891,7 +2891,7 @@ bool ConstantDataSequential::isCString() const {
   return Str.drop_back().find(0) == StringRef::npos;
 }
 
-bool ConstantDataVector::isSplat() const {
+bool ConstantDataVector::isSplatData() const {
   const char *Base = getRawDataValues().data();
 
   // Compare elements 1+ to the 0'th element.
@@ -2903,6 +2903,14 @@ bool ConstantDataVector::isSplat() const {
   return true;
 }
 
+bool ConstantDataVector::isSplat() const {
+  if (!IsSplatSet) {
+    IsSplatSet = true;
+    IsSplat = isSplatData();
+  }
+  return IsSplat;
+}
+
 Constant *ConstantDataVector::getSplatValue() const {
   // If they're all the same, return the 0th one as a representative.
   return isSplat() ? getElementAsConstant(0) : nullptr;

diff  --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp
index 0884a24a709e..f6748c8d864e 100644
--- a/llvm/lib/IR/Instructions.cpp
+++ b/llvm/lib/IR/Instructions.cpp
@@ -1958,7 +1958,11 @@ void ShuffleVectorInst::getShuffleMask(const Constant *Mask,
   assert(!Mask->getType()->getVectorElementCount().Scalable &&
     "Length of scalable vectors unknown at compile time");
   unsigned NumElts = Mask->getType()->getVectorNumElements();
-
+  if (isa<ConstantAggregateZero>(Mask)) {
+    Result.resize(NumElts, 0);
+    return;
+  }
+  Result.reserve(NumElts);
   if (auto *CDS = dyn_cast<ConstantDataSequential>(Mask)) {
     for (unsigned i = 0; i != NumElts; ++i)
       Result.push_back(CDS->getElementAsInteger(i));

diff  --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
index be5135e68ffd..90b00536b471 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
@@ -1387,6 +1387,24 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts,
            "Expected shuffle operands to have same type");
     unsigned OpWidth =
         Shuffle->getOperand(0)->getType()->getVectorNumElements();
+    // Handle trivial case of a splat. Only check the first element of LHS
+    // operand.
+    if (isa<ConstantAggregateZero>(Shuffle->getMask()) &&
+        DemandedElts.isAllOnesValue()) {
+      if (!isa<UndefValue>(I->getOperand(1))) {
+        I->setOperand(1, UndefValue::get(I->getOperand(1)->getType()));
+        MadeChange = true;
+      }
+      APInt LeftDemanded(OpWidth, 1);
+      APInt LHSUndefElts(OpWidth, 0);
+      simplifyAndSetOp(I, 0, LeftDemanded, LHSUndefElts);
+      if (LHSUndefElts[0])
+        UndefElts = EltMask;
+      else
+        UndefElts.clearAllBits();
+      break;
+    }
+
     APInt LeftDemanded(OpWidth, 0), RightDemanded(OpWidth, 0);
     for (unsigned i = 0; i < VWidth; i++) {
       if (DemandedElts[i]) {


        


More information about the llvm-commits mailing list