[llvm] [DAG] Fold trunc(srl(extract_elt(vec,c1),c2)) -> extract_elt(bitcast(vec),c3) (PR #107987)
Simon Pilgrim via llvm-commits
llvm-commits at lists.llvm.org
Wed Sep 11 04:03:18 PDT 2024
https://github.com/RKSimon updated https://github.com/llvm/llvm-project/pull/107987
>From a73c803922e8696acf63a450ab16e4a0958541cd Mon Sep 17 00:00:00 2001
From: Simon Pilgrim <llvm-dev at redking.me.uk>
Date: Tue, 10 Sep 2024 09:46:18 +0100
Subject: [PATCH] [DAG] Fold trunc(srl(extract_elt(vec,c1),c2)) ->
extract_elt(bitcast(vec),c3)
Extends existing trunc(extract_elt(vec,c1)) -> extract_elt(bitcast(vec),c3) fold.
Initial refector support for #107404
---
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 48 ++++++++++++-------
llvm/test/CodeGen/AArch64/expand-select.ll | 22 ++++-----
2 files changed, 41 insertions(+), 29 deletions(-)
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index bb907633e1f824..fe8ae5c9e9af6a 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -15142,26 +15142,42 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
// Note: We only run this optimization after type legalization (which often
// creates this pattern) and before operation legalization after which
// we need to be more careful about the vector instructions that we generate.
- if (N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
- LegalTypes && !LegalOperations && N0->hasOneUse() && VT != MVT::i1) {
- EVT VecTy = N0.getOperand(0).getValueType();
- EVT ExTy = N0.getValueType();
+ if (LegalTypes && !LegalOperations && VT.isScalarInteger() && VT != MVT::i1 &&
+ N0->hasOneUse()) {
EVT TrTy = N->getValueType(0);
+ SDValue Src = N0;
+
+ // Check for cases where we shift down an upper element before truncation.
+ int EltOffset = 0;
+ if (Src.getOpcode() == ISD::SRL && Src.getOperand(0)->hasOneUse()) {
+ if (auto ShAmt = DAG.getValidShiftAmount(Src)) {
+ if ((*ShAmt % TrTy.getSizeInBits()) == 0) {
+ Src = Src.getOperand(0);
+ EltOffset = *ShAmt / TrTy.getSizeInBits();
+ }
+ }
+ }
- auto EltCnt = VecTy.getVectorElementCount();
- unsigned SizeRatio = ExTy.getSizeInBits()/TrTy.getSizeInBits();
- auto NewEltCnt = EltCnt * SizeRatio;
+ if (Src.getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
+ EVT VecTy = Src.getOperand(0).getValueType();
+ EVT ExTy = Src.getValueType();
- EVT NVT = EVT::getVectorVT(*DAG.getContext(), TrTy, NewEltCnt);
- assert(NVT.getSizeInBits() == VecTy.getSizeInBits() && "Invalid Size");
+ auto EltCnt = VecTy.getVectorElementCount();
+ unsigned SizeRatio = ExTy.getSizeInBits() / TrTy.getSizeInBits();
+ auto NewEltCnt = EltCnt * SizeRatio;
- SDValue EltNo = N0->getOperand(1);
- if (isa<ConstantSDNode>(EltNo) && isTypeLegal(NVT)) {
- int Elt = EltNo->getAsZExtVal();
- int Index = isLE ? (Elt*SizeRatio) : (Elt*SizeRatio + (SizeRatio-1));
- return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, TrTy,
- DAG.getBitcast(NVT, N0.getOperand(0)),
- DAG.getVectorIdxConstant(Index, DL));
+ EVT NVT = EVT::getVectorVT(*DAG.getContext(), TrTy, NewEltCnt);
+ assert(NVT.getSizeInBits() == VecTy.getSizeInBits() && "Invalid Size");
+
+ SDValue EltNo = Src->getOperand(1);
+ if (isa<ConstantSDNode>(EltNo) && isTypeLegal(NVT)) {
+ int Elt = EltNo->getAsZExtVal();
+ int Index = isLE ? (Elt * SizeRatio + EltOffset)
+ : (Elt * SizeRatio + (SizeRatio - 1) - EltOffset);
+ return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, TrTy,
+ DAG.getBitcast(NVT, Src.getOperand(0)),
+ DAG.getVectorIdxConstant(Index, DL));
+ }
}
}
diff --git a/llvm/test/CodeGen/AArch64/expand-select.ll b/llvm/test/CodeGen/AArch64/expand-select.ll
index f8397290ab5e14..1ed2e09c6b4d43 100644
--- a/llvm/test/CodeGen/AArch64/expand-select.ll
+++ b/llvm/test/CodeGen/AArch64/expand-select.ll
@@ -33,24 +33,20 @@ define void @bar(i32 %In1, <2 x i96> %In2, <2 x i96> %In3, ptr %Out) {
; CHECK: // %bb.0:
; CHECK-NEXT: and w8, w0, #0x1
; CHECK-NEXT: fmov s0, wzr
-; CHECK-NEXT: ldr x11, [sp, #16]
+; CHECK-NEXT: ldr x10, [sp, #16]
; CHECK-NEXT: fmov s1, w8
-; CHECK-NEXT: ldp x9, x10, [sp]
; CHECK-NEXT: cmeq v0.4s, v1.4s, v0.4s
-; CHECK-NEXT: dup v1.4s, v0.s[0]
-; CHECK-NEXT: mov x8, v1.d[1]
-; CHECK-NEXT: lsr x8, x8, #32
-; CHECK-NEXT: tst w8, #0x1
; CHECK-NEXT: fmov w8, s0
-; CHECK-NEXT: csel x10, x5, x10, ne
-; CHECK-NEXT: csel x9, x4, x9, ne
-; CHECK-NEXT: stur x9, [x11, #12]
; CHECK-NEXT: tst w8, #0x1
-; CHECK-NEXT: str w10, [x11, #20]
-; CHECK-NEXT: csel x8, x2, x6, ne
+; CHECK-NEXT: ldp x9, x8, [sp]
+; CHECK-NEXT: csel x11, x2, x6, ne
+; CHECK-NEXT: str x11, [x10]
+; CHECK-NEXT: csel x9, x4, x9, ne
+; CHECK-NEXT: csel x8, x5, x8, ne
+; CHECK-NEXT: stur x9, [x10, #12]
; CHECK-NEXT: csel x9, x3, x7, ne
-; CHECK-NEXT: str x8, [x11]
-; CHECK-NEXT: str w9, [x11, #8]
+; CHECK-NEXT: str w8, [x10, #20]
+; CHECK-NEXT: str w9, [x10, #8]
; CHECK-NEXT: ret
%cond = and i32 %In1, 1
%cbool = icmp eq i32 %cond, 0
More information about the llvm-commits
mailing list