[llvm] [NVPTX] Don't use stack memory when bitcasting to/from v2i8 (PR #113928)

via llvm-commits llvm-commits at lists.llvm.org
Tue Oct 29 19:55:09 PDT 2024


https://github.com/peterbell10 updated https://github.com/llvm/llvm-project/pull/113928

>From 513bf2a835f7cc36e7af14bfe1d28293656a540c Mon Sep 17 00:00:00 2001
From: Peter Bell <peterbell10 at openai.com>
Date: Mon, 28 Oct 2024 15:15:12 +0000
Subject: [PATCH 1/6] [NVPTX] Don't use stack memory when bitcasting to/from
 2xi8

`v2i8` is and unsupported type, so we hit the default legalization rules
which perform the bitcast in stack memory and is very inefficient on GPU.

This adds a custom lowering where we pack `v2i8` into `i16` and from there use
another bitcast node to reach the final desired type. And also the
inverse unpacking `i16` into `v2i8`.
---
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp  | 48 ++++++++++++++++++++
 llvm/lib/Target/NVPTX/NVPTXISelLowering.h    |  2 +
 llvm/test/CodeGen/NVPTX/i8x2-instructions.ll | 36 +++++++++++++++
 3 files changed, 86 insertions(+)
 create mode 100644 llvm/test/CodeGen/NVPTX/i8x2-instructions.ll

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index a95cba586b8fc3..d71441c03e5dc3 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -551,6 +551,10 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
   setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v4i8, Custom);
   setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v4i8, Custom);
   setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v4i8, Custom);
+
+  // Custom conversions to/from v2i8.
+  setOperationAction(ISD::BITCAST, MVT::v2i8, Custom);
+
   // Only logical ops can be done on v4i8 directly, others must be done
   // elementwise.
   setOperationAction(
@@ -2311,6 +2315,45 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
   return DAG.getBuildVector(Node->getValueType(0), dl, Ops);
 }
 
+SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
+  // Handle bitcasting to/from v2i8 without hitting the default promotion
+  // strategy which goes through stack memory.
+  SDNode *Node = Op.getNode();
+  SDLoc dl(Node);
+
+  auto maybeBitcast = [&](EVT vt, SDValue val) {
+    if (val->getValueType(0) == vt) {
+      return val;
+    }
+    return DAG.getNode(ISD::BITCAST, dl, vt, val);
+  };
+
+  EVT VT = Op->getValueType(0);
+  EVT fromVT = Op->getOperand(0)->getValueType(0);
+
+  if (VT == MVT::v2i8) {
+    SDValue reg = maybeBitcast(MVT::i16, Op->getOperand(0));
+    // Promote result to v2i16
+    SDValue v0 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i8, reg);
+    SDValue C8 = DAG.getConstant(8, dl, MVT::i16);
+    SDValue v1 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i8, 
+                             DAG.getNode(ISD::SRL, dl, MVT::i16, {reg, C8}));
+    return DAG.getNode(ISD::BUILD_VECTOR, dl, MVT::v2i8, {v0, v1});
+  } else if (fromVT == MVT::v2i8) {
+    SDValue v0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i8, Op->getOperand(0),
+                             DAG.getIntPtrConstant(0, dl));
+    SDValue v1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i8, Op->getOperand(0),
+                             DAG.getIntPtrConstant(1, dl));
+    SDValue E0 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, v0);
+    SDValue E1 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, v1);
+    SDValue C8 = DAG.getConstant(8, dl, MVT::i16);
+    SDValue reg = DAG.getNode(ISD::OR, dl, MVT::i16, 
+                              {E0, DAG.getNode(ISD::SHL, dl, MVT::i16, {E1, C8})});
+    return maybeBitcast(VT, reg);
+  }
+  return Op;
+}
+
 // We can init constant f16x2/v2i16/v4i8 with a single .b32 move.  Normally it
 // would get lowered as two constant loads and vector-packing move.
 // Instead we want just a constant move:
@@ -2818,6 +2861,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
     return Op;
   case ISD::BUILD_VECTOR:
     return LowerBUILD_VECTOR(Op, DAG);
