[llvm] [NVPTX] Preserve v16i8 vector loads when legalizing (PR #67322)

Pierre-Andre Saulais via llvm-commits llvm-commits at lists.llvm.org
Sat Sep 30 01:56:01 PDT 2023


https://github.com/pasaulais updated https://github.com/llvm/llvm-project/pull/67322

>From da91480413ba34bd21a8af3fd38012d1be4142df Mon Sep 17 00:00:00 2001
From: Pierre-Andre Saulais <pierre-andre at codeplay.com>
Date: Thu, 21 Sep 2023 11:42:19 +0100
Subject: [PATCH] [NVPTX] Preserve v16i8 vector loads when legalizing

This is done by lowering v16i8 loads into LoadV4 operations with i32
results instead of letting ReplaceLoadVector split it into smaller
loads during legalization. This is done at dag-combine1 time, so that
vector operations with i8 elements can be optimised away instead of
being needlessly split during legalization, which involves storing to
the stack and loading it back.
---
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp   |  44 ++++++-
 .../test/CodeGen/NVPTX/LoadStoreVectorizer.ll | 123 ++++++++++++++++++
 2 files changed, 165 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index b24aae4792ce6a6..bb54fecd7c8013a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -671,8 +671,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
   setOperationAction(ISD::UMUL_LOHI, MVT::i64, Expand);
 
   // We have some custom DAG combine patterns for these nodes
-  setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::FADD, ISD::MUL, ISD::SHL,
-                       ISD::SREM, ISD::UREM, ISD::EXTRACT_VECTOR_ELT});
+  setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
+                       ISD::LOAD, ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM});
 
   // setcc for f16x2 and bf16x2 needs special handling to prevent
   // legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -5292,6 +5292,44 @@ static SDValue PerformEXTRACTCombine(SDNode *N,
   return Result;
 }
 
