[llvm] [NVPTX] Eliminate `prmt`s that result from `BUILD_VECTOR` of `LoadV2` (PR #149581)
Justin Fargnoli via llvm-commits
llvm-commits at lists.llvm.org
Fri Jul 18 13:06:08 PDT 2025
https://github.com/justinfargnoli created https://github.com/llvm/llvm-project/pull/149581
None
>From 22c519a473ab88395a0a1ff8945ca3d57bc3f434 Mon Sep 17 00:00:00 2001
From: Justin Fargnoli <jfargnoli at nvidia.com>
Date: Fri, 18 Jul 2025 19:52:43 +0000
Subject: [PATCH 1/2] Initial test
---
.../CodeGen/NVPTX/build-vector-combine.ll | 129 ++++++++++++++++++
1 file changed, 129 insertions(+)
create mode 100644 llvm/test/CodeGen/NVPTX/build-vector-combine.ll
diff --git a/llvm/test/CodeGen/NVPTX/build-vector-combine.ll b/llvm/test/CodeGen/NVPTX/build-vector-combine.ll
new file mode 100644
index 0000000000000..3c475a4ee3765
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/build-vector-combine.ll
@@ -0,0 +1,129 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -march=nvptx64 | FileCheck %s
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 | %ptxas-verify %}
+
+target datalayout = "e-p:64:64:64-p3:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-i128:128:128-f32:32:32-f64:64:64-f128:128:128-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64-a:8:8"
+target triple = "nvptx64-nvidia-cuda"
+
+define void @t1() {
+; CHECK-LABEL: t1(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<3>;
+; CHECK-NEXT: .reg .b32 %r<5>;
+; CHECK-NEXT: .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0: // %entry
+; CHECK-NEXT: mov.b64 %rd1, 0;
+; CHECK-NEXT: ld.global.v2.b8 {%rs1, %rs2}, [%rd1];
+; CHECK-NEXT: cvt.u32.u16 %r1, %rs2;
+; CHECK-NEXT: cvt.u32.u16 %r2, %rs1;
+; CHECK-NEXT: prmt.b32 %r3, %r2, %r1, 0x3340U;
+; CHECK-NEXT: prmt.b32 %r4, %r3, 0, 0x5410U;
+; CHECK-NEXT: st.global.v4.b32 [%rd1], {%r4, 0, 0, 0};
+; CHECK-NEXT: ret;
+entry:
+ %0 = load <2 x i8>, ptr addrspace(1) null, align 4
+ %1 = shufflevector <2 x i8> %0, <2 x i8> zeroinitializer, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+ %2 = bitcast <4 x i8> %1 to i32
+ %3 = insertelement <4 x i32> zeroinitializer, i32 %2, i64 0
+ store <4 x i32> %3, ptr addrspace(1) null, align 16
+ ret void
+}
+
+define void @t2() {
+; CHECK-LABEL: t2(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<3>;
+; CHECK-NEXT: .reg .b32 %r<5>;
+; CHECK-NEXT: .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0: // %entry
+; CHECK-NEXT: mov.b64 %rd1, 0;
+; CHECK-NEXT: ld.global.v2.b8 {%rs1, %rs2}, [%rd1];
+; CHECK-NEXT: cvt.u32.u16 %r1, %rs2;
+; CHECK-NEXT: cvt.u32.u16 %r2, %rs1;
+; CHECK-NEXT: prmt.b32 %r3, %r2, %r1, 0x3340U;
+; CHECK-NEXT: prmt.b32 %r4, %r3, 0, 0x5410U;
+; CHECK-NEXT: st.local.b32 [%rd1], %r4;
+; CHECK-NEXT: ret;
+entry:
+ %0 = load <2 x i8>, ptr addrspace(1) null, align 8
+ %1 = shufflevector <2 x i8> %0, <2 x i8> zeroinitializer, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+ store <4 x i8> %1, ptr addrspace(5) null, align 8
+ ret void
+}
+
+declare <2 x i8> @llvm.nvvm.ldg.global.i.v2i8.p1(ptr addrspace(1) %ptr, i32 %align)
+
+define void @ldg(ptr addrspace(1) %ptr) {
+; CHECK-LABEL: ldg(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<3>;
+; CHECK-NEXT: .reg .b32 %r<5>;
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0: // %entry
+; CHECK-NEXT: ld.param.b64 %rd1, [ldg_param_0];
+; CHECK-NEXT: ld.global.v2.b8 {%rs1, %rs2}, [%rd1];
+; CHECK-NEXT: cvt.u32.u16 %r1, %rs2;
+; CHECK-NEXT: cvt.u32.u16 %r2, %rs1;
+; CHECK-NEXT: prmt.b32 %r3, %r2, %r1, 0x3340U;
+; CHECK-NEXT: prmt.b32 %r4, %r3, 0, 0x5410U;
+; CHECK-NEXT: mov.b64 %rd2, 0;
+; CHECK-NEXT: st.local.b32 [%rd2], %r4;
+; CHECK-NEXT: ret;
+entry:
+ %0 = tail call <2 x i8> @llvm.nvvm.ldg.global.i.v2i8.p1(ptr addrspace(1) %ptr, i32 2)
+ %1 = shufflevector <2 x i8> %0, <2 x i8> zeroinitializer, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+ store <4 x i8> %1, ptr addrspace(5) null, align 8
+ ret void
+}
+
+declare <2 x i8> @llvm.nvvm.ldu.global.f.v2i8.p1(ptr addrspace(1) %ptr, i32 %align)
+
+define void @ldu(ptr addrspace(1) %ptr) {
+; CHECK-LABEL: ldu(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<3>;
+; CHECK-NEXT: .reg .b32 %r<5>;
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0: // %entry
+; CHECK-NEXT: ld.param.b64 %rd1, [ldu_param_0];
+; CHECK-NEXT: ldu.global.v2.b8 {%rs1, %rs2}, [%rd1];
+; CHECK-NEXT: cvt.u32.u16 %r1, %rs2;
+; CHECK-NEXT: cvt.u32.u16 %r2, %rs1;
+; CHECK-NEXT: prmt.b32 %r3, %r2, %r1, 0x3340U;
+; CHECK-NEXT: prmt.b32 %r4, %r3, 0, 0x5410U;
+; CHECK-NEXT: mov.b64 %rd2, 0;
+; CHECK-NEXT: st.local.b32 [%rd2], %r4;
+; CHECK-NEXT: ret;
+entry:
+ %0 = tail call <2 x i8> @llvm.nvvm.ldu.global.i.v2i8.p1(ptr addrspace(1) %ptr, i32 2)
+ %1 = shufflevector <2 x i8> %0, <2 x i8> zeroinitializer, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+ store <4 x i8> %1, ptr addrspace(5) null, align 8
+ ret void
+}
+
+define void @t3() {
+; CHECK-LABEL: t3(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<3>;
+; CHECK-NEXT: .reg .b32 %r<5>;
+; CHECK-NEXT: .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: mov.b64 %rd1, 0;
+; CHECK-NEXT: ld.global.v2.b8 {%rs1, %rs2}, [%rd1];
+; CHECK-NEXT: cvt.u32.u16 %r1, %rs2;
+; CHECK-NEXT: cvt.u32.u16 %r2, %rs1;
+; CHECK-NEXT: prmt.b32 %r3, %r2, %r1, 0x3340U;
+; CHECK-NEXT: prmt.b32 %r4, %r3, 0, 0x5410U;
+; CHECK-NEXT: st.global.v2.b32 [%rd1], {%r4, 0};
+; CHECK-NEXT: ret;
+ %1 = load <2 x i8>, ptr addrspace(1) null, align 2
+ %insval2 = bitcast <2 x i8> %1 to i16
+ %2 = insertelement <4 x i16> zeroinitializer, i16 %insval2, i32 0
+ store <4 x i16> %2, ptr addrspace(1) null, align 8
+ ret void
+}
>From 1d2bbcf3f96833df266288b0560c4897235f932c Mon Sep 17 00:00:00 2001
From: Justin Fargnoli <jfargnoli at nvidia.com>
Date: Fri, 18 Jul 2025 19:54:19 +0000
Subject: [PATCH 2/2] [NVPTX] Eliminate prmts that result from BUILD_VECTOR of
LoadV2
---
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 77 ++++++++++++++++++-
.../CodeGen/NVPTX/build-vector-combine.ll | 57 ++++----------
2 files changed, 93 insertions(+), 41 deletions(-)
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 7aa06f9079b09..5f98b1a27617d 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -5772,7 +5772,8 @@ static SDValue PerformVSELECTCombine(SDNode *N,
}
static SDValue
-PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
+PerformBUILD_VECTOROfV2i16Combine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI) {
auto VT = N->getValueType(0);
if (!DCI.isAfterLegalizeDAG() ||
// only process v2*16 types
@@ -5833,6 +5834,80 @@ PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
return DAG.getBitcast(VT, PRMT);
}
+static SDValue
+PerformBUILD_VECTOROfTargetLoadCombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI) {
+ // Match: BUILD_VECTOR of v4i8, where first two elements are from a
+ // NVPTXISD::LoadV2 or NVPTXISD::LDUV2 of i8, and the last two elements are
+ // zero constants. Replace with: zext the loaded i16 to i32, and return as a
+ // bitcast to v4i8.
+ EVT VT = N->getValueType(0);
+ if (VT != MVT::v4i8)
+ return SDValue();
+ // Check operands: [0]=lo, [1]=hi
+ SDValue Op0 = N->getOperand(0);
+ SDValue Op1 = N->getOperand(1);
+ // Check that Op0 and Op1 are from the same NVPTXISD::LoadV2 or
+ // NVPTXISD::LDUV2
+ if (Op0.getNode() != Op1.getNode())
+ return SDValue();
+ if (!(Op0.getOpcode() == NVPTXISD::LoadV2 ||
+ Op0.getOpcode() == NVPTXISD::LDUV2))
+ return SDValue();
+ if (Op0.getValueType() != MVT::i16)
+ return SDValue();
+ if (!(Op0.hasOneUse() && Op1.hasOneUse()))
+ return SDValue();
+
+ // Check operands: [2]= 0 or undef, [3]= 0 or undef
+ SDValue Op2 = N->getOperand(2);
+ SDValue Op3 = N->getOperand(3);
+ if (Op2 != Op3)
+ return SDValue();
+ if (!Op2.isUndef()) {
+ auto *C2 = dyn_cast<ConstantSDNode>(Op2);
+ if (!(C2 && C2->isZero()))
+ return SDValue();
+ }
+
+ // Now, replace with: zext(load i16) -> i32, then bitcast to v4i8
+ auto &DAG = DCI.DAG;
+ // Rebuild the load as i16
+ auto *Load = cast<MemSDNode>(Op0.getNode());
+ SDLoc DL(Load);
+ SDValue LoadI16;
+ if (Load->getOpcode() == NVPTXISD::LoadV2) {
+ LoadI16 = DAG.getLoad(MVT::i16, DL, Load->getChain(), Load->getBasePtr(),
+ Load->getPointerInfo(), Load->getAlign(),
+ Load->getMemOperand()->getFlags());
+ } else {
+ assert(Load->getOpcode() == NVPTXISD::LDUV2);
+ const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+ SmallVector<SDValue, 4> Ops;
+ Ops.push_back(Load->getChain());
+ Ops.push_back(DAG.getConstant(Intrinsic::nvvm_ldu_global_i, DL,
+ TLI.getPointerTy(DAG.getDataLayout())));
+ for (unsigned i = 1; i < Load->getNumOperands(); ++i)
+ Ops.push_back(Load->getOperand(i));
+ SDVTList NodeVTList = DAG.getVTList(MVT::i16, MVT::Other);
+ LoadI16 = DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, NodeVTList,
+ Ops, MVT::i16, Load->getPointerInfo(),
+ Load->getAlign());
+ }
+ DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 2), LoadI16.getValue(1));
+ SDValue Zext = DAG.getZExtOrTrunc(LoadI16, DL, MVT::i32);
+ return DAG.getBitcast(MVT::v4i8, Zext);
+}
+
+static SDValue
+PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
+ if (const auto V = PerformBUILD_VECTOROfV2i16Combine(N, DCI))
+ return V;
+ if (const auto V = PerformBUILD_VECTOROfTargetLoadCombine(N, DCI))
+ return V;
+ return SDValue();
+}
+
static SDValue combineADDRSPACECAST(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI) {
auto *ASCN1 = cast<AddrSpaceCastSDNode>(N);
diff --git a/llvm/test/CodeGen/NVPTX/build-vector-combine.ll b/llvm/test/CodeGen/NVPTX/build-vector-combine.ll
index 3c475a4ee3765..019bd3bde8761 100644
--- a/llvm/test/CodeGen/NVPTX/build-vector-combine.ll
+++ b/llvm/test/CodeGen/NVPTX/build-vector-combine.ll
@@ -8,18 +8,13 @@ target triple = "nvptx64-nvidia-cuda"
define void @t1() {
; CHECK-LABEL: t1(
; CHECK: {
-; CHECK-NEXT: .reg .b16 %rs<3>;
-; CHECK-NEXT: .reg .b32 %r<5>;
+; CHECK-NEXT: .reg .b32 %r<2>;
; CHECK-NEXT: .reg .b64 %rd<2>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0: // %entry
; CHECK-NEXT: mov.b64 %rd1, 0;
-; CHECK-NEXT: ld.global.v2.b8 {%rs1, %rs2}, [%rd1];
-; CHECK-NEXT: cvt.u32.u16 %r1, %rs2;
-; CHECK-NEXT: cvt.u32.u16 %r2, %rs1;
-; CHECK-NEXT: prmt.b32 %r3, %r2, %r1, 0x3340U;
-; CHECK-NEXT: prmt.b32 %r4, %r3, 0, 0x5410U;
-; CHECK-NEXT: st.global.v4.b32 [%rd1], {%r4, 0, 0, 0};
+; CHECK-NEXT: ld.global.b16 %r1, [%rd1];
+; CHECK-NEXT: st.global.v4.b32 [%rd1], {%r1, 0, 0, 0};
; CHECK-NEXT: ret;
entry:
%0 = load <2 x i8>, ptr addrspace(1) null, align 4
@@ -33,18 +28,13 @@ entry:
define void @t2() {
; CHECK-LABEL: t2(
; CHECK: {
-; CHECK-NEXT: .reg .b16 %rs<3>;
-; CHECK-NEXT: .reg .b32 %r<5>;
+; CHECK-NEXT: .reg .b32 %r<2>;
; CHECK-NEXT: .reg .b64 %rd<2>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0: // %entry
; CHECK-NEXT: mov.b64 %rd1, 0;
-; CHECK-NEXT: ld.global.v2.b8 {%rs1, %rs2}, [%rd1];
-; CHECK-NEXT: cvt.u32.u16 %r1, %rs2;
-; CHECK-NEXT: cvt.u32.u16 %r2, %rs1;
-; CHECK-NEXT: prmt.b32 %r3, %r2, %r1, 0x3340U;
-; CHECK-NEXT: prmt.b32 %r4, %r3, 0, 0x5410U;
-; CHECK-NEXT: st.local.b32 [%rd1], %r4;
+; CHECK-NEXT: ld.global.b16 %r1, [%rd1];
+; CHECK-NEXT: st.local.b32 [%rd1], %r1;
; CHECK-NEXT: ret;
entry:
%0 = load <2 x i8>, ptr addrspace(1) null, align 8
@@ -58,19 +48,14 @@ declare <2 x i8> @llvm.nvvm.ldg.global.i.v2i8.p1(ptr addrspace(1) %ptr, i32 %ali
define void @ldg(ptr addrspace(1) %ptr) {
; CHECK-LABEL: ldg(
; CHECK: {
-; CHECK-NEXT: .reg .b16 %rs<3>;
-; CHECK-NEXT: .reg .b32 %r<5>;
+; CHECK-NEXT: .reg .b32 %r<2>;
; CHECK-NEXT: .reg .b64 %rd<3>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0: // %entry
; CHECK-NEXT: ld.param.b64 %rd1, [ldg_param_0];
-; CHECK-NEXT: ld.global.v2.b8 {%rs1, %rs2}, [%rd1];
-; CHECK-NEXT: cvt.u32.u16 %r1, %rs2;
-; CHECK-NEXT: cvt.u32.u16 %r2, %rs1;
-; CHECK-NEXT: prmt.b32 %r3, %r2, %r1, 0x3340U;
-; CHECK-NEXT: prmt.b32 %r4, %r3, 0, 0x5410U;
+; CHECK-NEXT: ld.global.b16 %r1, [%rd1];
; CHECK-NEXT: mov.b64 %rd2, 0;
-; CHECK-NEXT: st.local.b32 [%rd2], %r4;
+; CHECK-NEXT: st.local.b32 [%rd2], %r1;
; CHECK-NEXT: ret;
entry:
%0 = tail call <2 x i8> @llvm.nvvm.ldg.global.i.v2i8.p1(ptr addrspace(1) %ptr, i32 2)
@@ -84,19 +69,16 @@ declare <2 x i8> @llvm.nvvm.ldu.global.f.v2i8.p1(ptr addrspace(1) %ptr, i32 %ali
define void @ldu(ptr addrspace(1) %ptr) {
; CHECK-LABEL: ldu(
; CHECK: {
-; CHECK-NEXT: .reg .b16 %rs<3>;
-; CHECK-NEXT: .reg .b32 %r<5>;
+; CHECK-NEXT: .reg .b16 %rs<2>;
+; CHECK-NEXT: .reg .b32 %r<2>;
; CHECK-NEXT: .reg .b64 %rd<3>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0: // %entry
; CHECK-NEXT: ld.param.b64 %rd1, [ldu_param_0];
-; CHECK-NEXT: ldu.global.v2.b8 {%rs1, %rs2}, [%rd1];
-; CHECK-NEXT: cvt.u32.u16 %r1, %rs2;
-; CHECK-NEXT: cvt.u32.u16 %r2, %rs1;
-; CHECK-NEXT: prmt.b32 %r3, %r2, %r1, 0x3340U;
-; CHECK-NEXT: prmt.b32 %r4, %r3, 0, 0x5410U;
+; CHECK-NEXT: ldu.global.b16 %rs1, [%rd1];
+; CHECK-NEXT: cvt.u32.u16 %r1, %rs1;
; CHECK-NEXT: mov.b64 %rd2, 0;
-; CHECK-NEXT: st.local.b32 [%rd2], %r4;
+; CHECK-NEXT: st.local.b32 [%rd2], %r1;
; CHECK-NEXT: ret;
entry:
%0 = tail call <2 x i8> @llvm.nvvm.ldu.global.i.v2i8.p1(ptr addrspace(1) %ptr, i32 2)
@@ -108,18 +90,13 @@ entry:
define void @t3() {
; CHECK-LABEL: t3(
; CHECK: {
-; CHECK-NEXT: .reg .b16 %rs<3>;
-; CHECK-NEXT: .reg .b32 %r<5>;
+; CHECK-NEXT: .reg .b32 %r<2>;
; CHECK-NEXT: .reg .b64 %rd<2>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: mov.b64 %rd1, 0;
-; CHECK-NEXT: ld.global.v2.b8 {%rs1, %rs2}, [%rd1];
-; CHECK-NEXT: cvt.u32.u16 %r1, %rs2;
-; CHECK-NEXT: cvt.u32.u16 %r2, %rs1;
-; CHECK-NEXT: prmt.b32 %r3, %r2, %r1, 0x3340U;
-; CHECK-NEXT: prmt.b32 %r4, %r3, 0, 0x5410U;
-; CHECK-NEXT: st.global.v2.b32 [%rd1], {%r4, 0};
+; CHECK-NEXT: ld.global.b16 %r1, [%rd1];
+; CHECK-NEXT: st.global.v2.b32 [%rd1], {%r1, 0};
; CHECK-NEXT: ret;
%1 = load <2 x i8>, ptr addrspace(1) null, align 2
%insval2 = bitcast <2 x i8> %1 to i16
More information about the llvm-commits
mailing list