[llvm] A few improvement in fcmla pattern recognitions (PR #173818)

Yichao Yu via llvm-commits llvm-commits at lists.llvm.org
Mon Dec 29 14:44:15 PST 2025


https://github.com/yuyichao updated https://github.com/llvm/llvm-project/pull/173818

>From 282ea83102022a0f372809f988a20dd230b8d35e Mon Sep 17 00:00:00 2001
From: Yichao Yu <yyc1992 at gmail.com>
Date: Fri, 26 Dec 2025 17:31:42 -0500
Subject: [PATCH 01/12] Improve mixed fastmath flags handling in complex
 deinterleaving

* Relax requirement on exact fastmath flag matching

  It should be enough to require all flags to include reassoc

* Fallback to treating non-reassoc additions as addends to discover more
  deinterleaving opportunities.
---
 .../lib/CodeGen/ComplexDeinterleavingPass.cpp |  40 +--
 ...plex-deinterleaving-add-mull-fixed-fast.ll | 213 ++++++++++++++++
 ...x-deinterleaving-add-mull-scalable-fast.ll | 236 ++++++++++++++++++
 3 files changed, 473 insertions(+), 16 deletions(-)

diff --git a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
index 87ada87b4d32f..93383301f78ff 100644
--- a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
+++ b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
@@ -1235,13 +1235,7 @@ ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
 
   std::optional<FastMathFlags> Flags;
   if (isa<FPMathOperator>(Real)) {
-    if (Real->getFastMathFlags() != Imag->getFastMathFlags()) {
-      LLVM_DEBUG(dbgs() << "The flags in Real and Imaginary instructions are "
-                           "not identical\n");
-      return nullptr;
-    }
-
-    Flags = Real->getFastMathFlags();
+    Flags = Real->getFastMathFlags() & Imag->getFastMathFlags();
     if (!Flags->allowReassoc()) {
       LLVM_DEBUG(
           dbgs()
@@ -1250,11 +1244,23 @@ ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
     }
   }
 
+  auto UpdateFlags = [&Flags](Instruction *I) {
+    if (!Flags)
+      return true;
+    if (!isa<FPMathOperator>(I))
+      return false;
+    auto NewFlags = I->getFastMathFlags();
+    if (!NewFlags.allowReassoc())
+      return false;
+    *Flags &= NewFlags;
+    return true;
+  };
+
   // Collect multiplications and addend instructions from the given instruction
   // while traversing it operands. Additionally, verify that all instructions
   // have the same fast math flags.
-  auto Collect = [&Flags](Instruction *Insn, SmallVectorImpl<Product> &Muls,
-                          AddendList &Addends) -> bool {
+  auto Collect = [&UpdateFlags](Instruction *Insn, SmallVectorImpl<Product> &Muls,
+                                AddendList &Addends) -> bool {
     SmallVector<PointerIntPair<Value *, 1, bool>> Worklist = {{Insn, true}};
     SmallPtrSet<Value *, 8> Visited;
     while (!Worklist.empty()) {
@@ -1279,6 +1285,15 @@ ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
         Addends.emplace_back(I, IsPositive);
         continue;
       }
+
+      if (!UpdateFlags(I)) {
+        LLVM_DEBUG(dbgs() << "The instruction's fast math flags miss "
+                             "the 'Reassoc' attribute: "
+                          << *I << "\n");
+        Addends.emplace_back(I, IsPositive);
+        continue;
+      }
+
       switch (I->getOpcode()) {
       case Instruction::FAdd:
       case Instruction::Add:
@@ -1323,13 +1338,6 @@ ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
         Addends.emplace_back(I, IsPositive);
         continue;
       }
-
-      if (Flags && I->getFastMathFlags() != *Flags) {
-        LLVM_DEBUG(dbgs() << "The instruction's fast math flags are "
-                             "inconsistent with the root instructions' flags: "
-                          << *I << "\n");
-        return false;
-      }
     }
     return true;
   };
diff --git a/llvm/test/CodeGen/AArch64/complex-deinterleaving-add-mull-fixed-fast.ll b/llvm/test/CodeGen/AArch64/complex-deinterleaving-add-mull-fixed-fast.ll
index 7692b1cf0aaae..a7a4cf3b673cb 100644
--- a/llvm/test/CodeGen/AArch64/complex-deinterleaving-add-mull-fixed-fast.ll
+++ b/llvm/test/CodeGen/AArch64/complex-deinterleaving-add-mull-fixed-fast.ll
@@ -33,6 +33,36 @@ entry:
   ret <4 x double> %interleaved.vec
 }
 
+; a * b + c
+define <4 x double> @mull_add_mixed(<4 x double> %a, <4 x double> %b, <4 x double> %c) {
+; CHECK-LABEL: mull_add_mixed:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    fcmla v4.2d, v0.2d, v2.2d, #0
+; CHECK-NEXT:    fcmla v5.2d, v1.2d, v3.2d, #0
+; CHECK-NEXT:    fcmla v4.2d, v0.2d, v2.2d, #90
+; CHECK-NEXT:    fcmla v5.2d, v1.2d, v3.2d, #90
+; CHECK-NEXT:    mov v0.16b, v4.16b
+; CHECK-NEXT:    mov v1.16b, v5.16b
+; CHECK-NEXT:    ret
+entry:
+  %strided.vec = shufflevector <4 x double> %a, <4 x double> poison, <2 x i32> <i32 0, i32 2>
+  %strided.vec28 = shufflevector <4 x double> %a, <4 x double> poison, <2 x i32> <i32 1, i32 3>
+  %strided.vec30 = shufflevector <4 x double> %b, <4 x double> poison, <2 x i32> <i32 0, i32 2>
+  %strided.vec31 = shufflevector <4 x double> %b, <4 x double> poison, <2 x i32> <i32 1, i32 3>
+  %0 = fmul fast <2 x double> %strided.vec31, %strided.vec
+  %1 = fmul reassoc contract <2 x double> %strided.vec30, %strided.vec28
+  %2 = fadd reassoc contract <2 x double> %0, %1
+  %3 = fmul fast <2 x double> %strided.vec30, %strided.vec
+  %strided.vec33 = shufflevector <4 x double> %c, <4 x double> poison, <2 x i32> <i32 0, i32 2>
+  %strided.vec34 = shufflevector <4 x double> %c, <4 x double> poison, <2 x i32> <i32 1, i32 3>
+  %4 = fadd reassoc contract <2 x double> %strided.vec33, %3
+  %5 = fmul fast <2 x double> %strided.vec31, %strided.vec28
+  %6 = fsub fast <2 x double> %4, %5
+  %7 = fadd reassoc contract <2 x double> %2, %strided.vec34
+  %interleaved.vec = shufflevector <2 x double> %6, <2 x double> %7, <4 x i32> <i32 0, i32 2, i32 1, i32 3>
+  ret <4 x double> %interleaved.vec
+}
+
 ; a * b + c * d
 define <4 x double> @mul_add_mull(<4 x double> %a, <4 x double> %b, <4 x double> %c, <4 x double> %d) {
 ; CHECK-LABEL: mul_add_mull:
@@ -77,6 +107,50 @@ entry:
   ret <4 x double> %interleaved.vec
 }
 
+; a * b + c * d
+define <4 x double> @mul_add_mull_mixed(<4 x double> %a, <4 x double> %b, <4 x double> %c, <4 x double> %d) {
+; CHECK-LABEL: mul_add_mull_mixed:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    movi v16.2d, #0000000000000000
+; CHECK-NEXT:    movi v17.2d, #0000000000000000
+; CHECK-NEXT:    fcmla v17.2d, v6.2d, v4.2d, #0
+; CHECK-NEXT:    fcmla v16.2d, v7.2d, v5.2d, #0
+; CHECK-NEXT:    fcmla v17.2d, v0.2d, v2.2d, #0
+; CHECK-NEXT:    fcmla v16.2d, v1.2d, v3.2d, #0
+; CHECK-NEXT:    fcmla v17.2d, v6.2d, v4.2d, #90
+; CHECK-NEXT:    fcmla v16.2d, v7.2d, v5.2d, #90
+; CHECK-NEXT:    fcmla v17.2d, v0.2d, v2.2d, #90
+; CHECK-NEXT:    fcmla v16.2d, v1.2d, v3.2d, #90
+; CHECK-NEXT:    mov v0.16b, v17.16b
+; CHECK-NEXT:    mov v1.16b, v16.16b
+; CHECK-NEXT:    ret
+entry:
+  %strided.vec = shufflevector <4 x double> %a, <4 x double> poison, <2 x i32> <i32 0, i32 2>
+  %strided.vec51 = shufflevector <4 x double> %a, <4 x double> poison, <2 x i32> <i32 1, i32 3>
+  %strided.vec53 = shufflevector <4 x double> %b, <4 x double> poison, <2 x i32> <i32 0, i32 2>
+  %strided.vec54 = shufflevector <4 x double> %b, <4 x double> poison, <2 x i32> <i32 1, i32 3>
+  %0 = fmul fast <2 x double> %strided.vec54, %strided.vec
+  %1 = fmul reassoc contract <2 x double> %strided.vec53, %strided.vec51
+  %2 = fmul fast <2 x double> %strided.vec53, %strided.vec
+  %3 = fmul reassoc contract <2 x double> %strided.vec54, %strided.vec51
+  %strided.vec56 = shufflevector <4 x double> %c, <4 x double> poison, <2 x i32> <i32 0, i32 2>
+  %strided.vec57 = shufflevector <4 x double> %c, <4 x double> poison, <2 x i32> <i32 1, i32 3>
+  %strided.vec59 = shufflevector <4 x double> %d, <4 x double> poison, <2 x i32> <i32 0, i32 2>
+  %strided.vec60 = shufflevector <4 x double> %d, <4 x double> poison, <2 x i32> <i32 1, i32 3>
+  %4 = fmul reassoc contract <2 x double> %strided.vec60, %strided.vec56
+  %5 = fmul fast <2 x double> %strided.vec59, %strided.vec57
+  %6 = fmul reassoc contract <2 x double> %strided.vec59, %strided.vec56
+  %7 = fmul fast <2 x double> %strided.vec60, %strided.vec57
+  %8 = fadd reassoc contract <2 x double> %7, %3
+  %9 = fadd fast <2 x double> %6, %2
+  %10 = fsub reassoc contract <2 x double> %9, %8
+  %11 = fadd fast <2 x double> %0, %1
+  %12 = fadd reassoc contract <2 x double> %11, %5
+  %13 = fadd reassoc contract <2 x double> %12, %4
+  %interleaved.vec = shufflevector <2 x double> %10, <2 x double> %13, <4 x i32> <i32 0, i32 2, i32 1, i32 3>
+  ret <4 x double> %interleaved.vec
+}
+
 ; a * b - c * d
 define <4 x double> @mul_sub_mull(<4 x double> %a, <4 x double> %b, <4 x double> %c, <4 x double> %d) {
 ; CHECK-LABEL: mul_sub_mull:
@@ -121,6 +195,50 @@ entry:
   ret <4 x double> %interleaved.vec
 }
 
+; a * b - c * d
+define <4 x double> @mul_sub_mull_mixed(<4 x double> %a, <4 x double> %b, <4 x double> %c, <4 x double> %d) {
+; CHECK-LABEL: mul_sub_mull_mixed:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    movi v16.2d, #0000000000000000
+; CHECK-NEXT:    movi v17.2d, #0000000000000000
+; CHECK-NEXT:    fcmla v17.2d, v6.2d, v4.2d, #270
+; CHECK-NEXT:    fcmla v16.2d, v7.2d, v5.2d, #270
+; CHECK-NEXT:    fcmla v17.2d, v0.2d, v2.2d, #0
+; CHECK-NEXT:    fcmla v16.2d, v1.2d, v3.2d, #0
+; CHECK-NEXT:    fcmla v17.2d, v6.2d, v4.2d, #180
+; CHECK-NEXT:    fcmla v16.2d, v7.2d, v5.2d, #180
+; CHECK-NEXT:    fcmla v17.2d, v0.2d, v2.2d, #90
+; CHECK-NEXT:    fcmla v16.2d, v1.2d, v3.2d, #90
+; CHECK-NEXT:    mov v0.16b, v17.16b
+; CHECK-NEXT:    mov v1.16b, v16.16b
+; CHECK-NEXT:    ret
+entry:
+  %strided.vec = shufflevector <4 x double> %a, <4 x double> poison, <2 x i32> <i32 0, i32 2>
+  %strided.vec53 = shufflevector <4 x double> %a, <4 x double> poison, <2 x i32> <i32 1, i32 3>
+  %strided.vec55 = shufflevector <4 x double> %b, <4 x double> poison, <2 x i32> <i32 0, i32 2>
+  %strided.vec56 = shufflevector <4 x double> %b, <4 x double> poison, <2 x i32> <i32 1, i32 3>
+  %0 = fmul reassoc contract <2 x double> %strided.vec56, %strided.vec
+  %1 = fmul fast <2 x double> %strided.vec55, %strided.vec53
+  %2 = fmul reassoc contract <2 x double> %strided.vec55, %strided.vec
+  %3 = fmul reassoc contract <2 x double> %strided.vec56, %strided.vec53
+  %strided.vec58 = shufflevector <4 x double> %c, <4 x double> poison, <2 x i32> <i32 0, i32 2>
+  %strided.vec59 = shufflevector <4 x double> %c, <4 x double> poison, <2 x i32> <i32 1, i32 3>
+  %strided.vec61 = shufflevector <4 x double> %d, <4 x double> poison, <2 x i32> <i32 0, i32 2>
+  %strided.vec62 = shufflevector <4 x double> %d, <4 x double> poison, <2 x i32> <i32 1, i32 3>
+  %4 = fmul fast <2 x double> %strided.vec62, %strided.vec59
+  %5 = fmul reassoc contract <2 x double> %strided.vec61, %strided.vec58
+  %6 = fadd reassoc contract <2 x double> %5, %3
+  %7 = fadd fast <2 x double> %4, %2
+  %8 = fsub reassoc contract <2 x double> %7, %6
+  %9 = fmul fast <2 x double> %strided.vec61, %strided.vec59
+  %10 = fmul reassoc contract <2 x double> %strided.vec62, %strided.vec58
+  %11 = fadd reassoc contract <2 x double> %10, %9
+  %12 = fadd fast <2 x double> %0, %1
+  %13 = fsub reassoc contract <2 x double> %12, %11
+  %interleaved.vec = shufflevector <2 x double> %8, <2 x double> %13, <4 x i32> <i32 0, i32 2, i32 1, i32 3>
+  ret <4 x double> %interleaved.vec
+}
+
 ; a * b + conj(c) * d
 define <4 x double> @mul_conj_mull(<4 x double> %a, <4 x double> %b, <4 x double> %c, <4 x double> %d) {
 ; CHECK-LABEL: mul_conj_mull:
@@ -165,6 +283,50 @@ entry:
   ret <4 x double> %interleaved.vec
 }
 
+; a * b + conj(c) * d
+define <4 x double> @mul_conj_mull_mixed(<4 x double> %a, <4 x double> %b, <4 x double> %c, <4 x double> %d) {
+; CHECK-LABEL: mul_conj_mull_mixed:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    movi v16.2d, #0000000000000000
+; CHECK-NEXT:    movi v17.2d, #0000000000000000
+; CHECK-NEXT:    fcmla v17.2d, v0.2d, v2.2d, #0
+; CHECK-NEXT:    fcmla v16.2d, v1.2d, v3.2d, #0
+; CHECK-NEXT:    fcmla v17.2d, v0.2d, v2.2d, #90
+; CHECK-NEXT:    fcmla v16.2d, v1.2d, v3.2d, #90
+; CHECK-NEXT:    fcmla v17.2d, v4.2d, v6.2d, #0
+; CHECK-NEXT:    fcmla v16.2d, v5.2d, v7.2d, #0
+; CHECK-NEXT:    fcmla v17.2d, v4.2d, v6.2d, #270
+; CHECK-NEXT:    fcmla v16.2d, v5.2d, v7.2d, #270
+; CHECK-NEXT:    mov v0.16b, v17.16b
+; CHECK-NEXT:    mov v1.16b, v16.16b
+; CHECK-NEXT:    ret
+entry:
+  %strided.vec = shufflevector <4 x double> %a, <4 x double> poison, <2 x i32> <i32 0, i32 2>
+  %strided.vec59 = shufflevector <4 x double> %a, <4 x double> poison, <2 x i32> <i32 1, i32 3>
+  %strided.vec61 = shufflevector <4 x double> %b, <4 x double> poison, <2 x i32> <i32 0, i32 2>
+  %strided.vec62 = shufflevector <4 x double> %b, <4 x double> poison, <2 x i32> <i32 1, i32 3>
+  %0 = fmul reassoc contract <2 x double> %strided.vec62, %strided.vec
+  %1 = fmul fast <2 x double> %strided.vec61, %strided.vec59
+  %2 = fmul reassoc contract <2 x double> %strided.vec61, %strided.vec
+  %strided.vec64 = shufflevector <4 x double> %c, <4 x double> poison, <2 x i32> <i32 0, i32 2>
+  %strided.vec65 = shufflevector <4 x double> %c, <4 x double> poison, <2 x i32> <i32 1, i32 3>
+  %strided.vec67 = shufflevector <4 x double> %d, <4 x double> poison, <2 x i32> <i32 0, i32 2>
+  %strided.vec68 = shufflevector <4 x double> %d, <4 x double> poison, <2 x i32> <i32 1, i32 3>
+  %3 = fmul fast <2 x double> %strided.vec68, %strided.vec64
+  %4 = fmul reassoc contract <2 x double> %strided.vec67, %strided.vec64
+  %5 = fmul fast <2 x double> %strided.vec68, %strided.vec65
+  %6 = fmul reassoc contract <2 x double> %strided.vec62, %strided.vec59
+  %7 = fsub reassoc contract <2 x double> %2, %6
+  %8 = fadd fast <2 x double> %7, %4
+  %9 = fadd reassoc contract <2 x double> %8, %5
+  %10 = fadd fast <2 x double> %0, %1
+  %11 = fmul reassoc contract <2 x double> %strided.vec67, %strided.vec65
+  %12 = fsub fast <2 x double> %10, %11
+  %13 = fadd reassoc contract <2 x double> %12, %3
+  %interleaved.vec = shufflevector <2 x double> %9, <2 x double> %13, <4 x i32> <i32 0, i32 2, i32 1, i32 3>
+  ret <4 x double> %interleaved.vec
+}
+
 ; a + b + 1i * c * d
 define <4 x double> @mul_add_rot_mull(<4 x double> %a, <4 x double> %b, <4 x double> %c, <4 x double> %d) {
 ; CHECK-LABEL: mul_add_rot_mull:
@@ -215,3 +377,54 @@ entry:
   %interleaved.vec = shufflevector <2 x double> %9, <2 x double> %13, <4 x i32> <i32 0, i32 2, i32 1, i32 3>
   ret <4 x double> %interleaved.vec
 }
+
+; a + b + 1i * c * d
+define <4 x double> @mul_add_rot_mull_mixed(<4 x double> %a, <4 x double> %b, <4 x double> %c, <4 x double> %d) {
+; CHECK-LABEL: mul_add_rot_mull_mixed:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    zip2 v16.2d, v2.2d, v3.2d
+; CHECK-NEXT:    zip2 v17.2d, v0.2d, v1.2d
+; CHECK-NEXT:    zip1 v2.2d, v2.2d, v3.2d
+; CHECK-NEXT:    zip2 v18.2d, v4.2d, v5.2d
+; CHECK-NEXT:    zip1 v19.2d, v6.2d, v7.2d
+; CHECK-NEXT:    zip1 v0.2d, v0.2d, v1.2d
+; CHECK-NEXT:    zip1 v1.2d, v4.2d, v5.2d
+; CHECK-NEXT:    zip2 v5.2d, v6.2d, v7.2d
+; CHECK-NEXT:    fmul v3.2d, v16.2d, v17.2d
+; CHECK-NEXT:    fmul v4.2d, v2.2d, v17.2d
+; CHECK-NEXT:    fmla v3.2d, v18.2d, v19.2d
+; CHECK-NEXT:    fmla v4.2d, v0.2d, v16.2d
+; CHECK-NEXT:    fmla v3.2d, v1.2d, v5.2d
+; CHECK-NEXT:    fmla v4.2d, v1.2d, v19.2d
+; CHECK-NEXT:    fneg v3.2d, v3.2d
+; CHECK-NEXT:    fmls v4.2d, v18.2d, v5.2d
+; CHECK-NEXT:    fmla v3.2d, v0.2d, v2.2d
+; CHECK-NEXT:    zip1 v0.2d, v3.2d, v4.2d
+; CHECK-NEXT:    zip2 v1.2d, v3.2d, v4.2d
+; CHECK-NEXT:    ret
+entry:
+  %strided.vec = shufflevector <4 x double> %a, <4 x double> poison, <2 x i32> <i32 0, i32 2>
+  %strided.vec79 = shufflevector <4 x double> %a, <4 x double> poison, <2 x i32> <i32 1, i32 3>
+  %strided.vec81 = shufflevector <4 x double> %b, <4 x double> poison, <2 x i32> <i32 0, i32 2>
+  %strided.vec82 = shufflevector <4 x double> %b, <4 x double> poison, <2 x i32> <i32 1, i32 3>
+  %0 = fmul reassoc contract <2 x double> %strided.vec82, %strided.vec
+  %1 = fmul reassoc contract <2 x double> %strided.vec81, %strided.vec79
+  %2 = fmul reassoc contract <2 x double> %strided.vec81, %strided.vec
+  %3 = fmul reassoc contract <2 x double> %strided.vec82, %strided.vec79
+  %strided.vec84 = shufflevector <4 x double> %c, <4 x double> poison, <2 x i32> <i32 0, i32 2>
+  %strided.vec85 = shufflevector <4 x double> %c, <4 x double> poison, <2 x i32> <i32 1, i32 3>
+  %strided.vec87 = shufflevector <4 x double> %d, <4 x double> poison, <2 x i32> <i32 0, i32 2>
+  %strided.vec88 = shufflevector <4 x double> %d, <4 x double> poison, <2 x i32> <i32 1, i32 3>
+  %4 = fmul fast <2 x double> %strided.vec87, %strided.vec84
+  %5 = fmul reassoc contract <2 x double> %strided.vec87, %strided.vec85
+  %6 = fmul fast <2 x double> %strided.vec88, %strided.vec84
+  %7 = fadd reassoc contract <2 x double> %5, %3
+  %8 = fadd fast <2 x double> %7, %6
+  %9 = fsub reassoc contract <2 x double> %2, %8
+  %10 = fadd reassoc contract <2 x double> %0, %1
+  %11 = fadd fast <2 x double> %10, %4
+  %12 = fmul reassoc contract <2 x double> %strided.vec88, %strided.vec85
+  %13 = fsub fast <2 x double> %11, %12
+  %interleaved.vec = shufflevector <2 x double> %9, <2 x double> %13, <4 x i32> <i32 0, i32 2, i32 1, i32 3>
+  ret <4 x double> %interleaved.vec
+}
diff --git a/llvm/test/CodeGen/AArch64/complex-deinterleaving-add-mull-scalable-fast.ll b/llvm/test/CodeGen/AArch64/complex-deinterleaving-add-mull-scalable-fast.ll
index b68c0094f84de..cc9813bc7d1fd 100644
--- a/llvm/test/CodeGen/AArch64/complex-deinterleaving-add-mull-scalable-fast.ll
+++ b/llvm/test/CodeGen/AArch64/complex-deinterleaving-add-mull-scalable-fast.ll
@@ -37,6 +37,40 @@ entry:
   ret <vscale x 4 x double> %interleaved.vec
 }
 
