[llvm] 2fef449 - [LLVM][AArch64] Enable verifyTargetSDNode for scalable vectors and fix the fallout. (#104820)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Sep 4 03:07:14 PDT 2024
Author: Paul Walker
Date: 2024-09-04T11:07:11+01:00
New Revision: 2fef449f30e2f484897cb199e3338a1520803c7d
URL: https://github.com/llvm/llvm-project/commit/2fef449f30e2f484897cb199e3338a1520803c7d
DIFF: https://github.com/llvm/llvm-project/commit/2fef449f30e2f484897cb199e3338a1520803c7d.diff
LOG: [LLVM][AArch64] Enable verifyTargetSDNode for scalable vectors and fix the fallout. (#104820)
Fix incorrect use of AArch64ISD::UZP1/UUNPK{HI,LO} in:
AArch64TargetLowering::LowerDIV
AArch64TargetLowering::LowerINSERT_SUBVECTOR
The latter highlighted DAG combines that relied on broken behaviour,
which this patch also fixes.
Added:
Modified:
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 1735ff5cd69748..5e3f9364ac3e12 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -14908,10 +14908,11 @@ SDValue AArch64TargetLowering::LowerINSERT_SUBVECTOR(SDValue Op,
// NOP cast operands to the largest legal vector of the same element count.
if (VT.isFloatingPoint()) {
Vec0 = getSVESafeBitCast(NarrowVT, Vec0, DAG);
- Vec1 = getSVESafeBitCast(WideVT, Vec1, DAG);
+ Vec1 = getSVESafeBitCast(NarrowVT, Vec1, DAG);
} else {
// Legal integer vectors are already their largest so Vec0 is fine as is.
Vec1 = DAG.getNode(ISD::ANY_EXTEND, DL, WideVT, Vec1);
+ Vec1 = DAG.getNode(AArch64ISD::NVCAST, DL, NarrowVT, Vec1);
}
// To replace the top/bottom half of vector V with vector SubV we widen the
@@ -14920,11 +14921,13 @@ SDValue AArch64TargetLowering::LowerINSERT_SUBVECTOR(SDValue Op,
SDValue Narrow;
if (Idx == 0) {
SDValue HiVec0 = DAG.getNode(AArch64ISD::UUNPKHI, DL, WideVT, Vec0);
+ HiVec0 = DAG.getNode(AArch64ISD::NVCAST, DL, NarrowVT, HiVec0);
Narrow = DAG.getNode(AArch64ISD::UZP1, DL, NarrowVT, Vec1, HiVec0);
} else {
assert(Idx == InVT.getVectorMinNumElements() &&
"Invalid subvector index!");
SDValue LoVec0 = DAG.getNode(AArch64ISD::UUNPKLO, DL, WideVT, Vec0);
+ LoVec0 = DAG.getNode(AArch64ISD::NVCAST, DL, NarrowVT, LoVec0);
Narrow = DAG.getNode(AArch64ISD::UZP1, DL, NarrowVT, LoVec0, Vec1);
}
@@ -15024,7 +15027,9 @@ SDValue AArch64TargetLowering::LowerDIV(SDValue Op, SelectionDAG &DAG) const {
SDValue Op1Hi = DAG.getNode(UnpkHi, dl, WidenedVT, Op.getOperand(1));
SDValue ResultLo = DAG.getNode(Op.getOpcode(), dl, WidenedVT, Op0Lo, Op1Lo);
SDValue ResultHi = DAG.getNode(Op.getOpcode(), dl, WidenedVT, Op0Hi, Op1Hi);
- return DAG.getNode(AArch64ISD::UZP1, dl, VT, ResultLo, ResultHi);
+ SDValue ResultLoCast = DAG.getNode(AArch64ISD::NVCAST, dl, VT, ResultLo);
+ SDValue ResultHiCast = DAG.getNode(AArch64ISD::NVCAST, dl, VT, ResultHi);
+ return DAG.getNode(AArch64ISD::UZP1, dl, VT, ResultLoCast, ResultHiCast);
}
bool AArch64TargetLowering::shouldExpandBuildVectorWithShuffles(
@@ -22739,7 +22744,19 @@ static SDValue trySimplifySrlAddToRshrnb(SDValue Srl, SelectionDAG &DAG,
SDValue Rshrnb = DAG.getNode(
AArch64ISD::RSHRNB_I, DL, ResVT,
{RShOperand, DAG.getTargetConstant(ShiftValue, DL, MVT::i32)});
- return DAG.getNode(ISD::BITCAST, DL, VT, Rshrnb);
+ return DAG.getNode(AArch64ISD::NVCAST, DL, VT, Rshrnb);
+}
+
+static SDValue isNVCastToHalfWidthElements(SDValue V) {
+ if (V.getOpcode() != AArch64ISD::NVCAST)
+ return SDValue();
+
+ SDValue Op = V.getOperand(0);
+ if (V.getValueType().getVectorElementCount() !=
+ Op.getValueType().getVectorElementCount() * 2)
+ return SDValue();
+
+ return Op;
}
static SDValue performUzpCombine(SDNode *N, SelectionDAG &DAG,
@@ -22802,25 +22819,37 @@ static SDValue performUzpCombine(SDNode *N, SelectionDAG &DAG,
if (SDValue Urshr = tryCombineExtendRShTrunc(N, DAG))
return Urshr;
- if (SDValue Rshrnb = trySimplifySrlAddToRshrnb(Op0, DAG, Subtarget))
- return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Rshrnb, Op1);
+ if (SDValue PreCast = isNVCastToHalfWidthElements(Op0)) {
+ if (SDValue Rshrnb = trySimplifySrlAddToRshrnb(PreCast, DAG, Subtarget)) {
+ Rshrnb = DAG.getNode(AArch64ISD::NVCAST, DL, ResVT, Rshrnb);
+ return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Rshrnb, Op1);
+ }
+ }
- if (SDValue Rshrnb = trySimplifySrlAddToRshrnb(Op1, DAG, Subtarget))
- return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Op0, Rshrnb);
+ if (SDValue PreCast = isNVCastToHalfWidthElements(Op1)) {
+ if (SDValue Rshrnb = trySimplifySrlAddToRshrnb(PreCast, DAG, Subtarget)) {
+ Rshrnb = DAG.getNode(AArch64ISD::NVCAST, DL, ResVT, Rshrnb);
+ return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Op0, Rshrnb);
+ }
+ }
- // uzp1(unpklo(uzp1(x, y)), z) => uzp1(x, z)
- if (Op0.getOpcode() == AArch64ISD::UUNPKLO) {
- if (Op0.getOperand(0).getOpcode() == AArch64ISD::UZP1) {
- SDValue X = Op0.getOperand(0).getOperand(0);
- return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, X, Op1);
+ // uzp1<ty>(nvcast(unpklo(uzp1<ty>(x, y))), z) => uzp1<ty>(x, z)
+ if (SDValue PreCast = isNVCastToHalfWidthElements(Op0)) {
+ if (PreCast.getOpcode() == AArch64ISD::UUNPKLO) {
+ if (PreCast.getOperand(0).getOpcode() == AArch64ISD::UZP1) {
+ SDValue X = PreCast.getOperand(0).getOperand(0);
+ return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, X, Op1);
+ }
}
}
- // uzp1(x, unpkhi(uzp1(y, z))) => uzp1(x, z)
- if (Op1.getOpcode() == AArch64ISD::UUNPKHI) {
- if (Op1.getOperand(0).getOpcode() == AArch64ISD::UZP1) {
- SDValue Z = Op1.getOperand(0).getOperand(1);
- return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Op0, Z);
+ // uzp1<ty>(x, nvcast(unpkhi(uzp1<ty>(y, z)))) => uzp1<ty>(x, z)
+ if (SDValue PreCast = isNVCastToHalfWidthElements(Op1)) {
+ if (PreCast.getOpcode() == AArch64ISD::UUNPKHI) {
+ if (PreCast.getOperand(0).getOpcode() == AArch64ISD::UZP1) {
+ SDValue Z = PreCast.getOperand(0).getOperand(1);
+ return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Op0, Z);
+ }
}
}
@@ -29415,9 +29444,6 @@ void AArch64TargetLowering::verifyTargetSDNode(const SDNode *N) const {
VT.isInteger() && "Expected integer vectors!");
assert(OpVT.getSizeInBits() == VT.getSizeInBits() &&
"Expected vectors of equal size!");
- // TODO: Enable assert once bogus creations have been fixed.
- if (VT.isScalableVector())
- break;
assert(OpVT.getVectorElementCount() == VT.getVectorElementCount() * 2 &&
"Expected result vector with half the lanes of its input!");
break;
@@ -29435,12 +29461,27 @@ void AArch64TargetLowering::verifyTargetSDNode(const SDNode *N) const {
EVT Op1VT = N->getOperand(1).getValueType();
assert(VT.isVector() && Op0VT.isVector() && Op1VT.isVector() &&
"Expected vectors!");
- // TODO: Enable assert once bogus creations have been fixed.
- if (VT.isScalableVector())
- break;
assert(VT == Op0VT && VT == Op1VT && "Expected matching vectors!");
break;
}
+ case AArch64ISD::RSHRNB_I: {
+ assert(N->getNumValues() == 1 && "Expected one result!");
+ assert(N->getNumOperands() == 2 && "Expected two operands!");
+ EVT VT = N->getValueType(0);
+ EVT Op0VT = N->getOperand(0).getValueType();
+ EVT Op1VT = N->getOperand(1).getValueType();
+ assert(VT.isVector() && VT.isInteger() &&
+ "Expected integer vector result type!");
+ assert(Op0VT.isVector() && Op0VT.isInteger() &&
+ "Expected first operand to be an integer vector!");
+ assert(VT.getSizeInBits() == Op0VT.getSizeInBits() &&
+ "Expected vectors of equal size!");
+ assert(VT.getVectorElementCount() == Op0VT.getVectorElementCount() * 2 &&
+ "Expected input vector with half the lanes of its result!");
+ assert(Op1VT == MVT::i32 && isa<ConstantSDNode>(N->getOperand(1)) &&
+ "Expected second operand to be a constant i32!");
+ break;
+ }
}
}
#endif
More information about the llvm-commits
mailing list