[llvm] r258404 - [SLP] Truncate expressions to minimum required bit width
Matthew Simpson via llvm-commits
llvm-commits at lists.llvm.org
Thu Jan 21 08:31:58 PST 2016
Author: mssimpso
Date: Thu Jan 21 10:31:55 2016
New Revision: 258404
URL: http://llvm.org/viewvc/llvm-project?rev=258404&view=rev
Log:
[SLP] Truncate expressions to minimum required bit width
This change attempts to produce vectorized integer expressions in bit widths
that are narrower than their scalar counterparts. The need for demotion arises
especially on architectures in which the small integer types (e.g., i8 and i16)
are not legal for scalar operations but can still be used in vectors. Like
similar work done within the loop vectorizer, we rely on InstCombine to perform
the actual type-shrinking. We use the DemandedBits analysis and
ComputeNumSignBits from ValueTracking to determine the minimum required bit
width of an expression.
Differential revision: http://reviews.llvm.org/D15815
Modified:
llvm/trunk/lib/Transforms/Vectorize/SLPVectorizer.cpp
llvm/trunk/test/Transforms/SLPVectorizer/AArch64/gather-reduce.ll
Modified: llvm/trunk/lib/Transforms/Vectorize/SLPVectorizer.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/Vectorize/SLPVectorizer.cpp?rev=258404&r1=258403&r2=258404&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/Vectorize/SLPVectorizer.cpp (original)
+++ llvm/trunk/lib/Transforms/Vectorize/SLPVectorizer.cpp Thu Jan 21 10:31:55 2016
@@ -15,21 +15,22 @@
// "Loop-Aware SLP in GCC" by Ira Rosen, Dorit Nuzman, Ayal Zaks.
//
//===----------------------------------------------------------------------===//
-#include "llvm/Transforms/Vectorize.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/AliasAnalysis.h"
-#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/CodeMetrics.h"
+#include "llvm/Analysis/DemandedBits.h"
+#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/ScalarEvolution.h"
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/Analysis/VectorUtils.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/Dominators.h"
#include "llvm/IR/IRBuilder.h"
@@ -44,7 +45,7 @@
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
-#include "llvm/Analysis/VectorUtils.h"
+#include "llvm/Transforms/Vectorize.h"
#include <algorithm>
#include <map>
#include <memory>
@@ -363,11 +364,12 @@ public:
BoUpSLP(Function *Func, ScalarEvolution *Se, TargetTransformInfo *Tti,
TargetLibraryInfo *TLi, AliasAnalysis *Aa, LoopInfo *Li,
- DominatorTree *Dt, AssumptionCache *AC)
+ DominatorTree *Dt, AssumptionCache *AC, DemandedBits *DB)
: NumLoadsWantToKeepOrder(0), NumLoadsWantToChangeOrder(0), F(Func),
- SE(Se), TTI(Tti), TLI(TLi), AA(Aa), LI(Li), DT(Dt),
+ SE(Se), TTI(Tti), TLI(TLi), AA(Aa), LI(Li), DT(Dt), AC(AC), DB(DB),
Builder(Se->getContext()) {
CodeMetrics::collectEphemeralValues(F, AC, EphValues);
+ MaxRequiredIntegerTy = nullptr;
}
/// \brief Vectorize the tree that starts with the elements in \p VL.
@@ -399,6 +401,7 @@ public:
BlockScheduling *BS = Iter.second.get();
BS->clear();
}
+ MaxRequiredIntegerTy = nullptr;
}
/// \returns true if the memory operations A and B are consecutive.
@@ -419,6 +422,10 @@ public:
/// vectorization factors.
unsigned getVectorElementSize(Value *V);
+ /// Compute the maximum width integer type required to represent the result
+ /// of a scalar expression, if such a type exists.
+ void computeMaxRequiredIntegerTy();
+
private:
struct TreeEntry;
@@ -924,8 +931,13 @@ private:
AliasAnalysis *AA;
LoopInfo *LI;
DominatorTree *DT;
+ AssumptionCache *AC;
+ DemandedBits *DB;
/// Instruction builder to construct the vectorized tree.
IRBuilder<> Builder;
+
+ // The maximum width integer type required to represent a scalar expression.
+ IntegerType *MaxRequiredIntegerTy;
};
#ifndef NDEBUG
@@ -1481,6 +1493,15 @@ int BoUpSLP::getEntryCost(TreeEntry *E)
ScalarTy = SI->getValueOperand()->getType();
VectorType *VecTy = VectorType::get(ScalarTy, VL.size());
+ // If we have computed a smaller type for the expression, update VecTy so
+ // that the costs will be accurate.
+ if (MaxRequiredIntegerTy) {
+ auto *IT = dyn_cast<IntegerType>(ScalarTy);
+ assert(IT && "Computed smaller type for non-integer value?");
+ if (MaxRequiredIntegerTy->getBitWidth() < IT->getBitWidth())
+ VecTy = VectorType::get(MaxRequiredIntegerTy, VL.size());
+ }
+
if (E->NeedToGather) {
if (allConstant(VL))
return 0;
@@ -1809,9 +1830,17 @@ int BoUpSLP::getTreeCost() {
if (EphValues.count(EU.User))
continue;
- VectorType *VecTy = VectorType::get(EU.Scalar->getType(), BundleWidth);
- ExtractCost += TTI->getVectorInstrCost(Instruction::ExtractElement, VecTy,
- EU.Lane);
+ // If we plan to rewrite the tree in a smaller type, we will need to sign
+ // extend the extracted value back to the original type. Here, we account
+ // for the extract and the added cost of the sign extend if needed.
+ auto *VecTy = VectorType::get(EU.Scalar->getType(), BundleWidth);
+ if (MaxRequiredIntegerTy) {
+ VecTy = VectorType::get(MaxRequiredIntegerTy, BundleWidth);
+ ExtractCost += TTI->getCastInstrCost(
+ Instruction::SExt, EU.Scalar->getType(), MaxRequiredIntegerTy);
+ }
+ ExtractCost +=
+ TTI->getVectorInstrCost(Instruction::ExtractElement, VecTy, EU.Lane);
}
Cost += getSpillCost();
@@ -2566,7 +2595,19 @@ Value *BoUpSLP::vectorizeTree() {
}
Builder.SetInsertPoint(&F->getEntryBlock().front());
- vectorizeTree(&VectorizableTree[0]);
+ auto *VectorRoot = vectorizeTree(&VectorizableTree[0]);
+
+ // If the vectorized tree can be rewritten in a smaller type, we truncate the
+ // vectorized root. InstCombine will then rewrite the entire expression. We
+ // sign extend the extracted values below.
+ if (MaxRequiredIntegerTy) {
+ BasicBlock::iterator I(cast<Instruction>(VectorRoot));
+ Builder.SetInsertPoint(&*++I);
+ auto BundleWidth = VectorizableTree[0].Scalars.size();
+ auto *SmallerTy = VectorType::get(MaxRequiredIntegerTy, BundleWidth);
+ auto *Trunc = Builder.CreateTrunc(VectorRoot, SmallerTy);
+ VectorizableTree[0].VectorizedValue = Trunc;
+ }
DEBUG(dbgs() << "SLP: Extracting " << ExternalUses.size() << " values .\n");
@@ -2599,6 +2640,8 @@ Value *BoUpSLP::vectorizeTree() {
if (PH->getIncomingValue(i) == Scalar) {
Builder.SetInsertPoint(PH->getIncomingBlock(i)->getTerminator());
Value *Ex = Builder.CreateExtractElement(Vec, Lane);
+ if (MaxRequiredIntegerTy)
+ Ex = Builder.CreateSExt(Ex, Scalar->getType());
CSEBlocks.insert(PH->getIncomingBlock(i));
PH->setOperand(i, Ex);
}
@@ -2606,12 +2649,16 @@ Value *BoUpSLP::vectorizeTree() {
} else {
Builder.SetInsertPoint(cast<Instruction>(User));
Value *Ex = Builder.CreateExtractElement(Vec, Lane);
+ if (MaxRequiredIntegerTy)
+ Ex = Builder.CreateSExt(Ex, Scalar->getType());
CSEBlocks.insert(cast<Instruction>(User)->getParent());
User->replaceUsesOfWith(Scalar, Ex);
}
} else {
Builder.SetInsertPoint(&F->getEntryBlock().front());
Value *Ex = Builder.CreateExtractElement(Vec, Lane);
+ if (MaxRequiredIntegerTy)
+ Ex = Builder.CreateSExt(Ex, Scalar->getType());
CSEBlocks.insert(&F->getEntryBlock());
User->replaceUsesOfWith(Scalar, Ex);
}
@@ -3180,7 +3227,7 @@ unsigned BoUpSLP::getVectorElementSize(V
// If the current instruction is a load, update MaxWidth to reflect the
// width of the loaded value.
else if (isa<LoadInst>(I))
- MaxWidth = std::max(MaxWidth, (unsigned)DL.getTypeSizeInBits(Ty));
+ MaxWidth = std::max<unsigned>(MaxWidth, DL.getTypeSizeInBits(Ty));
// Otherwise, we need to visit the operands of the instruction. We only
// handle the interesting cases from buildTree here. If an operand is an
@@ -3207,6 +3254,85 @@ unsigned BoUpSLP::getVectorElementSize(V
return MaxWidth;
}
+void BoUpSLP::computeMaxRequiredIntegerTy() {
+
+ // If there are no external uses, the expression tree must be rooted by a
+ // store. We can't demote in-memory values, so there is nothing to do here.
+ if (ExternalUses.empty())
+ return;
+
+ // If the expression is not rooted by a store, these roots should have
+ // external uses. We will rely on InstCombine to rewrite the expression in
+ // the narrower type. However, InstCombine only rewrites single-use values.
+ // This means that if a tree entry other than a root is used externally, it
+ // must have multiple uses and InstCombine will not rewrite it. The code
+ // below ensures that only the roots are used externally.
+ auto &TreeRoot = VectorizableTree[0].Scalars;
+ SmallPtrSet<Value *, 16> ScalarRoots(TreeRoot.begin(), TreeRoot.end());
+ for (auto &EU : ExternalUses)
+ if (!ScalarRoots.erase(EU.Scalar))
+ return;
+ if (!ScalarRoots.empty())
+ return;
+
+ // The maximum bit width required to represent all the instructions in the
+ // tree without loss of precision. It would be safe to truncate the
+ // expression to this width.
+ auto MaxBitWidth = 8u;
+
+ // We first check if all the bits of the root are demanded. If they're not,
+ // we can truncate the root to this narrower type.
+ auto *Root = dyn_cast<Instruction>(TreeRoot[0]);
+ if (!Root || !isa<IntegerType>(Root->getType()) || !Root->hasOneUse())
+ return;
+ auto Mask = DB->getDemandedBits(Root);
+ if (Mask.countLeadingZeros() > 0)
+ MaxBitWidth = Mask.getBitWidth() - Mask.countLeadingZeros();
+
+ // If all the bits of the root are demanded, we can try a little harder to
+ // compute a narrower type. This can happen, for example, if the roots are
+ // getelementptr indices. InstCombine promotes these indices to the pointer
+ // width. Thus, all their bits are technically demanded even though the
+ // address computation might be vectorized in a smaller type. We start by
+ // looking at each entry in the tree.
+ else
+ for (auto &Entry : VectorizableTree) {
+
+ // Get a representative value for the vectorizable bundle. All values in
+ // Entry.Scalars should be isomorphic.
+ auto *Scalar = Entry.Scalars[0];
+
+ // If the scalar is used more than once, InstCombine will not rewrite it,
+ // so we should give up.
+ if (!Scalar->hasOneUse())
+ return;
+
+ // We only compute smaller integer types. If the scalar has a different
+ // type, give up.
+ auto *IT = dyn_cast<IntegerType>(Scalar->getType());
+ if (!IT)
+ return;
+
+ // Compute the maximum bit width required to store the scalar. We use
+ // ValueTracking to compute the number of high-order bits we can
+ // truncate. We then round up to the next power-of-two.
+ auto &DL = F->getParent()->getDataLayout();
+ auto NumSignBits = ComputeNumSignBits(Scalar, DL, 0, AC, 0, DT);
+ auto NumTypeBits = IT->getBitWidth();
+ MaxBitWidth = std::max<unsigned>(NumTypeBits - NumSignBits, MaxBitWidth);
+ }
+
+ // Round up to the next power-of-two.
+ if (!isPowerOf2_64(MaxBitWidth))
+ MaxBitWidth = NextPowerOf2(MaxBitWidth);
+
+ // If the maximum bit width we compute is less than the with of the roots'
+ // type, we can proceed with the narrowing. Otherwise, do nothing.
+ auto *RootIT = cast<IntegerType>(TreeRoot[0]->getType());
+ if (MaxBitWidth > 0 && MaxBitWidth < RootIT->getBitWidth())
+ MaxRequiredIntegerTy = IntegerType::get(F->getContext(), MaxBitWidth);
+}
+
/// The SLPVectorizer Pass.
struct SLPVectorizer : public FunctionPass {
typedef SmallVector<StoreInst *, 8> StoreList;
@@ -3228,6 +3354,7 @@ struct SLPVectorizer : public FunctionPa
LoopInfo *LI;
DominatorTree *DT;
AssumptionCache *AC;
+ DemandedBits *DB;
bool runOnFunction(Function &F) override {
if (skipOptnoneFunction(F))
@@ -3241,6 +3368,7 @@ struct SLPVectorizer : public FunctionPa
LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
+ DB = &getAnalysis<DemandedBits>();
Stores.clear();
GEPs.clear();
@@ -3270,7 +3398,7 @@ struct SLPVectorizer : public FunctionPa
// Use the bottom up slp vectorizer to construct chains that start with
// store instructions.
- BoUpSLP R(&F, SE, TTI, TLI, AA, LI, DT, AC);
+ BoUpSLP R(&F, SE, TTI, TLI, AA, LI, DT, AC, DB);
// A general note: the vectorizer must use BoUpSLP::eraseInstruction() to
// delete instructions.
@@ -3313,6 +3441,7 @@ struct SLPVectorizer : public FunctionPa
AU.addRequired<TargetTransformInfoWrapperPass>();
AU.addRequired<LoopInfoWrapperPass>();
AU.addRequired<DominatorTreeWrapperPass>();
+ AU.addRequired<DemandedBits>();
AU.addPreserved<LoopInfoWrapperPass>();
AU.addPreserved<DominatorTreeWrapperPass>();
AU.addPreserved<AAResultsWrapperPass>();
@@ -3417,6 +3546,7 @@ bool SLPVectorizer::vectorizeStoreChain(
ArrayRef<Value *> Operands = Chain.slice(i, VF);
R.buildTree(Operands);
+ R.computeMaxRequiredIntegerTy();
int Cost = R.getTreeCost();
@@ -3616,6 +3746,7 @@ bool SLPVectorizer::tryToVectorizeList(A
Value *ReorderedOps[] = { Ops[1], Ops[0] };
R.buildTree(ReorderedOps, None);
}
+ R.computeMaxRequiredIntegerTy();
int Cost = R.getTreeCost();
if (Cost < -SLPCostThreshold) {
@@ -3882,6 +4013,7 @@ public:
for (; i < NumReducedVals - ReduxWidth + 1; i += ReduxWidth) {
V.buildTree(makeArrayRef(&ReducedVals[i], ReduxWidth), ReductionOps);
+ V.computeMaxRequiredIntegerTy();
// Estimate cost.
int Cost = V.getTreeCost() + getReductionCost(TTI, ReducedVals[i]);
Modified: llvm/trunk/test/Transforms/SLPVectorizer/AArch64/gather-reduce.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/SLPVectorizer/AArch64/gather-reduce.ll?rev=258404&r1=258403&r2=258404&view=diff
==============================================================================
--- llvm/trunk/test/Transforms/SLPVectorizer/AArch64/gather-reduce.ll (original)
+++ llvm/trunk/test/Transforms/SLPVectorizer/AArch64/gather-reduce.ll Thu Jan 21 10:31:55 2016
@@ -1,4 +1,5 @@
-; RUN: opt -S -slp-vectorizer -dce -instcombine < %s | FileCheck %s
+; RUN: opt -S -slp-vectorizer -dce -instcombine < %s | FileCheck %s --check-prefix=PROFITABLE
+; RUN: opt -S -slp-vectorizer -slp-threshold=-12 -dce -instcombine < %s | FileCheck %s --check-prefix=UNPROFITABLE
target datalayout = "e-m:e-i64:64-i128:128-n32:64-S128"
target triple = "aarch64--linux-gnu"
@@ -18,13 +19,13 @@ target triple = "aarch64--linux-gnu"
; return sum;
; }
-; CHECK-LABEL: @gather_reduce_8x16_i32
+; PROFITABLE-LABEL: @gather_reduce_8x16_i32
;
-; CHECK: [[L:%[a-zA-Z0-9.]+]] = load <8 x i16>
-; CHECK: zext <8 x i16> [[L]] to <8 x i32>
-; CHECK: [[S:%[a-zA-Z0-9.]+]] = sub nsw <8 x i32>
-; CHECK: [[X:%[a-zA-Z0-9.]+]] = extractelement <8 x i32> [[S]]
-; CHECK: sext i32 [[X]] to i64
+; PROFITABLE: [[L:%[a-zA-Z0-9.]+]] = load <8 x i16>
+; PROFITABLE: zext <8 x i16> [[L]] to <8 x i32>
+; PROFITABLE: [[S:%[a-zA-Z0-9.]+]] = sub nsw <8 x i32>
+; PROFITABLE: [[X:%[a-zA-Z0-9.]+]] = extractelement <8 x i32> [[S]]
+; PROFITABLE: sext i32 [[X]] to i64
;
define i32 @gather_reduce_8x16_i32(i16* nocapture readonly %a, i16* nocapture readonly %b, i16* nocapture readonly %g, i32 %n) {
entry:
@@ -137,14 +138,18 @@ for.body:
br i1 %exitcond, label %for.cond.cleanup.loopexit, label %for.body
}
-; CHECK-LABEL: @gather_reduce_8x16_i64
+; UNPROFITABLE-LABEL: @gather_reduce_8x16_i64
;
-; CHECK-NOT: load <8 x i16>
-;
-; FIXME: We are currently unable to vectorize the case with i64 subtraction
-; because the zero extensions are too expensive. The solution here is to
-; convert the i64 subtractions to i32 subtractions during vectorization.
-; This would then match the case above.
+; UNPROFITABLE: [[L:%[a-zA-Z0-9.]+]] = load <8 x i16>
+; UNPROFITABLE: zext <8 x i16> [[L]] to <8 x i32>
+; UNPROFITABLE: [[S:%[a-zA-Z0-9.]+]] = sub nsw <8 x i32>
+; UNPROFITABLE: [[X:%[a-zA-Z0-9.]+]] = extractelement <8 x i32> [[S]]
+; UNPROFITABLE: sext i32 [[X]] to i64
+;
+; TODO: Although we can now vectorize this case while converting the i64
+; subtractions to i32, the cost model currently finds vectorization to be
+; unprofitable. The cost model is penalizing the sign and zero
+; extensions in the vectorized version, but they are actually free.
;
define i32 @gather_reduce_8x16_i64(i16* nocapture readonly %a, i16* nocapture readonly %b, i16* nocapture readonly %g, i32 %n) {
entry:
More information about the llvm-commits
mailing list