+; a * b + c
+define <vscale x 4 x double> @mull_add_mixed(<vscale x 4 x double> %a, <vscale x 4 x double> %b, <vscale x 4 x double> %c) {
+; CHECK-LABEL: mull_add_mixed:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    fcmla z4.d, p0/m, z0.d, z2.d, #0
+; CHECK-NEXT:    fcmla z5.d, p0/m, z1.d, z3.d, #0
+; CHECK-NEXT:    fcmla z4.d, p0/m, z0.d, z2.d, #90
+; CHECK-NEXT:    fcmla z5.d, p0/m, z1.d, z3.d, #90
+; CHECK-NEXT:    mov z0.d, z4.d
+; CHECK-NEXT:    mov z1.d, z5.d
+; CHECK-NEXT:    ret
+entry:
+  %strided.vec = tail call { <vscale x 2 x double>, <vscale x 2 x double> } @llvm.vector.deinterleave2.nxv4f64(<vscale x 4 x double> %a)
+  %0 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec, 0
+  %1 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec, 1
+  %strided.vec29 = tail call { <vscale x 2 x double>, <vscale x 2 x double> } @llvm.vector.deinterleave2.nxv4f64(<vscale x 4 x double> %b)
+  %2 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec29, 0
+  %3 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec29, 1
+  %4 = fmul reassoc contract <vscale x 2 x double> %3, %0
+  %5 = fmul fast <vscale x 2 x double> %2, %1
+  %6 = fadd reassoc contract <vscale x 2 x double> %4, %5
+  %7 = fmul fast <vscale x 2 x double> %2, %0
+  %strided.vec31 = tail call { <vscale x 2 x double>, <vscale x 2 x double> } @llvm.vector.deinterleave2.nxv4f64(<vscale x 4 x double> %c)
+  %8 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec31, 0
+  %9 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec31, 1
+  %10 = fadd fast <vscale x 2 x double> %8, %7
+  %11 = fmul reassoc contract <vscale x 2 x double> %3, %1
+  %12 = fsub reassoc contract <vscale x 2 x double> %10, %11
+  %13 = fadd fast <vscale x 2 x double> %6, %9
+  %interleaved.vec = tail call <vscale x 4 x double> @llvm.vector.interleave2.nxv4f64(<vscale x 2 x double> %12, <vscale x 2 x double> %13)
+  ret <vscale x 4 x double> %interleaved.vec
+}
+
 ; a * b + c * d
 define <vscale x 4 x double> @mul_add_mull(<vscale x 4 x double> %a, <vscale x 4 x double> %b, <vscale x 4 x double> %c, <vscale x 4 x double> %d) {
 ; CHECK-LABEL: mul_add_mull:
@@ -86,6 +120,55 @@ entry:
   ret <vscale x 4 x double> %interleaved.vec
 }
 
+; a * b + c * d
+define <vscale x 4 x double> @mul_add_mull_mixed(<vscale x 4 x double> %a, <vscale x 4 x double> %b, <vscale x 4 x double> %c, <vscale x 4 x double> %d) {
+; CHECK-LABEL: mul_add_mull_mixed:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    movi v24.2d, #0000000000000000
+; CHECK-NEXT:    movi v25.2d, #0000000000000000
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    fcmla z25.d, p0/m, z6.d, z4.d, #0
+; CHECK-NEXT:    fcmla z24.d, p0/m, z7.d, z5.d, #0
+; CHECK-NEXT:    fcmla z25.d, p0/m, z0.d, z2.d, #0
+; CHECK-NEXT:    fcmla z24.d, p0/m, z1.d, z3.d, #0
+; CHECK-NEXT:    fcmla z25.d, p0/m, z6.d, z4.d, #90
+; CHECK-NEXT:    fcmla z24.d, p0/m, z7.d, z5.d, #90
+; CHECK-NEXT:    fcmla z25.d, p0/m, z0.d, z2.d, #90
+; CHECK-NEXT:    fcmla z24.d, p0/m, z1.d, z3.d, #90
+; CHECK-NEXT:    mov z0.d, z25.d
+; CHECK-NEXT:    mov z1.d, z24.d
+; CHECK-NEXT:    ret
+entry:
+  %strided.vec = tail call { <vscale x 2 x double>, <vscale x 2 x double> } @llvm.vector.deinterleave2.nxv4f64(<vscale x 4 x double> %a)
+  %0 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec, 0
+  %1 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec, 1
+  %strided.vec52 = tail call { <vscale x 2 x double>, <vscale x 2 x double> } @llvm.vector.deinterleave2.nxv4f64(<vscale x 4 x double> %b)
+  %2 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec52, 0
+  %3 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec52, 1
+  %4 = fmul fast <vscale x 2 x double> %3, %0
+  %5 = fmul reassoc contract <vscale x 2 x double> %2, %1
+  %6 = fmul fast <vscale x 2 x double> %2, %0
+  %7 = fmul reassoc contract <vscale x 2 x double> %3, %1
+  %strided.vec54 = tail call { <vscale x 2 x double>, <vscale x 2 x double> } @llvm.vector.deinterleave2.nxv4f64(<vscale x 4 x double> %c)
+  %8 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec54, 0
+  %9 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec54, 1
+  %strided.vec56 = tail call { <vscale x 2 x double>, <vscale x 2 x double> } @llvm.vector.deinterleave2.nxv4f64(<vscale x 4 x double> %d)
+  %10 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec56, 0
+  %11 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec56, 1
+  %12 = fmul reassoc contract <vscale x 2 x double> %11, %8
+  %13 = fmul fast <vscale x 2 x double> %10, %9
+  %14 = fmul reassoc contract <vscale x 2 x double> %10, %8
+  %15 = fmul fast <vscale x 2 x double> %11, %9
+  %16 = fadd fast <vscale x 2 x double> %15, %7
+  %17 = fadd reassoc contract <vscale x 2 x double> %14, %6
+  %18 = fsub reassoc contract <vscale x 2 x double> %17, %16
+  %19 = fadd fast <vscale x 2 x double> %4, %5
+  %20 = fadd reassoc contract <vscale x 2 x double> %19, %13
+  %21 = fadd fast <vscale x 2 x double> %20, %12
+  %interleaved.vec = tail call <vscale x 4 x double> @llvm.vector.interleave2.nxv4f64(<vscale x 2 x double> %18, <vscale x 2 x double> %21)
+  ret <vscale x 4 x double> %interleaved.vec
+}
+
 ; a * b - c * d
 define <vscale x 4 x double> @mul_sub_mull(<vscale x 4 x double> %a, <vscale x 4 x double> %b, <vscale x 4 x double> %c, <vscale x 4 x double> %d) {
 ; CHECK-LABEL: mul_sub_mull:
@@ -135,6 +218,55 @@ entry:
   ret <vscale x 4 x double> %interleaved.vec
 }
 
+; a * b - c * d
+define <vscale x 4 x double> @mul_sub_mull_mixed(<vscale x 4 x double> %a, <vscale x 4 x double> %b, <vscale x 4 x double> %c, <vscale x 4 x double> %d) {
+; CHECK-LABEL: mul_sub_mull_mixed:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    movi v24.2d, #0000000000000000
+; CHECK-NEXT:    movi v25.2d, #0000000000000000
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    fcmla z25.d, p0/m, z6.d, z4.d, #270
+; CHECK-NEXT:    fcmla z24.d, p0/m, z7.d, z5.d, #270
+; CHECK-NEXT:    fcmla z25.d, p0/m, z0.d, z2.d, #0
+; CHECK-NEXT:    fcmla z24.d, p0/m, z1.d, z3.d, #0
+; CHECK-NEXT:    fcmla z25.d, p0/m, z6.d, z4.d, #180
+; CHECK-NEXT:    fcmla z24.d, p0/m, z7.d, z5.d, #180
+; CHECK-NEXT:    fcmla z25.d, p0/m, z0.d, z2.d, #90
+; CHECK-NEXT:    fcmla z24.d, p0/m, z1.d, z3.d, #90
+; CHECK-NEXT:    mov z0.d, z25.d
+; CHECK-NEXT:    mov z1.d, z24.d
+; CHECK-NEXT:    ret
+entry:
+  %strided.vec = tail call { <vscale x 2 x double>, <vscale x 2 x double> } @llvm.vector.deinterleave2.nxv4f64(<vscale x 4 x double> %a)
+  %0 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec, 0
+  %1 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec, 1
+  %strided.vec54 = tail call { <vscale x 2 x double>, <vscale x 2 x double> } @llvm.vector.deinterleave2.nxv4f64(<vscale x 4 x double> %b)
+  %2 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec54, 0
+  %3 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec54, 1
+  %4 = fmul reassoc contract <vscale x 2 x double> %3, %0
+  %5 = fmul fast <vscale x 2 x double> %2, %1
+  %6 = fmul reassoc contract <vscale x 2 x double> %2, %0
+  %7 = fmul fast <vscale x 2 x double> %3, %1
+  %strided.vec56 = tail call { <vscale x 2 x double>, <vscale x 2 x double> } @llvm.vector.deinterleave2.nxv4f64(<vscale x 4 x double> %c)
+  %8 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec56, 0
+  %9 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec56, 1
+  %strided.vec58 = tail call { <vscale x 2 x double>, <vscale x 2 x double> } @llvm.vector.deinterleave2.nxv4f64(<vscale x 4 x double> %d)
+  %10 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec58, 0
+  %11 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec58, 1
+  %12 = fmul fast <vscale x 2 x double> %11, %9
+  %13 = fmul reassoc contract <vscale x 2 x double> %10, %8
+  %14 = fadd fast <vscale x 2 x double> %13, %7
+  %15 = fadd reassoc contract <vscale x 2 x double> %12, %6
+  %16 = fsub reassoc contract <vscale x 2 x double> %15, %14
+  %17 = fmul reassoc contract <vscale x 2 x double> %10, %9
+  %18 = fmul fast <vscale x 2 x double> %11, %8
+  %19 = fadd fast <vscale x 2 x double> %18, %17
+  %20 = fadd reassoc contract <vscale x 2 x double> %4, %5
+  %21 = fsub fast <vscale x 2 x double> %20, %19
+  %interleaved.vec = tail call <vscale x 4 x double> @llvm.vector.interleave2.nxv4f64(<vscale x 2 x double> %16, <vscale x 2 x double> %21)
+  ret <vscale x 4 x double> %interleaved.vec
+}
+
 ; a * b + conj(c) * d
 define <vscale x 4 x double> @mul_conj_mull(<vscale x 4 x double> %a, <vscale x 4 x double> %b, <vscale x 4 x double> %c, <vscale x 4 x double> %d) {
 ; CHECK-LABEL: mul_conj_mull:
@@ -184,6 +316,55 @@ entry:
   ret <vscale x 4 x double> %interleaved.vec
 }
 
+; a * b + conj(c) * d
+define <vscale x 4 x double> @mul_conj_mull_mixed(<vscale x 4 x double> %a, <vscale x 4 x double> %b, <vscale x 4 x double> %c, <vscale x 4 x double> %d) {
+; CHECK-LABEL: mul_conj_mull_mixed:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    movi v24.2d, #0000000000000000
+; CHECK-NEXT:    movi v25.2d, #0000000000000000
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    fcmla z25.d, p0/m, z0.d, z2.d, #0
+; CHECK-NEXT:    fcmla z24.d, p0/m, z1.d, z3.d, #0
+; CHECK-NEXT:    fcmla z25.d, p0/m, z0.d, z2.d, #90
+; CHECK-NEXT:    fcmla z24.d, p0/m, z1.d, z3.d, #90
+; CHECK-NEXT:    fcmla z25.d, p0/m, z4.d, z6.d, #0
+; CHECK-NEXT:    fcmla z24.d, p0/m, z5.d, z7.d, #0
+; CHECK-NEXT:    fcmla z25.d, p0/m, z4.d, z6.d, #270
+; CHECK-NEXT:    fcmla z24.d, p0/m, z5.d, z7.d, #270
+; CHECK-NEXT:    mov z0.d, z25.d
+; CHECK-NEXT:    mov z1.d, z24.d
+; CHECK-NEXT:    ret
+entry:
+  %strided.vec = tail call { <vscale x 2 x double>, <vscale x 2 x double> } @llvm.vector.deinterleave2.nxv4f64(<vscale x 4 x double> %a)
+  %0 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec, 0
+  %1 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec, 1
+  %strided.vec60 = tail call { <vscale x 2 x double>, <vscale x 2 x double> } @llvm.vector.deinterleave2.nxv4f64(<vscale x 4 x double> %b)
+  %2 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec60, 0
+  %3 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec60, 1
+  %4 = fmul fast <vscale x 2 x double> %3, %0
+  %5 = fmul reassoc contract <vscale x 2 x double> %2, %1
+  %6 = fmul reassoc contract <vscale x 2 x double> %2, %0
+  %strided.vec62 = tail call { <vscale x 2 x double>, <vscale x 2 x double> } @llvm.vector.deinterleave2.nxv4f64(<vscale x 4 x double> %c)
+  %7 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec62, 0
+  %8 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec62, 1
+  %strided.vec64 = tail call { <vscale x 2 x double>, <vscale x 2 x double> } @llvm.vector.deinterleave2.nxv4f64(<vscale x 4 x double> %d)
+  %9 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec64, 0
+  %10 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec64, 1
+  %11 = fmul fast <vscale x 2 x double> %10, %7
+  %12 = fmul reassoc contract <vscale x 2 x double> %9, %7
+  %13 = fmul fast <vscale x 2 x double> %10, %8
+  %14 = fmul reassoc contract <vscale x 2 x double> %3, %1
+  %15 = fsub reassoc contract <vscale x 2 x double> %6, %14
+  %16 = fadd fast <vscale x 2 x double> %15, %12
+  %17 = fadd reassoc contract <vscale x 2 x double> %16, %13
+  %18 = fadd reassoc contract <vscale x 2 x double> %4, %5
+  %19 = fmul fast <vscale x 2 x double> %9, %8
+  %20 = fsub reassoc contract <vscale x 2 x double> %18, %19
+  %21 = fadd fast <vscale x 2 x double> %20, %11
+  %interleaved.vec = tail call <vscale x 4 x double> @llvm.vector.interleave2.nxv4f64(<vscale x 2 x double> %17, <vscale x 2 x double> %21)
+  ret <vscale x 4 x double> %interleaved.vec
+}
+
 ; a + b + 1i * c * d
 define <vscale x 4 x double> @mul_add_rot_mull(<vscale x 4 x double> %a, <vscale x 4 x double> %b, <vscale x 4 x double> %c, <vscale x 4 x double> %d) {
 ; CHECK-LABEL: mul_add_rot_mull:
@@ -239,5 +420,60 @@ entry:
   ret <vscale x 4 x double> %interleaved.vec
 }
 
+; a + b + 1i * c * d
+define <vscale x 4 x double> @mul_add_rot_mull_mixed(<vscale x 4 x double> %a, <vscale x 4 x double> %b, <vscale x 4 x double> %c, <vscale x 4 x double> %d) {
+; CHECK-LABEL: mul_add_rot_mull_mixed:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    uzp2 z24.d, z2.d, z3.d
+; CHECK-NEXT:    uzp2 z25.d, z0.d, z1.d
+; CHECK-NEXT:    uzp1 z2.d, z2.d, z3.d
+; CHECK-NEXT:    uzp1 z0.d, z0.d, z1.d
+; CHECK-NEXT:    uzp2 z1.d, z4.d, z5.d
+; CHECK-NEXT:    uzp1 z26.d, z6.d, z7.d
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    uzp1 z4.d, z4.d, z5.d
+; CHECK-NEXT:    uzp2 z5.d, z6.d, z7.d
+; CHECK-NEXT:    fmul z3.d, z2.d, z25.d
+; CHECK-NEXT:    fmul z25.d, z24.d, z25.d
+; CHECK-NEXT:    fmla z3.d, p0/m, z24.d, z0.d
+; CHECK-NEXT:    fmla z25.d, p0/m, z26.d, z1.d
+; CHECK-NEXT:    fmla z25.d, p0/m, z5.d, z4.d
+; CHECK-NEXT:    fmla z3.d, p0/m, z26.d, z4.d
+; CHECK-NEXT:    fnmsb z2.d, p0/m, z0.d, z25.d
+; CHECK-NEXT:    fmsb z1.d, p0/m, z5.d, z3.d
+; CHECK-NEXT:    zip1 z0.d, z2.d, z1.d
+; CHECK-NEXT:    zip2 z1.d, z2.d, z1.d
+; CHECK-NEXT:    ret
+entry:
+  %strided.vec = tail call { <vscale x 2 x double>, <vscale x 2 x double> } @llvm.vector.deinterleave2.nxv4f64(<vscale x 4 x double> %a)
+  %0 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec, 0
+  %1 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec, 1
+  %strided.vec80 = tail call { <vscale x 2 x double>, <vscale x 2 x double> } @llvm.vector.deinterleave2.nxv4f64(<vscale x 4 x double> %b)
+  %2 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec80, 0
+  %3 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec80, 1
+  %4 = fmul reassoc contract <vscale x 2 x double> %3, %0
+  %5 = fmul fast <vscale x 2 x double> %2, %1
+  %6 = fmul reassoc contract <vscale x 2 x double> %2, %0
+  %7 = fmul reassoc contract <vscale x 2 x double> %3, %1
+  %strided.vec82 = tail call { <vscale x 2 x double>, <vscale x 2 x double> } @llvm.vector.deinterleave2.nxv4f64(<vscale x 4 x double> %c)
+  %8 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec82, 0
+  %9 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec82, 1
+  %strided.vec84 = tail call { <vscale x 2 x double>, <vscale x 2 x double> } @llvm.vector.deinterleave2.nxv4f64(<vscale x 4 x double> %d)
+  %10 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec84, 0
+  %11 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec84, 1
+  %12 = fmul fast <vscale x 2 x double> %10, %8
+  %13 = fmul reassoc contract <vscale x 2 x double> %10, %9
+  %14 = fmul fast <vscale x 2 x double> %11, %8
+  %15 = fadd reassoc contract <vscale x 2 x double> %13, %7
+  %16 = fadd fast <vscale x 2 x double> %15, %14
+  %17 = fsub reassoc contract <vscale x 2 x double> %6, %16
+  %18 = fadd fast <vscale x 2 x double> %4, %5
+  %19 = fadd reassoc contract <vscale x 2 x double> %18, %12
+  %20 = fmul reassoc contract <vscale x 2 x double> %11, %9
+  %21 = fsub fast <vscale x 2 x double> %19, %20
+  %interleaved.vec = tail call <vscale x 4 x double> @llvm.vector.interleave2.nxv4f64(<vscale x 2 x double> %17, <vscale x 2 x double> %21)
+  ret <vscale x 4 x double> %interleaved.vec
+}
+
 declare { <vscale x 2 x double>, <vscale x 2 x double> } @llvm.vector.deinterleave2.nxv4f64(<vscale x 4 x double>)
 declare <vscale x 4 x double> @llvm.vector.interleave2.nxv4f64(<vscale x 2 x double>, <vscale x 2 x double>)

