[llvm] 2cbc265 - [CodeGen] Add support for reductions in ComplexDeinterleaving pass

Igor Kirillov via llvm-commits llvm-commits at lists.llvm.org
Wed Jun 14 10:28:22 PDT 2023


Author: Igor Kirillov
Date: 2023-06-14T17:27:26Z
New Revision: 2cbc265cc947c40372b841f80649276fbf9d183f

URL: https://github.com/llvm/llvm-project/commit/2cbc265cc947c40372b841f80649276fbf9d183f
DIFF: https://github.com/llvm/llvm-project/commit/2cbc265cc947c40372b841f80649276fbf9d183f.diff

LOG: [CodeGen] Add support for reductions in ComplexDeinterleaving pass

This commit enhances the ComplexDeinterleaving pass to handle unordered
reductions in simple one-block vectorized loops, supporting both
SVE and Neon architectures.

Differential Revision: https://reviews.llvm.org/D152022

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/ComplexDeinterleavingPass.h
    llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
    llvm/test/CodeGen/AArch64/complex-deinterleaving-reductions-scalable.ll
    llvm/test/CodeGen/AArch64/complex-deinterleaving-reductions.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/ComplexDeinterleavingPass.h b/llvm/include/llvm/CodeGen/ComplexDeinterleavingPass.h
index defa806889ae1..e7119078d8a3b 100644
--- a/llvm/include/llvm/CodeGen/ComplexDeinterleavingPass.h
+++ b/llvm/include/llvm/CodeGen/ComplexDeinterleavingPass.h
@@ -38,7 +38,9 @@ enum class ComplexDeinterleavingOperation {
   // The following 'operations' are used to represent internal states. Backends
   // are not expected to try and support these in any capacity.
   Deinterleave,
-  Symmetric
+  Symmetric,
+  ReductionPHI,
+  ReductionOperation,
 };
 
 enum class ComplexDeinterleavingRotation {

diff  --git a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
index ec7abb298d9f9..1c61179753ba3 100644
--- a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
+++ b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
@@ -228,8 +228,53 @@ class ComplexDeinterleavingGraph {
   /// Topologically sorted root instructions
   SmallVector<Instruction *, 1> OrderedRoots;
 
+  /// When examining a basic block for complex deinterleaving, if it is a simple
+  /// one-block loop, then the only incoming block is 'Incoming' and the
+  /// 'BackEdge' block is the block itself."
+  BasicBlock *BackEdge = nullptr;
+  BasicBlock *Incoming = nullptr;
+
+  /// ReductionInfo maps from %ReductionOp to %PHInode and Instruction
+  /// %OutsideUser as it is shown in the IR:
+  ///
+  /// vector.body:
+  ///   %PHInode = phi <vector type> [ zeroinitializer, %entry ],
+  ///                                [ %ReductionOp, %vector.body ]
+  ///   ...
+  ///   %ReductionOp = fadd i64 ...
+  ///   ...
+  ///   br i1 %condition, label %vector.body, %middle.block
+  ///
+  /// middle.block:
+  ///   %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp)
+  ///
+  /// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding
+  /// `llvm.vector.reduce.fadd` when unroll factor isn't one.
+  std::map<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo;
+
+  /// In the process of detecting a reduction, we consider a pair of
+  /// %ReductionOP, which we refer to as real and imag (or vice versa), and
+  /// traverse the use-tree to detect complex operations. As this is a reduction
+  /// operation, it will eventually reach RealPHI and ImagPHI, which corresponds
+  /// to the %ReductionOPs that we suspect to be complex.
+  /// RealPHI and ImagPHI are used by the identifyPHINode method.
+  PHINode *RealPHI = nullptr;
+  PHINode *ImagPHI = nullptr;
+
+  /// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode.
+  /// The new PHINode corresponds to a vector of deinterleaved complex numbers.
+  /// This mapping is populated during
+  /// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then
+  /// used in the ComplexDeinterleavingOperation::ReductionOperation node
+  /// replacement process.
+  std::map<PHINode *, PHINode *> OldToNewPHI;
+
   NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation,
                                Instruction *R, Instruction *I) {
+    assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI &&
+             Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
+            (R && I)) &&
+           "Reduction related nodes must have Real and Imaginary parts");
     return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation, R,
                                                                 I);
   }
@@ -324,8 +369,17 @@ class ComplexDeinterleavingGraph {
   /// intrinsic (for both fixed and scalable vectors)
   NodePtr identifyDeinterleave(Instruction *Real, Instruction *Imag);
 
+  NodePtr identifyPHINode(Instruction *Real, Instruction *Imag);
+
   Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node);
 