+static SDValue PerformLOADCombine(SDNode *N,
+                                  TargetLowering::DAGCombinerInfo &DCI) {
+  SelectionDAG &DAG = DCI.DAG;
+  LoadSDNode *LD = cast<LoadSDNode>(N);
+
+  // Lower a v16i8 load into a LoadV4 operation with i32 results instead of
+  // letting ReplaceLoadVector split it into smaller loads during legalization.
+  // This is done at dag-combine1 time, so that vector operations with i8
+  // elements can be optimised away instead of being needlessly split during
+  // legalization, which involves storing to the stack and loading it back.
+  EVT VT = N->getValueType(0);
+  if (VT != MVT::v16i8)
+    return SDValue();
+
+  SDLoc DL(N);
+
+  // Create a v4i32 vector load operation, effectively <4 x v4i8>.
+  unsigned Opc = NVPTXISD::LoadV4;
+  EVT NewVT = MVT::v4i32;
+  EVT EltVT = NewVT.getVectorElementType();
+  unsigned NumElts = NewVT.getVectorNumElements();
+  EVT RetVTs[] = {EltVT, EltVT, EltVT, EltVT, MVT::Other};
+  SDVTList RetVTList = DAG.getVTList(RetVTs);
+  SmallVector<SDValue, 8> Ops(N->ops());
+  Ops.push_back(DAG.getIntPtrConstant(LD->getExtensionType(), DL));
+  SDValue NewLoad = DAG.getMemIntrinsicNode(Opc, DL, RetVTList, Ops, NewVT,
+                                            LD->getMemOperand());
+  SDValue NewChain = NewLoad.getValue(NumElts);
+
+  // Create a vector of the same type returned by the original load.
+  SmallVector<SDValue, 4> Elts;
+  for (unsigned i = 0; i < NumElts; i++)
+    Elts.push_back(NewLoad.getValue(i));
+  return DCI.DAG.getMergeValues({
+        DCI.DAG.getBitcast(VT, DCI.DAG.getBuildVector(NewVT, DL, Elts)),
+        NewChain}, DL);
+}
+
 SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
                                                DAGCombinerInfo &DCI) const {
   CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
@@ -5311,6 +5349,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
       return PerformREMCombine(N, DCI, OptLevel);
     case ISD::SETCC:
       return PerformSETCCCombine(N, DCI);
+    case ISD::LOAD:
+      return PerformLOADCombine(N, DCI);
     case NVPTXISD::StoreRetval:
     case NVPTXISD::StoreRetvalV2:
     case NVPTXISD::StoreRetvalV4:
diff --git a/llvm/test/CodeGen/NVPTX/LoadStoreVectorizer.ll b/llvm/test/CodeGen/NVPTX/LoadStoreVectorizer.ll
index 4f13b6d9d1a8a9d..868a06e2a850cc8 100644
--- a/llvm/test/CodeGen/NVPTX/LoadStoreVectorizer.ll
+++ b/llvm/test/CodeGen/NVPTX/LoadStoreVectorizer.ll
@@ -52,3 +52,126 @@ define float @ff(ptr %p) {
   %sum = fadd float %sum3, %v4
   ret float %sum
 }
+
+define void @combine_v16i8(ptr noundef align 16 %ptr1, ptr noundef align 16 %ptr2) {
+  ; ENABLED-LABEL: combine_v16i8
+  ; ENABLED: ld.v4.u32
+  %val0 = load i8, ptr %ptr1, align 16
+  %ptr1.1 = getelementptr inbounds i8, ptr %ptr1, i64 1
+  %val1 = load i8, ptr %ptr1.1, align 1
+  %ptr1.2 = getelementptr inbounds i8, ptr %ptr1, i64 2
+  %val2 = load i8, ptr %ptr1.2, align 2
+  %ptr1.3 = getelementptr inbounds i8, ptr %ptr1, i64 3
+  %val3 = load i8, ptr %ptr1.3, align 1
+  %ptr1.4 = getelementptr inbounds i8, ptr %ptr1, i64 4
+  %val4 = load i8, ptr %ptr1.4, align 4
+  %ptr1.5 = getelementptr inbounds i8, ptr %ptr1, i64 5
+  %val5 = load i8, ptr %ptr1.5, align 1
+  %ptr1.6 = getelementptr inbounds i8, ptr %ptr1, i64 6
+  %val6 = load i8, ptr %ptr1.6, align 2
+  %ptr1.7 = getelementptr inbounds i8, ptr %ptr1, i64 7
+  %val7 = load i8, ptr %ptr1.7, align 1
+  %ptr1.8 = getelementptr inbounds i8, ptr %ptr1, i64 8
+  %val8 = load i8, ptr %ptr1.8, align 8
+  %ptr1.9 = getelementptr inbounds i8, ptr %ptr1, i64 9
+  %val9 = load i8, ptr %ptr1.9, align 1
+  %ptr1.10 = getelementptr inbounds i8, ptr %ptr1, i64 10
+  %val10 = load i8, ptr %ptr1.10, align 2
+  %ptr1.11 = getelementptr inbounds i8, ptr %ptr1, i64 11
+  %val11 = load i8, ptr %ptr1.11, align 1
+  %ptr1.12 = getelementptr inbounds i8, ptr %ptr1, i64 12
+  %val12 = load i8, ptr %ptr1.12, align 4
+  %ptr1.13 = getelementptr inbounds i8, ptr %ptr1, i64 13
+  %val13 = load i8, ptr %ptr1.13, align 1
+  %ptr1.14 = getelementptr inbounds i8, ptr %ptr1, i64 14
+  %val14 = load i8, ptr %ptr1.14, align 2
+  %ptr1.15 = getelementptr inbounds i8, ptr %ptr1, i64 15
+  %val15 = load i8, ptr %ptr1.15, align 1
+  %lane0 = zext i8 %val0 to i32
+  %lane1 = zext i8 %val1 to i32
+  %lane2 = zext i8 %val2 to i32
+  %lane3 = zext i8 %val3 to i32
+  %lane4 = zext i8 %val4 to i32
+  %lane5 = zext i8 %val5 to i32
+  %lane6 = zext i8 %val6 to i32
+  %lane7 = zext i8 %val7 to i32
+  %lane8 = zext i8 %val8 to i32
+  %lane9 = zext i8 %val9 to i32
+  %lane10 = zext i8 %val10 to i32
+  %lane11 = zext i8 %val11 to i32
+  %lane12 = zext i8 %val12 to i32
+  %lane13 = zext i8 %val13 to i32
+  %lane14 = zext i8 %val14 to i32
+  %lane15 = zext i8 %val15 to i32
+  %red.1 = add i32 %lane0, %lane1
+  %red.2 = add i32 %red.1, %lane2
+  %red.3 = add i32 %red.2, %lane3
+  %red.4 = add i32 %red.3, %lane4
+  %red.5 = add i32 %red.4, %lane5
+  %red.6 = add i32 %red.5, %lane6
+  %red.7 = add i32 %red.6, %lane7
+  %red.8 = add i32 %red.7, %lane8
+  %red.9 = add i32 %red.8, %lane9
+  %red.10 = add i32 %red.9, %lane10
+  %red.11 = add i32 %red.10, %lane11
+  %red.12 = add i32 %red.11, %lane12
+  %red.13 = add i32 %red.12, %lane13
+  %red.14 = add i32 %red.13, %lane14
+  %red = add i32 %red.14, %lane15
+  store i32 %red, ptr %ptr2, align 4
+  ret void
+}
+
+define void @combine_v8i16(ptr noundef align 16 %ptr1, ptr noundef align 16 %ptr2) {
+  ; ENABLED-LABEL: combine_v8i16
+  ; ENABLED: ld.v4.b32
+  %val0 = load i16, ptr %ptr1, align 16
+  %ptr1.1 = getelementptr inbounds i16, ptr %ptr1, i64 1
+  %val1 = load i16, ptr %ptr1.1, align 2
+  %ptr1.2 = getelementptr inbounds i16, ptr %ptr1, i64 2
+  %val2 = load i16, ptr %ptr1.2, align 4
+  %ptr1.3 = getelementptr inbounds i16, ptr %ptr1, i64 3
+  %val3 = load i16, ptr %ptr1.3, align 2
+  %ptr1.4 = getelementptr inbounds i16, ptr %ptr1, i64 4
+  %val4 = load i16, ptr %ptr1.4, align 4
+  %ptr1.5 = getelementptr inbounds i16, ptr %ptr1, i64 5
+  %val5 = load i16, ptr %ptr1.5, align 2
+  %ptr1.6 = getelementptr inbounds i16, ptr %ptr1, i64 6
+  %val6 = load i16, ptr %ptr1.6, align 4
+  %ptr1.7 = getelementptr inbounds i16, ptr %ptr1, i64 7
+  %val7 = load i16, ptr %ptr1.7, align 2
+  %lane0 = zext i16 %val0 to i32
+  %lane1 = zext i16 %val1 to i32
+  %lane2 = zext i16 %val2 to i32
+  %lane3 = zext i16 %val3 to i32
+  %lane4 = zext i16 %val4 to i32
+  %lane5 = zext i16 %val5 to i32
+  %lane6 = zext i16 %val6 to i32
+  %lane7 = zext i16 %val7 to i32
+  %red.1 = add i32 %lane0, %lane1
+  %red.2 = add i32 %red.1, %lane2
+  %red.3 = add i32 %red.2, %lane3
+  %red.4 = add i32 %red.3, %lane4
+  %red.5 = add i32 %red.4, %lane5
+  %red.6 = add i32 %red.5, %lane6
+  %red = add i32 %red.6, %lane7
+  store i32 %red, ptr %ptr2, align 4
+  ret void
+}
+
+define void @combine_v4i32(ptr noundef align 16 %ptr1, ptr noundef align 16 %ptr2) {
+  ; ENABLED-LABEL: combine_v4i32
+  ; ENABLED: ld.v4.u32
+  %val0 = load i32, ptr %ptr1, align 16
+  %ptr1.1 = getelementptr inbounds i32, ptr %ptr1, i64 1
+  %val1 = load i32, ptr %ptr1.1, align 4
+  %ptr1.2 = getelementptr inbounds i32, ptr %ptr1, i64 2
+  %val2 = load i32, ptr %ptr1.2, align 8
+  %ptr1.3 = getelementptr inbounds i32, ptr %ptr1, i64 3
+  %val3 = load i32, ptr %ptr1.3, align 4
+  %red.1 = add i32 %val0, %val1
+  %red.2 = add i32 %red.1, %val2
+  %red = add i32 %red.2, %val3
+  store i32 %red, ptr %ptr2, align 4
+  ret void
+}



More information about the llvm-commits mailing list