[llvm] [NVPTX] fold movs into loads and stores (PR #144581)
Alex MacLean via llvm-commits
llvm-commits at lists.llvm.org
Wed Jun 18 09:52:09 PDT 2025
================
@@ -5047,26 +5044,244 @@ PerformFADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
return SDValue();
}
-static SDValue PerformStoreCombineHelper(SDNode *N, std::size_t Front,
- std::size_t Back) {
+/// Combine extractelts into a load by increasing the number of return values.
+static SDValue
+combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
+ // Don't run this optimization before the legalizer
+ if (!DCI.isAfterLegalizeDAG())
+ return SDValue();
+
+ EVT ElemVT = N->getValueType(0);
+ if (!Isv2x16VT(ElemVT))
+ return SDValue();
+
+ // Check whether all outputs are either used by an extractelt or are
+ // glue/chain nodes
+ if (!all_of(N->uses(), [&](SDUse &U) {
+ // Skip glue, chain nodes
+ if (U.getValueType() == MVT::Glue || U.getValueType() == MVT::Other)
+ return true;
+ if (U.getUser()->getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
+ if (N->getOpcode() != ISD::LOAD)
+ return true;
+ // Since this is an ISD::LOAD, check all extractelts are used. If
+ // any are not used, we don't want to defeat another optimization that
+ // will narrow the load.
+ //
+ // For example:
+ //
+ // L: v2f16,ch = load <p>
+ // e0: f16 = extractelt L:0, 0
+ // e1: f16 = extractelt L:0, 1 <-- unused
+ // store e0
+ //
+ // Can be optimized by DAGCombiner to:
+ //
+ // L: f16,ch = load <p>
+ // store L:0
+ return !U.getUser()->use_empty();
+ }
+
+ // Otherwise, this use prevents us from splitting a value.
+ return false;
+ }))
+ return SDValue();
+
+ auto *LD = cast<MemSDNode>(N);
+ EVT MemVT = LD->getMemoryVT();
+ SDLoc DL(LD);
+
+ // the new opcode after we double the number of operands
+ NVPTXISD::NodeType Opcode;
+ SmallVector<SDValue> Operands(LD->ops());
+ unsigned OldNumValues;
+ switch (LD->getOpcode()) {
+ case ISD::LOAD:
+ OldNumValues = 1;
+ // Any packed type is legal, so the legalizer will not have lowered
+ // ISD::LOAD -> NVPTXISD::Load (unless it's under-aligned). We have to do it
+ // here.
+ Opcode = NVPTXISD::LoadV2;
+ Operands.push_back(DCI.DAG.getIntPtrConstant(
+ cast<LoadSDNode>(LD)->getExtensionType(), DL));
+ break;
+ case NVPTXISD::LoadParamV2:
+ OldNumValues = 2;
+ Opcode = NVPTXISD::LoadParamV4;
+ break;
+ case NVPTXISD::LoadV2:
+ OldNumValues = 2;
+ Opcode = NVPTXISD::LoadV4;
+ break;
+ case NVPTXISD::LoadV4:
+ // PTX doesn't support v8 for 16-bit values
+ case NVPTXISD::LoadV8:
+ // PTX doesn't support the next doubling of outputs
+ return SDValue();
+ }
+
+ SmallVector<EVT> NewVTs(OldNumValues * 2, ElemVT.getVectorElementType());
+ // add remaining chain and glue values
+ for (unsigned I = OldNumValues; I < LD->getNumValues(); ++I)
+ NewVTs.push_back(LD->getValueType(I));
+
+ // Create the new load
+ SDValue NewLoad =
+ DCI.DAG.getMemIntrinsicNode(Opcode, DL, DCI.DAG.getVTList(NewVTs),
+ Operands, MemVT, LD->getMemOperand());
+
+ // Now we use a combination of BUILD_VECTORs and a MERGE_VALUES node to keep
+ // the outputs the same. These nodes will be optimized away in later
+ // DAGCombiner iterations.
+ SmallVector<SDValue> Results;
+ for (unsigned I = 0; I < NewLoad->getNumValues();) {
+ if (NewLoad->getValueType(I) == ElemVT.getVectorElementType()) {
+ Results.push_back(DCI.DAG.getBuildVector(
+ ElemVT, DL, {NewLoad.getValue(I), NewLoad.getValue(I + 1)}));
+ I += 2;
+ } else {
+ Results.push_back(NewLoad.getValue(I));
+ I += 1;
+ }
+ }
+
+ return DCI.DAG.getMergeValues(Results, DL);
+}
+
+/// Fold a packing mov into a store. This may help lower register pressure.
+///
+/// ex:
+/// v: v2f16 = build_vector a:f16, b:f16
+/// StoreRetval v
+///
+/// ...is turned into...
+///
+/// StoreRetvalV2 a:f16, b:f16
+static SDValue combinePackingMovIntoStore(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI,
+ unsigned Front, unsigned Back) {
+ // We want to run this as late as possible since other optimizations may
+ // eliminate the BUILD_VECTORs.
+ if (!DCI.isAfterLegalizeDAG())
+ return SDValue();
+
+ // Get the type of the operands being stored.
+ EVT ElementVT = N->getOperand(Front).getValueType();
+
+ if (!Isv2x16VT(ElementVT))
+ return SDValue();
+
+ auto *ST = cast<MemSDNode>(N);
+ EVT MemVT = ElementVT.getVectorElementType();
+
+ // The new opcode after we double the number of operands.
+ NVPTXISD::NodeType Opcode;
+ switch (N->getOpcode()) {
+ case ISD::STORE:
+ // Any packed type is legal, so the legalizer will not have lowered
+ // ISD::STORE -> NVPTXISD::Store (unless it's under-aligned). We have to do
+ // it here.
+ MemVT = ST->getMemoryVT();
+ Opcode = NVPTXISD::StoreV2;
+ break;
+ case NVPTXISD::StoreParam:
+ Opcode = NVPTXISD::StoreParamV2;
+ break;
+ case NVPTXISD::StoreParamV2:
+ Opcode = NVPTXISD::StoreParamV4;
+ break;
+ case NVPTXISD::StoreRetval:
+ Opcode = NVPTXISD::StoreRetvalV2;
+ break;
+ case NVPTXISD::StoreRetvalV2:
+ Opcode = NVPTXISD::StoreRetvalV4;
+ break;
+ case NVPTXISD::StoreV2:
+ MemVT = ST->getMemoryVT();
+ Opcode = NVPTXISD::StoreV4;
+ break;
+ case NVPTXISD::StoreV4:
+ // PTX doesn't support v8 for 16-bit values
+ case NVPTXISD::StoreParamV4:
+ case NVPTXISD::StoreRetvalV4:
+ case NVPTXISD::StoreV8:
+ // PTX doesn't support the next doubling of operands for these opcodes.
+ return SDValue();
+ default:
+ llvm_unreachable("Unhandled store opcode");
+ }
+
+ // Scan the operands and if they're all BUILD_VECTORs, we'll have gathered
+ // their elements.
+ SmallVector<SDValue, 4> Operands(N->ops().take_front(Front));
+ for (SDValue BV : N->ops().drop_front(Front).drop_back(Back)) {
+ if (BV.getOpcode() != ISD::BUILD_VECTOR)
+ return SDValue();
+
+ // If the operand has multiple uses, this optimization can increase register
+ // pressure.
+ if (!BV.hasOneUse())
+ return SDValue();
+
+ // DAGCombiner visits nodes bottom-up. Check the BUILD_VECTOR operands for
+ // any signs they may be folded by some other pattern or rule.
+ for (SDValue Op : BV->ops()) {
+ // Peek through bitcasts
+ if (Op.getOpcode() == ISD::BITCAST)
+ Op = Op.getOperand(0);
+
+ // This may be folded into a PRMT.
+ if (Op.getValueType() == MVT::i16 && Op.getOpcode() == ISD::TRUNCATE &&
+ Op->getOperand(0).getValueType() == MVT::i32)
+ return SDValue();
+
+ // This may be folded into cvt.bf16x2
+ if (Op.getOpcode() == ISD::FP_ROUND)
+ return SDValue();
+ }
+ Operands.insert(Operands.end(), {BV.getOperand(0), BV.getOperand(1)});
----------------
AlexMaclean wrote:
Can this be accomplished with an `append`?
https://github.com/llvm/llvm-project/pull/144581
More information about the llvm-commits
mailing list