+  case ISD::BITCAST:
+    return LowerBITCAST(Op, DAG);
   case ISD::EXTRACT_SUBVECTOR:
     return Op;
   case ISD::EXTRACT_VECTOR_ELT:
@@ -6413,6 +6458,9 @@ void NVPTXTargetLowering::ReplaceNodeResults(
   switch (N->getOpcode()) {
   default:
     report_fatal_error("Unhandled custom legalization");
+  case ISD::BITCAST:
+    Results.push_back(LowerBITCAST(SDValue(N, 0), DAG));
+    return;
   case ISD::LOAD:
     ReplaceLoadVector(N, DAG, Results);
     return;
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index 824a659671967a..13153f4830b695 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -616,6 +616,8 @@ class NVPTXTargetLowering : public TargetLowering {
   const NVPTXSubtarget &STI; // cache the subtarget here
   SDValue getParamSymbol(SelectionDAG &DAG, int idx, EVT) const;
 
+  SDValue LowerBITCAST(SDValue Op, SelectionDAG &DAG) const;
+
   SDValue LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
diff --git a/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll b/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll
new file mode 100644
index 00000000000000..4d92c41d72bfcd
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll
@@ -0,0 +1,36 @@
+; RUN: llc < %s -mtriple=nvptx64-nvidia-cuda -mcpu=sm_90 -mattr=+ptx80 -asm-verbose=false \
+; RUN:          -O0 -disable-post-ra -frame-pointer=all -verify-machineinstrs \
+; RUN: | FileCheck -allow-deprecated-dag-overlap -check-prefixes COMMON,I16x2 %s
+; RUN: %if ptxas %{                                                           \
+; RUN:   llc < %s -mtriple=nvptx64-nvidia-cuda -mcpu=sm_90 -asm-verbose=false \
+; RUN:          -O0 -disable-post-ra -frame-pointer=all -verify-machineinstrs \
+; RUN:   | %ptxas-verify -arch=sm_90                                          \
+; RUN: %}
+
+target datalayout = "e-m:o-i64:64-i128:128-n32:64-S128"
+
+; COMMON-LABEL: test_trunc_2xi8(
+; COMMON:      ld.param.u32 [[R1:%r[0-9]+]], [test_trunc_2xi8_param_0];
+; COMMON:      mov.b32 {[[RS1:%rs[0-9]+]], [[RS2:%rs[0-9]+]]}, [[R1]];
+; COMMON:      shl.b16 	[[RS3:%rs[0-9]+]], [[RS2]], 8;
+; COMMON:      and.b16  [[RS4:%rs[0-9]+]], [[RS1]], 255;
+; COMMON:      or.b16   [[RS5:%rs[0-9]+]], [[RS4]], [[RS3]]
+; COMMON:      cvt.u32.u16  [[R2:%r[0-9]]], [[RS5]]
+; COMMON:      st.param.b32  [func_retval0+0], [[R2]];
+define i16 @test_trunc_2xi8(<2 x i16> %a) #0 {
+  %trunc = trunc <2 x i16> %a to <2 x i8>
+  %res = bitcast <2 x i8> %trunc to i16
+  ret i16 %res
+}
+
+; COMMON-LABEL: test_zext_2xi8(
+; COMMON:      ld.param.u16  [[RS1:%rs[0-9]+]], [test_zext_2xi8_param_0];
+; COMMON:      shr.u16 	[[RS2:%rs[0-9]+]], [[RS1]], 8;
+; COMMON:      mov.b32  [[R1:%r[0-9]+]], {[[RS1]], [[RS2]]}
+; COMMON:      and.b32  [[R2:%r[0-9]+]], [[R1]], 16711935;
+; COMMON:      st.param.b32  [func_retval0+0], [[R2]];
+define <2 x i16> @test_zext_2xi8(i16 %a) #0 {
+  %vec = bitcast i16 %a to <2 x i8>
+  %ext = zext <2 x i8> %vec to <2 x i16>
+  ret <2 x i16> %ext
+}

>From b0aa1db4c67c60a81ea204271d750e9f0ec0e08a Mon Sep 17 00:00:00 2001
From: Peter Bell <peterbell10 at openai.com>
Date: Mon, 28 Oct 2024 15:51:24 +0000
Subject: [PATCH 2/6] Formatting

---
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 15 ++++++++-------
 1 file changed, 8 insertions(+), 7 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index d71441c03e5dc3..6deb656c6e3884 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -2336,19 +2336,20 @@ SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
     // Promote result to v2i16
     SDValue v0 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i8, reg);
     SDValue C8 = DAG.getConstant(8, dl, MVT::i16);
-    SDValue v1 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i8, 
+    SDValue v1 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i8,
                              DAG.getNode(ISD::SRL, dl, MVT::i16, {reg, C8}));
     return DAG.getNode(ISD::BUILD_VECTOR, dl, MVT::v2i8, {v0, v1});
   } else if (fromVT == MVT::v2i8) {
-    SDValue v0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i8, Op->getOperand(0),
-                             DAG.getIntPtrConstant(0, dl));
-    SDValue v1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i8, Op->getOperand(0),
-                             DAG.getIntPtrConstant(1, dl));
+    SDValue v0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i8,
+                             Op->getOperand(0), DAG.getIntPtrConstant(0, dl));
+    SDValue v1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i8,
+                             Op->getOperand(0), DAG.getIntPtrConstant(1, dl));
     SDValue E0 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, v0);
     SDValue E1 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, v1);
     SDValue C8 = DAG.getConstant(8, dl, MVT::i16);
