[llvm] [NVPTX] Eliminate `prmt`s that result from `BUILD_VECTOR` of `LoadV2` (PR #149581)

Justin Fargnoli via llvm-commits llvm-commits at lists.llvm.org
Fri Jul 18 13:06:08 PDT 2025


https://github.com/justinfargnoli created https://github.com/llvm/llvm-project/pull/149581

None

>From 22c519a473ab88395a0a1ff8945ca3d57bc3f434 Mon Sep 17 00:00:00 2001
From: Justin Fargnoli <jfargnoli at nvidia.com>
Date: Fri, 18 Jul 2025 19:52:43 +0000
Subject: [PATCH 1/2] Initial test

---
 .../CodeGen/NVPTX/build-vector-combine.ll     | 129 ++++++++++++++++++
 1 file changed, 129 insertions(+)
 create mode 100644 llvm/test/CodeGen/NVPTX/build-vector-combine.ll

diff --git a/llvm/test/CodeGen/NVPTX/build-vector-combine.ll b/llvm/test/CodeGen/NVPTX/build-vector-combine.ll
new file mode 100644
index 0000000000000..3c475a4ee3765
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/build-vector-combine.ll
@@ -0,0 +1,129 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -march=nvptx64 | FileCheck %s
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 | %ptxas-verify %}
+
+target datalayout = "e-p:64:64:64-p3:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-i128:128:128-f32:32:32-f64:64:64-f128:128:128-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64-a:8:8"
+target triple = "nvptx64-nvidia-cuda"
+
+define void @t1() {
+; CHECK-LABEL: t1(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<3>;
+; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-NEXT:    .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0: // %entry
+; CHECK-NEXT:    mov.b64 %rd1, 0;
+; CHECK-NEXT:    ld.global.v2.b8 {%rs1, %rs2}, [%rd1];
+; CHECK-NEXT:    cvt.u32.u16 %r1, %rs2;
+; CHECK-NEXT:    cvt.u32.u16 %r2, %rs1;
+; CHECK-NEXT:    prmt.b32 %r3, %r2, %r1, 0x3340U;
+; CHECK-NEXT:    prmt.b32 %r4, %r3, 0, 0x5410U;
+; CHECK-NEXT:    st.global.v4.b32 [%rd1], {%r4, 0, 0, 0};
+; CHECK-NEXT:    ret;
+entry:
+  %0 = load <2 x i8>, ptr addrspace(1) null, align 4
+  %1 = shufflevector <2 x i8> %0, <2 x i8> zeroinitializer, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %2 = bitcast <4 x i8> %1 to i32
+  %3 = insertelement <4 x i32> zeroinitializer, i32 %2, i64 0
+  store <4 x i32> %3, ptr addrspace(1) null, align 16
+  ret void
+}
+
+define void @t2() {
+; CHECK-LABEL: t2(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<3>;
+; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-NEXT:    .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0: // %entry
+; CHECK-NEXT:    mov.b64 %rd1, 0;
+; CHECK-NEXT:    ld.global.v2.b8 {%rs1, %rs2}, [%rd1];
+; CHECK-NEXT:    cvt.u32.u16 %r1, %rs2;
+; CHECK-NEXT:    cvt.u32.u16 %r2, %rs1;
+; CHECK-NEXT:    prmt.b32 %r3, %r2, %r1, 0x3340U;
+; CHECK-NEXT:    prmt.b32 %r4, %r3, 0, 0x5410U;
+; CHECK-NEXT:    st.local.b32 [%rd1], %r4;
+; CHECK-NEXT:    ret;
+entry:
+  %0 = load <2 x i8>, ptr addrspace(1) null, align 8
+  %1 = shufflevector <2 x i8> %0, <2 x i8> zeroinitializer, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  store <4 x i8> %1, ptr addrspace(5) null, align 8
+  ret void
+}
+
+declare <2 x i8> @llvm.nvvm.ldg.global.i.v2i8.p1(ptr addrspace(1) %ptr, i32 %align)
+
+define void @ldg(ptr addrspace(1) %ptr) {
+; CHECK-LABEL: ldg(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<3>;
+; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0: // %entry
+; CHECK-NEXT:    ld.param.b64 %rd1, [ldg_param_0];
+; CHECK-NEXT:    ld.global.v2.b8 {%rs1, %rs2}, [%rd1];
+; CHECK-NEXT:    cvt.u32.u16 %r1, %rs2;
+; CHECK-NEXT:    cvt.u32.u16 %r2, %rs1;
+; CHECK-NEXT:    prmt.b32 %r3, %r2, %r1, 0x3340U;
+; CHECK-NEXT:    prmt.b32 %r4, %r3, 0, 0x5410U;
+; CHECK-NEXT:    mov.b64 %rd2, 0;
+; CHECK-NEXT:    st.local.b32 [%rd2], %r4;
+; CHECK-NEXT:    ret;
+entry:
+  %0 = tail call <2 x i8> @llvm.nvvm.ldg.global.i.v2i8.p1(ptr addrspace(1) %ptr, i32 2)
+  %1 = shufflevector <2 x i8> %0, <2 x i8> zeroinitializer, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  store <4 x i8> %1, ptr addrspace(5) null, align 8
+  ret void
+}
+
+declare <2 x i8> @llvm.nvvm.ldu.global.f.v2i8.p1(ptr addrspace(1) %ptr, i32 %align)
+
+define void @ldu(ptr addrspace(1) %ptr) {
+; CHECK-LABEL: ldu(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<3>;
+; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-NEXT:    .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0: // %entry
+; CHECK-NEXT:    ld.param.b64 %rd1, [ldu_param_0];
+; CHECK-NEXT:    ldu.global.v2.b8 {%rs1, %rs2}, [%rd1];
+; CHECK-NEXT:    cvt.u32.u16 %r1, %rs2;
+; CHECK-NEXT:    cvt.u32.u16 %r2, %rs1;
+; CHECK-NEXT:    prmt.b32 %r3, %r2, %r1, 0x3340U;
+; CHECK-NEXT:    prmt.b32 %r4, %r3, 0, 0x5410U;
+; CHECK-NEXT:    mov.b64 %rd2, 0;
+; CHECK-NEXT:    st.local.b32 [%rd2], %r4;
+; CHECK-NEXT:    ret;
+entry:
+  %0 = tail call <2 x i8> @llvm.nvvm.ldu.global.i.v2i8.p1(ptr addrspace(1) %ptr, i32 2)
+  %1 = shufflevector <2 x i8> %0, <2 x i8> zeroinitializer, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  store <4 x i8> %1, ptr addrspace(5) null, align 8
+  ret void
+}
+
+define void @t3() {
+; CHECK-LABEL: t3(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b16 %rs<3>;
+; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-NEXT:    .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    mov.b64 %rd1, 0;
+; CHECK-NEXT:    ld.global.v2.b8 {%rs1, %rs2}, [%rd1];
+; CHECK-NEXT:    cvt.u32.u16 %r1, %rs2;
+; CHECK-NEXT:    cvt.u32.u16 %r2, %rs1;
+; CHECK-NEXT:    prmt.b32 %r3, %r2, %r1, 0x3340U;
+; CHECK-NEXT:    prmt.b32 %r4, %r3, 0, 0x5410U;
+; CHECK-NEXT:    st.global.v2.b32 [%rd1], {%r4, 0};
+; CHECK-NEXT:    ret;
+  %1 = load <2 x i8>, ptr addrspace(1) null, align 2
+  %insval2 = bitcast <2 x i8> %1 to i16
+  %2 = insertelement <4 x i16> zeroinitializer, i16 %insval2, i32 0
+  store <4 x i16> %2, ptr addrspace(1) null, align 8
+  ret void
+}

>From 1d2bbcf3f96833df266288b0560c4897235f932c Mon Sep 17 00:00:00 2001
From: Justin Fargnoli <jfargnoli at nvidia.com>
Date: Fri, 18 Jul 2025 19:54:19 +0000
Subject: [PATCH 2/2] [NVPTX] Eliminate prmts that result from BUILD_VECTOR of
 LoadV2

---
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp   | 77 ++++++++++++++++++-
 .../CodeGen/NVPTX/build-vector-combine.ll     | 57 ++++----------
 2 files changed, 93 insertions(+), 41 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 7aa06f9079b09..5f98b1a27617d 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -5772,7 +5772,8 @@ static SDValue PerformVSELECTCombine(SDNode *N,
 }
 
 static SDValue
-PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
+PerformBUILD_VECTOROfV2i16Combine(SDNode *N,
+                                  TargetLowering::DAGCombinerInfo &DCI) {
   auto VT = N->getValueType(0);
   if (!DCI.isAfterLegalizeDAG() ||
       // only process v2*16 types
@@ -5833,6 +5834,80 @@ PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
   return DAG.getBitcast(VT, PRMT);
 }
 
+static SDValue
+PerformBUILD_VECTOROfTargetLoadCombine(SDNode *N,
+                                       TargetLowering::DAGCombinerInfo &DCI) {
+  // Match: BUILD_VECTOR of v4i8, where first two elements are from a
+  // NVPTXISD::LoadV2 or NVPTXISD::LDUV2 of i8, and the last two elements are
+  // zero constants. Replace with: zext the loaded i16 to i32, and return as a
+  // bitcast to v4i8.
+  EVT VT = N->getValueType(0);
+  if (VT != MVT::v4i8)
+    return SDValue();
+  // Check operands: [0]=lo, [1]=hi
+  SDValue Op0 = N->getOperand(0);
+  SDValue Op1 = N->getOperand(1);
+  // Check that Op0 and Op1 are from the same NVPTXISD::LoadV2 or
+  // NVPTXISD::LDUV2
+  if (Op0.getNode() != Op1.getNode())
+    return SDValue();
+  if (!(Op0.getOpcode() == NVPTXISD::LoadV2 ||
+        Op0.getOpcode() == NVPTXISD::LDUV2))
+    return SDValue();
+  if (Op0.getValueType() != MVT::i16)
+    return SDValue();
+  if (!(Op0.hasOneUse() && Op1.hasOneUse()))
+    return SDValue();
+
+  // Check operands: [2]= 0 or undef, [3]= 0 or undef
+  SDValue Op2 = N->getOperand(2);
+  SDValue Op3 = N->getOperand(3);
+  if (Op2 != Op3)
+    return SDValue();
+  if (!Op2.isUndef()) {
+    auto *C2 = dyn_cast<ConstantSDNode>(Op2);
+    if (!(C2 && C2->isZero()))
+      return SDValue();
+  }
+
+  // Now, replace with: zext(load i16) -> i32, then bitcast to v4i8
+  auto &DAG = DCI.DAG;
+  // Rebuild the load as i16
+  auto *Load = cast<MemSDNode>(Op0.getNode());
+  SDLoc DL(Load);
+  SDValue LoadI16;
+  if (Load->getOpcode() == NVPTXISD::LoadV2) {
+    LoadI16 = DAG.getLoad(MVT::i16, DL, Load->getChain(), Load->getBasePtr(),
+                          Load->getPointerInfo(), Load->getAlign(),
+                          Load->getMemOperand()->getFlags());
+  } else {
+    assert(Load->getOpcode() == NVPTXISD::LDUV2);
+    const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+    SmallVector<SDValue, 4> Ops;
+    Ops.push_back(Load->getChain());
+    Ops.push_back(DAG.getConstant(Intrinsic::nvvm_ldu_global_i, DL,
+                                  TLI.getPointerTy(DAG.getDataLayout())));
+    for (unsigned i = 1; i < Load->getNumOperands(); ++i)
+      Ops.push_back(Load->getOperand(i));
+    SDVTList NodeVTList = DAG.getVTList(MVT::i16, MVT::Other);
+    LoadI16 = DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, NodeVTList,
+                                      Ops, MVT::i16, Load->getPointerInfo(),
+                                      Load->getAlign());
+  }
+  DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 2), LoadI16.getValue(1));
+  SDValue Zext = DAG.getZExtOrTrunc(LoadI16, DL, MVT::i32);
+  return DAG.getBitcast(MVT::v4i8, Zext);
+}
+
+static SDValue
+PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
+  if (const auto V = PerformBUILD_VECTOROfV2i16Combine(N, DCI))
+    return V;
+  if (const auto V = PerformBUILD_VECTOROfTargetLoadCombine(N, DCI))
+    return V;
+  return SDValue();
+}
+
 static SDValue combineADDRSPACECAST(SDNode *N,
                                     TargetLowering::DAGCombinerInfo &DCI) {
   auto *ASCN1 = cast<AddrSpaceCastSDNode>(N);
diff --git a/llvm/test/CodeGen/NVPTX/build-vector-combine.ll b/llvm/test/CodeGen/NVPTX/build-vector-combine.ll
index 3c475a4ee3765..019bd3bde8761 100644
--- a/llvm/test/CodeGen/NVPTX/build-vector-combine.ll
+++ b/llvm/test/CodeGen/NVPTX/build-vector-combine.ll
@@ -8,18 +8,13 @@ target triple = "nvptx64-nvidia-cuda"
 define void @t1() {
 ; CHECK-LABEL: t1(
 ; CHECK:       {
-; CHECK-NEXT:    .reg .b16 %rs<3>;
-; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-NEXT:    .reg .b32 %r<2>;
 ; CHECK-NEXT:    .reg .b64 %rd<2>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0: // %entry
 ; CHECK-NEXT:    mov.b64 %rd1, 0;
-; CHECK-NEXT:    ld.global.v2.b8 {%rs1, %rs2}, [%rd1];
-; CHECK-NEXT:    cvt.u32.u16 %r1, %rs2;
-; CHECK-NEXT:    cvt.u32.u16 %r2, %rs1;
-; CHECK-NEXT:    prmt.b32 %r3, %r2, %r1, 0x3340U;
-; CHECK-NEXT:    prmt.b32 %r4, %r3, 0, 0x5410U;
-; CHECK-NEXT:    st.global.v4.b32 [%rd1], {%r4, 0, 0, 0};
+; CHECK-NEXT:    ld.global.b16 %r1, [%rd1];
+; CHECK-NEXT:    st.global.v4.b32 [%rd1], {%r1, 0, 0, 0};
 ; CHECK-NEXT:    ret;
 entry:
   %0 = load <2 x i8>, ptr addrspace(1) null, align 4
@@ -33,18 +28,13 @@ entry:
 define void @t2() {
 ; CHECK-LABEL: t2(
 ; CHECK:       {
-; CHECK-NEXT:    .reg .b16 %rs<3>;
-; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-NEXT:    .reg .b32 %r<2>;
 ; CHECK-NEXT:    .reg .b64 %rd<2>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0: // %entry
 ; CHECK-NEXT:    mov.b64 %rd1, 0;
-; CHECK-NEXT:    ld.global.v2.b8 {%rs1, %rs2}, [%rd1];
-; CHECK-NEXT:    cvt.u32.u16 %r1, %rs2;
-; CHECK-NEXT:    cvt.u32.u16 %r2, %rs1;
-; CHECK-NEXT:    prmt.b32 %r3, %r2, %r1, 0x3340U;
-; CHECK-NEXT:    prmt.b32 %r4, %r3, 0, 0x5410U;
-; CHECK-NEXT:    st.local.b32 [%rd1], %r4;
+; CHECK-NEXT:    ld.global.b16 %r1, [%rd1];
+; CHECK-NEXT:    st.local.b32 [%rd1], %r1;
 ; CHECK-NEXT:    ret;
 entry:
   %0 = load <2 x i8>, ptr addrspace(1) null, align 8
@@ -58,19 +48,14 @@ declare <2 x i8> @llvm.nvvm.ldg.global.i.v2i8.p1(ptr addrspace(1) %ptr, i32 %ali
 define void @ldg(ptr addrspace(1) %ptr) {
 ; CHECK-LABEL: ldg(
 ; CHECK:       {
-; CHECK-NEXT:    .reg .b16 %rs<3>;
-; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-NEXT:    .reg .b32 %r<2>;
 ; CHECK-NEXT:    .reg .b64 %rd<3>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0: // %entry
 ; CHECK-NEXT:    ld.param.b64 %rd1, [ldg_param_0];
-; CHECK-NEXT:    ld.global.v2.b8 {%rs1, %rs2}, [%rd1];
-; CHECK-NEXT:    cvt.u32.u16 %r1, %rs2;
-; CHECK-NEXT:    cvt.u32.u16 %r2, %rs1;
-; CHECK-NEXT:    prmt.b32 %r3, %r2, %r1, 0x3340U;
-; CHECK-NEXT:    prmt.b32 %r4, %r3, 0, 0x5410U;
+; CHECK-NEXT:    ld.global.b16 %r1, [%rd1];
 ; CHECK-NEXT:    mov.b64 %rd2, 0;
-; CHECK-NEXT:    st.local.b32 [%rd2], %r4;
+; CHECK-NEXT:    st.local.b32 [%rd2], %r1;
 ; CHECK-NEXT:    ret;
 entry:
   %0 = tail call <2 x i8> @llvm.nvvm.ldg.global.i.v2i8.p1(ptr addrspace(1) %ptr, i32 2)
@@ -84,19 +69,16 @@ declare <2 x i8> @llvm.nvvm.ldu.global.f.v2i8.p1(ptr addrspace(1) %ptr, i32 %ali
 define void @ldu(ptr addrspace(1) %ptr) {
 ; CHECK-LABEL: ldu(
 ; CHECK:       {
-; CHECK-NEXT:    .reg .b16 %rs<3>;
-; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-NEXT:    .reg .b16 %rs<2>;
+; CHECK-NEXT:    .reg .b32 %r<2>;
 ; CHECK-NEXT:    .reg .b64 %rd<3>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0: // %entry
 ; CHECK-NEXT:    ld.param.b64 %rd1, [ldu_param_0];
-; CHECK-NEXT:    ldu.global.v2.b8 {%rs1, %rs2}, [%rd1];
-; CHECK-NEXT:    cvt.u32.u16 %r1, %rs2;
-; CHECK-NEXT:    cvt.u32.u16 %r2, %rs1;
-; CHECK-NEXT:    prmt.b32 %r3, %r2, %r1, 0x3340U;
-; CHECK-NEXT:    prmt.b32 %r4, %r3, 0, 0x5410U;
+; CHECK-NEXT:    ldu.global.b16 %rs1, [%rd1];
+; CHECK-NEXT:    cvt.u32.u16 %r1, %rs1;
 ; CHECK-NEXT:    mov.b64 %rd2, 0;
-; CHECK-NEXT:    st.local.b32 [%rd2], %r4;
+; CHECK-NEXT:    st.local.b32 [%rd2], %r1;
 ; CHECK-NEXT:    ret;
 entry:
   %0 = tail call <2 x i8> @llvm.nvvm.ldu.global.i.v2i8.p1(ptr addrspace(1) %ptr, i32 2)
@@ -108,18 +90,13 @@ entry:
 define void @t3() {
 ; CHECK-LABEL: t3(
 ; CHECK:       {
-; CHECK-NEXT:    .reg .b16 %rs<3>;
-; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-NEXT:    .reg .b32 %r<2>;
 ; CHECK-NEXT:    .reg .b64 %rd<2>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0:
 ; CHECK-NEXT:    mov.b64 %rd1, 0;
-; CHECK-NEXT:    ld.global.v2.b8 {%rs1, %rs2}, [%rd1];
-; CHECK-NEXT:    cvt.u32.u16 %r1, %rs2;
-; CHECK-NEXT:    cvt.u32.u16 %r2, %rs1;
-; CHECK-NEXT:    prmt.b32 %r3, %r2, %r1, 0x3340U;
-; CHECK-NEXT:    prmt.b32 %r4, %r3, 0, 0x5410U;
-; CHECK-NEXT:    st.global.v2.b32 [%rd1], {%r4, 0};
+; CHECK-NEXT:    ld.global.b16 %r1, [%rd1];
+; CHECK-NEXT:    st.global.v2.b32 [%rd1], {%r1, 0};
 ; CHECK-NEXT:    ret;
   %1 = load <2 x i8>, ptr addrspace(1) null, align 2
   %insval2 = bitcast <2 x i8> %1 to i16



More information about the llvm-commits mailing list