[llvm] 7f20407 - [CodeGen] Add support for Splats in ComplexDeinterleaving pass

Igor Kirillov via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 5 10:03:27 PDT 2023


Author: Igor Kirillov
Date: 2023-07-05T17:02:52Z
New Revision: 7f20407ceed8e713c5f193a2598358332b5ac0d3

URL: https://github.com/llvm/llvm-project/commit/7f20407ceed8e713c5f193a2598358332b5ac0d3
DIFF: https://github.com/llvm/llvm-project/commit/7f20407ceed8e713c5f193a2598358332b5ac0d3.diff

LOG: [CodeGen] Add support for Splats in ComplexDeinterleaving pass

This commit allows generating of complex number intrinsics for expressions
with constants or loops invariants, which are represented as splats.
For instance, after vectorizing loops in the following code snippets,
the ComplexDeinterleaving pass will be able to generate complex number
intrinsics:

```
complex<> x = ...;
for (int i = 0; i < N; ++i)
    c[i] = a[i] * b[i] * x;
```

or

```
for (int i = 0; i < N; ++i)
    c[i] = a[i] * b[i] * (11.0 + 3.0i);
```

Differential Revision: https://reviews.llvm.org/D153355

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/ComplexDeinterleavingPass.h
    llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
    llvm/test/CodeGen/AArch64/complex-deinterleaving-splat-scalable.ll
    llvm/test/CodeGen/AArch64/complex-deinterleaving-splat.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/ComplexDeinterleavingPass.h b/llvm/include/llvm/CodeGen/ComplexDeinterleavingPass.h