>From 4bff4f4a4de2a92f5583728b5d140bf73724a136 Mon Sep 17 00:00:00 2001
From: Yichao Yu <yyc1992 at gmail.com>
Date: Fri, 26 Dec 2025 19:59:12 -0500
Subject: [PATCH 02/12] Handle reassoc fma and fmuladd in complex
 deinterleaving pass

---
 .../lib/CodeGen/ComplexDeinterleavingPass.cpp | 61 ++++++++++++-------
 ...plex-deinterleaving-add-mull-fixed-fast.ll | 31 ++++++++++
 ...x-deinterleaving-add-mull-scalable-fast.ll | 34 +++++++++++
 3 files changed, 105 insertions(+), 21 deletions(-)

diff --git a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
index 93383301f78ff..5f9b3cc5bf064 100644
--- a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
+++ b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
@@ -622,6 +622,14 @@ Value *getNegOperand(Value *V) {
   return I->getOperand(1);
 }
 
+static const IntrinsicInst *getFMAOrMulAdd(const Instruction *I) {
+  if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
+    auto IID = II->getIntrinsicID();
+    return IID == Intrinsic::fmuladd || IID == Intrinsic::fma ? II : nullptr;
+  }
+  return nullptr;
+}
+
 bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B, unsigned Factor) {
   ComplexDeinterleavingGraph Graph(TL, TLI, Factor);
   if (Graph.collectPotentialReductions(B))
@@ -1223,14 +1231,17 @@ ComplexDeinterleavingGraph::identifyNode(ComplexValues &Vals) {
 ComplexDeinterleavingGraph::CompositeNode *
 ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
                                                  Instruction *Imag) {
-  auto IsOperationSupported = [](unsigned Opcode) -> bool {
+  auto IsOperationSupported = [](Instruction *I) -> bool {
+    if (getFMAOrMulAdd(I))
+      return true;
+    unsigned Opcode = I->getOpcode();
     return Opcode == Instruction::FAdd || Opcode == Instruction::FSub ||
            Opcode == Instruction::FNeg || Opcode == Instruction::Add ||
            Opcode == Instruction::Sub;
   };
 
-  if (!IsOperationSupported(Real->getOpcode()) ||
-      !IsOperationSupported(Imag->getOpcode()))
+  if (!IsOperationSupported(Real) ||
+      !IsOperationSupported(Imag))
     return nullptr;
 
   std::optional<FastMathFlags> Flags;
@@ -1294,6 +1305,24 @@ ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
         continue;
       }
 
+      auto addMul = [&Muls](Value *V0, Value *V1, bool IsPositive) {
+        Value *A, *B;
+        if (isNeg(V0)) {
+          A = getNegOperand(V0);
+          IsPositive = !IsPositive;
+        } else {
+          A = V0;
+        }
+
+        if (isNeg(V1)) {
+          B = getNegOperand(V1);
+          IsPositive = !IsPositive;
+        } else {
+          B = V1;
+        }
+        Muls.push_back(Product{A, B, IsPositive});
+      };
+
       switch (I->getOpcode()) {
       case Instruction::FAdd:
       case Instruction::Add:
@@ -1313,29 +1342,19 @@ ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
         }
         break;
       case Instruction::FMul:
-      case Instruction::Mul: {
-        Value *A, *B;
-        if (isNeg(I->getOperand(0))) {
-          A = getNegOperand(I->getOperand(0));
-          IsPositive = !IsPositive;
-        } else {
-          A = I->getOperand(0);
-        }
-
-        if (isNeg(I->getOperand(1))) {
-          B = getNegOperand(I->getOperand(1));
-          IsPositive = !IsPositive;
-        } else {
-          B = I->getOperand(1);
-        }
-        Muls.push_back(Product{A, B, IsPositive});
+      case Instruction::Mul:
+        addMul(I->getOperand(0), I->getOperand(1), IsPositive);
         break;
-      }
       case Instruction::FNeg:
         Worklist.emplace_back(I->getOperand(0), !IsPositive);
         break;
       default:
-        Addends.emplace_back(I, IsPositive);
+        if (auto II = getFMAOrMulAdd(I)) {
+          Worklist.emplace_back(II->getArgOperand(2), IsPositive);
+          addMul(II->getArgOperand(0), II->getArgOperand(1), IsPositive);
+        } else {
+          Addends.emplace_back(I, IsPositive);
+        }
         continue;
       }
     }
diff --git a/llvm/test/CodeGen/AArch64/complex-deinterleaving-add-mull-fixed-fast.ll b/llvm/test/CodeGen/AArch64/complex-deinterleaving-add-mull-fixed-fast.ll
index a7a4cf3b673cb..918e175d0a70e 100644
--- a/llvm/test/CodeGen/AArch64/complex-deinterleaving-add-mull-fixed-fast.ll
+++ b/llvm/test/CodeGen/AArch64/complex-deinterleaving-add-mull-fixed-fast.ll
@@ -33,6 +33,34 @@ entry:
   ret <4 x double> %interleaved.vec
 }
 
+; a * b + c
+define <4 x double> @mull_add_fma(<4 x double> %a, <4 x double> %b, <4 x double> %c) {
+; CHECK-LABEL: mull_add_fma:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    fcmla v4.2d, v0.2d, v2.2d, #90
+; CHECK-NEXT:    fcmla v5.2d, v1.2d, v3.2d, #90
+; CHECK-NEXT:    fcmla v4.2d, v0.2d, v2.2d, #0
+; CHECK-NEXT:    fcmla v5.2d, v1.2d, v3.2d, #0
+; CHECK-NEXT:    mov v0.16b, v4.16b
+; CHECK-NEXT:    mov v1.16b, v5.16b
+; CHECK-NEXT:    ret
+entry:
+  %strided.vec = shufflevector <4 x double> %a, <4 x double> poison, <2 x i32> <i32 0, i32 2>
+  %strided.vec28 = shufflevector <4 x double> %a, <4 x double> poison, <2 x i32> <i32 1, i32 3>
+  %strided.vec30 = shufflevector <4 x double> %b, <4 x double> poison, <2 x i32> <i32 0, i32 2>
+  %strided.vec31 = shufflevector <4 x double> %b, <4 x double> poison, <2 x i32> <i32 1, i32 3>
+  %0 = fmul fast <2 x double> %strided.vec31, %strided.vec
+  %1 = call fast <2 x double> @llvm.fma.v2f64(<2 x double> %strided.vec30, <2 x double> %strided.vec28, <2 x double> %0)
+  %strided.vec33 = shufflevector <4 x double> %c, <4 x double> poison, <2 x i32> <i32 0, i32 2>
+  %strided.vec34 = shufflevector <4 x double> %c, <4 x double> poison, <2 x i32> <i32 1, i32 3>
+  %2 = call fast <2 x double> @llvm.fmuladd.v2f64(<2 x double> %strided.vec30, <2 x double> %strided.vec, <2 x double> %strided.vec33)
+  %3 = fneg fast <2 x double> %strided.vec31
+  %4 = call fast <2 x double> @llvm.fma.v2f64(<2 x double> %3, <2 x double> %strided.vec28, <2 x double> %2)
+  %5 = fadd fast <2 x double> %1, %strided.vec34
+  %interleaved.vec = shufflevector <2 x double> %4, <2 x double> %5, <4 x i32> <i32 0, i32 2, i32 1, i32 3>
+  ret <4 x double> %interleaved.vec
+}
+
 ; a * b + c
 define <4 x double> @mull_add_mixed(<4 x double> %a, <4 x double> %b, <4 x double> %c) {
 ; CHECK-LABEL: mull_add_mixed:
@@ -428,3 +456,6 @@ entry:
   %interleaved.vec = shufflevector <2 x double> %9, <2 x double> %13, <4 x i32> <i32 0, i32 2, i32 1, i32 3>
   ret <4 x double> %interleaved.vec
 }
+
+declare <2 x double> @llvm.fma.v2f64(<2 x double>, <2 x double>, <2 x double>)
+declare <2 x double> @llvm.fmuladd.v2f64(<2 x double>, <2 x double>, <2 x double>)
diff --git a/llvm/test/CodeGen/AArch64/complex-deinterleaving-add-mull-scalable-fast.ll b/llvm/test/CodeGen/AArch64/complex-deinterleaving-add-mull-scalable-fast.ll
index cc9813bc7d1fd..a40984426a203 100644
--- a/llvm/test/CodeGen/AArch64/complex-deinterleaving-add-mull-scalable-fast.ll
+++ b/llvm/test/CodeGen/AArch64/complex-deinterleaving-add-mull-scalable-fast.ll
@@ -37,6 +37,38 @@ entry:
   ret <vscale x 4 x double> %interleaved.vec
 }
 
+; a * b + c
+define <vscale x 4 x double> @mull_add_fma(<vscale x 4 x double> %a, <vscale x 4 x double> %b, <vscale x 4 x double> %c) {
+; CHECK-LABEL: mull_add_fma:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    fcmla z4.d, p0/m, z2.d, z0.d, #90
+; CHECK-NEXT:    fcmla z5.d, p0/m, z3.d, z1.d, #90
+; CHECK-NEXT:    fcmla z4.d, p0/m, z2.d, z0.d, #0
+; CHECK-NEXT:    fcmla z5.d, p0/m, z3.d, z1.d, #0
+; CHECK-NEXT:    mov z0.d, z4.d
+; CHECK-NEXT:    mov z1.d, z5.d
+; CHECK-NEXT:    ret
+entry:
+  %strided.vec = tail call { <vscale x 2 x double>, <vscale x 2 x double> } @llvm.vector.deinterleave2.nxv4f64(<vscale x 4 x double> %a)
+  %0 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec, 0
+  %1 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec, 1
+  %strided.vec29 = tail call { <vscale x 2 x double>, <vscale x 2 x double> } @llvm.vector.deinterleave2.nxv4f64(<vscale x 4 x double> %b)
+  %2 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec29, 0
+  %3 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec29, 1
+  %4 = fmul fast <vscale x 2 x double> %2, %1
+  %5 = call fast <vscale x 2 x double> @llvm.fmuladd.nxv2f64(<vscale x 2 x double> %3, <vscale x 2 x double> %0, <vscale x 2 x double> %4)
+  %strided.vec31 = tail call { <vscale x 2 x double>, <vscale x 2 x double> } @llvm.vector.deinterleave2.nxv4f64(<vscale x 4 x double> %c)
+  %6 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec31, 0
+  %7 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec31, 1
+  %8 = call fast <vscale x 2 x double> @llvm.fma.nxv2f64(<vscale x 2 x double> %2, <vscale x 2 x double> %0, <vscale x 2 x double> %6)
+  %9 = fneg fast <vscale x 2 x double> %3
+  %10 = call fast <vscale x 2 x double> @llvm.fmuladd.nxv2f64(<vscale x 2 x double> %9, <vscale x 2 x double> %1, <vscale x 2 x double> %8)
+  %11 = fadd fast <vscale x 2 x double> %5, %7
+  %interleaved.vec = tail call <vscale x 4 x double> @llvm.vector.interleave2.nxv4f64(<vscale x 2 x double> %10, <vscale x 2 x double> %11)
+  ret <vscale x 4 x double> %interleaved.vec
+}
+
 ; a * b + c
 define <vscale x 4 x double> @mull_add_mixed(<vscale x 4 x double> %a, <vscale x 4 x double> %b, <vscale x 4 x double> %c) {
 ; CHECK-LABEL: mull_add_mixed:
@@ -477,3 +509,5 @@ entry:
 
 declare { <vscale x 2 x double>, <vscale x 2 x double> } @llvm.vector.deinterleave2.nxv4f64(<vscale x 4 x double>)
 declare <vscale x 4 x double> @llvm.vector.interleave2.nxv4f64(<vscale x 2 x double>, <vscale x 2 x double>)
+declare <vscale x 2 x double> @llvm.fma.nxv2f64(<vscale x 2 x double>, <vscale x 2 x double>, <vscale x 2 x double>)
+declare <vscale x 2 x double> @llvm.fmuladd.nxv2f64(<vscale x 2 x double>, <vscale x 2 x double>, <vscale x 2 x double>)

>From 068abc17414bbb58920c075916f2938d8c91f0e1 Mon Sep 17 00:00:00 2001
From: Yichao Yu <yyc1992 at gmail.com>
Date: Sat, 27 Dec 2025 17:03:01 -0500
Subject: [PATCH 03/12] Discover partial complex multiplication pattern

---
 .../lib/CodeGen/ComplexDeinterleavingPass.cpp | 186 ++++++++++--------
 ...plex-deinterleaving-add-mull-fixed-fast.ll |  22 +++
 ...x-deinterleaving-add-mull-scalable-fast.ll |  27 +++
 .../complex-deinterleaving-mixed-cases.ll     |  21 +-
 .../mve-complex-deinterleaving-mixed-cases.ll |  24 +--
 5 files changed, 166 insertions(+), 114 deletions(-)

diff --git a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
index 5f9b3cc5bf064..3c77c2ac7dfc4 100644
--- a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
+++ b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
@@ -1175,6 +1175,9 @@ ComplexDeinterleavingGraph::identifyNode(ComplexValues &Vals) {
   if (CompositeNode *CN = identifySplat(Vals))
     return CN;
 
+  if (CompositeNode *CN = identifyDeinterleave(Vals))
+    return CN;
+
   for (auto &V : Vals) {
     auto *Real = dyn_cast<Instruction>(V.Real);
     auto *Imag = dyn_cast<Instruction>(V.Imag);
@@ -1182,9 +1185,6 @@ ComplexDeinterleavingGraph::identifyNode(ComplexValues &Vals) {
       return nullptr;
   }
 
-  if (CompositeNode *CN = identifyDeinterleave(Vals))
-    return CN;
-
   if (Vals.size() == 1) {
     assert(Factor == 2 && "Can only handle interleave factors of 2");
     auto *Real = dyn_cast<Instruction>(Vals[0].Real);
@@ -1485,6 +1485,20 @@ ComplexDeinterleavingGraph::identifyMultiplications(
       CommonToNode[InfoImag->Common] = NodeFromCommon;
       Processed[I] = true;
       Processed[J] = true;
+      break;
+    }
+
+    if (!Processed[I]) {
+      auto PoisonCommon = PoisonValue::get(InfoA.Common->getType());
+      auto NodeFromCommon = identifyNode(InfoA.Common, PoisonCommon);
+      if (!NodeFromCommon) {
+        NodeFromCommon = identifyNode(PoisonCommon, InfoA.Common);
+      }
+      if (!NodeFromCommon)
+        continue;
+
+      CommonToNode[InfoA.Common] = NodeFromCommon;
+      Processed[I] = true;
     }
   }
 
@@ -1555,10 +1569,18 @@ ComplexDeinterleavingGraph::identifyMultiplications(
 
     LLVM_DEBUG({
       dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n";
-      dbgs().indent(4) << "X: " << *NodeA->Vals[0].Real << "\n";
-      dbgs().indent(4) << "Y: " << *NodeA->Vals[0].Imag << "\n";
-      dbgs().indent(4) << "U: " << *NodeB->Vals[0].Real << "\n";
-      dbgs().indent(4) << "V: " << *NodeB->Vals[0].Imag << "\n";
+      auto PrintValue = [](const char *Name, Value *V) {
+        auto &OS = dbgs().indent(4) << Name << ": ";
+        if (V) {
+          OS << *V << "\n";
+        } else {
+          OS << "nullptr\n";
+        }
+      };
+      PrintValue("X", NodeA->Vals[0].Real);
+      PrintValue("Y", NodeA->Vals[0].Imag);
+      PrintValue("U", NodeB->Vals[0].Real);
+      PrintValue("V", NodeB->Vals[0].Imag);
       dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
     });
 
@@ -2060,12 +2082,22 @@ ComplexDeinterleavingGraph::identifyDeinterleave(ComplexValues &Vals) {
     return EVI;
   };
 
+  auto CheckValue = [&](Value *V, unsigned ExpectedIdx) {
+      if (isa<PoisonValue>(V))
+        return true;
+      auto EVI = CheckExtract(V, ExpectedIdx, II);
+      if (!EVI) {
+        II = nullptr;
+        return false;
+      }
+      if (!II)
+        II = cast<Instruction>(EVI->getAggregateOperand());
+      return true;
+  };
+
   for (unsigned Idx = 0; Idx < Vals.size(); Idx++) {
-    ExtractValueInst *RealEVI = CheckExtract(Vals[Idx].Real, Idx * 2, II);
-    if (RealEVI && Idx == 0)
-      II = cast<Instruction>(RealEVI->getAggregateOperand());
-    if (!RealEVI || !CheckExtract(Vals[Idx].Imag, (Idx * 2) + 1, II)) {
-      II = nullptr;
+    if (!CheckValue(Vals[Idx].Real, Idx * 2) ||
+        !CheckValue(Vals[Idx].Imag, (Idx * 2) + 1)) {
       break;
     }
   }
@@ -2080,8 +2112,12 @@ ComplexDeinterleavingGraph::identifyDeinterleave(ComplexValues &Vals) {
         llvm::ComplexDeinterleavingOperation::Deinterleave, Vals);
     PlaceholderNode->ReplacementNode = II->getOperand(0);
     for (auto &V : Vals) {
-      FinalInstructions.insert(cast<Instruction>(V.Real));
-      FinalInstructions.insert(cast<Instruction>(V.Imag));
+      if (!isa<PoisonValue>(V.Real)) {
+        FinalInstructions.insert(cast<Instruction>(V.Real));
+      }
+      if (!isa<PoisonValue>(V.Imag)) {
+        FinalInstructions.insert(cast<Instruction>(V.Imag));
+      }
     }
     return submitCompositeNode(PlaceholderNode);
   }
@@ -2091,95 +2127,87 @@ ComplexDeinterleavingGraph::identifyDeinterleave(ComplexValues &Vals) {
 
   Value *Real = Vals[0].Real;
   Value *Imag = Vals[0].Imag;
+  bool RealPoison = isa<PoisonValue>(Real);
+  bool ImagPoison = isa<PoisonValue>(Imag);
   auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real);
   auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag);
-  if (!RealShuffle || !ImagShuffle) {
+  if (!(RealShuffle || RealPoison) || !(ImagShuffle || ImagPoison)) {
     if (RealShuffle || ImagShuffle)
       LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n");
     return nullptr;
   }
-
-  Value *RealOp1 = RealShuffle->getOperand(1);
-  if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) {
-    LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n");
-    return nullptr;
-  }
-  Value *ImagOp1 = ImagShuffle->getOperand(1);
-  if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) {
-    LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n");
-    return nullptr;
-  }
-
-  Value *RealOp0 = RealShuffle->getOperand(0);
-  Value *ImagOp0 = ImagShuffle->getOperand(0);
-
-  if (RealOp0 != ImagOp0) {
-    LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n");
-    return nullptr;
-  }
-
-  ArrayRef<int> RealMask = RealShuffle->getShuffleMask();
-  ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask();
-  if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) {
-    LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n");
-    return nullptr;
+  Value *Op0;
+  FixedVectorType *ShuffleTy;
+  if (!RealShuffle) {
+    Op0 = ImagShuffle->getOperand(0);
+    ShuffleTy = cast<FixedVectorType>(ImagShuffle->getType());
   }
-
-  if (RealMask[0] != 0 || ImagMask[0] != 1) {
-    LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n");
-    return nullptr;
+  else {
+    Op0 = RealShuffle->getOperand(0);
+    ShuffleTy = cast<FixedVectorType>(RealShuffle->getType());
+    if (ImagShuffle) {
+      if (RealShuffle->getType() != ImagShuffle->getType()) {
+        LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n");
+        return nullptr;
+      }
+      if (Op0 != ImagShuffle->getOperand(0)) {
+        LLVM_DEBUG(dbgs() << " - Shuffle operands aren't equal.\n");
+        return nullptr;
+      }
+    }
   }
 
   // Type checking, the shuffle type should be a vector type of the same
   // scalar type, but half the size
-  auto CheckType = [&](ShuffleVectorInst *Shuffle) {
-    Value *Op = Shuffle->getOperand(0);
-    auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType());
-    auto *OpTy = cast<FixedVectorType>(Op->getType());
+  auto *Op0Ty = cast<FixedVectorType>(Op0->getType());
+  int NumElements = Op0Ty->getNumElements();
+  if (ShuffleTy->getScalarType() != Op0Ty->getScalarType() ||
+      (ShuffleTy->getNumElements() * 2) != Op0Ty->getNumElements()) {
+    LLVM_DEBUG(dbgs() << " - Shuffle is invalid type.\n");
+    return nullptr;
+  }
 
-    if (OpTy->getScalarType() != ShuffleTy->getScalarType())
+  auto CheckShuffle = [&](ShuffleVectorInst *Shuffle, int Mask0, const char *Name) -> bool {
+    if (!Shuffle) // Poison value
+      return true;
+    Value *Op1 = Shuffle->getOperand(1);
+    if (!isa<UndefValue>(Op1) && !isa<ConstantAggregateZero>(Op1)) {
+        LLVM_DEBUG(dbgs() << " - " << Name << "Op1 is not undef or zero.\n");
       return false;
-    if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
+    }
+    ArrayRef<int> Mask = Shuffle->getShuffleMask();
+    if (!isDeinterleavingMask(Mask)) {
+      LLVM_DEBUG(dbgs() << " - " << Name << "Masks are not deinterleaving.\n");
       return false;
-
-    return true;
-  };
-
-  auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool {
-    if (!CheckType(Shuffle))
+    }
+    if (Mask[0] != Mask0) {
+      LLVM_DEBUG(dbgs() << " - " << Name << "Masks do not have the correct initial value.\n");
       return false;
-
-    ArrayRef<int> Mask = Shuffle->getShuffleMask();
-    int Last = *Mask.rbegin();
-
-    Value *Op = Shuffle->getOperand(0);
-    auto *OpTy = cast<FixedVectorType>(Op->getType());
-    int NumElements = OpTy->getNumElements();
-
+    }
     // Ensure that the deinterleaving shuffle only pulls from the first
     // shuffle operand.
-    return Last < NumElements;
+    int Last = *Mask.rbegin();
+    if (Last >= NumElements) {
+      LLVM_DEBUG(dbgs() << " - " << Name << "Masks are out of bound.\n");
+      return false;
+    }
+    return true;
   };
 
-  if (RealShuffle->getType() != ImagShuffle->getType()) {
-    LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n");
-    return nullptr;
-  }
-  if (!CheckDeinterleavingShuffle(RealShuffle)) {
-    LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n");
-    return nullptr;
-  }
-  if (!CheckDeinterleavingShuffle(ImagShuffle)) {
-    LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n");
+  if (!CheckShuffle(RealShuffle, 0, "Real") ||
+      !CheckShuffle(ImagShuffle, 1, "Imag"))
     return nullptr;
-  }
 
   CompositeNode *PlaceholderNode =
       prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Deinterleave,
                            RealShuffle, ImagShuffle);
-  PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
-  FinalInstructions.insert(RealShuffle);
-  FinalInstructions.insert(ImagShuffle);
+  PlaceholderNode->ReplacementNode = Op0;
+  if (RealShuffle) {
+    FinalInstructions.insert(RealShuffle);
+  }
+  if (ImagShuffle) {
+    FinalInstructions.insert(ImagShuffle);
+  }
   return submitCompositeNode(PlaceholderNode);
 }
 
diff --git a/llvm/test/CodeGen/AArch64/complex-deinterleaving-add-mull-fixed-fast.ll b/llvm/test/CodeGen/AArch64/complex-deinterleaving-add-mull-fixed-fast.ll
index 918e175d0a70e..1cbd43447be4a 100644
--- a/llvm/test/CodeGen/AArch64/complex-deinterleaving-add-mull-fixed-fast.ll
+++ b/llvm/test/CodeGen/AArch64/complex-deinterleaving-add-mull-fixed-fast.ll
@@ -91,6 +91,28 @@ entry:
   ret <4 x double> %interleaved.vec
 }
 
