[llvm] 56cd3bc - [X86] Directly emit VBROADCAST_LOAD from constant pool in lowerBuildVectorAsBroadcast

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Tue Mar 3 10:40:29 PST 2020


Author: Craig Topper
Date: 2020-03-03T10:39:10-08:00
New Revision: 56cd3bc209e00b419911e88672bb5c6488b4c5ab

URL: https://github.com/llvm/llvm-project/commit/56cd3bc209e00b419911e88672bb5c6488b4c5ab
DIFF: https://github.com/llvm/llvm-project/commit/56cd3bc209e00b419911e88672bb5c6488b4c5ab.diff

LOG: [X86] Directly emit VBROADCAST_LOAD from constant pool in lowerBuildVectorAsBroadcast

Also add a DAG combine to combine different sized broadcasts from
constant pool to avoid a regression.

Differential Revision: https://reviews.llvm.org/D75412

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86ISelLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index cef768e95dda..4ca7f803132f 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -8558,12 +8558,14 @@ static SDValue lowerBuildVectorAsBroadcast(BuildVectorSDNode *BVOp,
           unsigned Repeat = VT.getSizeInBits() / SplatBitSize;
 
           unsigned Alignment = cast<ConstantPoolSDNode>(CP)->getAlignment();
-          Ld = DAG.getLoad(
-              CVT, dl, DAG.getEntryNode(), CP,
-              MachinePointerInfo::getConstantPool(DAG.getMachineFunction()),
-              Alignment);
-          SDValue Brdcst = DAG.getNode(X86ISD::VBROADCAST, dl,
-                                       MVT::getVectorVT(CVT, Repeat), Ld);
+          SDVTList Tys =
+              DAG.getVTList(MVT::getVectorVT(CVT, Repeat), MVT::Other);
+          SDValue Ops[] = {DAG.getEntryNode(), CP};
+          MachinePointerInfo MPI =
+              MachinePointerInfo::getConstantPool(DAG.getMachineFunction());
+          SDValue Brdcst = DAG.getMemIntrinsicNode(
+              X86ISD::VBROADCAST_LOAD, dl, Tys, Ops, CVT, MPI, Alignment,
+              MachineMemOperand::MOLoad);
           return DAG.getBitcast(VT, Brdcst);
         } else if (SplatBitSize == 32 || SplatBitSize == 64) {
           // Splatted value can fit in one FLOAT constant in constant pool.
@@ -8582,12 +8584,14 @@ static SDValue lowerBuildVectorAsBroadcast(BuildVectorSDNode *BVOp,
           unsigned Repeat = VT.getSizeInBits() / SplatBitSize;
 
           unsigned Alignment = cast<ConstantPoolSDNode>(CP)->getAlignment();
-          Ld = DAG.getLoad(
-              CVT, dl, DAG.getEntryNode(), CP,
-              MachinePointerInfo::getConstantPool(DAG.getMachineFunction()),
-              Alignment);
-          SDValue Brdcst = DAG.getNode(X86ISD::VBROADCAST, dl,
-                                       MVT::getVectorVT(CVT, Repeat), Ld);
+          SDVTList Tys =
+              DAG.getVTList(MVT::getVectorVT(CVT, Repeat), MVT::Other);
+          SDValue Ops[] = {DAG.getEntryNode(), CP};
+          MachinePointerInfo MPI =
+              MachinePointerInfo::getConstantPool(DAG.getMachineFunction());
+          SDValue Brdcst = DAG.getMemIntrinsicNode(
+              X86ISD::VBROADCAST_LOAD, dl, Tys, Ops, CVT, MPI, Alignment,
+              MachineMemOperand::MOLoad);
           return DAG.getBitcast(VT, Brdcst);
         } else if (SplatBitSize > 64) {
           // Load the vector of constants and broadcast it.
@@ -8667,12 +8671,13 @@ static SDValue lowerBuildVectorAsBroadcast(BuildVectorSDNode *BVOp,
       SDValue CP =
           DAG.getConstantPool(C, TLI.getPointerTy(DAG.getDataLayout()));
       unsigned Alignment = cast<ConstantPoolSDNode>(CP)->getAlignment();
-      Ld = DAG.getLoad(
-          CVT, dl, DAG.getEntryNode(), CP,
-          MachinePointerInfo::getConstantPool(DAG.getMachineFunction()),
-          Alignment);
 
-      return DAG.getNode(X86ISD::VBROADCAST, dl, VT, Ld);
+      SDVTList Tys = DAG.getVTList(VT, MVT::Other);
+      SDValue Ops[] = {DAG.getEntryNode(), CP};
+      MachinePointerInfo MPI =
+          MachinePointerInfo::getConstantPool(DAG.getMachineFunction());
+      return DAG.getMemIntrinsicNode(X86ISD::VBROADCAST_LOAD, dl, Tys, Ops, CVT,
+                                     MPI, Alignment, MachineMemOperand::MOLoad);
     }
   }
 
@@ -46828,6 +46833,41 @@ static SDValue combineFP_EXTEND(SDNode *N, SelectionDAG &DAG,
   return DAG.getNode(ISD::FP_EXTEND, dl, VT, Cvt);
 }
 
+// Try to find a larger VBROADCAST_LOAD that we can extract from. Limit this to
+// cases where the loads have the same input chain and the output chains are
+// unused. This avoids any memory ordering issues.
+static SDValue combineVBROADCAST_LOAD(SDNode *N, SelectionDAG &DAG,
+                                      TargetLowering::DAGCombinerInfo &DCI) {
+  // Only do this if the chain result is unused.
+  if (N->hasAnyUseOfValue(1))
+    return SDValue();
+
+  auto *MemIntrin = cast<MemIntrinsicSDNode>(N);
+
+  SDValue Ptr = MemIntrin->getBasePtr();
+  SDValue Chain = MemIntrin->getChain();
+  EVT VT = N->getSimpleValueType(0);
+  EVT MemVT = MemIntrin->getMemoryVT();
+
+  // Look at other users of our base pointer and try to find a wider broadcast.
+  // The input chain and the size of the memory VT must match.
+  for (SDNode *User : Ptr->uses())
+    if (User != N && User->getOpcode() == X86ISD::VBROADCAST_LOAD &&
+        cast<MemIntrinsicSDNode>(User)->getBasePtr() == Ptr &&
+        cast<MemIntrinsicSDNode>(User)->getChain() == Chain &&
+        cast<MemIntrinsicSDNode>(User)->getMemoryVT().getSizeInBits() ==
+            MemVT.getSizeInBits() &&
+        !User->hasAnyUseOfValue(1) &&
+        User->getValueSizeInBits(0) > VT.getSizeInBits()) {
+      SDValue Extract = extractSubVector(SDValue(User, 0), 0, DAG, SDLoc(N),
+                                         VT.getSizeInBits());
+      Extract = DAG.getBitcast(VT, Extract);
+      return DCI.CombineTo(N, Extract, SDValue(User, 1));
+    }
+
+  return SDValue();
+}
+
 static SDValue combineFP_ROUND(SDNode *N, SelectionDAG &DAG,
                                const X86Subtarget &Subtarget) {
   if (!Subtarget.hasF16C() || Subtarget.useSoftFloat())
@@ -47027,6 +47067,7 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
   case ISD::STRICT_FP_EXTEND:
   case ISD::FP_EXTEND:      return combineFP_EXTEND(N, DAG, Subtarget);
   case ISD::FP_ROUND:       return combineFP_ROUND(N, DAG, Subtarget);
+  case X86ISD::VBROADCAST_LOAD: return combineVBROADCAST_LOAD(N, DAG, DCI);
   }
 
   return SDValue();


        


More information about the llvm-commits mailing list