index d75d0bdf26cbc0..84a2673fecb5bf 100644
--- a/llvm/include/llvm/CodeGen/ComplexDeinterleavingPass.h
+++ b/llvm/include/llvm/CodeGen/ComplexDeinterleavingPass.h
@@ -38,6 +38,7 @@ enum class ComplexDeinterleavingOperation {
   // The following 'operations' are used to represent internal states. Backends
   // are not expected to try and support these in any capacity.
   Deinterleave,
+  Splat,
   Symmetric,
   ReductionPHI,
   ReductionOperation,

diff  --git a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
index 3a4b94d5eae271..9f2c665866d3c9 100644
--- a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
+++ b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
@@ -369,6 +369,12 @@ class ComplexDeinterleavingGraph {
   /// intrinsic (for both fixed and scalable vectors)
   NodePtr identifyDeinterleave(Instruction *Real, Instruction *Imag);
 
+  /// identifying the operation that represents a complex number repeated in a
+  /// Splat vector. There are two possible types of splats: ConstantExpr with
+  /// the opcode ShuffleVector and ShuffleVectorInstr. Both should have an
+  /// initialization mask with all values set to zero.
+  NodePtr identifySplat(Value *Real, Value *Imag);
+
   NodePtr identifyPHINode(Instruction *Real, Instruction *Imag);
 
   /// Identifies SelectInsts in a loop that has reduction with predication masks
@@ -865,6 +871,9 @@ ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) {
     return CN;
   }
 
+  if (NodePtr CN = identifySplat(R, I))
+    return CN;
+
   auto *Real = dyn_cast<Instruction>(R);
   auto *Imag = dyn_cast<Instruction>(I);
   if (!Real || !Imag)
@@ -1694,6 +1703,59 @@ ComplexDeinterleavingGraph::identifyDeinterleave(Instruction *Real,
   return submitCompositeNode(PlaceholderNode);
 }
 
+ComplexDeinterleavingGraph::NodePtr
+ComplexDeinterleavingGraph::identifySplat(Value *R, Value *I) {
+  auto IsSplat = [](Value *V) -> bool {
+    // Fixed-width vector with constants
+    if (isa<ConstantDataVector>(V))
+      return true;
+
+    VectorType *VTy;
+    ArrayRef<int> Mask;
+    // Splats are represented 
diff erently depending on whether the repeated
+    // value is a constant or an Instruction
+    if (auto *Const = dyn_cast<ConstantExpr>(V)) {
+      if (Const->getOpcode() != Instruction::ShuffleVector)
+        return false;
+      VTy = cast<VectorType>(Const->getType());
+      Mask = Const->getShuffleMask();
+    } else if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V)) {
+      VTy = Shuf->getType();
+      Mask = Shuf->getShuffleMask();
+    } else {
+      return false;
+    }
+
+    // When the data type is <1 x Type>, it's not possible to 
diff erentiate
+    // between the ComplexDeinterleaving::Deinterleave and
+    // ComplexDeinterleaving::Splat operations.
+    if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1)
+      return false;
+
+    return all_equal(Mask) && Mask[0] == 0;
+  };
+
+  if (!IsSplat(R) || !IsSplat(I))
+    return nullptr;
+
+  auto *Real = dyn_cast<Instruction>(R);
+  auto *Imag = dyn_cast<Instruction>(I);
+  if ((!Real && Imag) || (Real && !Imag))
+    return nullptr;
+
+  if (Real && Imag) {
+    // Non-constant splats should be in the same basic block
+    if (Real->getParent() != Imag->getParent())
+      return nullptr;
+
+    FinalInstructions.insert(Real);
+    FinalInstructions.insert(Imag);
+  }
+  NodePtr PlaceholderNode =
+      prepareCompositeNode(ComplexDeinterleavingOperation::Splat, R, I);
+  return submitCompositeNode(PlaceholderNode);
+}
+
 ComplexDeinterleavingGraph::NodePtr
 ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
                                             Instruction *Imag) {
@@ -1804,6 +1866,25 @@ Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
   case ComplexDeinterleavingOperation::Deinterleave:
     llvm_unreachable("Deinterleave node should already have ReplacementNode");
     break;
+  case ComplexDeinterleavingOperation::Splat: {
+    auto *NewTy = VectorType::getDoubleElementsVectorType(
+        cast<VectorType>(Node->Real->getType()));
+    auto *R = dyn_cast<Instruction>(Node->Real);
+    auto *I = dyn_cast<Instruction>(Node->Imag);
+    if (R && I) {
+      // Splats that are not constant are interleaved where they are located
+      Instruction *InsertPoint = (I->comesBefore(R) ? R : I)->getNextNode();
+      IRBuilder<> IRB(InsertPoint);
+      ReplacementNode =
+          IRB.CreateIntrinsic(Intrinsic::experimental_vector_interleave2, NewTy,
+                              {Node->Real, Node->Imag});
+    } else {
+      ReplacementNode =
+          Builder.CreateIntrinsic(Intrinsic::experimental_vector_interleave2,
+                                  NewTy, {Node->Real, Node->Imag});
+    }
+    break;
+  }
   case ComplexDeinterleavingOperation::ReductionPHI: {
     // If Operation is ReductionPHI, a new empty PHINode is created.
     // It is filled later when the ReductionOperation is processed.

diff  --git a/llvm/test/CodeGen/AArch64/complex-deinterleaving-splat-scalable.ll b/llvm/test/CodeGen/AArch64/complex-deinterleaving-splat-scalable.ll
index b15fbd2e563289..db290aee1b3b99 100644
--- a/llvm/test/CodeGen/AArch64/complex-deinterleaving-splat-scalable.ll
+++ b/llvm/test/CodeGen/AArch64/complex-deinterleaving-splat-scalable.ll
@@ -8,23 +8,24 @@ target triple = "aarch64-arm-none-eabi"
 define <vscale x 4 x double> @complex_mul_const(<vscale x 4 x double> %a, <vscale x 4 x double> %b) {
 ; CHECK-LABEL: complex_mul_const:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    uzp1 z4.d, z2.d, z3.d
-; CHECK-NEXT:    uzp1 z5.d, z0.d, z1.d
-; CHECK-NEXT:    uzp2 z0.d, z0.d, z1.d
+; CHECK-NEXT:    mov z4.d, #0 // =0x0
 ; CHECK-NEXT:    ptrue p0.d
-; CHECK-NEXT:    uzp2 z1.d, z2.d, z3.d
-; CHECK-NEXT:    fmul z2.d, z4.d, z0.d
-; CHECK-NEXT:    fmla z2.d, p0/m, z1.d, z5.d
-; CHECK-NEXT:    fmul z0.d, z1.d, z0.d
-; CHECK-NEXT:    fmov z1.d, #11.00000000
-; CHECK-NEXT:    fnmls z0.d, p0/m, z4.d, z5.d
-; CHECK-NEXT:    fmov z3.d, #3.00000000
-; CHECK-NEXT:    fmul z4.d, z2.d, z1.d
-; CHECK-NEXT:    fmul z2.d, z2.d, z3.d
-; CHECK-NEXT:    fmla z4.d, p0/m, z0.d, z3.d
-; CHECK-NEXT:    fnmsb z1.d, p0/m, z0.d, z2.d
-; CHECK-NEXT:    zip1 z0.d, z1.d, z4.d
-; CHECK-NEXT:    zip2 z1.d, z1.d, z4.d
+; CHECK-NEXT:    mov z5.d, z4.d
+; CHECK-NEXT:    mov z6.d, z4.d
+; CHECK-NEXT:    fcmla z5.d, p0/m, z0.d, z2.d, #0
+; CHECK-NEXT:    fcmla z6.d, p0/m, z1.d, z3.d, #0
+; CHECK-NEXT:    fcmla z5.d, p0/m, z0.d, z2.d, #90
+; CHECK-NEXT:    fcmla z6.d, p0/m, z1.d, z3.d, #90
+; CHECK-NEXT:    fmov z1.d, #3.00000000
+; CHECK-NEXT:    fmov z2.d, #11.00000000
+; CHECK-NEXT:    zip2 z3.d, z2.d, z1.d
+; CHECK-NEXT:    mov z0.d, z4.d
+; CHECK-NEXT:    zip1 z1.d, z2.d, z1.d
+; CHECK-NEXT:    fcmla z4.d, p0/m, z6.d, z3.d, #0
+; CHECK-NEXT:    fcmla z0.d, p0/m, z5.d, z1.d, #0
+; CHECK-NEXT:    fcmla z4.d, p0/m, z6.d, z3.d, #90
+; CHECK-NEXT:    fcmla z0.d, p0/m, z5.d, z1.d, #90
+; CHECK-NEXT:    mov z1.d, z4.d
 ; CHECK-NEXT:    ret
 entry:
   %strided.vec = tail call { <vscale x 2 x double>, <vscale x 2 x double> } @llvm.experimental.vector.deinterleave2.nxv4f64(<vscale x 4 x double> %a)
@@ -54,25 +55,26 @@ entry:
 define <vscale x 4 x double> @complex_mul_non_const(<vscale x 4 x double> %a, <vscale x 4 x double> %b, [2 x double] %c) {
 ; CHECK-LABEL: complex_mul_non_const:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    uzp1 z6.d, z2.d, z3.d
-; CHECK-NEXT:    uzp1 z7.d, z0.d, z1.d
-; CHECK-NEXT:    uzp2 z0.d, z0.d, z1.d
-; CHECK-NEXT:    ptrue p0.d
-; CHECK-NEXT:    uzp2 z1.d, z2.d, z3.d
-; CHECK-NEXT:    fmul z2.d, z6.d, z0.d
-; CHECK-NEXT:    fmla z2.d, p0/m, z1.d, z7.d
+; CHECK-NEXT:    mov z6.d, #0 // =0x0
+; CHECK-NEXT:    // kill: def $d5 killed $d5 def $z5
 ; CHECK-NEXT:    // kill: def $d4 killed $d4 def $z4
-; CHECK-NEXT:    fmul z0.d, z1.d, z0.d
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    mov z7.d, z6.d
+; CHECK-NEXT:    mov z24.d, z6.d
+; CHECK-NEXT:    mov z5.d, d5
 ; CHECK-NEXT:    mov z4.d, d4
-; CHECK-NEXT:    // kill: def $d5 killed $d5 def $z5
-; CHECK-NEXT:    mov z3.d, d5
-; CHECK-NEXT:    fnmls z0.d, p0/m, z6.d, z7.d
-; CHECK-NEXT:    fmul z1.d, z2.d, z4.d
-; CHECK-NEXT:    fmul z2.d, z2.d, z3.d
-; CHECK-NEXT:    fmla z1.d, p0/m, z0.d, z3.d
-; CHECK-NEXT:    fnmls z2.d, p0/m, z0.d, z4.d
-; CHECK-NEXT:    zip1 z0.d, z2.d, z1.d
-; CHECK-NEXT:    zip2 z1.d, z2.d, z1.d
+; CHECK-NEXT:    fcmla z7.d, p0/m, z0.d, z2.d, #0
+; CHECK-NEXT:    fcmla z24.d, p0/m, z1.d, z3.d, #0
+; CHECK-NEXT:    fcmla z7.d, p0/m, z0.d, z2.d, #90
+; CHECK-NEXT:    zip2 z2.d, z4.d, z5.d
+; CHECK-NEXT:    fcmla z24.d, p0/m, z1.d, z3.d, #90
+; CHECK-NEXT:    mov z0.d, z6.d
+; CHECK-NEXT:    zip1 z4.d, z4.d, z5.d
+; CHECK-NEXT:    fcmla z6.d, p0/m, z24.d, z2.d, #0
+; CHECK-NEXT:    fcmla z0.d, p0/m, z7.d, z4.d, #0
+; CHECK-NEXT:    fcmla z6.d, p0/m, z24.d, z2.d, #90
+; CHECK-NEXT:    fcmla z0.d, p0/m, z7.d, z4.d, #90
+; CHECK-NEXT:    mov z1.d, z6.d
 ; CHECK-NEXT:    ret
 entry:
   %c.coerce.fca.0.extract = extractvalue [2 x double] %c, 0

diff  --git a/llvm/test/CodeGen/AArch64/complex-deinterleaving-splat.ll b/llvm/test/CodeGen/AArch64/complex-deinterleaving-splat.ll
index 0123406f92113d..d27436b6be66a6 100644
--- a/llvm/test/CodeGen/AArch64/complex-deinterleaving-splat.ll
+++ b/llvm/test/CodeGen/AArch64/complex-deinterleaving-splat.ll
@@ -9,24 +9,21 @@ target triple = "aarch64-arm-none-eabi"
 define <4 x double> @complex_mul_const(<4 x double> %a, <4 x double> %b) {
 ; CHECK-LABEL: complex_mul_const:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    zip1 v5.2d, v2.2d, v3.2d
-; CHECK-NEXT:    zip2 v6.2d, v0.2d, v1.2d
-; CHECK-NEXT:    zip2 v2.2d, v2.2d, v3.2d
-; CHECK-NEXT:    zip1 v0.2d, v0.2d, v1.2d
-; CHECK-NEXT:    fmov v4.2d, #3.00000000
-; CHECK-NEXT:    fmul v1.2d, v5.2d, v6.2d
-; CHECK-NEXT:    fmul v3.2d, v2.2d, v6.2d
-; CHECK-NEXT:    fmla v1.2d, v0.2d, v2.2d
-; CHECK-NEXT:    fneg v2.2d, v3.2d
-; CHECK-NEXT:    fmov v3.2d, #11.00000000
-; CHECK-NEXT:    fmul v6.2d, v1.2d, v4.2d
-; CHECK-NEXT:    fmla v2.2d, v0.2d, v5.2d
-; CHECK-NEXT:    fmul v1.2d, v1.2d, v3.2d
-; CHECK-NEXT:    fneg v5.2d, v6.2d
-; CHECK-NEXT:    fmla v1.2d, v4.2d, v2.2d
-; CHECK-NEXT:    fmla v5.2d, v3.2d, v2.2d
-; CHECK-NEXT:    zip1 v0.2d, v5.2d, v1.2d
-; CHECK-NEXT:    zip2 v1.2d, v5.2d, v1.2d
+; CHECK-NEXT:    movi v6.2d, #0000000000000000
+; CHECK-NEXT:    adrp x8, .LCPI0_0
+; CHECK-NEXT:    movi v5.2d, #0000000000000000
+; CHECK-NEXT:    movi v4.2d, #0000000000000000
+; CHECK-NEXT:    fcmla v6.2d, v3.2d, v1.2d, #0
+; CHECK-NEXT:    fcmla v5.2d, v2.2d, v0.2d, #0
+; CHECK-NEXT:    fcmla v6.2d, v3.2d, v1.2d, #90
+; CHECK-NEXT:    fcmla v5.2d, v2.2d, v0.2d, #90
+; CHECK-NEXT:    ldr q2, [x8, :lo12:.LCPI0_0]
+; CHECK-NEXT:    movi v0.2d, #0000000000000000
+; CHECK-NEXT:    fcmla v4.2d, v2.2d, v6.2d, #0
+; CHECK-NEXT:    fcmla v0.2d, v2.2d, v5.2d, #0
+; CHECK-NEXT:    fcmla v4.2d, v2.2d, v6.2d, #90
+; CHECK-NEXT:    fcmla v0.2d, v2.2d, v5.2d, #90
+; CHECK-NEXT:    mov v1.16b, v4.16b
 ; CHECK-NEXT:    ret
 entry:
   %strided.vec = shufflevector <4 x double> %a, <4 x double> poison, <2 x i32> <i32 0, i32 2>
@@ -55,24 +52,22 @@ entry:
 define <4 x double> @complex_mul_non_const(<4 x double> %a, <4 x double> %b, [2 x double] %c) {
 ; CHECK-LABEL: complex_mul_non_const:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    zip1 v6.2d, v2.2d, v3.2d
-; CHECK-NEXT:    // kill: def $d5 killed $d5 def $q5
+; CHECK-NEXT:    movi v6.2d, #0000000000000000
 ; CHECK-NEXT:    // kill: def $d4 killed $d4 def $q4
-; CHECK-NEXT:    zip2 v7.2d, v0.2d, v1.2d
-; CHECK-NEXT:    zip1 v0.2d, v0.2d, v1.2d
-; CHECK-NEXT:    zip2 v1.2d, v2.2d, v3.2d
-; CHECK-NEXT:    fmul v2.2d, v6.2d, v7.2d
-; CHECK-NEXT:    fmul v3.2d, v1.2d, v7.2d
-; CHECK-NEXT:    fmla v2.2d, v0.2d, v1.2d
-; CHECK-NEXT:    fneg v1.2d, v3.2d
-; CHECK-NEXT:    fmul v3.2d, v2.2d, v5.d[0]
-; CHECK-NEXT:    fmul v2.2d, v2.2d, v4.d[0]
-; CHECK-NEXT:    fmla v1.2d, v0.2d, v6.2d
-; CHECK-NEXT:    fneg v3.2d, v3.2d
-; CHECK-NEXT:    fmla v2.2d, v1.2d, v5.d[0]
-; CHECK-NEXT:    fmla v3.2d, v1.2d, v4.d[0]
-; CHECK-NEXT:    zip1 v0.2d, v3.2d, v2.2d
-; CHECK-NEXT:    zip2 v1.2d, v3.2d, v2.2d
+; CHECK-NEXT:    // kill: def $d5 killed $d5 def $q5
+; CHECK-NEXT:    movi v7.2d, #0000000000000000
+; CHECK-NEXT:    mov v4.d[1], v5.d[0]
+; CHECK-NEXT:    fcmla v6.2d, v2.2d, v0.2d, #0
+; CHECK-NEXT:    fcmla v7.2d, v3.2d, v1.2d, #0
+; CHECK-NEXT:    fcmla v6.2d, v2.2d, v0.2d, #90
+; CHECK-NEXT:    movi v2.2d, #0000000000000000
+; CHECK-NEXT:    fcmla v7.2d, v3.2d, v1.2d, #90
+; CHECK-NEXT:    movi v0.2d, #0000000000000000
+; CHECK-NEXT:    fcmla v2.2d, v4.2d, v7.2d, #0
+; CHECK-NEXT:    fcmla v0.2d, v4.2d, v6.2d, #0
+; CHECK-NEXT:    fcmla v2.2d, v4.2d, v7.2d, #90
+; CHECK-NEXT:    fcmla v0.2d, v4.2d, v6.2d, #90
+; CHECK-NEXT:    mov v1.16b, v2.16b
 ; CHECK-NEXT:    ret
 entry:
   %c.coerce.fca.1.extract = extractvalue [2 x double] %c, 1


        


More information about the llvm-commits mailing list