[llvm] c692e87 - [CodeGen] Enable processing of interconnected complex number operations

Igor Kirillov via llvm-commits llvm-commits at lists.llvm.org
Tue Apr 18 06:06:40 PDT 2023


Author: Igor Kirillov
Date: 2023-04-18T13:05:49Z
New Revision: c692e87ab8e7d3c7d8e2365572ffb41f6ec9ac1d

URL: https://github.com/llvm/llvm-project/commit/c692e87ab8e7d3c7d8e2365572ffb41f6ec9ac1d
DIFF: https://github.com/llvm/llvm-project/commit/c692e87ab8e7d3c7d8e2365572ffb41f6ec9ac1d.diff

LOG: [CodeGen] Enable processing of interconnected complex number operations

With this patch, ComplexDeinterleavingPass now has the ability to handle
any number of interconnected operations involving complex numbers.
For example, the patch enables the processing of code like the following:

for (int i = 0; i < 1000; ++i) {
    a[i] =  w[i] * v[i];
    b[i] =  w[i] * u[i];
}

This code has multiple arrays containing complex numbers and a common
subexpression `w` that appears in two expressions.

Differential Revision: https://reviews.llvm.org/D146988

Added: 
    

Modified: 
    llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
    llvm/test/CodeGen/AArch64/complex-deinterleaving-multiuses.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
