[llvm] [IA]: Construct (de)interleave4 out of (de)interleave2 (PR #89276)
Paul Walker via llvm-commits
llvm-commits at lists.llvm.org
Fri Jun 21 06:47:17 PDT 2024
================
@@ -16585,17 +16585,87 @@ bool AArch64TargetLowering::lowerInterleavedStore(StoreInst *SI,
return true;
}
+bool getDeinterleavedValues(Value *DI,
+ SmallVectorImpl<Instruction *> &DeinterleavedValues) {
+ if (!DI->hasNUsesOrMore(2))
+ return false;
+ auto *Extr1 = dyn_cast<ExtractValueInst>(*(DI->user_begin()));
+ auto *Extr2 = dyn_cast<ExtractValueInst>(*(++DI->user_begin()));
+ if (!Extr1 || !Extr2)
+ return false;
+
+ if (!Extr1->hasNUsesOrMore(1) || !Extr2->hasNUsesOrMore(1))
+ return false;
+ auto *DI1 = *(Extr1->user_begin());
+ auto *DI2 = *(Extr2->user_begin());
+
+ if (!DI1->hasNUsesOrMore(2) || !DI2->hasNUsesOrMore(2))
+ return false;
+ // Leaf nodes of the deinterleave tree:
+ auto *A = dyn_cast<ExtractValueInst>(*(DI1->user_begin()));
+ auto *B = dyn_cast<ExtractValueInst>(*(++DI1->user_begin()));
+ auto *C = dyn_cast<ExtractValueInst>(*(DI2->user_begin()));
+ auto *D = dyn_cast<ExtractValueInst>(*(++DI2->user_begin()));
+ // Make sure that the A,B,C,D are instructions of ExtractValue,
+ // before getting the extract index
+ if (!A || !B || !C || !D)
+ return false;
+
+ DeinterleavedValues.resize(4);
+ // Place the values into the vector in the order of extraction:
+ DeinterleavedValues[A->getIndices()[0] + (Extr1->getIndices()[0] * 2)] = A;
+ DeinterleavedValues[B->getIndices()[0] + (Extr1->getIndices()[0] * 2)] = B;
+ DeinterleavedValues[C->getIndices()[0] + (Extr2->getIndices()[0] * 2)] = C;
+ DeinterleavedValues[D->getIndices()[0] + (Extr2->getIndices()[0] * 2)] = D;
+
+ // Make sure that A,B,C,D match the deinterleave tree pattern
+ if (!match(DeinterleavedValues[0], m_ExtractValue<0>(m_Deinterleave2(
+ m_ExtractValue<0>(m_Deinterleave2(m_Value()))))) ||
+ !match(DeinterleavedValues[1], m_ExtractValue<1>(m_Deinterleave2(
+ m_ExtractValue<0>(m_Deinterleave2(m_Value()))))) ||
+ !match(DeinterleavedValues[2], m_ExtractValue<0>(m_Deinterleave2(
+ m_ExtractValue<1>(m_Deinterleave2(m_Value()))))) ||
+ !match(DeinterleavedValues[3], m_ExtractValue<1>(m_Deinterleave2(
+ m_ExtractValue<1>(m_Deinterleave2(m_Value())))))) {
+ LLVM_DEBUG(dbgs() << "matching deinterleave4 failed\n");
+ return false;
+ }
+ // Order the values according to the deinterleaving order.
+ std::swap(DeinterleavedValues[1], DeinterleavedValues[2]);
+ return true;
+}
+
+void deleteDeadDeinterleaveInstructions(Instruction *DeadRoot) {
+ Value *DeadDeinterleave = nullptr, *DeadExtract = nullptr;
+ match(DeadRoot, m_ExtractValue(m_Value(DeadDeinterleave)));
+ assert(DeadDeinterleave != nullptr && "Match is expected to succeed");
+ match(DeadDeinterleave, m_Deinterleave2(m_Value(DeadExtract)));
+ assert(DeadExtract != nullptr && "Match is expected to succeed");
+ DeadRoot->eraseFromParent();
+ if (DeadDeinterleave->getNumUses() == 0)
+ cast<Instruction>(DeadDeinterleave)->eraseFromParent();
+ if (DeadExtract->getNumUses() == 0)
+ cast<Instruction>(DeadExtract)->eraseFromParent();
+}
+
bool AArch64TargetLowering::lowerDeinterleaveIntrinsicToLoad(
IntrinsicInst *DI, LoadInst *LI) const {
// Only deinterleave2 supported at present.
if (DI->getIntrinsicID() != Intrinsic::vector_deinterleave2)
return false;
- // Only a factor of 2 supported at present.
- const unsigned Factor = 2;
-
- VectorType *VTy = cast<VectorType>(DI->getType()->getContainedType(0));
+ SmallVector<Instruction *, 4> DeinterleavedValues;
const DataLayout &DL = DI->getModule()->getDataLayout();
+ unsigned Factor = 2;
+ VectorType *VTy = cast<VectorType>(DI->getType()->getContainedType(0));
+
+ if (getDeinterleavedValues(DI, DeinterleavedValues)) {
+ Factor = DeinterleavedValues.size();
+ VTy = cast<VectorType>(DeinterleavedValues[0]->getType());
+ }
----------------
paulwalker-arm wrote:
It would be nice to unify the code a little more, especially because after inspection I think the current upstream Factor==2 code is not good, especially when compared to your new Factor==4 code. If you don't mind improve that as part of this PR then I think it worth implementing something like:
```
getDeinterleavedValues(DI, DeinterleavedValues....) {
if (getDeinterleaved4Values(Di, DeinterleavedValues...))
return true;
return getDeinterleaved2Values(Di, DeinterleavedValues...));
}.
```
and then here you'd do as you've done for the interleave handling, i.e.
```
if (!getDeinterleavedValues(DI, DeinterleavedValues....))
return false;
```
I'm assuming `getDeinterleaved2Values` will be a fairly trivial function
https://github.com/llvm/llvm-project/pull/89276
More information about the llvm-commits
mailing list