[llvm] [AArch64] Generate DOT instructions from matching IR (PR #69583)

via llvm-commits llvm-commits at lists.llvm.org
Thu Oct 19 03:26:11 PDT 2023


github-actions[bot] wrote:


<!--LLVM CODE FORMAT COMMENT: {clang-format}-->

:warning: C/C++ code formatter, clang-format found issues in your code. :warning:

<details>
<summary>
You can test this locally with the following command:
</summary>

``````````bash
git-clang-format --diff 202de4a5c6edb82d50d4bd7586c4b1db5f51073d 2b6da683e001ba852674d0f55cc5beb95c14782f -- llvm/lib/Target/AArch64/AArch64DotProdMatcher.cpp llvm/lib/Target/AArch64/AArch64.h llvm/lib/Target/AArch64/AArch64TargetMachine.cpp
``````````

</details>

<details>
<summary>
View the diff from clang-format here.
</summary>

``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64DotProdMatcher.cpp b/llvm/lib/Target/AArch64/AArch64DotProdMatcher.cpp
index 44215efee75c..c086036eb3be 100644
--- a/llvm/lib/Target/AArch64/AArch64DotProdMatcher.cpp
+++ b/llvm/lib/Target/AArch64/AArch64DotProdMatcher.cpp
@@ -13,8 +13,9 @@
 //===----------------------------------------------------------------------===//
 
 #include "AArch64.h"
-#include "llvm/ADT/Statistic.h"
+#include "Utils/AArch64BaseInfo.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/Statistic.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/Constants.h"
@@ -33,7 +34,6 @@
 #include "llvm/Support/InstructionCost.h"
 #include "llvm/Support/raw_ostream.h"
 #include "llvm/Transforms/Utils/Local.h"
-#include "Utils/AArch64BaseInfo.h"
 #include <deque>
 #include <optional>
 #include <tuple>
@@ -68,17 +68,16 @@ struct LoopAccumulate {
                  Value *Mul, Value *ValA, Value *ValB, VectorType *VTy,
                  Type *AccTy, BasicBlock *LoopBlock, BasicBlock *PHBlock,
                  bool IsSExt)
-    : RVal(RVal), Phi(Phi), IterVals(IterVals), Predicate(Predicate),
-    Mul(Mul), ValA(ValA), ValB(ValB), VTy(VTy), AccTy(AccTy), LoopBlock(LoopBlock),
-    PHBlock(PHBlock), IsSExt(IsSExt) {}
+      : RVal(RVal), Phi(Phi), IterVals(IterVals), Predicate(Predicate),
+        Mul(Mul), ValA(ValA), ValB(ValB), VTy(VTy), AccTy(AccTy),
+        LoopBlock(LoopBlock), PHBlock(PHBlock), IsSExt(IsSExt) {}
 };
 
 // Returns true if the instruction in question is an vector integer add
 // reduction intrinsic.
 static bool isScalableIntegerSumReduction(Instruction &I) {
   auto *II = dyn_cast<IntrinsicInst>(&I);
-  return II &&
-         II->getIntrinsicID() == Intrinsic::vector_reduce_add &&
+  return II && II->getIntrinsicID() == Intrinsic::vector_reduce_add &&
          isa<ScalableVectorType>(II->getOperand(0)->getType());
 }
 