+define <4 x double> @mull_add_partial(<4 x double> %a, <4 x double> %b, <4 x double> %c) {
+; CHECK-LABEL: mull_add_partial:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    fcmla v4.2d, v0.2d, v2.2d, #0
+; CHECK-NEXT:    fcmla v5.2d, v1.2d, v3.2d, #0
+; CHECK-NEXT:    mov v0.16b, v4.16b
+; CHECK-NEXT:    mov v1.16b, v5.16b
+; CHECK-NEXT:    ret
+entry:
+  %strided.vec = shufflevector <4 x double> %a, <4 x double> poison, <2 x i32> <i32 0, i32 2>
+  %strided.vec30 = shufflevector <4 x double> %b, <4 x double> poison, <2 x i32> <i32 0, i32 2>
+  %strided.vec31 = shufflevector <4 x double> %b, <4 x double> poison, <2 x i32> <i32 1, i32 3>
+  %strided.vec33 = shufflevector <4 x double> %c, <4 x double> poison, <2 x i32> <i32 0, i32 2>
+  %strided.vec34 = shufflevector <4 x double> %c, <4 x double> poison, <2 x i32> <i32 1, i32 3>
+  %0 = fmul fast <2 x double> %strided.vec30, %strided.vec
+  %1 = fadd fast <2 x double> %strided.vec33, %0
+  %2 = fmul fast <2 x double> %strided.vec31, %strided.vec
+  %3 = fadd fast <2 x double> %strided.vec34, %2
+  %interleaved.vec = shufflevector <2 x double> %1, <2 x double> %3, <4 x i32> <i32 0, i32 2, i32 1, i32 3>
+  ret <4 x double> %interleaved.vec
+}
+
 ; a * b + c * d
 define <4 x double> @mul_add_mull(<4 x double> %a, <4 x double> %b, <4 x double> %c, <4 x double> %d) {
 ; CHECK-LABEL: mul_add_mull:
diff --git a/llvm/test/CodeGen/AArch64/complex-deinterleaving-add-mull-scalable-fast.ll b/llvm/test/CodeGen/AArch64/complex-deinterleaving-add-mull-scalable-fast.ll
index a40984426a203..2998a8f3c6d4e 100644
--- a/llvm/test/CodeGen/AArch64/complex-deinterleaving-add-mull-scalable-fast.ll
+++ b/llvm/test/CodeGen/AArch64/complex-deinterleaving-add-mull-scalable-fast.ll
@@ -103,6 +103,33 @@ entry:
   ret <vscale x 4 x double> %interleaved.vec
 }
 
+define <vscale x 4 x double> @mull_add_partial(<vscale x 4 x double> %a, <vscale x 4 x double> %b, <vscale x 4 x double> %c) {
+; CHECK-LABEL: mull_add_partial:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    fcmla z4.d, p0/m, z0.d, z2.d, #0
+; CHECK-NEXT:    fcmla z5.d, p0/m, z1.d, z3.d, #0
+; CHECK-NEXT:    mov z0.d, z4.d
+; CHECK-NEXT:    mov z1.d, z5.d
+; CHECK-NEXT:    ret
+entry:
+  %strided.vec = tail call { <vscale x 2 x double>, <vscale x 2 x double> } @llvm.vector.deinterleave2.nxv4f64(<vscale x 4 x double> %a)
+  %v0 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec, 0
+  %v1 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec, 1
+  %strided.vec29 = tail call { <vscale x 2 x double>, <vscale x 2 x double> } @llvm.vector.deinterleave2.nxv4f64(<vscale x 4 x double> %b)
+  %v2 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec29, 0
+  %v3 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec29, 1
+  %strided.vec31 = tail call { <vscale x 2 x double>, <vscale x 2 x double> } @llvm.vector.deinterleave2.nxv4f64(<vscale x 4 x double> %c)
+  %v8 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec31, 0
+  %v9 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec31, 1
+  %v4 = fmul fast <vscale x 2 x double> %v2, %v0
+  %v5 = fadd fast <vscale x 2 x double> %v8, %v4
+  %v6 = fmul fast <vscale x 2 x double> %v3, %v0
+  %v7 = fadd fast <vscale x 2 x double> %v9, %v6
+  %interleaved.vec = tail call <vscale x 4 x double> @llvm.vector.interleave2.nxv4f64(<vscale x 2 x double> %v5, <vscale x 2 x double> %v7)
+  ret <vscale x 4 x double> %interleaved.vec
+}
+
 ; a * b + c * d
 define <vscale x 4 x double> @mul_add_mull(<vscale x 4 x double> %a, <vscale x 4 x double> %b, <vscale x 4 x double> %c, <vscale x 4 x double> %d) {
 ; CHECK-LABEL: mul_add_mull:
diff --git a/llvm/test/CodeGen/AArch64/complex-deinterleaving-mixed-cases.ll b/llvm/test/CodeGen/AArch64/complex-deinterleaving-mixed-cases.ll
index 1ed9cf2db24f7..5e4d9800a9b81 100644
--- a/llvm/test/CodeGen/AArch64/complex-deinterleaving-mixed-cases.ll
+++ b/llvm/test/CodeGen/AArch64/complex-deinterleaving-mixed-cases.ll
@@ -38,25 +38,16 @@ entry:
   ret <4 x float> %interleaved.vec
 }
 
-; Expected to not transform
+; Expected to transform
 define <4 x float> @add_mul(<4 x float> %a, <4 x float> %b, <4 x float> %c) {
 ; CHECK-LABEL: add_mul:
 ; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    fsub v4.4s, v1.4s, v2.4s
+; CHECK-NEXT:    movi v3.2d, #0000000000000000
 ; CHECK-NEXT:    fsub v0.4s, v1.4s, v0.4s
-; CHECK-NEXT:    fsub v1.4s, v1.4s, v2.4s
-; CHECK-NEXT:    ext v3.16b, v2.16b, v2.16b, #8
-; CHECK-NEXT:    ext v4.16b, v0.16b, v0.16b, #8
-; CHECK-NEXT:    ext v5.16b, v1.16b, v1.16b, #8
-; CHECK-NEXT:    zip2 v0.2s, v0.2s, v4.2s
-; CHECK-NEXT:    zip2 v4.2s, v2.2s, v3.2s
-; CHECK-NEXT:    zip1 v1.2s, v1.2s, v5.2s
-; CHECK-NEXT:    zip1 v2.2s, v2.2s, v3.2s
-; CHECK-NEXT:    fmul v5.2s, v4.2s, v0.2s
-; CHECK-NEXT:    fmul v3.2s, v1.2s, v4.2s
-; CHECK-NEXT:    fneg v4.2s, v5.2s
-; CHECK-NEXT:    fmla v3.2s, v0.2s, v2.2s
-; CHECK-NEXT:    fmla v4.2s, v1.2s, v2.2s
-; CHECK-NEXT:    zip1 v0.4s, v4.4s, v3.4s
+; CHECK-NEXT:    fcmla v3.4s, v4.4s, v2.4s, #0
+; CHECK-NEXT:    fcmla v3.4s, v0.4s, v2.4s, #90
+; CHECK-NEXT:    mov v0.16b, v3.16b
 ; CHECK-NEXT:    ret
 entry:
   %0 = fsub fast <4 x float> %b, %c
diff --git a/llvm/test/CodeGen/Thumb2/mve-complex-deinterleaving-mixed-cases.ll b/llvm/test/CodeGen/Thumb2/mve-complex-deinterleaving-mixed-cases.ll
index 16dc5a81e782b..c07ad70d18d39 100644
--- a/llvm/test/CodeGen/Thumb2/mve-complex-deinterleaving-mixed-cases.ll
+++ b/llvm/test/CodeGen/Thumb2/mve-complex-deinterleaving-mixed-cases.ll
@@ -35,30 +35,14 @@ entry:
   ret <4 x float> %interleaved.vec
 }
 
-; Expected to not transform
+; Expected to transform
 define arm_aapcs_vfpcc <4 x float> @add_mul(<4 x float> %a, <4 x float> %b, <4 x float> %c) {
 ; CHECK-LABEL: add_mul:
 ; CHECK:       @ %bb.0: @ %entry
-; CHECK-NEXT:    .vsave {d8, d9}
-; CHECK-NEXT:    vpush {d8, d9}
 ; CHECK-NEXT:    vsub.f32 q3, q1, q2
-; CHECK-NEXT:    vsub.f32 q0, q1, q0
-; CHECK-NEXT:    vmov.f32 s4, s9
-; CHECK-NEXT:    vmov.f32 s13, s14
-; CHECK-NEXT:    vmov.f32 s5, s11
-; CHECK-NEXT:    vmov.f32 s0, s1
-; CHECK-NEXT:    vmul.f32 q4, q3, q1
-; CHECK-NEXT:    vmov.f32 s1, s3
-; CHECK-NEXT:    vmov.f32 s9, s10
-; CHECK-NEXT:    vfma.f32 q4, q2, q0
-; CHECK-NEXT:    vmul.f32 q0, q1, q0
-; CHECK-NEXT:    vneg.f32 q1, q0
-; CHECK-NEXT:    vmov.f32 s1, s16
-; CHECK-NEXT:    vfma.f32 q1, q2, q3
-; CHECK-NEXT:    vmov.f32 s3, s17
-; CHECK-NEXT:    vmov.f32 s0, s4
-; CHECK-NEXT:    vmov.f32 s2, s5
-; CHECK-NEXT:    vpop {d8, d9}
+; CHECK-NEXT:    vsub.f32 q1, q1, q0
+; CHECK-NEXT:    vcmul.f32 q0, q2, q3, #0
+; CHECK-NEXT:    vcmla.f32 q0, q2, q1, #90
 ; CHECK-NEXT:    bx lr
 entry:
   %0 = fsub fast <4 x float> %b, %c

>From 491d3d38ecd5dcbf8821184f1a0add93499e1c45 Mon Sep 17 00:00:00 2001
From: Yichao Yu <yyc1992 at gmail.com>
Date: Sun, 28 Dec 2025 17:12:48 -0500
Subject: [PATCH 04/12] Match fcmla with more generic contract pattern

Use an approach similar to how reassoc is handled.
However, in this case, we need to maintain the structure of the operations
so instead of collecting a set of multiplications to be added together,
we build a stack of multiplications that will be added in the stack order.

Compared to the old approach, the depth of the stack can be 1
(to match unpaired single partial multiplication) and can also be
arbitrarily deep (to match longer complex computations).
Similar to the reassoc case, we can also walk the stack to find
complex pairs of common terms that may be more than one level
away from each other.
---
 .../lib/CodeGen/ComplexDeinterleavingPass.cpp | 487 ++++++++++--------
 ...-deinterleaving-add-mull-fixed-contract.ll |  54 +-
 ...interleaving-add-mull-scalable-contract.ll |  58 ++-
 3 files changed, 348 insertions(+), 251 deletions(-)

diff --git a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
index 3c77c2ac7dfc4..f3794103a87bf 100644
--- a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
+++ b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
@@ -275,6 +275,15 @@ class ComplexDeinterleavingGraph {
     bool IsNodeInverted;
   };
 
+  struct PartialMulNode {
+    PartialMulNode *prev;
+    Value *Common;
+    CompositeNode *UncommonNode;
+    CompositeNode *CommonNode{nullptr};
+    ComplexDeinterleavingRotation Rotation;
+    bool IsCommonReal() const { return Rotation == ComplexDeinterleavingRotation::Rotation_0 || Rotation == ComplexDeinterleavingRotation::Rotation_180; }
+  };
+
   explicit ComplexDeinterleavingGraph(const TargetLowering *TL,
                                       const TargetLibraryInfo *TLI,
                                       unsigned Factor)
@@ -384,14 +393,9 @@ class ComplexDeinterleavingGraph {
   ///      i: ci - ar * bi
   /// 270: r: cr + ai * bi
   ///      i: ci - ai * br
-  CompositeNode *identifyPartialMul(Instruction *Real, Instruction *Imag);
-
-  /// Identify the other branch of a Partial Mul, taking the CommonOperandI that
-  /// is partially known from identifyPartialMul, filling in the other half of
-  /// the complex pair.
-  CompositeNode *
-  identifyNodeWithImplicitAdd(Instruction *I, Instruction *J,
-                              std::pair<Value *, Value *> &CommonOperandI);
+  CompositeNode *identifyPartialMul(Instruction *Real, Instruction *Imag,
+                                    bool RealPositive=true, bool ImagPositive=true,
+                                    PartialMulNode *PN=nullptr);
 
   /// Identifies a complex add pattern and its rotation, based on the following
   /// patterns.
@@ -646,229 +650,265 @@ bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B, unsigned Factor) {
   return false;
 }
 
-ComplexDeinterleavingGraph::CompositeNode *
-ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
-    Instruction *Real, Instruction *Imag,
-    std::pair<Value *, Value *> &PartialMatch) {
-  LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag
-                    << "\n");
-
-  if (!Real->hasOneUse() || !Imag->hasOneUse()) {
-    LLVM_DEBUG(dbgs() << "  - Mul operand has multiple uses.\n");
-    return nullptr;
-  }
-
-  if ((Real->getOpcode() != Instruction::FMul &&
-       Real->getOpcode() != Instruction::Mul) ||
-      (Imag->getOpcode() != Instruction::FMul &&
-       Imag->getOpcode() != Instruction::Mul)) {
-    LLVM_DEBUG(
-        dbgs() << "  - Real or imaginary instruction is not fmul or mul\n");
-    return nullptr;
-  }
-
-  Value *R0 = Real->getOperand(0);
-  Value *R1 = Real->getOperand(1);
-  Value *I0 = Imag->getOperand(0);
-  Value *I1 = Imag->getOperand(1);
-
-  // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the
-  // rotations and use the operand.
-  unsigned Negs = 0;
-  Value *Op;
-  if (match(R0, m_Neg(m_Value(Op)))) {
-    Negs |= 1;
-    R0 = Op;
-  } else if (match(R1, m_Neg(m_Value(Op)))) {
-    Negs |= 1;
-    R1 = Op;
-  }
-
-  if (isNeg(I0)) {
-    Negs |= 2;
-    Negs ^= 1;
-    I0 = Op;
-  } else if (match(I1, m_Neg(m_Value(Op)))) {
-    Negs |= 2;
-    Negs ^= 1;
-    I1 = Op;
-  }
-
-  ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs;
-
-  Value *CommonOperand;
-  Value *UncommonRealOp;
-  Value *UncommonImagOp;
-
-  if (R0 == I0 || R0 == I1) {
-    CommonOperand = R0;
-    UncommonRealOp = R1;
-  } else if (R1 == I0 || R1 == I1) {
-    CommonOperand = R1;
-    UncommonRealOp = R0;
-  } else {
-    LLVM_DEBUG(dbgs() << "  - No equal operand\n");
-    return nullptr;
-  }
-
-  UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
-  if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
-      Rotation == ComplexDeinterleavingRotation::Rotation_270)
-    std::swap(UncommonRealOp, UncommonImagOp);
-
-  // Between identifyPartialMul and here we need to have found a complete valid
-  // pair from the CommonOperand of each part.
-  if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
-      Rotation == ComplexDeinterleavingRotation::Rotation_180)
-    PartialMatch.first = CommonOperand;
-  else
-    PartialMatch.second = CommonOperand;
-
-  if (!PartialMatch.first || !PartialMatch.second) {
-    LLVM_DEBUG(dbgs() << "  - Incomplete partial match\n");
-    return nullptr;
-  }
-
-  CompositeNode *CommonNode =
-      identifyNode(PartialMatch.first, PartialMatch.second);
-  if (!CommonNode) {
-    LLVM_DEBUG(dbgs() << "  - No CommonNode identified\n");
-    return nullptr;
-  }
-
-  CompositeNode *UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);
-  if (!UncommonNode) {
-    LLVM_DEBUG(dbgs() << "  - No UncommonNode identified\n");
-    return nullptr;
-  }
-
-  CompositeNode *Node = prepareCompositeNode(
-      ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
-  Node->Rotation = Rotation;
-  Node->addOperand(CommonNode);
-  Node->addOperand(UncommonNode);
-  return submitCompositeNode(Node);
-}
-
 ComplexDeinterleavingGraph::CompositeNode *
 ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
