[llvm] eb2af3a - [ComplexDeinterleaving] Use BumpPtrAllocator for CompositeNodes (NFC) (#153217)

via llvm-commits llvm-commits at lists.llvm.org
Wed Aug 20 00:59:35 PDT 2025


Author: Benjamin Maxwell
Date: 2025-08-20T08:59:31+01:00
New Revision: eb2af3a5beaaa71a6088b8e0c940b209fdbea3c8

URL: https://github.com/llvm/llvm-project/commit/eb2af3a5beaaa71a6088b8e0c940b209fdbea3c8
DIFF: https://github.com/llvm/llvm-project/commit/eb2af3a5beaaa71a6088b8e0c940b209fdbea3c8.diff

LOG: [ComplexDeinterleaving] Use BumpPtrAllocator for CompositeNodes (NFC) (#153217)

I was looking over this pass and noticed it was using shared pointers
for CompositeNodes. However, all nodes are owned by the deinterleaving
graph and are not released until the graph is destroyed. This means a
bump allocator and raw pointers can be used, which have a simpler
ownership model and less overhead than shared pointers.

The changes in this PR are to:
- Add a `SpecificBumpPtrAllocator<CompositeNode>` to the
`ComplexDeinterleavingGraph`
- This allocates new nodes and will deallocate them when the graph is
destroyed
  - Replace `NodePtr` and `RawNodePtr` with  `CompositeNode *`

Added: 
    

Modified: 
    llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
index cd21e254c446e..de95e0aaf2cba 100644
--- a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
+++ b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
@@ -70,6 +70,7 @@
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/PatternMatch.h"
 #include "llvm/InitializePasses.h"
+#include "llvm/Support/Allocator.h"
 #include "llvm/Target/TargetMachine.h"
 #include "llvm/Transforms/Utils/Local.h"
 #include <algorithm>
@@ -192,8 +193,7 @@ struct ComplexDeinterleavingCompositeNode {
 
 private:
   friend class ComplexDeinterleavingGraph;
-  using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>;
-  using RawNodePtr = ComplexDeinterleavingCompositeNode *;
+  using CompositeNode = ComplexDeinterleavingCompositeNode;
   bool OperandsValid = true;
 
 public:
@@ -207,13 +207,13 @@ struct ComplexDeinterleavingCompositeNode {
 
   ComplexDeinterleavingRotation Rotation =
       ComplexDeinterleavingRotation::Rotation_0;
-  SmallVector<RawNodePtr> Operands;
+  SmallVector<CompositeNode *> Operands;
   Value *ReplacementNode = nullptr;
 
-  void addOperand(NodePtr Node) {
-    if (!Node || !Node.get())
+  void addOperand(CompositeNode *Node) {
+    if (!Node)
       OperandsValid = false;
-    Operands.push_back(Node.get());
+    Operands.push_back(Node);
   }
 
   void dump() { dump(dbgs()); }
@@ -226,7 +226,7 @@ struct ComplexDeinterleavingCompositeNode {
       } else
         OS << "nullptr\n";
     };
-    auto PrintNodeRef = [&](RawNodePtr Ptr) {
+    auto PrintNodeRef = [&](CompositeNode *Ptr) {
       if (Ptr)
         OS << Ptr << "\n";
       else
@@ -263,14 +263,13 @@ class ComplexDeinterleavingGraph {
   };
 
   using Addend = std::pair<Value *, bool>;
-  using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr;
-  using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr;
+  using CompositeNode = ComplexDeinterleavingCompositeNode::CompositeNode;
 
   // Helper struct for holding info about potential partial multiplication
   // candidates
   struct PartialMulCandidate {
     Value *Common;
-    NodePtr Node;
+    CompositeNode *Node;
     unsigned RealIdx;
     unsigned ImagIdx;
     bool IsNodeInverted;
@@ -285,13 +284,14 @@ class ComplexDeinterleavingGraph {
   const TargetLowering *TL = nullptr;
   const TargetLibraryInfo *TLI = nullptr;
   unsigned Factor;
-  SmallVector<NodePtr> CompositeNodes;
-  DenseMap<ComplexValues, NodePtr> CachedResult;
+  SmallVector<CompositeNode *> CompositeNodes;
+  DenseMap<ComplexValues, CompositeNode *> CachedResult;
+  SpecificBumpPtrAllocator<ComplexDeinterleavingCompositeNode> Allocator;
 
   SmallPtrSet<Instruction *, 16> FinalInstructions;
 
   /// Root instructions are instructions from which complex computation starts
-  std::map<Instruction *, NodePtr> RootToNode;
+  std::map<Instruction *, CompositeNode *> RootToNode;
 
   /// Topologically sorted root instructions
   SmallVector<Instruction *, 1> OrderedRoots;
@@ -341,18 +341,18 @@ class ComplexDeinterleavingGraph {
   /// replacement process.
   std::map<PHINode *, PHINode *> OldToNewPHI;
 
-  NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation,
-                               Value *R, Value *I) {
+  CompositeNode *prepareCompositeNode(ComplexDeinterleavingOperation Operation,
+                                      Value *R, Value *I) {
     assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI &&
              Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
             (R && I)) &&
            "Reduction related nodes must have Real and Imaginary parts");
-    return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation, R,
-                                                                I);
+    return new (Allocator.Allocate())
+        ComplexDeinterleavingCompositeNode(Operation, R, I);
   }
 
-  NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation,
-                               ComplexValues &Vals) {
+  CompositeNode *prepareCompositeNode(ComplexDeinterleavingOperation Operation,
+                                      ComplexValues &Vals) {
 #ifndef NDEBUG
     for (auto &V : Vals) {
       assert(
@@ -362,11 +362,11 @@ class ComplexDeinterleavingGraph {
           "Reduction related nodes must have Real and Imaginary parts");
     }
 #endif
-    return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation,
-                                                                Vals);
+    return new (Allocator.Allocate())
+        ComplexDeinterleavingCompositeNode(Operation, Vals);
   }
 
-  NodePtr submitCompositeNode(NodePtr Node) {
+  CompositeNode *submitCompositeNode(CompositeNode *Node) {
     CompositeNodes.push_back(Node);
     if (Node->Vals[0].Real)
       CachedResult[Node->Vals] = Node;
@@ -384,12 +384,12 @@ class ComplexDeinterleavingGraph {
   ///      i: ci - ar * bi
   /// 270: r: cr + ai * bi
   ///      i: ci - ai * br
-  NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag);
+  CompositeNode *identifyPartialMul(Instruction *Real, Instruction *Imag);
 
   /// Identify the other branch of a Partial Mul, taking the CommonOperandI that
   /// is partially known from identifyPartialMul, filling in the other half of
   /// the complex pair.
-  NodePtr
+  CompositeNode *
   identifyNodeWithImplicitAdd(Instruction *I, Instruction *J,
                               std::pair<Value *, Value *> &CommonOperandI);
 
@@ -400,14 +400,14 @@ class ComplexDeinterleavingGraph {
   ///      i: ai + br
   /// 270: r: ar + bi
   ///      i: ai - br
-  NodePtr identifyAdd(Instruction *Real, Instruction *Imag);
-  NodePtr identifySymmetricOperation(ComplexValues &Vals);
-  NodePtr identifyPartialReduction(Value *R, Value *I);
-  NodePtr identifyDotProduct(Value *Inst);
+  CompositeNode *identifyAdd(Instruction *Real, Instruction *Imag);
+  CompositeNode *identifySymmetricOperation(ComplexValues &Vals);
+  CompositeNode *identifyPartialReduction(Value *R, Value *I);
+  CompositeNode *identifyDotProduct(Value *Inst);
 
-  NodePtr identifyNode(ComplexValues &Vals);
+  CompositeNode *identifyNode(ComplexValues &Vals);
 
-  NodePtr identifyNode(Value *R, Value *I) {
+  CompositeNode *identifyNode(Value *R, Value *I) {
     ComplexValues Vals;
     Vals.push_back({R, I});
     return identifyNode(Vals);
@@ -417,21 +417,21 @@ class ComplexDeinterleavingGraph {
   /// and \p ImagAddens. If \p Accumulator is not null, add the result to it.
   /// Return nullptr if it is not possible to construct a complex number.
   /// \p Flags are needed to generate symmetric Add and Sub operations.
-  NodePtr identifyAdditions(std::list<Addend> &RealAddends,
-                            std::list<Addend> &ImagAddends,
-                            std::optional<FastMathFlags> Flags,
-                            NodePtr Accumulator);
+  CompositeNode *identifyAdditions(std::list<Addend> &RealAddends,
+                                   std::list<Addend> &ImagAddends,
+                                   std::optional<FastMathFlags> Flags,
+                                   CompositeNode *Accumulator);
 
   /// Extract one addend that have both real and imaginary parts positive.
-  NodePtr extractPositiveAddend(std::list<Addend> &RealAddends,
-                                std::list<Addend> &ImagAddends);
+  CompositeNode *extractPositiveAddend(std::list<Addend> &RealAddends,
+                                       std::list<Addend> &ImagAddends);
 
   /// Determine if sum of multiplications of complex numbers can be formed from
   /// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result
   /// to it. Return nullptr if it is not possible to construct a complex number.
-  NodePtr identifyMultiplications(std::vector<Product> &RealMuls,
-                                  std::vector<Product> &ImagMuls,
-                                  NodePtr Accumulator);
+  CompositeNode *identifyMultiplications(std::vector<Product> &RealMuls,
+                                         std::vector<Product> &ImagMuls,
+                                         CompositeNode *Accumulator);
 
   /// Go through pairs of multiplication (one Real and one Imag) and find all
   /// possible candidates for partial multiplication and put them into \p
@@ -446,9 +446,9 @@ class ComplexDeinterleavingGraph {
   /// function takes this into consideration and employs a more general approach
   /// to identify complex computations. Initially, it gathers all the addends
   /// and multiplicands and then constructs a complex expression from them.
-  NodePtr identifyReassocNodes(Instruction *I, Instruction *J);
+  CompositeNode *identifyReassocNodes(Instruction *I, Instruction *J);
 
-  NodePtr identifyRoot(Instruction *I);
+  CompositeNode *identifyRoot(Instruction *I);
 
   /// Identifies the Deinterleave operation applied to a vector containing
   /// complex numbers. There are two ways to represent the Deinterleave
@@ -458,29 +458,30 @@ class ComplexDeinterleavingGraph {
   /// * Using N extractvalue instructions applied to `vector.deinterleaveN`
   /// intrinsics (for both fixed and scalable vectors) where N is a multiple of
   /// 2.
-  NodePtr identifyDeinterleave(ComplexValues &Vals);
+  CompositeNode *identifyDeinterleave(ComplexValues &Vals);
 
   /// identifying the operation that represents a complex number repeated in a
   /// Splat vector. There are two possible types of splats: ConstantExpr with
   /// the opcode ShuffleVector and ShuffleVectorInstr. Both should have an
   /// initialization mask with all values set to zero.
-  NodePtr identifySplat(ComplexValues &Vals);
+  CompositeNode *identifySplat(ComplexValues &Vals);
 
-  NodePtr identifyPHINode(Instruction *Real, Instruction *Imag);
+  CompositeNode *identifyPHINode(Instruction *Real, Instruction *Imag);
 
   /// Identifies SelectInsts in a loop that has reduction with predication masks
   /// and/or predicated tail folding
-  NodePtr identifySelectNode(Instruction *Real, Instruction *Imag);
+  CompositeNode *identifySelectNode(Instruction *Real, Instruction *Imag);
 
-  Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node);
+  Value *replaceNode(IRBuilderBase &Builder, CompositeNode *Node);
 
   /// Complete IR modifications after producing new reduction operation:
   /// * Populate the PHINode generated for
   /// ComplexDeinterleavingOperation::ReductionPHI
   /// * Deinterleave the final value outside of the loop and repurpose original
   /// reduction users
-  void processReductionOperation(Value *OperationReplacement, RawNodePtr Node);
-  void processReductionSingle(Value *OperationReplacement, RawNodePtr Node);
+  void processReductionOperation(Value *OperationReplacement,
+                                 CompositeNode *Node);
+  void processReductionSingle(Value *OperationReplacement, CompositeNode *Node);
 
 public:
   void dump() { dump(dbgs()); }
@@ -637,7 +638,7 @@ bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B, unsigned Factor) {
   return false;
 }
 
-ComplexDeinterleavingGraph::NodePtr
+ComplexDeinterleavingGraph::CompositeNode *
 ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
     Instruction *Real, Instruction *Imag,
     std::pair<Value *, Value *> &PartialMatch) {
@@ -720,19 +721,20 @@ ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
     return nullptr;
   }
 
-  NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second);
+  CompositeNode *CommonNode =
+      identifyNode(PartialMatch.first, PartialMatch.second);
   if (!CommonNode) {
     LLVM_DEBUG(dbgs() << "  - No CommonNode identified\n");
     return nullptr;
   }
 
-  NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);
+  CompositeNode *UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);
   if (!UncommonNode) {
     LLVM_DEBUG(dbgs() << "  - No UncommonNode identified\n");
     return nullptr;
   }
 
-  NodePtr Node = prepareCompositeNode(
+  CompositeNode *Node = prepareCompositeNode(
       ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
   Node->Rotation = Rotation;
   Node->addOperand(CommonNode);
@@ -740,7 +742,7 @@ ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
   return submitCompositeNode(Node);
 }
 
-ComplexDeinterleavingGraph::NodePtr
+ComplexDeinterleavingGraph::CompositeNode *
 ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
                                                Instruction *Imag) {
   LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag
@@ -831,26 +833,28 @@ ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
     return nullptr;
   }
 
-  NodePtr CNode = identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch);
+  CompositeNode *CNode =
+      identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch);
   if (!CNode) {
     LLVM_DEBUG(dbgs() << "  - No cnode identified\n");
     return nullptr;
   }
 
-  NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);
+  CompositeNode *UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);
   if (!UncommonRes) {
     LLVM_DEBUG(dbgs() << "  - No UncommonRes identified\n");
     return nullptr;
   }
 
   assert(PartialMatch.first && PartialMatch.second);
-  NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second);
+  CompositeNode *CommonRes =
+      identifyNode(PartialMatch.first, PartialMatch.second);
   if (!CommonRes) {
     LLVM_DEBUG(dbgs() << "  - No CommonRes identified\n");
     return nullptr;
   }
 
-  NodePtr Node = prepareCompositeNode(
+  CompositeNode *Node = prepareCompositeNode(
       ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
   Node->Rotation = Rotation;
   Node->addOperand(CommonRes);
@@ -859,7 +863,7 @@ ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
   return submitCompositeNode(Node);
 }
 
-ComplexDeinterleavingGraph::NodePtr
+ComplexDeinterleavingGraph::CompositeNode *
 ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
   LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n");
 
@@ -890,18 +894,18 @@ ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
     return nullptr;
   }
 
-  NodePtr ResA = identifyNode(AR, AI);
+  CompositeNode *ResA = identifyNode(AR, AI);
   if (!ResA) {
     LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n");
     return nullptr;
   }
-  NodePtr ResB = identifyNode(BR, BI);
+  CompositeNode *ResB = identifyNode(BR, BI);
   if (!ResB) {
     LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n");
     return nullptr;
   }
 
-  NodePtr Node =
+  CompositeNode *Node =
       prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag);
   Node->Rotation = Rotation;
   Node->addOperand(ResA);
@@ -941,7 +945,7 @@ static bool isInstructionPotentiallySymmetric(Instruction *I) {
   }
 }
 
-ComplexDeinterleavingGraph::NodePtr
+ComplexDeinterleavingGraph::CompositeNode *
 ComplexDeinterleavingGraph::identifySymmetricOperation(ComplexValues &Vals) {
   auto *FirstReal = cast<Instruction>(Vals[0].Real);
   unsigned FirstOpc = FirstReal->getOpcode();
@@ -968,8 +972,8 @@ ComplexDeinterleavingGraph::identifySymmetricOperation(ComplexValues &Vals) {
     OpVals.push_back({R0, I0});
   }
 
-  NodePtr Op0 = identifyNode(OpVals);
-  NodePtr Op1 = nullptr;
+  CompositeNode *Op0 = identifyNode(OpVals);
+  CompositeNode *Op1 = nullptr;
   if (Op0 == nullptr)
     return nullptr;
 
@@ -998,7 +1002,7 @@ ComplexDeinterleavingGraph::identifySymmetricOperation(ComplexValues &Vals) {
   return submitCompositeNode(Node);
 }
 
-ComplexDeinterleavingGraph::NodePtr
+ComplexDeinterleavingGraph::CompositeNode *
 ComplexDeinterleavingGraph::identifyDotProduct(Value *V) {
   if (!TL->isComplexDeinterleavingOperationSupported(
           ComplexDeinterleavingOperation::CDot, V->getType())) {
@@ -1011,10 +1015,10 @@ ComplexDeinterleavingGraph::identifyDotProduct(Value *V) {
   auto *Inst = cast<Instruction>(V);
   auto *RealUser = cast<Instruction>(*Inst->user_begin());
 
-  NodePtr CN =
+  CompositeNode *CN =
       prepareCompositeNode(ComplexDeinterleavingOperation::CDot, Inst, nullptr);
 
-  NodePtr ANode;
+  CompositeNode *ANode = nullptr;
 
   const Intrinsic::ID PartialReduceInt =
       Intrinsic::experimental_vector_partial_reduce_add;
@@ -1098,7 +1102,7 @@ ComplexDeinterleavingGraph::identifyDotProduct(Value *V) {
   if (Phi->getType() != VTy && RealUser->getType() != VTy)
     return nullptr;
 
-  NodePtr Node = identifyNode(AReal, AImag);
+  CompositeNode *Node = identifyNode(AReal, AImag);
 
   // In the case that a node was identified to figure out the rotation, ensure
   // that trying to identify a node with AReal and AImag post-unwrap results in
@@ -1118,7 +1122,7 @@ ComplexDeinterleavingGraph::identifyDotProduct(Value *V) {
   return submitCompositeNode(CN);
 }
 
-ComplexDeinterleavingGraph::NodePtr
+ComplexDeinterleavingGraph::CompositeNode *
 ComplexDeinterleavingGraph::identifyPartialReduction(Value *R, Value *I) {
   // Partial reductions don't support non-vector types, so check these first
   if (!isa<VectorType>(R->getType()) || !isa<VectorType>(I->getType()))
@@ -1137,13 +1141,13 @@ ComplexDeinterleavingGraph::identifyPartialReduction(Value *R, Value *I) {
                     Intrinsic::experimental_vector_partial_reduce_add)
     return nullptr;
 
-  if (NodePtr CN = identifyDotProduct(IInst))
+  if (CompositeNode *CN = identifyDotProduct(IInst))
     return CN;
 
   return nullptr;
 }
 
-ComplexDeinterleavingGraph::NodePtr
+ComplexDeinterleavingGraph::CompositeNode *
 ComplexDeinterleavingGraph::identifyNode(ComplexValues &Vals) {
   auto It = CachedResult.find(Vals);
   if (It != CachedResult.end()) {
@@ -1155,14 +1159,14 @@ ComplexDeinterleavingGraph::identifyNode(ComplexValues &Vals) {
     assert(Factor == 2 && "Can only handle interleave factors of 2");
     Value *R = Vals[0].Real;
     Value *I = Vals[0].Imag;
-    if (NodePtr CN = identifyPartialReduction(R, I))
+    if (CompositeNode *CN = identifyPartialReduction(R, I))
       return CN;
     bool IsReduction = RealPHI == R && (!ImagPHI || ImagPHI == I);
     if (!IsReduction && R->getType() != I->getType())
       return nullptr;
   }
 
-  if (NodePtr CN = identifySplat(Vals))
+  if (CompositeNode *CN = identifySplat(Vals))
     return CN;
 
   for (auto &V : Vals) {
@@ -1172,17 +1176,17 @@ ComplexDeinterleavingGraph::identifyNode(ComplexValues &Vals) {
       return nullptr;
   }
 
-  if (NodePtr CN = identifyDeinterleave(Vals))
+  if (CompositeNode *CN = identifyDeinterleave(Vals))
     return CN;
 
   if (Vals.size() == 1) {
     assert(Factor == 2 && "Can only handle interleave factors of 2");
     auto *Real = dyn_cast<Instruction>(Vals[0].Real);
     auto *Imag = dyn_cast<Instruction>(Vals[0].Imag);
-    if (NodePtr CN = identifyPHINode(Real, Imag))
+    if (CompositeNode *CN = identifyPHINode(Real, Imag))
       return CN;
 
-    if (NodePtr CN = identifySelectNode(Real, Imag))
+    if (CompositeNode *CN = identifySelectNode(Real, Imag))
       return CN;
 
     auto *VTy = cast<VectorType>(Real->getType());
@@ -1194,23 +1198,23 @@ ComplexDeinterleavingGraph::identifyNode(ComplexValues &Vals) {
         ComplexDeinterleavingOperation::CAdd, NewVTy);
 
     if (HasCMulSupport && isInstructionPairMul(Real, Imag)) {
-      if (NodePtr CN = identifyPartialMul(Real, Imag))
+      if (CompositeNode *CN = identifyPartialMul(Real, Imag))
         return CN;
     }
 
     if (HasCAddSupport && isInstructionPairAdd(Real, Imag)) {
-      if (NodePtr CN = identifyAdd(Real, Imag))
+      if (CompositeNode *CN = identifyAdd(Real, Imag))
         return CN;
     }
 
     if (HasCMulSupport && HasCAddSupport) {
-      if (NodePtr CN = identifyReassocNodes(Real, Imag)) {
+      if (CompositeNode *CN = identifyReassocNodes(Real, Imag)) {
         return CN;
       }
     }
   }
 
-  if (NodePtr CN = identifySymmetricOperation(Vals))
+  if (CompositeNode *CN = identifySymmetricOperation(Vals))
     return CN;
 
   LLVM_DEBUG(dbgs() << "  - Not recognised as a valid pattern.\n");
@@ -1218,7 +1222,7 @@ ComplexDeinterleavingGraph::identifyNode(ComplexValues &Vals) {
   return nullptr;
 }
 
-ComplexDeinterleavingGraph::NodePtr
+ComplexDeinterleavingGraph::CompositeNode *
 ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
                                                  Instruction *Imag) {
   auto IsOperationSupported = [](unsigned Opcode) -> bool {
@@ -1341,7 +1345,7 @@ ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
   if (RealAddends.size() != ImagAddends.size())
     return nullptr;
 
-  NodePtr FinalNode;
+  CompositeNode *FinalNode = nullptr;
   if (!RealMuls.empty() || !ImagMuls.empty()) {
     // If there are multiplicands, extract positive addend and use it as an
     // accumulator
@@ -1417,10 +1421,10 @@ bool ComplexDeinterleavingGraph::collectPartialMuls(
   return true;
 }
 
-ComplexDeinterleavingGraph::NodePtr
+ComplexDeinterleavingGraph::CompositeNode *
 ComplexDeinterleavingGraph::identifyMultiplications(
     std::vector<Product> &RealMuls, std::vector<Product> &ImagMuls,
-    NodePtr Accumulator = nullptr) {
+    CompositeNode *Accumulator = nullptr) {
   if (RealMuls.size() != ImagMuls.size())
     return nullptr;
 
@@ -1429,7 +1433,7 @@ ComplexDeinterleavingGraph::identifyMultiplications(
     return nullptr;
 
   // Map to store common instruction to node pointers
-  std::map<Value *, NodePtr> CommonToNode;
+  std::map<Value *, CompositeNode *> CommonToNode;
   std::vector<bool> Processed(Info.size(), false);
   for (unsigned I = 0; I < Info.size(); ++I) {
     if (Processed[I])
@@ -1461,7 +1465,7 @@ ComplexDeinterleavingGraph::identifyMultiplications(
 
   std::vector<bool> ProcessedReal(RealMuls.size(), false);
   std::vector<bool> ProcessedImag(ImagMuls.size(), false);
-  NodePtr Result = Accumulator;
+  CompositeNode *Result = Accumulator;
   for (auto &PMI : Info) {
     if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx])
       continue;
@@ -1533,7 +1537,7 @@ ComplexDeinterleavingGraph::identifyMultiplications(
       dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
     });
 
-    NodePtr NodeMul = prepareCompositeNode(
+    CompositeNode *NodeMul = prepareCompositeNode(
         ComplexDeinterleavingOperation::CMulPartial, nullptr, nullptr);
     NodeMul->Rotation = Rotation;
     NodeMul->addOperand(NodeA);
@@ -1574,14 +1578,14 @@ ComplexDeinterleavingGraph::identifyMultiplications(
   return Result;
 }
 
-ComplexDeinterleavingGraph::NodePtr
+ComplexDeinterleavingGraph::CompositeNode *
 ComplexDeinterleavingGraph::identifyAdditions(
     std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends,
-    std::optional<FastMathFlags> Flags, NodePtr Accumulator = nullptr) {
+    std::optional<FastMathFlags> Flags, CompositeNode *Accumulator = nullptr) {
   if (RealAddends.size() != ImagAddends.size())
     return nullptr;
 
-  NodePtr Result;
+  CompositeNode *Result = nullptr;
   // If we have accumulator use it as first addend
   if (Accumulator)
     Result = Accumulator;
@@ -1609,7 +1613,7 @@ ComplexDeinterleavingGraph::identifyAdditions(
       else
         Rotation = ComplexDeinterleavingRotation::Rotation_270;
 
-      NodePtr AddNode;
+      CompositeNode *AddNode = nullptr;
       if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
           Rotation == ComplexDeinterleavingRotation::Rotation_180) {
         AddNode = identifyNode(R, I);
@@ -1624,7 +1628,7 @@ ComplexDeinterleavingGraph::identifyAdditions(
           dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
         });
 
-        NodePtr TmpNode;
+        CompositeNode *TmpNode = nullptr;
         if (Rotation == llvm::ComplexDeinterleavingRotation::Rotation_0) {
           TmpNode = prepareCompositeNode(
               ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
@@ -1666,7 +1670,7 @@ ComplexDeinterleavingGraph::identifyAdditions(
   return Result;
 }
 
-ComplexDeinterleavingGraph::NodePtr
+ComplexDeinterleavingGraph::CompositeNode *
 ComplexDeinterleavingGraph::extractPositiveAddend(
     std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends) {
   for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) {
@@ -1882,7 +1886,7 @@ void ComplexDeinterleavingGraph::identifyReductionNodes() {
 
 bool ComplexDeinterleavingGraph::checkNodes() {
   bool FoundDeinterleaveNode = false;
-  for (NodePtr N : CompositeNodes) {
+  for (CompositeNode *N : CompositeNodes) {
     if (!N->areOperandsValid())
       return false;
 
@@ -1968,7 +1972,7 @@ bool ComplexDeinterleavingGraph::checkNodes() {
   return !RootToNode.empty();
 }
 
-ComplexDeinterleavingGraph::NodePtr
+ComplexDeinterleavingGraph::CompositeNode *
 ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) {
   if (auto *Intrinsic = dyn_cast<IntrinsicInst>(RootI)) {
     if (Intrinsic::getInterleaveIntrinsicID(Factor) !=
@@ -1984,7 +1988,7 @@ ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) {
       Vals.push_back({Real, Imag});
     }
 
-    ComplexDeinterleavingGraph::NodePtr Node1 = identifyNode(Vals);
+    ComplexDeinterleavingGraph::CompositeNode *Node1 = identifyNode(Vals);
     if (!Node1)
       return nullptr;
     return Node1;
@@ -2015,7 +2019,7 @@ ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) {
   return identifyNode(Real, Imag);
 }
 
-ComplexDeinterleavingGraph::NodePtr
+ComplexDeinterleavingGraph::CompositeNode *
 ComplexDeinterleavingGraph::identifyDeinterleave(ComplexValues &Vals) {
   Instruction *II = nullptr;
 
@@ -2047,7 +2051,7 @@ ComplexDeinterleavingGraph::identifyDeinterleave(ComplexValues &Vals) {
       return nullptr;
 
     // The remaining should match too.
-    NodePtr PlaceholderNode = prepareCompositeNode(
+    CompositeNode *PlaceholderNode = prepareCompositeNode(
         llvm::ComplexDeinterleavingOperation::Deinterleave, Vals);
     PlaceholderNode->ReplacementNode = II->getOperand(0);
     for (auto &V : Vals) {
@@ -2145,7 +2149,7 @@ ComplexDeinterleavingGraph::identifyDeinterleave(ComplexValues &Vals) {
     return nullptr;
   }
 
-  NodePtr PlaceholderNode =
+  CompositeNode *PlaceholderNode =
       prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Deinterleave,
                            RealShuffle, ImagShuffle);
   PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
@@ -2154,7 +2158,7 @@ ComplexDeinterleavingGraph::identifyDeinterleave(ComplexValues &Vals) {
   return submitCompositeNode(PlaceholderNode);
 }
 
-ComplexDeinterleavingGraph::NodePtr
+ComplexDeinterleavingGraph::CompositeNode *
 ComplexDeinterleavingGraph::identifySplat(ComplexValues &Vals) {
   auto IsSplat = [](Value *V) -> bool {
     // Fixed-width vector with constants
@@ -2220,24 +2224,24 @@ ComplexDeinterleavingGraph::identifySplat(ComplexValues &Vals) {
       FinalInstructions.insert(Imag);
     }
   }
-  NodePtr PlaceholderNode =
+  CompositeNode *PlaceholderNode =
       prepareCompositeNode(ComplexDeinterleavingOperation::Splat, Vals);
   return submitCompositeNode(PlaceholderNode);
 }
 
-ComplexDeinterleavingGraph::NodePtr
+ComplexDeinterleavingGraph::CompositeNode *
 ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
                                             Instruction *Imag) {
   if (Real != RealPHI || (ImagPHI && Imag != ImagPHI))
     return nullptr;
 
   PHIsFound = true;
-  NodePtr PlaceholderNode = prepareCompositeNode(
+  CompositeNode *PlaceholderNode = prepareCompositeNode(
       ComplexDeinterleavingOperation::ReductionPHI, Real, Imag);
   return submitCompositeNode(PlaceholderNode);
 }
 
-ComplexDeinterleavingGraph::NodePtr
+ComplexDeinterleavingGraph::CompositeNode *
 ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real,
                                                Instruction *Imag) {
   auto *SelectReal = dyn_cast<SelectInst>(Real);
@@ -2267,7 +2271,7 @@ ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real,
   if (!NodeB)
     return nullptr;
 
-  NodePtr PlaceholderNode = prepareCompositeNode(
+  CompositeNode *PlaceholderNode = prepareCompositeNode(
       ComplexDeinterleavingOperation::ReductionSelect, Real, Imag);
   PlaceholderNode->addOperand(NodeA);
   PlaceholderNode->addOperand(NodeB);
@@ -2311,17 +2315,18 @@ static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode,
 }
 
 Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
-                                               RawNodePtr Node) {
+                                               CompositeNode *Node) {
   if (Node->ReplacementNode)
     return Node->ReplacementNode;
 
-  auto ReplaceOperandIfExist = [&](RawNodePtr &Node, unsigned Idx) -> Value * {
+  auto ReplaceOperandIfExist = [&](CompositeNode *Node,
+                                   unsigned Idx) -> Value * {
     return Node->Operands.size() > Idx
                ? replaceNode(Builder, Node->Operands[Idx])
                : nullptr;
   };
 
-  Value *ReplacementNode;
+  Value *ReplacementNode = nullptr;
   switch (Node->Operation) {
   case ComplexDeinterleavingOperation::CDot: {
     Value *Input0 = ReplaceOperandIfExist(Node, 0);
@@ -2418,7 +2423,7 @@ Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
 }
 
 void ComplexDeinterleavingGraph::processReductionSingle(
-    Value *OperationReplacement, RawNodePtr Node) {
+    Value *OperationReplacement, CompositeNode *Node) {
   auto *Real = cast<Instruction>(Node->Vals[0].Real);
   auto *OldPHI = ReductionInfo[Real].first;
   auto *NewPHI = OldToNewPHI[OldPHI];
@@ -2450,7 +2455,7 @@ void ComplexDeinterleavingGraph::processReductionSingle(
 }
 
 void ComplexDeinterleavingGraph::processReductionOperation(
-    Value *OperationReplacement, RawNodePtr Node) {
+    Value *OperationReplacement, CompositeNode *Node) {
   auto *Real = cast<Instruction>(Node->Vals[0].Real);
   auto *Imag = cast<Instruction>(Node->Vals[0].Imag);
   auto *OldPHIReal = ReductionInfo[Real].first;
@@ -2496,7 +2501,7 @@ void ComplexDeinterleavingGraph::replaceNodes() {
 
     IRBuilder<> Builder(RootInstruction);
     auto RootNode = RootToNode[RootInstruction];
-    Value *R = replaceNode(Builder, RootNode.get());
+    Value *R = replaceNode(Builder, RootNode);
 
     if (RootNode->Operation ==
         ComplexDeinterleavingOperation::ReductionOperation) {


        


More information about the llvm-commits mailing list