[llvm] b4f9c3a - [CodeGen] Refactor ComplexDeinterleaving to run identification on Values instead of Instructions

Igor Kirillov via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 3 03:36:02 PDT 2023


Author: Igor Kirillov
Date: 2023-07-03T10:35:14Z
New Revision: b4f9c3a933e80de40ee7860db75bb1088cf9bfa7

URL: https://github.com/llvm/llvm-project/commit/b4f9c3a933e80de40ee7860db75bb1088cf9bfa7
DIFF: https://github.com/llvm/llvm-project/commit/b4f9c3a933e80de40ee7860db75bb1088cf9bfa7.diff

LOG: [CodeGen] Refactor ComplexDeinterleaving to run identification on Values instead of Instructions

This change will make it easier to add identification of complex constants
in future patches.

Differential Revision: https://reviews.llvm.org/D153446

Added: 
    

Modified: 
    llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
index 2464ebab221d4a..3a4b94d5eae271 100644
--- a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
+++ b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
@@ -130,7 +130,7 @@ class ComplexDeinterleavingGraph;
 struct ComplexDeinterleavingCompositeNode {
 
   ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
-                                     Instruction *R, Instruction *I)
+                                     Value *R, Value *I)
       : Operation(Op), Real(R), Imag(I) {}
 
 private:
