[llvm] [NVPTX] Eliminate `prmt`s that result from `BUILD_VECTOR` of `LoadV2` (PR #149581)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Jul 18 13:06:39 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-nvptx
Author: Justin Fargnoli (justinfargnoli)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/149581.diff
2 Files Affected:
- (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+76-1)
- (added) llvm/test/CodeGen/NVPTX/build-vector-combine.ll (+106)
``````````diff
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 7aa06f9079b09..5f98b1a27617d 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -5772,7 +5772,8 @@ static SDValue PerformVSELECTCombine(SDNode *N,
}
static SDValue
-PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
+PerformBUILD_VECTOROfV2i16Combine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI) {
auto VT = N->getValueType(0);
if (!DCI.isAfterLegalizeDAG() ||
// only process v2*16 types
@@ -5833,6 +5834,80 @@ PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
return DAG.getBitcast(VT, PRMT);
}
+static SDValue
+PerformBUILD_VECTOROfTargetLoadCombine(SDNode *N,
+ TargetLowering::DAGCombinerInfo &DCI) {
+ // Match: BUILD_VECTOR of v4i8, where first two elements are from a
+ // NVPTXISD::LoadV2 or NVPTXISD::LDUV2 of i8, and the last two elements are
+ // zero constants. Replace with: zext the loaded i16 to i32, and return as a
+ // bitcast to v4i8.
+ EVT VT = N->getValueType(0);
+ if (VT != MVT::v4i8)
+ return SDValue();
+ // Check operands: [0]=lo, [1]=hi
+ SDValue Op0 = N->getOperand(0);
+ SDValue Op1 = N->getOperand(1);
+ // Check that Op0 and Op1 are from the same NVPTXISD::LoadV2 or
+ // NVPTXISD::LDUV2
+ if (Op0.getNode() != Op1.getNode())
+ return SDValue();
+ if (!(Op0.getOpcode() == NVPTXISD::LoadV2 ||
+ Op0.getOpcode() == NVPTXISD::LDUV2))
+ return SDValue();
+ if (Op0.getValueType() != MVT::i16)
+ return SDValue();
+ if (!(Op0.hasOneUse() && Op1.hasOneUse()))
+ return SDValue();
+
+ // Check operands: [2]= 0 or undef, [3]= 0 or undef
+ SDValue Op2 = N->getOperand(2);
+ SDValue Op3 = N->getOperand(3);
+ if (Op2 != Op3)
+ return SDValue();
+ if (!Op2.isUndef()) {
+ auto *C2 = dyn_cast<ConstantSDNode>(Op2);
+ if (!(C2 && C2->isZero()))
+ return SDValue();
+ }
+
+ // Now, replace with: zext(load i16) -> i32, then bitcast to v4i8
+ auto &DAG = DCI.DAG;
+ // Rebuild the load as i16
+ auto *Load = cast<MemSDNode>(Op0.getNode());
+ SDLoc DL(Load);
+ SDValue LoadI16;
+ if (Load->getOpcode() == NVPTXISD::LoadV2) {
+ LoadI16 = DAG.getLoad(MVT::i16, DL, Load->getChain(), Load->getBasePtr(),
+ Load->getPointerInfo(), Load->getAlign(),
+ Load->getMemOperand()->getFlags());
+ } else {
+ assert(Load->getOpcode() == NVPTXISD::LDUV2);
+ const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+ SmallVector<SDValue, 4> Ops;
+ Ops.push_back(Load->getChain());
+ Ops.push_back(DAG.getConstant(Intrinsic::nvvm_ldu_global_i, DL,
+ TLI.getPointerTy(DAG.getDataLayout())));
+ for (unsigned i = 1; i < Load->getNumOperands(); ++i)
+ Ops.push_back(Load->getOperand(i));
+ SDVTList NodeVTList = DAG.getVTList(MVT::i16, MVT::Other);
+ LoadI16 = DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, NodeVTList,
+ Ops, MVT::i16, Load->getPointerInfo(),
+ Load->getAlign());
+ }
+ DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 2), LoadI16.getValue(1));
+ SDValue Zext = DAG.getZExtOrTrunc(LoadI16, DL, MVT::i32);
+ return DAG.getBitcast(MVT::v4i8, Zext);
+}
+
+static SDValue
+PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
+ if (const auto V = PerformBUILD_VECTOROfV2i16Combine(N, DCI))
+ return V;
+ if (const auto V = PerformBUILD_VECTOROfTargetLoadCombine(N, DCI))
+ return V;
+ return SDValue();
+}
+
static SDValue combineADDRSPACECAST(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI) {
auto *ASCN1 = cast<AddrSpaceCastSDNode>(N);
diff --git a/llvm/test/CodeGen/NVPTX/build-vector-combine.ll b/llvm/test/CodeGen/NVPTX/build-vector-combine.ll
new file mode 100644
index 0000000000000..019bd3bde8761
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/build-vector-combine.ll
@@ -0,0 +1,106 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -march=nvptx64 | FileCheck %s
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 | %ptxas-verify %}
+
+target datalayout = "e-p:64:64:64-p3:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-i128:128:128-f32:32:32-f64:64:64-f128:128:128-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64-a:8:8"
+target triple = "nvptx64-nvidia-cuda"
+
+define void @t1() {
+; CHECK-LABEL: t1(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<2>;
+; CHECK-NEXT: .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0: // %entry
+; CHECK-NEXT: mov.b64 %rd1, 0;
+; CHECK-NEXT: ld.global.b16 %r1, [%rd1];
+; CHECK-NEXT: st.global.v4.b32 [%rd1], {%r1, 0, 0, 0};
+; CHECK-NEXT: ret;
+entry:
+ %0 = load <2 x i8>, ptr addrspace(1) null, align 4
+ %1 = shufflevector <2 x i8> %0, <2 x i8> zeroinitializer, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+ %2 = bitcast <4 x i8> %1 to i32
+ %3 = insertelement <4 x i32> zeroinitializer, i32 %2, i64 0
+ store <4 x i32> %3, ptr addrspace(1) null, align 16
+ ret void
+}
+
+define void @t2() {
+; CHECK-LABEL: t2(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<2>;
+; CHECK-NEXT: .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0: // %entry
+; CHECK-NEXT: mov.b64 %rd1, 0;
+; CHECK-NEXT: ld.global.b16 %r1, [%rd1];
+; CHECK-NEXT: st.local.b32 [%rd1], %r1;
+; CHECK-NEXT: ret;
+entry:
+ %0 = load <2 x i8>, ptr addrspace(1) null, align 8
+ %1 = shufflevector <2 x i8> %0, <2 x i8> zeroinitializer, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+ store <4 x i8> %1, ptr addrspace(5) null, align 8
+ ret void
+}
+
+declare <2 x i8> @llvm.nvvm.ldg.global.i.v2i8.p1(ptr addrspace(1) %ptr, i32 %align)
+
+define void @ldg(ptr addrspace(1) %ptr) {
+; CHECK-LABEL: ldg(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<2>;
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0: // %entry
+; CHECK-NEXT: ld.param.b64 %rd1, [ldg_param_0];
+; CHECK-NEXT: ld.global.b16 %r1, [%rd1];
+; CHECK-NEXT: mov.b64 %rd2, 0;
+; CHECK-NEXT: st.local.b32 [%rd2], %r1;
+; CHECK-NEXT: ret;
+entry:
+ %0 = tail call <2 x i8> @llvm.nvvm.ldg.global.i.v2i8.p1(ptr addrspace(1) %ptr, i32 2)
+ %1 = shufflevector <2 x i8> %0, <2 x i8> zeroinitializer, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+ store <4 x i8> %1, ptr addrspace(5) null, align 8
+ ret void
+}
+
+declare <2 x i8> @llvm.nvvm.ldu.global.f.v2i8.p1(ptr addrspace(1) %ptr, i32 %align)
+
+define void @ldu(ptr addrspace(1) %ptr) {
+; CHECK-LABEL: ldu(
+; CHECK: {
+; CHECK-NEXT: .reg .b16 %rs<2>;
+; CHECK-NEXT: .reg .b32 %r<2>;
+; CHECK-NEXT: .reg .b64 %rd<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0: // %entry
+; CHECK-NEXT: ld.param.b64 %rd1, [ldu_param_0];
+; CHECK-NEXT: ldu.global.b16 %rs1, [%rd1];
+; CHECK-NEXT: cvt.u32.u16 %r1, %rs1;
+; CHECK-NEXT: mov.b64 %rd2, 0;
+; CHECK-NEXT: st.local.b32 [%rd2], %r1;
+; CHECK-NEXT: ret;
+entry:
+ %0 = tail call <2 x i8> @llvm.nvvm.ldu.global.i.v2i8.p1(ptr addrspace(1) %ptr, i32 2)
+ %1 = shufflevector <2 x i8> %0, <2 x i8> zeroinitializer, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+ store <4 x i8> %1, ptr addrspace(5) null, align 8
+ ret void
+}
+
+define void @t3() {
+; CHECK-LABEL: t3(
+; CHECK: {
+; CHECK-NEXT: .reg .b32 %r<2>;
+; CHECK-NEXT: .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT: // %bb.0:
+; CHECK-NEXT: mov.b64 %rd1, 0;
+; CHECK-NEXT: ld.global.b16 %r1, [%rd1];
+; CHECK-NEXT: st.global.v2.b32 [%rd1], {%r1, 0};
+; CHECK-NEXT: ret;
+ %1 = load <2 x i8>, ptr addrspace(1) null, align 2
+ %insval2 = bitcast <2 x i8> %1 to i16
+ %2 = insertelement <4 x i16> zeroinitializer, i16 %insval2, i32 0
+ store <4 x i16> %2, ptr addrspace(1) null, align 8
+ ret void
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/149581
More information about the llvm-commits
mailing list