index ff0c5d530747b..3cfe935e2cca3 100644
--- a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
+++ b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
@@ -137,19 +137,12 @@ struct ComplexDeinterleavingCompositeNode {
   Instruction *Real;
   Instruction *Imag;
 
-  // Instructions that should only exist within this node, there should be no
-  // users of these instructions outside the node. An example of these would be
-  // the multiply instructions of a partial multiply operation.
-  SmallVector<Instruction *> InternalInstructions;
   ComplexDeinterleavingRotation Rotation;
   SmallVector<RawNodePtr> Operands;
   Value *ReplacementNode = nullptr;
 
-  void addInstruction(Instruction *I) { InternalInstructions.push_back(I); }
   void addOperand(NodePtr Node) { Operands.push_back(Node.get()); }
 
-  bool hasAllInternalUses(SmallPtrSet<Instruction *, 16> &AllInstructions);
-
   void dump() { dump(dbgs()); }
   void dump(raw_ostream &OS) {
     auto PrintValue = [&](Value *V) {
@@ -181,12 +174,6 @@ struct ComplexDeinterleavingCompositeNode {
       OS << "    - ";
       PrintNodeRef(Op);
     }
-    OS << "  InternalInstructions:\n";
-    for (const auto &I : InternalInstructions) {
-      OS << "    - \"";
-      I->print(OS, true);
-      OS << "\"\n";
-    }
   }
 };
 
@@ -194,14 +181,22 @@ class ComplexDeinterleavingGraph {
 public:
   using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr;
   using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr;
-  explicit ComplexDeinterleavingGraph(const TargetLowering *tl) : TL(tl) {}
+  explicit ComplexDeinterleavingGraph(const TargetLowering *TL,
+                                      const TargetLibraryInfo *TLI)
+      : TL(TL), TLI(TLI) {}
 
 private:
   const TargetLowering *TL = nullptr;
-  Instruction *RootValue = nullptr;
-  NodePtr RootNode;
+  const TargetLibraryInfo *TLI = nullptr;
   SmallVector<NodePtr> CompositeNodes;
-  SmallPtrSet<Instruction *, 16> AllInstructions;
+
+  SmallPtrSet<Instruction *, 16> FinalInstructions;
+
+  /// Root instructions are instructions from which complex computation starts
+  std::map<Instruction *, NodePtr> RootToNode;
+
+  /// Topologically sorted root instructions
+  SmallVector<Instruction *, 1> OrderedRoots;
 
   NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation,
                                Instruction *R, Instruction *I) {
@@ -211,10 +206,6 @@ class ComplexDeinterleavingGraph {
 
   NodePtr submitCompositeNode(NodePtr Node) {
     CompositeNodes.push_back(Node);
-    AllInstructions.insert(Node->Real);
-    AllInstructions.insert(Node->Imag);
-    for (auto *I : Node->InternalInstructions)
-      AllInstructions.insert(I);
     return Node;
   }
 
@@ -271,6 +262,10 @@ class ComplexDeinterleavingGraph {
   /// current graph.
   bool identifyNodes(Instruction *RootI);
 
+  /// Check that every instruction, from the roots to the leaves, has internal
+  /// uses.
+  bool checkNodes();
+
   /// Perform the actual replacement of the underlying instruction graph.
   void replaceNodes();
 };
@@ -368,9 +363,7 @@ static bool isDeinterleavingMask(ArrayRef<int> Mask) {
 }
 
 bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) {
-  bool Changed = false;
-
-  SmallVector<Instruction *> DeadInstrRoots;
+  ComplexDeinterleavingGraph Graph(TL, TLI);
 
   for (auto &I : *B) {
     auto *SVI = dyn_cast<ShuffleVectorInst>(&I);
@@ -382,22 +375,15 @@ bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) {
     if (!isInterleavingMask(SVI->getShuffleMask()))
       continue;
 
-    ComplexDeinterleavingGraph Graph(TL);
-    if (!Graph.identifyNodes(SVI))
-      continue;
-
-    Graph.replaceNodes();
-    DeadInstrRoots.push_back(SVI);
-    Changed = true;
+    Graph.identifyNodes(SVI);
   }
 
-  for (const auto &I : DeadInstrRoots) {
-    if (!I || I->getParent() == nullptr)
-      continue;
-    llvm::RecursivelyDeleteTriviallyDeadInstructions(I, TLI);
+  if (Graph.checkNodes()) {
+    Graph.replaceNodes();
+    return true;
   }
 
-  return Changed;
+  return false;
 }
 
 ComplexDeinterleavingGraph::NodePtr
@@ -511,7 +497,6 @@ ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
   Node->Rotation = Rotation;
   Node->addOperand(CommonNode);
   Node->addOperand(UncommonNode);
-  Node->InternalInstructions.append(FNegs);
   return submitCompositeNode(Node);
 }
 
@@ -627,8 +612,6 @@ ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
 
   NodePtr Node = prepareCompositeNode(
       ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
-  Node->addInstruction(RealMulI);
-  Node->addInstruction(ImagMulI);
   Node->Rotation = Rotation;
   Node->addOperand(CommonRes);
   Node->addOperand(UncommonRes);
@@ -846,6 +829,8 @@ ComplexDeinterleavingGraph::identifyNode(Instruction *Real, Instruction *Imag) {
         prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Shuffle,
                              RealShuffle, ImagShuffle);
     PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
+    FinalInstructions.insert(RealShuffle);
+    FinalInstructions.insert(ImagShuffle);
     return submitCompositeNode(PlaceholderNode);
   }
   if (RealShuffle || ImagShuffle) {
@@ -881,9 +866,7 @@ bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
   if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag))))
     return false;
 
-  RootValue = RootI;
-  AllInstructions.insert(RootI);
-  RootNode = identifyNode(Real, Imag);
+  auto RootNode = identifyNode(Real, Imag);
 
   LLVM_DEBUG({
     Function *F = RootI->getFunction();
@@ -894,14 +877,86 @@ bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
     dbgs() << "\n";
   });
 
