[llvm] 69a2115 - [DAG] Fold trunc(srl(extract_elt(vec,c1),c2)) -> extract_elt(bitcast(vec),c3) (#107987)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Sep 13 07:14:02 PDT 2024
Author: Simon Pilgrim
Date: 2024-09-13T15:13:58+01:00
New Revision: 69a21154caa5b53d302cd3bfd7ce0ec1a0c3d985
URL: https://github.com/llvm/llvm-project/commit/69a21154caa5b53d302cd3bfd7ce0ec1a0c3d985
DIFF: https://github.com/llvm/llvm-project/commit/69a21154caa5b53d302cd3bfd7ce0ec1a0c3d985.diff
LOG: [DAG] Fold trunc(srl(extract_elt(vec,c1),c2)) -> extract_elt(bitcast(vec),c3) (#107987)
Extends existing trunc(extract_elt(vec,c1)) -> extract_elt(bitcast(vec),c3) fold.
Noticed while working on #107404
Added:
Modified:
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
llvm/test/CodeGen/AArch64/expand-select.ll
Removed:
################################################################################
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index bb907633e1f824..fe8ae5c9e9af6a 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -15142,26 +15142,42 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
// Note: We only run this optimization after type legalization (which often
// creates this pattern) and before operation legalization after which
// we need to be more careful about the vector instructions that we generate.
- if (N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
- LegalTypes && !LegalOperations && N0->hasOneUse() && VT != MVT::i1) {
- EVT VecTy = N0.getOperand(0).getValueType();
- EVT ExTy = N0.getValueType();
+ if (LegalTypes && !LegalOperations && VT.isScalarInteger() && VT != MVT::i1 &&
+ N0->hasOneUse()) {
EVT TrTy = N->getValueType(0);
+ SDValue Src = N0;
+
+ // Check for cases where we shift down an upper element before truncation.
+ int EltOffset = 0;
+ if (Src.getOpcode() == ISD::SRL && Src.getOperand(0)->hasOneUse()) {
+ if (auto ShAmt = DAG.getValidShiftAmount(Src)) {
+ if ((*ShAmt % TrTy.getSizeInBits()) == 0) {
+ Src = Src.getOperand(0);
+ EltOffset = *ShAmt / TrTy.getSizeInBits();
+ }
+ }
+ }
- auto EltCnt = VecTy.getVectorElementCount();
- unsigned SizeRatio = ExTy.getSizeInBits()/TrTy.getSizeInBits();
- auto NewEltCnt = EltCnt * SizeRatio;
+ if (Src.getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
+ EVT VecTy = Src.getOperand(0).getValueType();
+ EVT ExTy = Src.getValueType();
- EVT NVT = EVT::getVectorVT(*DAG.getContext(), TrTy, NewEltCnt);
- assert(NVT.getSizeInBits() == VecTy.getSizeInBits() && "Invalid Size");
+ auto EltCnt = VecTy.getVectorElementCount();
+ unsigned SizeRatio = ExTy.getSizeInBits() / TrTy.getSizeInBits();
+ auto NewEltCnt = EltCnt * SizeRatio;
- SDValue EltNo = N0->getOperand(1);
- if (isa<ConstantSDNode>(EltNo) && isTypeLegal(NVT)) {
- int Elt = EltNo->getAsZExtVal();
- int Index = isLE ? (Elt*SizeRatio) : (Elt*SizeRatio + (SizeRatio-1));
- return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, TrTy,
- DAG.getBitcast(NVT, N0.getOperand(0)),
- DAG.getVectorIdxConstant(Index, DL));
+ EVT NVT = EVT::getVectorVT(*DAG.getContext(), TrTy, NewEltCnt);
+ assert(NVT.getSizeInBits() == VecTy.getSizeInBits() && "Invalid Size");
+
+ SDValue EltNo = Src->getOperand(1);
+ if (isa<ConstantSDNode>(EltNo) && isTypeLegal(NVT)) {
+ int Elt = EltNo->getAsZExtVal();
+ int Index = isLE ? (Elt * SizeRatio + EltOffset)
+ : (Elt * SizeRatio + (SizeRatio - 1) - EltOffset);
+ return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, TrTy,
+ DAG.getBitcast(NVT, Src.getOperand(0)),
+ DAG.getVectorIdxConstant(Index, DL));
+ }
}
}
diff --git a/llvm/test/CodeGen/AArch64/expand-select.ll b/llvm/test/CodeGen/AArch64/expand-select.ll
index f8397290ab5e14..1ed2e09c6b4d43 100644
--- a/llvm/test/CodeGen/AArch64/expand-select.ll
+++ b/llvm/test/CodeGen/AArch64/expand-select.ll
@@ -33,24 +33,20 @@ define void @bar(i32 %In1, <2 x i96> %In2, <2 x i96> %In3, ptr %Out) {
; CHECK: // %bb.0:
; CHECK-NEXT: and w8, w0, #0x1
; CHECK-NEXT: fmov s0, wzr
-; CHECK-NEXT: ldr x11, [sp, #16]
+; CHECK-NEXT: ldr x10, [sp, #16]
; CHECK-NEXT: fmov s1, w8
-; CHECK-NEXT: ldp x9, x10, [sp]
; CHECK-NEXT: cmeq v0.4s, v1.4s, v0.4s
-; CHECK-NEXT: dup v1.4s, v0.s[0]
-; CHECK-NEXT: mov x8, v1.d[1]
-; CHECK-NEXT: lsr x8, x8, #32
-; CHECK-NEXT: tst w8, #0x1
; CHECK-NEXT: fmov w8, s0
-; CHECK-NEXT: csel x10, x5, x10, ne
-; CHECK-NEXT: csel x9, x4, x9, ne
-; CHECK-NEXT: stur x9, [x11, #12]
; CHECK-NEXT: tst w8, #0x1
-; CHECK-NEXT: str w10, [x11, #20]
-; CHECK-NEXT: csel x8, x2, x6, ne
+; CHECK-NEXT: ldp x9, x8, [sp]
+; CHECK-NEXT: csel x11, x2, x6, ne
+; CHECK-NEXT: str x11, [x10]
+; CHECK-NEXT: csel x9, x4, x9, ne
+; CHECK-NEXT: csel x8, x5, x8, ne
+; CHECK-NEXT: stur x9, [x10, #12]
; CHECK-NEXT: csel x9, x3, x7, ne
-; CHECK-NEXT: str x8, [x11]
-; CHECK-NEXT: str w9, [x11, #8]
+; CHECK-NEXT: str w8, [x10, #20]
+; CHECK-NEXT: str w9, [x10, #8]
; CHECK-NEXT: ret
%cond = and i32 %In1, 1
%cbool = icmp eq i32 %cond, 0
More information about the llvm-commits
mailing list