[llvm] e6d62c9 - [X86] IsElementEquivalent - pull out vector element count mismatch code. NFC.
Simon Pilgrim via llvm-commits
llvm-commits at lists.llvm.org
Fri Jun 6 10:18:55 PDT 2025
Author: Simon Pilgrim
Date: 2025-06-06T18:06:54+01:00
New Revision: e6d62c910fdc26cda58d21db84c5ef01b910c81d
URL: https://github.com/llvm/llvm-project/commit/e6d62c910fdc26cda58d21db84c5ef01b910c81d
DIFF: https://github.com/llvm/llvm-project/commit/e6d62c910fdc26cda58d21db84c5ef01b910c81d.diff
LOG: [X86] IsElementEquivalent - pull out vector element count mismatch code. NFC.
All cases rely on the ops having the same vector count as the masksize, and this is unlikely to change now that we handle bitcasts, so just early out.
Added:
Modified:
llvm/lib/Target/X86/X86ISelLowering.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 13635305f6a89..34e3f52bf7ff9 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -9782,20 +9782,23 @@ static bool IsElementEquivalent(int MaskSize, SDValue Op, SDValue ExpectedOp,
return false;
EVT VT = Op.getValueType();
+ EVT ExpectedVT = ExpectedOp.getValueType();
+
+ // Sources must be vectors and match the mask's element count.
+ if (!VT.isVector() || !ExpectedVT.isVector() ||
+ (int)VT.getVectorNumElements() != MaskSize ||
+ (int)ExpectedVT.getVectorNumElements() != MaskSize)
+ return false;
+
switch (Op.getOpcode()) {
case ISD::BUILD_VECTOR:
// If the values are build vectors, we can look through them to find
// equivalent inputs that make the shuffles equivalent.
- // TODO: Handle MaskSize != Op.getNumOperands()?
- if (MaskSize == (int)Op.getNumOperands() &&
- MaskSize == (int)ExpectedOp.getNumOperands())
- return Op.getOperand(Idx) == ExpectedOp.getOperand(ExpectedIdx);
- break;
+ return Op.getOperand(Idx) == ExpectedOp.getOperand(ExpectedIdx);
case ISD::BITCAST: {
SDValue Src = peekThroughBitcasts(Op);
EVT SrcVT = Src.getValueType();
- if (Op == ExpectedOp && SrcVT.isVector() &&
- (int)VT.getVectorNumElements() == MaskSize) {
+ if (Op == ExpectedOp && SrcVT.isVector()) {
if ((SrcVT.getScalarSizeInBits() % VT.getScalarSizeInBits()) == 0) {
unsigned Scale = SrcVT.getScalarSizeInBits() / VT.getScalarSizeInBits();
return (Idx % Scale) == (ExpectedIdx % Scale) &&
@@ -9816,23 +9819,21 @@ static bool IsElementEquivalent(int MaskSize, SDValue Op, SDValue ExpectedOp,
}
case ISD::VECTOR_SHUFFLE: {
auto *SVN = cast<ShuffleVectorSDNode>(Op);
- return Op == ExpectedOp && (int)VT.getVectorNumElements() == MaskSize &&
+ return Op == ExpectedOp &&
SVN->getMaskElt(Idx) == SVN->getMaskElt(ExpectedIdx);
}
case X86ISD::VBROADCAST:
case X86ISD::VBROADCAST_LOAD:
- // TODO: Handle MaskSize != VT.getVectorNumElements()?
- return (Op == ExpectedOp && (int)VT.getVectorNumElements() == MaskSize);
+ return Op == ExpectedOp;
case X86ISD::SUBV_BROADCAST_LOAD:
- // TODO: Handle MaskSize != VT.getVectorNumElements()?
- if (Op == ExpectedOp && (int)VT.getVectorNumElements() == MaskSize) {
+ if (Op == ExpectedOp) {
auto *MemOp = cast<MemSDNode>(Op);
unsigned NumMemElts = MemOp->getMemoryVT().getVectorNumElements();
return (Idx % NumMemElts) == (ExpectedIdx % NumMemElts);
}
break;
case X86ISD::VPERMI: {
- if (Op == ExpectedOp && (int)VT.getVectorNumElements() == MaskSize) {
+ if (Op == ExpectedOp) {
SmallVector<int, 8> Mask;
DecodeVPERMMask(MaskSize, Op.getConstantOperandVal(1), Mask);
SDValue Src = Op.getOperand(0);
@@ -9849,20 +9850,16 @@ static bool IsElementEquivalent(int MaskSize, SDValue Op, SDValue ExpectedOp,
case X86ISD::PACKSS:
case X86ISD::PACKUS:
// HOP(X,X) can refer to the elt from the lower/upper half of a lane.
- // TODO: Handle MaskSize != NumElts?
// TODO: Handle HOP(X,Y) vs HOP(Y,X) equivalence cases.
if (Op == ExpectedOp && Op.getOperand(0) == Op.getOperand(1)) {
int NumElts = VT.getVectorNumElements();
- if (MaskSize == NumElts) {
- int NumLanes = VT.getSizeInBits() / 128;
- int NumEltsPerLane = NumElts / NumLanes;
- int NumHalfEltsPerLane = NumEltsPerLane / 2;
- bool SameLane =
- (Idx / NumEltsPerLane) == (ExpectedIdx / NumEltsPerLane);
- bool SameElt =
- (Idx % NumHalfEltsPerLane) == (ExpectedIdx % NumHalfEltsPerLane);
- return SameLane && SameElt;
- }
+ int NumLanes = VT.getSizeInBits() / 128;
+ int NumEltsPerLane = NumElts / NumLanes;
+ int NumHalfEltsPerLane = NumEltsPerLane / 2;
+ bool SameLane = (Idx / NumEltsPerLane) == (ExpectedIdx / NumEltsPerLane);
+ bool SameElt =
+ (Idx % NumHalfEltsPerLane) == (ExpectedIdx % NumHalfEltsPerLane);
+ return SameLane && SameElt;
}
break;
}
More information about the llvm-commits
mailing list