[llvm] 04a8070 - Revert "Revert "[CodeGen] Extend reduction support in ComplexDeinterleaving pass to support predication""

Igor Kirillov via llvm-commits llvm-commits at lists.llvm.org
Fri Jun 23 03:13:56 PDT 2023


Author: Igor Kirillov
Date: 2023-06-23T10:13:22Z
New Revision: 04a8070b46da2bcd47d0a134922409dc16bb9d57

URL: https://github.com/llvm/llvm-project/commit/04a8070b46da2bcd47d0a134922409dc16bb9d57
DIFF: https://github.com/llvm/llvm-project/commit/04a8070b46da2bcd47d0a134922409dc16bb9d57.diff

LOG: Revert "Revert "[CodeGen] Extend reduction support in ComplexDeinterleaving pass to support predication""

Adds the capability to recognize SelectInst that appear in the IR.
These instructions are generated during scalable vectorization for reduction
and when the code contains conditions inside the loop body or when
"-prefer-predicate-over-epilogue=predicate-dont-vectorize" is set.

Differential Revision: https://reviews.llvm.org/D152558

This reverts commit ab09654832dba5cef8baa6400fdfd3e4d1495624.

Reason: Reapplying after removing unnecessary default case in switch expression.

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/ComplexDeinterleavingPass.h
    llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
    llvm/test/CodeGen/AArch64/complex-deinterleaving-reductions-predicated-scalable.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/ComplexDeinterleavingPass.h b/llvm/include/llvm/CodeGen/ComplexDeinterleavingPass.h
index e7119078d8a3b..d75d0bdf26cbc 100644
--- a/llvm/include/llvm/CodeGen/ComplexDeinterleavingPass.h
+++ b/llvm/include/llvm/CodeGen/ComplexDeinterleavingPass.h
@@ -41,6 +41,7 @@ enum class ComplexDeinterleavingOperation {
   Symmetric,
   ReductionPHI,
   ReductionOperation,
+  ReductionSelect,
 };
 
 enum class ComplexDeinterleavingRotation {

diff  --git a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
index a30edf5613e45..9339ba3072473 100644
--- a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
+++ b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
@@ -371,6 +371,10 @@ class ComplexDeinterleavingGraph {
 
   NodePtr identifyPHINode(Instruction *Real, Instruction *Imag);
 
+  /// Identifies SelectInsts in a loop that has reduction with predication masks
+  /// and/or predicated tail folding
+  NodePtr identifySelectNode(Instruction *Real, Instruction *Imag);
+
   Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node);
 
   /// Complete IR modifications after producing new reduction operation:
@@ -889,6 +893,9 @@ ComplexDeinterleavingGraph::identifyNode(Instruction *Real, Instruction *Imag) {
   if (NodePtr CN = identifyPHINode(Real, Imag))
     return CN;
 
+  if (NodePtr CN = identifySelectNode(Real, Imag))
+    return CN;
+
   auto *VTy = cast<VectorType>(Real->getType());
   auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
 
@@ -1713,6 +1720,45 @@ ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
   return submitCompositeNode(PlaceholderNode);
 }
 
+ComplexDeinterleavingGraph::NodePtr
+ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real,
+                                               Instruction *Imag) {
+  auto *SelectReal = dyn_cast<SelectInst>(Real);
+  auto *SelectImag = dyn_cast<SelectInst>(Imag);
+  if (!SelectReal || !SelectImag)
+    return nullptr;
+
+  Instruction *MaskA, *MaskB;
+  Instruction *AR, *AI, *RA, *BI;
+  if (!match(Real, m_Select(m_Instruction(MaskA), m_Instruction(AR),
+                            m_Instruction(RA))) ||
+      !match(Imag, m_Select(m_Instruction(MaskB), m_Instruction(AI),
+                            m_Instruction(BI))))
+    return nullptr;
+
+  if (MaskA != MaskB && !MaskA->isIdenticalTo(MaskB))
+    return nullptr;
+
+  if (!MaskA->getType()->isVectorTy())
+    return nullptr;
+
+  auto NodeA = identifyNode(AR, AI);
+  if (!NodeA)
+    return nullptr;
+
+  auto NodeB = identifyNode(RA, BI);
+  if (!NodeB)
+    return nullptr;
+
+  NodePtr PlaceholderNode = prepareCompositeNode(
+      ComplexDeinterleavingOperation::ReductionSelect, Real, Imag);
+  PlaceholderNode->addOperand(NodeA);
+  PlaceholderNode->addOperand(NodeB);
+  FinalInstructions.insert(MaskA);
+  FinalInstructions.insert(MaskB);
+  return submitCompositeNode(PlaceholderNode);
+}
+
 static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode,
                                    FastMathFlags Flags, Value *InputA,
                                    Value *InputB) {
@@ -1787,6 +1833,19 @@ Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
     ReplacementNode = replaceNode(Builder, Node->Operands[0]);
     processReductionOperation(ReplacementNode, Node);
     break;
+  case ComplexDeinterleavingOperation::ReductionSelect: {
+    auto *MaskReal = Node->Real->getOperand(0);
+    auto *MaskImag = Node->Imag->getOperand(0);
+    auto *A = replaceNode(Builder, Node->Operands[0]);
+    auto *B = replaceNode(Builder, Node->Operands[1]);
+    auto *NewMaskTy = VectorType::getDoubleElementsVectorType(
+        cast<VectorType>(MaskReal->getType()));
+    auto *NewMask =
+        Builder.CreateIntrinsic(Intrinsic::experimental_vector_interleave2,
+                                NewMaskTy, {MaskReal, MaskImag});
+    ReplacementNode = Builder.CreateSelect(NewMask, A, B);
+    break;
+  }
   }
 
   assert(ReplacementNode && "Target failed to create Intrinsic call.");

