[llvm] [AArch64] Generate DOT instructions from matching IR (PR #69583)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Oct 19 03:26:11 PDT 2023
github-actions[bot] wrote:
<!--LLVM CODE FORMAT COMMENT: {clang-format}-->
:warning: C/C++ code formatter, clang-format found issues in your code. :warning:
<details>
<summary>
You can test this locally with the following command:
</summary>
``````````bash
git-clang-format --diff 202de4a5c6edb82d50d4bd7586c4b1db5f51073d 2b6da683e001ba852674d0f55cc5beb95c14782f -- llvm/lib/Target/AArch64/AArch64DotProdMatcher.cpp llvm/lib/Target/AArch64/AArch64.h llvm/lib/Target/AArch64/AArch64TargetMachine.cpp
``````````
</details>
<details>
<summary>
View the diff from clang-format here.
</summary>
``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64DotProdMatcher.cpp b/llvm/lib/Target/AArch64/AArch64DotProdMatcher.cpp
index 44215efee75c..c086036eb3be 100644
--- a/llvm/lib/Target/AArch64/AArch64DotProdMatcher.cpp
+++ b/llvm/lib/Target/AArch64/AArch64DotProdMatcher.cpp
@@ -13,8 +13,9 @@
//===----------------------------------------------------------------------===//
#include "AArch64.h"
-#include "llvm/ADT/Statistic.h"
+#include "Utils/AArch64BaseInfo.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constants.h"
@@ -33,7 +34,6 @@
#include "llvm/Support/InstructionCost.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Utils/Local.h"
-#include "Utils/AArch64BaseInfo.h"
#include <deque>
#include <optional>
#include <tuple>
@@ -68,17 +68,16 @@ struct LoopAccumulate {
Value *Mul, Value *ValA, Value *ValB, VectorType *VTy,
Type *AccTy, BasicBlock *LoopBlock, BasicBlock *PHBlock,
bool IsSExt)
- : RVal(RVal), Phi(Phi), IterVals(IterVals), Predicate(Predicate),
- Mul(Mul), ValA(ValA), ValB(ValB), VTy(VTy), AccTy(AccTy), LoopBlock(LoopBlock),
- PHBlock(PHBlock), IsSExt(IsSExt) {}
+ : RVal(RVal), Phi(Phi), IterVals(IterVals), Predicate(Predicate),
+ Mul(Mul), ValA(ValA), ValB(ValB), VTy(VTy), AccTy(AccTy),
+ LoopBlock(LoopBlock), PHBlock(PHBlock), IsSExt(IsSExt) {}
};
// Returns true if the instruction in question is an vector integer add
// reduction intrinsic.
static bool isScalableIntegerSumReduction(Instruction &I) {
auto *II = dyn_cast<IntrinsicInst>(&I);
- return II &&
- II->getIntrinsicID() == Intrinsic::vector_reduce_add &&
+ return II && II->getIntrinsicID() == Intrinsic::vector_reduce_add &&
isa<ScalableVectorType>(II->getOperand(0)->getType());
}
@@ -99,7 +98,7 @@ static Type *getAccumulatorType(Type *EltTy, Type *ExtEltTy, ElementCount EC) {
// Returns either a pair of basic block pointers corresponding to the expected
// two incoming values for the phi, or None if one of the checks failed.
-static std::optional<std::pair<BasicBlock*, BasicBlock*>>
+static std::optional<std::pair<BasicBlock *, BasicBlock *>>
getPHIIncomingBlocks(PHINode *Phi) {
// Check PHI; we're only expecting the incoming value from within the loop
// and one incoming value from a preheader.
@@ -152,8 +151,8 @@ static bool checkLoopAcc(Value *RVal, PHINode *OldPHI, Value *IterVals,
IsSExt = true;
else if (!match(Mul, m_OneUse(m_Mul(m_ZExt(m_OneUse(m_Value(ValA))),
m_ZExt(m_OneUse(m_Value(ValB))))))) {
- LLVM_DEBUG(dbgs() << "Couldn't match inner loop multiply: "
- << *Mul << "\n");
+ LLVM_DEBUG(dbgs() << "Couldn't match inner loop multiply: " << *Mul
+ << "\n");
return false;
}
@@ -177,8 +176,8 @@ static bool checkLoopAcc(Value *RVal, PHINode *OldPHI, Value *IterVals,
// The element count needs to be 1/4th that of the input data, since the
// dot product instructions take four smaller elements and multiply/accumulate
// them into one larger element.
- Type *AccTy = getAccumulatorType(ValTy->getElementType(),
- Mul->getType()->getScalarType(),
+ Type *AccTy = getAccumulatorType(
+ ValTy->getElementType(), Mul->getType()->getScalarType(),
ValTy->getElementCount().divideCoefficientBy(4));
if (!AccTy) {
@@ -197,15 +196,16 @@ static bool checkLoopAcc(Value *RVal, PHINode *OldPHI, Value *IterVals,
// Everything looks in order, so add it to the list of accumulators to
// transform.
- Accumulators.emplace_back(RVal, OldPHI, IterVals, Predicate, Mul, ValA,
- ValB, ValTy, AccTy, PhiBlocks->first,
- PhiBlocks->second, IsSExt);
+ Accumulators.emplace_back(RVal, OldPHI, IterVals, Predicate, Mul, ValA, ValB,
+ ValTy, AccTy, PhiBlocks->first, PhiBlocks->second,
+ IsSExt);
return true;
}
-static bool findDOTAccumulatorsInLoop(Value *RVal,
- SmallVectorImpl<LoopAccumulate> &Accumulators,
- unsigned Depth = DOT_ACCUMULATOR_DEPTH) {
+static bool
+findDOTAccumulatorsInLoop(Value *RVal,
+ SmallVectorImpl<LoopAccumulate> &Accumulators,
+ unsigned Depth = DOT_ACCUMULATOR_DEPTH) {
// Don't recurse too far.
if (Depth == 0)
return false;
@@ -215,12 +215,12 @@ static bool findDOTAccumulatorsInLoop(Value *RVal,
// Try to match the expected pattern from a sum reduction in
// a vectorized loop.
if (match(RVal, m_Add(m_Value(V1), m_Value(V2)))) {
- if (isa<PHINode>(V1) && !isa<PHINode>(V2) &&
- V1->hasOneUse() && V2->hasOneUse())
+ if (isa<PHINode>(V1) && !isa<PHINode>(V2) && V1->hasOneUse() &&
+ V2->hasOneUse())
return checkLoopAcc(RVal, cast<PHINode>(V1), V2, Accumulators);
- if (!isa<PHINode>(V1) && isa<PHINode>(V2) &&
- V1->hasOneUse() && V2->hasOneUse())
+ if (!isa<PHINode>(V1) && isa<PHINode>(V2) && V1->hasOneUse() &&
+ V2->hasOneUse())
return checkLoopAcc(RVal, cast<PHINode>(V2), V1, Accumulators);
// Otherwise assume this is an intermediate multi-register reduction
@@ -248,8 +248,8 @@ public:
SmallVector<Instruction *, 4> Reductions;
for (BasicBlock &Block : F)
// TODO: Support non-scalable dot instructions too.
- for (Instruction &I : make_filter_range(Block,
- isScalableIntegerSumReduction))
+ for (Instruction &I :
+ make_filter_range(Block, isScalableIntegerSumReduction))
Reductions.push_back(&I);
for (auto *Rdx : Reductions)
@@ -274,10 +274,10 @@ private:
char AArch64DotProdMatcher::ID = 0;
INITIALIZE_PASS_BEGIN(AArch64DotProdMatcher, DEBUG_TYPE,
- "AArch64 Dot Product Instruction Matcher", false, false)
+ "AArch64 Dot Product Instruction Matcher", false, false)
INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
INITIALIZE_PASS_END(AArch64DotProdMatcher, DEBUG_TYPE,
- "AArch64 Dot Product Instruction Matcher", false, false)
+ "AArch64 Dot Product Instruction Matcher", false, false)
FunctionPass *llvm::createAArch64DotProdMatcherPass() {
return new AArch64DotProdMatcher();
@@ -344,10 +344,10 @@ bool AArch64DotProdMatcher::trySimpleDotReplacement(Instruction &I) {
MTy->getElementCount().divideCoefficientBy(4));
Value *Zeroes = ConstantAggregateZero::get(AccTy);
- Intrinsic::ID IntID = IsSExt ? Intrinsic::aarch64_sve_sdot :
- Intrinsic::aarch64_sve_udot;
- Value *DotProd = Builder.CreateIntrinsic(IntID, {AccTy},
- {Zeroes, ValA, ValB});
+ Intrinsic::ID IntID =
+ IsSExt ? Intrinsic::aarch64_sve_sdot : Intrinsic::aarch64_sve_udot;
+ Value *DotProd =
+ Builder.CreateIntrinsic(IntID, {AccTy}, {Zeroes, ValA, ValB});
Builder.SetInsertPoint(&I);
Value *Reduce = Builder.CreateAddReduce(DotProd);
I.replaceAllUsesWith(Reduce);
@@ -437,8 +437,8 @@ bool AArch64DotProdMatcher::tryLoopDotReplacement(Instruction &I) {
// Need to generate selects for ValA and ValB if there was one before the
// accumulate before.
- // Hopefully we can fold away some extra selects (e.g. if the data originally
- // came from masked loads with the same predicate).
+ // Hopefully we can fold away some extra selects (e.g. if the data
+ // originally came from masked loads with the same predicate).
if (Acc.Predicate) {
Value *Zeroes = ConstantAggregateZero::get(Acc.VTy);
Acc.ValA = Builder.CreateSelect(Acc.Predicate, Acc.ValA, Zeroes);
@@ -446,8 +446,8 @@ bool AArch64DotProdMatcher::tryLoopDotReplacement(Instruction &I) {
}
// Now plant the dot instruction.
- Intrinsic::ID IntID = Acc.IsSExt ? Intrinsic::aarch64_sve_sdot :
- Intrinsic::aarch64_sve_udot;
+ Intrinsic::ID IntID =
+ Acc.IsSExt ? Intrinsic::aarch64_sve_sdot : Intrinsic::aarch64_sve_udot;
Value *DotProd = Builder.CreateIntrinsic(IntID, {Acc.AccTy},
{DotAcc, Acc.ValA, Acc.ValB});
DotAcc->addIncoming(DotProd, Acc.LoopBlock);
@@ -457,9 +457,7 @@ bool AArch64DotProdMatcher::tryLoopDotReplacement(Instruction &I) {
NumDOTInstrs++;
}
- assert(!RdxVals.empty() &&
- "We found accumulators but generated no RdxVals");
-
+ assert(!RdxVals.empty() && "We found accumulators but generated no RdxVals");
Builder.SetInsertPoint(cast<Instruction>(RVal));
@@ -477,10 +475,8 @@ bool AArch64DotProdMatcher::tryLoopDotReplacement(Instruction &I) {
Value *Trunc = Builder.CreateTrunc(Reduce, I.getType(), "dot.trunc");
I.replaceAllUsesWith(Trunc);
-
// Delete the original reduction, since it's no longer required
RecursivelyDeleteTriviallyDeadInstructions(&I);
NumLoopDOTReplacements++;
return true;
}
-
diff --git a/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp b/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp
index 4a76d2f705a5..7a95a0dd898f 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp
@@ -165,10 +165,9 @@ static cl::opt<bool>
cl::desc("Enable SVE intrinsic opts"),
cl::init(true));
-static cl::opt<bool>
-EnableAArch64DotProdMatch("aarch64-enable-dotprodmatch", cl::Hidden,
- cl::desc("Enable matching dot product instructions"),
- cl::init(true));
+static cl::opt<bool> EnableAArch64DotProdMatch(
+ "aarch64-enable-dotprodmatch", cl::Hidden,
+ cl::desc("Enable matching dot product instructions"), cl::init(true));
static cl::opt<bool> EnableFalkorHWPFFix("aarch64-enable-falkor-hwpf-fix",
cl::init(true), cl::Hidden);
``````````
</details>
https://github.com/llvm/llvm-project/pull/69583
More information about the llvm-commits
mailing list