@@ -99,7 +98,7 @@ static Type *getAccumulatorType(Type *EltTy, Type *ExtEltTy, ElementCount EC) {
 
 // Returns either a pair of basic block pointers corresponding to the expected
 // two incoming values for the phi, or None if one of the checks failed.
-static std::optional<std::pair<BasicBlock*, BasicBlock*>>
+static std::optional<std::pair<BasicBlock *, BasicBlock *>>
 getPHIIncomingBlocks(PHINode *Phi) {
   // Check PHI; we're only expecting the incoming value from within the loop
   // and one incoming value from a preheader.
@@ -152,8 +151,8 @@ static bool checkLoopAcc(Value *RVal, PHINode *OldPHI, Value *IterVals,
     IsSExt = true;
   else if (!match(Mul, m_OneUse(m_Mul(m_ZExt(m_OneUse(m_Value(ValA))),
                                       m_ZExt(m_OneUse(m_Value(ValB))))))) {
-    LLVM_DEBUG(dbgs() << "Couldn't match inner loop multiply: "
-                      << *Mul << "\n");
+    LLVM_DEBUG(dbgs() << "Couldn't match inner loop multiply: " << *Mul
+                      << "\n");
     return false;
   }
 
@@ -177,8 +176,8 @@ static bool checkLoopAcc(Value *RVal, PHINode *OldPHI, Value *IterVals,
   // The element count needs to be 1/4th that of the input data, since the
   // dot product instructions take four smaller elements and multiply/accumulate
   // them into one larger element.
-  Type *AccTy = getAccumulatorType(ValTy->getElementType(),
-      Mul->getType()->getScalarType(),
+  Type *AccTy = getAccumulatorType(
+      ValTy->getElementType(), Mul->getType()->getScalarType(),
       ValTy->getElementCount().divideCoefficientBy(4));
 
   if (!AccTy) {
@@ -197,15 +196,16 @@ static bool checkLoopAcc(Value *RVal, PHINode *OldPHI, Value *IterVals,
 
   // Everything looks in order, so add it to the list of accumulators to
   // transform.
-  Accumulators.emplace_back(RVal, OldPHI, IterVals, Predicate, Mul, ValA,
-                            ValB, ValTy, AccTy, PhiBlocks->first,
-                            PhiBlocks->second, IsSExt);
+  Accumulators.emplace_back(RVal, OldPHI, IterVals, Predicate, Mul, ValA, ValB,
+                            ValTy, AccTy, PhiBlocks->first, PhiBlocks->second,
+                            IsSExt);
   return true;
 }
 
-static bool findDOTAccumulatorsInLoop(Value *RVal,
-                                SmallVectorImpl<LoopAccumulate> &Accumulators,
-                                unsigned Depth = DOT_ACCUMULATOR_DEPTH) {
+static bool
+findDOTAccumulatorsInLoop(Value *RVal,
+                          SmallVectorImpl<LoopAccumulate> &Accumulators,
+                          unsigned Depth = DOT_ACCUMULATOR_DEPTH) {
   // Don't recurse too far.
   if (Depth == 0)
     return false;
@@ -215,12 +215,12 @@ static bool findDOTAccumulatorsInLoop(Value *RVal,
   // Try to match the expected pattern from a sum reduction in
   // a vectorized loop.
   if (match(RVal, m_Add(m_Value(V1), m_Value(V2)))) {
-    if (isa<PHINode>(V1) && !isa<PHINode>(V2) &&
-        V1->hasOneUse() && V2->hasOneUse())
+    if (isa<PHINode>(V1) && !isa<PHINode>(V2) && V1->hasOneUse() &&
+        V2->hasOneUse())
       return checkLoopAcc(RVal, cast<PHINode>(V1), V2, Accumulators);
 
-    if (!isa<PHINode>(V1) && isa<PHINode>(V2) &&
-        V1->hasOneUse() && V2->hasOneUse())
+    if (!isa<PHINode>(V1) && isa<PHINode>(V2) && V1->hasOneUse() &&
+        V2->hasOneUse())
       return checkLoopAcc(RVal, cast<PHINode>(V2), V1, Accumulators);
 
     // Otherwise assume this is an intermediate multi-register reduction
@@ -248,8 +248,8 @@ public:
     SmallVector<Instruction *, 4> Reductions;
     for (BasicBlock &Block : F)
       // TODO: Support non-scalable dot instructions too.
-      for (Instruction &I : make_filter_range(Block,
-                                              isScalableIntegerSumReduction))
+      for (Instruction &I :
+           make_filter_range(Block, isScalableIntegerSumReduction))
         Reductions.push_back(&I);
 
     for (auto *Rdx : Reductions)
@@ -274,10 +274,10 @@ private:
 
 char AArch64DotProdMatcher::ID = 0;
 INITIALIZE_PASS_BEGIN(AArch64DotProdMatcher, DEBUG_TYPE,
-                "AArch64 Dot Product Instruction Matcher", false, false)
+                      "AArch64 Dot Product Instruction Matcher", false, false)
 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
 INITIALIZE_PASS_END(AArch64DotProdMatcher, DEBUG_TYPE,
-                "AArch64 Dot Product Instruction Matcher", false, false)
+                    "AArch64 Dot Product Instruction Matcher", false, false)
 
 FunctionPass *llvm::createAArch64DotProdMatcherPass() {
   return new AArch64DotProdMatcher();
@@ -344,10 +344,10 @@ bool AArch64DotProdMatcher::trySimpleDotReplacement(Instruction &I) {
                                 MTy->getElementCount().divideCoefficientBy(4));
   Value *Zeroes = ConstantAggregateZero::get(AccTy);
 
-  Intrinsic::ID IntID = IsSExt ? Intrinsic::aarch64_sve_sdot :
-                                 Intrinsic::aarch64_sve_udot;
-  Value *DotProd = Builder.CreateIntrinsic(IntID, {AccTy},
-                                           {Zeroes, ValA, ValB});
+  Intrinsic::ID IntID =
+      IsSExt ? Intrinsic::aarch64_sve_sdot : Intrinsic::aarch64_sve_udot;
+  Value *DotProd =
+      Builder.CreateIntrinsic(IntID, {AccTy}, {Zeroes, ValA, ValB});
   Builder.SetInsertPoint(&I);
   Value *Reduce = Builder.CreateAddReduce(DotProd);
   I.replaceAllUsesWith(Reduce);
@@ -437,8 +437,8 @@ bool AArch64DotProdMatcher::tryLoopDotReplacement(Instruction &I) {
 
     // Need to generate selects for ValA and ValB if there was one before the
     // accumulate before.
-    // Hopefully we can fold away some extra selects (e.g. if the data originally
-    // came from masked loads with the same predicate).
+    // Hopefully we can fold away some extra selects (e.g. if the data
+    // originally came from masked loads with the same predicate).
     if (Acc.Predicate) {
       Value *Zeroes = ConstantAggregateZero::get(Acc.VTy);
       Acc.ValA = Builder.CreateSelect(Acc.Predicate, Acc.ValA, Zeroes);
@@ -446,8 +446,8 @@ bool AArch64DotProdMatcher::tryLoopDotReplacement(Instruction &I) {
     }
 
     // Now plant the dot instruction.
-    Intrinsic::ID IntID = Acc.IsSExt ? Intrinsic::aarch64_sve_sdot :
-                                          Intrinsic::aarch64_sve_udot;
+    Intrinsic::ID IntID =
+        Acc.IsSExt ? Intrinsic::aarch64_sve_sdot : Intrinsic::aarch64_sve_udot;
     Value *DotProd = Builder.CreateIntrinsic(IntID, {Acc.AccTy},
                                              {DotAcc, Acc.ValA, Acc.ValB});
     DotAcc->addIncoming(DotProd, Acc.LoopBlock);
@@ -457,9 +457,7 @@ bool AArch64DotProdMatcher::tryLoopDotReplacement(Instruction &I) {
     NumDOTInstrs++;
   }
 
-  assert(!RdxVals.empty() &&
-         "We found accumulators but generated no RdxVals");
-
+  assert(!RdxVals.empty() && "We found accumulators but generated no RdxVals");
 
   Builder.SetInsertPoint(cast<Instruction>(RVal));
 
@@ -477,10 +475,8 @@ bool AArch64DotProdMatcher::tryLoopDotReplacement(Instruction &I) {
   Value *Trunc = Builder.CreateTrunc(Reduce, I.getType(), "dot.trunc");
   I.replaceAllUsesWith(Trunc);
 
-
   // Delete the original reduction, since it's no longer required
   RecursivelyDeleteTriviallyDeadInstructions(&I);
   NumLoopDOTReplacements++;
   return true;
 }
-
diff --git a/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp b/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp
index 4a76d2f705a5..7a95a0dd898f 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetMachine.cpp
@@ -165,10 +165,9 @@ static cl::opt<bool>
                            cl::desc("Enable SVE intrinsic opts"),
                            cl::init(true));
 
-static cl::opt<bool>
-EnableAArch64DotProdMatch("aarch64-enable-dotprodmatch", cl::Hidden,
-                          cl::desc("Enable matching dot product instructions"),
-                          cl::init(true));
+static cl::opt<bool> EnableAArch64DotProdMatch(
+    "aarch64-enable-dotprodmatch", cl::Hidden,
+    cl::desc("Enable matching dot product instructions"), cl::init(true));
 
 static cl::opt<bool> EnableFalkorHWPFFix("aarch64-enable-falkor-hwpf-fix",
                                          cl::init(true), cl::Hidden);

``````````

</details>


https://github.com/llvm/llvm-project/pull/69583


More information about the llvm-commits mailing list