[llvm] [NVPTX] Lower 16xi8 and 8xi8 stores efficiently (PR #73646)

Uday Bondhugula via llvm-commits llvm-commits at lists.llvm.org
Tue Nov 28 21:10:46 PST 2023


https://github.com/bondhugula updated https://github.com/llvm/llvm-project/pull/73646

>From 6ab9db7ed2f835e68d1177c87d48d9c8bcbf9d99 Mon Sep 17 00:00:00 2001
From: Uday Bondhugula <uday at polymagelabs.com>
Date: Tue, 28 Nov 2023 14:34:03 +0530
Subject: [PATCH] [NVPTX] Lower 16xi8 and 8xi8 stores efficiently

Lower 16xi8 vector stores in NVPTX ISel efficiently using
st.v4.b32 instead of multiple st.v4.u8 along the lines of vector loads
and 8xf16. Similarly, 8xi8 using st.v2.u32.
---
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp  | 52 ++++++++++++++++++--
 llvm/test/CodeGen/NVPTX/i8x4-instructions.ll |  7 ++-
 llvm/test/CodeGen/NVPTX/vector-stores.ll     | 16 ++++++
 3 files changed, 68 insertions(+), 7 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 61285c6ba98dffa..7537e4dfcc236e3 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -508,6 +508,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
   setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v2i16, Expand);
   setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v2i16, Expand);
 
+  // Conversion to/from i8/i8x4 is always legal.
   setOperationAction(ISD::BUILD_VECTOR, MVT::v4i8, Custom);
   setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v4i8, Custom);
   setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v4i8, Custom);
@@ -717,8 +718,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
 
   // We have some custom DAG combine patterns for these nodes
   setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
-                       ISD::LOAD, ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM,
-                       ISD::VSELECT});
+                       ISD::LOAD, ISD::MUL, ISD::SHL, ISD::SREM, ISD::STORE,
+                       ISD::UREM, ISD::VSELECT});
 
   // setcc for f16x2 and bf16x2 needs special handling to prevent
   // legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -2916,7 +2917,6 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
         DAG.getMemIntrinsicNode(Opcode, DL, DAG.getVTList(MVT::Other), Ops,
                                 MemSD->getMemoryVT(), MemSD->getMemOperand());
 
-    // return DCI.CombineTo(N, NewSt, true);
     return NewSt;
   }
 
@@ -5557,6 +5557,50 @@ static SDValue PerformLOADCombine(SDNode *N,
       DL);
 }
 
