[llvm] 0b80288 - [NVPTX] Preserve v16i8 vector loads when legalizing
Luke Drummond via llvm-commits
llvm-commits at lists.llvm.org
Thu Oct 19 04:37:15 PDT 2023
Author: Pierre-Andre Saulais
Date: 2023-10-19T12:34:25+01:00
New Revision: 0b80288e9e0b12f9680d9f2cfdff5686c38982d2
URL: https://github.com/llvm/llvm-project/commit/0b80288e9e0b12f9680d9f2cfdff5686c38982d2
DIFF: https://github.com/llvm/llvm-project/commit/0b80288e9e0b12f9680d9f2cfdff5686c38982d2.diff
LOG: [NVPTX] Preserve v16i8 vector loads when legalizing
This is done by lowering v16i8 loads into LoadV4 operations with i32
results instead of letting ReplaceLoadVector split it into smaller
loads during legalization. This is done at dag-combine1 time, so that
vector operations with i8 elements can be optimised away instead of
being needlessly split during legalization, which involves storing to
the stack and loading it back.
Added:
Modified:
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
llvm/test/CodeGen/NVPTX/LoadStoreVectorizer.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index a935c0e16a5523c..617009334dd201c 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -701,8 +701,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setOperationAction(ISD::UMUL_LOHI, MVT::i64, Expand);
// We have some custom DAG combine patterns for these nodes
- setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::FADD, ISD::MUL, ISD::SHL,
- ISD::SREM, ISD::UREM, ISD::EXTRACT_VECTOR_ELT,
+ setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
+ ISD::LOAD, ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM,
ISD::VSELECT});
// setcc for f16x2 and bf16x2 needs special handling to prevent
@@ -5479,6 +5479,45 @@ static SDValue PerformVSELECTCombine(SDNode *N,
return DCI.DAG.getNode(ISD::BUILD_VECTOR, DL, MVT::v4i8, E);
}
+static SDValue PerformLOADCombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI) {
+ SelectionDAG &DAG = DCI.DAG;
+ LoadSDNode *LD = cast<LoadSDNode>(N);
+
+ // Lower a v16i8 load into a LoadV4 operation with i32 results instead of
+ // letting ReplaceLoadVector split it into smaller loads during legalization.
+ // This is done at dag-combine1 time, so that vector operations with i8
+ // elements can be optimised away instead of being needlessly split during
+ // legalization, which involves storing to the stack and loading it back.
+ EVT VT = N->getValueType(0);
+ if (VT != MVT::v16i8)
+ return SDValue();
+
+ SDLoc DL(N);
+
+ // Create a v4i32 vector load operation, effectively <4 x v4i8>.
+ unsigned Opc = NVPTXISD::LoadV4;
+ EVT NewVT = MVT::v4i32;
+ EVT EltVT = NewVT.getVectorElementType();
+ unsigned NumElts = NewVT.getVectorNumElements();
+ EVT RetVTs[] = {EltVT, EltVT, EltVT, EltVT, MVT::Other};
+ SDVTList RetVTList = DAG.getVTList(RetVTs);
+ SmallVector<SDValue, 8> Ops(N->ops());
+ Ops.push_back(DAG.getIntPtrConstant(LD->getExtensionType(), DL));
+ SDValue NewLoad = DAG.getMemIntrinsicNode(Opc, DL, RetVTList, Ops, NewVT,
+ LD->getMemOperand());
+ SDValue NewChain = NewLoad.getValue(NumElts);
+
+ // Create a vector of the same type returned by the original load.
+ SmallVector<SDValue, 4> Elts;
+ for (unsigned i = 0; i < NumElts; i++)
+ Elts.push_back(NewLoad.getValue(i));
+ return DCI.DAG.getMergeValues(
+ {DCI.DAG.getBitcast(VT, DCI.DAG.getBuildVector(NewVT, DL, Elts)),
+ NewChain},
+ DL);
+}
+
SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
@@ -5498,6 +5537,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
return PerformREMCombine(N, DCI, OptLevel);
case ISD::SETCC:
return PerformSETCCCombine(N, DCI);
+ case ISD::LOAD:
+ return PerformLOADCombine(N, DCI);
case NVPTXISD::StoreRetval:
case NVPTXISD::StoreRetvalV2:
case NVPTXISD::StoreRetvalV4:
diff --git a/llvm/test/CodeGen/NVPTX/LoadStoreVectorizer.ll b/llvm/test/CodeGen/NVPTX/LoadStoreVectorizer.ll
index 4f13b6d9d1a8a9d..868a06e2a850cc8 100644
--- a/llvm/test/CodeGen/NVPTX/LoadStoreVectorizer.ll
+++ b/llvm/test/CodeGen/NVPTX/LoadStoreVectorizer.ll
@@ -52,3 +52,126 @@ define float @ff(ptr %p) {
%sum = fadd float %sum3, %v4
ret float %sum
}
+
+define void @combine_v16i8(ptr noundef align 16 %ptr1, ptr noundef align 16 %ptr2) {
+ ; ENABLED-LABEL: combine_v16i8
+ ; ENABLED: ld.v4.u32
+ %val0 = load i8, ptr %ptr1, align 16
+ %ptr1.1 = getelementptr inbounds i8, ptr %ptr1, i64 1
+ %val1 = load i8, ptr %ptr1.1, align 1
+ %ptr1.2 = getelementptr inbounds i8, ptr %ptr1, i64 2
+ %val2 = load i8, ptr %ptr1.2, align 2
+ %ptr1.3 = getelementptr inbounds i8, ptr %ptr1, i64 3
+ %val3 = load i8, ptr %ptr1.3, align 1
+ %ptr1.4 = getelementptr inbounds i8, ptr %ptr1, i64 4
+ %val4 = load i8, ptr %ptr1.4, align 4
+ %ptr1.5 = getelementptr inbounds i8, ptr %ptr1, i64 5
+ %val5 = load i8, ptr %ptr1.5, align 1
+ %ptr1.6 = getelementptr inbounds i8, ptr %ptr1, i64 6
+ %val6 = load i8, ptr %ptr1.6, align 2
+ %ptr1.7 = getelementptr inbounds i8, ptr %ptr1, i64 7
+ %val7 = load i8, ptr %ptr1.7, align 1
+ %ptr1.8 = getelementptr inbounds i8, ptr %ptr1, i64 8
+ %val8 = load i8, ptr %ptr1.8, align 8
+ %ptr1.9 = getelementptr inbounds i8, ptr %ptr1, i64 9
+ %val9 = load i8, ptr %ptr1.9, align 1
+ %ptr1.10 = getelementptr inbounds i8, ptr %ptr1, i64 10
+ %val10 = load i8, ptr %ptr1.10, align 2
+ %ptr1.11 = getelementptr inbounds i8, ptr %ptr1, i64 11
+ %val11 = load i8, ptr %ptr1.11, align 1
+ %ptr1.12 = getelementptr inbounds i8, ptr %ptr1, i64 12
+ %val12 = load i8, ptr %ptr1.12, align 4
+ %ptr1.13 = getelementptr inbounds i8, ptr %ptr1, i64 13
+ %val13 = load i8, ptr %ptr1.13, align 1
+ %ptr1.14 = getelementptr inbounds i8, ptr %ptr1, i64 14
+ %val14 = load i8, ptr %ptr1.14, align 2
+ %ptr1.15 = getelementptr inbounds i8, ptr %ptr1, i64 15
+ %val15 = load i8, ptr %ptr1.15, align 1
+ %lane0 = zext i8 %val0 to i32
+ %lane1 = zext i8 %val1 to i32
+ %lane2 = zext i8 %val2 to i32
+ %lane3 = zext i8 %val3 to i32
+ %lane4 = zext i8 %val4 to i32
+ %lane5 = zext i8 %val5 to i32
+ %lane6 = zext i8 %val6 to i32
+ %lane7 = zext i8 %val7 to i32
+ %lane8 = zext i8 %val8 to i32
+ %lane9 = zext i8 %val9 to i32
+ %lane10 = zext i8 %val10 to i32
+ %lane11 = zext i8 %val11 to i32
+ %lane12 = zext i8 %val12 to i32
+ %lane13 = zext i8 %val13 to i32
+ %lane14 = zext i8 %val14 to i32
+ %lane15 = zext i8 %val15 to i32
+ %red.1 = add i32 %lane0, %lane1
+ %red.2 = add i32 %red.1, %lane2
+ %red.3 = add i32 %red.2, %lane3
+ %red.4 = add i32 %red.3, %lane4
+ %red.5 = add i32 %red.4, %lane5
+ %red.6 = add i32 %red.5, %lane6
+ %red.7 = add i32 %red.6, %lane7
+ %red.8 = add i32 %red.7, %lane8
+ %red.9 = add i32 %red.8, %lane9
+ %red.10 = add i32 %red.9, %lane10
+ %red.11 = add i32 %red.10, %lane11
+ %red.12 = add i32 %red.11, %lane12
+ %red.13 = add i32 %red.12, %lane13
+ %red.14 = add i32 %red.13, %lane14
+ %red = add i32 %red.14, %lane15
+ store i32 %red, ptr %ptr2, align 4
+ ret void
+}
+
+define void @combine_v8i16(ptr noundef align 16 %ptr1, ptr noundef align 16 %ptr2) {
+ ; ENABLED-LABEL: combine_v8i16
+ ; ENABLED: ld.v4.b32
+ %val0 = load i16, ptr %ptr1, align 16
+ %ptr1.1 = getelementptr inbounds i16, ptr %ptr1, i64 1
+ %val1 = load i16, ptr %ptr1.1, align 2
+ %ptr1.2 = getelementptr inbounds i16, ptr %ptr1, i64 2
+ %val2 = load i16, ptr %ptr1.2, align 4
+ %ptr1.3 = getelementptr inbounds i16, ptr %ptr1, i64 3
+ %val3 = load i16, ptr %ptr1.3, align 2
+ %ptr1.4 = getelementptr inbounds i16, ptr %ptr1, i64 4
+ %val4 = load i16, ptr %ptr1.4, align 4
+ %ptr1.5 = getelementptr inbounds i16, ptr %ptr1, i64 5
+ %val5 = load i16, ptr %ptr1.5, align 2
+ %ptr1.6 = getelementptr inbounds i16, ptr %ptr1, i64 6
+ %val6 = load i16, ptr %ptr1.6, align 4
+ %ptr1.7 = getelementptr inbounds i16, ptr %ptr1, i64 7
+ %val7 = load i16, ptr %ptr1.7, align 2
+ %lane0 = zext i16 %val0 to i32
+ %lane1 = zext i16 %val1 to i32
+ %lane2 = zext i16 %val2 to i32
+ %lane3 = zext i16 %val3 to i32
+ %lane4 = zext i16 %val4 to i32
+ %lane5 = zext i16 %val5 to i32
+ %lane6 = zext i16 %val6 to i32
+ %lane7 = zext i16 %val7 to i32
+ %red.1 = add i32 %lane0, %lane1
+ %red.2 = add i32 %red.1, %lane2
+ %red.3 = add i32 %red.2, %lane3
+ %red.4 = add i32 %red.3, %lane4
+ %red.5 = add i32 %red.4, %lane5
+ %red.6 = add i32 %red.5, %lane6
+ %red = add i32 %red.6, %lane7
+ store i32 %red, ptr %ptr2, align 4
+ ret void
+}
+
+define void @combine_v4i32(ptr noundef align 16 %ptr1, ptr noundef align 16 %ptr2) {
+ ; ENABLED-LABEL: combine_v4i32
+ ; ENABLED: ld.v4.u32
+ %val0 = load i32, ptr %ptr1, align 16
+ %ptr1.1 = getelementptr inbounds i32, ptr %ptr1, i64 1
+ %val1 = load i32, ptr %ptr1.1, align 4
+ %ptr1.2 = getelementptr inbounds i32, ptr %ptr1, i64 2
+ %val2 = load i32, ptr %ptr1.2, align 8
+ %ptr1.3 = getelementptr inbounds i32, ptr %ptr1, i64 3
+ %val3 = load i32, ptr %ptr1.3, align 4
+ %red.1 = add i32 %val0, %val1
+ %red.2 = add i32 %red.1, %val2
+ %red = add i32 %red.2, %val3
+ store i32 %red, ptr %ptr2, align 4
+ ret void
+}
More information about the llvm-commits
mailing list