[llvm] 1a9fbf6 - [X86] combineLoad - reuse an existing VBROADCAST_LOAD constant for a smaller vector load of the same constant

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Mon Nov 13 04:00:47 PST 2023


Author: Simon Pilgrim
Date: 2023-11-13T11:59:04Z
New Revision: 1a9fbf61661558d4f3e03390161400db734aab59

URL: https://github.com/llvm/llvm-project/commit/1a9fbf61661558d4f3e03390161400db734aab59
DIFF: https://github.com/llvm/llvm-project/commit/1a9fbf61661558d4f3e03390161400db734aab59.diff

LOG: [X86] combineLoad - reuse an existing VBROADCAST_LOAD constant for a smaller vector load of the same constant

Extends the existing code that performed something similar for SUBV_BROADCAST_LOAD, but this is just for cases where AVX2 targets loads full width 128-bit constant vectors but broadcasts the equivalent 256-bit constant vector

Fixes AVX2 case for Issue #70947

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86ISelLowering.cpp
    llvm/test/CodeGen/X86/vec_fabs.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 8a883ad26a78d96..3d44c50b44e6234 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -49785,25 +49785,47 @@ static SDValue combineLoad(SDNode *N, SelectionDAG &DAG,
     }
   }
 
-  // If we also broadcast this as a subvector to a wider type, then just extract
-  // the lowest subvector.
+  // If we also broadcast this to a wider type, then just extract the lowest
+  // subvector.
   if (Ext == ISD::NON_EXTLOAD && Subtarget.hasAVX() && Ld->isSimple() &&
       (RegVT.is128BitVector() || RegVT.is256BitVector())) {
     SDValue Ptr = Ld->getBasePtr();
     SDValue Chain = Ld->getChain();
-    for (SDNode *User : Ptr->uses()) {
-      if (User != N && User->getOpcode() == X86ISD::SUBV_BROADCAST_LOAD &&
-          cast<MemIntrinsicSDNode>(User)->getBasePtr() == Ptr &&
+    for (SDNode *User : Chain->uses()) {
+      if (User != N &&
+          (User->getOpcode() == X86ISD::SUBV_BROADCAST_LOAD ||
+           User->getOpcode() == X86ISD::VBROADCAST_LOAD) &&
           cast<MemIntrinsicSDNode>(User)->getChain() == Chain &&
-          cast<MemIntrinsicSDNode>(User)->getMemoryVT().getSizeInBits() ==
-              MemVT.getSizeInBits() &&
           !User->hasAnyUseOfValue(1) &&
           User->getValueSizeInBits(0).getFixedValue() >
               RegVT.getFixedSizeInBits()) {
-        SDValue Extract = extractSubVector(SDValue(User, 0), 0, DAG, SDLoc(N),
-                                           RegVT.getSizeInBits());
-        Extract = DAG.getBitcast(RegVT, Extract);
-        return DCI.CombineTo(N, Extract, SDValue(User, 1));
+        if (User->getOpcode() == X86ISD::SUBV_BROADCAST_LOAD &&
+            cast<MemIntrinsicSDNode>(User)->getBasePtr() == Ptr &&
+            cast<MemIntrinsicSDNode>(User)->getMemoryVT().getSizeInBits() ==
+                MemVT.getSizeInBits()) {
+          SDValue Extract = extractSubVector(SDValue(User, 0), 0, DAG, SDLoc(N),
+                                             RegVT.getSizeInBits());
+          Extract = DAG.getBitcast(RegVT, Extract);
+          return DCI.CombineTo(N, Extract, SDValue(User, 1));
+        }
+        if (User->getOpcode() == X86ISD::VBROADCAST_LOAD &&
+            getTargetConstantFromBasePtr(Ptr)) {
+          // See if we are loading a constant that has also been broadcast.
+          APInt Undefs, UserUndefs;
+          SmallVector<APInt> Bits, UserBits;
+          if (getTargetConstantBitsFromNode(SDValue(N, 0), 8, Undefs, Bits) &&
+              getTargetConstantBitsFromNode(SDValue(User, 0), 8, UserUndefs,
+                                            UserBits)) {
+            UserUndefs = UserUndefs.trunc(Undefs.getBitWidth());
+            UserBits.truncate(Bits.size());
+            if (Bits == UserBits && UserUndefs.isSubsetOf(Undefs)) {
+              SDValue Extract = extractSubVector(
+                  SDValue(User, 0), 0, DAG, SDLoc(N), RegVT.getSizeInBits());
+              Extract = DAG.getBitcast(RegVT, Extract);
+              return DCI.CombineTo(N, Extract, SDValue(User, 1));
+            }
+          }
+        }
       }
     }
   }

diff  --git a/llvm/test/CodeGen/X86/vec_fabs.ll b/llvm/test/CodeGen/X86/vec_fabs.ll
index 0377d74fdcdb0f5..8876d2f9b19928e 100644
--- a/llvm/test/CodeGen/X86/vec_fabs.ll
+++ b/llvm/test/CodeGen/X86/vec_fabs.ll
@@ -332,10 +332,9 @@ define void @PR70947(ptr %src, ptr %dst) {
 ; X86-AVX2:       # %bb.0:
 ; X86-AVX2-NEXT:    movl {{[0-9]+}}(%esp), %eax
 ; X86-AVX2-NEXT:    movl {{[0-9]+}}(%esp), %ecx
-; X86-AVX2-NEXT:    vmovups 32(%ecx), %xmm0
-; X86-AVX2-NEXT:    vbroadcastsd {{.*#+}} ymm1 = [NaN,NaN,NaN,NaN]
-; X86-AVX2-NEXT:    vandps (%ecx), %ymm1, %ymm1
-; X86-AVX2-NEXT:    vandps {{\.?LCPI[0-9]+_[0-9]+}}, %xmm0, %xmm0
+; X86-AVX2-NEXT:    vbroadcastsd {{.*#+}} ymm0 = [NaN,NaN,NaN,NaN]
+; X86-AVX2-NEXT:    vandps (%ecx), %ymm0, %ymm1
+; X86-AVX2-NEXT:    vandps 32(%ecx), %xmm0, %xmm0
 ; X86-AVX2-NEXT:    vmovups %ymm1, (%eax)
 ; X86-AVX2-NEXT:    vmovups %xmm0, 16(%eax)
 ; X86-AVX2-NEXT:    vzeroupper
@@ -378,10 +377,9 @@ define void @PR70947(ptr %src, ptr %dst) {
 ;
 ; X64-AVX2-LABEL: PR70947:
 ; X64-AVX2:       # %bb.0:
-; X64-AVX2-NEXT:    vmovups 32(%rdi), %xmm0
-; X64-AVX2-NEXT:    vbroadcastsd {{.*#+}} ymm1 = [NaN,NaN,NaN,NaN]
-; X64-AVX2-NEXT:    vandps (%rdi), %ymm1, %ymm1
-; X64-AVX2-NEXT:    vandps {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0, %xmm0
+; X64-AVX2-NEXT:    vbroadcastsd {{.*#+}} ymm0 = [NaN,NaN,NaN,NaN]
+; X64-AVX2-NEXT:    vandps (%rdi), %ymm0, %ymm1
+; X64-AVX2-NEXT:    vandps 32(%rdi), %xmm0, %xmm0
 ; X64-AVX2-NEXT:    vmovups %ymm1, (%rsi)
 ; X64-AVX2-NEXT:    vmovups %xmm0, 16(%rsi)
 ; X64-AVX2-NEXT:    vzeroupper


        


More information about the llvm-commits mailing list