-  // Check all instructions have internal uses
-  for (const auto &Node : CompositeNodes) {
-    if (!Node->hasAllInternalUses(AllInstructions)) {
-      LLVM_DEBUG(dbgs() << "  - Invalid internal uses\n");
-      return false;
+  if (RootNode) {
+    RootToNode[RootI] = RootNode;
+    OrderedRoots.push_back(RootI);
+    return true;
+  }
+
+  return false;
+}
+
+bool ComplexDeinterleavingGraph::checkNodes() {
+  // Collect all instructions from roots to leaves
+  SmallPtrSet<Instruction *, 16> AllInstructions;
+  SmallVector<Instruction *, 8> Worklist;
+  for (auto *I : OrderedRoots)
+    Worklist.push_back(I);
+
+  // Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG
+  // chains
+  while (!Worklist.empty()) {
+    auto *I = Worklist.back();
+    Worklist.pop_back();
+
+    if (!AllInstructions.insert(I).second)
+      continue;
+
+    for (Value *Op : I->operands()) {
+      if (auto *OpI = dyn_cast<Instruction>(Op)) {
+        if (!FinalInstructions.count(I))
+          Worklist.emplace_back(OpI);
+      }
     }
   }
-  return RootNode != nullptr;
+
+  // Find instructions that have users outside of chain
+  SmallVector<Instruction *, 2> OuterInstructions;
+  for (auto *I : AllInstructions) {
+    // Skip root nodes
+    if (RootToNode.count(I))
+      continue;
+
+    for (User *U : I->users()) {
+      if (AllInstructions.count(cast<Instruction>(U)))
+        continue;
+
+      // Found an instruction that is not used by XCMLA/XCADD chain
+      Worklist.emplace_back(I);
+      break;
+    }
+  }
+
+  // If any instructions are found to be used outside, find and remove roots
+  // that somehow connect to those instructions.
+  SmallPtrSet<Instruction *, 16> Visited;
+  while (!Worklist.empty()) {
+    auto *I = Worklist.back();
+    Worklist.pop_back();
+    if (!Visited.insert(I).second)
+      continue;
+
+    // Found an impacted root node. Removing it from the nodes to be
+    // deinterleaved
+    if (RootToNode.count(I)) {
+      LLVM_DEBUG(dbgs() << "Instruction " << *I
+                        << " could be deinterleaved but its chain of complex "
+                           "operations have an outside user\n");
+      RootToNode.erase(I);
+    }
+
+    if (!AllInstructions.count(I) || FinalInstructions.count(I))
+      continue;
+
+    for (User *U : I->users())
+      Worklist.emplace_back(cast<Instruction>(U));
+
+    for (Value *Op : I->operands()) {
+      if (auto *OpI = dyn_cast<Instruction>(Op))
+        Worklist.emplace_back(OpI);
+    }
+  }
+  return !RootToNode.empty();
 }
 
 static Value *replaceSymmetricNode(ComplexDeinterleavingGraph::RawNodePtr Node,
@@ -958,29 +1013,21 @@ Value *ComplexDeinterleavingGraph::replaceNode(
 }
 
 void ComplexDeinterleavingGraph::replaceNodes() {
-  Value *R = replaceNode(RootNode.get());
-  assert(R && "Unable to find replacement for RootValue");
-  RootValue->replaceAllUsesWith(R);
-}
-
-bool ComplexDeinterleavingCompositeNode::hasAllInternalUses(
-    SmallPtrSet<Instruction *, 16> &AllInstructions) {
-  if (Operation == ComplexDeinterleavingOperation::Shuffle)
-    return true;
+  SmallVector<Instruction *, 16> DeadInstrRoots;
+  for (auto *RootInstruction : OrderedRoots) {
+    // Check if this potential root went through check process and we can
+    // deinterleave it
+    if (!RootToNode.count(RootInstruction))
+      continue;
 
-  for (auto *User : Real->users()) {
-    if (!AllInstructions.contains(cast<Instruction>(User)))
-      return false;
+    IRBuilder<> Builder(RootInstruction);
+    auto RootNode = RootToNode[RootInstruction];
+    Value *R = replaceNode(RootNode.get());
+    assert(R && "Unable to find replacement for RootInstruction");
+    DeadInstrRoots.push_back(RootInstruction);
+    RootInstruction->replaceAllUsesWith(R);
   }
-  for (auto *User : Imag->users()) {
-    if (!AllInstructions.contains(cast<Instruction>(User)))
-      return false;
-  }
-  for (auto *I : InternalInstructions) {
-    for (auto *User : I->users()) {
-      if (!AllInstructions.contains(cast<Instruction>(User)))
-        return false;
-    }
-  }
-  return true;
+
+  for (auto *I : DeadInstrRoots)
+    RecursivelyDeleteTriviallyDeadInstructions(I, TLI);
 }

diff  --git a/llvm/test/CodeGen/AArch64/complex-deinterleaving-multiuses.ll b/llvm/test/CodeGen/AArch64/complex-deinterleaving-multiuses.ll
index fe3d30677f084..4d84636e92ca2 100644
--- a/llvm/test/CodeGen/AArch64/complex-deinterleaving-multiuses.ll
+++ b/llvm/test/CodeGen/AArch64/complex-deinterleaving-multiuses.ll
@@ -2,30 +2,20 @@
 ; RUN: llc < %s --mattr=+complxnum,+neon -o - | FileCheck %s
 
 target triple = "aarch64-arm-none-eabi"
-; Expected to not transform
+; Expected to transform
 ;   *p = (a * b);
 ;   return (a * b) * a;
 define <4 x float> @mul_triangle(<4 x float> %a, <4 x float> %b, ptr %p) {
 ; CHECK-LABEL: mul_triangle:
 ; CHECK:       // %bb.0: // %entry
-; CHECK-NEXT:    ext v2.16b, v0.16b, v0.16b, #8
-; CHECK-NEXT:    ext v3.16b, v1.16b, v1.16b, #8
-; CHECK-NEXT:    zip2 v4.2s, v0.2s, v2.2s
-; CHECK-NEXT:    zip1 v0.2s, v0.2s, v2.2s
-; CHECK-NEXT:    zip2 v5.2s, v1.2s, v3.2s
-; CHECK-NEXT:    zip1 v1.2s, v1.2s, v3.2s
-; CHECK-NEXT:    fmul v6.2s, v5.2s, v4.2s
-; CHECK-NEXT:    fneg v2.2s, v6.2s
-; CHECK-NEXT:    fmla v2.2s, v0.2s, v1.2s
-; CHECK-NEXT:    fmul v3.2s, v4.2s, v1.2s
-; CHECK-NEXT:    fmla v3.2s, v0.2s, v5.2s
-; CHECK-NEXT:    fmul v1.2s, v3.2s, v4.2s
-; CHECK-NEXT:    fmul v5.2s, v3.2s, v0.2s
-; CHECK-NEXT:    st2 { v2.2s, v3.2s }, [x0]
-; CHECK-NEXT:    fneg v1.2s, v1.2s
-; CHECK-NEXT:    fmla v5.2s, v4.2s, v2.2s
-; CHECK-NEXT:    fmla v1.2s, v0.2s, v2.2s
-; CHECK-NEXT:    zip1 v0.4s, v1.4s, v5.4s
+; CHECK-NEXT:    movi v3.2d, #0000000000000000
+; CHECK-NEXT:    movi v2.2d, #0000000000000000
+; CHECK-NEXT:    fcmla v3.4s, v1.4s, v0.4s, #0
+; CHECK-NEXT:    fcmla v3.4s, v1.4s, v0.4s, #90
+; CHECK-NEXT:    fcmla v2.4s, v0.4s, v3.4s, #0
+; CHECK-NEXT:    str q3, [x0]
+; CHECK-NEXT:    fcmla v2.4s, v0.4s, v3.4s, #90
+; CHECK-NEXT:    mov v0.16b, v2.16b
 ; CHECK-NEXT:    ret
 entry:
   %strided.vec = shufflevector <4 x float> %a, <4 x float> poison, <2 x i32> <i32 0, i32 2>


        


More information about the llvm-commits mailing list