[llvm] [SLP][REVEC] Fix CommonMask is transformed into vector form but used outside finalize. (PR #120952)

Alexey Bataev via llvm-commits llvm-commits at lists.llvm.org
Thu Dec 26 09:50:32 PST 2024


alexey-bataev wrote:

> > > Take the following code as an example.
> > > ```
> > >     if (Action) {
> > >       Value *Vec = InVectors.front();
> > >       if (InVectors.size() == 2) {
> > >         Vec = createShuffle(Vec, InVectors.back(), CommonMask);
> > >         InVectors.pop_back();
> > >       } else {
> > >         Vec = createShuffle(Vec, nullptr, CommonMask);
> > >       }
> > >       for (unsigned Idx = 0, Sz = CommonMask.size(); Idx < Sz; ++Idx)
> > >         if (CommonMask[Idx] != PoisonMaskElem)
> > >           CommonMask[Idx] = Idx;
> > >       assert(VF > 0 &&
> > >              "Expected vector length for the final value before action.");
> > >       unsigned VecVF = cast<FixedVectorType>(Vec->getType())->getNumElements();
> > >       if (VecVF < VF) {
> > >         SmallVector<int> ResizeMask(VF, PoisonMaskElem);
> > >         std::iota(ResizeMask.begin(), std::next(ResizeMask.begin(), VecVF), 0);
> > >         Vec = createShuffle(Vec, nullptr, ResizeMask);
> > >       }
> > >       Action(Vec, CommonMask);
> > >       InVectors.front() = Vec;
> > >     }
> > > ```
> > > 
> > > 
> > >     
> > >       
> > >     
> > > 
> > >       
> > >     
> > > 
> > >     
> > >   
> > > How do you know `<4 x Ty>` from `unsigned VecVF = cast<FixedVectorType>(Vec->getType())->getNumElements();` is two `<2 x Ty>` (so VF is 2) instead of four `Ty` (so VF is 4)?
> > 
> > 
> > It should be considered as four `Ty`, we should operate on actual types here
> 
> If we have VL with two `<2 x Ty>`.
> 
> ```
>       Res = ShuffleBuilder.finalize(
>           E->ReuseShuffleIndices, SubVectors, SubVectorsMask, E->Scalars.size(),
>           [&](Value *&Vec, SmallVectorImpl<int> &Mask) {
>             TryPackScalars(NonConstants, Mask, /*IsRootPoison=*/false);
>             Vec = ShuffleBuilder.gather(NonConstants, Mask.size(), Vec);
>           });
> ```
> 
> The VF it expects is 2. But `unsigned VecVF = cast<FixedVectorType>(Vec->getType())->getNumElements();` here is 4? And we compare `VecVF` (4) with the incoming `VF` (2)?

The VF here should be expanded to 4 too

https://github.com/llvm/llvm-project/pull/120952


More information about the llvm-commits mailing list