[llvm] [NVPTX] Lower 16xi8 and 8xi8 stores efficiently (PR #73646)

Uday Bondhugula via llvm-commits llvm-commits at lists.llvm.org
Wed Nov 29 18:26:12 PST 2023


https://github.com/bondhugula updated https://github.com/llvm/llvm-project/pull/73646

>From c1973012575bcf34e76d5e10c0c0f71a74c1e11a Mon Sep 17 00:00:00 2001
From: Uday Bondhugula <uday at polymagelabs.com>
Date: Tue, 28 Nov 2023 14:34:03 +0530
Subject: [PATCH] [NVPTX] Lower 16xi8 and 8xi8 stores efficiently

Lower 16xi8 vector stores in NVPTX ISel efficiently using
st.v4.b32 instead of multiple st.v4.u8 along the lines of vector loads
and 8xf16. Similarly, 8xi8 using st.v2.u32.
---
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp  | 53 ++++++++++++++++++--
 llvm/test/CodeGen/NVPTX/i8x4-instructions.ll |  7 ++-
 llvm/test/CodeGen/NVPTX/vector-stores.ll     | 16 ++++++
 3 files changed, 69 insertions(+), 7 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 61285c6ba98dffa..b975825dae4b6a0 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -508,6 +508,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
   setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v2i16, Expand);
   setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v2i16, Expand);
 
+  // Conversion to/from i8/i8x4 is always legal.
   setOperationAction(ISD::BUILD_VECTOR, MVT::v4i8, Custom);
   setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v4i8, Custom);
   setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v4i8, Custom);
@@ -717,8 +718,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
 
   // We have some custom DAG combine patterns for these nodes
   setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
-                       ISD::LOAD, ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM,
-                       ISD::VSELECT});
+                       ISD::LOAD, ISD::MUL, ISD::SHL, ISD::SREM, ISD::STORE,
+                       ISD::UREM, ISD::VSELECT});
 
   // setcc for f16x2 and bf16x2 needs special handling to prevent
   // legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -2916,7 +2917,6 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
         DAG.getMemIntrinsicNode(Opcode, DL, DAG.getVTList(MVT::Other), Ops,
                                 MemSD->getMemoryVT(), MemSD->getMemOperand());
 
-    // return DCI.CombineTo(N, NewSt, true);
     return NewSt;
   }
 
@@ -5557,6 +5557,51 @@ static SDValue PerformLOADCombine(SDNode *N,
       DL);
 }
 