-    SDValue reg = DAG.getNode(ISD::OR, dl, MVT::i16, 
-                              {E0, DAG.getNode(ISD::SHL, dl, MVT::i16, {E1, C8})});
+    SDValue reg =
+        DAG.getNode(ISD::OR, dl, MVT::i16,
+                    {E0, DAG.getNode(ISD::SHL, dl, MVT::i16, {E1, C8})});
     return maybeBitcast(VT, reg);
   }
   return Op;

>From fc9ec25ff8dcc9cdfbac65fb4f794a94756f765a Mon Sep 17 00:00:00 2001
From: Peter Bell <peterbell10 at openai.com>
Date: Mon, 28 Oct 2024 19:33:15 +0000
Subject: [PATCH 3/6] Fix lit test

---
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp  |  3 +-
 llvm/test/CodeGen/NVPTX/i8x2-instructions.ll | 30 ++++++++++----------
 2 files changed, 17 insertions(+), 16 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 6deb656c6e3884..050fbcfbcd8165 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -2332,14 +2332,15 @@ SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
   EVT fromVT = Op->getOperand(0)->getValueType(0);
 
   if (VT == MVT::v2i8) {
+    // Bitcast to i16 and unpack elements into a vector
     SDValue reg = maybeBitcast(MVT::i16, Op->getOperand(0));
-    // Promote result to v2i16
     SDValue v0 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i8, reg);
     SDValue C8 = DAG.getConstant(8, dl, MVT::i16);
     SDValue v1 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i8,
                              DAG.getNode(ISD::SRL, dl, MVT::i16, {reg, C8}));
     return DAG.getNode(ISD::BUILD_VECTOR, dl, MVT::v2i8, {v0, v1});
   } else if (fromVT == MVT::v2i8) {
+    // Pack vector elements into i16 and bitcast to final type
     SDValue v0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i8,
                              Op->getOperand(0), DAG.getIntPtrConstant(0, dl));
     SDValue v1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i8,
diff --git a/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll b/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll
index 4d92c41d72bfcd..2f5d8cfed2b7b7 100644
--- a/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll
@@ -1,6 +1,6 @@
 ; RUN: llc < %s -mtriple=nvptx64-nvidia-cuda -mcpu=sm_90 -mattr=+ptx80 -asm-verbose=false \
 ; RUN:          -O0 -disable-post-ra -frame-pointer=all -verify-machineinstrs \
