[llvm] Add support for single reductions in ComplexDeinterleavingPass (PR #112875)
Sam Tebbs via llvm-commits
llvm-commits at lists.llvm.org
Tue Nov 26 08:10:48 PST 2024
================
@@ -891,17 +901,172 @@ ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real,
}
ComplexDeinterleavingGraph::NodePtr
-ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) {
- LLVM_DEBUG(dbgs() << "identifyNode on " << *R << " / " << *I << "\n");
- assert(R->getType() == I->getType() &&
- "Real and imaginary parts should not have different types");
+ComplexDeinterleavingGraph::identifyDotProduct(Value *V) {
+ auto *Inst = cast<Instruction>(V);
+
+ if (!TL->isComplexDeinterleavingOperationSupported(
+ ComplexDeinterleavingOperation::CDot, Inst->getType())) {
+ LLVM_DEBUG(dbgs() << "Target doesn't support complex deinterleaving "
+ "operation CDot with the type "
+ << *Inst->getType() << "\n");
+ return nullptr;
+ }
+
+ auto *RealUser = cast<Instruction>(*Inst->user_begin());
+
+ NodePtr CN =
+ prepareCompositeNode(ComplexDeinterleavingOperation::CDot, Inst, nullptr);
+
+ NodePtr ANode;
+
+ const Intrinsic::ID PartialReduceInt =
+ Intrinsic::experimental_vector_partial_reduce_add;
+
+ Value *AReal = nullptr;
+ Value *AImag = nullptr;
+ Value *BReal = nullptr;
+ Value *BImag = nullptr;
+ Value *Phi = nullptr;
+
+ auto UnwrapCast = [](Value *V) -> Value * {
+ if (auto *CI = dyn_cast<CastInst>(V))
+ return CI->getOperand(0);
+ return V;
+ };
+
+ auto PatternRot0 = m_Intrinsic<PartialReduceInt>(
+ m_Intrinsic<PartialReduceInt>(m_Value(Phi),
+ m_Mul(m_Value(BReal), m_Value(AReal))),
+ m_Neg(m_Mul(m_Value(BImag), m_Value(AImag))));
+
+ auto PatternRot270 = m_Intrinsic<PartialReduceInt>(
+ m_Intrinsic<PartialReduceInt>(
+ m_Value(Phi), m_Neg(m_Mul(m_Value(BReal), m_Value(AImag)))),
+ m_Mul(m_Value(BImag), m_Value(AReal)));
+
+ if (match(Inst, PatternRot0)) {
+ CN->Rotation = ComplexDeinterleavingRotation::Rotation_0;
+ } else if (match(Inst, PatternRot270)) {
+ CN->Rotation = ComplexDeinterleavingRotation::Rotation_270;
+ } else {
+ Value *A0, *A1;
+ // The rotations 90 and 180 share the same operation pattern, so inspect the
+ // order of the operands, identifying where the real and imaginary
+ // components of A go, to discern between the aforementioned rotations.
+ auto PatternRot90Rot180 = m_Intrinsic<PartialReduceInt>(
+ m_Intrinsic<PartialReduceInt>(m_Value(Phi),
+ m_Mul(m_Value(BReal), m_Value(A0))),
+ m_Mul(m_Value(BImag), m_Value(A1)));
+
+ if (!match(Inst, PatternRot90Rot180))
+ return nullptr;
+
+ A0 = UnwrapCast(A0);
+ A1 = UnwrapCast(A1);
+
+ // Test if A0 is real/A1 is imag
+ ANode = identifyNode(A0, A1);
+ if (!ANode) {
+ // Test if A0 is imag/A1 is real
+ ANode = identifyNode(A1, A0);
+ // Unable to identify operand components, thus unable to identify rotation
+ if (!ANode)
+ return nullptr;
+ CN->Rotation = ComplexDeinterleavingRotation::Rotation_90;
+ AReal = A1;
+ AImag = A0;
+ } else {
+ AReal = A0;
+ AImag = A1;
+ CN->Rotation = ComplexDeinterleavingRotation::Rotation_180;
+ }
+ }
+
+ AReal = UnwrapCast(AReal);
+ AImag = UnwrapCast(AImag);
+ BReal = UnwrapCast(BReal);
+ BImag = UnwrapCast(BImag);
+
+ bool WasANodeFromCache = false;
+ NodePtr Node = identifyNode(AReal, AImag, &WasANodeFromCache);
+
+ // In the case that a node was identified to figure out the rotation, ensure
+ // that trying to identify a node with AReal and AImag post-unwrap results in
+ // the same node
+ if (Node && ANode && !WasANodeFromCache) {
+ LLVM_DEBUG(
+ dbgs()
+ << "Identified node is different from previously identified node. "
+ "Unable to confidently generate a complex operation node\n");
+ return nullptr;
+ }
+
+ CN->addOperand(Node);
+ CN->addOperand(identifyNode(BReal, BImag));
+ CN->addOperand(identifyNode(Phi, RealUser));
+
+ return submitCompositeNode(CN);
+}
+
+ComplexDeinterleavingGraph::NodePtr
+ComplexDeinterleavingGraph::identifyPartialReduction(Value *R, Value *I) {
+ if (!I->hasOneUser())
+ return nullptr;
+
+ VectorType *RealTy = dyn_cast<VectorType>(R->getType());
+ if (!RealTy)
+ return nullptr;
+ VectorType *ImagTy = dyn_cast<VectorType>(I->getType());
+ if (!ImagTy)
+ return nullptr;
+
+ if (RealTy->isScalableTy() != ImagTy->isScalableTy())
+ return nullptr;
+ if (RealTy->getElementType() != ImagTy->getElementType())
+ return nullptr;
+
+ // `I` is known to only have one user, so iterate over the Phi (R) users to
+ // find the common user between R and I
+ auto *CommonUser = *I->user_begin();
+ bool CommonUserFound = false;
+ for (auto *User : R->users()) {
+ if (User == CommonUser) {
+ CommonUserFound = true;
+ break;
+ }
+ }
+
+ if (!CommonUserFound)
+ return nullptr;
+
+ auto *IInst = dyn_cast<IntrinsicInst>(CommonUser);
+ if (!IInst || IInst->getIntrinsicID() !=
+ Intrinsic::experimental_vector_partial_reduce_add)
+ return nullptr;
+
+ if (NodePtr CN = identifyDotProduct(IInst))
+ return CN;
+ return nullptr;
+}
+
+ComplexDeinterleavingGraph::NodePtr
+ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I, bool *FromCache) {
----------------
SamTebbs33 wrote:
Can `FromCache` be a reference instead?
https://github.com/llvm/llvm-project/pull/112875
More information about the llvm-commits
mailing list