+  /// Complete IR modifications after producing new reduction operation:
+  /// * Populate the PHINode generated for
+  /// ComplexDeinterleavingOperation::ReductionPHI
+  /// * Deinterleave the final value outside of the loop and repurpose original
+  /// reduction users
+  void processReductionOperation(Value *OperationReplacement, RawNodePtr Node);
+
 public:
   void dump() { dump(dbgs()); }
   void dump(raw_ostream &OS) {
@@ -337,6 +391,13 @@ class ComplexDeinterleavingGraph {
   /// current graph.
   bool identifyNodes(Instruction *RootI);
 
+  /// In case \pB is one-block loop, this function seeks potential reductions
+  /// and populates ReductionInfo. Returns true if any reductions were
+  /// identified.
+  bool collectPotentialReductions(BasicBlock *B);
+
+  void identifyReductionNodes();
+
   /// Check that every instruction, from the roots to the leaves, has internal
   /// uses.
   bool checkNodes();
@@ -439,6 +500,9 @@ static bool isDeinterleavingMask(ArrayRef<int> Mask) {
 
 bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) {
   ComplexDeinterleavingGraph Graph(TL, TLI);
+  if (Graph.collectPotentialReductions(B))
+    Graph.identifyReductionNodes();
+
   for (auto &I : *B)
     Graph.identifyNodes(&I);
 
@@ -822,6 +886,9 @@ ComplexDeinterleavingGraph::identifyNode(Instruction *Real, Instruction *Imag) {
   if (NodePtr CN = identifyDeinterleave(Real, Imag))
     return CN;
 
+  if (NodePtr CN = identifyPHINode(Real, Imag))
+    return CN;
+
   auto *VTy = cast<VectorType>(Real->getType());
   auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
 
@@ -1293,6 +1360,16 @@ ComplexDeinterleavingGraph::extractPositiveAddend(
 }
 
 bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
+  // This potential root instruction might already have been recognized as
+  // reduction. Because RootToNode maps both Real and Imaginary parts to
+  // CompositeNode we should choose only one either Real or Imag instruction to
+  // use as an anchor for generating complex instruction.
+  auto It = RootToNode.find(RootI);
+  if (It != RootToNode.end() && It->second->Real == RootI) {
+    OrderedRoots.push_back(RootI);
+    return true;
+  }
+
   auto RootNode = identifyRoot(RootI);
   if (!RootNode)
     return false;
@@ -1310,12 +1387,113 @@ bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
   return true;
 }
 
+bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) {
+  bool FoundPotentialReduction = false;
+
+  auto *Br = dyn_cast<BranchInst>(B->getTerminator());
+  if (!Br || Br->getNumSuccessors() != 2)
+    return false;
+
+  // Identify simple one-block loop
+  if (Br->getSuccessor(0) != B && Br->getSuccessor(1) != B)
+    return false;
+
+  SmallVector<PHINode *> PHIs;
+  for (auto &PHI : B->phis()) {
+    if (PHI.getNumIncomingValues() != 2)
+      continue;
+
+    if (!PHI.getType()->isVectorTy())
+      continue;
+
+    auto *ReductionOp = dyn_cast<Instruction>(PHI.getIncomingValueForBlock(B));
+    if (!ReductionOp)
+      continue;
+
+    // Check if final instruction is reduced outside of current block
+    Instruction *FinalReduction = nullptr;
+    auto NumUsers = 0u;
+    for (auto *U : ReductionOp->users()) {
+      ++NumUsers;
+      if (U == &PHI)
+        continue;
+      FinalReduction = dyn_cast<Instruction>(U);
+    }
+
+    if (NumUsers != 2 || !FinalReduction || FinalReduction->getParent() == B)
+      continue;
+
+    ReductionInfo[ReductionOp] = {&PHI, FinalReduction};
+    BackEdge = B;
+    auto BackEdgeIdx = PHI.getBasicBlockIndex(B);
+    auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0;
+    Incoming = PHI.getIncomingBlock(IncomingIdx);
+    FoundPotentialReduction = true;
+
+    // If the initial value of PHINode is an Instruction, consider it a leaf
+    // value of a complex deinterleaving graph.
+    if (auto *InitPHI =
+            dyn_cast<Instruction>(PHI.getIncomingValueForBlock(Incoming)))
+      FinalInstructions.insert(InitPHI);
+  }
+  return FoundPotentialReduction;
+}
+
+void ComplexDeinterleavingGraph::identifyReductionNodes() {
+  SmallVector<bool> Processed(ReductionInfo.size(), false);
+  SmallVector<Instruction *> OperationInstruction;
+  for (auto &P : ReductionInfo)
+    OperationInstruction.push_back(P.first);
+
+  // Identify a complex computation by evaluating two reduction operations that
+  // potentially could be involved
+  for (size_t i = 0; i < OperationInstruction.size(); ++i) {
+    if (Processed[i])
+      continue;
+    for (size_t j = i + 1; j < OperationInstruction.size(); ++j) {
+      if (Processed[j])
+        continue;
+
+      auto *Real = OperationInstruction[i];
+      auto *Imag = OperationInstruction[j];
+
+      RealPHI = ReductionInfo[Real].first;
+      ImagPHI = ReductionInfo[Imag].first;
+      auto Node = identifyNode(Real, Imag);
+      if (!Node) {
+        std::swap(Real, Imag);
+        std::swap(RealPHI, ImagPHI);
+        Node = identifyNode(Real, Imag);
+      }
+
+      // If a node is identified, mark its operation instructions as used to
+      // prevent re-identification and attach the node to the real part
+      if (Node) {
+        LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: "
+                          << *Real << " / " << *Imag << "\n");
+        Processed[i] = true;
+        Processed[j] = true;
+        auto RootNode = prepareCompositeNode(
+            ComplexDeinterleavingOperation::ReductionOperation, Real, Imag);
+        RootNode->addOperand(Node);
+        RootToNode[Real] = RootNode;
+        RootToNode[Imag] = RootNode;
+        submitCompositeNode(RootNode);
+        break;
+      }
+    }
+  }
+
+  RealPHI = nullptr;
+  ImagPHI = nullptr;
+}
+
 bool ComplexDeinterleavingGraph::checkNodes() {
   // Collect all instructions from roots to leaves
   SmallPtrSet<Instruction *, 16> AllInstructions;
   SmallVector<Instruction *, 8> Worklist;
-  for (auto *I : OrderedRoots)
-    Worklist.push_back(I);
+  for (auto &Pair : RootToNode)
+    Worklist.push_back(Pair.first);
 
   // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG
   // chains
@@ -1524,6 +1702,17 @@ ComplexDeinterleavingGraph::identifyDeinterleave(Instruction *Real,
   return submitCompositeNode(PlaceholderNode);
 }
 
+ComplexDeinterleavingGraph::NodePtr
+ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
+                                            Instruction *Imag) {
+  if (Real != RealPHI || Imag != ImagPHI)
+    return nullptr;
+
+  NodePtr PlaceholderNode = prepareCompositeNode(
+      ComplexDeinterleavingOperation::ReductionPHI, Real, Imag);
+  return submitCompositeNode(PlaceholderNode);
+}
+
 static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode,
                                    FastMathFlags Flags, Value *InputA,
                                    Value *InputB) {
@@ -1553,27 +1742,100 @@ Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
   if (Node->ReplacementNode)
     return Node->ReplacementNode;
 
-  Value *Input0 = replaceNode(Builder, Node->Operands[0]);
-  Value *Input1 = Node->Operands.size() > 1
-                      ? replaceNode(Builder, Node->Operands[1])
-                      : nullptr;
-  Value *Accumulator = Node->Operands.size() > 2
-                           ? replaceNode(Builder, Node->Operands[2])
-                           : nullptr;
-  if (Input1)
-    assert(Input0->getType() == Input1->getType() &&
-           "Node inputs need to be of the same type");
-
-  if (Node->Operation == ComplexDeinterleavingOperation::Symmetric)
-    Node->ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode,
-                                                 Node->Flags, Input0, Input1);
-  else
-    Node->ReplacementNode = TL->createComplexDeinterleavingIR(
-        Builder, Node->Operation, Node->Rotation, Input0, Input1, Accumulator);
+  auto ReplaceOperandIfExist = [&](RawNodePtr &Node, unsigned Idx) -> Value * {
+    return Node->Operands.size() > Idx
+               ? replaceNode(Builder, Node->Operands[Idx])
+               : nullptr;
+  };
+
+  Value *ReplacementNode;
+  switch (Node->Operation) {
+  case ComplexDeinterleavingOperation::CAdd:
+  case ComplexDeinterleavingOperation::CMulPartial:
+  case ComplexDeinterleavingOperation::Symmetric: {
+    Value *Input0 = ReplaceOperandIfExist(Node, 0);
+    Value *Input1 = ReplaceOperandIfExist(Node, 1);
+    Value *Accumulator = ReplaceOperandIfExist(Node, 2);
+    assert(!Input1 || (Input0->getType() == Input1->getType() &&
+                       "Node inputs need to be of the same type"));
+    assert(!Accumulator ||
+           (Input0->getType() == Accumulator->getType() &&
+            "Accumulator and input need to be of the same type"));
+    if (Node->Operation == ComplexDeinterleavingOperation::Symmetric)
+      ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode, Node->Flags,
+                                             Input0, Input1);
+    else
+      ReplacementNode = TL->createComplexDeinterleavingIR(
+          Builder, Node->Operation, Node->Rotation, Input0, Input1,
+          Accumulator);
+    break;
+  }
+  case ComplexDeinterleavingOperation::Deinterleave:
+    llvm_unreachable("Deinterleave node should already have ReplacementNode");
+    break;
+  case ComplexDeinterleavingOperation::ReductionPHI: {
+    // If Operation is ReductionPHI, a new empty PHINode is created.
+    // It is filled later when the ReductionOperation is processed.
+    auto *VTy = cast<VectorType>(Node->Real->getType());
+    auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
+    auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHI());
+    OldToNewPHI[dyn_cast<PHINode>(Node->Real)] = NewPHI;
+    ReplacementNode = NewPHI;
+    break;
+  }
+  case ComplexDeinterleavingOperation::ReductionOperation:
+    ReplacementNode = replaceNode(Builder, Node->Operands[0]);
+    processReductionOperation(ReplacementNode, Node);
+    break;
+  default:
+    llvm_unreachable(
+        "Unhandled case in ComplexDeinterleavingGraph::replaceNode");
+    break;
+  }
 
