[llvm] [AArch64] Support symmetric complex deinterleaving with higher factors (PR #151295)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Jul 30 01:54:29 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-aarch64
Author: David Sherwood (david-arm)
<details>
<summary>Changes</summary>
For loops such as this:
struct foo {
double a, b;
};
void foo(struct foo *dst, struct foo *src, int n) {
for (int i = 0; i < n; i++) {
dst[i].a += src[i].a * 3.2;
dst[i].b += src[i].b * 3.2;
}
}
the complex deinterleaving pass will spot that the deinterleaving
associated with the structured loads cancels out the interleaving
associated with the structured stores. This happens even though
they are not truly "complex" numbers because the pass can handle
symmetric operations too. This is great because it means we can
then perform normal loads and stores instead. However, we can also
do the same for higher interleave factors, e.g. 4:
struct foo {
double a, b, c, d;
};
void foo(struct foo *dst, struct foo *src, int n) {
for (int i = 0; i < n; i++) {
dst[i].a += src[i].a * 3.2;
dst[i].b += src[i].b * 3.2;
dst[i].c += src[i].c * 3.2;
dst[i].d += src[i].d * 3.2;
}
}
This PR extends the pass to effectively treat such structures as
a set of complex numbers, i.e.
struct foo_alt {
std::complex<double> x, y;
};
with equivalence between members:
foo_alt.x.real == foo.a
foo_alt.x.imag == foo.b
foo_alt.y.real == foo.c
foo_alt.y.imag == foo.d
I've written the code to handle sets with arbitrary numbers of
complex values, but since we only support interleave factors
between 2 and 4 I've restricted the sets to 1 or 2 complex
numbers. Also, for now I've restricted support for interleave
factors of 4 to purely symmetric operations only. However, it
could also be extended to handle complex multiplications,
reductions, etc.
Fixes: https://github.com/llvm/llvm-project/issues/144795
---
Patch is 48.43 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/151295.diff
4 Files Affected:
- (modified) llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp (+321-139)
- (modified) llvm/test/CodeGen/AArch64/complex-deinterleaving-reductions-scalable.ll (+110)
- (added) llvm/test/CodeGen/AArch64/complex-deinterleaving-symmetric-fixed.ll (+76)
- (added) llvm/test/CodeGen/AArch64/complex-deinterleaving-symmetric-scalable.ll (+113)
``````````diff
diff --git a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
index 9b2851eb42b40..2787227e0a255 100644
--- a/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
+++ b/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
@@ -67,6 +67,7 @@
#include "llvm/CodeGen/TargetLowering.h"
#include "llvm/CodeGen/TargetSubtargetInfo.h"
#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/InitializePasses.h"
#include "llvm/Target/TargetMachine.h"
@@ -107,6 +108,42 @@ static bool isNeg(Value *V);
/// Returns the operand for negation operation.
static Value *getNegOperand(Value *V);
+namespace {
+struct ComplexValue {
+ Value *Real = nullptr;
+ Value *Imag = nullptr;
+
+ bool operator==(const ComplexValue &Other) const {
+ return Real == Other.Real && Imag == Other.Imag;
+ }
+};
+hash_code hash_value(const ComplexValue &Arg) {
+ return hash_combine(DenseMapInfo<Value *>::getHashValue(Arg.Real),
+ DenseMapInfo<Value *>::getHashValue(Arg.Imag));
+}
+} // end namespace
+typedef SmallVector<struct ComplexValue, 2> ComplexValues;
+
+namespace llvm {
+template <> struct DenseMapInfo<ComplexValue> {
+ static inline ComplexValue getEmptyKey() {
+ return {DenseMapInfo<Value *>::getEmptyKey(),
+ DenseMapInfo<Value *>::getEmptyKey()};
+ }
+ static inline ComplexValue getTombstoneKey() {
+ return {DenseMapInfo<Value *>::getTombstoneKey(),
+ DenseMapInfo<Value *>::getTombstoneKey()};
+ }
+ static unsigned getHashValue(const ComplexValue &Val) {
+ return hash_combine(DenseMapInfo<Value *>::getHashValue(Val.Real),
+ DenseMapInfo<Value *>::getHashValue(Val.Imag));
+ }
+ static bool isEqual(const ComplexValue &LHS, const ComplexValue &RHS) {
+ return LHS.Real == RHS.Real && LHS.Imag == RHS.Imag;
+ }
+};
+} // end namespace llvm
+
namespace {
template <typename T, typename IterT>
std::optional<T> findCommonBetweenCollections(IterT A, IterT B) {
@@ -145,7 +182,13 @@ struct ComplexDeinterleavingCompositeNode {
ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
Value *R, Value *I)
- : Operation(Op), Real(R), Imag(I) {}
+ : Operation(Op) {
+ Vals.push_back({R, I});
+ }
+
+ ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
+ ComplexValues &Other)
+ : Operation(Op), Vals(Other) {}
private:
friend class ComplexDeinterleavingGraph;
@@ -155,8 +198,7 @@ struct ComplexDeinterleavingCompositeNode {
public:
ComplexDeinterleavingOperation Operation;
- Value *Real;
- Value *Imag;
+ ComplexValues Vals;
// This two members are required exclusively for generating
// ComplexDeinterleavingOperation::Symmetric operations.
@@ -192,10 +234,12 @@ struct ComplexDeinterleavingCompositeNode {
};
OS << "- CompositeNode: " << this << "\n";
- OS << " Real: ";
- PrintValue(Real);
- OS << " Imag: ";
- PrintValue(Imag);
+ for (unsigned I = 0; I < Vals.size(); I++) {
+ OS << " Real(" << I << ") : ";
+ PrintValue(Vals[I].Real);
+ OS << " Imag(" << I << ") : ";
+ PrintValue(Vals[I].Imag);
+ }
OS << " ReplacementNode: ";
PrintValue(ReplacementNode);
OS << " Operation: " << (int)Operation << "\n";
@@ -233,14 +277,16 @@ class ComplexDeinterleavingGraph {
};
explicit ComplexDeinterleavingGraph(const TargetLowering *TL,
- const TargetLibraryInfo *TLI)
- : TL(TL), TLI(TLI) {}
+ const TargetLibraryInfo *TLI,
+ unsigned Factor)
+ : TL(TL), TLI(TLI), Factor(Factor) {}
private:
const TargetLowering *TL = nullptr;
const TargetLibraryInfo *TLI = nullptr;
+ unsigned Factor;
SmallVector<NodePtr> CompositeNodes;
- DenseMap<std::pair<Value *, Value *>, NodePtr> CachedResult;
+ DenseMap<ComplexValues, NodePtr> CachedResult;
SmallPtrSet<Instruction *, 16> FinalInstructions;
@@ -305,10 +351,26 @@ class ComplexDeinterleavingGraph {
I);
}
+ NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation,
+ ComplexValues &Vals) {
+#ifndef NDEBUG
+ for (auto &V : Vals) {
+ assert(
+ ((Operation != ComplexDeinterleavingOperation::ReductionPHI &&
+ Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
+ (V.Real && V.Imag)) &&
+ "Reduction related nodes must have Real and Imaginary parts");
+ }
+#endif
+ return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation,
+ Vals);
+ }
+
NodePtr submitCompositeNode(NodePtr Node) {
CompositeNodes.push_back(Node);
- if (Node->Real)
- CachedResult[{Node->Real, Node->Imag}] = Node;
+ if (Node->Vals[0].Real) {
+ CachedResult[Node->Vals] = Node;
+ }
return Node;
}
@@ -340,11 +402,17 @@ class ComplexDeinterleavingGraph {
/// 270: r: ar + bi
/// i: ai - br
NodePtr identifyAdd(Instruction *Real, Instruction *Imag);
- NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag);
+ NodePtr identifySymmetricOperation(ComplexValues &Vals);
NodePtr identifyPartialReduction(Value *R, Value *I);
NodePtr identifyDotProduct(Value *Inst);
- NodePtr identifyNode(Value *R, Value *I);
+ NodePtr identifyNode(ComplexValues &Vals);
+
+ NodePtr identifyNode(Value *R, Value *I) {
+ ComplexValues Vals;
+ Vals.push_back({R, I});
+ return identifyNode(Vals);
+ }
/// Determine if a sum of complex numbers can be formed from \p RealAddends
/// and \p ImagAddens. If \p Accumulator is not null, add the result to it.
@@ -390,13 +458,13 @@ class ComplexDeinterleavingGraph {
/// odd indices for /pImag instructions (only for fixed-width vectors)
/// * Using two extractvalue instructions applied to `vector.deinterleave2`
/// intrinsic (for both fixed and scalable vectors)
- NodePtr identifyDeinterleave(Instruction *Real, Instruction *Imag);
+ NodePtr identifyDeinterleave(ComplexValues &Vals);
/// identifying the operation that represents a complex number repeated in a
/// Splat vector. There are two possible types of splats: ConstantExpr with
/// the opcode ShuffleVector and ShuffleVectorInstr. Both should have an
/// initialization mask with all values set to zero.
- NodePtr identifySplat(Value *Real, Value *Imag);
+ NodePtr identifySplat(ComplexValues &Vals);
NodePtr identifyPHINode(Instruction *Real, Instruction *Imag);
@@ -447,7 +515,7 @@ class ComplexDeinterleaving {
bool runOnFunction(Function &F);
private:
- bool evaluateBasicBlock(BasicBlock *B);
+ bool evaluateBasicBlock(BasicBlock *B, unsigned Factor);
const TargetLowering *TL = nullptr;
const TargetLibraryInfo *TLI = nullptr;
@@ -500,7 +568,15 @@ bool ComplexDeinterleaving::runOnFunction(Function &F) {
bool Changed = false;
for (auto &B : F)
- Changed |= evaluateBasicBlock(&B);
+ Changed |= evaluateBasicBlock(&B, 2);
+
+ // TODO: Permit changes for both interleave factors in the same function.
+ if (!Changed) {
+ for (auto &B : F)
+ Changed |= evaluateBasicBlock(&B, 4);
+ }
+
+ // TODO: We can also support interleave factors of 6 and 8 if needed.
return Changed;
}
@@ -545,8 +621,8 @@ Value *getNegOperand(Value *V) {
return I->getOperand(1);
}
-bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) {
- ComplexDeinterleavingGraph Graph(TL, TLI);
+bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B, unsigned Factor) {
+ ComplexDeinterleavingGraph Graph(TL, TLI, Factor);
if (Graph.collectPotentialReductions(B))
Graph.identifyReductionNodes();
@@ -669,6 +745,7 @@ ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
Instruction *Imag) {
LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag
<< "\n");
+
// Determine rotation
auto IsAdd = [](unsigned Op) {
return Op == Instruction::FAdd || Op == Instruction::Add;
@@ -865,43 +942,57 @@ static bool isInstructionPotentiallySymmetric(Instruction *I) {
}
ComplexDeinterleavingGraph::NodePtr
-ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real,
- Instruction *Imag) {
- if (Real->getOpcode() != Imag->getOpcode())
- return nullptr;
+ComplexDeinterleavingGraph::identifySymmetricOperation(ComplexValues &Vals) {
+ auto *FirstReal = cast<Instruction>(Vals[0].Real);
+ unsigned FirstOpc = FirstReal->getOpcode();
+ for (auto &V : Vals) {
+ auto *Real = cast<Instruction>(V.Real);
+ auto *Imag = cast<Instruction>(V.Imag);
+ if (Real->getOpcode() != FirstOpc || Imag->getOpcode() != FirstOpc)
+ return nullptr;
- if (!isInstructionPotentiallySymmetric(Real) ||
- !isInstructionPotentiallySymmetric(Imag))
- return nullptr;
+ if (!isInstructionPotentiallySymmetric(Real) ||
+ !isInstructionPotentiallySymmetric(Imag))
+ return nullptr;
- auto *R0 = Real->getOperand(0);
- auto *I0 = Imag->getOperand(0);
+ if (isa<FPMathOperator>(FirstReal))
+ if (Real->getFastMathFlags() != FirstReal->getFastMathFlags() ||
+ Imag->getFastMathFlags() != FirstReal->getFastMathFlags())
+ return nullptr;
+ }
+
+ ComplexValues OpVals;
+ for (auto &V : Vals) {
+ auto *R0 = cast<Instruction>(V.Real)->getOperand(0);
+ auto *I0 = cast<Instruction>(V.Imag)->getOperand(0);
+ OpVals.push_back({R0, I0});
+ }
- NodePtr Op0 = identifyNode(R0, I0);
+ NodePtr Op0 = identifyNode(OpVals);
NodePtr Op1 = nullptr;
if (Op0 == nullptr)
return nullptr;
- if (Real->isBinaryOp()) {
- auto *R1 = Real->getOperand(1);
- auto *I1 = Imag->getOperand(1);
- Op1 = identifyNode(R1, I1);
+ if (FirstReal->isBinaryOp()) {
+ OpVals.clear();
+ for (auto &V : Vals) {
+ auto *R1 = cast<Instruction>(V.Real)->getOperand(1);
+ auto *I1 = cast<Instruction>(V.Imag)->getOperand(1);
+ OpVals.push_back({R1, I1});
+ }
+ Op1 = identifyNode(OpVals);
if (Op1 == nullptr)
return nullptr;
}
- if (isa<FPMathOperator>(Real) &&
- Real->getFastMathFlags() != Imag->getFastMathFlags())
- return nullptr;
-
- auto Node = prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric,
- Real, Imag);
- Node->Opcode = Real->getOpcode();
- if (isa<FPMathOperator>(Real))
- Node->Flags = Real->getFastMathFlags();
+ auto Node =
+ prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric, Vals);
+ Node->Opcode = FirstReal->getOpcode();
+ if (isa<FPMathOperator>(FirstReal))
+ Node->Flags = FirstReal->getFastMathFlags();
Node->addOperand(Op0);
- if (Real->isBinaryOp())
+ if (FirstReal->isBinaryOp())
Node->addOperand(Op1);
return submitCompositeNode(Node);
@@ -909,7 +1000,6 @@ ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real,
ComplexDeinterleavingGraph::NodePtr
ComplexDeinterleavingGraph::identifyDotProduct(Value *V) {
-
if (!TL->isComplexDeinterleavingOperationSupported(
ComplexDeinterleavingOperation::CDot, V->getType())) {
LLVM_DEBUG(dbgs() << "Target doesn't support complex deinterleaving "
@@ -1054,65 +1144,77 @@ ComplexDeinterleavingGraph::identifyPartialReduction(Value *R, Value *I) {
}
ComplexDeinterleavingGraph::NodePtr
-ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) {
- auto It = CachedResult.find({R, I});
+ComplexDeinterleavingGraph::identifyNode(ComplexValues &Vals) {
+ auto It = CachedResult.find(Vals);
if (It != CachedResult.end()) {
LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
return It->second;
}
- if (NodePtr CN = identifyPartialReduction(R, I))
- return CN;
-
- bool IsReduction = RealPHI == R && (!ImagPHI || ImagPHI == I);
- if (!IsReduction && R->getType() != I->getType())
- return nullptr;
+ if (Vals.size() == 1) {
+ assert(Factor == 2 && "Can only handle interleave factors of 2");
+ Value *R = Vals[0].Real;
+ Value *I = Vals[0].Imag;
+ if (NodePtr CN = identifyPartialReduction(R, I))
+ return CN;
+ bool IsReduction = RealPHI == R && (!ImagPHI || ImagPHI == I);
+ if (!IsReduction && R->getType() != I->getType())
+ return nullptr;
+ }
- if (NodePtr CN = identifySplat(R, I))
+ if (NodePtr CN = identifySplat(Vals))
return CN;
- auto *Real = dyn_cast<Instruction>(R);
- auto *Imag = dyn_cast<Instruction>(I);
- if (!Real || !Imag)
- return nullptr;
+ for (auto &V : Vals) {
+ auto *Real = dyn_cast<Instruction>(V.Real);
+ auto *Imag = dyn_cast<Instruction>(V.Imag);
+ if (!Real || !Imag)
+ return nullptr;
+ }
- if (NodePtr CN = identifyDeinterleave(Real, Imag))
+ if (NodePtr CN = identifyDeinterleave(Vals))
return CN;
- if (NodePtr CN = identifyPHINode(Real, Imag))
- return CN;
+ if (Vals.size() == 1) {
+ assert(Factor == 2 && "Can only handle interleave factors of 2");
+ auto *Real = dyn_cast<Instruction>(Vals[0].Real);
+ auto *Imag = dyn_cast<Instruction>(Vals[0].Imag);
+ if (NodePtr CN = identifyPHINode(Real, Imag))
+ return CN;
- if (NodePtr CN = identifySelectNode(Real, Imag))
- return CN;
+ if (NodePtr CN = identifySelectNode(Real, Imag))
+ return CN;
- auto *VTy = cast<VectorType>(Real->getType());
- auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
+ auto *VTy = cast<VectorType>(Real->getType());
+ auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
- bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported(
- ComplexDeinterleavingOperation::CMulPartial, NewVTy);
- bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported(
- ComplexDeinterleavingOperation::CAdd, NewVTy);
+ bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported(
+ ComplexDeinterleavingOperation::CMulPartial, NewVTy);
+ bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported(
+ ComplexDeinterleavingOperation::CAdd, NewVTy);
- if (HasCMulSupport && isInstructionPairMul(Real, Imag)) {
- if (NodePtr CN = identifyPartialMul(Real, Imag))
- return CN;
- }
+ if (HasCMulSupport && isInstructionPairMul(Real, Imag)) {
+ if (NodePtr CN = identifyPartialMul(Real, Imag))
+ return CN;
+ }
- if (HasCAddSupport && isInstructionPairAdd(Real, Imag)) {
- if (NodePtr CN = identifyAdd(Real, Imag))
- return CN;
- }
+ if (HasCAddSupport && isInstructionPairAdd(Real, Imag)) {
+ if (NodePtr CN = identifyAdd(Real, Imag))
+ return CN;
+ }
- if (HasCMulSupport && HasCAddSupport) {
- if (NodePtr CN = identifyReassocNodes(Real, Imag))
- return CN;
+ if (HasCMulSupport && HasCAddSupport) {
+ if (NodePtr CN = identifyReassocNodes(Real, Imag)) {
+ return CN;
+ }
+ }
}
- if (NodePtr CN = identifySymmetricOperation(Real, Imag))
+ if (NodePtr CN = identifySymmetricOperation(Vals))
return CN;
LLVM_DEBUG(dbgs() << " - Not recognised as a valid pattern.\n");
- CachedResult[{R, I}] = nullptr;
+ CachedResult[Vals] = nullptr;
return nullptr;
}
@@ -1256,9 +1358,10 @@ ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
return nullptr;
}
assert(FinalNode && "FinalNode can not be nullptr here");
+ assert(FinalNode->Vals.size() == 1);
// Set the Real and Imag fields of the final node and submit it
- FinalNode->Real = Real;
- FinalNode->Imag = Imag;
+ FinalNode->Vals[0].Real = Real;
+ FinalNode->Vals[0].Imag = Imag;
submitCompositeNode(FinalNode);
return FinalNode;
}
@@ -1381,7 +1484,7 @@ ComplexDeinterleavingGraph::identifyMultiplications(
auto NodeA = It->second;
auto NodeB = PMI.Node;
- auto IsMultiplicandReal = PMI.Common == NodeA->Real;
+ auto IsMultiplicandReal = PMI.Common == NodeA->Vals[0].Real;
// The following table illustrates the relationship between multiplications
// and rotations. If we consider the multiplication (X + iY) * (U + iV), we
// can see:
@@ -1423,10 +1526,10 @@ ComplexDeinterleavingGraph::identifyMultiplications(
LLVM_DEBUG({
dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n";
- dbgs().indent(4) << "X: " << *NodeA->Real << "\n";
- dbgs().indent(4) << "Y: " << *NodeA->Imag << "\n";
- dbgs().indent(4) << "U: " << *NodeB->Real << "\n";
- dbgs().indent(4) << "V: " << *NodeB->Imag << "\n";
+ dbgs().indent(4) << "X: " << *NodeA->Vals[0].Real << "\n";
+ dbgs().indent(4) << "Y: " << *NodeA->Vals[0].Imag << "\n";
+ dbgs().indent(4) << "U: " << *NodeB->Vals[0].Real << "\n";
+ dbgs().indent(4) << "V: " << *NodeB->Vals[0].Imag << "\n";
dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
});
@@ -1595,10 +1698,13 @@ bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
ComplexDeinterleavingOperation::ReductionOperation ||
RootNode->Operation ==
ComplexDeinterleavingOperation::ReductionSingle);
+ assert(RootNode->Vals.size() == 1 &&
+ "Cannot handle reductions involving multiple complex values");
// Find out which part, Real or Imag, comes later, and only if we come to
// the latest part, add it to OrderedRoots.
- auto *R = cast<Instruction>(RootNode->Real);
- auto *I = RootNode->Imag ? cast<Instruction>(RootNode->Imag) : nullptr;
+ auto *R = cast<Instruction>(RootNode->Vals[0].Real);
+ auto *I = RootNode->Vals[0].Imag ? cast<Instruction>(RootNode->Vals[0].Imag)
+ : nullptr;
Instruction *ReplacementAnchor;
if (I)
@@ -1631,6 +1737,8 @@ bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) {
bool FoundPotentialReduction = false;
+ if (Factor != 2)
+ return false;
auto *Br = dyn_cast<BranchInst>(B->getTerminator());
if (!Br || Br->getNumSuccessors() != 2)
@@ -1682,6 +1790,8 @@ bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) {
}
void ComplexDeinterleavingGraph::identifyReductionNodes() {
+ assert(Factor == 2 && "Cannot handle multiple complex values");
+
SmallVector<bool> Processed(ReductionInfo.size(), false);
SmallVector<Instruction *> OperationInstruction;
for (auto &P : ReductionInfo)
@@ -1771,11 +1881,11 @@ void ComplexDeinterleavingGraph::identifyReductionNodes() {
}
bool ComplexDeinterleavingGraph::checkNodes() {
-
bool FoundDeinterleaveNode = false;
for (NodePtr N : CompositeNodes) {
if (!N->areOperandsValid())
return false;
+
if (N->Operation == ComplexDeinterleavingOperation::Deinterleave)
FoundDeinterleaveNode = true;
}
@@ -1861,17 +1971,33 @@ bool ComplexDeinterleavingGraph::checkNodes() {
ComplexDeinterleavingGraph::NodePtr
ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) {
if (auto *Intrinsic = dyn_cast<IntrinsicInst>(RootI)) {
- if (Intrinsic->getIntrinsicID() != Intrinsic::vector_interleave2)
+ if (Intrinsic::getInterleaveIntrinsicID(Factor) !=
+ Intrinsic->getIntrinsicID())
return nullptr;
- auto *Real = dyn_cast<Instruction>(Intrinsic->getOperand(0));
- auto *Imag = dyn_cast<Instruction>(Intrinsic->getOperand(1));
- if (!Real || !Imag)
- return nullptr;
+ ComplexValues Vals;
+ for (unsigned I = 0; I < Factor; I += 2) {
+ auto *Real = dyn_cast<Instruction>(Intrinsic->getOperand(I));
+ auto *Imag = dyn_cast<Instruction>(Intrinsic->getOperand(I + 1));
+ if (!Real || !Imag)
+ return nullptr;
+ Vals.push_back({Real, Imag});
+ }
- return identifyNode(Real, Imag);
+ ComplexDeinterleavingGraph::NodePtr Node1 = identifyNode(Vals);
+ if (!Node1)
+ return nullptr;
+ return Node1;
}
+ // TODO: We could also add support for fixed-width interleave factors of 4
+ // and above, but currently for symmetric operations the interleaves and
+ // deinterleaves are...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/151295
More information about the llvm-commits
mailing list