[llvm] Add support for single reductions in ComplexDeinterleavingPass (PR #112875)

via llvm-commits llvm-commits at lists.llvm.org
Fri Oct 18 03:10:18 PDT 2024


github-actions[bot] wrote:

<!--LLVM CODE FORMAT COMMENT: {clang-format}-->


:warning: C/C++ code formatter, clang-format found issues in your code. :warning:

<details>
<summary>
You can test this locally with the following command:
</summary>

``````````bash
git-clang-format --diff 3764d0ff15ef281974879002e27857a041bd5b9c d77a3e402d05290903f75ce2dc7478f14a3943ef --extensions h,cpp -- llvm/include/llvm/CodeGen/ComplexDeinterleavingPass.h llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
``````````

</details>

<details>
<summary>
View the diff from clang-format here.
</summary>

``````````diff
diff --git a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
index 08287a4d5e..a79105ab0c 100644
--- a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
+++ b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
@@ -201,9 +201,7 @@ public:
     }
   }
 
-  bool AreOperandsValid() {
-    return OperandsValid;
-  }
+  bool AreOperandsValid() { return OperandsValid; }
 };
 
 class ComplexDeinterleavingGraph {
@@ -918,11 +916,11 @@ ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I, bool &FromCache) {
     return It->second;
   }
 
-  if(NodePtr CN = identifyPartialReduction(R, I))
+  if (NodePtr CN = identifyPartialReduction(R, I))
     return CN;
 
   bool IsReduction = RealPHI == R && (!ImagPHI || ImagPHI == I);
-  if(!IsReduction && R->getType() != I->getType())
+  if (!IsReduction && R->getType() != I->getType())
     return nullptr;
 
   if (NodePtr CN = identifySplat(R, I))
@@ -1450,18 +1448,20 @@ bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
   if (It != RootToNode.end()) {
     auto RootNode = It->second;
     assert(RootNode->Operation ==
-           ComplexDeinterleavingOperation::ReductionOperation || RootNode->Operation == ComplexDeinterleavingOperation::ReductionSingle);
+               ComplexDeinterleavingOperation::ReductionOperation ||
+           RootNode->Operation ==
+               ComplexDeinterleavingOperation::ReductionSingle);
     // Find out which part, Real or Imag, comes later, and only if we come to
     // the latest part, add it to OrderedRoots.
     auto *R = cast<Instruction>(RootNode->Real);
     auto *I = RootNode->Imag ? cast<Instruction>(RootNode->Imag) : nullptr;
 
     Instruction *ReplacementAnchor;
-    if(I) 
+    if (I)
       ReplacementAnchor = R->comesBefore(I) ? I : R;
-    else 
+    else
       ReplacementAnchor = R;
-    
+
     if (ReplacementAnchor != RootI)
       return false;
     OrderedRoots.push_back(RootI);