-  assert(Node->ReplacementNode && "Target failed to create Intrinsic call.");
+  assert(ReplacementNode && "Target failed to create Intrinsic call.");
   NumComplexTransformations += 1;
-  return Node->ReplacementNode;
+  Node->ReplacementNode = ReplacementNode;
+  return ReplacementNode;
+}
+
+void ComplexDeinterleavingGraph::processReductionOperation(
+    Value *OperationReplacement, RawNodePtr Node) {
+  auto *OldPHIReal = ReductionInfo[Node->Real].first;
+  auto *OldPHIImag = ReductionInfo[Node->Imag].first;
+  auto *NewPHI = OldToNewPHI[OldPHIReal];
+
+  auto *VTy = cast<VectorType>(Node->Real->getType());
+  auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
+
+  // We have to interleave initial origin values coming from IncomingBlock
+  Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming);
+  Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming);
+
+  IRBuilder<> Builder(Incoming->getTerminator());
+  auto *NewInit = Builder.CreateIntrinsic(
+      Intrinsic::experimental_vector_interleave2, NewVTy, {InitReal, InitImag});
+
+  NewPHI->addIncoming(NewInit, Incoming);
+  NewPHI->addIncoming(OperationReplacement, BackEdge);
+
+  // Deinterleave complex vector outside of loop so that it can be finally
+  // reduced
+  auto *FinalReductionReal = ReductionInfo[Node->Real].second;
+  auto *FinalReductionImag = ReductionInfo[Node->Imag].second;
+
+  Builder.SetInsertPoint(
+      &*FinalReductionReal->getParent()->getFirstInsertionPt());
+  auto *Deinterleave = Builder.CreateIntrinsic(
+      Intrinsic::experimental_vector_deinterleave2,
+      OperationReplacement->getType(), OperationReplacement);
+
+  auto *NewReal = Builder.CreateExtractValue(Deinterleave, (uint64_t)0);
+  FinalReductionReal->replaceUsesOfWith(Node->Real, NewReal);
+
+  Builder.SetInsertPoint(FinalReductionImag);
+  auto *NewImag = Builder.CreateExtractValue(Deinterleave, 1);
+  FinalReductionImag->replaceUsesOfWith(Node->Imag, NewImag);
 }
 
 void ComplexDeinterleavingGraph::replaceNodes() {
@@ -1587,9 +1849,18 @@ void ComplexDeinterleavingGraph::replaceNodes() {
     IRBuilder<> Builder(RootInstruction);
     auto RootNode = RootToNode[RootInstruction];
     Value *R = replaceNode(Builder, RootNode.get());
-    assert(R && "Unable to find replacement for RootInstruction");
-    DeadInstrRoots.push_back(RootInstruction);
-    RootInstruction->replaceAllUsesWith(R);
+
+    if (RootNode->Operation ==
+        ComplexDeinterleavingOperation::ReductionOperation) {
+      ReductionInfo[RootNode->Real].first->removeIncomingValue(BackEdge);
+      ReductionInfo[RootNode->Imag].first->removeIncomingValue(BackEdge);
+      DeadInstrRoots.push_back(RootNode->Real);
+      DeadInstrRoots.push_back(RootNode->Imag);
+    } else {
+      assert(R && "Unable to find replacement for RootInstruction");
+      DeadInstrRoots.push_back(RootInstruction);
+      RootInstruction->replaceAllUsesWith(R);
+    }
   }
 
   for (auto *I : DeadInstrRoots)

diff  --git a/llvm/test/CodeGen/AArch64/complex-deinterleaving-reductions-scalable.ll b/llvm/test/CodeGen/AArch64/complex-deinterleaving-reductions-scalable.ll
index 1da41b37df2e6..88c59d1f9fa26 100644
--- a/llvm/test/CodeGen/AArch64/complex-deinterleaving-reductions-scalable.ll
+++ b/llvm/test/CodeGen/AArch64/complex-deinterleaving-reductions-scalable.ll
@@ -20,8 +20,9 @@ define %"class.std::complex" @complex_mul_v2f64(ptr %a, ptr %b) {
 ; CHECK-NEXT:    mov x8, xzr
 ; CHECK-NEXT:    and x10, x10, x11
 ; CHECK-NEXT:    mov z1.d, #0 // =0x0
-; CHECK-NEXT:    mov z0.d, z1.d
 ; CHECK-NEXT:    rdvl x11, #2
+; CHECK-NEXT:    zip2 z0.d, z1.d, z1.d
+; CHECK-NEXT:    zip1 z1.d, z1.d, z1.d
 ; CHECK-NEXT:    ptrue p1.b
 ; CHECK-NEXT:    ptrue p0.d
 ; CHECK-NEXT:  .LBB0_1: // %vector.body
@@ -34,18 +35,16 @@ define %"class.std::complex" @complex_mul_v2f64(ptr %a, ptr %b) {
 ; CHECK-NEXT:    ld1b { z4.b }, p1/z, [x1, x8]
 ; CHECK-NEXT:    ld1d { z5.d }, p0/z, [x13, #1, mul vl]
 ; CHECK-NEXT:    add x8, x8, x11
-; CHECK-NEXT:    uzp2 z6.d, z2.d, z3.d
-; CHECK-NEXT:    uzp1 z2.d, z2.d, z3.d
-; CHECK-NEXT:    uzp1 z3.d, z4.d, z5.d
-; CHECK-NEXT:    fmla z0.d, p0/m, z3.d, z2.d
-; CHECK-NEXT:    fmla z1.d, p0/m, z3.d, z6.d
-; CHECK-NEXT:    uzp2 z3.d, z4.d, z5.d
-; CHECK-NEXT:    fmls z0.d, p0/m, z3.d, z6.d
-; CHECK-NEXT:    fmla z1.d, p0/m, z3.d, z2.d
+; CHECK-NEXT:    fcmla z1.d, p0/m, z4.d, z2.d, #0
+; CHECK-NEXT:    fcmla z0.d, p0/m, z5.d, z3.d, #0
+; CHECK-NEXT:    fcmla z1.d, p0/m, z4.d, z2.d, #90
+; CHECK-NEXT:    fcmla z0.d, p0/m, z5.d, z3.d, #90
 ; CHECK-NEXT:    b.ne .LBB0_1
 ; CHECK-NEXT:  // %bb.2: // %exit.block
+; CHECK-NEXT:    uzp2 z2.d, z1.d, z0.d
+; CHECK-NEXT:    uzp1 z0.d, z1.d, z0.d
 ; CHECK-NEXT:    faddv d0, p0, z0.d
-; CHECK-NEXT:    faddv d1, p0, z1.d
+; CHECK-NEXT:    faddv d1, p0, z2.d
 ; CHECK-NEXT:    // kill: def $d0 killed $d0 killed $z0
 ; CHECK-NEXT:    // kill: def $d1 killed $d1 killed $z1
 ; CHECK-NEXT:    ret
@@ -103,17 +102,19 @@ define %"class.std::complex" @complex_mul_nonzero_init_v2f64(ptr %a, ptr %b) {
 ; CHECK-LABEL: complex_mul_nonzero_init_v2f64:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    cntd x9
-; CHECK-NEXT:    fmov d0, #2.00000000
+; CHECK-NEXT:    fmov d0, #1.00000000
+; CHECK-NEXT:    fmov d1, #2.00000000
 ; CHECK-NEXT:    neg x10, x9
 ; CHECK-NEXT:    mov w11, #100 // =0x64
-; CHECK-NEXT:    fmov d1, #1.00000000
 ; CHECK-NEXT:    mov x8, xzr
 ; CHECK-NEXT:    and x10, x10, x11
 ; CHECK-NEXT:    mov z2.d, #0 // =0x0
 ; CHECK-NEXT:    ptrue p0.d, vl1
 ; CHECK-NEXT:    rdvl x11, #2
-; CHECK-NEXT:    sel z0.d, p0, z0.d, z2.d
+; CHECK-NEXT:    sel z3.d, p0, z0.d, z2.d
 ; CHECK-NEXT:    sel z1.d, p0, z1.d, z2.d
+; CHECK-NEXT:    zip2 z0.d, z1.d, z3.d
+; CHECK-NEXT:    zip1 z1.d, z1.d, z3.d
 ; CHECK-NEXT:    ptrue p1.b
 ; CHECK-NEXT:    ptrue p0.d
 ; CHECK-NEXT:  .LBB1_1: // %vector.body
@@ -126,18 +127,16 @@ define %"class.std::complex" @complex_mul_nonzero_init_v2f64(ptr %a, ptr %b) {
 ; CHECK-NEXT:    ld1b { z4.b }, p1/z, [x1, x8]
 ; CHECK-NEXT:    ld1d { z5.d }, p0/z, [x13, #1, mul vl]
 ; CHECK-NEXT:    add x8, x8, x11
-; CHECK-NEXT:    uzp2 z6.d, z2.d, z3.d
-; CHECK-NEXT:    uzp1 z2.d, z2.d, z3.d
-; CHECK-NEXT:    uzp1 z3.d, z4.d, z5.d
-; CHECK-NEXT:    fmla z0.d, p0/m, z3.d, z2.d
-; CHECK-NEXT:    fmla z1.d, p0/m, z3.d, z6.d
-; CHECK-NEXT:    uzp2 z3.d, z4.d, z5.d
-; CHECK-NEXT:    fmls z0.d, p0/m, z3.d, z6.d
-; CHECK-NEXT:    fmla z1.d, p0/m, z3.d, z2.d
+; CHECK-NEXT:    fcmla z1.d, p0/m, z4.d, z2.d, #0
+; CHECK-NEXT:    fcmla z0.d, p0/m, z5.d, z3.d, #0
+; CHECK-NEXT:    fcmla z1.d, p0/m, z4.d, z2.d, #90
+; CHECK-NEXT:    fcmla z0.d, p0/m, z5.d, z3.d, #90
 ; CHECK-NEXT:    b.ne .LBB1_1
 ; CHECK-NEXT:  // %bb.2: // %exit.block
+; CHECK-NEXT:    uzp2 z2.d, z1.d, z0.d
+; CHECK-NEXT:    uzp1 z0.d, z1.d, z0.d
 ; CHECK-NEXT:    faddv d0, p0, z0.d
-; CHECK-NEXT:    faddv d1, p0, z1.d
+; CHECK-NEXT:    faddv d1, p0, z2.d
 ; CHECK-NEXT:    // kill: def $d0 killed $d0 killed $z0
 ; CHECK-NEXT:    // kill: def $d1 killed $d1 killed $z1
 ; CHECK-NEXT:    ret
@@ -195,10 +194,11 @@ define %"class.std::complex" @complex_mul_v2f64_unrolled(ptr %a, ptr %b) {
 ; CHECK-NEXT:    neg x10, x9
 ; CHECK-NEXT:    mov x8, xzr
 ; CHECK-NEXT:    and x10, x10, x11
-; CHECK-NEXT:    mov z0.d, #0 // =0x0
+; CHECK-NEXT:    mov z1.d, #0 // =0x0
+; CHECK-NEXT:    zip2 z0.d, z1.d, z1.d
+; CHECK-NEXT:    zip1 z1.d, z1.d, z1.d
 ; CHECK-NEXT:    rdvl x11, #4
-; CHECK-NEXT:    mov z1.d, z0.d
-; CHECK-NEXT:    mov z2.d, z0.d
+; CHECK-NEXT:    mov z2.d, z1.d
 ; CHECK-NEXT:    mov z3.d, z0.d
 ; CHECK-NEXT:    addvl x12, x1, #2
 ; CHECK-NEXT:    addvl x13, x0, #2
@@ -220,27 +220,23 @@ define %"class.std::complex" @complex_mul_v2f64_unrolled(ptr %a, ptr %b) {
 ; CHECK-NEXT:    ld1b { z18.b }, p1/z, [x12, x8]
 ; CHECK-NEXT:    ld1d { z19.d }, p0/z, [x17, #1, mul vl]
 ; CHECK-NEXT:    add x8, x8, x11
-; CHECK-NEXT:    uzp2 z20.d, z4.d, z5.d
-; CHECK-NEXT:    uzp1 z4.d, z4.d, z5.d
-; CHECK-NEXT:    uzp2 z5.d, z6.d, z7.d
-; CHECK-NEXT:    uzp1 z6.d, z6.d, z7.d
-; CHECK-NEXT:    uzp1 z7.d, z16.d, z17.d
-; CHECK-NEXT:    uzp1 z21.d, z18.d, z19.d
-; CHECK-NEXT:    fmla z2.d, p0/m, z7.d, z4.d
-; CHECK-NEXT:    fmla z3.d, p0/m, z21.d, z6.d
-; CHECK-NEXT:    fmla z0.d, p0/m, z7.d, z20.d
-; CHECK-NEXT:    fmla z1.d, p0/m, z21.d, z5.d
-; CHECK-NEXT:    uzp2 z7.d, z16.d, z17.d
-; CHECK-NEXT:    uzp2 z16.d, z18.d, z19.d
-; CHECK-NEXT:    fmls z2.d, p0/m, z7.d, z20.d
-; CHECK-NEXT:    fmls z3.d, p0/m, z16.d, z5.d
-; CHECK-NEXT:    fmla z0.d, p0/m, z7.d, z4.d
-; CHECK-NEXT:    fmla z1.d, p0/m, z16.d, z6.d
+; CHECK-NEXT:    fcmla z1.d, p0/m, z16.d, z4.d, #0
+; CHECK-NEXT:    fcmla z0.d, p0/m, z17.d, z5.d, #0
+; CHECK-NEXT:    fcmla z2.d, p0/m, z18.d, z6.d, #0
+; CHECK-NEXT:    fcmla z3.d, p0/m, z19.d, z7.d, #0
+; CHECK-NEXT:    fcmla z1.d, p0/m, z16.d, z4.d, #90
+; CHECK-NEXT:    fcmla z0.d, p0/m, z17.d, z5.d, #90
+; CHECK-NEXT:    fcmla z2.d, p0/m, z18.d, z6.d, #90
+; CHECK-NEXT:    fcmla z3.d, p0/m, z19.d, z7.d, #90
 ; CHECK-NEXT:    b.ne .LBB2_1
 ; CHECK-NEXT:  // %bb.2: // %exit.block
-; CHECK-NEXT:    fadd z2.d, z3.d, z2.d
-; CHECK-NEXT:    fadd z1.d, z1.d, z0.d
-; CHECK-NEXT:    faddv d0, p0, z2.d
+; CHECK-NEXT:    uzp2 z4.d, z2.d, z3.d
+; CHECK-NEXT:    uzp1 z2.d, z2.d, z3.d
+; CHECK-NEXT:    uzp2 z3.d, z1.d, z0.d
+; CHECK-NEXT:    uzp1 z0.d, z1.d, z0.d
+; CHECK-NEXT:    fadd z0.d, z2.d, z0.d
+; CHECK-NEXT:    fadd z1.d, z4.d, z3.d
+; CHECK-NEXT:    faddv d0, p0, z0.d
 ; CHECK-NEXT:    faddv d1, p0, z1.d
 ; CHECK-NEXT:    // kill: def $d0 killed $d0 killed $z0
 ; CHECK-NEXT:    // kill: def $d1 killed $d1 killed $z1

diff  --git a/llvm/test/CodeGen/AArch64/complex-deinterleaving-reductions.ll b/llvm/test/CodeGen/AArch64/complex-deinterleaving-reductions.ll
index a47c410b67299..675b1b8948d11 100644
--- a/llvm/test/CodeGen/AArch64/complex-deinterleaving-reductions.ll
+++ b/llvm/test/CodeGen/AArch64/complex-deinterleaving-reductions.ll
@@ -14,25 +14,27 @@ target triple = "aarch64-arm-none-eabi"
 define dso_local %"struct.std::complex" @complex_mul_v2f64(ptr %a, ptr %b) {
 ; CHECK-LABEL: complex_mul_v2f64:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    movi v1.2d, #0000000000000000
-; CHECK-NEXT:    mov x8, xzr
 ; CHECK-NEXT:    movi v0.2d, #0000000000000000
+; CHECK-NEXT:    mov x8, xzr
+; CHECK-NEXT:    movi v1.2d, #0000000000000000
 ; CHECK-NEXT:  .LBB0_1: // %vector.body
 ; CHECK-NEXT:    // =>This Inner Loop Header: Depth=1
 ; CHECK-NEXT:    add x9, x0, x8
-; CHECK-NEXT:    ld2 { v2.2d, v3.2d }, [x9]
-; CHECK-NEXT:    add x9, x1, x8
+; CHECK-NEXT:    add x10, x1, x8
 ; CHECK-NEXT:    add x8, x8, #32
 ; CHECK-NEXT:    cmp x8, #1600
-; CHECK-NEXT:    ld2 { v4.2d, v5.2d }, [x9]
-; CHECK-NEXT:    fmla v0.2d, v2.2d, v4.2d
-; CHECK-NEXT:    fmla v1.2d, v3.2d, v4.2d
-; CHECK-NEXT:    fmls v0.2d, v3.2d, v5.2d
-; CHECK-NEXT:    fmla v1.2d, v2.2d, v5.2d
+; CHECK-NEXT:    ldp q3, q2, [x9]
+; CHECK-NEXT:    ldp q4, q5, [x10]
+; CHECK-NEXT:    fcmla v0.2d, v3.2d, v4.2d, #0
+; CHECK-NEXT:    fcmla v1.2d, v2.2d, v5.2d, #0
+; CHECK-NEXT:    fcmla v0.2d, v3.2d, v4.2d, #90
+; CHECK-NEXT:    fcmla v1.2d, v2.2d, v5.2d, #90
 ; CHECK-NEXT:    b.ne .LBB0_1
 ; CHECK-NEXT:  // %bb.2: // %middle.block
+; CHECK-NEXT:    zip2 v2.2d, v0.2d, v1.2d
+; CHECK-NEXT:    zip1 v0.2d, v0.2d, v1.2d
+; CHECK-NEXT:    faddp d1, v2.2d
 ; CHECK-NEXT:    faddp d0, v0.2d
-; CHECK-NEXT:    faddp d1, v1.2d
 ; CHECK-NEXT:    ret
 entry:
   br label %vector.body
@@ -79,26 +81,27 @@ define %"struct.std::complex" @complex_mul_nonzero_init_v2f64(ptr %a, ptr %b) {
 ; CHECK-LABEL: complex_mul_nonzero_init_v2f64:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    adrp x9, .LCPI1_0
-; CHECK-NEXT:    adrp x10, .LCPI1_1
 ; CHECK-NEXT:    mov x8, xzr
-; CHECK-NEXT:    ldr q0, [x9, :lo12:.LCPI1_0]
-; CHECK-NEXT:    ldr q1, [x10, :lo12:.LCPI1_1]
+; CHECK-NEXT:    movi v0.2d, #0000000000000000
+; CHECK-NEXT:    ldr q1, [x9, :lo12:.LCPI1_0]
 ; CHECK-NEXT:  .LBB1_1: // %vector.body
 ; CHECK-NEXT:    // =>This Inner Loop Header: Depth=1
 ; CHECK-NEXT:    add x9, x0, x8
-; CHECK-NEXT:    ld2 { v2.2d, v3.2d }, [x9]
-; CHECK-NEXT:    add x9, x1, x8
+; CHECK-NEXT:    add x10, x1, x8
 ; CHECK-NEXT:    add x8, x8, #32
 ; CHECK-NEXT:    cmp x8, #1600
-; CHECK-NEXT:    ld2 { v4.2d, v5.2d }, [x9]
-; CHECK-NEXT:    fmla v0.2d, v2.2d, v4.2d
-; CHECK-NEXT:    fmla v1.2d, v3.2d, v4.2d
-; CHECK-NEXT:    fmls v0.2d, v3.2d, v5.2d
-; CHECK-NEXT:    fmla v1.2d, v2.2d, v5.2d
+; CHECK-NEXT:    ldp q3, q2, [x9]
+; CHECK-NEXT:    ldp q4, q5, [x10]
+; CHECK-NEXT:    fcmla v1.2d, v3.2d, v4.2d, #0
+; CHECK-NEXT:    fcmla v0.2d, v2.2d, v5.2d, #0
+; CHECK-NEXT:    fcmla v1.2d, v3.2d, v4.2d, #90
+; CHECK-NEXT:    fcmla v0.2d, v2.2d, v5.2d, #90
 ; CHECK-NEXT:    b.ne .LBB1_1
 ; CHECK-NEXT:  // %bb.2: // %middle.block
+; CHECK-NEXT:    zip2 v2.2d, v1.2d, v0.2d
+; CHECK-NEXT:    zip1 v0.2d, v1.2d, v0.2d
+; CHECK-NEXT:    faddp d1, v2.2d
 ; CHECK-NEXT:    faddp d0, v0.2d
-; CHECK-NEXT:    faddp d1, v1.2d
 ; CHECK-NEXT:    ret
 entry:
   br label %vector.body
@@ -141,36 +144,39 @@ define %"struct.std::complex" @complex_mul_v2f64_unrolled(ptr %a, ptr %b) {
 ; CHECK-LABEL: complex_mul_v2f64_unrolled:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    adrp x9, .LCPI2_0
-; CHECK-NEXT:    adrp x10, .LCPI2_1
-; CHECK-NEXT:    movi v0.2d, #0000000000000000
 ; CHECK-NEXT:    mov x8, xzr
+; CHECK-NEXT:    movi v0.2d, #0000000000000000
 ; CHECK-NEXT:    movi v2.2d, #0000000000000000
-; CHECK-NEXT:    ldr q3, [x9, :lo12:.LCPI2_0]
-; CHECK-NEXT:    ldr q1, [x10, :lo12:.LCPI2_1]
+; CHECK-NEXT:    movi v3.2d, #0000000000000000
+; CHECK-NEXT:    ldr q1, [x9, :lo12:.LCPI2_0]
 ; CHECK-NEXT:  .LBB2_1: // %vector.body
 ; CHECK-NEXT:    // =>This Inner Loop Header: Depth=1
 ; CHECK-NEXT:    add x9, x0, x8
-; CHECK-NEXT:    ld2 { v4.2d, v5.2d }, [x9], #32
-; CHECK-NEXT:    ld2 { v6.2d, v7.2d }, [x9]
-; CHECK-NEXT:    add x9, x1, x8
+; CHECK-NEXT:    add x10, x1, x8
 ; CHECK-NEXT:    add x8, x8, #64
 ; CHECK-NEXT:    cmp x8, #1600
-; CHECK-NEXT:    ld2 { v16.2d, v17.2d }, [x9], #32
-; CHECK-NEXT:    fmla v3.2d, v4.2d, v16.2d
-; CHECK-NEXT:    fmla v1.2d, v5.2d, v16.2d
-; CHECK-NEXT:    fmls v3.2d, v5.2d, v17.2d
-; CHECK-NEXT:    fmla v1.2d, v4.2d, v17.2d
-; CHECK-NEXT:    ld2 { v18.2d, v19.2d }, [x9]
-; CHECK-NEXT:    fmla v2.2d, v6.2d, v18.2d
-; CHECK-NEXT:    fmla v0.2d, v7.2d, v18.2d
-; CHECK-NEXT:    fmls v2.2d, v7.2d, v19.2d
-; CHECK-NEXT:    fmla v0.2d, v6.2d, v19.2d
+; CHECK-NEXT:    ldp q5, q4, [x9]
+; CHECK-NEXT:    ldp q7, q6, [x9, #32]
+; CHECK-NEXT:    ldp q17, q16, [x10]
+; CHECK-NEXT:    fcmla v1.2d, v5.2d, v17.2d, #0
+; CHECK-NEXT:    ldp q19, q18, [x10, #32]
+; CHECK-NEXT:    fcmla v0.2d, v4.2d, v16.2d, #0
+; CHECK-NEXT:    fcmla v1.2d, v5.2d, v17.2d, #90
+; CHECK-NEXT:    fcmla v2.2d, v7.2d, v19.2d, #0
+; CHECK-NEXT:    fcmla v0.2d, v4.2d, v16.2d, #90
+; CHECK-NEXT:    fcmla v3.2d, v6.2d, v18.2d, #0
+; CHECK-NEXT:    fcmla v2.2d, v7.2d, v19.2d, #90
+; CHECK-NEXT:    fcmla v3.2d, v6.2d, v18.2d, #90
 ; CHECK-NEXT:    b.ne .LBB2_1
 ; CHECK-NEXT:  // %bb.2: // %middle.block
-; CHECK-NEXT:    fadd v2.2d, v2.2d, v3.2d
-; CHECK-NEXT:    fadd v1.2d, v0.2d, v1.2d
-; CHECK-NEXT:    faddp d0, v2.2d
-; CHECK-NEXT:    faddp d1, v1.2d
+; CHECK-NEXT:    zip2 v4.2d, v2.2d, v3.2d
+; CHECK-NEXT:    zip1 v2.2d, v2.2d, v3.2d
+; CHECK-NEXT:    zip1 v3.2d, v1.2d, v0.2d
+; CHECK-NEXT:    zip2 v0.2d, v1.2d, v0.2d
+; CHECK-NEXT:    fadd v1.2d, v2.2d, v3.2d
+; CHECK-NEXT:    fadd v2.2d, v4.2d, v0.2d
+; CHECK-NEXT:    faddp d0, v1.2d
+; CHECK-NEXT:    faddp d1, v2.2d
 ; CHECK-NEXT:    ret
 entry:
   %scevgep = getelementptr i8, ptr %a, i64 32


        


More information about the llvm-commits mailing list