[llvm] 69a2115 - [DAG] Fold trunc(srl(extract_elt(vec,c1),c2)) -> extract_elt(bitcast(vec),c3) (#107987)

via llvm-commits llvm-commits at lists.llvm.org
Fri Sep 13 07:14:02 PDT 2024


Author: Simon Pilgrim
Date: 2024-09-13T15:13:58+01:00
New Revision: 69a21154caa5b53d302cd3bfd7ce0ec1a0c3d985

URL: https://github.com/llvm/llvm-project/commit/69a21154caa5b53d302cd3bfd7ce0ec1a0c3d985
DIFF: https://github.com/llvm/llvm-project/commit/69a21154caa5b53d302cd3bfd7ce0ec1a0c3d985.diff

LOG: [DAG] Fold trunc(srl(extract_elt(vec,c1),c2)) -> extract_elt(bitcast(vec),c3) (#107987)

Extends existing trunc(extract_elt(vec,c1)) -> extract_elt(bitcast(vec),c3) fold.

Noticed while working on #107404

Added: 
    

Modified: 
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/test/CodeGen/AArch64/expand-select.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index bb907633e1f824..fe8ae5c9e9af6a 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -15142,26 +15142,42 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
   // Note: We only run this optimization after type legalization (which often
   // creates this pattern) and before operation legalization after which
   // we need to be more careful about the vector instructions that we generate.
-  if (N0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
-      LegalTypes && !LegalOperations && N0->hasOneUse() && VT != MVT::i1) {
-    EVT VecTy = N0.getOperand(0).getValueType();
-    EVT ExTy = N0.getValueType();
+  if (LegalTypes && !LegalOperations && VT.isScalarInteger() && VT != MVT::i1 &&
+      N0->hasOneUse()) {
     EVT TrTy = N->getValueType(0);
+    SDValue Src = N0;
+
+    // Check for cases where we shift down an upper element before truncation.
+    int EltOffset = 0;
+    if (Src.getOpcode() == ISD::SRL && Src.getOperand(0)->hasOneUse()) {
+      if (auto ShAmt = DAG.getValidShiftAmount(Src)) {
+        if ((*ShAmt % TrTy.getSizeInBits()) == 0) {
+          Src = Src.getOperand(0);
+          EltOffset = *ShAmt / TrTy.getSizeInBits();
+        }
+      }
+    }
 
-    auto EltCnt = VecTy.getVectorElementCount();
-    unsigned SizeRatio = ExTy.getSizeInBits()/TrTy.getSizeInBits();
-    auto NewEltCnt = EltCnt * SizeRatio;
+    if (Src.getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
+      EVT VecTy = Src.getOperand(0).getValueType();
+      EVT ExTy = Src.getValueType();
 
-    EVT NVT = EVT::getVectorVT(*DAG.getContext(), TrTy, NewEltCnt);
-    assert(NVT.getSizeInBits() == VecTy.getSizeInBits() && "Invalid Size");
+      auto EltCnt = VecTy.getVectorElementCount();
+      unsigned SizeRatio = ExTy.getSizeInBits() / TrTy.getSizeInBits();
+      auto NewEltCnt = EltCnt * SizeRatio;
 
-    SDValue EltNo = N0->getOperand(1);
-    if (isa<ConstantSDNode>(EltNo) && isTypeLegal(NVT)) {
-      int Elt = EltNo->getAsZExtVal();
-      int Index = isLE ? (Elt*SizeRatio) : (Elt*SizeRatio + (SizeRatio-1));
-      return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, TrTy,
-                         DAG.getBitcast(NVT, N0.getOperand(0)),
-                         DAG.getVectorIdxConstant(Index, DL));
+      EVT NVT = EVT::getVectorVT(*DAG.getContext(), TrTy, NewEltCnt);
+      assert(NVT.getSizeInBits() == VecTy.getSizeInBits() && "Invalid Size");
+
+      SDValue EltNo = Src->getOperand(1);
+      if (isa<ConstantSDNode>(EltNo) && isTypeLegal(NVT)) {
+        int Elt = EltNo->getAsZExtVal();
+        int Index = isLE ? (Elt * SizeRatio + EltOffset)
+                         : (Elt * SizeRatio + (SizeRatio - 1) - EltOffset);
+        return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, TrTy,
+                           DAG.getBitcast(NVT, Src.getOperand(0)),
+                           DAG.getVectorIdxConstant(Index, DL));
+      }
     }
   }
 

diff  --git a/llvm/test/CodeGen/AArch64/expand-select.ll b/llvm/test/CodeGen/AArch64/expand-select.ll
index f8397290ab5e14..1ed2e09c6b4d43 100644
--- a/llvm/test/CodeGen/AArch64/expand-select.ll
+++ b/llvm/test/CodeGen/AArch64/expand-select.ll
@@ -33,24 +33,20 @@ define void @bar(i32 %In1, <2 x i96> %In2, <2 x i96> %In3, ptr %Out) {
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    and w8, w0, #0x1
 ; CHECK-NEXT:    fmov s0, wzr
-; CHECK-NEXT:    ldr x11, [sp, #16]
+; CHECK-NEXT:    ldr x10, [sp, #16]
 ; CHECK-NEXT:    fmov s1, w8
-; CHECK-NEXT:    ldp x9, x10, [sp]
 ; CHECK-NEXT:    cmeq v0.4s, v1.4s, v0.4s
-; CHECK-NEXT:    dup v1.4s, v0.s[0]
-; CHECK-NEXT:    mov x8, v1.d[1]
-; CHECK-NEXT:    lsr x8, x8, #32
-; CHECK-NEXT:    tst w8, #0x1
 ; CHECK-NEXT:    fmov w8, s0
-; CHECK-NEXT:    csel x10, x5, x10, ne
-; CHECK-NEXT:    csel x9, x4, x9, ne
-; CHECK-NEXT:    stur x9, [x11, #12]
 ; CHECK-NEXT:    tst w8, #0x1
-; CHECK-NEXT:    str w10, [x11, #20]
-; CHECK-NEXT:    csel x8, x2, x6, ne
+; CHECK-NEXT:    ldp x9, x8, [sp]
+; CHECK-NEXT:    csel x11, x2, x6, ne
+; CHECK-NEXT:    str x11, [x10]
+; CHECK-NEXT:    csel x9, x4, x9, ne
+; CHECK-NEXT:    csel x8, x5, x8, ne
+; CHECK-NEXT:    stur x9, [x10, #12]
 ; CHECK-NEXT:    csel x9, x3, x7, ne
-; CHECK-NEXT:    str x8, [x11]
-; CHECK-NEXT:    str w9, [x11, #8]
+; CHECK-NEXT:    str w8, [x10, #20]
+; CHECK-NEXT:    str w9, [x10, #8]
 ; CHECK-NEXT:    ret
   %cond = and i32 %In1, 1
   %cbool = icmp eq i32 %cond, 0


        


More information about the llvm-commits mailing list