-; RUN: | FileCheck -allow-deprecated-dag-overlap -check-prefixes COMMON,I16x2 %s
+; RUN: | FileCheck  %s
 ; RUN: %if ptxas %{                                                           \
 ; RUN:   llc < %s -mtriple=nvptx64-nvidia-cuda -mcpu=sm_90 -asm-verbose=false \
 ; RUN:          -O0 -disable-post-ra -frame-pointer=all -verify-machineinstrs \
@@ -9,26 +9,26 @@
 
 target datalayout = "e-m:o-i64:64-i128:128-n32:64-S128"
 
-; COMMON-LABEL: test_trunc_2xi8(
-; COMMON:      ld.param.u32 [[R1:%r[0-9]+]], [test_trunc_2xi8_param_0];
-; COMMON:      mov.b32 {[[RS1:%rs[0-9]+]], [[RS2:%rs[0-9]+]]}, [[R1]];
-; COMMON:      shl.b16 	[[RS3:%rs[0-9]+]], [[RS2]], 8;
-; COMMON:      and.b16  [[RS4:%rs[0-9]+]], [[RS1]], 255;
-; COMMON:      or.b16   [[RS5:%rs[0-9]+]], [[RS4]], [[RS3]]
-; COMMON:      cvt.u32.u16  [[R2:%r[0-9]]], [[RS5]]
-; COMMON:      st.param.b32  [func_retval0+0], [[R2]];
+; CHECK-LABEL: test_trunc_2xi8(
+; CHECK:      ld.param.u32 [[R1:%r[0-9]+]], [test_trunc_2xi8_param_0];
+; CHECK:      mov.b32 {[[RS1:%rs[0-9]+]], [[RS2:%rs[0-9]+]]}, [[R1]];
+; CHECK:      shl.b16 	[[RS3:%rs[0-9]+]], [[RS2]], 8;
+; CHECK:      and.b16  [[RS4:%rs[0-9]+]], [[RS1]], 255;
+; CHECK:      or.b16   [[RS5:%rs[0-9]+]], [[RS4]], [[RS3]]
+; CHECK:      cvt.u32.u16  [[R2:%r[0-9]]], [[RS5]]
+; CHECK:      st.param.b32  [func_retval0], [[R2]];
 define i16 @test_trunc_2xi8(<2 x i16> %a) #0 {
   %trunc = trunc <2 x i16> %a to <2 x i8>
   %res = bitcast <2 x i8> %trunc to i16
   ret i16 %res
 }
 
-; COMMON-LABEL: test_zext_2xi8(
-; COMMON:      ld.param.u16  [[RS1:%rs[0-9]+]], [test_zext_2xi8_param_0];
-; COMMON:      shr.u16 	[[RS2:%rs[0-9]+]], [[RS1]], 8;
-; COMMON:      mov.b32  [[R1:%r[0-9]+]], {[[RS1]], [[RS2]]}
-; COMMON:      and.b32  [[R2:%r[0-9]+]], [[R1]], 16711935;
-; COMMON:      st.param.b32  [func_retval0+0], [[R2]];
+; CHECK-LABEL: test_zext_2xi8(
+; CHECK:      ld.param.u16  [[RS1:%rs[0-9]+]], [test_zext_2xi8_param_0];
+; CHECK:      shr.u16 	[[RS2:%rs[0-9]+]], [[RS1]], 8;
+; CHECK:      mov.b32  [[R1:%r[0-9]+]], {[[RS1]], [[RS2]]}
+; CHECK:      and.b32  [[R2:%r[0-9]+]], [[R1]], 16711935;
+; CHECK:      st.param.b32  [func_retval0], [[R2]];
 define <2 x i16> @test_zext_2xi8(i16 %a) #0 {
   %vec = bitcast i16 %a to <2 x i8>
   %ext = zext <2 x i8> %vec to <2 x i16>

>From 6e16bb5348e9597033275af706d96c198b65dee7 Mon Sep 17 00:00:00 2001
From: Peter Bell <peterbell10 at openai.com>
Date: Tue, 29 Oct 2024 14:06:22 +0000
Subject: [PATCH 4/6] Address review comments

---
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp  | 54 ++++++++++----------
 llvm/test/CodeGen/NVPTX/i8x2-instructions.ll | 39 +++++++-------
 2 files changed, 45 insertions(+), 48 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 050fbcfbcd8165..72a34de6b33e61 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -2319,39 +2319,39 @@ SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
   // Handle bitcasting to/from v2i8 without hitting the default promotion
   // strategy which goes through stack memory.
   SDNode *Node = Op.getNode();
-  SDLoc dl(Node);
+  SDLoc DL(Node);
 
-  auto maybeBitcast = [&](EVT vt, SDValue val) {
-    if (val->getValueType(0) == vt) {
-      return val;
-    }
-    return DAG.getNode(ISD::BITCAST, dl, vt, val);
+  auto maybeBitcast = [&](EVT VT, SDValue Value) {
+    if (Value->getValueType(0) == VT)
+      return Value;
+    return DAG.getNode(ISD::BITCAST, DL, VT, Value);
   };
 
-  EVT VT = Op->getValueType(0);
-  EVT fromVT = Op->getOperand(0)->getValueType(0);
+  EVT ToVT = Op->getValueType(0);
+  EVT FromVT = Op->getOperand(0)->getValueType(0);
 
-  if (VT == MVT::v2i8) {
+  if (ToVT == MVT::v2i8) {
     // Bitcast to i16 and unpack elements into a vector
-    SDValue reg = maybeBitcast(MVT::i16, Op->getOperand(0));
-    SDValue v0 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i8, reg);
-    SDValue C8 = DAG.getConstant(8, dl, MVT::i16);
-    SDValue v1 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i8,
-                             DAG.getNode(ISD::SRL, dl, MVT::i16, {reg, C8}));
-    return DAG.getNode(ISD::BUILD_VECTOR, dl, MVT::v2i8, {v0, v1});
-  } else if (fromVT == MVT::v2i8) {
+    SDValue AsInt = maybeBitcast(MVT::i16, Op->getOperand(0));
+    SDValue Vec0 = DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, AsInt);
+    SDValue Const8 = DAG.getConstant(8, DL, MVT::i16);
+    SDValue Vec1 =
+        DAG.getNode(ISD::TRUNCATE, DL, MVT::i8,
+                    DAG.getNode(ISD::SRL, DL, MVT::i16, {AsInt, Const8}));
+    return DAG.getNode(ISD::BUILD_VECTOR, DL, MVT::v2i8, {Vec0, Vec1});
+  } else if (FromVT == MVT::v2i8) {
     // Pack vector elements into i16 and bitcast to final type
-    SDValue v0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i8,
-                             Op->getOperand(0), DAG.getIntPtrConstant(0, dl));
-    SDValue v1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i8,
-                             Op->getOperand(0), DAG.getIntPtrConstant(1, dl));
-    SDValue E0 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, v0);
-    SDValue E1 = DAG.getNode(ISD::ZERO_EXTEND, dl, MVT::i16, v1);
-    SDValue C8 = DAG.getConstant(8, dl, MVT::i16);
-    SDValue reg =
-        DAG.getNode(ISD::OR, dl, MVT::i16,
-                    {E0, DAG.getNode(ISD::SHL, dl, MVT::i16, {E1, C8})});
-    return maybeBitcast(VT, reg);
+    SDValue Vec0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8,
+                               Op->getOperand(0), DAG.getIntPtrConstant(0, DL));
+    SDValue Vec1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8,
+                               Op->getOperand(0), DAG.getIntPtrConstant(1, DL));
+    SDValue Extend0 = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i16, Vec0);
+    SDValue Extend1 = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i16, Vec1);
+    SDValue Const8 = DAG.getConstant(8, DL, MVT::i16);
+    SDValue AsInt = DAG.getNode(
+        ISD::OR, DL, MVT::i16,
+        {Extend0, DAG.getNode(ISD::SHL, DL, MVT::i16, {Extend1, Const8})});
+    return maybeBitcast(ToVT, AsInt);
   }
   return Op;
 }