-                                               Instruction *Imag) {
-  LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag
-                    << "\n");
-
-  // Determine rotation
-  auto IsAdd = [](unsigned Op) {
-    return Op == Instruction::FAdd || Op == Instruction::Add;
+                                               Instruction *Imag,
+                                               bool RealPositive,
+                                               bool ImagPositive,
+                                               PartialMulNode *PN) {
+  LLVM_DEBUG(dbgs() << "identifyPartialMul "
+                    << (RealPositive ? " + " : " - ") << *Real << " / "
+                    << (ImagPositive ? " + " : " - ") << *Imag << "\n");
+
+  auto GetProduct = [](Value *V1, Value *V2, bool IsPositive) -> Product {
+    if (isNeg(V1)) {
+      V1 = getNegOperand(V1);
+      IsPositive = !IsPositive;
+    }
+    if (isNeg(V2)) {
+      V2 = getNegOperand(V2);
+      IsPositive = !IsPositive;
+    }
+    return {V1, V2, IsPositive};
   };
-  auto IsSub = [](unsigned Op) {
-    return Op == Instruction::FSub || Op == Instruction::Sub;
+  auto GetAddend = [](Value *V, bool IsPositive) -> Addend {
+    if (isNeg(V)) {
+      V = getNegOperand(V);
+      IsPositive = !IsPositive;
+    }
+    return {V, IsPositive};
   };
-  ComplexDeinterleavingRotation Rotation;
-  if (IsAdd(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
-    Rotation = ComplexDeinterleavingRotation::Rotation_0;
-  else if (IsSub(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
-    Rotation = ComplexDeinterleavingRotation::Rotation_90;
-  else if (IsSub(Real->getOpcode()) && IsSub(Imag->getOpcode()))
-    Rotation = ComplexDeinterleavingRotation::Rotation_180;
-  else if (IsAdd(Real->getOpcode()) && IsSub(Imag->getOpcode()))
-    Rotation = ComplexDeinterleavingRotation::Rotation_270;
-  else {
-    LLVM_DEBUG(dbgs() << "  - Unhandled rotation.\n");
-    return nullptr;
-  }
 
-  if (isa<FPMathOperator>(Real) &&
-      (!Real->getFastMathFlags().allowContract() ||
-       !Imag->getFastMathFlags().allowContract())) {
-    LLVM_DEBUG(dbgs() << "  - Contract is missing from the FastMath flags.\n");
-    return nullptr;
-  }
+  auto ProcessMulAdd = [&](Product Mul, Addend Add, bool CheckAdd,
+                          SmallVectorImpl<Product> &Muls, Addend &Addend) {
+    Muls.push_back(Mul);
+    if (CheckAdd) {
+      if (auto AddI = dyn_cast<Instruction>(Add.first)) {
+        auto Op = AddI->getOpcode();
+        if (Op == Instruction::FMul || Op == Instruction::Mul) {
+          Muls.emplace_back(GetProduct(AddI->getOperand(0), AddI->getOperand(1),
+                                       Add.second));
+          return;
+        }
+      }
+    }
+    Addend = Add;
+  };
 
-  Value *CR = Real->getOperand(0);
-  Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1));
-  if (!RealMulI)
-    return nullptr;
-  Value *CI = Imag->getOperand(0);
-  Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1));
-  if (!ImagMulI)
-    return nullptr;
+  auto ProcessInst = [&](Instruction *I, bool IsPositive,
+                         SmallVectorImpl<Product> &Muls, Addend &Addend) {
+    if (isNeg(I)) {
+      I = dyn_cast<Instruction>(getNegOperand(I));
+      if (!I) {
+        return false;
+      }
+      IsPositive = !IsPositive;
+    }
+    if (auto II = getFMAOrMulAdd(I)) {
+      ProcessMulAdd(GetProduct(II->getArgOperand(0), II->getArgOperand(1),
+                               IsPositive),
+                    GetAddend(II->getArgOperand(2), IsPositive),
+                    II->getFastMathFlags().allowReassoc(), Muls, Addend);
+      return true;
+    }
 
-  if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) {
-    LLVM_DEBUG(dbgs() << "  - Mul instruction has multiple uses\n");
-    return nullptr;
-  }
+    if (isa<FPMathOperator>(I) && !I->getFastMathFlags().allowContract()) {
+      LLVM_DEBUG(dbgs() << "  - Contract is missing from the FastMath flags.\n");
+      return false;
+    }
 
-  Value *R0 = RealMulI->getOperand(0);
-  Value *R1 = RealMulI->getOperand(1);
-  Value *I0 = ImagMulI->getOperand(0);
-  Value *I1 = ImagMulI->getOperand(1);
+    unsigned Opcode = I->getOpcode();
+    bool IsSub;
+    if (Opcode == Instruction::FAdd || Opcode == Instruction::Add)
+      IsSub = false;
+    else if (Opcode == Instruction::FSub || Opcode == Instruction::Sub)
+      IsSub = true;
+    else
+      return false;
+    Value *Op0 = I->getOperand(0);
+    Value *Op1 = I->getOperand(1);
+    if (auto I0 = dyn_cast<Instruction>(Op0)) {
+      unsigned Opcode0 = I0->getOpcode();
+      if (I0->hasOneUse() &&
+          (Opcode0 == Instruction::FMul || Opcode0 == Instruction::Mul)) {
+        ProcessMulAdd(GetProduct(I0->getOperand(0), I0->getOperand(1),
+                                 IsPositive),
+                      GetAddend(Op1, IsPositive ^ IsSub),
+                      true, Muls, Addend);
+        return true;
+      }
+    }
+    if (auto I1 = dyn_cast<Instruction>(Op1)) {
+      unsigned Opcode1 = I1->getOpcode();
+      if (I1->hasOneUse() &&
+          (Opcode1 == Instruction::FMul || Opcode1 == Instruction::Mul)) {
+        ProcessMulAdd(GetProduct(I1->getOperand(0), I1->getOperand(1),
+                                 IsPositive ^ IsSub),
+                      GetAddend(Op0, IsPositive),
+                      false, Muls, Addend);
+        return true;
+      }
+    }
+    return false;
+  };
 
-  Value *CommonOperand;
-  Value *UncommonRealOp;
-  Value *UncommonImagOp;
+  auto MatchCommons = [&](PartialMulNode *PN, CompositeNode *CN) -> CompositeNode* {
+    assert(PN);
+    for (auto PN0 = PN; PN0; PN0 = PN0->prev) {
+      if (PN0->CommonNode)
+        continue;
+      auto Common0 = PN0->Common;
+      auto RealCommon0 = PN0->IsCommonReal();
+      for (auto PN1 = PN0->prev; PN1; PN1 = PN1->prev) {
+        if (PN1->CommonNode)
+          continue;
+        auto Common1 = PN1->Common;
+        if (RealCommon0 == PN1->IsCommonReal())
+          continue;
+        if (auto CommonNode = (RealCommon0 ? identifyNode(Common0, Common1) :
+                               identifyNode(Common1, Common0))) {
+          PN0->CommonNode = CommonNode;
+          PN1->CommonNode = CommonNode;
+          break;
+        }
+      }
+      if (!PN0->CommonNode) {
+        auto PoisonCommon = PoisonValue::get(Common0->getType());
+        if (auto CommonNode = (RealCommon0 ? identifyNode(Common0, PoisonCommon) :
+                               identifyNode(PoisonCommon, Common0))) {
+          PN0->CommonNode = CommonNode;
+          continue;
+        }
+        // Clear CommonNodes for the next round
+        for (; PN; PN = PN->prev) {
+          PN->CommonNode = nullptr;
+        }
+        return nullptr;
+      }
+    }
+    for (; PN; PN = PN->prev) {
+      CompositeNode *NewCN = prepareCompositeNode(
+          ComplexDeinterleavingOperation::CMulPartial, nullptr, nullptr);
+      NewCN->Rotation = PN->Rotation;
+      NewCN->addOperand(PN->CommonNode);
+      NewCN->addOperand(PN->UncommonNode);
+      if (CN) {
+        NewCN->addOperand(CN);
+      }
+      CN = submitCompositeNode(NewCN);
+    }
+    return CN;
+  };
 
-  if (R0 == I0 || R0 == I1) {
-    CommonOperand = R0;
-    UncommonRealOp = R1;
-  } else if (R1 == I0 || R1 == I1) {
-    CommonOperand = R1;
-    UncommonRealOp = R0;
-  } else {
-    LLVM_DEBUG(dbgs() << "  - No equal operand\n");
+  SmallVector<Product,2> RealMuls{};
+  SmallVector<Product,2> ImagMuls{};
+  Addend RealAddend{nullptr, true};
+  Addend ImagAddend{nullptr, true};
+  if (!ProcessInst(Real, RealPositive, RealMuls, RealAddend) ||
+      !ProcessInst(Imag, ImagPositive, ImagMuls, ImagAddend)) {
+    LLVM_DEBUG(dbgs() << "  - Failed to match PartialMul in Real/Imag terms.\n");
+    if (PN && RealPositive && ImagPositive) {
+      auto CN = identifyNode(Real, Imag);
+      if (CN) {
+        LLVM_DEBUG({
+          dbgs() << "  - Addends matched:\n";
+          CN->dump();
+        });
+        return MatchCommons(PN, CN);
+      }
+      LLVM_DEBUG(dbgs() << "  - Failed to match Addends "
+                 << *Real << " / " << *Imag << ".\n");
+    }
     return nullptr;
   }
-
-  UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
-  if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
-      Rotation == ComplexDeinterleavingRotation::Rotation_270)
-    std::swap(UncommonRealOp, UncommonImagOp);
-
-  std::pair<Value *, Value *> PartialMatch(
-      (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
-       Rotation == ComplexDeinterleavingRotation::Rotation_180)
-          ? CommonOperand
-          : nullptr,
-      (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
-       Rotation == ComplexDeinterleavingRotation::Rotation_270)
-          ? CommonOperand
-          : nullptr);
-
-  auto *CRInst = dyn_cast<Instruction>(CR);
-  auto *CIInst = dyn_cast<Instruction>(CI);
-
-  if (!CRInst || !CIInst) {
-    LLVM_DEBUG(dbgs() << "  - Common operands are not instructions.\n");
+  assert(RealMuls.size() > 0 && ImagMuls.size() > 0);
+  if (RealMuls.size() != ImagMuls.size())
     return nullptr;
-  }
 
-  CompositeNode *CNode =
-      identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch);
-  if (!CNode) {
-    LLVM_DEBUG(dbgs() << "  - No cnode identified\n");
+  auto ForeachMatch = [&](Product RealMul, Product ImagMul,
+                          PartialMulNode *PN, auto &&cb) -> CompositeNode* {
+    PartialMulNode NewPN{};
+    NewPN.prev = PN;
+    if (RealMul.IsPositive) {
+      NewPN.Rotation = (ImagMul.IsPositive ?
+                        ComplexDeinterleavingRotation::Rotation_0 :
+                        ComplexDeinterleavingRotation::Rotation_270);
+    }
+    else {
+      NewPN.Rotation = (ImagMul.IsPositive ?
+                        ComplexDeinterleavingRotation::Rotation_90 :
+                        ComplexDeinterleavingRotation::Rotation_180);
+    }
+    auto IdentifyUncommon = [&] (Value *Real, Value *Imag) {
+      return (NewPN.IsCommonReal() ? identifyNode(Real, Imag) :
+              identifyNode(Imag, Real));
+    };
+    if (RealMul.Multiplier == ImagMul.Multiplier &&
+        (NewPN.UncommonNode = IdentifyUncommon(RealMul.Multiplicand,
+                                               ImagMul.Multiplicand))) {
+        NewPN.Common = RealMul.Multiplier;
+        if (auto CN = cb(&NewPN)) {
+          return CN;
+        }
+    }
+    if (ImagMul.Multiplicand != ImagMul.Multiplier &&
+        RealMul.Multiplier == ImagMul.Multiplicand &&
+        (NewPN.UncommonNode = IdentifyUncommon(RealMul.Multiplicand,
+                                               ImagMul.Multiplier))) {
+        NewPN.Common = RealMul.Multiplier;
+        if (auto CN = cb(&NewPN)) {
+          return CN;
+        }
+    }
+    if (RealMul.Multiplicand == RealMul.Multiplier)
+      return nullptr;
+    if (RealMul.Multiplicand == ImagMul.Multiplier &&
+        (NewPN.UncommonNode = IdentifyUncommon(RealMul.Multiplier,
+                                               ImagMul.Multiplicand))) {
+        NewPN.Common = RealMul.Multiplicand;
+        if (auto CN = cb(&NewPN)) {
+          return CN;
+        }
+    }
+    if (ImagMul.Multiplicand != ImagMul.Multiplier &&
+        RealMul.Multiplicand == ImagMul.Multiplicand &&
+        (NewPN.UncommonNode = IdentifyUncommon(RealMul.Multiplier,
+                                               ImagMul.Multiplier))) {
+        NewPN.Common = RealMul.Multiplicand;
+        if (auto CN = cb(&NewPN)) {
+          return CN;
+        }
+    }
     return nullptr;
-  }
+  };
 
