[llvm] Add support for single reductions in ComplexDeinterleavingPass (PR #112875)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Oct 18 03:10:18 PDT 2024
github-actions[bot] wrote:
<!--LLVM CODE FORMAT COMMENT: {clang-format}-->
:warning: C/C++ code formatter, clang-format found issues in your code. :warning:
<details>
<summary>
You can test this locally with the following command:
</summary>
``````````bash
git-clang-format --diff 3764d0ff15ef281974879002e27857a041bd5b9c d77a3e402d05290903f75ce2dc7478f14a3943ef --extensions h,cpp -- llvm/include/llvm/CodeGen/ComplexDeinterleavingPass.h llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
``````````
</details>
<details>
<summary>
View the diff from clang-format here.
</summary>
``````````diff
diff --git a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
index 08287a4d5e..a79105ab0c 100644
--- a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
+++ b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
@@ -201,9 +201,7 @@ public:
}
}
- bool AreOperandsValid() {
- return OperandsValid;
- }
+ bool AreOperandsValid() { return OperandsValid; }
};
class ComplexDeinterleavingGraph {
@@ -918,11 +916,11 @@ ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I, bool &FromCache) {
return It->second;
}
- if(NodePtr CN = identifyPartialReduction(R, I))
+ if (NodePtr CN = identifyPartialReduction(R, I))
return CN;
bool IsReduction = RealPHI == R && (!ImagPHI || ImagPHI == I);
- if(!IsReduction && R->getType() != I->getType())
+ if (!IsReduction && R->getType() != I->getType())
return nullptr;
if (NodePtr CN = identifySplat(R, I))
@@ -1450,18 +1448,20 @@ bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
if (It != RootToNode.end()) {
auto RootNode = It->second;
assert(RootNode->Operation ==
- ComplexDeinterleavingOperation::ReductionOperation || RootNode->Operation == ComplexDeinterleavingOperation::ReductionSingle);
+ ComplexDeinterleavingOperation::ReductionOperation ||
+ RootNode->Operation ==
+ ComplexDeinterleavingOperation::ReductionSingle);
// Find out which part, Real or Imag, comes later, and only if we come to
// the latest part, add it to OrderedRoots.
auto *R = cast<Instruction>(RootNode->Real);
auto *I = RootNode->Imag ? cast<Instruction>(RootNode->Imag) : nullptr;
Instruction *ReplacementAnchor;
- if(I)
+ if (I)
ReplacementAnchor = R->comesBefore(I) ? I : R;
- else
+ else
ReplacementAnchor = R;
-
+
if (ReplacementAnchor != RootI)
return false;
OrderedRoots.push_back(RootI);
@@ -1553,7 +1553,7 @@ void ComplexDeinterleavingGraph::identifyReductionNodes() {
for (size_t j = i + 1; j < OperationInstruction.size(); ++j) {
if (Processed[j])
continue;
-
+
auto *Imag = OperationInstruction[j];
if (Real->getType() != Imag->getType())
continue;
@@ -1588,18 +1588,20 @@ void ComplexDeinterleavingGraph::identifyReductionNodes() {
// We want to check that we have 2 operands, but the function attributes
// being counted as operands bloats this value.
- if(Real->getNumOperands() < 2)
+ if (Real->getNumOperands() < 2)
continue;
RealPHI = ReductionInfo[Real].first;
ImagPHI = nullptr;
PHIsFound = false;
auto Node = identifyNode(Real->getOperand(0), Real->getOperand(1));
- if(Node && PHIsFound) {
- LLVM_DEBUG(dbgs() << "Identified single reduction starting from instruction: "
- << *Real << "/" << *ReductionInfo[Real].second << "\n");
+ if (Node && PHIsFound) {
+ LLVM_DEBUG(
+ dbgs() << "Identified single reduction starting from instruction: "
+ << *Real << "/" << *ReductionInfo[Real].second << "\n");
Processed[i] = true;
- auto RootNode = prepareCompositeNode(ComplexDeinterleavingOperation::ReductionSingle, Real, nullptr);
+ auto RootNode = prepareCompositeNode(
+ ComplexDeinterleavingOperation::ReductionSingle, Real, nullptr);
RootNode->addOperand(Node);
RootToNode[Real] = RootNode;
submitCompositeNode(RootNode);
@@ -2059,7 +2061,8 @@ Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
return ReplacementNode;
}
-void ComplexDeinterleavingGraph::processReductionSingle(Value *OperationReplacement, RawNodePtr Node) {
+void ComplexDeinterleavingGraph::processReductionSingle(
+ Value *OperationReplacement, RawNodePtr Node) {
auto *Real = cast<Instruction>(Node->Real);
auto *OldPHI = ReductionInfo[Real].first;
auto *NewPHI = OldToNewPHI[OldPHI];
@@ -2071,21 +2074,22 @@ void ComplexDeinterleavingGraph::processReductionSingle(Value *OperationReplacem
IRBuilder<> Builder(Incoming->getTerminator());
Value *NewInit = nullptr;
- if(auto *C = dyn_cast<Constant>(Init)) {
- if(C->isZeroValue())
+ if (auto *C = dyn_cast<Constant>(Init)) {
+ if (C->isZeroValue())
NewInit = Constant::getNullValue(NewVTy);
}
if (!NewInit)
NewInit = Builder.CreateIntrinsic(Intrinsic::vector_interleave2, NewVTy,
- {Init, Constant::getNullValue(VTy)});
+ {Init, Constant::getNullValue(VTy)});
NewPHI->addIncoming(NewInit, Incoming);
NewPHI->addIncoming(OperationReplacement, BackEdge);
auto *FinalReduction = ReductionInfo[Real].second;
Builder.SetInsertPoint(&*FinalReduction->getParent()->getFirstInsertionPt());
- // TODO Ensure that the `AddReduce` here matches the original, found in `FinalReduction`
+ // TODO Ensure that the `AddReduce` here matches the original, found in
+ // `FinalReduction`
auto *AddReduce = Builder.CreateAddReduce(OperationReplacement);
FinalReduction->replaceAllUsesWith(AddReduce);
}
@@ -2151,7 +2155,8 @@ void ComplexDeinterleavingGraph::replaceNodes() {
ReductionInfo[RootImag].first->removeIncomingValue(BackEdge);
DeadInstrRoots.push_back(RootReal);
DeadInstrRoots.push_back(RootImag);
- } else if(RootNode->Operation == ComplexDeinterleavingOperation::ReductionSingle) {
+ } else if (RootNode->Operation ==
+ ComplexDeinterleavingOperation::ReductionSingle) {
auto *RootInst = cast<Instruction>(RootNode->Real);
ReductionInfo[RootInst].first->removeIncomingValue(BackEdge);
DeadInstrRoots.push_back(ReductionInfo[RootInst].second);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 8068bb6740..6f58ed4f6c 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -29184,7 +29184,10 @@ Value *AArch64TargetLowering::createComplexDeinterleavingIR(
if (TyWidth > 128) {
int Stride = Ty->getElementCount().getKnownMinValue() / 2;
- int AccStride = cast<VectorType>(Accumulator->getType())->getElementCount().getKnownMinValue() / 2;
+ int AccStride = cast<VectorType>(Accumulator->getType())
+ ->getElementCount()
+ .getKnownMinValue() /
+ 2;
auto *HalfTy = VectorType::getHalfElementsVectorType(Ty);
auto *LowerSplitA = B.CreateExtractVector(HalfTy, InputA, B.getInt64(0));
auto *LowerSplitB = B.CreateExtractVector(HalfTy, InputB, B.getInt64(0));
@@ -29195,19 +29198,22 @@ Value *AArch64TargetLowering::createComplexDeinterleavingIR(
Value *LowerSplitAcc = nullptr;
Value *UpperSplitAcc = nullptr;
Type *FullTy = Ty;
- FullTy = Accumulator->getType();
- auto *HalfAccTy = VectorType::getHalfElementsVectorType(cast<VectorType>(Accumulator->getType()));
- LowerSplitAcc = B.CreateExtractVector(HalfAccTy, Accumulator, B.getInt64(0));
- UpperSplitAcc =
- B.CreateExtractVector(HalfAccTy, Accumulator, B.getInt64(AccStride));
+ FullTy = Accumulator->getType();
+ auto *HalfAccTy = VectorType::getHalfElementsVectorType(
+ cast<VectorType>(Accumulator->getType()));
+ LowerSplitAcc =
+ B.CreateExtractVector(HalfAccTy, Accumulator, B.getInt64(0));
+ UpperSplitAcc =
+ B.CreateExtractVector(HalfAccTy, Accumulator, B.getInt64(AccStride));
auto *LowerSplitInt = createComplexDeinterleavingIR(
B, OperationType, Rotation, LowerSplitA, LowerSplitB, LowerSplitAcc);
auto *UpperSplitInt = createComplexDeinterleavingIR(
B, OperationType, Rotation, UpperSplitA, UpperSplitB, UpperSplitAcc);
- auto *Result = B.CreateInsertVector(FullTy, PoisonValue::get(FullTy), LowerSplitInt,
- B.getInt64(0));
- return B.CreateInsertVector(FullTy, Result, UpperSplitInt, B.getInt64(AccStride));
+ auto *Result = B.CreateInsertVector(FullTy, PoisonValue::get(FullTy),
+ LowerSplitInt, B.getInt64(0));
+ return B.CreateInsertVector(FullTy, Result, UpperSplitInt,
+ B.getInt64(AccStride));
}
if (OperationType == ComplexDeinterleavingOperation::CMulPartial) {
``````````
</details>
https://github.com/llvm/llvm-project/pull/112875
More information about the llvm-commits
mailing list