diff --git a/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll b/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll
index 2f5d8cfed2b7b7..df9c3e59b0e6ba 100644
--- a/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll
@@ -9,28 +9,25 @@
 
 target datalayout = "e-m:o-i64:64-i128:128-n32:64-S128"
 
-; CHECK-LABEL: test_trunc_2xi8(
-; CHECK:      ld.param.u32 [[R1:%r[0-9]+]], [test_trunc_2xi8_param_0];
-; CHECK:      mov.b32 {[[RS1:%rs[0-9]+]], [[RS2:%rs[0-9]+]]}, [[R1]];
-; CHECK:      shl.b16 	[[RS3:%rs[0-9]+]], [[RS2]], 8;
-; CHECK:      and.b16  [[RS4:%rs[0-9]+]], [[RS1]], 255;
-; CHECK:      or.b16   [[RS5:%rs[0-9]+]], [[RS4]], [[RS3]]
-; CHECK:      cvt.u32.u16  [[R2:%r[0-9]]], [[RS5]]
-; CHECK:      st.param.b32  [func_retval0], [[R2]];
-define i16 @test_trunc_2xi8(<2 x i16> %a) #0 {
-  %trunc = trunc <2 x i16> %a to <2 x i8>
-  %res = bitcast <2 x i8> %trunc to i16
+; CHECK-LABEL: test_bitcast_2xi8_i16(
+; CHECK: ld.param.u32 	%r1, [test_bitcast_2xi8_i16_param_0];
+; CHECK: mov.b32 	{%rs1, %rs2}, %r1;
+; CHECK: shl.b16 	%rs3, %rs2, 8;
+; CHECK: and.b16  	%rs4, %rs1, 255;
+; CHECK: or.b16  	%rs5, %rs4, %rs3;
+; CHECK: cvt.u32.u16 	%r2, %rs5;
+; CHECK: st.param.b32 	[func_retval0], %r2;
+define i16 @test_bitcast_2xi8_i16(<2 x i8> %a) {
+  %res = bitcast <2 x i8> %a to i16
   ret i16 %res
 }
 
-; CHECK-LABEL: test_zext_2xi8(
-; CHECK:      ld.param.u16  [[RS1:%rs[0-9]+]], [test_zext_2xi8_param_0];
-; CHECK:      shr.u16 	[[RS2:%rs[0-9]+]], [[RS1]], 8;
-; CHECK:      mov.b32  [[R1:%r[0-9]+]], {[[RS1]], [[RS2]]}
-; CHECK:      and.b32  [[R2:%r[0-9]+]], [[R1]], 16711935;
-; CHECK:      st.param.b32  [func_retval0], [[R2]];
-define <2 x i16> @test_zext_2xi8(i16 %a) #0 {
-  %vec = bitcast i16 %a to <2 x i8>
-  %ext = zext <2 x i8> %vec to <2 x i16>
-  ret <2 x i16> %ext
+; CHECK-LABEL: test_bitcast_i16_2xi8(
+; CHECK: ld.param.u16 	%rs1, [test_bitcast_i16_2xi8_param_0];
+; CHECK: shr.u16 	%rs2, %rs1, 8;
+; CHECK: mov.b32 	%r1, {%rs1, %rs2};
+; CHECK: st.param.b32 	[func_retval0], %r1;
+define <2 x i8> @test_bitcast_i16_2xi8(i16 %a) {
+  %res = bitcast i16 %a to <2 x i8>
+  ret <2 x i8> %res
 }

>From b611165ec44ac7e7fb53707cc77792df0aab3b1a Mon Sep 17 00:00:00 2001
From: Peter Bell <peterbell10 at openai.com>
Date: Tue, 29 Oct 2024 22:53:28 +0000
Subject: [PATCH 5/6] Split ReplaceNodeResults path out from LowerBITCAST

---
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 53 +++++++++++++--------
 1 file changed, 32 insertions(+), 21 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 72a34de6b33e61..782dab5f0dac22 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -409,6 +409,13 @@ VectorizePTXValueVTs(const SmallVectorImpl<EVT> &ValueVTs,
   return VectorInfo;
 }
 
+static SDValue MaybeBitcast(SelectionDAG &DAG, SDLoc DL, EVT VT,
+                            SDValue Value) {
+  if (Value->getValueType(0) == VT)
+    return Value;
+  return DAG.getNode(ISD::BITCAST, DL, VT, Value);
+}
+
 // NVPTXTargetLowering Constructor.
 NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
                                          const NVPTXSubtarget &STI)
@@ -2316,30 +2323,14 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
 }
 
 SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
-  // Handle bitcasting to/from v2i8 without hitting the default promotion
+  // Handle bitcasting from v2i8 without hitting the default promotion
   // strategy which goes through stack memory.
-  SDNode *Node = Op.getNode();
-  SDLoc DL(Node);
-
-  auto maybeBitcast = [&](EVT VT, SDValue Value) {
-    if (Value->getValueType(0) == VT)
-      return Value;
-    return DAG.getNode(ISD::BITCAST, DL, VT, Value);
-  };
+  SDLoc DL(Op);
 
   EVT ToVT = Op->getValueType(0);
   EVT FromVT = Op->getOperand(0)->getValueType(0);
 
-  if (ToVT == MVT::v2i8) {
-    // Bitcast to i16 and unpack elements into a vector
-    SDValue AsInt = maybeBitcast(MVT::i16, Op->getOperand(0));
-    SDValue Vec0 = DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, AsInt);
-    SDValue Const8 = DAG.getConstant(8, DL, MVT::i16);
-    SDValue Vec1 =
-        DAG.getNode(ISD::TRUNCATE, DL, MVT::i8,
-                    DAG.getNode(ISD::SRL, DL, MVT::i16, {AsInt, Const8}));
-    return DAG.getNode(ISD::BUILD_VECTOR, DL, MVT::v2i8, {Vec0, Vec1});
-  } else if (FromVT == MVT::v2i8) {
+  if (FromVT == MVT::v2i8) {
     // Pack vector elements into i16 and bitcast to final type
     SDValue Vec0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8,
                                Op->getOperand(0), DAG.getIntPtrConstant(0, DL));
@@ -2351,7 +2342,7 @@ SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
     SDValue AsInt = DAG.getNode(
         ISD::OR, DL, MVT::i16,
         {Extend0, DAG.getNode(ISD::SHL, DL, MVT::i16, {Extend1, Const8})});
-    return maybeBitcast(ToVT, AsInt);
+    return MaybeBitcast(DAG, DL, ToVT, AsInt);
   }
   return Op;
 }