diff  --git a/llvm/test/CodeGen/AArch64/complex-deinterleaving-reductions-predicated-scalable.ll b/llvm/test/CodeGen/AArch64/complex-deinterleaving-reductions-predicated-scalable.ll
index 2ca15296e8e97..cfa3f3943b1bb 100644
--- a/llvm/test/CodeGen/AArch64/complex-deinterleaving-reductions-predicated-scalable.ll
+++ b/llvm/test/CodeGen/AArch64/complex-deinterleaving-reductions-predicated-scalable.ll
@@ -20,8 +20,9 @@ define %"class.std::complex" @complex_mul_v2f64(ptr %a, ptr %b) {
 ; CHECK-NEXT:    mov x11, x10
 ; CHECK-NEXT:    mov z1.d, #0 // =0x0
 ; CHECK-NEXT:    rdvl x12, #2
-; CHECK-NEXT:    mov z0.d, z1.d
 ; CHECK-NEXT:    whilelo p1.d, xzr, x9
+; CHECK-NEXT:    zip2 z0.d, z1.d, z1.d
+; CHECK-NEXT:    zip1 z1.d, z1.d, z1.d
 ; CHECK-NEXT:    ptrue p0.d
 ; CHECK-NEXT:  .LBB0_1: // %vector.body
 ; CHECK-NEXT:    // =>This Inner Loop Header: Depth=1
@@ -29,29 +30,27 @@ define %"class.std::complex" @complex_mul_v2f64(ptr %a, ptr %b) {
 ; CHECK-NEXT:    add x14, x1, x8
 ; CHECK-NEXT:    zip1 p2.d, p1.d, p1.d
 ; CHECK-NEXT:    zip2 p3.d, p1.d, p1.d
-; CHECK-NEXT:    add x8, x8, x12
+; CHECK-NEXT:    mov z6.d, z1.d
+; CHECK-NEXT:    mov z7.d, z0.d
 ; CHECK-NEXT:    ld1d { z2.d }, p3/z, [x13, #1, mul vl]
 ; CHECK-NEXT:    ld1d { z3.d }, p2/z, [x13]
 ; CHECK-NEXT:    ld1d { z4.d }, p3/z, [x14, #1, mul vl]
 ; CHECK-NEXT:    ld1d { z5.d }, p2/z, [x14]
-; CHECK-NEXT:    uzp2 z6.d, z3.d, z2.d
-; CHECK-NEXT:    uzp1 z2.d, z3.d, z2.d
-; CHECK-NEXT:    uzp1 z7.d, z5.d, z4.d
-; CHECK-NEXT:    uzp2 z4.d, z5.d, z4.d
-; CHECK-NEXT:    movprfx z3, z1
-; CHECK-NEXT:    fmla z3.d, p0/m, z7.d, z6.d
-; CHECK-NEXT:    fmad z7.d, p0/m, z2.d, z0.d
-; CHECK-NEXT:    fmad z2.d, p0/m, z4.d, z3.d
-; CHECK-NEXT:    movprfx z3, z7
-; CHECK-NEXT:    fmls z3.d, p0/m, z4.d, z6.d
-; CHECK-NEXT:    mov z1.d, p1/m, z2.d
-; CHECK-NEXT:    mov z0.d, p1/m, z3.d
 ; CHECK-NEXT:    whilelo p1.d, x11, x9
+; CHECK-NEXT:    add x8, x8, x12
 ; CHECK-NEXT:    add x11, x11, x10
+; CHECK-NEXT:    fcmla z6.d, p0/m, z5.d, z3.d, #0
+; CHECK-NEXT:    fcmla z7.d, p0/m, z4.d, z2.d, #0
+; CHECK-NEXT:    fcmla z6.d, p0/m, z5.d, z3.d, #90
+; CHECK-NEXT:    fcmla z7.d, p0/m, z4.d, z2.d, #90
+; CHECK-NEXT:    mov z0.d, p3/m, z7.d
+; CHECK-NEXT:    mov z1.d, p2/m, z6.d
 ; CHECK-NEXT:    b.mi .LBB0_1
 ; CHECK-NEXT:  // %bb.2: // %exit.block
+; CHECK-NEXT:    uzp2 z2.d, z1.d, z0.d
+; CHECK-NEXT:    uzp1 z0.d, z1.d, z0.d
 ; CHECK-NEXT:    faddv d0, p0, z0.d
-; CHECK-NEXT:    faddv d1, p0, z1.d
+; CHECK-NEXT:    faddv d1, p0, z2.d
 ; CHECK-NEXT:    // kill: def $d0 killed $d0 killed $z0
 ; CHECK-NEXT:    // kill: def $d1 killed $d1 killed $z1
 ; CHECK-NEXT:    ret
@@ -122,39 +121,38 @@ define %"class.std::complex" @complex_mul_predicated_v2f64(ptr %a, ptr %b, ptr %
 ; CHECK-NEXT:    and x11, x11, x12
 ; CHECK-NEXT:    mov z1.d, #0 // =0x0
 ; CHECK-NEXT:    rdvl x12, #2
-; CHECK-NEXT:    mov z0.d, z1.d
+; CHECK-NEXT:    zip2 z0.d, z1.d, z1.d
+; CHECK-NEXT:    zip1 z1.d, z1.d, z1.d
 ; CHECK-NEXT:    ptrue p0.d
 ; CHECK-NEXT:  .LBB1_1: // %vector.body
 ; CHECK-NEXT:    // =>This Inner Loop Header: Depth=1
 ; CHECK-NEXT:    ld1w { z2.d }, p0/z, [x2, x9, lsl #2]
 ; CHECK-NEXT:    add x13, x0, x8
 ; CHECK-NEXT:    add x14, x1, x8
+; CHECK-NEXT:    mov z6.d, z1.d
+; CHECK-NEXT:    mov z7.d, z0.d
 ; CHECK-NEXT:    add x9, x9, x10
 ; CHECK-NEXT:    add x8, x8, x12
-; CHECK-NEXT:    cmpne p1.d, p0/z, z2.d, #0
-; CHECK-NEXT:    zip1 p2.d, p1.d, p1.d
-; CHECK-NEXT:    zip2 p3.d, p1.d, p1.d
-; CHECK-NEXT:    ld1d { z2.d }, p3/z, [x13, #1, mul vl]
-; CHECK-NEXT:    ld1d { z3.d }, p2/z, [x13]
-; CHECK-NEXT:    ld1d { z4.d }, p3/z, [x14, #1, mul vl]
-; CHECK-NEXT:    ld1d { z5.d }, p2/z, [x14]
+; CHECK-NEXT:    cmpne p2.d, p0/z, z2.d, #0
+; CHECK-NEXT:    zip1 p1.d, p2.d, p2.d
+; CHECK-NEXT:    zip2 p2.d, p2.d, p2.d
+; CHECK-NEXT:    ld1d { z2.d }, p2/z, [x13, #1, mul vl]
+; CHECK-NEXT:    ld1d { z3.d }, p1/z, [x13]
+; CHECK-NEXT:    ld1d { z4.d }, p2/z, [x14, #1, mul vl]
+; CHECK-NEXT:    ld1d { z5.d }, p1/z, [x14]
 ; CHECK-NEXT:    cmp x11, x9
-; CHECK-NEXT:    uzp2 z6.d, z3.d, z2.d
-; CHECK-NEXT:    uzp1 z2.d, z3.d, z2.d
-; CHECK-NEXT:    uzp1 z3.d, z5.d, z4.d
-; CHECK-NEXT:    movprfx z7, z0
-; CHECK-NEXT:    fmla z7.d, p0/m, z3.d, z2.d
-; CHECK-NEXT:    fmad z3.d, p0/m, z6.d, z1.d
-; CHECK-NEXT:    uzp2 z4.d, z5.d, z4.d
-; CHECK-NEXT:    fmad z2.d, p0/m, z4.d, z3.d
-; CHECK-NEXT:    movprfx z5, z7
-; CHECK-NEXT:    fmls z5.d, p0/m, z4.d, z6.d
-; CHECK-NEXT:    mov z0.d, p1/m, z5.d
-; CHECK-NEXT:    mov z1.d, p1/m, z2.d
+; CHECK-NEXT:    fcmla z6.d, p0/m, z5.d, z3.d, #0
+; CHECK-NEXT:    fcmla z7.d, p0/m, z4.d, z2.d, #0
+; CHECK-NEXT:    fcmla z6.d, p0/m, z5.d, z3.d, #90
+; CHECK-NEXT:    fcmla z7.d, p0/m, z4.d, z2.d, #90
+; CHECK-NEXT:    mov z0.d, p2/m, z7.d
+; CHECK-NEXT:    mov z1.d, p1/m, z6.d
 ; CHECK-NEXT:    b.ne .LBB1_1
 ; CHECK-NEXT:  // %bb.2: // %exit.block
+; CHECK-NEXT:    uzp2 z2.d, z1.d, z0.d
+; CHECK-NEXT:    uzp1 z0.d, z1.d, z0.d
 ; CHECK-NEXT:    faddv d0, p0, z0.d
-; CHECK-NEXT:    faddv d1, p0, z1.d
+; CHECK-NEXT:    faddv d1, p0, z2.d
 ; CHECK-NEXT:    // kill: def $d0 killed $d0 killed $z0
 ; CHECK-NEXT:    // kill: def $d1 killed $d1 killed $z1
 ; CHECK-NEXT:    ret
@@ -223,42 +221,41 @@ define %"class.std::complex" @complex_mul_predicated_x2_v2f64(ptr %a, ptr %b, pt
 ; CHECK-NEXT:    mov x8, xzr
 ; CHECK-NEXT:    mov x9, xzr
 ; CHECK-NEXT:    mov z1.d, #0 // =0x0
-; CHECK-NEXT:    mov z0.d, z1.d
 ; CHECK-NEXT:    cntd x11
-; CHECK-NEXT:    whilelo p1.d, xzr, x10
 ; CHECK-NEXT:    rdvl x12, #2
+; CHECK-NEXT:    whilelo p1.d, xzr, x10
+; CHECK-NEXT:    zip2 z0.d, z1.d, z1.d
+; CHECK-NEXT:    zip1 z1.d, z1.d, z1.d
 ; CHECK-NEXT:    ptrue p0.d
 ; CHECK-NEXT:  .LBB2_1: // %vector.body
 ; CHECK-NEXT:    // =>This Inner Loop Header: Depth=1
 ; CHECK-NEXT:    ld1w { z2.d }, p1/z, [x2, x9, lsl #2]
 ; CHECK-NEXT:    add x13, x0, x8
 ; CHECK-NEXT:    add x14, x1, x8
+; CHECK-NEXT:    mov z6.d, z1.d
+; CHECK-NEXT:    mov z7.d, z0.d
 ; CHECK-NEXT:    add x9, x9, x11
 ; CHECK-NEXT:    add x8, x8, x12
-; CHECK-NEXT:    cmpne p2.d, p1/z, z2.d, #0
-; CHECK-NEXT:    zip1 p1.d, p2.d, p2.d
-; CHECK-NEXT:    zip2 p3.d, p2.d, p2.d
+; CHECK-NEXT:    cmpne p1.d, p1/z, z2.d, #0
+; CHECK-NEXT:    zip1 p2.d, p1.d, p1.d
+; CHECK-NEXT:    zip2 p3.d, p1.d, p1.d
 ; CHECK-NEXT:    ld1d { z2.d }, p3/z, [x13, #1, mul vl]
-; CHECK-NEXT:    ld1d { z3.d }, p1/z, [x13]
+; CHECK-NEXT:    ld1d { z3.d }, p2/z, [x13]
 ; CHECK-NEXT:    ld1d { z4.d }, p3/z, [x14, #1, mul vl]
-; CHECK-NEXT:    ld1d { z5.d }, p1/z, [x14]
+; CHECK-NEXT:    ld1d { z5.d }, p2/z, [x14]
 ; CHECK-NEXT:    whilelo p1.d, x9, x10
-; CHECK-NEXT:    uzp1 z6.d, z3.d, z2.d
-; CHECK-NEXT:    uzp2 z2.d, z3.d, z2.d
-; CHECK-NEXT:    uzp1 z7.d, z5.d, z4.d
-; CHECK-NEXT:    uzp2 z4.d, z5.d, z4.d
-; CHECK-NEXT:    movprfx z3, z0
-; CHECK-NEXT:    fmla z3.d, p0/m, z7.d, z6.d
-; CHECK-NEXT:    fmad z7.d, p0/m, z2.d, z1.d
-; CHECK-NEXT:    fmsb z2.d, p0/m, z4.d, z3.d
-; CHECK-NEXT:    movprfx z3, z7
-; CHECK-NEXT:    fmla z3.d, p0/m, z4.d, z6.d
-; CHECK-NEXT:    mov z1.d, p2/m, z3.d
-; CHECK-NEXT:    mov z0.d, p2/m, z2.d
+; CHECK-NEXT:    fcmla z6.d, p0/m, z5.d, z3.d, #0
+; CHECK-NEXT:    fcmla z7.d, p0/m, z4.d, z2.d, #0
+; CHECK-NEXT:    fcmla z6.d, p0/m, z5.d, z3.d, #90
+; CHECK-NEXT:    fcmla z7.d, p0/m, z4.d, z2.d, #90
+; CHECK-NEXT:    mov z0.d, p3/m, z7.d
+; CHECK-NEXT:    mov z1.d, p2/m, z6.d
 ; CHECK-NEXT:    b.mi .LBB2_1
 ; CHECK-NEXT:  // %bb.2: // %exit.block
+; CHECK-NEXT:    uzp2 z2.d, z1.d, z0.d
+; CHECK-NEXT:    uzp1 z0.d, z1.d, z0.d
 ; CHECK-NEXT:    faddv d0, p0, z0.d
-; CHECK-NEXT:    faddv d1, p0, z1.d
+; CHECK-NEXT:    faddv d1, p0, z2.d
 ; CHECK-NEXT:    // kill: def $d0 killed $d0 killed $z0
 ; CHECK-NEXT:    // kill: def $d1 killed $d1 killed $z1
 ; CHECK-NEXT:    ret


        


More information about the llvm-commits mailing list