[llvm] b4f9c3a - [CodeGen] Refactor ComplexDeinterleaving to run identification on Values instead of Instructions
Igor Kirillov via llvm-commits
llvm-commits at lists.llvm.org
Mon Jul 3 03:36:02 PDT 2023
Author: Igor Kirillov
Date: 2023-07-03T10:35:14Z
New Revision: b4f9c3a933e80de40ee7860db75bb1088cf9bfa7
URL: https://github.com/llvm/llvm-project/commit/b4f9c3a933e80de40ee7860db75bb1088cf9bfa7
DIFF: https://github.com/llvm/llvm-project/commit/b4f9c3a933e80de40ee7860db75bb1088cf9bfa7.diff
LOG: [CodeGen] Refactor ComplexDeinterleaving to run identification on Values instead of Instructions
This change will make it easier to add identification of complex constants
in future patches.
Differential Revision: https://reviews.llvm.org/D153446
Added:
Modified:
llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
Removed:
################################################################################
diff --git a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
index 2464ebab221d4a..3a4b94d5eae271 100644
--- a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
+++ b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
@@ -130,7 +130,7 @@ class ComplexDeinterleavingGraph;
struct ComplexDeinterleavingCompositeNode {
ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
- Instruction *R, Instruction *I)
+ Value *R, Value *I)
: Operation(Op), Real(R), Imag(I) {}
private:
@@ -140,8 +140,8 @@ struct ComplexDeinterleavingCompositeNode {
public:
ComplexDeinterleavingOperation Operation;
- Instruction *Real;
- Instruction *Imag;
+ Value *Real;
+ Value *Imag;
// This two members are required exclusively for generating
// ComplexDeinterleavingOperation::Symmetric operations.
@@ -192,19 +192,19 @@ struct ComplexDeinterleavingCompositeNode {
class ComplexDeinterleavingGraph {
public:
struct Product {
- Instruction *Multiplier;
- Instruction *Multiplicand;
+ Value *Multiplier;
+ Value *Multiplicand;
bool IsPositive;
};
- using Addend = std::pair<Instruction *, bool>;
+ using Addend = std::pair<Value *, bool>;
using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr;
using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr;
// Helper struct for holding info about potential partial multiplication
// candidates
struct PartialMulCandidate {
- Instruction *Common;
+ Value *Common;
NodePtr Node;
unsigned RealIdx;
unsigned ImagIdx;
@@ -270,7 +270,7 @@ class ComplexDeinterleavingGraph {
std::map<PHINode *, PHINode *> OldToNewPHI;
NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation,
- Instruction *R, Instruction *I) {
+ Value *R, Value *I) {
assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI &&
Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
(R && I)) &&
@@ -308,9 +308,9 @@ class ComplexDeinterleavingGraph {
/// Identify the other branch of a Partial Mul, taking the CommonOperandI that
/// is partially known from identifyPartialMul, filling in the other half of
/// the complex pair.
- NodePtr identifyNodeWithImplicitAdd(
- Instruction *I, Instruction *J,
- std::pair<Instruction *, Instruction *> &CommonOperandI);
+ NodePtr
+ identifyNodeWithImplicitAdd(Instruction *I, Instruction *J,
+ std::pair<Value *, Value *> &CommonOperandI);
/// Identifies a complex add pattern and its rotation, based on the following
/// patterns.
@@ -322,7 +322,7 @@ class ComplexDeinterleavingGraph {
NodePtr identifyAdd(Instruction *Real, Instruction *Imag);
NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag);
- NodePtr identifyNode(Instruction *I, Instruction *J);
+ NodePtr identifyNode(Value *R, Value *I);
/// Determine if a sum of complex numbers can be formed from \p RealAddends
/// and \p ImagAddens. If \p Accumulator is not null, add the result to it.
@@ -521,7 +521,7 @@ bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) {
ComplexDeinterleavingGraph::NodePtr
ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
Instruction *Real, Instruction *Imag,
- std::pair<Instruction *, Instruction *> &PartialMatch) {
+ std::pair<Value *, Value *> &PartialMatch) {
LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag
<< "\n");
@@ -536,52 +536,38 @@ ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
return nullptr;
}
- Instruction *R0 = dyn_cast<Instruction>(Real->getOperand(0));
- Instruction *R1 = dyn_cast<Instruction>(Real->getOperand(1));
- Instruction *I0 = dyn_cast<Instruction>(Imag->getOperand(0));
- Instruction *I1 = dyn_cast<Instruction>(Imag->getOperand(1));
- if (!R0 || !R1 || !I0 || !I1) {
- LLVM_DEBUG(dbgs() << " - Mul operand not Instruction\n");
- return nullptr;
- }
+ Value *R0 = Real->getOperand(0);
+ Value *R1 = Real->getOperand(1);
+ Value *I0 = Imag->getOperand(0);
+ Value *I1 = Imag->getOperand(1);
// A +/+ has a rotation of 0. If any of the operands are fneg, we flip the
// rotations and use the operand.
unsigned Negs = 0;
- SmallVector<Instruction *> FNegs;
- if (R0->getOpcode() == Instruction::FNeg ||
- R1->getOpcode() == Instruction::FNeg) {
+ Value *Op;
+ if (match(R0, m_Neg(m_Value(Op)))) {
Negs |= 1;
- if (R0->getOpcode() == Instruction::FNeg) {
- FNegs.push_back(R0);
- R0 = dyn_cast<Instruction>(R0->getOperand(0));
- } else {
- FNegs.push_back(R1);
- R1 = dyn_cast<Instruction>(R1->getOperand(0));
- }
- if (!R0 || !R1)
- return nullptr;
+ R0 = Op;
+ } else if (match(R1, m_Neg(m_Value(Op)))) {
+ Negs |= 1;
+ R1 = Op;
}
- if (I0->getOpcode() == Instruction::FNeg ||
- I1->getOpcode() == Instruction::FNeg) {
+
+ if (match(I0, m_Neg(m_Value(Op)))) {
Negs |= 2;
Negs ^= 1;
- if (I0->getOpcode() == Instruction::FNeg) {
- FNegs.push_back(I0);
- I0 = dyn_cast<Instruction>(I0->getOperand(0));
- } else {
- FNegs.push_back(I1);
- I1 = dyn_cast<Instruction>(I1->getOperand(0));
- }
- if (!I0 || !I1)
- return nullptr;
+ I0 = Op;
+ } else if (match(I1, m_Neg(m_Value(Op)))) {
+ Negs |= 2;
+ Negs ^= 1;
+ I1 = Op;
}
ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs;
- Instruction *CommonOperand;
- Instruction *UncommonRealOp;
- Instruction *UncommonImagOp;
+ Value *CommonOperand;
+ Value *UncommonRealOp;
+ Value *UncommonImagOp;
if (R0 == I0 || R0 == I1) {
CommonOperand = R0;
@@ -676,18 +662,14 @@ ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
return nullptr;
}
- Instruction *R0 = dyn_cast<Instruction>(RealMulI->getOperand(0));
- Instruction *R1 = dyn_cast<Instruction>(RealMulI->getOperand(1));
- Instruction *I0 = dyn_cast<Instruction>(ImagMulI->getOperand(0));
- Instruction *I1 = dyn_cast<Instruction>(ImagMulI->getOperand(1));
- if (!R0 || !R1 || !I0 || !I1) {
- LLVM_DEBUG(dbgs() << " - Mul operand not Instruction\n");
- return nullptr;
- }
+ Value *R0 = RealMulI->getOperand(0);
+ Value *R1 = RealMulI->getOperand(1);
+ Value *I0 = ImagMulI->getOperand(0);
+ Value *I1 = ImagMulI->getOperand(1);
- Instruction *CommonOperand;
- Instruction *UncommonRealOp;
- Instruction *UncommonImagOp;
+ Value *CommonOperand;
+ Value *UncommonRealOp;
+ Value *UncommonImagOp;
if (R0 == I0 || R0 == I1) {
CommonOperand = R0;
@@ -705,7 +687,7 @@ ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
Rotation == ComplexDeinterleavingRotation::Rotation_270)
std::swap(UncommonRealOp, UncommonImagOp);
- std::pair<Instruction *, Instruction *> PartialMatch(
+ std::pair<Value *, Value *> PartialMatch(
(Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
Rotation == ComplexDeinterleavingRotation::Rotation_180)
? CommonOperand
@@ -840,11 +822,8 @@ ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real,
!isInstructionPotentiallySymmetric(Imag))
return nullptr;
- auto *R0 = dyn_cast<Instruction>(Real->getOperand(0));
- auto *I0 = dyn_cast<Instruction>(Imag->getOperand(0));
-
- if (!R0 || !I0)
- return nullptr;
+ auto *R0 = Real->getOperand(0);
+ auto *I0 = Imag->getOperand(0);
NodePtr Op0 = identifyNode(R0, I0);
NodePtr Op1 = nullptr;
@@ -852,11 +831,8 @@ ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real,
return nullptr;
if (Real->isBinaryOp()) {
- auto *R1 = dyn_cast<Instruction>(Real->getOperand(1));
- auto *I1 = dyn_cast<Instruction>(Imag->getOperand(1));
- if (!R1 || !I1)
- return nullptr;
-
+ auto *R1 = Real->getOperand(1);
+ auto *I1 = Imag->getOperand(1);
Op1 = identifyNode(R1, I1);
if (Op1 == nullptr)
return nullptr;
@@ -880,15 +856,20 @@ ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real,
}
ComplexDeinterleavingGraph::NodePtr
-ComplexDeinterleavingGraph::identifyNode(Instruction *Real, Instruction *Imag) {
- LLVM_DEBUG(dbgs() << "identifyNode on " << *Real << " / " << *Imag << "\n");
- assert(Real->getType() == Imag->getType() &&
+ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) {
+ LLVM_DEBUG(dbgs() << "identifyNode on " << *R << " / " << *I << "\n");
+ assert(R->getType() == I->getType() &&
"Real and imaginary parts should not have
diff erent types");
- if (NodePtr CN = getContainingComposite(Real, Imag)) {
+ if (NodePtr CN = getContainingComposite(R, I)) {
LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
return CN;
}
+ auto *Real = dyn_cast<Instruction>(R);
+ auto *Imag = dyn_cast<Instruction>(I);
+ if (!Real || !Imag)
+ return nullptr;
+
if (NodePtr CN = identifyDeinterleave(Real, Imag))
return CN;
@@ -931,6 +912,7 @@ ComplexDeinterleavingGraph::identifyNode(Instruction *Real, Instruction *Imag) {
ComplexDeinterleavingGraph::NodePtr
ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
Instruction *Imag) {
+
if ((Real->getOpcode() != Instruction::FAdd &&
Real->getOpcode() != Instruction::FSub &&
Real->getOpcode() != Instruction::FNeg) ||
@@ -967,8 +949,10 @@ ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
continue;
Instruction *I = dyn_cast<Instruction>(V);
- if (!I)
- return false;
+ if (!I) {
+ Addends.emplace_back(V, IsPositive);
+ continue;
+ }
// If an instruction has more than one user, it indicates that it either
// has an external user, which will be later checked by the checkNodes
@@ -989,20 +973,18 @@ ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
Worklist.emplace_back(I->getOperand(1), !IsPositive);
Worklist.emplace_back(I->getOperand(0), IsPositive);
} else if (I->getOpcode() == Instruction::FMul) {
- auto *A = dyn_cast<Instruction>(I->getOperand(0));
- if (A && A->getOpcode() == Instruction::FNeg) {
- A = dyn_cast<Instruction>(A->getOperand(0));
+ Value *A, *B;
+ if (match(I->getOperand(0), m_FNeg(m_Value(A)))) {
IsPositive = !IsPositive;
+ } else {
+ A = I->getOperand(0);
}
- if (!A)
- return false;
- auto *B = dyn_cast<Instruction>(I->getOperand(1));
- if (B && B->getOpcode() == Instruction::FNeg) {
- B = dyn_cast<Instruction>(B->getOperand(0));
+
+ if (match(I->getOperand(1), m_FNeg(m_Value(B)))) {
IsPositive = !IsPositive;
+ } else {
+ B = I->getOperand(1);
}
- if (!B)
- return false;
Muls.push_back(Product{A, B, IsPositive});
} else if (I->getOpcode() == Instruction::FNeg) {
Worklist.emplace_back(I->getOperand(0), !IsPositive);
@@ -1059,7 +1041,7 @@ bool ComplexDeinterleavingGraph::collectPartialMuls(
std::vector<PartialMulCandidate> &PartialMulCandidates) {
// Helper function to extract a common operand from two products
auto FindCommonInstruction = [](const Product &Real,
- const Product &Imag) -> Instruction * {
+ const Product &Imag) -> Value * {
if (Real.Multiplicand == Imag.Multiplicand ||
Real.Multiplicand == Imag.Multiplier)
return Real.Multiplicand;
@@ -1087,18 +1069,17 @@ bool ComplexDeinterleavingGraph::collectPartialMuls(
auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier
: ImagMuls[j].Multiplicand;
- bool Inverted = false;
auto Node = identifyNode(A, B);
- if (!Node) {
- std::swap(A, B);
- Inverted = true;
- Node = identifyNode(A, B);
+ if (Node) {
+ FoundCommon = true;
+ PartialMulCandidates.push_back({Common, Node, i, j, false});
}
- if (!Node)
- continue;
- FoundCommon = true;
- PartialMulCandidates.push_back({Common, Node, i, j, Inverted});
+ Node = identifyNode(B, A);
+ if (Node) {
+ FoundCommon = true;
+ PartialMulCandidates.push_back({Common, Node, i, j, true});
+ }
}
if (!FoundCommon)
return false;
@@ -1118,7 +1099,7 @@ ComplexDeinterleavingGraph::identifyMultiplications(
return nullptr;
// Map to store common instruction to node pointers
- std::map<Instruction *, NodePtr> CommonToNode;
+ std::map<Value *, NodePtr> CommonToNode;
std::vector<bool> Processed(Info.size(), false);
for (unsigned I = 0; I < Info.size(); ++I) {
if (Processed[I])
@@ -1838,8 +1819,8 @@ Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
processReductionOperation(ReplacementNode, Node);
break;
case ComplexDeinterleavingOperation::ReductionSelect: {
- auto *MaskReal = Node->Real->getOperand(0);
- auto *MaskImag = Node->Imag->getOperand(0);
+ auto *MaskReal = cast<Instruction>(Node->Real)->getOperand(0);
+ auto *MaskImag = cast<Instruction>(Node->Imag)->getOperand(0);
auto *A = replaceNode(Builder, Node->Operands[0]);
auto *B = replaceNode(Builder, Node->Operands[1]);
auto *NewMaskTy = VectorType::getDoubleElementsVectorType(
@@ -1860,11 +1841,13 @@ Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
void ComplexDeinterleavingGraph::processReductionOperation(
Value *OperationReplacement, RawNodePtr Node) {
- auto *OldPHIReal = ReductionInfo[Node->Real].first;
- auto *OldPHIImag = ReductionInfo[Node->Imag].first;
+ auto *Real = cast<Instruction>(Node->Real);
+ auto *Imag = cast<Instruction>(Node->Imag);
+ auto *OldPHIReal = ReductionInfo[Real].first;
+ auto *OldPHIImag = ReductionInfo[Imag].first;
auto *NewPHI = OldToNewPHI[OldPHIReal];
- auto *VTy = cast<VectorType>(Node->Real->getType());
+ auto *VTy = cast<VectorType>(Real->getType());
auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
// We have to interleave initial origin values coming from IncomingBlock
@@ -1880,8 +1863,8 @@ void ComplexDeinterleavingGraph::processReductionOperation(
// Deinterleave complex vector outside of loop so that it can be finally
// reduced
- auto *FinalReductionReal = ReductionInfo[Node->Real].second;
- auto *FinalReductionImag = ReductionInfo[Node->Imag].second;
+ auto *FinalReductionReal = ReductionInfo[Real].second;
+ auto *FinalReductionImag = ReductionInfo[Imag].second;
Builder.SetInsertPoint(
&*FinalReductionReal->getParent()->getFirstInsertionPt());
@@ -1890,11 +1873,11 @@ void ComplexDeinterleavingGraph::processReductionOperation(
OperationReplacement->getType(), OperationReplacement);
auto *NewReal = Builder.CreateExtractValue(Deinterleave, (uint64_t)0);
- FinalReductionReal->replaceUsesOfWith(Node->Real, NewReal);
+ FinalReductionReal->replaceUsesOfWith(Real, NewReal);
Builder.SetInsertPoint(FinalReductionImag);
auto *NewImag = Builder.CreateExtractValue(Deinterleave, 1);
- FinalReductionImag->replaceUsesOfWith(Node->Imag, NewImag);
+ FinalReductionImag->replaceUsesOfWith(Imag, NewImag);
}
void ComplexDeinterleavingGraph::replaceNodes() {
@@ -1911,10 +1894,12 @@ void ComplexDeinterleavingGraph::replaceNodes() {
if (RootNode->Operation ==
ComplexDeinterleavingOperation::ReductionOperation) {
- ReductionInfo[RootNode->Real].first->removeIncomingValue(BackEdge);
- ReductionInfo[RootNode->Imag].first->removeIncomingValue(BackEdge);
- DeadInstrRoots.push_back(RootNode->Real);
- DeadInstrRoots.push_back(RootNode->Imag);
+ auto *RootReal = cast<Instruction>(RootNode->Real);
+ auto *RootImag = cast<Instruction>(RootNode->Imag);
+ ReductionInfo[RootReal].first->removeIncomingValue(BackEdge);
+ ReductionInfo[RootImag].first->removeIncomingValue(BackEdge);
+ DeadInstrRoots.push_back(cast<Instruction>(RootReal));
+ DeadInstrRoots.push_back(cast<Instruction>(RootImag));
} else {
assert(R && "Unable to find replacement for RootInstruction");
DeadInstrRoots.push_back(RootInstruction);
More information about the llvm-commits
mailing list