[llvm] Add support for single reductions in ComplexDeinterleavingPass (PR #112875)

Nicholas Guy via llvm-commits llvm-commits at lists.llvm.org
Tue Dec 10 07:07:45 PST 2024


================
@@ -891,17 +901,172 @@ ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real,
 }
 
 ComplexDeinterleavingGraph::NodePtr
-ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) {
-  LLVM_DEBUG(dbgs() << "identifyNode on " << *R << " / " << *I << "\n");
-  assert(R->getType() == I->getType() &&
-         "Real and imaginary parts should not have different types");
+ComplexDeinterleavingGraph::identifyDotProduct(Value *V) {
+  auto *Inst = cast<Instruction>(V);
+
+  if (!TL->isComplexDeinterleavingOperationSupported(
+          ComplexDeinterleavingOperation::CDot, Inst->getType())) {
+    LLVM_DEBUG(dbgs() << "Target doesn't support complex deinterleaving "
+                         "operation CDot with the type "
+                      << *Inst->getType() << "\n");
+    return nullptr;
+  }
+
+  auto *RealUser = cast<Instruction>(*Inst->user_begin());
+
+  NodePtr CN =
+      prepareCompositeNode(ComplexDeinterleavingOperation::CDot, Inst, nullptr);
+
+  NodePtr ANode;
+
+  const Intrinsic::ID PartialReduceInt =
+      Intrinsic::experimental_vector_partial_reduce_add;
+
+  Value *AReal = nullptr;
+  Value *AImag = nullptr;
+  Value *BReal = nullptr;
+  Value *BImag = nullptr;
+  Value *Phi = nullptr;
+
+  auto UnwrapCast = [](Value *V) -> Value * {
+    if (auto *CI = dyn_cast<CastInst>(V))
+      return CI->getOperand(0);
+    return V;
+  };
+
+  auto PatternRot0 = m_Intrinsic<PartialReduceInt>(
+      m_Intrinsic<PartialReduceInt>(m_Value(Phi),
+                                    m_Mul(m_Value(BReal), m_Value(AReal))),
+      m_Neg(m_Mul(m_Value(BImag), m_Value(AImag))));
+
+  auto PatternRot270 = m_Intrinsic<PartialReduceInt>(
+      m_Intrinsic<PartialReduceInt>(
+          m_Value(Phi), m_Neg(m_Mul(m_Value(BReal), m_Value(AImag)))),
+      m_Mul(m_Value(BImag), m_Value(AReal)));
+
+  if (match(Inst, PatternRot0)) {
+    CN->Rotation = ComplexDeinterleavingRotation::Rotation_0;
+  } else if (match(Inst, PatternRot270)) {
+    CN->Rotation = ComplexDeinterleavingRotation::Rotation_270;
+  } else {
+    Value *A0, *A1;
+    // The rotations 90 and 180 share the same operation pattern, so inspect the
+    // order of the operands, identifying where the real and imaginary
+    // components of A go, to discern between the aforementioned rotations.
+    auto PatternRot90Rot180 = m_Intrinsic<PartialReduceInt>(
+        m_Intrinsic<PartialReduceInt>(m_Value(Phi),
+                                      m_Mul(m_Value(BReal), m_Value(A0))),
+        m_Mul(m_Value(BImag), m_Value(A1)));
+
+    if (!match(Inst, PatternRot90Rot180))
+      return nullptr;
+
+    A0 = UnwrapCast(A0);
+    A1 = UnwrapCast(A1);
+
+    // Test if A0 is real/A1 is imag
+    ANode = identifyNode(A0, A1);
+    if (!ANode) {
+      // Test if A0 is imag/A1 is real
+      ANode = identifyNode(A1, A0);
+      // Unable to identify operand components, thus unable to identify rotation
+      if (!ANode)
+        return nullptr;
+      CN->Rotation = ComplexDeinterleavingRotation::Rotation_90;
+      AReal = A1;
+      AImag = A0;
+    } else {
+      AReal = A0;
+      AImag = A1;
+      CN->Rotation = ComplexDeinterleavingRotation::Rotation_180;
+    }
+  }
+
+  AReal = UnwrapCast(AReal);
+  AImag = UnwrapCast(AImag);
+  BReal = UnwrapCast(BReal);
+  BImag = UnwrapCast(BImag);
+
+  bool WasANodeFromCache = false;
+  NodePtr Node = identifyNode(AReal, AImag, &WasANodeFromCache);
+
+  // In the case that a node was identified to figure out the rotation, ensure
+  // that trying to identify a node with AReal and AImag post-unwrap results in
+  // the same node
+  if (Node && ANode && !WasANodeFromCache) {
+    LLVM_DEBUG(
+        dbgs()
+        << "Identified node is different from previously identified node. "
+           "Unable to confidently generate a complex operation node\n");
+    return nullptr;
+  }
+
+  CN->addOperand(Node);
+  CN->addOperand(identifyNode(BReal, BImag));
+  CN->addOperand(identifyNode(Phi, RealUser));
+
+  return submitCompositeNode(CN);
+}
+
+ComplexDeinterleavingGraph::NodePtr
+ComplexDeinterleavingGraph::identifyPartialReduction(Value *R, Value *I) {
+  if (!I->hasOneUser())
+    return nullptr;
----------------
NickGuy-Arm wrote:

The `hasOneUser` check was to simplify things later, but it's not strictly necessary so I've removed it and fixed the succeeding code to not assume one user.

https://github.com/llvm/llvm-project/pull/112875


More information about the llvm-commits mailing list