@@ -6175,6 +6166,26 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
   return SDValue();
 }
 
+static void ReplaceBITCAST(SDNode *Node, SelectionDAG &DAG,
+                           SmallVectorImpl<SDValue> &Results) {
+  // Handle bitcasting to v2i8 without hitting the default promotion
+  // strategy which goes through stack memory.
+  SDValue Op(Node, 0);
+  SDLoc DL(Node);
+
+  EVT ToVT = Op->getValueType(0);
+  if (ToVT == MVT::v2i8) {
+    SDValue AsInt = MaybeBitcast(DAG, DL, MVT::i16, Op->getOperand(0));
+    SDValue Vec0 = DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, AsInt);
+    SDValue Const8 = DAG.getConstant(8, DL, MVT::i16);
+    SDValue Vec1 =
+        DAG.getNode(ISD::TRUNCATE, DL, MVT::i8,
+                    DAG.getNode(ISD::SRL, DL, MVT::i16, {AsInt, Const8}));
+    Results.push_back(
+        DAG.getNode(ISD::BUILD_VECTOR, DL, MVT::v2i8, {Vec0, Vec1}));
+  }
+}
+
 /// ReplaceVectorLoad - Convert vector loads into multi-output scalar loads.
 static void ReplaceLoadVector(SDNode *N, SelectionDAG &DAG,
                               SmallVectorImpl<SDValue> &Results) {
@@ -6461,7 +6472,7 @@ void NVPTXTargetLowering::ReplaceNodeResults(
   default:
     report_fatal_error("Unhandled custom legalization");
   case ISD::BITCAST:
-    Results.push_back(LowerBITCAST(SDValue(N, 0), DAG));
+    ReplaceBITCAST(N, DAG, Results);
     return;
   case ISD::LOAD:
     ReplaceLoadVector(N, DAG, Results);

>From be5bc92a23205e4273fe5fc1a9c9684d57a71c6d Mon Sep 17 00:00:00 2001
From: Peter Bell <peterbell10 at openai.com>
Date: Wed, 30 Oct 2024 02:52:06 +0000
Subject: [PATCH 6/6] Use guard clauses

---
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 59 +++++++++++----------
 1 file changed, 30 insertions(+), 29 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 782dab5f0dac22..ba21733e961654 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -2325,26 +2325,25 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
 SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
   // Handle bitcasting from v2i8 without hitting the default promotion
   // strategy which goes through stack memory.
-  SDLoc DL(Op);
-
-  EVT ToVT = Op->getValueType(0);
   EVT FromVT = Op->getOperand(0)->getValueType(0);
-
-  if (FromVT == MVT::v2i8) {
-    // Pack vector elements into i16 and bitcast to final type
-    SDValue Vec0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8,
-                               Op->getOperand(0), DAG.getIntPtrConstant(0, DL));
-    SDValue Vec1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8,
-                               Op->getOperand(0), DAG.getIntPtrConstant(1, DL));
-    SDValue Extend0 = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i16, Vec0);
-    SDValue Extend1 = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i16, Vec1);
-    SDValue Const8 = DAG.getConstant(8, DL, MVT::i16);
-    SDValue AsInt = DAG.getNode(
-        ISD::OR, DL, MVT::i16,
-        {Extend0, DAG.getNode(ISD::SHL, DL, MVT::i16, {Extend1, Const8})});
-    return MaybeBitcast(DAG, DL, ToVT, AsInt);
+  if (FromVT != MVT::v2i8) {
+    return Op;
   }