@@ -140,8 +140,8 @@ struct ComplexDeinterleavingCompositeNode {
 
 public:
   ComplexDeinterleavingOperation Operation;
-  Instruction *Real;
-  Instruction *Imag;
+  Value *Real;
+  Value *Imag;
 
   // This two members are required exclusively for generating
   // ComplexDeinterleavingOperation::Symmetric operations.
@@ -192,19 +192,19 @@ struct ComplexDeinterleavingCompositeNode {
 class ComplexDeinterleavingGraph {
 public:
   struct Product {
-    Instruction *Multiplier;
-    Instruction *Multiplicand;
+    Value *Multiplier;
+    Value *Multiplicand;
     bool IsPositive;
   };
 
-  using Addend = std::pair<Instruction *, bool>;
+  using Addend = std::pair<Value *, bool>;
   using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr;
   using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr;
 
   // Helper struct for holding info about potential partial multiplication
   // candidates
   struct PartialMulCandidate {
-    Instruction *Common;
+    Value *Common;
     NodePtr Node;
     unsigned RealIdx;
     unsigned ImagIdx;
@@ -270,7 +270,7 @@ class ComplexDeinterleavingGraph {
   std::map<PHINode *, PHINode *> OldToNewPHI;
 
   NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation,
-                               Instruction *R, Instruction *I) {
+                               Value *R, Value *I) {
     assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI &&
              Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
             (R && I)) &&
@@ -308,9 +308,9 @@ class ComplexDeinterleavingGraph {
   /// Identify the other branch of a Partial Mul, taking the CommonOperandI that
   /// is partially known from identifyPartialMul, filling in the other half of
   /// the complex pair.
-  NodePtr identifyNodeWithImplicitAdd(
-      Instruction *I, Instruction *J,
-      std::pair<Instruction *, Instruction *> &CommonOperandI);
+  NodePtr
+  identifyNodeWithImplicitAdd(Instruction *I, Instruction *J,
+                              std::pair<Value *, Value *> &CommonOperandI);
 
   /// Identifies a complex add pattern and its rotation, based on the following
   /// patterns.
@@ -322,7 +322,7 @@ class ComplexDeinterleavingGraph {
   NodePtr identifyAdd(Instruction *Real, Instruction *Imag);
   NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag);
 
-  NodePtr identifyNode(Instruction *I, Instruction *J);
+  NodePtr identifyNode(Value *R, Value *I);
 
   /// Determine if a sum of complex numbers can be formed from \p RealAddends
   /// and \p ImagAddens. If \p Accumulator is not null, add the result to it.
@@ -521,7 +521,7 @@ bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) {
 ComplexDeinterleavingGraph::NodePtr
 ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
     Instruction *Real, Instruction *Imag,
-    std::pair<Instruction *, Instruction *> &PartialMatch) {
+    std::pair<Value *, Value *> &PartialMatch) {
   LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag
                     << "\n");
 
@@ -536,52 +536,38 @@ ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
     return nullptr;
   }
 
-  Instruction *R0 = dyn_cast<Instruction>(Real->getOperand(0));
-  Instruction *R1 = dyn_cast<Instruction>(Real->getOperand(1));
-  Instruction *I0 = dyn_cast<Instruction>(Imag->getOperand(0));
-  Instruction *I1 = dyn_cast<Instruction>(Imag->getOperand(1));
-  if (!R0 || !R1 || !I0 || !I1) {
-    LLVM_DEBUG(dbgs() << "  - Mul operand not Instruction\n");
-    return nullptr;
-  }
+  Value *R0 = Real->getOperand(0);
+  Value *R1 = Real->getOperand(1);
+  Value *I0 = Imag->getOperand(0);
+  Value *I1 = Imag->getOperand(1);
 
   // A +/+ has a rotation of 0. If any of the operands are fneg, we flip the
   // rotations and use the operand.
   unsigned Negs = 0;
-  SmallVector<Instruction *> FNegs;
-  if (R0->getOpcode() == Instruction::FNeg ||
-      R1->getOpcode() == Instruction::FNeg) {
+  Value *Op;
+  if (match(R0, m_Neg(m_Value(Op)))) {
     Negs |= 1;
-    if (R0->getOpcode() == Instruction::FNeg) {
-      FNegs.push_back(R0);
-      R0 = dyn_cast<Instruction>(R0->getOperand(0));
-    } else {
-      FNegs.push_back(R1);
-      R1 = dyn_cast<Instruction>(R1->getOperand(0));
-    }
-    if (!R0 || !R1)
-      return nullptr;
+    R0 = Op;
+  } else if (match(R1, m_Neg(m_Value(Op)))) {
+    Negs |= 1;
+    R1 = Op;
   }
-  if (I0->getOpcode() == Instruction::FNeg ||
-      I1->getOpcode() == Instruction::FNeg) {
+
+  if (match(I0, m_Neg(m_Value(Op)))) {
     Negs |= 2;
     Negs ^= 1;
-    if (I0->getOpcode() == Instruction::FNeg) {
-      FNegs.push_back(I0);
-      I0 = dyn_cast<Instruction>(I0->getOperand(0));
-    } else {
-      FNegs.push_back(I1);
-      I1 = dyn_cast<Instruction>(I1->getOperand(0));
-    }
-    if (!I0 || !I1)
-      return nullptr;
+    I0 = Op;
+  } else if (match(I1, m_Neg(m_Value(Op)))) {
+    Negs |= 2;
+    Negs ^= 1;
+    I1 = Op;
   }
 
   ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs;
 
-  Instruction *CommonOperand;
-  Instruction *UncommonRealOp;
-  Instruction *UncommonImagOp;
+  Value *CommonOperand;
+  Value *UncommonRealOp;
+  Value *UncommonImagOp;
 
   if (R0 == I0 || R0 == I1) {
     CommonOperand = R0;
@@ -676,18 +662,14 @@ ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
     return nullptr;
   }
 
-  Instruction *R0 = dyn_cast<Instruction>(RealMulI->getOperand(0));
-  Instruction *R1 = dyn_cast<Instruction>(RealMulI->getOperand(1));
-  Instruction *I0 = dyn_cast<Instruction>(ImagMulI->getOperand(0));
-  Instruction *I1 = dyn_cast<Instruction>(ImagMulI->getOperand(1));
-  if (!R0 || !R1 || !I0 || !I1) {
-    LLVM_DEBUG(dbgs() << "  - Mul operand not Instruction\n");
-    return nullptr;
-  }
+  Value *R0 = RealMulI->getOperand(0);
+  Value *R1 = RealMulI->getOperand(1);
+  Value *I0 = ImagMulI->getOperand(0);
+  Value *I1 = ImagMulI->getOperand(1);
 
-  Instruction *CommonOperand;
-  Instruction *UncommonRealOp;
-  Instruction *UncommonImagOp;
+  Value *CommonOperand;
+  Value *UncommonRealOp;
+  Value *UncommonImagOp;
 
   if (R0 == I0 || R0 == I1) {
     CommonOperand = R0;
@@ -705,7 +687,7 @@ ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
       Rotation == ComplexDeinterleavingRotation::Rotation_270)
     std::swap(UncommonRealOp, UncommonImagOp);
 
-  std::pair<Instruction *, Instruction *> PartialMatch(
+  std::pair<Value *, Value *> PartialMatch(
       (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
        Rotation == ComplexDeinterleavingRotation::Rotation_180)
           ? CommonOperand
@@ -840,11 +822,8 @@ ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real,
       !isInstructionPotentiallySymmetric(Imag))
     return nullptr;
 
-  auto *R0 = dyn_cast<Instruction>(Real->getOperand(0));
-  auto *I0 = dyn_cast<Instruction>(Imag->getOperand(0));
-
-  if (!R0 || !I0)
-    return nullptr;
+  auto *R0 = Real->getOperand(0);
+  auto *I0 = Imag->getOperand(0);
 
   NodePtr Op0 = identifyNode(R0, I0);
   NodePtr Op1 = nullptr;
@@ -852,11 +831,8 @@ ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real,
     return nullptr;
 
   if (Real->isBinaryOp()) {
-    auto *R1 = dyn_cast<Instruction>(Real->getOperand(1));
-    auto *I1 = dyn_cast<Instruction>(Imag->getOperand(1));
-    if (!R1 || !I1)
-      return nullptr;
-
+    auto *R1 = Real->getOperand(1);
+    auto *I1 = Imag->getOperand(1);
     Op1 = identifyNode(R1, I1);
     if (Op1 == nullptr)
       return nullptr;
@@ -880,15 +856,20 @@ ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real,
 }
 
 ComplexDeinterleavingGraph::NodePtr
-ComplexDeinterleavingGraph::identifyNode(Instruction *Real, Instruction *Imag) {
-  LLVM_DEBUG(dbgs() << "identifyNode on " << *Real << " / " << *Imag << "\n");
-  assert(Real->getType() == Imag->getType() &&
+ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) {
+  LLVM_DEBUG(dbgs() << "identifyNode on " << *R << " / " << *I << "\n");
+  assert(R->getType() == I->getType() &&
          "Real and imaginary parts should not have 
diff erent types");
-  if (NodePtr CN = getContainingComposite(Real, Imag)) {
+  if (NodePtr CN = getContainingComposite(R, I)) {
     LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
     return CN;
   }
 
+  auto *Real = dyn_cast<Instruction>(R);
+  auto *Imag = dyn_cast<Instruction>(I);
+  if (!Real || !Imag)
+    return nullptr;
+
   if (NodePtr CN = identifyDeinterleave(Real, Imag))
     return CN;
 
@@ -931,6 +912,7 @@ ComplexDeinterleavingGraph::identifyNode(Instruction *Real, Instruction *Imag) {
 ComplexDeinterleavingGraph::NodePtr
 ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
                                                  Instruction *Imag) {
+
   if ((Real->getOpcode() != Instruction::FAdd &&
        Real->getOpcode() != Instruction::FSub &&
        Real->getOpcode() != Instruction::FNeg) ||
@@ -967,8 +949,10 @@ ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
         continue;
 
       Instruction *I = dyn_cast<Instruction>(V);
-      if (!I)
-        return false;
+      if (!I) {
+        Addends.emplace_back(V, IsPositive);
+        continue;
+      }
 
       // If an instruction has more than one user, it indicates that it either
       // has an external user, which will be later checked by the checkNodes
@@ -989,20 +973,18 @@ ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
         Worklist.emplace_back(I->getOperand(1), !IsPositive);
         Worklist.emplace_back(I->getOperand(0), IsPositive);
       } else if (I->getOpcode() == Instruction::FMul) {
-        auto *A = dyn_cast<Instruction>(I->getOperand(0));
-        if (A && A->getOpcode() == Instruction::FNeg) {
-          A = dyn_cast<Instruction>(A->getOperand(0));
+        Value *A, *B;
+        if (match(I->getOperand(0), m_FNeg(m_Value(A)))) {
           IsPositive = !IsPositive;
+        } else {
+          A = I->getOperand(0);
         }
-        if (!A)
-          return false;
-        auto *B = dyn_cast<Instruction>(I->getOperand(1));
-        if (B && B->getOpcode() == Instruction::FNeg) {
-          B = dyn_cast<Instruction>(B->getOperand(0));
+
+        if (match(I->getOperand(1), m_FNeg(m_Value(B)))) {
           IsPositive = !IsPositive;
+        } else {
+          B = I->getOperand(1);
         }
-        if (!B)
-          return false;
         Muls.push_back(Product{A, B, IsPositive});
       } else if (I->getOpcode() == Instruction::FNeg) {
         Worklist.emplace_back(I->getOperand(0), !IsPositive);
@@ -1059,7 +1041,7 @@ bool ComplexDeinterleavingGraph::collectPartialMuls(
     std::vector<PartialMulCandidate> &PartialMulCandidates) {
   // Helper function to extract a common operand from two products
   auto FindCommonInstruction = [](const Product &Real,
-                                  const Product &Imag) -> Instruction * {
+                                  const Product &Imag) -> Value * {
     if (Real.Multiplicand == Imag.Multiplicand ||
         Real.Multiplicand == Imag.Multiplier)
       return Real.Multiplicand;
@@ -1087,18 +1069,17 @@ bool ComplexDeinterleavingGraph::collectPartialMuls(
       auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier
                                                    : ImagMuls[j].Multiplicand;
 
-      bool Inverted = false;
       auto Node = identifyNode(A, B);
-      if (!Node) {
-        std::swap(A, B);
-        Inverted = true;
-        Node = identifyNode(A, B);
+      if (Node) {
+        FoundCommon = true;
+        PartialMulCandidates.push_back({Common, Node, i, j, false});
       }
-      if (!Node)
-        continue;
 
-      FoundCommon = true;
-      PartialMulCandidates.push_back({Common, Node, i, j, Inverted});
+      Node = identifyNode(B, A);
+      if (Node) {
+        FoundCommon = true;
+        PartialMulCandidates.push_back({Common, Node, i, j, true});
+      }
     }
     if (!FoundCommon)
       return false;
@@ -1118,7 +1099,7 @@ ComplexDeinterleavingGraph::identifyMultiplications(
     return nullptr;
 
   // Map to store common instruction to node pointers
-  std::map<Instruction *, NodePtr> CommonToNode;
+  std::map<Value *, NodePtr> CommonToNode;
   std::vector<bool> Processed(Info.size(), false);
   for (unsigned I = 0; I < Info.size(); ++I) {
     if (Processed[I])
@@ -1838,8 +1819,8 @@ Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
     processReductionOperation(ReplacementNode, Node);
     break;
   case ComplexDeinterleavingOperation::ReductionSelect: {
-    auto *MaskReal = Node->Real->getOperand(0);
-    auto *MaskImag = Node->Imag->getOperand(0);
+    auto *MaskReal = cast<Instruction>(Node->Real)->getOperand(0);
+    auto *MaskImag = cast<Instruction>(Node->Imag)->getOperand(0);
     auto *A = replaceNode(Builder, Node->Operands[0]);
     auto *B = replaceNode(Builder, Node->Operands[1]);
     auto *NewMaskTy = VectorType::getDoubleElementsVectorType(
@@ -1860,11 +1841,13 @@ Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
 
 void ComplexDeinterleavingGraph::processReductionOperation(
     Value *OperationReplacement, RawNodePtr Node) {
-  auto *OldPHIReal = ReductionInfo[Node->Real].first;
-  auto *OldPHIImag = ReductionInfo[Node->Imag].first;
+  auto *Real = cast<Instruction>(Node->Real);
+  auto *Imag = cast<Instruction>(Node->Imag);
+  auto *OldPHIReal = ReductionInfo[Real].first;
+  auto *OldPHIImag = ReductionInfo[Imag].first;
   auto *NewPHI = OldToNewPHI[OldPHIReal];
 
-  auto *VTy = cast<VectorType>(Node->Real->getType());
+  auto *VTy = cast<VectorType>(Real->getType());
   auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
 
   // We have to interleave initial origin values coming from IncomingBlock
@@ -1880,8 +1863,8 @@ void ComplexDeinterleavingGraph::processReductionOperation(
 
   // Deinterleave complex vector outside of loop so that it can be finally
   // reduced
-  auto *FinalReductionReal = ReductionInfo[Node->Real].second;
-  auto *FinalReductionImag = ReductionInfo[Node->Imag].second;
+  auto *FinalReductionReal = ReductionInfo[Real].second;
+  auto *FinalReductionImag = ReductionInfo[Imag].second;
 
   Builder.SetInsertPoint(
       &*FinalReductionReal->getParent()->getFirstInsertionPt());
@@ -1890,11 +1873,11 @@ void ComplexDeinterleavingGraph::processReductionOperation(
       OperationReplacement->getType(), OperationReplacement);
 
   auto *NewReal = Builder.CreateExtractValue(Deinterleave, (uint64_t)0);
-  FinalReductionReal->replaceUsesOfWith(Node->Real, NewReal);
+  FinalReductionReal->replaceUsesOfWith(Real, NewReal);
 
   Builder.SetInsertPoint(FinalReductionImag);
   auto *NewImag = Builder.CreateExtractValue(Deinterleave, 1);
-  FinalReductionImag->replaceUsesOfWith(Node->Imag, NewImag);
+  FinalReductionImag->replaceUsesOfWith(Imag, NewImag);
 }
 
 void ComplexDeinterleavingGraph::replaceNodes() {
@@ -1911,10 +1894,12 @@ void ComplexDeinterleavingGraph::replaceNodes() {
 
     if (RootNode->Operation ==
         ComplexDeinterleavingOperation::ReductionOperation) {
-      ReductionInfo[RootNode->Real].first->removeIncomingValue(BackEdge);
-      ReductionInfo[RootNode->Imag].first->removeIncomingValue(BackEdge);
-      DeadInstrRoots.push_back(RootNode->Real);
-      DeadInstrRoots.push_back(RootNode->Imag);
+      auto *RootReal = cast<Instruction>(RootNode->Real);
+      auto *RootImag = cast<Instruction>(RootNode->Imag);
+      ReductionInfo[RootReal].first->removeIncomingValue(BackEdge);
+      ReductionInfo[RootImag].first->removeIncomingValue(BackEdge);
+      DeadInstrRoots.push_back(cast<Instruction>(RootReal));
+      DeadInstrRoots.push_back(cast<Instruction>(RootImag));
     } else {
       assert(R && "Unable to find replacement for RootInstruction");
       DeadInstrRoots.push_back(RootInstruction);


        


More information about the llvm-commits mailing list