-  CompositeNode *UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);
-  if (!UncommonRes) {
-    LLVM_DEBUG(dbgs() << "  - No UncommonRes identified\n");
-    return nullptr;
+  if (RealMuls.size() == 1) {
+    assert(RealAddend.first && ImagAddend.first);
+    if (!isa<Instruction>(RealAddend.first) || !isa<Instruction>(ImagAddend.first)) {
+      if (!RealAddend.second || !ImagAddend.second)
+        return nullptr;
+      auto CN = identifyNode(RealAddend.first, ImagAddend.first);
+      if (!CN)
+        return nullptr;
+      return ForeachMatch(RealMuls[0], ImagMuls[0], PN, [&](PartialMulNode *PN) {
+        return MatchCommons(PN, CN);
+      });
+    }
+    return ForeachMatch(RealMuls[0], ImagMuls[0], PN, [&](PartialMulNode *PN) {
+      return identifyPartialMul(cast<Instruction>(RealAddend.first),
+                                cast<Instruction>(ImagAddend.first),
+                                RealAddend.second, ImagAddend.second, PN);
+    });
   }
-
-  assert(PartialMatch.first && PartialMatch.second);
-  CompositeNode *CommonRes =
-      identifyNode(PartialMatch.first, PartialMatch.second);
-  if (!CommonRes) {
-    LLVM_DEBUG(dbgs() << "  - No CommonRes identified\n");
-    return nullptr;
+  else {
+    assert(RealMuls.size() == 2);
+    assert(!RealAddend.first && !ImagAddend.first);
+    return ForeachMatch(RealMuls[0], ImagMuls[0], PN, [&](PartialMulNode *PN) {
+      return ForeachMatch(RealMuls[1], ImagMuls[1], PN, [&](PartialMulNode *PN) {
+        return MatchCommons(PN, nullptr);
+      });
+    });
   }
-
-  CompositeNode *Node = prepareCompositeNode(
-      ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
-  Node->Rotation = Rotation;
-  Node->addOperand(CommonRes);
-  Node->addOperand(UncommonRes);
-  Node->addOperand(CNode);
-  return submitCompositeNode(Node);
 }
 
 ComplexDeinterleavingGraph::CompositeNode *
@@ -931,13 +971,6 @@ static bool isInstructionPairAdd(Instruction *A, Instruction *B) {
          (OpcA == Instruction::Add && OpcB == Instruction::Sub);
 }
 
-static bool isInstructionPairMul(Instruction *A, Instruction *B) {
-  auto Pattern =
-      m_BinOp(m_FMul(m_Value(), m_Value()), m_FMul(m_Value(), m_Value()));
-
-  return match(A, Pattern) && match(B, Pattern);
-}
-
 static bool isInstructionPotentiallySymmetric(Instruction *I) {
   switch (I->getOpcode()) {
   case Instruction::FAdd:
@@ -1203,20 +1236,20 @@ ComplexDeinterleavingGraph::identifyNode(ComplexValues &Vals) {
     bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported(
         ComplexDeinterleavingOperation::CAdd, NewVTy);
 
-    if (HasCMulSupport && isInstructionPairMul(Real, Imag)) {
-      if (CompositeNode *CN = identifyPartialMul(Real, Imag))
+    if (HasCMulSupport && HasCAddSupport) {
+      if (CompositeNode *CN = identifyReassocNodes(Real, Imag)) {
         return CN;
+      }
     }
 
-    if (HasCAddSupport && isInstructionPairAdd(Real, Imag)) {
-      if (CompositeNode *CN = identifyAdd(Real, Imag))
+    if (HasCMulSupport) {
+      if (CompositeNode *CN = identifyPartialMul(Real, Imag))
         return CN;
     }
 
-    if (HasCMulSupport && HasCAddSupport) {
-      if (CompositeNode *CN = identifyReassocNodes(Real, Imag)) {
+    if (HasCAddSupport && isInstructionPairAdd(Real, Imag)) {
+      if (CompositeNode *CN = identifyAdd(Real, Imag))
         return CN;
-      }
     }
   }
 
diff --git a/llvm/test/CodeGen/AArch64/complex-deinterleaving-add-mull-fixed-contract.ll b/llvm/test/CodeGen/AArch64/complex-deinterleaving-add-mull-fixed-contract.ll
index 09672d1be2161..e55d2a8f9cd5e 100644
--- a/llvm/test/CodeGen/AArch64/complex-deinterleaving-add-mull-fixed-contract.ll
+++ b/llvm/test/CodeGen/AArch64/complex-deinterleaving-add-mull-fixed-contract.ll
@@ -40,6 +40,36 @@ entry:
   ret <4 x double> %interleaved.vec
 }
 
+; a * b + c
+define <4 x double> @mull_accum(<4 x double> %a, <4 x double> %b, <4 x double> %c) {
+; CHECK-LABEL: mull_accum:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    fcmla   v4.2d, v0.2d, v2.2d, #0
+; CHECK-NEXT:    fcmla   v5.2d, v1.2d, v3.2d, #0
+; CHECK-NEXT:    fcmla   v4.2d, v0.2d, v2.2d, #90
+; CHECK-NEXT:    fcmla   v5.2d, v1.2d, v3.2d, #90
+; CHECK-NEXT:    mov     v0.16b, v4.16b
+; CHECK-NEXT:    mov     v1.16b, v5.16b
+; CHECK-NEXT:    ret
+entry:
+  %strided.vec = shufflevector <4 x double> %a, <4 x double> poison, <2 x i32> <i32 0, i32 2>
+  %strided.vec28 = shufflevector <4 x double> %a, <4 x double> poison, <2 x i32> <i32 1, i32 3>
+  %strided.vec30 = shufflevector <4 x double> %b, <4 x double> poison, <2 x i32> <i32 0, i32 2>
+  %strided.vec31 = shufflevector <4 x double> %b, <4 x double> poison, <2 x i32> <i32 1, i32 3>
+  %strided.vec33 = shufflevector <4 x double> %c, <4 x double> poison, <2 x i32> <i32 0, i32 2>
+  %strided.vec34 = shufflevector <4 x double> %c, <4 x double> poison, <2 x i32> <i32 1, i32 3>
+  %0 = fmul contract <2 x double> %strided.vec, %strided.vec31
+  %1 = fadd contract <2 x double> %strided.vec34, %0
+  %2 = fmul contract <2 x double> %strided.vec28, %strided.vec30
+  %3 = fadd contract <2 x double> %1, %2
+  %4 = fmul contract <2 x double> %strided.vec, %strided.vec30
+  %5 = fadd contract <2 x double> %strided.vec33, %4
+  %6 = fmul contract <2 x double> %strided.vec28, %strided.vec31
+  %7 = fsub contract <2 x double> %5, %6
+  %interleaved.vec = shufflevector <2 x double> %7, <2 x double> %3, <4 x i32> <i32 0, i32 2, i32 1, i32 3>
+  ret <4 x double> %interleaved.vec
+}
+
 ; a * b + c * d
 define <4 x double> @mul_add_mull(<4 x double> %a, <4 x double> %b, <4 x double> %c, <4 x double> %d) {
 ; CHECK-LABEL: mul_add_mull:
@@ -48,14 +78,14 @@ define <4 x double> @mul_add_mull(<4 x double> %a, <4 x double> %b, <4 x double>
 ; CHECK-NEXT:    movi v17.2d, #0000000000000000
 ; CHECK-NEXT:    movi v18.2d, #0000000000000000
 ; CHECK-NEXT:    movi v19.2d, #0000000000000000
-; CHECK-NEXT:    fcmla v16.2d, v2.2d, v0.2d, #0
-; CHECK-NEXT:    fcmla v18.2d, v3.2d, v1.2d, #0
-; CHECK-NEXT:    fcmla v17.2d, v7.2d, v5.2d, #0
-; CHECK-NEXT:    fcmla v19.2d, v6.2d, v4.2d, #0
 ; CHECK-NEXT:    fcmla v16.2d, v2.2d, v0.2d, #90
 ; CHECK-NEXT:    fcmla v18.2d, v3.2d, v1.2d, #90
 ; CHECK-NEXT:    fcmla v17.2d, v7.2d, v5.2d, #90
 ; CHECK-NEXT:    fcmla v19.2d, v6.2d, v4.2d, #90
+; CHECK-NEXT:    fcmla v16.2d, v2.2d, v0.2d, #0
+; CHECK-NEXT:    fcmla v18.2d, v3.2d, v1.2d, #0
+; CHECK-NEXT:    fcmla v17.2d, v7.2d, v5.2d, #0
+; CHECK-NEXT:    fcmla v19.2d, v6.2d, v4.2d, #0
 ; CHECK-NEXT:    fadd v1.2d, v18.2d, v17.2d
 ; CHECK-NEXT:    fadd v0.2d, v16.2d, v19.2d
 ; CHECK-NEXT:    ret
@@ -94,14 +124,14 @@ define <4 x double> @mul_sub_mull(<4 x double> %a, <4 x double> %b, <4 x double>
 ; CHECK-NEXT:    movi v17.2d, #0000000000000000
 ; CHECK-NEXT:    movi v18.2d, #0000000000000000
 ; CHECK-NEXT:    movi v19.2d, #0000000000000000
-; CHECK-NEXT:    fcmla v16.2d, v2.2d, v0.2d, #0
-; CHECK-NEXT:    fcmla v18.2d, v3.2d, v1.2d, #0
-; CHECK-NEXT:    fcmla v17.2d, v7.2d, v5.2d, #0
-; CHECK-NEXT:    fcmla v19.2d, v6.2d, v4.2d, #0
 ; CHECK-NEXT:    fcmla v16.2d, v2.2d, v0.2d, #90
 ; CHECK-NEXT:    fcmla v18.2d, v3.2d, v1.2d, #90
 ; CHECK-NEXT:    fcmla v17.2d, v7.2d, v5.2d, #90
 ; CHECK-NEXT:    fcmla v19.2d, v6.2d, v4.2d, #90
+; CHECK-NEXT:    fcmla v16.2d, v2.2d, v0.2d, #0
+; CHECK-NEXT:    fcmla v18.2d, v3.2d, v1.2d, #0
+; CHECK-NEXT:    fcmla v17.2d, v7.2d, v5.2d, #0
+; CHECK-NEXT:    fcmla v19.2d, v6.2d, v4.2d, #0
 ; CHECK-NEXT:    fsub v1.2d, v18.2d, v17.2d
 ; CHECK-NEXT:    fsub v0.2d, v16.2d, v19.2d
 ; CHECK-NEXT:    ret
@@ -140,14 +170,14 @@ define <4 x double> @mul_conj_mull(<4 x double> %a, <4 x double> %b, <4 x double
 ; CHECK-NEXT:    movi v17.2d, #0000000000000000
 ; CHECK-NEXT:    movi v18.2d, #0000000000000000
 ; CHECK-NEXT:    movi v19.2d, #0000000000000000
-; CHECK-NEXT:    fcmla v16.2d, v2.2d, v0.2d, #0
-; CHECK-NEXT:    fcmla v18.2d, v3.2d, v1.2d, #0
-; CHECK-NEXT:    fcmla v17.2d, v5.2d, v7.2d, #0
-; CHECK-NEXT:    fcmla v19.2d, v4.2d, v6.2d, #0
 ; CHECK-NEXT:    fcmla v16.2d, v2.2d, v0.2d, #90
 ; CHECK-NEXT:    fcmla v18.2d, v3.2d, v1.2d, #90
 ; CHECK-NEXT:    fcmla v17.2d, v5.2d, v7.2d, #270
 ; CHECK-NEXT:    fcmla v19.2d, v4.2d, v6.2d, #270
+; CHECK-NEXT:    fcmla v16.2d, v2.2d, v0.2d, #0
+; CHECK-NEXT:    fcmla v18.2d, v3.2d, v1.2d, #0
+; CHECK-NEXT:    fcmla v17.2d, v5.2d, v7.2d, #0
+; CHECK-NEXT:    fcmla v19.2d, v4.2d, v6.2d, #0
 ; CHECK-NEXT:    fadd v1.2d, v18.2d, v17.2d
 ; CHECK-NEXT:    fadd v0.2d, v16.2d, v19.2d
 ; CHECK-NEXT:    ret
diff --git a/llvm/test/CodeGen/AArch64/complex-deinterleaving-add-mull-scalable-contract.ll b/llvm/test/CodeGen/AArch64/complex-deinterleaving-add-mull-scalable-contract.ll
index 258eaabee9376..9988e9337d1ef 100644
--- a/llvm/test/CodeGen/AArch64/complex-deinterleaving-add-mull-scalable-contract.ll
+++ b/llvm/test/CodeGen/AArch64/complex-deinterleaving-add-mull-scalable-contract.ll
@@ -45,6 +45,40 @@ entry:
   ret <vscale x 4 x double> %interleaved.vec
 }
 
+; a * b + c
+define <vscale x 4 x double> @mull_accum(<vscale x 4 x double> %a, <vscale x 4 x double> %b, <vscale x 4 x double> %c) {
+; CHECK-LABEL: mull_accum:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ptrue   p0.d
+; CHECK-NEXT:    fcmla   z4.d, p0/m, z0.d, z2.d, #0
+; CHECK-NEXT:    fcmla   z5.d, p0/m, z1.d, z3.d, #0
+; CHECK-NEXT:    fcmla   z4.d, p0/m, z0.d, z2.d, #90
+; CHECK-NEXT:    fcmla   z5.d, p0/m, z1.d, z3.d, #90
+; CHECK-NEXT:    mov     z0.d, z4.d
+; CHECK-NEXT:    mov     z1.d, z5.d
+; CHECK-NEXT:    ret
+entry:
+  %strided.vec = tail call { <vscale x 2 x double>, <vscale x 2 x double> } @llvm.vector.deinterleave2.nxv4f64(<vscale x 4 x double> %a)
+  %0 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec, 0
+  %1 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec, 1
+  %strided.vec29 = tail call { <vscale x 2 x double>, <vscale x 2 x double> } @llvm.vector.deinterleave2.nxv4f64(<vscale x 4 x double> %b)
+  %2 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec29, 0
+  %3 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec29, 1
+  %strided.vec31 = tail call { <vscale x 2 x double>, <vscale x 2 x double> } @llvm.vector.deinterleave2.nxv4f64(<vscale x 4 x double> %c)
+  %4 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec31, 0
+  %5 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec31, 1
+  %6 = fmul contract <vscale x 2 x double> %0, %3
+  %7 = fadd contract <vscale x 2 x double> %6, %5
+  %8 = fmul contract <vscale x 2 x double> %1, %2
+  %9 = fadd contract <vscale x 2 x double> %8, %7
+  %10 = fmul contract <vscale x 2 x double> %0, %2
+  %11 = fadd contract <vscale x 2 x double> %4, %10
+  %12 = fmul contract <vscale x 2 x double> %1, %3
+  %13 = fsub contract <vscale x 2 x double> %11, %12
+  %interleaved.vec = tail call <vscale x 4 x double> @llvm.vector.interleave2.nxv4f64(<vscale x 2 x double> %13, <vscale x 2 x double> %9)
+  ret <vscale x 4 x double> %interleaved.vec
+}
+
 ; a * b + c * d
 define <vscale x 4 x double> @mul_add_mull(<vscale x 4 x double> %a, <vscale x 4 x double> %b, <vscale x 4 x double> %c, <vscale x 4 x double> %d) {
 ; CHECK-LABEL: mul_add_mull:
@@ -54,14 +88,14 @@ define <vscale x 4 x double> @mul_add_mull(<vscale x 4 x double> %a, <vscale x 4
 ; CHECK-NEXT:    movi v26.2d, #0000000000000000
 ; CHECK-NEXT:    movi v27.2d, #0000000000000000
 ; CHECK-NEXT:    ptrue p0.d
-; CHECK-NEXT:    fcmla z24.d, p0/m, z2.d, z0.d, #0
-; CHECK-NEXT:    fcmla z25.d, p0/m, z3.d, z1.d, #0
-; CHECK-NEXT:    fcmla z27.d, p0/m, z6.d, z4.d, #0
-; CHECK-NEXT:    fcmla z26.d, p0/m, z7.d, z5.d, #0
 ; CHECK-NEXT:    fcmla z24.d, p0/m, z2.d, z0.d, #90
 ; CHECK-NEXT:    fcmla z25.d, p0/m, z3.d, z1.d, #90
 ; CHECK-NEXT:    fcmla z27.d, p0/m, z6.d, z4.d, #90
 ; CHECK-NEXT:    fcmla z26.d, p0/m, z7.d, z5.d, #90
+; CHECK-NEXT:    fcmla z24.d, p0/m, z2.d, z0.d, #0
+; CHECK-NEXT:    fcmla z25.d, p0/m, z3.d, z1.d, #0
+; CHECK-NEXT:    fcmla z27.d, p0/m, z6.d, z4.d, #0
+; CHECK-NEXT:    fcmla z26.d, p0/m, z7.d, z5.d, #0
 ; CHECK-NEXT:    fadd z0.d, z24.d, z27.d
 ; CHECK-NEXT:    fadd z1.d, z25.d, z26.d
 ; CHECK-NEXT:    ret
@@ -105,14 +139,14 @@ define <vscale x 4 x double> @mul_sub_mull(<vscale x 4 x double> %a, <vscale x 4
 ; CHECK-NEXT:    movi v26.2d, #0000000000000000
 ; CHECK-NEXT:    movi v27.2d, #0000000000000000
 ; CHECK-NEXT:    ptrue p0.d
-; CHECK-NEXT:    fcmla z24.d, p0/m, z2.d, z0.d, #0
-; CHECK-NEXT:    fcmla z25.d, p0/m, z3.d, z1.d, #0
-; CHECK-NEXT:    fcmla z27.d, p0/m, z6.d, z4.d, #0
-; CHECK-NEXT:    fcmla z26.d, p0/m, z7.d, z5.d, #0
 ; CHECK-NEXT:    fcmla z24.d, p0/m, z2.d, z0.d, #90
 ; CHECK-NEXT:    fcmla z25.d, p0/m, z3.d, z1.d, #90
 ; CHECK-NEXT:    fcmla z27.d, p0/m, z6.d, z4.d, #90
 ; CHECK-NEXT:    fcmla z26.d, p0/m, z7.d, z5.d, #90
+; CHECK-NEXT:    fcmla z24.d, p0/m, z2.d, z0.d, #0
+; CHECK-NEXT:    fcmla z25.d, p0/m, z3.d, z1.d, #0
+; CHECK-NEXT:    fcmla z27.d, p0/m, z6.d, z4.d, #0
+; CHECK-NEXT:    fcmla z26.d, p0/m, z7.d, z5.d, #0
 ; CHECK-NEXT:    fsub z0.d, z24.d, z27.d
 ; CHECK-NEXT:    fsub z1.d, z25.d, z26.d
 ; CHECK-NEXT:    ret
@@ -156,14 +190,14 @@ define <vscale x 4 x double> @mul_conj_mull(<vscale x 4 x double> %a, <vscale x
 ; CHECK-NEXT:    movi v26.2d, #0000000000000000
 ; CHECK-NEXT:    movi v27.2d, #0000000000000000
 ; CHECK-NEXT:    ptrue p0.d
-; CHECK-NEXT:    fcmla z24.d, p0/m, z2.d, z0.d, #0
-; CHECK-NEXT:    fcmla z25.d, p0/m, z3.d, z1.d, #0
-; CHECK-NEXT:    fcmla z27.d, p0/m, z4.d, z6.d, #0
-; CHECK-NEXT:    fcmla z26.d, p0/m, z5.d, z7.d, #0
 ; CHECK-NEXT:    fcmla z24.d, p0/m, z2.d, z0.d, #90
 ; CHECK-NEXT:    fcmla z25.d, p0/m, z3.d, z1.d, #90
 ; CHECK-NEXT:    fcmla z27.d, p0/m, z4.d, z6.d, #270
 ; CHECK-NEXT:    fcmla z26.d, p0/m, z5.d, z7.d, #270
+; CHECK-NEXT:    fcmla z24.d, p0/m, z2.d, z0.d, #0
+; CHECK-NEXT:    fcmla z25.d, p0/m, z3.d, z1.d, #0
+; CHECK-NEXT:    fcmla z27.d, p0/m, z4.d, z6.d, #0
+; CHECK-NEXT:    fcmla z26.d, p0/m, z5.d, z7.d, #0
 ; CHECK-NEXT:    fadd z0.d, z24.d, z27.d
 ; CHECK-NEXT:    fadd z1.d, z25.d, z26.d
 ; CHECK-NEXT:    ret

>From c49619cfa2a7566a16fa2e6edf0505bd84a04a02 Mon Sep 17 00:00:00 2001
From: Yichao Yu <yyc1992 at gmail.com>
Date: Sun, 28 Dec 2025 17:41:01 -0500
Subject: [PATCH 05/12] Slightly simplify and optimize symmetric operation
 opcode check

We are already confirming that everything is consistent with the first operation so there's no need to check the opcode for every single instructions
---
 llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp | 12 +++++-------
 1 file changed, 5 insertions(+), 7 deletions(-)

diff --git a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
index f3794103a87bf..4347fc0b0d27d 100644
--- a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
+++ b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
@@ -971,8 +971,8 @@ static bool isInstructionPairAdd(Instruction *A, Instruction *B) {
          (OpcA == Instruction::Add && OpcB == Instruction::Sub);
 }
 
-static bool isInstructionPotentiallySymmetric(Instruction *I) {
-  switch (I->getOpcode()) {
+static bool isInstructionPotentiallySymmetric(unsigned OpCode) {
+  switch (OpCode) {
   case Instruction::FAdd:
   case Instruction::FSub:
   case Instruction::FMul:
@@ -990,16 +990,14 @@ ComplexDeinterleavingGraph::CompositeNode *
 ComplexDeinterleavingGraph::identifySymmetricOperation(ComplexValues &Vals) {
   auto *FirstReal = cast<Instruction>(Vals[0].Real);
   unsigned FirstOpc = FirstReal->getOpcode();
+  if (!isInstructionPotentiallySymmetric(FirstOpc))
+    return nullptr;
   for (auto &V : Vals) {
     auto *Real = cast<Instruction>(V.Real);
     auto *Imag = cast<Instruction>(V.Imag);
     if (Real->getOpcode() != FirstOpc || Imag->getOpcode() != FirstOpc)
       return nullptr;
 
-    if (!isInstructionPotentiallySymmetric(Real) ||
-        !isInstructionPotentiallySymmetric(Imag))
-      return nullptr;
-
     if (isa<FPMathOperator>(FirstReal))
       if (Real->getFastMathFlags() != FirstReal->getFastMathFlags() ||
           Imag->getFastMathFlags() != FirstReal->getFastMathFlags())
@@ -1032,7 +1030,7 @@ ComplexDeinterleavingGraph::identifySymmetricOperation(ComplexValues &Vals) {
 
   auto Node =
       prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric, Vals);
-  Node->Opcode = FirstReal->getOpcode();
+  Node->Opcode = FirstOpc;
   if (isa<FPMathOperator>(FirstReal))
     Node->Flags = FirstReal->getFastMathFlags();
 

>From 6d33aa165afa6eb7d380e1c804c3e390124dd63a Mon Sep 17 00:00:00 2001
From: Yichao Yu <yyc1992 at gmail.com>
Date: Sun, 28 Dec 2025 18:28:48 -0500
Subject: [PATCH 06/12] Try flipping operand order when deinterleaving
 symmetric operations

---
 .../lib/CodeGen/ComplexDeinterleavingPass.cpp | 21 +++++++++++++---
 ...-deinterleaving-add-mull-fixed-contract.ll | 25 ++++++++-----------
 ...interleaving-add-mull-scalable-contract.ll | 24 +++++++-----------
 3 files changed, 36 insertions(+), 34 deletions(-)

diff --git a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
index 4347fc0b0d27d..177229768c16f 100644
--- a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
+++ b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
@@ -1012,15 +1012,28 @@ ComplexDeinterleavingGraph::identifySymmetricOperation(ComplexValues &Vals) {
   }
 
   CompositeNode *Op0 = identifyNode(OpVals);
-  CompositeNode *Op1 = nullptr;
-  if (Op0 == nullptr)
-    return nullptr;
+  bool FlipImag = false;
+  if (Op0 == nullptr) {
+    if (FirstOpc == Instruction::FAdd || FirstOpc == Instruction::FMul ||
+        FirstOpc == Instruction::Add || FirstOpc == Instruction::Mul) {
+      FlipImag = true;
+      unsigned NVals = Vals.size();
+      for (unsigned I = 0; I < NVals; I++) {
+        OpVals[I].Imag = cast<Instruction>(Vals[I].Imag)->getOperand(1);
+      }
+      Op0 = identifyNode(OpVals);
+    }
+    if (Op0 == nullptr) {
+      return nullptr;
+    }
+  }
 
+  CompositeNode *Op1 = nullptr;
   if (FirstReal->isBinaryOp()) {
     OpVals.clear();
     for (auto &V : Vals) {
       auto *R1 = cast<Instruction>(V.Real)->getOperand(1);
-      auto *I1 = cast<Instruction>(V.Imag)->getOperand(1);
+      auto *I1 = cast<Instruction>(V.Imag)->getOperand(FlipImag ? 0 : 1);
       OpVals.push_back({R1, I1});
     }
     Op1 = identifyNode(OpVals);
diff --git a/llvm/test/CodeGen/AArch64/complex-deinterleaving-add-mull-fixed-contract.ll b/llvm/test/CodeGen/AArch64/complex-deinterleaving-add-mull-fixed-contract.ll
index e55d2a8f9cd5e..b1bc17c41801c 100644
--- a/llvm/test/CodeGen/AArch64/complex-deinterleaving-add-mull-fixed-contract.ll
+++ b/llvm/test/CodeGen/AArch64/complex-deinterleaving-add-mull-fixed-contract.ll
@@ -7,19 +7,14 @@ target triple = "aarch64"
 define <4 x double> @mull_add(<4 x double> %a, <4 x double> %b, <4 x double> %c) {
 ; CHECK-LABEL: mull_add:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    zip2 v4.2d, v2.2d, v3.2d
-; CHECK-NEXT:    zip2 v5.2d, v0.2d, v1.2d
-; CHECK-NEXT:    zip1 v0.2d, v0.2d, v1.2d
-; CHECK-NEXT:    zip1 v2.2d, v2.2d, v3.2d
-; CHECK-NEXT:    fmul v1.2d, v5.2d, v4.2d
-; CHECK-NEXT:    fmul v3.2d, v0.2d, v4.2d
-; CHECK-NEXT:    fneg v1.2d, v1.2d
-; CHECK-NEXT:    fmla v3.2d, v2.2d, v5.2d
-; CHECK-NEXT:    fmla v1.2d, v2.2d, v0.2d
-; CHECK-NEXT:    fadd v1.2d, v2.2d, v1.2d
-; CHECK-NEXT:    fadd v2.2d, v3.2d, v4.2d
-; CHECK-NEXT:    zip1 v0.2d, v1.2d, v2.2d
-; CHECK-NEXT:    zip2 v1.2d, v1.2d, v2.2d
+; CHECK-NEXT:    movi    v6.2d, #0000000000000000
+; CHECK-NEXT:    movi    v7.2d, #0000000000000000
+; CHECK-NEXT:    fcmla   v7.2d, v2.2d, v0.2d, #90
+; CHECK-NEXT:    fcmla   v6.2d, v3.2d, v1.2d, #90
+; CHECK-NEXT:    fcmla   v7.2d, v2.2d, v0.2d, #0
+; CHECK-NEXT:    fcmla   v6.2d, v3.2d, v1.2d, #0
+; CHECK-NEXT:    fadd    v0.2d, v4.2d, v7.2d
+; CHECK-NEXT:    fadd    v1.2d, v5.2d, v6.2d
 ; CHECK-NEXT:    ret
 entry:
   %strided.vec = shufflevector <4 x double> %a, <4 x double> poison, <2 x i32> <i32 0, i32 2>
@@ -32,8 +27,8 @@ entry:
   %3 = fmul contract <2 x double> %strided.vec, %strided.vec30
   %4 = fmul contract <2 x double> %strided.vec28, %strided.vec31
   %5 = fsub contract <2 x double> %3, %4
-  %strided.vec33 = shufflevector <4 x double> %b, <4 x double> poison, <2 x i32> <i32 0, i32 2>
-  %strided.vec34 = shufflevector <4 x double> %b, <4 x double> poison, <2 x i32> <i32 1, i32 3>
+  %strided.vec33 = shufflevector <4 x double> %c, <4 x double> poison, <2 x i32> <i32 0, i32 2>
+  %strided.vec34 = shufflevector <4 x double> %c, <4 x double> poison, <2 x i32> <i32 1, i32 3>
   %6 = fadd contract <2 x double> %strided.vec33, %5
   %7 = fadd contract <2 x double> %2, %strided.vec34
   %interleaved.vec = shufflevector <2 x double> %6, <2 x double> %7, <4 x i32> <i32 0, i32 2, i32 1, i32 3>
diff --git a/llvm/test/CodeGen/AArch64/complex-deinterleaving-add-mull-scalable-contract.ll b/llvm/test/CodeGen/AArch64/complex-deinterleaving-add-mull-scalable-contract.ll
index 9988e9337d1ef..47f5d33b5a859 100644
--- a/llvm/test/CodeGen/AArch64/complex-deinterleaving-add-mull-scalable-contract.ll
+++ b/llvm/test/CodeGen/AArch64/complex-deinterleaving-add-mull-scalable-contract.ll
@@ -7,21 +7,15 @@ target triple = "aarch64-unknown-linux-gnu"
 define <vscale x 4 x double> @mull_add(<vscale x 4 x double> %a, <vscale x 4 x double> %b, <vscale x 4 x double> %c) {
 ; CHECK-LABEL: mull_add:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    uzp2 z6.d, z0.d, z1.d
-; CHECK-NEXT:    uzp1 z0.d, z0.d, z1.d
-; CHECK-NEXT:    uzp2 z1.d, z2.d, z3.d
-; CHECK-NEXT:    uzp1 z2.d, z2.d, z3.d
-; CHECK-NEXT:    ptrue p0.d
-; CHECK-NEXT:    fmul z7.d, z0.d, z1.d
-; CHECK-NEXT:    fmul z1.d, z6.d, z1.d
-; CHECK-NEXT:    fmad z6.d, p0/m, z2.d, z7.d
-; CHECK-NEXT:    fnmsb z0.d, p0/m, z2.d, z1.d
-; CHECK-NEXT:    uzp2 z1.d, z4.d, z5.d
-; CHECK-NEXT:    uzp1 z2.d, z4.d, z5.d
-; CHECK-NEXT:    fadd z2.d, z2.d, z0.d
-; CHECK-NEXT:    fadd z1.d, z6.d, z1.d
-; CHECK-NEXT:    zip1 z0.d, z2.d, z1.d
-; CHECK-NEXT:    zip2 z1.d, z2.d, z1.d
+; CHECK-NEXT:    movi    v6.2d, #0000000000000000
+; CHECK-NEXT:    movi    v7.2d, #0000000000000000
+; CHECK-NEXT:    ptrue   p0.d
+; CHECK-NEXT:    fcmla   z7.d, p0/m, z2.d, z0.d, #90
+; CHECK-NEXT:    fcmla   z6.d, p0/m, z3.d, z1.d, #90
+; CHECK-NEXT:    fcmla   z7.d, p0/m, z2.d, z0.d, #0
+; CHECK-NEXT:    fcmla   z6.d, p0/m, z3.d, z1.d, #0
+; CHECK-NEXT:    fadd    z0.d, z4.d, z7.d
+; CHECK-NEXT:    fadd    z1.d, z5.d, z6.d
 ; CHECK-NEXT:    ret
 entry:
   %strided.vec = tail call { <vscale x 2 x double>, <vscale x 2 x double> } @llvm.vector.deinterleave2.nxv4f64(<vscale x 4 x double> %a)

>From 502baec0f660dfe52744a45efda74b4fd628156f Mon Sep 17 00:00:00 2001
From: Yichao Yu <yyc1992 at gmail.com>
Date: Sun, 28 Dec 2025 22:50:12 -0500
Subject: [PATCH 07/12] Handle negative addend when contracting to fcmla

We propagate the negative sign to the top level to maximize the chance
of it being merged with other operations
(e.g. canceling another neg or merging into add/sub)
---
 .../lib/CodeGen/ComplexDeinterleavingPass.cpp | 41 +++++++++++++++----
 ...-deinterleaving-add-mull-fixed-contract.ll | 30 ++++++++++++++
 ...interleaving-add-mull-scalable-contract.ll | 36 ++++++++++++++++
 3 files changed, 99 insertions(+), 8 deletions(-)

diff --git a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
index 177229768c16f..6d46e5bacd097 100644
--- a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
+++ b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
@@ -382,6 +382,14 @@ class ComplexDeinterleavingGraph {
     return Node;
   }
 
+  CompositeNode *negCompositeNode(CompositeNode *Node) {
+    auto NegNode = prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric,
+                                        nullptr, nullptr);
+    NegNode->Opcode = Instruction::FNeg;
+    NegNode->addOperand(Node);
+    return submitCompositeNode(NegNode);
+  }
+
   /// Identifies a complex partial multiply pattern and its rotation, based on
   /// the following patterns
   ///
@@ -634,6 +642,14 @@ static const IntrinsicInst *getFMAOrMulAdd(const Instruction *I) {
   return nullptr;
 }
 
+static inline ComplexDeinterleavingRotation
+flipRotation(ComplexDeinterleavingRotation Rotation, bool Cond=true)
+{
+  if (!Cond)
+    return Rotation;
+  return ComplexDeinterleavingRotation(unsigned(Rotation) ^ 2);
+}
+
 bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B, unsigned Factor) {
   ComplexDeinterleavingGraph Graph(TL, TLI, Factor);
   if (Graph.collectPotentialReductions(B))
@@ -752,7 +768,8 @@ ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
     return false;
   };
 
-  auto MatchCommons = [&](PartialMulNode *PN, CompositeNode *CN) -> CompositeNode* {
+  auto MatchCommons = [&](PartialMulNode *PN,
+                          CompositeNode *CN, bool CNPositive) -> CompositeNode* {
     assert(PN);
     for (auto PN0 = PN; PN0; PN0 = PN0->prev) {
       if (PN0->CommonNode)
@@ -789,7 +806,7 @@ ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
     for (; PN; PN = PN->prev) {
       CompositeNode *NewCN = prepareCompositeNode(
           ComplexDeinterleavingOperation::CMulPartial, nullptr, nullptr);
-      NewCN->Rotation = PN->Rotation;
+      NewCN->Rotation = flipRotation(PN->Rotation, !CNPositive);
       NewCN->addOperand(PN->CommonNode);
       NewCN->addOperand(PN->UncommonNode);
       if (CN) {
@@ -797,6 +814,9 @@ ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
       }
       CN = submitCompositeNode(NewCN);
     }
+    if (!CNPositive) {
+      return negCompositeNode(CN);
+    }
     return CN;
   };
 
@@ -807,14 +827,14 @@ ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
   if (!ProcessInst(Real, RealPositive, RealMuls, RealAddend) ||
       !ProcessInst(Imag, ImagPositive, ImagMuls, ImagAddend)) {
     LLVM_DEBUG(dbgs() << "  - Failed to match PartialMul in Real/Imag terms.\n");
-    if (PN && RealPositive && ImagPositive) {
+    if (PN && RealPositive == ImagPositive) {
       auto CN = identifyNode(Real, Imag);
       if (CN) {
         LLVM_DEBUG({
           dbgs() << "  - Addends matched:\n";
           CN->dump();
         });
-        return MatchCommons(PN, CN);
+        return MatchCommons(PN, CN, RealPositive);
       }
       LLVM_DEBUG(dbgs() << "  - Failed to match Addends "
                  << *Real << " / " << *Imag << ".\n");
@@ -885,13 +905,13 @@ ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
   if (RealMuls.size() == 1) {
     assert(RealAddend.first && ImagAddend.first);
     if (!isa<Instruction>(RealAddend.first) || !isa<Instruction>(ImagAddend.first)) {
-      if (!RealAddend.second || !ImagAddend.second)
+      if (RealAddend.second != ImagAddend.second)
         return nullptr;
       auto CN = identifyNode(RealAddend.first, ImagAddend.first);
       if (!CN)
         return nullptr;
       return ForeachMatch(RealMuls[0], ImagMuls[0], PN, [&](PartialMulNode *PN) {
-        return MatchCommons(PN, CN);
+        return MatchCommons(PN, CN, RealAddend.second);
       });
     }
     return ForeachMatch(RealMuls[0], ImagMuls[0], PN, [&](PartialMulNode *PN) {
@@ -905,7 +925,7 @@ ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
     assert(!RealAddend.first && !ImagAddend.first);
     return ForeachMatch(RealMuls[0], ImagMuls[0], PN, [&](PartialMulNode *PN) {
       return ForeachMatch(RealMuls[1], ImagMuls[1], PN, [&](PartialMulNode *PN) {
-        return MatchCommons(PN, nullptr);
+        return MatchCommons(PN, nullptr, true);
       });
     });
   }
@@ -2383,7 +2403,12 @@ static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode,
   Value *I;
   switch (Opcode) {
   case Instruction::FNeg:
-    I = B.CreateFNeg(InputA);
+    // We use FNeg to encode both floating point and integer negation
+    if (InputA->getType()->isIntOrIntVectorTy()) {
+      I = B.CreateSub(Constant::getNullValue(InputA->getType()), InputA);
+    } else {
+      I = B.CreateFNeg(InputA);
+    }
     break;
   case Instruction::FAdd:
     I = B.CreateFAdd(InputA, InputB);
diff --git a/llvm/test/CodeGen/AArch64/complex-deinterleaving-add-mull-fixed-contract.ll b/llvm/test/CodeGen/AArch64/complex-deinterleaving-add-mull-fixed-contract.ll
index b1bc17c41801c..9fab5071e79db 100644
--- a/llvm/test/CodeGen/AArch64/complex-deinterleaving-add-mull-fixed-contract.ll
+++ b/llvm/test/CodeGen/AArch64/complex-deinterleaving-add-mull-fixed-contract.ll
@@ -65,6 +65,36 @@ entry:
   ret <4 x double> %interleaved.vec
 }
 
+; a * b - c
+define <4 x double> @mull_neg_accum(<4 x double> %a, <4 x double> %b, <4 x double> %c) {
+; CHECK-LABEL: mull_neg_accum:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    fcmla   v4.2d, v0.2d, v2.2d, #180
+; CHECK-NEXT:    fcmla   v5.2d, v1.2d, v3.2d, #180
+; CHECK-NEXT:    fcmla   v4.2d, v0.2d, v2.2d, #270
+; CHECK-NEXT:    fcmla   v5.2d, v1.2d, v3.2d, #270
+; CHECK-NEXT:    fneg    v0.2d, v4.2d
+; CHECK-NEXT:    fneg    v1.2d, v5.2d
+; CHECK-NEXT:    ret
+entry:
+  %strided.vec = shufflevector <4 x double> %a, <4 x double> poison, <2 x i32> <i32 0, i32 2>
+  %strided.vec28 = shufflevector <4 x double> %a, <4 x double> poison, <2 x i32> <i32 1, i32 3>
+  %strided.vec30 = shufflevector <4 x double> %b, <4 x double> poison, <2 x i32> <i32 0, i32 2>
+  %strided.vec31 = shufflevector <4 x double> %b, <4 x double> poison, <2 x i32> <i32 1, i32 3>
+  %strided.vec33 = shufflevector <4 x double> %c, <4 x double> poison, <2 x i32> <i32 0, i32 2>
+  %strided.vec34 = shufflevector <4 x double> %c, <4 x double> poison, <2 x i32> <i32 1, i32 3>
+  %0 = fmul contract <2 x double> %strided.vec, %strided.vec31
+  %1 = fsub contract <2 x double> %0, %strided.vec34
+  %2 = fmul contract <2 x double> %strided.vec28, %strided.vec30
+  %3 = fadd contract <2 x double> %1, %2
+  %4 = fmul contract <2 x double> %strided.vec, %strided.vec30
+  %5 = fsub contract <2 x double> %4, %strided.vec33
+  %6 = fmul contract <2 x double> %strided.vec28, %strided.vec31
+  %7 = fsub contract <2 x double> %5, %6
+  %interleaved.vec = shufflevector <2 x double> %7, <2 x double> %3, <4 x i32> <i32 0, i32 2, i32 1, i32 3>
+  ret <4 x double> %interleaved.vec
+}
+
 ; a * b + c * d
 define <4 x double> @mul_add_mull(<4 x double> %a, <4 x double> %b, <4 x double> %c, <4 x double> %d) {
 ; CHECK-LABEL: mul_add_mull:
diff --git a/llvm/test/CodeGen/AArch64/complex-deinterleaving-add-mull-scalable-contract.ll b/llvm/test/CodeGen/AArch64/complex-deinterleaving-add-mull-scalable-contract.ll
index 47f5d33b5a859..1b1d35bf6420f 100644
--- a/llvm/test/CodeGen/AArch64/complex-deinterleaving-add-mull-scalable-contract.ll
+++ b/llvm/test/CodeGen/AArch64/complex-deinterleaving-add-mull-scalable-contract.ll
@@ -73,6 +73,42 @@ entry:
   ret <vscale x 4 x double> %interleaved.vec
 }
 
+; a * b - c
+define <vscale x 4 x double> @mull_neg_accum(<vscale x 4 x double> %a, <vscale x 4 x double> %b, <vscale x 4 x double> %c) {
+; CHECK-LABEL: mull_neg_accum:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    ptrue   p0.d
+; CHECK-NEXT:    fcmla   z4.d, p0/m, z0.d, z2.d, #180
+; CHECK-NEXT:    fcmla   z5.d, p0/m, z1.d, z3.d, #180
+; CHECK-NEXT:    fcmla   z4.d, p0/m, z0.d, z2.d, #270
+; CHECK-NEXT:    fcmla   z5.d, p0/m, z1.d, z3.d, #270
+; CHECK-NEXT:    movprfx z0, z4
+; CHECK-NEXT:    fneg    z0.d, p0/m, z4.d
+; CHECK-NEXT:    movprfx z1, z5
+; CHECK-NEXT:    fneg    z1.d, p0/m, z5.d
+; CHECK-NEXT:    ret
+entry:
+  %strided.vec = tail call { <vscale x 2 x double>, <vscale x 2 x double> } @llvm.vector.deinterleave2.nxv4f64(<vscale x 4 x double> %a)
+  %0 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec, 0
+  %1 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec, 1
+  %strided.vec29 = tail call { <vscale x 2 x double>, <vscale x 2 x double> } @llvm.vector.deinterleave2.nxv4f64(<vscale x 4 x double> %b)
+  %2 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec29, 0
+  %3 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec29, 1
+  %strided.vec31 = tail call { <vscale x 2 x double>, <vscale x 2 x double> } @llvm.vector.deinterleave2.nxv4f64(<vscale x 4 x double> %c)
+  %4 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec31, 0
+  %5 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %strided.vec31, 1
+  %6 = fmul contract <vscale x 2 x double> %0, %3
+  %7 = fsub contract <vscale x 2 x double> %6, %5
+  %8 = fmul contract <vscale x 2 x double> %1, %2
+  %9 = fadd contract <vscale x 2 x double> %8, %7
+  %10 = fmul contract <vscale x 2 x double> %0, %2
+  %11 = fsub contract <vscale x 2 x double> %10, %4
+  %12 = fmul contract <vscale x 2 x double> %1, %3
+  %13 = fsub contract <vscale x 2 x double> %11, %12
+  %interleaved.vec = tail call <vscale x 4 x double> @llvm.vector.interleave2.nxv4f64(<vscale x 2 x double> %13, <vscale x 2 x double> %9)
+  ret <vscale x 4 x double> %interleaved.vec
+}
+
 ; a * b + c * d
 define <vscale x 4 x double> @mul_add_mull(<vscale x 4 x double> %a, <vscale x 4 x double> %b, <vscale x 4 x double> %c, <vscale x 4 x double> %d) {
 ; CHECK-LABEL: mul_add_mull:

>From e72ff131d63db66f677fa0d853e5e4adf3dd0a14 Mon Sep 17 00:00:00 2001
From: Yichao Yu <yyc1992 at gmail.com>
Date: Sun, 28 Dec 2025 22:52:42 -0500
Subject: [PATCH 08/12] Improve negative addend handling for reassoc case

If we couldn't find a positive addend, we could simply find a negative one
and use that as the accumulator.
In the worst case we may need to add a negation to the final result
but we'll get rid of an add/sub between addends and a zero initialization
of the accumulator.
---
 .../lib/CodeGen/ComplexDeinterleavingPass.cpp | 60 ++++++++++++++-----
 .../complex-deinterleaving-mixed-cases.ll     |  7 +--
 .../complex-deinterleaving-multiuses.ll       | 34 +++++------
 .../mve-complex-deinterleaving-mixed-cases.ll | 12 ++--
 4 files changed, 70 insertions(+), 43 deletions(-)

diff --git a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
index 6d46e5bacd097..d96da91e355c2 100644
--- a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
+++ b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
@@ -432,18 +432,21 @@ class ComplexDeinterleavingGraph {
   CompositeNode *identifyAdditions(AddendList &RealAddends,
                                    AddendList &ImagAddends,
                                    std::optional<FastMathFlags> Flags,
-                                   CompositeNode *Accumulator);
+                                   CompositeNode *Accumulator,
+                                   bool &AccumPositive);
 
-  /// Extract one addend that have both real and imaginary parts positive.
-  CompositeNode *extractPositiveAddend(AddendList &RealAddends,
-                                       AddendList &ImagAddends);
+  /// Extract one addend that have both real and imaginary parts positive/negative.
+  CompositeNode *extractAddend(AddendList &RealAddends,
+                               AddendList &ImagAddends,
+                               bool Positive);
 
   /// Determine if sum of multiplications of complex numbers can be formed from
   /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result
   /// to it. Return nullptr if it is not possible to construct a complex number.
   CompositeNode *identifyMultiplications(SmallVectorImpl<Product> &RealMuls,
                                          SmallVectorImpl<Product> &ImagMuls,
-                                         CompositeNode *Accumulator);
+                                         CompositeNode *Accumulator,
+                                         bool AccumPositive);
 
   /// Go through pairs of multiplication (one Real and one Imag) and find all
   /// possible candidates for partial multiplication and put them into \p
@@ -1435,21 +1438,32 @@ ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
     return nullptr;
 
   CompositeNode *FinalNode = nullptr;
+  bool AddendPositive = true;
   if (!RealMuls.empty() || !ImagMuls.empty()) {
     // If there are multiplicands, extract positive addend and use it as an
     // accumulator
-    FinalNode = extractPositiveAddend(RealAddends, ImagAddends);
-    FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode);
+    FinalNode = extractAddend(RealAddends, ImagAddends, true);
+    if (!FinalNode) {
+      FinalNode = extractAddend(RealAddends, ImagAddends, false);
+      if (FinalNode) {
+        AddendPositive = false;
+      }
+    }
+    FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode,
+                                        AddendPositive);
     if (!FinalNode)
       return nullptr;
   }
 
   // Identify and process remaining additions
   if (!RealAddends.empty() || !ImagAddends.empty()) {
-    FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode);
+    FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode,
+                                  AddendPositive);
     if (!FinalNode)
       return nullptr;
   }
+  if (!AddendPositive)
+    FinalNode = negCompositeNode(FinalNode);
   assert(FinalNode && "FinalNode can not be nullptr here");
   assert(FinalNode->Vals.size() == 1);
   // Set the Real and Imag fields of the final node and submit it
@@ -1513,7 +1527,7 @@ bool ComplexDeinterleavingGraph::collectPartialMuls(
 ComplexDeinterleavingGraph::CompositeNode *
 ComplexDeinterleavingGraph::identifyMultiplications(
     SmallVectorImpl<Product> &RealMuls, SmallVectorImpl<Product> &ImagMuls,
-    CompositeNode *Accumulator = nullptr) {
+    CompositeNode *Accumulator, bool AccumPositive) {
   if (RealMuls.size() != ImagMuls.size())
     return nullptr;
 
@@ -1650,7 +1664,7 @@ ComplexDeinterleavingGraph::identifyMultiplications(
 
     CompositeNode *NodeMul = prepareCompositeNode(
         ComplexDeinterleavingOperation::CMulPartial, nullptr, nullptr);
-    NodeMul->Rotation = Rotation;
+    NodeMul->Rotation = flipRotation(Rotation, !AccumPositive);
     NodeMul->addOperand(NodeA);
     NodeMul->addOperand(NodeB);
     if (Result)
@@ -1692,7 +1706,8 @@ ComplexDeinterleavingGraph::identifyMultiplications(
 ComplexDeinterleavingGraph::CompositeNode *
 ComplexDeinterleavingGraph::identifyAdditions(
     AddendList &RealAddends, AddendList &ImagAddends,
-    std::optional<FastMathFlags> Flags, CompositeNode *Accumulator = nullptr) {
+    std::optional<FastMathFlags> Flags, CompositeNode *Accumulator,
+    bool &AccumPositive) {
   if (RealAddends.size() != ImagAddends.size())
     return nullptr;
 
@@ -1701,8 +1716,15 @@ ComplexDeinterleavingGraph::identifyAdditions(
   if (Accumulator)
     Result = Accumulator;
   // Otherwise find an element with both positive real and imaginary parts.
-  else
-    Result = extractPositiveAddend(RealAddends, ImagAddends);
+  else {
+    Result = extractAddend(RealAddends, ImagAddends, true);
+    if (!Result) {
+      Result = extractAddend(RealAddends, ImagAddends, false);
+      if (Result) {
+        AccumPositive = false;
+      }
+    }
+  }
 
   if (!Result)
     return nullptr;
@@ -1723,6 +1745,7 @@ ComplexDeinterleavingGraph::identifyAdditions(
         Rotation = ComplexDeinterleavingRotation::Rotation_180;
       else
         Rotation = ComplexDeinterleavingRotation::Rotation_270;
+      Rotation = flipRotation(Rotation, !AccumPositive);
 
       CompositeNode *AddNode = nullptr;
       if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
@@ -1759,6 +1782,10 @@ ComplexDeinterleavingGraph::identifyAdditions(
           } else {
             TmpNode->Opcode = Instruction::Sub;
           }
+          if (!AccumPositive) {
+            std::swap(Result, AddNode);
+            AccumPositive = true;
+          }
         } else {
           TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd,
                                          nullptr, nullptr);
@@ -1782,13 +1809,14 @@ ComplexDeinterleavingGraph::identifyAdditions(
 }
 
 ComplexDeinterleavingGraph::CompositeNode *
-ComplexDeinterleavingGraph::extractPositiveAddend(AddendList &RealAddends,
-                                                  AddendList &ImagAddends) {
+ComplexDeinterleavingGraph::extractAddend(AddendList &RealAddends,
+                                          AddendList &ImagAddends,
+                                          bool Positive) {
   for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) {
     for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
       auto [R, IsPositiveR] = *ItR;
       auto [I, IsPositiveI] = *ItI;
-      if (IsPositiveR && IsPositiveI) {
+      if (IsPositiveR == Positive && IsPositiveI == Positive) {
         auto Result = identifyNode(R, I);
         if (Result) {
           RealAddends.erase(ItR);
diff --git a/llvm/test/CodeGen/AArch64/complex-deinterleaving-mixed-cases.ll b/llvm/test/CodeGen/AArch64/complex-deinterleaving-mixed-cases.ll
index 5e4d9800a9b81..126dc591bbc4f 100644
--- a/llvm/test/CodeGen/AArch64/complex-deinterleaving-mixed-cases.ll
+++ b/llvm/test/CodeGen/AArch64/complex-deinterleaving-mixed-cases.ll
@@ -376,10 +376,9 @@ entry:
 define <4 x float> @mul_subequal(<4 x float> %a, <4 x float> %b, <4 x float> %c) {
 ; CHECK-LABEL: mul_subequal:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    movi v3.2d, #0000000000000000
-; CHECK-NEXT:    fcmla v3.4s, v1.4s, v0.4s, #0
-; CHECK-NEXT:    fcmla v3.4s, v1.4s, v0.4s, #90
-; CHECK-NEXT:    fsub v0.4s, v3.4s, v2.4s
+; CHECK-NEXT:    fcmla v2.4s, v1.4s, v0.4s, #180
+; CHECK-NEXT:    fcmla v2.4s, v1.4s, v0.4s, #270
+; CHECK-NEXT:    fneg v0.4s, v2.4s
 ; CHECK-NEXT:    ret
 entry:
   %strided.vec = shufflevector <4 x float> %a, <4 x float> poison, <2 x i32> <i32 0, i32 2>
diff --git a/llvm/test/CodeGen/AArch64/complex-deinterleaving-multiuses.ll b/llvm/test/CodeGen/AArch64/complex-deinterleaving-multiuses.ll
index 039025dafa0d6..6aa4c8146fd9a 100644
--- a/llvm/test/CodeGen/AArch64/complex-deinterleaving-multiuses.ll
+++ b/llvm/test/CodeGen/AArch64/complex-deinterleaving-multiuses.ll
@@ -299,33 +299,33 @@ define void @mul_add_common_mul_add_mul(<4 x double> %a, <4 x double> %b, <4 x d
 ; CHECK-LABEL: mul_add_common_mul_add_mul:
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    movi v16.2d, #0000000000000000
-; CHECK-NEXT:    movi v17.2d, #0000000000000000
-; CHECK-NEXT:    ldr q19, [sp, #112]
 ; CHECK-NEXT:    ldp q18, q20, [sp, #80]
+; CHECK-NEXT:    ldr q19, [sp, #112]
+; CHECK-NEXT:    movi v17.2d, #0000000000000000
 ; CHECK-NEXT:    ldr q21, [sp, #64]
-; CHECK-NEXT:    movi v22.2d, #0000000000000000
 ; CHECK-NEXT:    fcmla v16.2d, v18.2d, v19.2d, #0
 ; CHECK-NEXT:    fcmla v17.2d, v21.2d, v20.2d, #0
-; CHECK-NEXT:    fcmla v22.2d, v1.2d, v3.2d, #0
 ; CHECK-NEXT:    fcmla v16.2d, v18.2d, v19.2d, #90
-; CHECK-NEXT:    movi v18.2d, #0000000000000000
 ; CHECK-NEXT:    fcmla v17.2d, v21.2d, v20.2d, #90
-; CHECK-NEXT:    fcmla v22.2d, v1.2d, v3.2d, #90
 ; CHECK-NEXT:    fcmla v16.2d, v5.2d, v7.2d, #0
-; CHECK-NEXT:    fcmla v18.2d, v0.2d, v2.2d, #0
 ; CHECK-NEXT:    fcmla v17.2d, v4.2d, v6.2d, #0
 ; CHECK-NEXT:    fcmla v16.2d, v5.2d, v7.2d, #90
-; CHECK-NEXT:    fcmla v18.2d, v0.2d, v2.2d, #90
 ; CHECK-NEXT:    fcmla v17.2d, v4.2d, v6.2d, #90
-; CHECK-NEXT:    ldp q3, q0, [sp, #32]
-; CHECK-NEXT:    ldp q2, q1, [sp]
-; CHECK-NEXT:    fsub v4.2d, v22.2d, v16.2d
-; CHECK-NEXT:    fsub v5.2d, v18.2d, v17.2d
-; CHECK-NEXT:    fcmla v16.2d, v0.2d, v1.2d, #0
-; CHECK-NEXT:    fcmla v17.2d, v3.2d, v2.2d, #0
-; CHECK-NEXT:    stp q5, q4, [x0]
-; CHECK-NEXT:    fcmla v16.2d, v0.2d, v1.2d, #90
-; CHECK-NEXT:    fcmla v17.2d, v3.2d, v2.2d, #90
+; CHECK-NEXT:    ldp q7, q6, [sp, #32]
+; CHECK-NEXT:    mov     v4.16b, v16.16b
+; CHECK-NEXT:    mov     v5.16b, v17.16b
+; CHECK-NEXT:    fcmla v4.2d, v1.2d, v3.2d, #180
+; CHECK-NEXT:    fcmla v5.2d, v0.2d, v2.2d, #180
+; CHECK-NEXT:    fcmla v4.2d, v1.2d, v3.2d, #270
+; CHECK-NEXT:    ldp q3, q1, [sp]
+; CHECK-NEXT:    fcmla v5.2d, v0.2d, v2.2d, #270
+; CHECK-NEXT:    fcmla v16.2d, v6.2d, v1.2d, #0
+; CHECK-NEXT:    fcmla v17.2d, v7.2d, v3.2d, #0
+; CHECK-NEXT:    fneg v0.2d, v4.2d
+; CHECK-NEXT:    fneg v2.2d, v5.2d
+; CHECK-NEXT:    fcmla v16.2d, v6.2d, v1.2d, #90
+; CHECK-NEXT:    fcmla v17.2d, v7.2d, v3.2d, #90
+; CHECK-NEXT:    stp q2, q0, [x0]
 ; CHECK-NEXT:    stp q17, q16, [x1]
 ; CHECK-NEXT:    ret
 entry:
diff --git a/llvm/test/CodeGen/Thumb2/mve-complex-deinterleaving-mixed-cases.ll b/llvm/test/CodeGen/Thumb2/mve-complex-deinterleaving-mixed-cases.ll
index c07ad70d18d39..837477365ca00 100644
--- a/llvm/test/CodeGen/Thumb2/mve-complex-deinterleaving-mixed-cases.ll
+++ b/llvm/test/CodeGen/Thumb2/mve-complex-deinterleaving-mixed-cases.ll
@@ -403,14 +403,14 @@ define <4 x float> @mul_subequal(<4 x float> %a, <4 x float> %b, <4 x float> %c)
 ; CHECK-LABEL: mul_subequal:
 ; CHECK:       @ %bb.0: @ %entry
 ; CHECK-NEXT:    vmov d0, r0, r1
-; CHECK-NEXT:    mov r1, sp
+; CHECK-NEXT:    mov r0, sp
+; CHECK-NEXT:    add r1, sp, #16
+; CHECK-NEXT:    vldrw.u32 q1, [r0]
 ; CHECK-NEXT:    vldrw.u32 q2, [r1]
 ; CHECK-NEXT:    vmov d1, r2, r3
-; CHECK-NEXT:    add r0, sp, #16
-; CHECK-NEXT:    vcmul.f32 q3, q0, q2, #0
-; CHECK-NEXT:    vldrw.u32 q1, [r0]
-; CHECK-NEXT:    vcmla.f32 q3, q0, q2, #90
-; CHECK-NEXT:    vsub.f32 q0, q3, q1
+; CHECK-NEXT:    vcmla.f32 q2, q0, q1, #180
+; CHECK-NEXT:    vcmla.f32 q2, q0, q1, #270
+; CHECK-NEXT:    vneg.f32 q0, q2
 ; CHECK-NEXT:    vmov r0, r1, d0
 ; CHECK-NEXT:    vmov r2, r3, d1
 ; CHECK-NEXT:    bx lr

>From fcbb1be54329ce62142f3b81c72458baf23d8afd Mon Sep 17 00:00:00 2001
From: Yichao Yu <yyc1992 at gmail.com>
Date: Mon, 29 Dec 2025 15:55:47 -0500
Subject: [PATCH 09/12] Recognize single standalone partial complex
 multiplication without accumulator

---
 llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp | 17 ++++++++++++++++-
 1 file changed, 16 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
index d96da91e355c2..f7de31ce9942c 100644
--- a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
+++ b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
@@ -731,12 +731,19 @@ ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
       return true;
     }
 
+    unsigned Opcode = I->getOpcode();
+    if (I->hasOneUse() &&
+        (Opcode == Instruction::FMul || Opcode == Instruction::Mul)) {
+      Muls.push_back(GetProduct(I->getOperand(0), I->getOperand(1),
+                                IsPositive));
+      return true;
+    }
+
     if (isa<FPMathOperator>(I) && !I->getFastMathFlags().allowContract()) {
       LLVM_DEBUG(dbgs() << "  - Contract is missing from the FastMath flags.\n");
       return false;
     }
 
-    unsigned Opcode = I->getOpcode();
     bool IsSub;
     if (Opcode == Instruction::FAdd || Opcode == Instruction::Add)
       IsSub = false;
@@ -906,6 +913,14 @@ ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
   };
 
   if (RealMuls.size() == 1) {
+    if (!RealAddend.first && !ImagAddend.first) {
+      return ForeachMatch(RealMuls[0], ImagMuls[0], PN, [&](PartialMulNode *PN) {
+        return MatchCommons(PN, nullptr, RealAddend.second);
+      });
+    }
+    if (!RealAddend.first || !ImagAddend.first) {
+      return nullptr;
+    }
     assert(RealAddend.first && ImagAddend.first);
     if (!isa<Instruction>(RealAddend.first) || !isa<Instruction>(ImagAddend.first)) {
       if (RealAddend.second != ImagAddend.second)

>From 95c10ccb21f465108a642f5b2ddf00b79e44bd44 Mon Sep 17 00:00:00 2001
From: Yichao Yu <yyc1992 at gmail.com>
Date: Mon, 29 Dec 2025 15:56:50 -0500
Subject: [PATCH 10/12] Relax splat recognition condition

For fixed vector it's possible to see non-zero masks in splats
---
 llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
index f7de31ce9942c..a93c981d04025 100644
--- a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
+++ b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
@@ -2350,7 +2350,7 @@ ComplexDeinterleavingGraph::identifySplat(ComplexValues &Vals) {
     if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1)
       return false;
 
-    return all_equal(Mask) && Mask[0] == 0;
+    return all_equal(Mask);
   };
 
   // The splats must meet the following requirements:

>From accb7e12c5ff239faf9d6600df6870a253db0252 Mon Sep 17 00:00:00 2001
From: Yichao Yu <yyc1992 at gmail.com>
Date: Mon, 29 Dec 2025 16:32:59 -0500
Subject: [PATCH 11/12] Try matching complex add with both orders on the
 addition

---
 .../lib/CodeGen/ComplexDeinterleavingPass.cpp | 48 ++++++++++++-------
 1 file changed, 32 insertions(+), 16 deletions(-)

diff --git a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
index a93c981d04025..26f2c5912d84d 100644
--- a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
+++ b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
@@ -980,23 +980,39 @@ ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
     return nullptr;
   }
 
-  CompositeNode *ResA = identifyNode(AR, AI);
-  if (!ResA) {
-    LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n");
-    return nullptr;
-  }
-  CompositeNode *ResB = identifyNode(BR, BI);
-  if (!ResB) {
-    LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n");
-    return nullptr;
-  }
+  auto MatchCAdd = [&](Instruction *AR, Instruction *BI,
+                       Instruction *AI, Instruction *BR) -> CompositeNode* {
+    CompositeNode *ResA = identifyNode(AR, AI);
+    if (!ResA) {
+      LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n");
+      return nullptr;
+    }
+    CompositeNode *ResB = identifyNode(BR, BI);
+    if (!ResB) {
+      LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n");
+      return nullptr;
+    }
 
-  CompositeNode *Node =
-      prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag);
-  Node->Rotation = Rotation;
-  Node->addOperand(ResA);
-  Node->addOperand(ResB);
-  return submitCompositeNode(Node);
+    CompositeNode *Node =
+        prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag);
+    Node->Rotation = Rotation;
+    Node->addOperand(ResA);
+    Node->addOperand(ResB);
+    return submitCompositeNode(Node);
+  };
+
+  if (auto Res = MatchCAdd(AR, BI, AI, BR))
+    return Res;
+  if (Rotation == ComplexDeinterleavingRotation::Rotation_90) {
+    if (BR != AI) {
+      return MatchCAdd(AR, BI, BR, AI);
+    }
+  } else {
+    if (AR != BI) {
+      return MatchCAdd(BI, AR, AI, BR);
+    }
+  }
+  return nullptr;
 }
 
 static bool isInstructionPairAdd(Instruction *A, Instruction *B) {

>From ff699cb163094a0841e1171112b14f1033c4b609 Mon Sep 17 00:00:00 2001
From: Yichao Yu <yyc1992 at gmail.com>
Date: Mon, 29 Dec 2025 17:19:29 -0500
Subject: [PATCH 12/12] Handle non-contracting add/sub as well

---
 .../lib/CodeGen/ComplexDeinterleavingPass.cpp | 27 +++++++++++++++----
 .../complex-deinterleaving-uniform-cases.ll   | 19 +++++--------
 ...ve-complex-deinterleaving-uniform-cases.ll | 23 +++-------------
 3 files changed, 32 insertions(+), 37 deletions(-)

diff --git a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
index 26f2c5912d84d..f86459f80fc79 100644
--- a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
+++ b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
@@ -281,6 +281,7 @@ class ComplexDeinterleavingGraph {
     CompositeNode *UncommonNode;
     CompositeNode *CommonNode{nullptr};
     ComplexDeinterleavingRotation Rotation;
+    bool AllowContract;
     bool IsCommonReal() const { return Rotation == ComplexDeinterleavingRotation::Rotation_0 || Rotation == ComplexDeinterleavingRotation::Rotation_180; }
   };
 
@@ -679,6 +680,8 @@ ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
                     << (RealPositive ? " + " : " - ") << *Real << " / "
                     << (ImagPositive ? " + " : " - ") << *Imag << "\n");
 
+  bool AllowContract = true;
+
   auto GetProduct = [](Value *V1, Value *V2, bool IsPositive) -> Product {
     if (isNeg(V1)) {
       V1 = getNegOperand(V1);
@@ -740,8 +743,7 @@ ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
     }
 
     if (isa<FPMathOperator>(I) && !I->getFastMathFlags().allowContract()) {
-      LLVM_DEBUG(dbgs() << "  - Contract is missing from the FastMath flags.\n");
-      return false;
+      AllowContract = false;
     }
 
     bool IsSub;
@@ -816,13 +818,27 @@ ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
     for (; PN; PN = PN->prev) {
       CompositeNode *NewCN = prepareCompositeNode(
           ComplexDeinterleavingOperation::CMulPartial, nullptr, nullptr);
-      NewCN->Rotation = flipRotation(PN->Rotation, !CNPositive);
+      if (!CNPositive && CN && !PN->AllowContract) {
+        NewCN->Rotation = PN->Rotation;
+      } else {
+        NewCN->Rotation = flipRotation(PN->Rotation, !CNPositive);
+      }
       NewCN->addOperand(PN->CommonNode);
       NewCN->addOperand(PN->UncommonNode);
-      if (CN) {
+      if (CN && PN->AllowContract) {
         NewCN->addOperand(CN);
       }
-      CN = submitCompositeNode(NewCN);
+      submitCompositeNode(NewCN);
+      if (CN && !PN->AllowContract) {
+        auto AddNode = prepareCompositeNode(
+            ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
+        AddNode->Opcode = CNPositive ? Instruction::FAdd : Instruction::FSub;
+        CNPositive = true;
+        AddNode->addOperand(NewCN);
+        AddNode->addOperand(CN);
+        NewCN = submitCompositeNode(AddNode);
+      }
+      CN = NewCN;
     }
     if (!CNPositive) {
       return negCompositeNode(CN);
@@ -859,6 +875,7 @@ ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
                           PartialMulNode *PN, auto &&cb) -> CompositeNode* {
     PartialMulNode NewPN{};
     NewPN.prev = PN;
+    NewPN.AllowContract = AllowContract;
     if (RealMul.IsPositive) {
       NewPN.Rotation = (ImagMul.IsPositive ?
                         ComplexDeinterleavingRotation::Rotation_0 :
diff --git a/llvm/test/CodeGen/AArch64/complex-deinterleaving-uniform-cases.ll b/llvm/test/CodeGen/AArch64/complex-deinterleaving-uniform-cases.ll
index 13434fabefa78..a50a332b180ff 100644
--- a/llvm/test/CodeGen/AArch64/complex-deinterleaving-uniform-cases.ll
+++ b/llvm/test/CodeGen/AArch64/complex-deinterleaving-uniform-cases.ll
@@ -27,22 +27,15 @@ entry:
   ret <4 x float> %interleaved.vec
 }
 
-; Expected to not transform
+; Expected to transform
 define <4 x float> @simple_mul_no_contract(<4 x float> %a, <4 x float> %b) {
 ; CHECK-LABEL: simple_mul_no_contract:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    ext v2.16b, v0.16b, v0.16b, #8
-; CHECK-NEXT:    ext v3.16b, v1.16b, v1.16b, #8
-; CHECK-NEXT:    zip1 v4.2s, v0.2s, v2.2s
-; CHECK-NEXT:    zip2 v0.2s, v0.2s, v2.2s
-; CHECK-NEXT:    zip1 v2.2s, v1.2s, v3.2s
-; CHECK-NEXT:    zip2 v1.2s, v1.2s, v3.2s
-; CHECK-NEXT:    fmul v3.2s, v1.2s, v4.2s
-; CHECK-NEXT:    fmul v4.2s, v2.2s, v4.2s
-; CHECK-NEXT:    fmul v1.2s, v0.2s, v1.2s
-; CHECK-NEXT:    fmla v3.2s, v0.2s, v2.2s
-; CHECK-NEXT:    fsub v0.2s, v4.2s, v1.2s
-; CHECK-NEXT:    zip1 v0.4s, v0.4s, v3.4s
+; CHECK-NEXT:    movi    v2.2d, #0000000000000000
+; CHECK-NEXT:    movi    v3.2d, #0000000000000000
+; CHECK-NEXT:    fcmla   v3.4s, v1.4s, v0.4s, #0
+; CHECK-NEXT:    fcmla   v2.4s, v1.4s, v0.4s, #90
+; CHECK-NEXT:    fadd    v0.4s, v3.4s, v2.4s
 ; CHECK-NEXT:    ret
 entry:
   %strided.vec = shufflevector <4 x float> %a, <4 x float> poison, <2 x i32> <i32 0, i32 2>
diff --git a/llvm/test/CodeGen/Thumb2/mve-complex-deinterleaving-uniform-cases.ll b/llvm/test/CodeGen/Thumb2/mve-complex-deinterleaving-uniform-cases.ll
index dc67abc1be07e..c2af945107ce2 100644
--- a/llvm/test/CodeGen/Thumb2/mve-complex-deinterleaving-uniform-cases.ll
+++ b/llvm/test/CodeGen/Thumb2/mve-complex-deinterleaving-uniform-cases.ll
@@ -26,28 +26,13 @@ entry:
   ret <4 x float> %interleaved.vec
 }
 
-; Expected to not transform
+; Expected to transform
 define arm_aapcs_vfpcc <4 x float> @simple_mul_no_contract(<4 x float> %a, <4 x float> %b) {
 ; CHECK-LABEL: simple_mul_no_contract:
 ; CHECK:       @ %bb.0: @ %entry
-; CHECK-NEXT:    .vsave {d8, d9, d10, d11}
-; CHECK-NEXT:    vpush {d8, d9, d10, d11}
-; CHECK-NEXT:    vmov.f32 s8, s5
-; CHECK-NEXT:    vmov.f32 s12, s1
-; CHECK-NEXT:    vmov.f32 s9, s7
-; CHECK-NEXT:    vmov.f32 s13, s3
-; CHECK-NEXT:    vmov.f32 s1, s2
-; CHECK-NEXT:    vmul.f32 q4, q3, q2
-; CHECK-NEXT:    vmov.f32 s5, s6
-; CHECK-NEXT:    vmul.f32 q2, q2, q0
-; CHECK-NEXT:    vmul.f32 q5, q1, q0
-; CHECK-NEXT:    vfma.f32 q2, q1, q3
-; CHECK-NEXT:    vsub.f32 q4, q5, q4
-; CHECK-NEXT:    vmov.f32 s1, s8
-; CHECK-NEXT:    vmov.f32 s0, s16
-; CHECK-NEXT:    vmov.f32 s2, s17
-; CHECK-NEXT:    vmov.f32 s3, s9
-; CHECK-NEXT:    vpop {d8, d9, d10, d11}
+; CHECK-NEXT:    vcmul.f32       q2, q0, q1, #90
+; CHECK-NEXT:    vcmul.f32       q3, q0, q1, #0
+; CHECK-NEXT:    vadd.f32        q0, q3, q2
 ; CHECK-NEXT:    bx lr
 entry:
   %strided.vec = shufflevector <4 x float> %a, <4 x float> poison, <2 x i32> <i32 0, i32 2>



More information about the llvm-commits mailing list