[llvm] [NVPTX] Add SimplifyDemandedBitsForTargetNode for PRMT (PR #149395)
Alex MacLean via llvm-commits
llvm-commits at lists.llvm.org
Fri Jul 18 14:29:38 PDT 2025
https://github.com/AlexMaclean updated https://github.com/llvm/llvm-project/pull/149395
>From f63a62dbde3f486ed6f34a5a8472a80ac545d9c0 Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Thu, 17 Jul 2025 18:58:39 +0000
Subject: [PATCH 1/2] [NVPTX] Add SimplifyDemandedBitsForTargetNode for PRMT
---
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 99 +++++++++++++++-
llvm/lib/Target/NVPTX/NVPTXISelLowering.h | 5 +
.../test/CodeGen/NVPTX/LoadStoreVectorizer.ll | 102 ++++++++---------
llvm/test/CodeGen/NVPTX/extractelement.ll | 71 ++++++------
llvm/test/CodeGen/NVPTX/i8x4-instructions.ll | 106 ++++++++----------
5 files changed, 234 insertions(+), 149 deletions(-)
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 7aa06f9079b09..e9190b6097709 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -6544,4 +6544,101 @@ void NVPTXTargetLowering::computeKnownBitsForTargetNode(
default:
break;
}
-}
\ No newline at end of file
+}
+
+static void getPRMTDemandedBits(const APInt &SelectorVal,
+ const APInt &DemandedBits, APInt &DemandedLHS,
+ APInt &DemandedRHS) {
+ DemandedLHS = APInt(32, 0);
+ DemandedRHS = APInt(32, 0);
+
+ for (unsigned I : llvm::seq(4)) {
+ if (DemandedBits.extractBits(8, I * 8).isZero())
+ continue;
+
+ APInt Sel = SelectorVal.extractBits(4, I * 4);
+ unsigned Idx = Sel.getLoBits(3).getZExtValue();
+ unsigned Sign = Sel.getHiBits(1).getZExtValue();
+
+ APInt &Src = Idx < 4 ? DemandedLHS : DemandedRHS;
+ unsigned ByteStart = (Idx % 4) * 8;
+ if (Sign)
+ Src.setBit(ByteStart + 7);
+ else
+ Src.setBits(ByteStart, ByteStart + 8);
+ }
+}
+
+// Replace undef with 0 as this is easier for other optimizations such as
+// known bits.
+static SDValue canonicalizePRMTInput(SDValue Op, SelectionDAG &DAG) {
+ if (!Op)
+ return SDValue();
+ if (Op.isUndef())
+ return DAG.getConstant(0, SDLoc(), MVT::i32);
+ return Op;
+}
+
+static SDValue simplifyDemandedBitsForPRMT(SDValue PRMT,
+ const APInt &DemandedBits,
+ SelectionDAG &DAG,
+ const TargetLowering &TLI,
+ unsigned Depth) {
+ SDValue Op0 = PRMT.getOperand(0);
+ SDValue Op1 = PRMT.getOperand(1);
+ ConstantSDNode *Selector = dyn_cast<ConstantSDNode>(PRMT.getOperand(2));
+ unsigned Mode = PRMT.getConstantOperandVal(3);
+ if (!Selector)
+ return SDValue();
+
+ const APInt SelectorVal = getPRMTSelector(Selector->getAPIntValue(), Mode);
+
+ // Try to simplify the PRMT to one of the inputs if the used bytes are all
+ // from the same input in the correct order.
+ const unsigned LeadingBytes = DemandedBits.countLeadingZeros() / 8;
+ const unsigned SelBits = (4 - LeadingBytes) * 4;
+ if (SelectorVal.getLoBits(SelBits) == APInt(32, 0x3210).getLoBits(SelBits))
+ return Op0;
+ if (SelectorVal.getLoBits(SelBits) == APInt(32, 0x7654).getLoBits(SelBits))
+ return Op1;
+
+ APInt DemandedLHS, DemandedRHS;
+ getPRMTDemandedBits(SelectorVal, DemandedBits, DemandedLHS, DemandedRHS);
+
+ // Attempt to avoid multi-use ops if we don't need anything from them.
+ SDValue DemandedOp0 =
+ TLI.SimplifyMultipleUseDemandedBits(Op0, DemandedLHS, DAG, Depth + 1);
+ SDValue DemandedOp1 =
+ TLI.SimplifyMultipleUseDemandedBits(Op1, DemandedRHS, DAG, Depth + 1);
+
+ DemandedOp0 = canonicalizePRMTInput(DemandedOp0, DAG);
+ DemandedOp1 = canonicalizePRMTInput(DemandedOp1, DAG);
+ if (DemandedOp0 != Op0 || DemandedOp1 != Op1) {
+ Op0 = DemandedOp0 ? DemandedOp0 : Op0;
+ Op1 = DemandedOp1 ? DemandedOp1 : Op1;
+ return getPRMT(Op0, Op1, SelectorVal.getZExtValue(), SDLoc(PRMT), DAG);
+ }
+
+ return SDValue();
+}
+
+bool NVPTXTargetLowering::SimplifyDemandedBitsForTargetNode(
+ SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
+ KnownBits &Known, TargetLoweringOpt &TLO, unsigned Depth) const {
+ Known.resetAll();
+
+ switch (Op.getOpcode()) {
+ case NVPTXISD::PRMT:
+ if (SDValue Result = simplifyDemandedBitsForPRMT(Op, DemandedBits, TLO.DAG,
+ *this, Depth)) {
+ TLO.CombineTo(Op, Result);
+ return true;
+ }
+ break;
+ default:
+ break;
+ }
+
+ computeKnownBitsForTargetNode(Op, Known, DemandedElts, TLO.DAG, Depth);
+ return false;
+}
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index bc3548c0272bb..228e2aac47aec 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -275,6 +275,11 @@ class NVPTXTargetLowering : public TargetLowering {
const APInt &DemandedElts,
const SelectionDAG &DAG,
unsigned Depth = 0) const override;
+ bool SimplifyDemandedBitsForTargetNode(SDValue Op, const APInt &DemandedBits,
+ const APInt &DemandedElts,
+ KnownBits &Known,
+ TargetLoweringOpt &TLO,
+ unsigned Depth = 0) const override;
private:
const NVPTXSubtarget &STI; // cache the subtarget here
diff --git a/llvm/test/CodeGen/NVPTX/LoadStoreVectorizer.ll b/llvm/test/CodeGen/NVPTX/LoadStoreVectorizer.ll
index 23832a9cb5c58..dd9a472984c25 100644
--- a/llvm/test/CodeGen/NVPTX/LoadStoreVectorizer.ll
+++ b/llvm/test/CodeGen/NVPTX/LoadStoreVectorizer.ll
@@ -181,32 +181,32 @@ define void @combine_v16i8(ptr noundef align 16 %ptr1, ptr noundef align 16 %ptr
; ENABLED-NEXT: prmt.b32 %r5, %r4, 0, 0x7773U;
; ENABLED-NEXT: prmt.b32 %r6, %r4, 0, 0x7772U;
; ENABLED-NEXT: prmt.b32 %r7, %r4, 0, 0x7771U;
-; ENABLED-NEXT: prmt.b32 %r8, %r4, 0, 0x7770U;
-; ENABLED-NEXT: prmt.b32 %r9, %r3, 0, 0x7773U;
-; ENABLED-NEXT: prmt.b32 %r10, %r3, 0, 0x7772U;
-; ENABLED-NEXT: prmt.b32 %r11, %r3, 0, 0x7771U;
-; ENABLED-NEXT: prmt.b32 %r12, %r3, 0, 0x7770U;
-; ENABLED-NEXT: prmt.b32 %r13, %r2, 0, 0x7773U;
-; ENABLED-NEXT: prmt.b32 %r14, %r2, 0, 0x7772U;
-; ENABLED-NEXT: prmt.b32 %r15, %r2, 0, 0x7771U;
-; ENABLED-NEXT: prmt.b32 %r16, %r2, 0, 0x7770U;
-; ENABLED-NEXT: prmt.b32 %r17, %r1, 0, 0x7773U;
-; ENABLED-NEXT: prmt.b32 %r18, %r1, 0, 0x7772U;
-; ENABLED-NEXT: prmt.b32 %r19, %r1, 0, 0x7771U;
-; ENABLED-NEXT: prmt.b32 %r20, %r1, 0, 0x7770U;
+; ENABLED-NEXT: prmt.b32 %r8, %r3, 0, 0x7773U;
+; ENABLED-NEXT: prmt.b32 %r9, %r3, 0, 0x7772U;
+; ENABLED-NEXT: prmt.b32 %r10, %r3, 0, 0x7771U;
+; ENABLED-NEXT: prmt.b32 %r11, %r2, 0, 0x7773U;
+; ENABLED-NEXT: prmt.b32 %r12, %r2, 0, 0x7772U;
+; ENABLED-NEXT: prmt.b32 %r13, %r2, 0, 0x7771U;
+; ENABLED-NEXT: prmt.b32 %r14, %r1, 0, 0x7773U;
+; ENABLED-NEXT: prmt.b32 %r15, %r1, 0, 0x7772U;
+; ENABLED-NEXT: prmt.b32 %r16, %r1, 0, 0x7771U;
; ENABLED-NEXT: ld.param.b64 %rd2, [combine_v16i8_param_1];
-; ENABLED-NEXT: add.s32 %r21, %r20, %r19;
-; ENABLED-NEXT: add.s32 %r22, %r21, %r18;
-; ENABLED-NEXT: add.s32 %r23, %r22, %r17;
-; ENABLED-NEXT: add.s32 %r24, %r23, %r16;
-; ENABLED-NEXT: add.s32 %r25, %r24, %r15;
-; ENABLED-NEXT: add.s32 %r26, %r25, %r14;
-; ENABLED-NEXT: add.s32 %r27, %r26, %r13;
-; ENABLED-NEXT: add.s32 %r28, %r27, %r12;
-; ENABLED-NEXT: add.s32 %r29, %r28, %r11;
-; ENABLED-NEXT: add.s32 %r30, %r29, %r10;
-; ENABLED-NEXT: add.s32 %r31, %r30, %r9;
-; ENABLED-NEXT: add.s32 %r32, %r31, %r8;
+; ENABLED-NEXT: and.b32 %r17, %r1, 255;
+; ENABLED-NEXT: and.b32 %r18, %r2, 255;
+; ENABLED-NEXT: and.b32 %r19, %r3, 255;
+; ENABLED-NEXT: and.b32 %r20, %r4, 255;
+; ENABLED-NEXT: add.s32 %r21, %r17, %r16;
+; ENABLED-NEXT: add.s32 %r22, %r21, %r15;
+; ENABLED-NEXT: add.s32 %r23, %r22, %r14;
+; ENABLED-NEXT: add.s32 %r24, %r23, %r18;
+; ENABLED-NEXT: add.s32 %r25, %r24, %r13;
+; ENABLED-NEXT: add.s32 %r26, %r25, %r12;
+; ENABLED-NEXT: add.s32 %r27, %r26, %r11;
+; ENABLED-NEXT: add.s32 %r28, %r27, %r19;
+; ENABLED-NEXT: add.s32 %r29, %r28, %r10;
+; ENABLED-NEXT: add.s32 %r30, %r29, %r9;
+; ENABLED-NEXT: add.s32 %r31, %r30, %r8;
+; ENABLED-NEXT: add.s32 %r32, %r31, %r20;
; ENABLED-NEXT: add.s32 %r33, %r32, %r7;
; ENABLED-NEXT: add.s32 %r34, %r33, %r6;
; ENABLED-NEXT: add.s32 %r35, %r34, %r5;
@@ -332,36 +332,36 @@ define void @combine_v16i8_unaligned(ptr noundef align 8 %ptr1, ptr noundef alig
; ENABLED-NEXT: prmt.b32 %r3, %r2, 0, 0x7773U;
; ENABLED-NEXT: prmt.b32 %r4, %r2, 0, 0x7772U;
; ENABLED-NEXT: prmt.b32 %r5, %r2, 0, 0x7771U;
-; ENABLED-NEXT: prmt.b32 %r6, %r2, 0, 0x7770U;
-; ENABLED-NEXT: prmt.b32 %r7, %r1, 0, 0x7773U;
-; ENABLED-NEXT: prmt.b32 %r8, %r1, 0, 0x7772U;
-; ENABLED-NEXT: prmt.b32 %r9, %r1, 0, 0x7771U;
-; ENABLED-NEXT: prmt.b32 %r10, %r1, 0, 0x7770U;
+; ENABLED-NEXT: prmt.b32 %r6, %r1, 0, 0x7773U;
+; ENABLED-NEXT: prmt.b32 %r7, %r1, 0, 0x7772U;
+; ENABLED-NEXT: prmt.b32 %r8, %r1, 0, 0x7771U;
; ENABLED-NEXT: ld.param.b64 %rd2, [combine_v16i8_unaligned_param_1];
-; ENABLED-NEXT: ld.v2.b32 {%r11, %r12}, [%rd1+8];
-; ENABLED-NEXT: prmt.b32 %r13, %r12, 0, 0x7773U;
-; ENABLED-NEXT: prmt.b32 %r14, %r12, 0, 0x7772U;
-; ENABLED-NEXT: prmt.b32 %r15, %r12, 0, 0x7771U;
-; ENABLED-NEXT: prmt.b32 %r16, %r12, 0, 0x7770U;
-; ENABLED-NEXT: prmt.b32 %r17, %r11, 0, 0x7773U;
-; ENABLED-NEXT: prmt.b32 %r18, %r11, 0, 0x7772U;
-; ENABLED-NEXT: prmt.b32 %r19, %r11, 0, 0x7771U;
-; ENABLED-NEXT: prmt.b32 %r20, %r11, 0, 0x7770U;
-; ENABLED-NEXT: add.s32 %r21, %r10, %r9;
-; ENABLED-NEXT: add.s32 %r22, %r21, %r8;
-; ENABLED-NEXT: add.s32 %r23, %r22, %r7;
-; ENABLED-NEXT: add.s32 %r24, %r23, %r6;
+; ENABLED-NEXT: ld.v2.b32 {%r9, %r10}, [%rd1+8];
+; ENABLED-NEXT: prmt.b32 %r11, %r10, 0, 0x7773U;
+; ENABLED-NEXT: prmt.b32 %r12, %r10, 0, 0x7772U;
+; ENABLED-NEXT: prmt.b32 %r13, %r10, 0, 0x7771U;
+; ENABLED-NEXT: prmt.b32 %r14, %r9, 0, 0x7773U;
+; ENABLED-NEXT: prmt.b32 %r15, %r9, 0, 0x7772U;
+; ENABLED-NEXT: prmt.b32 %r16, %r9, 0, 0x7771U;
+; ENABLED-NEXT: and.b32 %r17, %r1, 255;
+; ENABLED-NEXT: and.b32 %r18, %r2, 255;
+; ENABLED-NEXT: and.b32 %r19, %r9, 255;
+; ENABLED-NEXT: and.b32 %r20, %r10, 255;
+; ENABLED-NEXT: add.s32 %r21, %r17, %r8;
+; ENABLED-NEXT: add.s32 %r22, %r21, %r7;
+; ENABLED-NEXT: add.s32 %r23, %r22, %r6;
+; ENABLED-NEXT: add.s32 %r24, %r23, %r18;
; ENABLED-NEXT: add.s32 %r25, %r24, %r5;
; ENABLED-NEXT: add.s32 %r26, %r25, %r4;
; ENABLED-NEXT: add.s32 %r27, %r26, %r3;
-; ENABLED-NEXT: add.s32 %r28, %r27, %r20;
-; ENABLED-NEXT: add.s32 %r29, %r28, %r19;
-; ENABLED-NEXT: add.s32 %r30, %r29, %r18;
-; ENABLED-NEXT: add.s32 %r31, %r30, %r17;
-; ENABLED-NEXT: add.s32 %r32, %r31, %r16;
-; ENABLED-NEXT: add.s32 %r33, %r32, %r15;
-; ENABLED-NEXT: add.s32 %r34, %r33, %r14;
-; ENABLED-NEXT: add.s32 %r35, %r34, %r13;
+; ENABLED-NEXT: add.s32 %r28, %r27, %r19;
+; ENABLED-NEXT: add.s32 %r29, %r28, %r16;
+; ENABLED-NEXT: add.s32 %r30, %r29, %r15;
+; ENABLED-NEXT: add.s32 %r31, %r30, %r14;
+; ENABLED-NEXT: add.s32 %r32, %r31, %r20;
+; ENABLED-NEXT: add.s32 %r33, %r32, %r13;
+; ENABLED-NEXT: add.s32 %r34, %r33, %r12;
+; ENABLED-NEXT: add.s32 %r35, %r34, %r11;
; ENABLED-NEXT: st.b32 [%rd2], %r35;
; ENABLED-NEXT: ret;
;
diff --git a/llvm/test/CodeGen/NVPTX/extractelement.ll b/llvm/test/CodeGen/NVPTX/extractelement.ll
index 80980efbab05b..d61a63ce24f89 100644
--- a/llvm/test/CodeGen/NVPTX/extractelement.ll
+++ b/llvm/test/CodeGen/NVPTX/extractelement.ll
@@ -56,23 +56,22 @@ define i16 @test_v4i8(i32 %a) {
; CHECK-LABEL: test_v4i8(
; CHECK: {
; CHECK-NEXT: .reg .b16 %rs<8>;
-; CHECK-NEXT: .reg .b32 %r<7>;
+; CHECK-NEXT: .reg .b32 %r<6>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.b32 %r1, [test_v4i8_param_0];
-; CHECK-NEXT: prmt.b32 %r2, %r1, 0, 0x8880U;
-; CHECK-NEXT: cvt.u16.u32 %rs1, %r2;
-; CHECK-NEXT: prmt.b32 %r3, %r1, 0, 0x9991U;
-; CHECK-NEXT: cvt.u16.u32 %rs2, %r3;
-; CHECK-NEXT: prmt.b32 %r4, %r1, 0, 0xaaa2U;
-; CHECK-NEXT: cvt.u16.u32 %rs3, %r4;
-; CHECK-NEXT: prmt.b32 %r5, %r1, 0, 0xbbb3U;
-; CHECK-NEXT: cvt.u16.u32 %rs4, %r5;
+; CHECK-NEXT: cvt.s8.s32 %rs1, %r1;
+; CHECK-NEXT: prmt.b32 %r2, %r1, 0, 0x9991U;
+; CHECK-NEXT: cvt.u16.u32 %rs2, %r2;
+; CHECK-NEXT: prmt.b32 %r3, %r1, 0, 0xaaa2U;
+; CHECK-NEXT: cvt.u16.u32 %rs3, %r3;
+; CHECK-NEXT: prmt.b32 %r4, %r1, 0, 0xbbb3U;
+; CHECK-NEXT: cvt.u16.u32 %rs4, %r4;
; CHECK-NEXT: add.s16 %rs5, %rs1, %rs2;
; CHECK-NEXT: add.s16 %rs6, %rs3, %rs4;
; CHECK-NEXT: add.s16 %rs7, %rs5, %rs6;
-; CHECK-NEXT: cvt.u32.u16 %r6, %rs7;
-; CHECK-NEXT: st.param.b32 [func_retval0], %r6;
+; CHECK-NEXT: cvt.u32.u16 %r5, %rs7;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r5;
; CHECK-NEXT: ret;
%v = bitcast i32 %a to <4 x i8>
%r0 = extractelement <4 x i8> %v, i64 0
@@ -96,7 +95,7 @@ define i32 @test_v4i8_s32(i32 %a) {
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.b32 %r1, [test_v4i8_s32_param_0];
-; CHECK-NEXT: prmt.b32 %r2, %r1, 0, 0x8880U;
+; CHECK-NEXT: cvt.s32.s8 %r2, %r1;
; CHECK-NEXT: prmt.b32 %r3, %r1, 0, 0x9991U;
; CHECK-NEXT: prmt.b32 %r4, %r1, 0, 0xaaa2U;
; CHECK-NEXT: prmt.b32 %r5, %r1, 0, 0xbbb3U;
@@ -127,12 +126,12 @@ define i32 @test_v4i8_u32(i32 %a) {
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.b32 %r1, [test_v4i8_u32_param_0];
-; CHECK-NEXT: prmt.b32 %r2, %r1, 0, 0x7770U;
-; CHECK-NEXT: prmt.b32 %r3, %r1, 0, 0x7771U;
-; CHECK-NEXT: prmt.b32 %r4, %r1, 0, 0x7772U;
-; CHECK-NEXT: prmt.b32 %r5, %r1, 0, 0x7773U;
-; CHECK-NEXT: add.s32 %r6, %r2, %r3;
-; CHECK-NEXT: add.s32 %r7, %r4, %r5;
+; CHECK-NEXT: prmt.b32 %r2, %r1, 0, 0x7771U;
+; CHECK-NEXT: prmt.b32 %r3, %r1, 0, 0x7772U;
+; CHECK-NEXT: prmt.b32 %r4, %r1, 0, 0x7773U;
+; CHECK-NEXT: and.b32 %r5, %r1, 255;
+; CHECK-NEXT: add.s32 %r6, %r5, %r2;
+; CHECK-NEXT: add.s32 %r7, %r3, %r4;
; CHECK-NEXT: add.s32 %r8, %r6, %r7;
; CHECK-NEXT: st.param.b32 [func_retval0], %r8;
; CHECK-NEXT: ret;
@@ -157,26 +156,24 @@ define i16 @test_v8i8(i64 %a) {
; CHECK-LABEL: test_v8i8(
; CHECK: {
; CHECK-NEXT: .reg .b16 %rs<16>;
-; CHECK-NEXT: .reg .b32 %r<12>;
+; CHECK-NEXT: .reg .b32 %r<10>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.v2.b32 {%r1, %r2}, [test_v8i8_param_0];
-; CHECK-NEXT: prmt.b32 %r3, %r1, 0, 0x8880U;
-; CHECK-NEXT: cvt.u16.u32 %rs1, %r3;
-; CHECK-NEXT: prmt.b32 %r4, %r1, 0, 0x9991U;
-; CHECK-NEXT: cvt.u16.u32 %rs2, %r4;
-; CHECK-NEXT: prmt.b32 %r5, %r1, 0, 0xaaa2U;
-; CHECK-NEXT: cvt.u16.u32 %rs3, %r5;
-; CHECK-NEXT: prmt.b32 %r6, %r1, 0, 0xbbb3U;
-; CHECK-NEXT: cvt.u16.u32 %rs4, %r6;
-; CHECK-NEXT: prmt.b32 %r7, %r2, 0, 0x8880U;
-; CHECK-NEXT: cvt.u16.u32 %rs5, %r7;
-; CHECK-NEXT: prmt.b32 %r8, %r2, 0, 0x9991U;
-; CHECK-NEXT: cvt.u16.u32 %rs6, %r8;
-; CHECK-NEXT: prmt.b32 %r9, %r2, 0, 0xaaa2U;
-; CHECK-NEXT: cvt.u16.u32 %rs7, %r9;
-; CHECK-NEXT: prmt.b32 %r10, %r2, 0, 0xbbb3U;
-; CHECK-NEXT: cvt.u16.u32 %rs8, %r10;
+; CHECK-NEXT: cvt.s8.s32 %rs1, %r1;
+; CHECK-NEXT: prmt.b32 %r3, %r1, 0, 0x9991U;
+; CHECK-NEXT: cvt.u16.u32 %rs2, %r3;
+; CHECK-NEXT: prmt.b32 %r4, %r1, 0, 0xaaa2U;
+; CHECK-NEXT: cvt.u16.u32 %rs3, %r4;
+; CHECK-NEXT: prmt.b32 %r5, %r1, 0, 0xbbb3U;
+; CHECK-NEXT: cvt.u16.u32 %rs4, %r5;
+; CHECK-NEXT: cvt.s8.s32 %rs5, %r2;
+; CHECK-NEXT: prmt.b32 %r6, %r2, 0, 0x9991U;
+; CHECK-NEXT: cvt.u16.u32 %rs6, %r6;
+; CHECK-NEXT: prmt.b32 %r7, %r2, 0, 0xaaa2U;
+; CHECK-NEXT: cvt.u16.u32 %rs7, %r7;
+; CHECK-NEXT: prmt.b32 %r8, %r2, 0, 0xbbb3U;
+; CHECK-NEXT: cvt.u16.u32 %rs8, %r8;
; CHECK-NEXT: add.s16 %rs9, %rs1, %rs2;
; CHECK-NEXT: add.s16 %rs10, %rs3, %rs4;
; CHECK-NEXT: add.s16 %rs11, %rs5, %rs6;
@@ -184,8 +181,8 @@ define i16 @test_v8i8(i64 %a) {
; CHECK-NEXT: add.s16 %rs13, %rs9, %rs10;
; CHECK-NEXT: add.s16 %rs14, %rs11, %rs12;
; CHECK-NEXT: add.s16 %rs15, %rs13, %rs14;
-; CHECK-NEXT: cvt.u32.u16 %r11, %rs15;
-; CHECK-NEXT: st.param.b32 [func_retval0], %r11;
+; CHECK-NEXT: cvt.u32.u16 %r9, %rs15;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r9;
; CHECK-NEXT: ret;
%v = bitcast i64 %a to <8 x i8>
%r0 = extractelement <8 x i8> %v, i64 0
diff --git a/llvm/test/CodeGen/NVPTX/i8x4-instructions.ll b/llvm/test/CodeGen/NVPTX/i8x4-instructions.ll
index aba20e6b0f27f..5a11057cf0db1 100644
--- a/llvm/test/CodeGen/NVPTX/i8x4-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/i8x4-instructions.ll
@@ -2040,7 +2040,7 @@ define void @test_srem_v4i8(ptr %a, ptr %b, ptr %c) {
; O0-LABEL: test_srem_v4i8(
; O0: {
; O0-NEXT: .reg .b16 %rs<13>;
-; O0-NEXT: .reg .b32 %r<18>;
+; O0-NEXT: .reg .b32 %r<16>;
; O0-NEXT: .reg .b64 %rd<4>;
; O0-EMPTY:
; O0-NEXT: // %bb.0: // %entry
@@ -2062,27 +2062,25 @@ define void @test_srem_v4i8(ptr %a, ptr %b, ptr %c) {
; O0-NEXT: rem.s16 %rs6, %rs5, %rs4;
; O0-NEXT: cvt.u32.u16 %r8, %rs6;
; O0-NEXT: prmt.b32 %r9, %r8, %r5, 0x3340U;
-; O0-NEXT: prmt.b32 %r10, %r2, 0, 0x9991U;
-; O0-NEXT: cvt.u16.u32 %rs7, %r10;
-; O0-NEXT: prmt.b32 %r11, %r1, 0, 0x9991U;
-; O0-NEXT: cvt.u16.u32 %rs8, %r11;
+; O0-NEXT: cvt.s8.s32 %rs7, %r2;
+; O0-NEXT: cvt.s8.s32 %rs8, %r1;
; O0-NEXT: rem.s16 %rs9, %rs8, %rs7;
-; O0-NEXT: cvt.u32.u16 %r12, %rs9;
-; O0-NEXT: prmt.b32 %r13, %r2, 0, 0x8880U;
-; O0-NEXT: cvt.u16.u32 %rs10, %r13;
-; O0-NEXT: prmt.b32 %r14, %r1, 0, 0x8880U;
-; O0-NEXT: cvt.u16.u32 %rs11, %r14;
+; O0-NEXT: cvt.u32.u16 %r10, %rs9;
+; O0-NEXT: prmt.b32 %r11, %r2, 0, 0x9991U;
+; O0-NEXT: cvt.u16.u32 %rs10, %r11;
+; O0-NEXT: prmt.b32 %r12, %r1, 0, 0x9991U;
+; O0-NEXT: cvt.u16.u32 %rs11, %r12;
; O0-NEXT: rem.s16 %rs12, %rs11, %rs10;
-; O0-NEXT: cvt.u32.u16 %r15, %rs12;
-; O0-NEXT: prmt.b32 %r16, %r15, %r12, 0x3340U;
-; O0-NEXT: prmt.b32 %r17, %r16, %r9, 0x5410U;
-; O0-NEXT: st.b32 [%rd3], %r17;
+; O0-NEXT: cvt.u32.u16 %r13, %rs12;
+; O0-NEXT: prmt.b32 %r14, %r10, %r13, 0x3340U;
+; O0-NEXT: prmt.b32 %r15, %r14, %r9, 0x5410U;
+; O0-NEXT: st.b32 [%rd3], %r15;
; O0-NEXT: ret;
;
; O3-LABEL: test_srem_v4i8(
; O3: {
; O3-NEXT: .reg .b16 %rs<13>;
-; O3-NEXT: .reg .b32 %r<18>;
+; O3-NEXT: .reg .b32 %r<16>;
; O3-NEXT: .reg .b64 %rd<4>;
; O3-EMPTY:
; O3-NEXT: // %bb.0: // %entry
@@ -2104,21 +2102,19 @@ define void @test_srem_v4i8(ptr %a, ptr %b, ptr %c) {
; O3-NEXT: rem.s16 %rs6, %rs5, %rs4;
; O3-NEXT: cvt.u32.u16 %r8, %rs6;
; O3-NEXT: prmt.b32 %r9, %r8, %r5, 0x3340U;
-; O3-NEXT: prmt.b32 %r10, %r2, 0, 0x9991U;
-; O3-NEXT: cvt.u16.u32 %rs7, %r10;
-; O3-NEXT: prmt.b32 %r11, %r1, 0, 0x9991U;
-; O3-NEXT: cvt.u16.u32 %rs8, %r11;
+; O3-NEXT: cvt.s8.s32 %rs7, %r2;
+; O3-NEXT: cvt.s8.s32 %rs8, %r1;
; O3-NEXT: rem.s16 %rs9, %rs8, %rs7;
-; O3-NEXT: cvt.u32.u16 %r12, %rs9;
-; O3-NEXT: prmt.b32 %r13, %r2, 0, 0x8880U;
-; O3-NEXT: cvt.u16.u32 %rs10, %r13;
-; O3-NEXT: prmt.b32 %r14, %r1, 0, 0x8880U;
-; O3-NEXT: cvt.u16.u32 %rs11, %r14;
+; O3-NEXT: cvt.u32.u16 %r10, %rs9;
+; O3-NEXT: prmt.b32 %r11, %r2, 0, 0x9991U;
+; O3-NEXT: cvt.u16.u32 %rs10, %r11;
+; O3-NEXT: prmt.b32 %r12, %r1, 0, 0x9991U;
+; O3-NEXT: cvt.u16.u32 %rs11, %r12;
; O3-NEXT: rem.s16 %rs12, %rs11, %rs10;
-; O3-NEXT: cvt.u32.u16 %r15, %rs12;
-; O3-NEXT: prmt.b32 %r16, %r15, %r12, 0x3340U;
-; O3-NEXT: prmt.b32 %r17, %r16, %r9, 0x5410U;
-; O3-NEXT: st.b32 [%rd3], %r17;
+; O3-NEXT: cvt.u32.u16 %r13, %rs12;
+; O3-NEXT: prmt.b32 %r14, %r10, %r13, 0x3340U;
+; O3-NEXT: prmt.b32 %r15, %r14, %r9, 0x5410U;
+; O3-NEXT: st.b32 [%rd3], %r15;
; O3-NEXT: ret;
entry:
%t57 = load <4 x i8>, ptr %a, align 4
@@ -2138,7 +2134,7 @@ define void @test_srem_v3i8(ptr %a, ptr %b, ptr %c) {
; O0-LABEL: test_srem_v3i8(
; O0: {
; O0-NEXT: .reg .b16 %rs<20>;
-; O0-NEXT: .reg .b32 %r<14>;
+; O0-NEXT: .reg .b32 %r<8>;
; O0-NEXT: .reg .b64 %rd<4>;
; O0-EMPTY:
; O0-NEXT: // %bb.0: // %entry
@@ -2157,25 +2153,19 @@ define void @test_srem_v3i8(ptr %a, ptr %b, ptr %c) {
; O0-NEXT: or.b16 %rs9, %rs8, %rs6;
; O0-NEXT: cvt.u32.u16 %r2, %rs9;
; O0-NEXT: ld.s8 %rs10, [%rd2+2];
-; O0-NEXT: prmt.b32 %r3, %r2, 0, 0x9991U;
-; O0-NEXT: cvt.u16.u32 %rs11, %r3;
-; O0-NEXT: prmt.b32 %r4, %r1, 0, 0x9991U;
-; O0-NEXT: cvt.u16.u32 %rs12, %r4;
+; O0-NEXT: cvt.s16.s8 %rs11, %rs9;
+; O0-NEXT: cvt.s16.s8 %rs12, %rs4;
; O0-NEXT: rem.s16 %rs13, %rs12, %rs11;
-; O0-NEXT: cvt.u32.u16 %r5, %rs13;
-; O0-NEXT: prmt.b32 %r6, %r2, 0, 0x8880U;
-; O0-NEXT: cvt.u16.u32 %rs14, %r6;
-; O0-NEXT: prmt.b32 %r7, %r1, 0, 0x8880U;
-; O0-NEXT: cvt.u16.u32 %rs15, %r7;
+; O0-NEXT: cvt.u32.u16 %r3, %rs13;
+; O0-NEXT: prmt.b32 %r4, %r2, 0, 0x9991U;
+; O0-NEXT: cvt.u16.u32 %rs14, %r4;
+; O0-NEXT: prmt.b32 %r5, %r1, 0, 0x9991U;
+; O0-NEXT: cvt.u16.u32 %rs15, %r5;
; O0-NEXT: rem.s16 %rs16, %rs15, %rs14;
-; O0-NEXT: cvt.u32.u16 %r8, %rs16;
-; O0-NEXT: prmt.b32 %r9, %r8, %r5, 0x3340U;
-; O0-NEXT: // implicit-def: %r11
-; O0-NEXT: // implicit-def: %r12
-; O0-NEXT: prmt.b32 %r10, %r11, %r12, 0x3340U;
-; O0-NEXT: prmt.b32 %r13, %r9, %r10, 0x5410U;
+; O0-NEXT: cvt.u32.u16 %r6, %rs16;
+; O0-NEXT: prmt.b32 %r7, %r3, %r6, 0x3340U;
; O0-NEXT: rem.s16 %rs17, %rs5, %rs10;
-; O0-NEXT: cvt.u16.u32 %rs18, %r13;
+; O0-NEXT: cvt.u16.u32 %rs18, %r7;
; O0-NEXT: st.b8 [%rd3], %rs18;
; O0-NEXT: shr.u16 %rs19, %rs18, 8;
; O0-NEXT: st.b8 [%rd3+1], %rs19;
@@ -2185,7 +2175,7 @@ define void @test_srem_v3i8(ptr %a, ptr %b, ptr %c) {
; O3-LABEL: test_srem_v3i8(
; O3: {
; O3-NEXT: .reg .b16 %rs<20>;
-; O3-NEXT: .reg .b32 %r<14>;
+; O3-NEXT: .reg .b32 %r<8>;
; O3-NEXT: .reg .b64 %rd<4>;
; O3-EMPTY:
; O3-NEXT: // %bb.0: // %entry
@@ -2204,24 +2194,20 @@ define void @test_srem_v3i8(ptr %a, ptr %b, ptr %c) {
; O3-NEXT: cvt.u32.u16 %r2, %rs9;
; O3-NEXT: ld.s8 %rs10, [%rd2+2];
; O3-NEXT: ld.param.b64 %rd3, [test_srem_v3i8_param_2];
-; O3-NEXT: prmt.b32 %r3, %r2, 0, 0x9991U;
-; O3-NEXT: cvt.u16.u32 %rs11, %r3;
-; O3-NEXT: prmt.b32 %r4, %r1, 0, 0x9991U;
-; O3-NEXT: cvt.u16.u32 %rs12, %r4;
+; O3-NEXT: cvt.s16.s8 %rs11, %rs9;
+; O3-NEXT: cvt.s16.s8 %rs12, %rs4;
; O3-NEXT: rem.s16 %rs13, %rs12, %rs11;
-; O3-NEXT: cvt.u32.u16 %r5, %rs13;
-; O3-NEXT: prmt.b32 %r6, %r2, 0, 0x8880U;
-; O3-NEXT: cvt.u16.u32 %rs14, %r6;
-; O3-NEXT: prmt.b32 %r7, %r1, 0, 0x8880U;
-; O3-NEXT: cvt.u16.u32 %rs15, %r7;
+; O3-NEXT: cvt.u32.u16 %r3, %rs13;
+; O3-NEXT: prmt.b32 %r4, %r2, 0, 0x9991U;
+; O3-NEXT: cvt.u16.u32 %rs14, %r4;
+; O3-NEXT: prmt.b32 %r5, %r1, 0, 0x9991U;
+; O3-NEXT: cvt.u16.u32 %rs15, %r5;
; O3-NEXT: rem.s16 %rs16, %rs15, %rs14;
-; O3-NEXT: cvt.u32.u16 %r8, %rs16;
-; O3-NEXT: prmt.b32 %r9, %r8, %r5, 0x3340U;
-; O3-NEXT: prmt.b32 %r10, %r11, %r12, 0x3340U;
-; O3-NEXT: prmt.b32 %r13, %r9, %r10, 0x5410U;
+; O3-NEXT: cvt.u32.u16 %r6, %rs16;
+; O3-NEXT: prmt.b32 %r7, %r3, %r6, 0x3340U;
; O3-NEXT: rem.s16 %rs17, %rs5, %rs10;
; O3-NEXT: st.b8 [%rd3+2], %rs17;
-; O3-NEXT: cvt.u16.u32 %rs18, %r13;
+; O3-NEXT: cvt.u16.u32 %rs18, %r7;
; O3-NEXT: st.b8 [%rd3], %rs18;
; O3-NEXT: shr.u16 %rs19, %rs18, 8;
; O3-NEXT: st.b8 [%rd3+1], %rs19;
>From c97459ced4ba2962a452e958916e475c7899066b Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Fri, 18 Jul 2025 20:50:16 +0000
Subject: [PATCH 2/2] address comments
---
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 32 +++++++++++----------
1 file changed, 17 insertions(+), 15 deletions(-)
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index e9190b6097709..82cf1495841ae 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -6546,11 +6546,10 @@ void NVPTXTargetLowering::computeKnownBitsForTargetNode(
}
}
-static void getPRMTDemandedBits(const APInt &SelectorVal,
- const APInt &DemandedBits, APInt &DemandedLHS,
- APInt &DemandedRHS) {
- DemandedLHS = APInt(32, 0);
- DemandedRHS = APInt(32, 0);
+static std::pair<APInt, APInt> getPRMTDemandedBits(const APInt &SelectorVal,
+ const APInt &DemandedBits) {
+ APInt DemandedLHS = APInt(32, 0);
+ APInt DemandedRHS = APInt(32, 0);
for (unsigned I : llvm::seq(4)) {
if (DemandedBits.extractBits(8, I * 8).isZero())
@@ -6567,6 +6566,8 @@ static void getPRMTDemandedBits(const APInt &SelectorVal,
else
Src.setBits(ByteStart, ByteStart + 8);
}
+
+ return {DemandedLHS, DemandedRHS};
}
// Replace undef with 0 as this is easier for other optimizations such as
@@ -6584,26 +6585,26 @@ static SDValue simplifyDemandedBitsForPRMT(SDValue PRMT,
SelectionDAG &DAG,
const TargetLowering &TLI,
unsigned Depth) {
+ assert(PRMT.getOpcode() == NVPTXISD::PRMT);
SDValue Op0 = PRMT.getOperand(0);
SDValue Op1 = PRMT.getOperand(1);
- ConstantSDNode *Selector = dyn_cast<ConstantSDNode>(PRMT.getOperand(2));
- unsigned Mode = PRMT.getConstantOperandVal(3);
- if (!Selector)
+ auto *SelectorConst = dyn_cast<ConstantSDNode>(PRMT.getOperand(2));
+ if (!SelectorConst)
return SDValue();
- const APInt SelectorVal = getPRMTSelector(Selector->getAPIntValue(), Mode);
+ unsigned Mode = PRMT.getConstantOperandVal(3);
+ const APInt Selector = getPRMTSelector(SelectorConst->getAPIntValue(), Mode);
// Try to simplify the PRMT to one of the inputs if the used bytes are all
// from the same input in the correct order.
const unsigned LeadingBytes = DemandedBits.countLeadingZeros() / 8;
const unsigned SelBits = (4 - LeadingBytes) * 4;
- if (SelectorVal.getLoBits(SelBits) == APInt(32, 0x3210).getLoBits(SelBits))
+ if (Selector.getLoBits(SelBits) == APInt(32, 0x3210).getLoBits(SelBits))
return Op0;
- if (SelectorVal.getLoBits(SelBits) == APInt(32, 0x7654).getLoBits(SelBits))
+ if (Selector.getLoBits(SelBits) == APInt(32, 0x7654).getLoBits(SelBits))
return Op1;
- APInt DemandedLHS, DemandedRHS;
- getPRMTDemandedBits(SelectorVal, DemandedBits, DemandedLHS, DemandedRHS);
+ auto [DemandedLHS, DemandedRHS] = getPRMTDemandedBits(Selector, DemandedBits);
// Attempt to avoid multi-use ops if we don't need anything from them.
SDValue DemandedOp0 =
@@ -6613,10 +6614,11 @@ static SDValue simplifyDemandedBitsForPRMT(SDValue PRMT,
DemandedOp0 = canonicalizePRMTInput(DemandedOp0, DAG);
DemandedOp1 = canonicalizePRMTInput(DemandedOp1, DAG);
- if (DemandedOp0 != Op0 || DemandedOp1 != Op1) {
+ if ((DemandedOp0 && DemandedOp0 != Op0) ||
+ (DemandedOp1 && DemandedOp1 != Op1)) {
Op0 = DemandedOp0 ? DemandedOp0 : Op0;
Op1 = DemandedOp1 ? DemandedOp1 : Op1;
- return getPRMT(Op0, Op1, SelectorVal.getZExtValue(), SDLoc(PRMT), DAG);
+ return getPRMT(Op0, Op1, Selector.getZExtValue(), SDLoc(PRMT), DAG);
}
return SDValue();
More information about the llvm-commits
mailing list