[llvm] [NVPTX] Optimize v16i8 reductions (PR #67322)
Artem Belevich via llvm-commits
llvm-commits at lists.llvm.org
Mon Sep 25 13:30:05 PDT 2023
================
@@ -5294,6 +5295,98 @@ static SDValue PerformEXTRACTCombine(SDNode *N,
return Result;
}
+static SDValue PerformLOADCombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI) {
+ SelectionDAG &DAG = DCI.DAG;
+ LoadSDNode *LD = cast<LoadSDNode>(N);
+
+ // Lower a v16i8 load into a LoadV4 operation with i32 results instead of
+ // letting ReplaceLoadVector split it into smaller loads during legalization.
+ // This is done at dag-combine1 time, so that vector operations with i8
+ // elements can be optimised away instead of being needlessly split during
+ // legalization, which involves storing to the stack and loading it back.
+ EVT VT = N->getValueType(0);
+ if (VT != MVT::v16i8)
+ return SDValue();
+
+ SDLoc DL(N);
+
+ // Create a v4i32 vector load operation, effectively <4 x v4i8>.
+ unsigned Opc = NVPTXISD::LoadV4;
+ EVT NewVT = MVT::v4i32;
+ EVT EleVT = NewVT.getVectorElementType();
+ unsigned NumEles = NewVT.getVectorNumElements();
+ EVT RetVTs[] = {EleVT, EleVT, EleVT, EleVT, MVT::Other};
+ SDVTList RetVTList = DAG.getVTList(RetVTs);
+ SmallVector<SDValue, 8> Ops(N->op_begin(), N->op_end());
+ Ops.push_back(DAG.getIntPtrConstant(LD->getExtensionType(), DL));
+ SDValue NewLoad = DAG.getMemIntrinsicNode(Opc, DL, RetVTList, Ops, NewVT,
+ LD->getMemOperand());
+ SDValue NewChain = NewLoad.getValue(NumEles);
+
+ // Create a vector of the same type returned by the original load.
+ SmallVector<SDValue, 4> Eles;
+ SDValue Vec;
+ for (unsigned i = 0; i < NumEles; i++)
+ Eles.push_back(NewLoad.getValue(i));
+ Vec = DAG.getBuildVector(NewVT, DL, Eles);
+ Vec = DCI.DAG.getBitcast(VT, Vec);
+
+ // Wrap the new vector and chain from the new load.
+ return DCI.DAG.getMergeValues({Vec, NewChain}, DL);
----------------
Artem-B wrote:
Nit: These all can be combined. We're not doing anything interesting with all the intermediate nodes.
```
return DCI.DAG.getMergeValues({
DCI.DAG.getBitcast(VT, DCI.DAG.getBuildVector(NewVT, DL, Eles)),
NewChain}, DL);
```
https://github.com/llvm/llvm-project/pull/67322
More information about the llvm-commits
mailing list