[llvm] 40a81d3 - [CodeGen] Refactor IR generation functions to use IRBuilder in ComplexDeinterleaving pass

Igor Kirillov via llvm-commits llvm-commits at lists.llvm.org
Tue May 30 09:19:19 PDT 2023


Author: Igor Kirillov
Date: 2023-05-30T16:18:28Z
New Revision: 40a81d3100b416393557f015efc971497c0bea46

URL: https://github.com/llvm/llvm-project/commit/40a81d3100b416393557f015efc971497c0bea46
DIFF: https://github.com/llvm/llvm-project/commit/40a81d3100b416393557f015efc971497c0bea46.diff

LOG: [CodeGen] Refactor IR generation functions to use IRBuilder in ComplexDeinterleaving pass

This patch updates several functions in LLVM's IR generation code to accept
an IRBuilder object as an argument, rather than an Instruction that indicates
the insertion point for new instructions.
This change is necessary to handle sophisticated -Ofast optimization cases
from D148558 where it's unclear which instructions should be used as the
insertion point for new operations.

Differential Revision: https://reviews.llvm.org/D148703

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/TargetLowering.h
    llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.h
    llvm/lib/Target/ARM/ARMISelLowering.cpp
    llvm/lib/Target/ARM/ARMISelLowering.h
    llvm/test/CodeGen/AArch64/complex-deinterleaving-mixed-cases.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index b2d73b286b0ad..908d881d7f6da 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -3196,7 +3196,7 @@ class TargetLoweringBase {
   /// If one cannot be created using all the given inputs, nullptr should be
   /// returned.
   virtual Value *createComplexDeinterleavingIR(
-      Instruction *I, ComplexDeinterleavingOperation OperationType,
+      IRBuilderBase &B, ComplexDeinterleavingOperation OperationType,
       ComplexDeinterleavingRotation Rotation, Value *InputA, Value *InputB,
       Value *Accumulator = nullptr) const {
     return nullptr;

diff  --git a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
index 5f06a666a5f2e..4351d68ebc87c 100644
--- a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
+++ b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
@@ -267,7 +267,7 @@ class ComplexDeinterleavingGraph {
   /// intrinsic (for both fixed and scalable vectors)
   NodePtr identifyDeinterleave(Instruction *Real, Instruction *Imag);
 
-  Value *replaceNode(RawNodePtr Node);
+  Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node);
 
 public:
   void dump() { dump(dbgs()); }
@@ -1011,7 +1011,8 @@ ComplexDeinterleavingGraph::identifyDeinterleave(Instruction *Real,
   return submitCompositeNode(PlaceholderNode);
 }
 
-static Value *replaceSymmetricNode(ComplexDeinterleavingGraph::RawNodePtr Node,
+static Value *replaceSymmetricNode(IRBuilderBase &B,
+                                   ComplexDeinterleavingGraph::RawNodePtr Node,
                                    Value *InputA, Value *InputB) {
   Instruction *I = Node->Real;
   if (I->isUnaryOp())
@@ -1021,8 +1022,6 @@ static Value *replaceSymmetricNode(ComplexDeinterleavingGraph::RawNodePtr Node,
     assert(InputB && "Binary symmetric operations need two inputs, only one "
                      "was provided.");
 
-  IRBuilder<> B(I);
-
   switch (I->getOpcode()) {
   case Instruction::FNeg:
     return B.CreateFNegFMF(InputA, I);
@@ -1037,27 +1036,28 @@ static Value *replaceSymmetricNode(ComplexDeinterleavingGraph::RawNodePtr Node,
   return nullptr;
 }
 
-Value *ComplexDeinterleavingGraph::replaceNode(
-    ComplexDeinterleavingGraph::RawNodePtr Node) {
+Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
+                                               RawNodePtr Node) {
   if (Node->ReplacementNode)
     return Node->ReplacementNode;
 
-  Value *Input0 = replaceNode(Node->Operands[0]);
-  Value *Input1 =
-      Node->Operands.size() > 1 ? replaceNode(Node->Operands[1]) : nullptr;
-  Value *Accumulator =
-      Node->Operands.size() > 2 ? replaceNode(Node->Operands[2]) : nullptr;
+  Value *Input0 = replaceNode(Builder, Node->Operands[0]);
+  Value *Input1 = Node->Operands.size() > 1
+                      ? replaceNode(Builder, Node->Operands[1])
+                      : nullptr;
+  Value *Accumulator = Node->Operands.size() > 2
+                           ? replaceNode(Builder, Node->Operands[2])
+                           : nullptr;
 
   if (Input1)
     assert(Input0->getType() == Input1->getType() &&
            "Node inputs need to be of the same type");
 
   if (Node->Operation == ComplexDeinterleavingOperation::Symmetric)
-    Node->ReplacementNode = replaceSymmetricNode(Node, Input0, Input1);
+    Node->ReplacementNode = replaceSymmetricNode(Builder, Node, Input0, Input1);
   else
     Node->ReplacementNode = TL->createComplexDeinterleavingIR(
-        Node->Real, Node->Operation, Node->Rotation, Input0, Input1,
-        Accumulator);
+        Builder, Node->Operation, Node->Rotation, Input0, Input1, Accumulator);
 
   assert(Node->ReplacementNode && "Target failed to create Intrinsic call.");
   NumComplexTransformations += 1;
@@ -1074,7 +1074,7 @@ void ComplexDeinterleavingGraph::replaceNodes() {
 
     IRBuilder<> Builder(RootInstruction);
     auto RootNode = RootToNode[RootInstruction];
-    Value *R = replaceNode(RootNode.get());
+    Value *R = replaceNode(Builder, RootNode.get());
     assert(R && "Unable to find replacement for RootInstruction");
     DeadInstrRoots.push_back(RootInstruction);
     RootInstruction->replaceAllUsesWith(R);

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 0a628fc402d69..b8ae8a034e54c 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -25286,14 +25286,12 @@ bool AArch64TargetLowering::isComplexDeinterleavingOperationSupported(
 }
 
 Value *AArch64TargetLowering::createComplexDeinterleavingIR(
-    Instruction *I, ComplexDeinterleavingOperation OperationType,
+    IRBuilderBase &B, ComplexDeinterleavingOperation OperationType,
     ComplexDeinterleavingRotation Rotation, Value *InputA, Value *InputB,
     Value *Accumulator) const {
   VectorType *Ty = cast<VectorType>(InputA->getType());
   bool IsScalable = Ty->isScalableTy();
 
-  IRBuilder<> B(I);
-
   unsigned TyWidth =
       Ty->getScalarSizeInBits() * Ty->getElementCount().getKnownMinValue();
 
@@ -25317,9 +25315,9 @@ Value *AArch64TargetLowering::createComplexDeinterleavingIR(
           B.CreateExtractVector(HalfTy, Accumulator, B.getInt64(Stride));
     }
     auto *LowerSplitInt = createComplexDeinterleavingIR(
-        I, OperationType, Rotation, LowerSplitA, LowerSplitB, LowerSplitAcc);
+        B, OperationType, Rotation, LowerSplitA, LowerSplitB, LowerSplitAcc);
     auto *UpperSplitInt = createComplexDeinterleavingIR(
-        I, OperationType, Rotation, UpperSplitA, UpperSplitB, UpperSplitAcc);
+        B, OperationType, Rotation, UpperSplitA, UpperSplitB, UpperSplitAcc);
 
     auto *Result = B.CreateInsertVector(Ty, PoisonValue::get(Ty), LowerSplitInt,
                                         B.getInt64(0));

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 689c2d1860064..cf766a74d6949 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -843,7 +843,7 @@ class AArch64TargetLowering : public TargetLowering {
       ComplexDeinterleavingOperation Operation, Type *Ty) const override;
 
   Value *createComplexDeinterleavingIR(
-      Instruction *I, ComplexDeinterleavingOperation OperationType,
+      IRBuilderBase &B, ComplexDeinterleavingOperation OperationType,
       ComplexDeinterleavingRotation Rotation, Value *InputA, Value *InputB,
       Value *Accumulator = nullptr) const override;
 

diff  --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp
index 9eab7b0e53d12..9cde9205335fd 100644
--- a/llvm/lib/Target/ARM/ARMISelLowering.cpp
+++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp
@@ -22060,14 +22060,12 @@ bool ARMTargetLowering::isComplexDeinterleavingOperationSupported(
 }
 
 Value *ARMTargetLowering::createComplexDeinterleavingIR(
-    Instruction *I, ComplexDeinterleavingOperation OperationType,
+    IRBuilderBase &B, ComplexDeinterleavingOperation OperationType,
     ComplexDeinterleavingRotation Rotation, Value *InputA, Value *InputB,
     Value *Accumulator) const {
 
   FixedVectorType *Ty = cast<FixedVectorType>(InputA->getType());
 
-  IRBuilder<> B(I);
-
   unsigned TyWidth = Ty->getScalarSizeInBits() * Ty->getNumElements();
 
   assert(TyWidth >= 128 && "Width of vector type must be at least 128 bits");
@@ -22092,9 +22090,9 @@ Value *ARMTargetLowering::createComplexDeinterleavingIR(
     }
 
     auto *LowerSplitInt = createComplexDeinterleavingIR(
-        I, OperationType, Rotation, LowerSplitA, LowerSplitB, LowerSplitAcc);
+        B, OperationType, Rotation, LowerSplitA, LowerSplitB, LowerSplitAcc);
     auto *UpperSplitInt = createComplexDeinterleavingIR(
-        I, OperationType, Rotation, UpperSplitA, UpperSplitB, UpperSplitAcc);
+        B, OperationType, Rotation, UpperSplitA, UpperSplitB, UpperSplitAcc);
 
     ArrayRef<int> JoinMask(&SplitSeqVec[0], Ty->getNumElements());
     return B.CreateShuffleVector(LowerSplitInt, UpperSplitInt, JoinMask);

diff  --git a/llvm/lib/Target/ARM/ARMISelLowering.h b/llvm/lib/Target/ARM/ARMISelLowering.h
index 49fc5a50686a1..2dd54602ef61b 100644
--- a/llvm/lib/Target/ARM/ARMISelLowering.h
+++ b/llvm/lib/Target/ARM/ARMISelLowering.h
@@ -750,7 +750,7 @@ class VectorType;
         ComplexDeinterleavingOperation Operation, Type *Ty) const override;
 
     Value *createComplexDeinterleavingIR(
-        Instruction *I, ComplexDeinterleavingOperation OperationType,
+        IRBuilderBase &B, ComplexDeinterleavingOperation OperationType,
         ComplexDeinterleavingRotation Rotation, Value *InputA, Value *InputB,
         Value *Accumulator = nullptr) const override;
 

diff  --git a/llvm/test/CodeGen/AArch64/complex-deinterleaving-mixed-cases.ll b/llvm/test/CodeGen/AArch64/complex-deinterleaving-mixed-cases.ll
index 9aa6a856bc02c..65012899c97e3 100644
--- a/llvm/test/CodeGen/AArch64/complex-deinterleaving-mixed-cases.ll
+++ b/llvm/test/CodeGen/AArch64/complex-deinterleaving-mixed-cases.ll
@@ -220,11 +220,11 @@ define <4 x float> @mul_add90_mul(<4 x float> %a, <4 x float> %b, <4 x float> %c
 ; CHECK:       // %bb.0: // %entry
 ; CHECK-NEXT:    movi v3.2d, #0000000000000000
 ; CHECK-NEXT:    movi v4.2d, #0000000000000000
-; CHECK-NEXT:    fcmla v3.4s, v1.4s, v0.4s, #0
-; CHECK-NEXT:    fcmla v4.4s, v2.4s, v0.4s, #0
-; CHECK-NEXT:    fcmla v3.4s, v1.4s, v0.4s, #90
-; CHECK-NEXT:    fcmla v4.4s, v2.4s, v0.4s, #90
-; CHECK-NEXT:    fcadd v0.4s, v4.4s, v3.4s, #90
+; CHECK-NEXT:    fcmla v3.4s, v2.4s, v0.4s, #0
+; CHECK-NEXT:    fcmla v4.4s, v1.4s, v0.4s, #0
+; CHECK-NEXT:    fcmla v3.4s, v2.4s, v0.4s, #90
+; CHECK-NEXT:    fcmla v4.4s, v1.4s, v0.4s, #90
+; CHECK-NEXT:    fcadd v0.4s, v3.4s, v4.4s, #90
 ; CHECK-NEXT:    ret
 entry:
   %ar = shufflevector <4 x float> %a, <4 x float> poison, <2 x i32> <i32 0, i32 2>


        


More information about the llvm-commits mailing list