[llvm] Add support for single reductions in ComplexDeinterleavingPass (PR #112875)

Sam Tebbs via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 26 08:10:48 PST 2024


================
@@ -891,17 +901,172 @@ ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real,
 }
 
 ComplexDeinterleavingGraph::NodePtr
-ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) {
-  LLVM_DEBUG(dbgs() << "identifyNode on " << *R << " / " << *I << "\n");
-  assert(R->getType() == I->getType() &&
-         "Real and imaginary parts should not have different types");
+ComplexDeinterleavingGraph::identifyDotProduct(Value *V) {
+  auto *Inst = cast<Instruction>(V);
+
+  if (!TL->isComplexDeinterleavingOperationSupported(
+          ComplexDeinterleavingOperation::CDot, Inst->getType())) {
+    LLVM_DEBUG(dbgs() << "Target doesn't support complex deinterleaving "
+                         "operation CDot with the type "
+                      << *Inst->getType() << "\n");
+    return nullptr;
+  }
+
+  auto *RealUser = cast<Instruction>(*Inst->user_begin());
+
+  NodePtr CN =
+      prepareCompositeNode(ComplexDeinterleavingOperation::CDot, Inst, nullptr);
+
+  NodePtr ANode;
+
+  const Intrinsic::ID PartialReduceInt =
+      Intrinsic::experimental_vector_partial_reduce_add;
+
+  Value *AReal = nullptr;
+  Value *AImag = nullptr;
+  Value *BReal = nullptr;
+  Value *BImag = nullptr;
+  Value *Phi = nullptr;
+
+  auto UnwrapCast = [](Value *V) -> Value * {
+    if (auto *CI = dyn_cast<CastInst>(V))
+      return CI->getOperand(0);
+    return V;
+  };
+
+  auto PatternRot0 = m_Intrinsic<PartialReduceInt>(
+      m_Intrinsic<PartialReduceInt>(m_Value(Phi),
+                                    m_Mul(m_Value(BReal), m_Value(AReal))),
+      m_Neg(m_Mul(m_Value(BImag), m_Value(AImag))));
+
+  auto PatternRot270 = m_Intrinsic<PartialReduceInt>(
+      m_Intrinsic<PartialReduceInt>(
+          m_Value(Phi), m_Neg(m_Mul(m_Value(BReal), m_Value(AImag)))),
+      m_Mul(m_Value(BImag), m_Value(AReal)));
+
+  if (match(Inst, PatternRot0)) {
+    CN->Rotation = ComplexDeinterleavingRotation::Rotation_0;
+  } else if (match(Inst, PatternRot270)) {
+    CN->Rotation = ComplexDeinterleavingRotation::Rotation_270;
+  } else {
+    Value *A0, *A1;
+    // The rotations 90 and 180 share the same operation pattern, so inspect the
+    // order of the operands, identifying where the real and imaginary
+    // components of A go, to discern between the aforementioned rotations.
+    auto PatternRot90Rot180 = m_Intrinsic<PartialReduceInt>(
+        m_Intrinsic<PartialReduceInt>(m_Value(Phi),
+                                      m_Mul(m_Value(BReal), m_Value(A0))),
+        m_Mul(m_Value(BImag), m_Value(A1)));
+
+    if (!match(Inst, PatternRot90Rot180))
+      return nullptr;
+
+    A0 = UnwrapCast(A0);
+    A1 = UnwrapCast(A1);
+
+    // Test if A0 is real/A1 is imag
+    ANode = identifyNode(A0, A1);
+    if (!ANode) {
+      // Test if A0 is imag/A1 is real
+      ANode = identifyNode(A1, A0);
+      // Unable to identify operand components, thus unable to identify rotation
+      if (!ANode)
+        return nullptr;
+      CN->Rotation = ComplexDeinterleavingRotation::Rotation_90;
+      AReal = A1;
+      AImag = A0;
+    } else {
+      AReal = A0;
+      AImag = A1;
+      CN->Rotation = ComplexDeinterleavingRotation::Rotation_180;
+    }
+  }
+
+  AReal = UnwrapCast(AReal);
+  AImag = UnwrapCast(AImag);
+  BReal = UnwrapCast(BReal);
+  BImag = UnwrapCast(BImag);
+
+  bool WasANodeFromCache = false;
+  NodePtr Node = identifyNode(AReal, AImag, &WasANodeFromCache);
+
+  // In the case that a node was identified to figure out the rotation, ensure
+  // that trying to identify a node with AReal and AImag post-unwrap results in
+  // the same node
+  if (Node && ANode && !WasANodeFromCache) {
+    LLVM_DEBUG(
+        dbgs()
+        << "Identified node is different from previously identified node. "
+           "Unable to confidently generate a complex operation node\n");
+    return nullptr;
+  }
+
+  CN->addOperand(Node);
+  CN->addOperand(identifyNode(BReal, BImag));
+  CN->addOperand(identifyNode(Phi, RealUser));
+
+  return submitCompositeNode(CN);
+}
+
+ComplexDeinterleavingGraph::NodePtr
+ComplexDeinterleavingGraph::identifyPartialReduction(Value *R, Value *I) {
+  if (!I->hasOneUser())
+    return nullptr;
+
+  VectorType *RealTy = dyn_cast<VectorType>(R->getType());
+  if (!RealTy)
+    return nullptr;
+  VectorType *ImagTy = dyn_cast<VectorType>(I->getType());
+  if (!ImagTy)
+    return nullptr;
+
+  if (RealTy->isScalableTy() != ImagTy->isScalableTy())
+    return nullptr;
+  if (RealTy->getElementType() != ImagTy->getElementType())
+    return nullptr;
+
+  // `I` is known to only have one user, so iterate over the Phi (R) users to
+  // find the common user between R and I
+  auto *CommonUser = *I->user_begin();
+  bool CommonUserFound = false;
+  for (auto *User : R->users()) {
+    if (User == CommonUser) {
+      CommonUserFound = true;
+      break;
+    }
+  }
+
+  if (!CommonUserFound)
+    return nullptr;
+
+  auto *IInst = dyn_cast<IntrinsicInst>(CommonUser);
+  if (!IInst || IInst->getIntrinsicID() !=
+                    Intrinsic::experimental_vector_partial_reduce_add)
+    return nullptr;
+
+  if (NodePtr CN = identifyDotProduct(IInst))
+    return CN;
 
+  return nullptr;
+}
+
+ComplexDeinterleavingGraph::NodePtr
+ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I, bool *FromCache) {
----------------
SamTebbs33 wrote:

Can `FromCache` be a reference instead?

https://github.com/llvm/llvm-project/pull/112875


More information about the llvm-commits mailing list