[llvm] [AArch64] Improve lowering of truncating uzp1 (PR #82457)
David Sherwood via llvm-commits
llvm-commits at lists.llvm.org
Thu Feb 22 02:04:58 PST 2024
================
@@ -21059,20 +21055,53 @@ static SDValue performUzpCombine(SDNode *N, SelectionDAG &DAG,
return SDValue();
auto getSourceOp = [](SDValue Operand) -> SDValue {
- const unsigned Opcode = Operand.getOpcode();
- if (Opcode == ISD::TRUNCATE)
- return Operand->getOperand(0);
- if (Opcode == ISD::BITCAST &&
- Operand->getOperand(0).getOpcode() == ISD::TRUNCATE)
- return Operand->getOperand(0)->getOperand(0);
- return SDValue();
+ if (Operand.getOpcode() == ISD::BITCAST)
+ Operand = Operand->getOperand(0);
+ if (Operand.getOpcode() == ISD::AssertSext ||
+ Operand.getOpcode() == ISD::AssertZext)
+ Operand = Operand->getOperand(0);
+ return Operand;
};
SDValue SourceOp0 = getSourceOp(Op0);
SDValue SourceOp1 = getSourceOp(Op1);
- if (!SourceOp0 || !SourceOp1)
+ auto IsTruncatingUZP1Concat = [](SDNode *N, LLVMContext &Ctx) -> bool {
+ if (N->getOpcode() != AArch64ISD::UZP1)
+ return false;
+ SDValue Op0 = N->getOperand(0);
+ SDValue Op1 = N->getOperand(1);
+ if (Op0.getOpcode() != ISD::BITCAST || Op1.getOpcode() != ISD::BITCAST)
+ return false;
+ EVT Op0Ty = Op0.getOperand(0).getValueType();
+ if (Op0Ty != Op1.getOperand(0).getValueType())
+ return false;
+
+ EVT ResVT = N->getValueType(0);
+ return ResVT.widenIntegerVectorElementType(Ctx).getHalfNumVectorElementsVT(
+ Ctx) == Op0Ty;
+ };
+
+ // truncating uzp1(x=uzp1, y=uzp1) -> trunc(concat (x, y))
+ // This is similar to the transform below, except that it looks for truncation
+ // done using the uzp1 node.
+ LLVMContext &Ctx = *DAG.getContext();
+ if (IsTruncatingUZP1Concat(N, Ctx) &&
+ IsTruncatingUZP1Concat(SourceOp0.getNode(), Ctx) &&
+ IsTruncatingUZP1Concat(SourceOp1.getNode(), Ctx)) {
+ SDValue UZP1 =
+ DAG.getNode(ISD::CONCAT_VECTORS, DL,
+ SourceOp0.getValueType().getDoubleNumVectorElementsVT(Ctx),
+ SourceOp0, SourceOp1);
+ return DAG.getNode(ISD::TRUNCATE, DL, ResVT, UZP1);
+ }
+
+ // uzp1(xtn x, xtn y) -> xtn(uzp1 (x, y))
+ if (SourceOp0.getOpcode() != ISD::TRUNCATE ||
+ SourceOp1.getOpcode() != ISD::TRUNCATE)
return SDValue();
+ SourceOp0 = SourceOp0.getOperand(0);
----------------
david-arm wrote:
It looks like you've extended the capability of the existing optimisation to look through ASSERTZEXT so we could have
UZP1(ASSERTZEXT(TRUNC(Op1)), ASSERTZEXT(TRUNC(Op2)))
->
TRUNC(UZP1(BITCAST(Op1), BITCAST(Op2)))
Doesn't that defeat the purpose of the patch, i.e. to preserve the AssertZExt and AssertSExt DAG nodes as long as possible?
https://github.com/llvm/llvm-project/pull/82457
More information about the llvm-commits
mailing list