+// Lower a v16i8 (or a v8i8) store into a StoreV4 operation with i32 results
+// instead of letting ReplaceLoadVector split it into smaller stores during
+// legalization. This is done at dag-combine time, so that vector operations
+// with i8 elements can be optimised away instead of being needlessly split
+// during legalization, which involves storing to the stack and loading it back.
+static SDValue PerformSTORECombine(SDNode *N,
+                                   TargetLowering::DAGCombinerInfo &DCI) {
+  SelectionDAG &DAG = DCI.DAG;
+  StoreSDNode *ST = cast<StoreSDNode>(N);
+  EVT VT = ST->getValue().getValueType();
+  if (VT != MVT::v16i8 && VT != MVT::v8i8)
+    return SDValue();
+
+  // Create a v4i32 vector store operation, effectively <4 x v4i8>.
+  unsigned Opc = VT == MVT::v16i8 ? NVPTXISD::StoreV4 : NVPTXISD::StoreV2;
+  EVT NewVT = VT == MVT::v16i8 ? MVT::v4i32 : MVT::v2i32;
+  unsigned NumElts = NewVT.getVectorNumElements();
+
+  // Create a vector of the type required by the new store: v16i8 -> v4i32.
+  SDValue NewStoreValue = DCI.DAG.getBitcast(NewVT, ST->getValue());
+
+  // Operands for the store.
+  SmallVector<SDValue, 8> Ops;
+  Ops.reserve(N->getNumOperands() + NumElts - 1);
+  // Chain value.
+  Ops.push_back(N->ops().front());
+
+  SDLoc DL(N);
+  SmallVector<SDValue> Elts(NumElts);
+  // Break v4i32 (or v2i32) into four (or two) elements.
+  for (unsigned I = 0; I < NumElts; ++I)
+    Elts[I] = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL,
+                          NewStoreValue.getValueType().getVectorElementType(),
+                          NewStoreValue, DAG.getIntPtrConstant(I, DL));
+  Ops.append(Elts.begin(), Elts.end());
+  // Any remaining operands.
+  Ops.append(N->op_begin() + 2, N->op_end());
+
+  SDValue NewStore = DAG.getMemIntrinsicNode(Opc, DL, DAG.getVTList(MVT::Other),
+                                             Ops, NewVT, ST->getMemOperand());
+  // Return the new chain.
+  return NewStore.getValue(0);
+}
+
 SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
                                                DAGCombinerInfo &DCI) const {
   CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
@@ -5578,6 +5622,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
       return PerformSETCCCombine(N, DCI, STI.getSmVersion());
     case ISD::LOAD:
       return PerformLOADCombine(N, DCI);
+    case ISD::STORE:
+      return PerformSTORECombine(N, DCI);
     case NVPTXISD::StoreRetval:
     case NVPTXISD::StoreRetvalV2:
     case NVPTXISD::StoreRetvalV4:
diff --git a/llvm/test/CodeGen/NVPTX/i8x4-instructions.ll b/llvm/test/CodeGen/NVPTX/i8x4-instructions.ll
index 1ec68b4a271bac9..55cf6fb8257627a 100644
--- a/llvm/test/CodeGen/NVPTX/i8x4-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/i8x4-instructions.ll
@@ -790,10 +790,9 @@ define void @test_ldst_v8i8(ptr %a, ptr %b) {
 ; CHECK-NEXT:  // %bb.0:
 ; CHECK-NEXT:    ld.param.u64 %rd2, [test_ldst_v8i8_param_1];
 ; CHECK-NEXT:    ld.param.u64 %rd1, [test_ldst_v8i8_param_0];
-; CHECK-NEXT:    ld.u32 %r1, [%rd1];
-; CHECK-NEXT:    ld.u32 %r2, [%rd1+4];
-; CHECK-NEXT:    st.u32 [%rd2+4], %r2;
-; CHECK-NEXT:    st.u32 [%rd2], %r1;
+; CHECK-NEXT:    ld.u32 %r1, [%rd1+4];
+; CHECK-NEXT:    ld.u32 %r2, [%rd1];
+; CHECK-NEXT:    st.v2.u32 [%rd2], {%r2, %r1};
 ; CHECK-NEXT:    ret;
   %t1 = load <8 x i8>, ptr %a
   store <8 x i8> %t1, ptr %b, align 16
diff --git a/llvm/test/CodeGen/NVPTX/vector-stores.ll b/llvm/test/CodeGen/NVPTX/vector-stores.ll
index df14553a7720576..8248bdbc1ee1c4d 100644
--- a/llvm/test/CodeGen/NVPTX/vector-stores.ll
+++ b/llvm/test/CodeGen/NVPTX/vector-stores.ll
@@ -37,3 +37,19 @@ define void @v16i8(ptr %a, ptr %b) {
   store <16 x i8> %v, ptr %b
   ret void
 }
+
+; CHECK-LABEL: .visible .func v16i8_store
+define void @v16i8_store(ptr %a, <16 x i8> %v) {
+  ; CHECK:      ld.param.u64   %rd1, [v16i8_store_param_0];
+  ; CHECK-NEXT: ld.param.v4.u32   {%r1, %r2, %r3, %r4}, [v16i8_store_param_1];
+  ; CHECK-NEXT: st.v4.u32   [%rd1], {%r1, %r2, %r3, %r4};
+  store <16 x i8> %v, ptr %a
+  ret void
+}
+
+; CHECK-LABEL: .visible .func v8i8_store
+define void @v8i8_store(ptr %a, <8 x i8> %v) {
+  ; CHECK: st.v2.u32
+  store <8 x i8> %v, ptr %a
+  ret void
+}



More information about the llvm-commits mailing list