-  return Op;
+
+  // Pack vector elements into i16 and bitcast to final type
+  SDLoc DL(Op);
+  SDValue Vec0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8,
+                             Op->getOperand(0), DAG.getIntPtrConstant(0, DL));
+  SDValue Vec1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8,
+                             Op->getOperand(0), DAG.getIntPtrConstant(1, DL));
+  SDValue Extend0 = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i16, Vec0);
+  SDValue Extend1 = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i16, Vec1);
+  SDValue Const8 = DAG.getConstant(8, DL, MVT::i16);
+  SDValue AsInt = DAG.getNode(
+      ISD::OR, DL, MVT::i16,
+      {Extend0, DAG.getNode(ISD::SHL, DL, MVT::i16, {Extend1, Const8})});
+  EVT ToVT = Op->getValueType(0);
+  return MaybeBitcast(DAG, DL, ToVT, AsInt);
 }
 
 // We can init constant f16x2/v2i16/v4i8 with a single .b32 move.  Normally it
@@ -6171,19 +6170,21 @@ static void ReplaceBITCAST(SDNode *Node, SelectionDAG &DAG,
   // Handle bitcasting to v2i8 without hitting the default promotion
   // strategy which goes through stack memory.
   SDValue Op(Node, 0);
-  SDLoc DL(Node);
-
   EVT ToVT = Op->getValueType(0);
-  if (ToVT == MVT::v2i8) {
-    SDValue AsInt = MaybeBitcast(DAG, DL, MVT::i16, Op->getOperand(0));
-    SDValue Vec0 = DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, AsInt);
-    SDValue Const8 = DAG.getConstant(8, DL, MVT::i16);
-    SDValue Vec1 =
-        DAG.getNode(ISD::TRUNCATE, DL, MVT::i8,
-                    DAG.getNode(ISD::SRL, DL, MVT::i16, {AsInt, Const8}));
-    Results.push_back(
-        DAG.getNode(ISD::BUILD_VECTOR, DL, MVT::v2i8, {Vec0, Vec1}));
+  if (ToVT != MVT::v2i8) {
+    return;
   }
+
+  // Bitcast to i16 and unpack elements into a vector
+  SDLoc DL(Node);
+  SDValue AsInt = MaybeBitcast(DAG, DL, MVT::i16, Op->getOperand(0));
+  SDValue Vec0 = DAG.getNode(ISD::TRUNCATE, DL, MVT::i8, AsInt);
+  SDValue Const8 = DAG.getConstant(8, DL, MVT::i16);
+  SDValue Vec1 =
+      DAG.getNode(ISD::TRUNCATE, DL, MVT::i8,
+                  DAG.getNode(ISD::SRL, DL, MVT::i16, {AsInt, Const8}));
+  Results.push_back(
+      DAG.getNode(ISD::BUILD_VECTOR, DL, MVT::v2i8, {Vec0, Vec1}));
 }
 
 /// ReplaceVectorLoad - Convert vector loads into multi-output scalar loads.



More information about the llvm-commits mailing list