@@ -1553,7 +1553,7 @@ void ComplexDeinterleavingGraph::identifyReductionNodes() {
     for (size_t j = i + 1; j < OperationInstruction.size(); ++j) {
       if (Processed[j])
         continue;
-      
+
       auto *Imag = OperationInstruction[j];
       if (Real->getType() != Imag->getType())
         continue;
@@ -1588,18 +1588,20 @@ void ComplexDeinterleavingGraph::identifyReductionNodes() {
 
     // We want to check that we have 2 operands, but the function attributes
     // being counted as operands bloats this value.
-    if(Real->getNumOperands() < 2)
+    if (Real->getNumOperands() < 2)
       continue;
 
     RealPHI = ReductionInfo[Real].first;
     ImagPHI = nullptr;
     PHIsFound = false;
     auto Node = identifyNode(Real->getOperand(0), Real->getOperand(1));
-    if(Node && PHIsFound) {
-      LLVM_DEBUG(dbgs() << "Identified single reduction starting from instruction: "
-                          << *Real << "/" << *ReductionInfo[Real].second << "\n");
+    if (Node && PHIsFound) {
+      LLVM_DEBUG(
+          dbgs() << "Identified single reduction starting from instruction: "
+                 << *Real << "/" << *ReductionInfo[Real].second << "\n");
       Processed[i] = true;
-      auto RootNode = prepareCompositeNode(ComplexDeinterleavingOperation::ReductionSingle, Real, nullptr);
+      auto RootNode = prepareCompositeNode(
+          ComplexDeinterleavingOperation::ReductionSingle, Real, nullptr);
       RootNode->addOperand(Node);
       RootToNode[Real] = RootNode;
       submitCompositeNode(RootNode);
@@ -2059,7 +2061,8 @@ Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
   return ReplacementNode;
 }
 
-void ComplexDeinterleavingGraph::processReductionSingle(Value *OperationReplacement, RawNodePtr Node) {
+void ComplexDeinterleavingGraph::processReductionSingle(
+    Value *OperationReplacement, RawNodePtr Node) {
   auto *Real = cast<Instruction>(Node->Real);
   auto *OldPHI = ReductionInfo[Real].first;
   auto *NewPHI = OldToNewPHI[OldPHI];
@@ -2071,21 +2074,22 @@ void ComplexDeinterleavingGraph::processReductionSingle(Value *OperationReplacem
   IRBuilder<> Builder(Incoming->getTerminator());
 
   Value *NewInit = nullptr;
-  if(auto *C = dyn_cast<Constant>(Init)) {
-    if(C->isZeroValue())
+  if (auto *C = dyn_cast<Constant>(Init)) {
+    if (C->isZeroValue())
       NewInit = Constant::getNullValue(NewVTy);
   }
 
   if (!NewInit)
     NewInit = Builder.CreateIntrinsic(Intrinsic::vector_interleave2, NewVTy,
-                                          {Init, Constant::getNullValue(VTy)});
+                                      {Init, Constant::getNullValue(VTy)});
 
   NewPHI->addIncoming(NewInit, Incoming);
   NewPHI->addIncoming(OperationReplacement, BackEdge);
 
   auto *FinalReduction = ReductionInfo[Real].second;
   Builder.SetInsertPoint(&*FinalReduction->getParent()->getFirstInsertionPt());
-  // TODO Ensure that the `AddReduce` here matches the original, found in `FinalReduction`
+  // TODO Ensure that the `AddReduce` here matches the original, found in
+  // `FinalReduction`
   auto *AddReduce = Builder.CreateAddReduce(OperationReplacement);
   FinalReduction->replaceAllUsesWith(AddReduce);
 }
@@ -2151,7 +2155,8 @@ void ComplexDeinterleavingGraph::replaceNodes() {
       ReductionInfo[RootImag].first->removeIncomingValue(BackEdge);
       DeadInstrRoots.push_back(RootReal);
       DeadInstrRoots.push_back(RootImag);
-    } else if(RootNode->Operation == ComplexDeinterleavingOperation::ReductionSingle) {
+    } else if (RootNode->Operation ==
+               ComplexDeinterleavingOperation::ReductionSingle) {
       auto *RootInst = cast<Instruction>(RootNode->Real);
       ReductionInfo[RootInst].first->removeIncomingValue(BackEdge);
       DeadInstrRoots.push_back(ReductionInfo[RootInst].second);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 8068bb6740..6f58ed4f6c 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -29184,7 +29184,10 @@ Value *AArch64TargetLowering::createComplexDeinterleavingIR(
 
   if (TyWidth > 128) {
     int Stride = Ty->getElementCount().getKnownMinValue() / 2;
-    int AccStride = cast<VectorType>(Accumulator->getType())->getElementCount().getKnownMinValue() / 2;
+    int AccStride = cast<VectorType>(Accumulator->getType())
+                        ->getElementCount()
+                        .getKnownMinValue() /
+                    2;
     auto *HalfTy = VectorType::getHalfElementsVectorType(Ty);
     auto *LowerSplitA = B.CreateExtractVector(HalfTy, InputA, B.getInt64(0));
     auto *LowerSplitB = B.CreateExtractVector(HalfTy, InputB, B.getInt64(0));
@@ -29195,19 +29198,22 @@ Value *AArch64TargetLowering::createComplexDeinterleavingIR(
     Value *LowerSplitAcc = nullptr;
     Value *UpperSplitAcc = nullptr;
     Type *FullTy = Ty;
-      FullTy = Accumulator->getType();
-      auto *HalfAccTy = VectorType::getHalfElementsVectorType(cast<VectorType>(Accumulator->getType()));
-      LowerSplitAcc = B.CreateExtractVector(HalfAccTy, Accumulator, B.getInt64(0));
-      UpperSplitAcc =
-          B.CreateExtractVector(HalfAccTy, Accumulator, B.getInt64(AccStride));
+    FullTy = Accumulator->getType();
+    auto *HalfAccTy = VectorType::getHalfElementsVectorType(
+        cast<VectorType>(Accumulator->getType()));
+    LowerSplitAcc =
+        B.CreateExtractVector(HalfAccTy, Accumulator, B.getInt64(0));
+    UpperSplitAcc =
+        B.CreateExtractVector(HalfAccTy, Accumulator, B.getInt64(AccStride));
     auto *LowerSplitInt = createComplexDeinterleavingIR(
         B, OperationType, Rotation, LowerSplitA, LowerSplitB, LowerSplitAcc);
     auto *UpperSplitInt = createComplexDeinterleavingIR(
         B, OperationType, Rotation, UpperSplitA, UpperSplitB, UpperSplitAcc);
 
-    auto *Result = B.CreateInsertVector(FullTy, PoisonValue::get(FullTy), LowerSplitInt,
-                                        B.getInt64(0));
-    return B.CreateInsertVector(FullTy, Result, UpperSplitInt, B.getInt64(AccStride));
+    auto *Result = B.CreateInsertVector(FullTy, PoisonValue::get(FullTy),
+                                        LowerSplitInt, B.getInt64(0));
+    return B.CreateInsertVector(FullTy, Result, UpperSplitInt,
+                                B.getInt64(AccStride));
   }
 
   if (OperationType == ComplexDeinterleavingOperation::CMulPartial) {

``````````

</details>


https://github.com/llvm/llvm-project/pull/112875


More information about the llvm-commits mailing list