[llvm] [NVPTX] fold movs into loads and stores (PR #144581)

Princeton Ferro via llvm-commits llvm-commits at lists.llvm.org
Wed Jun 18 10:26:34 PDT 2025


================
@@ -5047,26 +5044,244 @@ PerformFADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
   return SDValue();
 }
 
-static SDValue PerformStoreCombineHelper(SDNode *N, std::size_t Front,
-                                         std::size_t Back) {
+/// Combine extractelts into a load by increasing the number of return values.
+static SDValue
+combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
+  // Don't run this optimization before the legalizer
+  if (!DCI.isAfterLegalizeDAG())
+    return SDValue();
+
+  EVT ElemVT = N->getValueType(0);
+  if (!Isv2x16VT(ElemVT))
+    return SDValue();
+
+  // Check whether all outputs are either used by an extractelt or are
+  // glue/chain nodes
+  if (!all_of(N->uses(), [&](SDUse &U) {
+        // Skip glue, chain nodes
+        if (U.getValueType() == MVT::Glue || U.getValueType() == MVT::Other)
+          return true;
+        if (U.getUser()->getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
+          if (N->getOpcode() != ISD::LOAD)
+            return true;
+          // Since this is an ISD::LOAD, check all extractelts are used. If
+          // any are not used, we don't want to defeat another optimization that
+          // will narrow the load.
+          //
+          // For example:
+          //
+          // L: v2f16,ch = load <p>
+          // e0: f16 = extractelt L:0, 0
+          // e1: f16 = extractelt L:0, 1        <-- unused
+          // store e0
+          //
+          // Can be optimized by DAGCombiner to:
+          //
+          // L: f16,ch = load <p>
+          // store L:0
+          return !U.getUser()->use_empty();
+        }
+
+        // Otherwise, this use prevents us from splitting a value.
+        return false;
+      }))
+    return SDValue();
+
+  auto *LD = cast<MemSDNode>(N);
+  EVT MemVT = LD->getMemoryVT();
+  SDLoc DL(LD);
+
+  // the new opcode after we double the number of operands
+  NVPTXISD::NodeType Opcode;
+  SmallVector<SDValue> Operands(LD->ops());
+  unsigned OldNumValues;
+  switch (LD->getOpcode()) {
+  case ISD::LOAD:
+    OldNumValues = 1;
+    // Any packed type is legal, so the legalizer will not have lowered
+    // ISD::LOAD -> NVPTXISD::Load (unless it's under-aligned). We have to do it
+    // here.
+    Opcode = NVPTXISD::LoadV2;
+    Operands.push_back(DCI.DAG.getIntPtrConstant(
+        cast<LoadSDNode>(LD)->getExtensionType(), DL));
+    break;
+  case NVPTXISD::LoadParamV2:
+    OldNumValues = 2;
+    Opcode = NVPTXISD::LoadParamV4;
+    break;
+  case NVPTXISD::LoadV2:
+    OldNumValues = 2;
+    Opcode = NVPTXISD::LoadV4;
+    break;
+  case NVPTXISD::LoadV4:
+    // PTX doesn't support v8 for 16-bit values
+  case NVPTXISD::LoadV8:
+    // PTX doesn't support the next doubling of outputs
+    return SDValue();
+  }
+
+  SmallVector<EVT> NewVTs(OldNumValues * 2, ElemVT.getVectorElementType());
+  // add remaining chain and glue values
+  for (unsigned I = OldNumValues; I < LD->getNumValues(); ++I)
+    NewVTs.push_back(LD->getValueType(I));
+
+  // Create the new load
+  SDValue NewLoad =
+      DCI.DAG.getMemIntrinsicNode(Opcode, DL, DCI.DAG.getVTList(NewVTs),
+                                  Operands, MemVT, LD->getMemOperand());
+
+  // Now we use a combination of BUILD_VECTORs and a MERGE_VALUES node to keep
+  // the outputs the same. These nodes will be optimized away in later
+  // DAGCombiner iterations.
+  SmallVector<SDValue> Results;
+  for (unsigned I = 0; I < NewLoad->getNumValues();) {
+    if (NewLoad->getValueType(I) == ElemVT.getVectorElementType()) {
+      Results.push_back(DCI.DAG.getBuildVector(
+          ElemVT, DL, {NewLoad.getValue(I), NewLoad.getValue(I + 1)}));
+      I += 2;
+    } else {
+      Results.push_back(NewLoad.getValue(I));
+      I += 1;
+    }
+  }
+
+  return DCI.DAG.getMergeValues(Results, DL);
+}
+
+/// Fold a packing mov into a store. This may help lower register pressure.
+///
+/// ex:
+/// v: v2f16 = build_vector a:f16, b:f16
+/// StoreRetval v
+///
+/// ...is turned into...
+///
+/// StoreRetvalV2 a:f16, b:f16
+static SDValue combinePackingMovIntoStore(SDNode *N,
+                                          TargetLowering::DAGCombinerInfo &DCI,
+                                          unsigned Front, unsigned Back) {
+  // We want to run this as late as possible since other optimizations may
+  // eliminate the BUILD_VECTORs.
+  if (!DCI.isAfterLegalizeDAG())
+    return SDValue();
+
+  // Get the type of the operands being stored.
+  EVT ElementVT = N->getOperand(Front).getValueType();
+
+  if (!Isv2x16VT(ElementVT))
+    return SDValue();
+
+  auto *ST = cast<MemSDNode>(N);
+  EVT MemVT = ElementVT.getVectorElementType();
+
+  // The new opcode after we double the number of operands.
+  NVPTXISD::NodeType Opcode;
+  switch (N->getOpcode()) {
+  case ISD::STORE:
+    // Any packed type is legal, so the legalizer will not have lowered
+    // ISD::STORE -> NVPTXISD::Store (unless it's under-aligned). We have to do
+    // it here.
+    MemVT = ST->getMemoryVT();
+    Opcode = NVPTXISD::StoreV2;
+    break;
+  case NVPTXISD::StoreParam:
+    Opcode = NVPTXISD::StoreParamV2;
+    break;
+  case NVPTXISD::StoreParamV2:
+    Opcode = NVPTXISD::StoreParamV4;
+    break;
+  case NVPTXISD::StoreRetval:
+    Opcode = NVPTXISD::StoreRetvalV2;
+    break;
+  case NVPTXISD::StoreRetvalV2:
+    Opcode = NVPTXISD::StoreRetvalV4;
+    break;
+  case NVPTXISD::StoreV2:
+    MemVT = ST->getMemoryVT();
+    Opcode = NVPTXISD::StoreV4;
+    break;
+  case NVPTXISD::StoreV4:
+    // PTX doesn't support v8 for 16-bit values
+  case NVPTXISD::StoreParamV4:
+  case NVPTXISD::StoreRetvalV4:
+  case NVPTXISD::StoreV8:
+    // PTX doesn't support the next doubling of operands for these opcodes.
+    return SDValue();
+  default:
+    llvm_unreachable("Unhandled store opcode");
+  }
+
+  // Scan the operands and if they're all BUILD_VECTORs, we'll have gathered
+  // their elements.
+  SmallVector<SDValue, 4> Operands(N->ops().take_front(Front));
+  for (SDValue BV : N->ops().drop_front(Front).drop_back(Back)) {
+    if (BV.getOpcode() != ISD::BUILD_VECTOR)
+      return SDValue();
+
+    // If the operand has multiple uses, this optimization can increase register
+    // pressure.
+    if (!BV.hasOneUse())
+      return SDValue();
+
+    // DAGCombiner visits nodes bottom-up. Check the BUILD_VECTOR operands for
+    // any signs they may be folded by some other pattern or rule.
+    for (SDValue Op : BV->ops()) {
+      // Peek through bitcasts
+      if (Op.getOpcode() == ISD::BITCAST)
+        Op = Op.getOperand(0);
+
+      // This may be folded into a PRMT.
+      if (Op.getValueType() == MVT::i16 && Op.getOpcode() == ISD::TRUNCATE &&
+          Op->getOperand(0).getValueType() == MVT::i32)
+        return SDValue();
+
+      // This may be folded into cvt.bf16x2
+      if (Op.getOpcode() == ISD::FP_ROUND)
+        return SDValue();
+    }
+    Operands.insert(Operands.end(), {BV.getOperand(0), BV.getOperand(1)});
----------------
Prince781 wrote:

Done.

https://github.com/llvm/llvm-project/pull/144581


More information about the llvm-commits mailing list