+// Lower a v16i8 (or a v8i8) store into a StoreV4 (or StoreV2) operation with
+// i32 results instead of letting ReplaceLoadVector split it into smaller stores
+// during legalization. This is done at dag-combine1 time, so that vector
+// operations with i8 elements can be optimised away instead of being needlessly
+// split during legalization, which involves storing to the stack and loading it
+// back.
+static SDValue PerformSTORECombine(SDNode *N,
+                                   TargetLowering::DAGCombinerInfo &DCI) {
+  SelectionDAG &DAG = DCI.DAG;
+  StoreSDNode *ST = cast<StoreSDNode>(N);
+  EVT VT = ST->getValue().getValueType();
+  if (VT != MVT::v16i8 && VT != MVT::v8i8)
+    return SDValue();
+
+  // Create a v4i32 vector store operation, effectively <4 x v4i8>.
+  unsigned Opc = VT == MVT::v16i8 ? NVPTXISD::StoreV4 : NVPTXISD::StoreV2;
+  EVT NewVT = VT == MVT::v16i8 ? MVT::v4i32 : MVT::v2i32;
+  unsigned NumElts = NewVT.getVectorNumElements();
+
+  // Create a vector of the type required by the new store: v16i8 -> v4i32.
+  SDValue NewStoreValue = DCI.DAG.getBitcast(NewVT, ST->getValue());
+
+  // Operands for the store.
+  SmallVector<SDValue, 8> Ops;
+  Ops.reserve(N->getNumOperands() + NumElts - 1);
+  // Chain value.
+  Ops.push_back(N->ops().front());
+
+  SDLoc DL(N);
+  SmallVector<SDValue> Elts(NumElts);
+  // Break v4i32 (or v2i32) into four (or two) elements.
+  for (unsigned I = 0; I < NumElts; ++I)
+    Elts[I] = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL,
+                          NewStoreValue.getValueType().getVectorElementType(),
+                          NewStoreValue, DAG.getIntPtrConstant(I, DL));
+  Ops.append(Elts.begin(), Elts.end());
+  // Any remaining operands.
+  Ops.append(N->op_begin() + 2, N->op_end());
+
+  SDValue NewStore = DAG.getMemIntrinsicNode(Opc, DL, DAG.getVTList(MVT::Other),
+                                             Ops, NewVT, ST->getMemOperand());
+  // Return the new chain.
+  return NewStore.getValue(0);
+}
+
 SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
                                                DAGCombinerInfo &DCI) const {
   CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
@@ -5578,6 +5623,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
       return PerformSETCCCombine(N, DCI, STI.getSmVersion());
     case ISD::LOAD:
       return PerformLOADCombine(N, DCI);
+    case ISD::STORE:
+      return PerformSTORECombine(N, DCI);
     case NVPTXISD::StoreRetval:
     case NVPTXISD::StoreRetvalV2:
     case NVPTXISD::StoreRetvalV4:
diff --git a/llvm/test/CodeGen/NVPTX/i8x4-instructions.ll b/llvm/test/CodeGen/NVPTX/i8x4-instructions.ll
index 1ec68b4a271bac9..55cf6fb8257627a 100644
--- a/llvm/test/CodeGen/NVPTX/i8x4-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/i8x4-instructions.ll
@@ -790,10 +790,9 @@ define void @test_ldst_v8i8(ptr %a, ptr %b) {
 ; CHECK-NEXT:  // %bb.0:
 ; CHECK-NEXT:    ld.param.u64 %rd2, [test_ldst_v8i8_param_1];
 ; CHECK-NEXT:    ld.param.u64 %rd1, [test_ldst_v8i8_param_0];
-; CHECK-NEXT:    ld.u32 %r1, [%rd1];
-; CHECK-NEXT:    ld.u32 %r2, [%rd1+4];
-; CHECK-NEXT:    st.u32 [%rd2+4], %r2;
-; CHECK-NEXT:    st.u32 [%rd2], %r1;
+; CHECK-NEXT:    ld.u32 %r1, [%rd1+4];
+; CHECK-NEXT:    ld.u32 %r2, [%rd1];
+; CHECK-NEXT:    st.v2.u32 [%rd2], {%r2, %r1};
 ; CHECK-NEXT:    ret;
   %t1 = load <8 x i8>, ptr %a
   store <8 x i8> %t1, ptr %b, align 16
diff --git a/llvm/test/CodeGen/NVPTX/vector-stores.ll b/llvm/test/CodeGen/NVPTX/vector-stores.ll
index df14553a7720576..8248bdbc1ee1c4d 100644
--- a/llvm/test/CodeGen/NVPTX/vector-stores.ll
+++ b/llvm/test/CodeGen/NVPTX/vector-stores.ll
@@ -37,3 +37,19 @@ define void @v16i8(ptr %a, ptr %b) {
   store <16 x i8> %v, ptr %b
   ret void
 }
+
+; CHECK-LABEL: .visible .func v16i8_store
+define void @v16i8_store(ptr %a, <16 x i8> %v) {
+  ; CHECK:      ld.param.u64   %rd1, [v16i8_store_param_0];
+  ; CHECK-NEXT: ld.param.v4.u32   {%r1, %r2, %r3, %r4}, [v16i8_store_param_1];
+  ; CHECK-NEXT: st.v4.u32   [%rd1], {%r1, %r2, %r3, %r4};
+  store <16 x i8> %v, ptr %a
+  ret void
+}
+
+; CHECK-LABEL: .visible .func v8i8_store
+define void @v8i8_store(ptr %a, <8 x i8> %v) {
+  ; CHECK: st.v2.u32
+  store <8 x i8> %v, ptr %a
+  ret void
+}



More information about the llvm-commits mailing list