[llvm] [NVPTX] Don't use stack memory when bitcasting to/from v2i8 (PR #113928)

via llvm-commits llvm-commits at lists.llvm.org
Mon Oct 28 17:00:44 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-nvptx

Author: None (peterbell10)

<details>
<summary>Changes</summary>

`v2i8` is an unsupported type, so we hit the default legalization rules which perform the bitcast in stack memory and is very inefficient on GPU.

This adds a custom lowering where we pack `v2i8` into `i16` and from there use another bitcast node to reach the final desired type. And also the inverse unpacking `i16` into `v2i8`.

---
Full diff: https://github.com/llvm/llvm-project/pull/113928.diff


3 Files Affected:

- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+50) 
- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.h (+2) 
- (added) llvm/test/CodeGen/NVPTX/i8x2-instructions.ll (+36) 


``````````diff
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index a95cba586b8fc3..050fbcfbcd8165 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -551,6 +551,10 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
   setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v4i8, Custom);
   setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v4i8, Custom);
   setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v4i8, Custom);
+
+  // Custom conversions to/from v2i8.
+  setOperationAction(ISD::BITCAST, MVT::v2i8, Custom);
+
   // Only logical ops can be done on v4i8 directly, others must be done
   // elementwise.
   setOperationAction(
@@ -2311,6 +2315,47 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
   return DAG.getBuildVector(Node->getValueType(0), dl, Ops);
 }
 
+SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
+  // Handle bitcasting to/from v2i8 without hitting the default promotion
+  // strategy which goes through stack memory.
+  SDNode *Node = Op.getNode();
+  SDLoc dl(Node);
+
+  auto maybeBitcast = [&](EVT vt, SDValue val) {
+    if (val->getValueType(0) == vt) {
+      return val;
+    }
+    return DAG.getNode(ISD::BITCAST, dl, vt, val);
+  };
+
+  EVT VT = Op->getValueType(0);
+  EVT fromVT = Op->getOperand(0)->getValueType(0);
+
+  if (VT == MVT::v2i8) {
+    // Bitcast to i16 and unpack elements into a vector
+    SDValue reg = maybeBitcast(MVT::i16, Op->getOperand(0));
+    SDValue v0 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i8, reg);
+    SDValue C8 = DAG.getConstant(8, dl, MVT::i16);
+    SDValue v1 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i8,
+                             DAG.getNode(ISD::SRL, dl, MVT::i16, {reg, C8}));
+    return DAG.getNode(ISD::BUILD_VECTOR, dl, MVT::v2i8, {v0, v1});
+  } else if (fromVT == MVT::v2i8) {
+    // Pack vector elements into i16 and bitcast to final type
+    SDValue v0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i8,
+                             Op->getOperand(0), DAG.getIntPtrConstant(0, dl));
+    SDValue v1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i8,
+                             Op->getOperand(0), DAG.getIntPtrConstant(1, dl));
+    SDValue E0 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, v0);
+    SDValue E1 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, v1);
+    SDValue C8 = DAG.getConstant(8, dl, MVT::i16);
+    SDValue reg =
+        DAG.getNode(ISD::OR, dl, MVT::i16,
+                    {E0, DAG.getNode(ISD::SHL, dl, MVT::i16, {E1, C8})});
+    return maybeBitcast(VT, reg);
+  }
+  return Op;
+}
+
 // We can init constant f16x2/v2i16/v4i8 with a single .b32 move.  Normally it
 // would get lowered as two constant loads and vector-packing move.
 // Instead we want just a constant move:
@@ -2818,6 +2863,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
     return Op;
   case ISD::BUILD_VECTOR:
     return LowerBUILD_VECTOR(Op, DAG);
+  case ISD::BITCAST:
+    return LowerBITCAST(Op, DAG);
   case ISD::EXTRACT_SUBVECTOR:
     return Op;
   case ISD::EXTRACT_VECTOR_ELT:
@@ -6413,6 +6460,9 @@ void NVPTXTargetLowering::ReplaceNodeResults(
   switch (N->getOpcode()) {
   default:
     report_fatal_error("Unhandled custom legalization");
+  case ISD::BITCAST:
+    Results.push_back(LowerBITCAST(SDValue(N, 0), DAG));
+    return;
   case ISD::LOAD:
     ReplaceLoadVector(N, DAG, Results);
     return;
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 824a659671967a..13153f4830b695 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -616,6 +616,8 @@ class NVPTXTargetLowering : public TargetLowering {
   const NVPTXSubtarget &STI; // cache the subtarget here
   SDValue getParamSymbol(SelectionDAG &DAG, int idx, EVT) const;
 
+  SDValue LowerBITCAST(SDValue Op, SelectionDAG &DAG) const;
+
   SDValue LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
diff --git a/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll b/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll
new file mode 100644
index 00000000000000..2f5d8cfed2b7b7
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll
@@ -0,0 +1,36 @@
+; RUN: llc < %s -mtriple=nvptx64-nvidia-cuda -mcpu=sm_90 -mattr=+ptx80 -asm-verbose=false \
+; RUN:          -O0 -disable-post-ra -frame-pointer=all -verify-machineinstrs \
+; RUN: | FileCheck  %s
+; RUN: %if ptxas %{                                                           \
+; RUN:   llc < %s -mtriple=nvptx64-nvidia-cuda -mcpu=sm_90 -asm-verbose=false \
+; RUN:          -O0 -disable-post-ra -frame-pointer=all -verify-machineinstrs \
+; RUN:   | %ptxas-verify -arch=sm_90                                          \
+; RUN: %}
+
+target datalayout = "e-m:o-i64:64-i128:128-n32:64-S128"
+
+; CHECK-LABEL: test_trunc_2xi8(
+; CHECK:      ld.param.u32 [[R1:%r[0-9]+]], [test_trunc_2xi8_param_0];
+; CHECK:      mov.b32 {[[RS1:%rs[0-9]+]], [[RS2:%rs[0-9]+]]}, [[R1]];
+; CHECK:      shl.b16 	[[RS3:%rs[0-9]+]], [[RS2]], 8;
+; CHECK:      and.b16  [[RS4:%rs[0-9]+]], [[RS1]], 255;
+; CHECK:      or.b16   [[RS5:%rs[0-9]+]], [[RS4]], [[RS3]]
+; CHECK:      cvt.u32.u16  [[R2:%r[0-9]]], [[RS5]]
+; CHECK:      st.param.b32  [func_retval0], [[R2]];
+define i16 @test_trunc_2xi8(<2 x i16> %a) #0 {
+  %trunc = trunc <2 x i16> %a to <2 x i8>
+  %res = bitcast <2 x i8> %trunc to i16
+  ret i16 %res
+}
+
+; CHECK-LABEL: test_zext_2xi8(
+; CHECK:      ld.param.u16  [[RS1:%rs[0-9]+]], [test_zext_2xi8_param_0];
+; CHECK:      shr.u16 	[[RS2:%rs[0-9]+]], [[RS1]], 8;
+; CHECK:      mov.b32  [[R1:%r[0-9]+]], {[[RS1]], [[RS2]]}
+; CHECK:      and.b32  [[R2:%r[0-9]+]], [[R1]], 16711935;
+; CHECK:      st.param.b32  [func_retval0], [[R2]];
+define <2 x i16> @test_zext_2xi8(i16 %a) #0 {
+  %vec = bitcast i16 %a to <2 x i8>
+  %ext = zext <2 x i8> %vec to <2 x i16>
+  ret <2 x i16> %ext
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/113928


More information about the llvm-commits mailing list