[llvm] [NVPTX] Add SimplifyDemandedBitsForTargetNode for PRMT (PR #149395)

Alex MacLean via llvm-commits llvm-commits at lists.llvm.org
Fri Jul 18 13:50:31 PDT 2025


https://github.com/AlexMaclean updated https://github.com/llvm/llvm-project/pull/149395

>From bf0cf92d4a5a4754b9605afc50606c21467674e7 Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Thu, 17 Jul 2025 18:58:39 +0000
Subject: [PATCH 1/2] [NVPTX] Add SimplifyDemandedBitsForTargetNode for PRMT

---
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp   |  99 +++++++++++++++-
 llvm/lib/Target/NVPTX/NVPTXISelLowering.h     |   5 +
 .../test/CodeGen/NVPTX/LoadStoreVectorizer.ll | 102 ++++++++---------
 llvm/test/CodeGen/NVPTX/extractelement.ll     |  71 ++++++------
 llvm/test/CodeGen/NVPTX/i8x4-instructions.ll  | 106 ++++++++----------
 5 files changed, 234 insertions(+), 149 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 7aa06f9079b09..e9190b6097709 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -6544,4 +6544,101 @@ void NVPTXTargetLowering::computeKnownBitsForTargetNode(
   default:
     break;
   }
-}
\ No newline at end of file
+}
+
+static void getPRMTDemandedBits(const APInt &SelectorVal,
+                                const APInt &DemandedBits, APInt &DemandedLHS,
+                                APInt &DemandedRHS) {
+  DemandedLHS = APInt(32, 0);
+  DemandedRHS = APInt(32, 0);
+
+  for (unsigned I : llvm::seq(4)) {
+    if (DemandedBits.extractBits(8, I * 8).isZero())
+      continue;
+
+    APInt Sel = SelectorVal.extractBits(4, I * 4);
+    unsigned Idx = Sel.getLoBits(3).getZExtValue();
+    unsigned Sign = Sel.getHiBits(1).getZExtValue();
+
+    APInt &Src = Idx < 4 ? DemandedLHS : DemandedRHS;
+    unsigned ByteStart = (Idx % 4) * 8;
+    if (Sign)
+      Src.setBit(ByteStart + 7);
+    else
+      Src.setBits(ByteStart, ByteStart + 8);
+  }
+}
+
+// Replace undef with 0 as this is easier for other optimizations such as
+// known bits.
+static SDValue canonicalizePRMTInput(SDValue Op, SelectionDAG &DAG) {
+  if (!Op)
+    return SDValue();
+  if (Op.isUndef())
+    return DAG.getConstant(0, SDLoc(), MVT::i32);
+  return Op;
+}
+
+static SDValue simplifyDemandedBitsForPRMT(SDValue PRMT,
+                                           const APInt &DemandedBits,
+                                           SelectionDAG &DAG,
+                                           const TargetLowering &TLI,
+                                           unsigned Depth) {
+  SDValue Op0 = PRMT.getOperand(0);
+  SDValue Op1 = PRMT.getOperand(1);
+  ConstantSDNode *Selector = dyn_cast<ConstantSDNode>(PRMT.getOperand(2));
+  unsigned Mode = PRMT.getConstantOperandVal(3);
+  if (!Selector)
+    return SDValue();
+
+  const APInt SelectorVal = getPRMTSelector(Selector->getAPIntValue(), Mode);
+
+  // Try to simplify the PRMT to one of the inputs if the used bytes are all
+  // from the same input in the correct order.
+  const unsigned LeadingBytes = DemandedBits.countLeadingZeros() / 8;
+  const unsigned SelBits = (4 - LeadingBytes) * 4;
+  if (SelectorVal.getLoBits(SelBits) == APInt(32, 0x3210).getLoBits(SelBits))
+    return Op0;
+  if (SelectorVal.getLoBits(SelBits) == APInt(32, 0x7654).getLoBits(SelBits))
+    return Op1;
+
+  APInt DemandedLHS, DemandedRHS;
+  getPRMTDemandedBits(SelectorVal, DemandedBits, DemandedLHS, DemandedRHS);
+
+  // Attempt to avoid multi-use ops if we don't need anything from them.
+  SDValue DemandedOp0 =
+      TLI.SimplifyMultipleUseDemandedBits(Op0, DemandedLHS, DAG, Depth + 1);
+  SDValue DemandedOp1 =
+      TLI.SimplifyMultipleUseDemandedBits(Op1, DemandedRHS, DAG, Depth + 1);
+
+  DemandedOp0 = canonicalizePRMTInput(DemandedOp0, DAG);
+  DemandedOp1 = canonicalizePRMTInput(DemandedOp1, DAG);
+  if (DemandedOp0 != Op0 || DemandedOp1 != Op1) {
+    Op0 = DemandedOp0 ? DemandedOp0 : Op0;
+    Op1 = DemandedOp1 ? DemandedOp1 : Op1;
+    return getPRMT(Op0, Op1, SelectorVal.getZExtValue(), SDLoc(PRMT), DAG);
+  }
+
+  return SDValue();
+}
+
+bool NVPTXTargetLowering::SimplifyDemandedBitsForTargetNode(
+    SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
+    KnownBits &Known, TargetLoweringOpt &TLO, unsigned Depth) const {
+  Known.resetAll();
+
+  switch (Op.getOpcode()) {
+  case NVPTXISD::PRMT:
+    if (SDValue Result = simplifyDemandedBitsForPRMT(Op, DemandedBits, TLO.DAG,
+                                                     *this, Depth)) {
+      TLO.CombineTo(Op, Result);
+      return true;
+    }
+    break;
+  default:
+    break;
+  }
+
+  computeKnownBitsForTargetNode(Op, Known, DemandedElts, TLO.DAG, Depth);
+  return false;
+}
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
index bc3548c0272bb..228e2aac47aec 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h
@@ -275,6 +275,11 @@ class NVPTXTargetLowering : public TargetLowering {
                                      const APInt &DemandedElts,
                                      const SelectionDAG &DAG,
                                      unsigned Depth = 0) const override;
+  bool SimplifyDemandedBitsForTargetNode(SDValue Op, const APInt &DemandedBits,
+                                         const APInt &DemandedElts,
+                                         KnownBits &Known,
+                                         TargetLoweringOpt &TLO,
+                                         unsigned Depth = 0) const override;
 
 private:
   const NVPTXSubtarget &STI; // cache the subtarget here
diff --git a/llvm/test/CodeGen/NVPTX/LoadStoreVectorizer.ll b/llvm/test/CodeGen/NVPTX/LoadStoreVectorizer.ll
index 23832a9cb5c58..dd9a472984c25 100644
--- a/llvm/test/CodeGen/NVPTX/LoadStoreVectorizer.ll
+++ b/llvm/test/CodeGen/NVPTX/LoadStoreVectorizer.ll
@@ -181,32 +181,32 @@ define void @combine_v16i8(ptr noundef align 16 %ptr1, ptr noundef align 16 %ptr
 ; ENABLED-NEXT:    prmt.b32 %r5, %r4, 0, 0x7773U;
 ; ENABLED-NEXT:    prmt.b32 %r6, %r4, 0, 0x7772U;
 ; ENABLED-NEXT:    prmt.b32 %r7, %r4, 0, 0x7771U;
-; ENABLED-NEXT:    prmt.b32 %r8, %r4, 0, 0x7770U;
-; ENABLED-NEXT:    prmt.b32 %r9, %r3, 0, 0x7773U;
-; ENABLED-NEXT:    prmt.b32 %r10, %r3, 0, 0x7772U;
-; ENABLED-NEXT:    prmt.b32 %r11, %r3, 0, 0x7771U;
-; ENABLED-NEXT:    prmt.b32 %r12, %r3, 0, 0x7770U;
-; ENABLED-NEXT:    prmt.b32 %r13, %r2, 0, 0x7773U;
-; ENABLED-NEXT:    prmt.b32 %r14, %r2, 0, 0x7772U;
-; ENABLED-NEXT:    prmt.b32 %r15, %r2, 0, 0x7771U;
-; ENABLED-NEXT:    prmt.b32 %r16, %r2, 0, 0x7770U;
-; ENABLED-NEXT:    prmt.b32 %r17, %r1, 0, 0x7773U;
-; ENABLED-NEXT:    prmt.b32 %r18, %r1, 0, 0x7772U;
-; ENABLED-NEXT:    prmt.b32 %r19, %r1, 0, 0x7771U;
-; ENABLED-NEXT:    prmt.b32 %r20, %r1, 0, 0x7770U;
+; ENABLED-NEXT:    prmt.b32 %r8, %r3, 0, 0x7773U;
+; ENABLED-NEXT:    prmt.b32 %r9, %r3, 0, 0x7772U;
+; ENABLED-NEXT:    prmt.b32 %r10, %r3, 0, 0x7771U;
+; ENABLED-NEXT:    prmt.b32 %r11, %r2, 0, 0x7773U;
+; ENABLED-NEXT:    prmt.b32 %r12, %r2, 0, 0x7772U;
+; ENABLED-NEXT:    prmt.b32 %r13, %r2, 0, 0x7771U;
+; ENABLED-NEXT:    prmt.b32 %r14, %r1, 0, 0x7773U;
+; ENABLED-NEXT:    prmt.b32 %r15, %r1, 0, 0x7772U;
+; ENABLED-NEXT:    prmt.b32 %r16, %r1, 0, 0x7771U;
 ; ENABLED-NEXT:    ld.param.b64 %rd2, [combine_v16i8_param_1];
-; ENABLED-NEXT:    add.s32 %r21, %r20, %r19;
-; ENABLED-NEXT:    add.s32 %r22, %r21, %r18;
-; ENABLED-NEXT:    add.s32 %r23, %r22, %r17;
-; ENABLED-NEXT:    add.s32 %r24, %r23, %r16;
-; ENABLED-NEXT:    add.s32 %r25, %r24, %r15;
-; ENABLED-NEXT:    add.s32 %r26, %r25, %r14;
-; ENABLED-NEXT:    add.s32 %r27, %r26, %r13;
-; ENABLED-NEXT:    add.s32 %r28, %r27, %r12;
-; ENABLED-NEXT:    add.s32 %r29, %r28, %r11;
-; ENABLED-NEXT:    add.s32 %r30, %r29, %r10;
-; ENABLED-NEXT:    add.s32 %r31, %r30, %r9;
-; ENABLED-NEXT:    add.s32 %r32, %r31, %r8;
+; ENABLED-NEXT:    and.b32 %r17, %r1, 255;
+; ENABLED-NEXT:    and.b32 %r18, %r2, 255;
+; ENABLED-NEXT:    and.b32 %r19, %r3, 255;
+; ENABLED-NEXT:    and.b32 %r20, %r4, 255;
+; ENABLED-NEXT:    add.s32 %r21, %r17, %r16;
+; ENABLED-NEXT:    add.s32 %r22, %r21, %r15;
+; ENABLED-NEXT:    add.s32 %r23, %r22, %r14;
+; ENABLED-NEXT:    add.s32 %r24, %r23, %r18;
+; ENABLED-NEXT:    add.s32 %r25, %r24, %r13;
+; ENABLED-NEXT:    add.s32 %r26, %r25, %r12;
+; ENABLED-NEXT:    add.s32 %r27, %r26, %r11;
+; ENABLED-NEXT:    add.s32 %r28, %r27, %r19;
+; ENABLED-NEXT:    add.s32 %r29, %r28, %r10;
+; ENABLED-NEXT:    add.s32 %r30, %r29, %r9;
+; ENABLED-NEXT:    add.s32 %r31, %r30, %r8;
+; ENABLED-NEXT:    add.s32 %r32, %r31, %r20;
 ; ENABLED-NEXT:    add.s32 %r33, %r32, %r7;
 ; ENABLED-NEXT:    add.s32 %r34, %r33, %r6;
 ; ENABLED-NEXT:    add.s32 %r35, %r34, %r5;
@@ -332,36 +332,36 @@ define void @combine_v16i8_unaligned(ptr noundef align 8 %ptr1, ptr noundef alig
 ; ENABLED-NEXT:    prmt.b32 %r3, %r2, 0, 0x7773U;
 ; ENABLED-NEXT:    prmt.b32 %r4, %r2, 0, 0x7772U;
 ; ENABLED-NEXT:    prmt.b32 %r5, %r2, 0, 0x7771U;
-; ENABLED-NEXT:    prmt.b32 %r6, %r2, 0, 0x7770U;
-; ENABLED-NEXT:    prmt.b32 %r7, %r1, 0, 0x7773U;
-; ENABLED-NEXT:    prmt.b32 %r8, %r1, 0, 0x7772U;
-; ENABLED-NEXT:    prmt.b32 %r9, %r1, 0, 0x7771U;
-; ENABLED-NEXT:    prmt.b32 %r10, %r1, 0, 0x7770U;
+; ENABLED-NEXT:    prmt.b32 %r6, %r1, 0, 0x7773U;
+; ENABLED-NEXT:    prmt.b32 %r7, %r1, 0, 0x7772U;
+; ENABLED-NEXT:    prmt.b32 %r8, %r1, 0, 0x7771U;
 ; ENABLED-NEXT:    ld.param.b64 %rd2, [combine_v16i8_unaligned_param_1];
-; ENABLED-NEXT:    ld.v2.b32 {%r11, %r12}, [%rd1+8];
-; ENABLED-NEXT:    prmt.b32 %r13, %r12, 0, 0x7773U;
-; ENABLED-NEXT:    prmt.b32 %r14, %r12, 0, 0x7772U;
-; ENABLED-NEXT:    prmt.b32 %r15, %r12, 0, 0x7771U;
-; ENABLED-NEXT:    prmt.b32 %r16, %r12, 0, 0x7770U;
-; ENABLED-NEXT:    prmt.b32 %r17, %r11, 0, 0x7773U;
-; ENABLED-NEXT:    prmt.b32 %r18, %r11, 0, 0x7772U;
-; ENABLED-NEXT:    prmt.b32 %r19, %r11, 0, 0x7771U;
-; ENABLED-NEXT:    prmt.b32 %r20, %r11, 0, 0x7770U;
-; ENABLED-NEXT:    add.s32 %r21, %r10, %r9;
-; ENABLED-NEXT:    add.s32 %r22, %r21, %r8;
-; ENABLED-NEXT:    add.s32 %r23, %r22, %r7;
-; ENABLED-NEXT:    add.s32 %r24, %r23, %r6;
+; ENABLED-NEXT:    ld.v2.b32 {%r9, %r10}, [%rd1+8];
+; ENABLED-NEXT:    prmt.b32 %r11, %r10, 0, 0x7773U;
+; ENABLED-NEXT:    prmt.b32 %r12, %r10, 0, 0x7772U;
+; ENABLED-NEXT:    prmt.b32 %r13, %r10, 0, 0x7771U;
+; ENABLED-NEXT:    prmt.b32 %r14, %r9, 0, 0x7773U;
+; ENABLED-NEXT:    prmt.b32 %r15, %r9, 0, 0x7772U;
+; ENABLED-NEXT:    prmt.b32 %r16, %r9, 0, 0x7771U;
+; ENABLED-NEXT:    and.b32 %r17, %r1, 255;
+; ENABLED-NEXT:    and.b32 %r18, %r2, 255;
+; ENABLED-NEXT:    and.b32 %r19, %r9, 255;
+; ENABLED-NEXT:    and.b32 %r20, %r10, 255;
+; ENABLED-NEXT:    add.s32 %r21, %r17, %r8;
+; ENABLED-NEXT:    add.s32 %r22, %r21, %r7;
+; ENABLED-NEXT:    add.s32 %r23, %r22, %r6;
+; ENABLED-NEXT:    add.s32 %r24, %r23, %r18;
 ; ENABLED-NEXT:    add.s32 %r25, %r24, %r5;
 ; ENABLED-NEXT:    add.s32 %r26, %r25, %r4;
 ; ENABLED-NEXT:    add.s32 %r27, %r26, %r3;
-; ENABLED-NEXT:    add.s32 %r28, %r27, %r20;
-; ENABLED-NEXT:    add.s32 %r29, %r28, %r19;
-; ENABLED-NEXT:    add.s32 %r30, %r29, %r18;
-; ENABLED-NEXT:    add.s32 %r31, %r30, %r17;
-; ENABLED-NEXT:    add.s32 %r32, %r31, %r16;
-; ENABLED-NEXT:    add.s32 %r33, %r32, %r15;
-; ENABLED-NEXT:    add.s32 %r34, %r33, %r14;
-; ENABLED-NEXT:    add.s32 %r35, %r34, %r13;
+; ENABLED-NEXT:    add.s32 %r28, %r27, %r19;
+; ENABLED-NEXT:    add.s32 %r29, %r28, %r16;
+; ENABLED-NEXT:    add.s32 %r30, %r29, %r15;
+; ENABLED-NEXT:    add.s32 %r31, %r30, %r14;
+; ENABLED-NEXT:    add.s32 %r32, %r31, %r20;
+; ENABLED-NEXT:    add.s32 %r33, %r32, %r13;
+; ENABLED-NEXT:    add.s32 %r34, %r33, %r12;
+; ENABLED-NEXT:    add.s32 %r35, %r34, %r11;
 ; ENABLED-NEXT:    st.b32 [%rd2], %r35;
 ; ENABLED-NEXT:    ret;
 ;
diff --git a/llvm/test/CodeGen/NVPTX/extractelement.ll b/llvm/test/CodeGen/NVPTX/extractelement.ll
index 80980efbab05b..d61a63ce24f89 100644
--- a/llvm/test/CodeGen/NVPTX/extractelement.ll
+++ b/llvm/test/CodeGen/NVPTX/extractelement.ll
@@ -56,23 +56,22 @@ define i16  @test_v4i8(i32 %a) {
 ; CHECK-LABEL: test_v4i8(
 ; CHECK:       {
 ; CHECK-NEXT:    .reg .b16 %rs<8>;
-; CHECK-NEXT:    .reg .b32 %r<7>;
+; CHECK-NEXT:    .reg .b32 %r<6>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0:
 ; CHECK-NEXT:    ld.param.b32 %r1, [test_v4i8_param_0];
-; CHECK-NEXT:    prmt.b32 %r2, %r1, 0, 0x8880U;
-; CHECK-NEXT:    cvt.u16.u32 %rs1, %r2;
-; CHECK-NEXT:    prmt.b32 %r3, %r1, 0, 0x9991U;
-; CHECK-NEXT:    cvt.u16.u32 %rs2, %r3;
-; CHECK-NEXT:    prmt.b32 %r4, %r1, 0, 0xaaa2U;
-; CHECK-NEXT:    cvt.u16.u32 %rs3, %r4;
-; CHECK-NEXT:    prmt.b32 %r5, %r1, 0, 0xbbb3U;
-; CHECK-NEXT:    cvt.u16.u32 %rs4, %r5;
+; CHECK-NEXT:    cvt.s8.s32 %rs1, %r1;
+; CHECK-NEXT:    prmt.b32 %r2, %r1, 0, 0x9991U;
+; CHECK-NEXT:    cvt.u16.u32 %rs2, %r2;
+; CHECK-NEXT:    prmt.b32 %r3, %r1, 0, 0xaaa2U;
+; CHECK-NEXT:    cvt.u16.u32 %rs3, %r3;
+; CHECK-NEXT:    prmt.b32 %r4, %r1, 0, 0xbbb3U;
+; CHECK-NEXT:    cvt.u16.u32 %rs4, %r4;
 ; CHECK-NEXT:    add.s16 %rs5, %rs1, %rs2;
 ; CHECK-NEXT:    add.s16 %rs6, %rs3, %rs4;
 ; CHECK-NEXT:    add.s16 %rs7, %rs5, %rs6;
-; CHECK-NEXT:    cvt.u32.u16 %r6, %rs7;
-; CHECK-NEXT:    st.param.b32 [func_retval0], %r6;
+; CHECK-NEXT:    cvt.u32.u16 %r5, %rs7;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r5;
 ; CHECK-NEXT:    ret;
   %v = bitcast i32 %a to <4 x i8>
   %r0 = extractelement <4 x i8> %v, i64 0
@@ -96,7 +95,7 @@ define i32  @test_v4i8_s32(i32 %a) {
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0:
 ; CHECK-NEXT:    ld.param.b32 %r1, [test_v4i8_s32_param_0];
-; CHECK-NEXT:    prmt.b32 %r2, %r1, 0, 0x8880U;
+; CHECK-NEXT:    cvt.s32.s8 %r2, %r1;
 ; CHECK-NEXT:    prmt.b32 %r3, %r1, 0, 0x9991U;
 ; CHECK-NEXT:    prmt.b32 %r4, %r1, 0, 0xaaa2U;
 ; CHECK-NEXT:    prmt.b32 %r5, %r1, 0, 0xbbb3U;
@@ -127,12 +126,12 @@ define i32  @test_v4i8_u32(i32 %a) {
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0:
 ; CHECK-NEXT:    ld.param.b32 %r1, [test_v4i8_u32_param_0];
-; CHECK-NEXT:    prmt.b32 %r2, %r1, 0, 0x7770U;
-; CHECK-NEXT:    prmt.b32 %r3, %r1, 0, 0x7771U;
-; CHECK-NEXT:    prmt.b32 %r4, %r1, 0, 0x7772U;
-; CHECK-NEXT:    prmt.b32 %r5, %r1, 0, 0x7773U;
-; CHECK-NEXT:    add.s32 %r6, %r2, %r3;
-; CHECK-NEXT:    add.s32 %r7, %r4, %r5;
+; CHECK-NEXT:    prmt.b32 %r2, %r1, 0, 0x7771U;
+; CHECK-NEXT:    prmt.b32 %r3, %r1, 0, 0x7772U;
+; CHECK-NEXT:    prmt.b32 %r4, %r1, 0, 0x7773U;
+; CHECK-NEXT:    and.b32 %r5, %r1, 255;
+; CHECK-NEXT:    add.s32 %r6, %r5, %r2;
+; CHECK-NEXT:    add.s32 %r7, %r3, %r4;
 ; CHECK-NEXT:    add.s32 %r8, %r6, %r7;
 ; CHECK-NEXT:    st.param.b32 [func_retval0], %r8;
 ; CHECK-NEXT:    ret;
@@ -157,26 +156,24 @@ define i16  @test_v8i8(i64 %a) {
 ; CHECK-LABEL: test_v8i8(
 ; CHECK:       {
 ; CHECK-NEXT:    .reg .b16 %rs<16>;
-; CHECK-NEXT:    .reg .b32 %r<12>;
+; CHECK-NEXT:    .reg .b32 %r<10>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0:
 ; CHECK-NEXT:    ld.param.v2.b32 {%r1, %r2}, [test_v8i8_param_0];
-; CHECK-NEXT:    prmt.b32 %r3, %r1, 0, 0x8880U;
-; CHECK-NEXT:    cvt.u16.u32 %rs1, %r3;
-; CHECK-NEXT:    prmt.b32 %r4, %r1, 0, 0x9991U;
-; CHECK-NEXT:    cvt.u16.u32 %rs2, %r4;
-; CHECK-NEXT:    prmt.b32 %r5, %r1, 0, 0xaaa2U;
-; CHECK-NEXT:    cvt.u16.u32 %rs3, %r5;
-; CHECK-NEXT:    prmt.b32 %r6, %r1, 0, 0xbbb3U;
-; CHECK-NEXT:    cvt.u16.u32 %rs4, %r6;
-; CHECK-NEXT:    prmt.b32 %r7, %r2, 0, 0x8880U;
-; CHECK-NEXT:    cvt.u16.u32 %rs5, %r7;
-; CHECK-NEXT:    prmt.b32 %r8, %r2, 0, 0x9991U;
-; CHECK-NEXT:    cvt.u16.u32 %rs6, %r8;
-; CHECK-NEXT:    prmt.b32 %r9, %r2, 0, 0xaaa2U;
-; CHECK-NEXT:    cvt.u16.u32 %rs7, %r9;
-; CHECK-NEXT:    prmt.b32 %r10, %r2, 0, 0xbbb3U;
-; CHECK-NEXT:    cvt.u16.u32 %rs8, %r10;
+; CHECK-NEXT:    cvt.s8.s32 %rs1, %r1;
+; CHECK-NEXT:    prmt.b32 %r3, %r1, 0, 0x9991U;
+; CHECK-NEXT:    cvt.u16.u32 %rs2, %r3;
+; CHECK-NEXT:    prmt.b32 %r4, %r1, 0, 0xaaa2U;
+; CHECK-NEXT:    cvt.u16.u32 %rs3, %r4;
+; CHECK-NEXT:    prmt.b32 %r5, %r1, 0, 0xbbb3U;
+; CHECK-NEXT:    cvt.u16.u32 %rs4, %r5;
+; CHECK-NEXT:    cvt.s8.s32 %rs5, %r2;
+; CHECK-NEXT:    prmt.b32 %r6, %r2, 0, 0x9991U;
+; CHECK-NEXT:    cvt.u16.u32 %rs6, %r6;
+; CHECK-NEXT:    prmt.b32 %r7, %r2, 0, 0xaaa2U;
+; CHECK-NEXT:    cvt.u16.u32 %rs7, %r7;
+; CHECK-NEXT:    prmt.b32 %r8, %r2, 0, 0xbbb3U;
+; CHECK-NEXT:    cvt.u16.u32 %rs8, %r8;
 ; CHECK-NEXT:    add.s16 %rs9, %rs1, %rs2;
 ; CHECK-NEXT:    add.s16 %rs10, %rs3, %rs4;
 ; CHECK-NEXT:    add.s16 %rs11, %rs5, %rs6;
@@ -184,8 +181,8 @@ define i16  @test_v8i8(i64 %a) {
 ; CHECK-NEXT:    add.s16 %rs13, %rs9, %rs10;
 ; CHECK-NEXT:    add.s16 %rs14, %rs11, %rs12;
 ; CHECK-NEXT:    add.s16 %rs15, %rs13, %rs14;
-; CHECK-NEXT:    cvt.u32.u16 %r11, %rs15;
-; CHECK-NEXT:    st.param.b32 [func_retval0], %r11;
+; CHECK-NEXT:    cvt.u32.u16 %r9, %rs15;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r9;
 ; CHECK-NEXT:    ret;
   %v = bitcast i64 %a to <8 x i8>
   %r0 = extractelement <8 x i8> %v, i64 0
diff --git a/llvm/test/CodeGen/NVPTX/i8x4-instructions.ll b/llvm/test/CodeGen/NVPTX/i8x4-instructions.ll
index cbc9f700b1f01..78f889a548420 100644
--- a/llvm/test/CodeGen/NVPTX/i8x4-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/i8x4-instructions.ll
@@ -2040,7 +2040,7 @@ define void @test_srem_v4i8(ptr %a, ptr %b, ptr %c) {
 ; O0-LABEL: test_srem_v4i8(
 ; O0:       {
 ; O0-NEXT:    .reg .b16 %rs<13>;
-; O0-NEXT:    .reg .b32 %r<18>;
+; O0-NEXT:    .reg .b32 %r<16>;
 ; O0-NEXT:    .reg .b64 %rd<4>;
 ; O0-EMPTY:
 ; O0-NEXT:  // %bb.0: // %entry
@@ -2062,27 +2062,25 @@ define void @test_srem_v4i8(ptr %a, ptr %b, ptr %c) {
 ; O0-NEXT:    rem.s16 %rs6, %rs5, %rs4;
 ; O0-NEXT:    cvt.u32.u16 %r8, %rs6;
 ; O0-NEXT:    prmt.b32 %r9, %r8, %r5, 0x3340U;
-; O0-NEXT:    prmt.b32 %r10, %r2, 0, 0x9991U;
-; O0-NEXT:    cvt.u16.u32 %rs7, %r10;
-; O0-NEXT:    prmt.b32 %r11, %r1, 0, 0x9991U;
-; O0-NEXT:    cvt.u16.u32 %rs8, %r11;
+; O0-NEXT:    cvt.s8.s32 %rs7, %r2;
+; O0-NEXT:    cvt.s8.s32 %rs8, %r1;
 ; O0-NEXT:    rem.s16 %rs9, %rs8, %rs7;
-; O0-NEXT:    cvt.u32.u16 %r12, %rs9;
-; O0-NEXT:    prmt.b32 %r13, %r2, 0, 0x8880U;
-; O0-NEXT:    cvt.u16.u32 %rs10, %r13;
-; O0-NEXT:    prmt.b32 %r14, %r1, 0, 0x8880U;
-; O0-NEXT:    cvt.u16.u32 %rs11, %r14;
+; O0-NEXT:    cvt.u32.u16 %r10, %rs9;
+; O0-NEXT:    prmt.b32 %r11, %r2, 0, 0x9991U;
+; O0-NEXT:    cvt.u16.u32 %rs10, %r11;
+; O0-NEXT:    prmt.b32 %r12, %r1, 0, 0x9991U;
+; O0-NEXT:    cvt.u16.u32 %rs11, %r12;
 ; O0-NEXT:    rem.s16 %rs12, %rs11, %rs10;
-; O0-NEXT:    cvt.u32.u16 %r15, %rs12;
-; O0-NEXT:    prmt.b32 %r16, %r15, %r12, 0x3340U;
-; O0-NEXT:    prmt.b32 %r17, %r16, %r9, 0x5410U;
-; O0-NEXT:    st.b32 [%rd3], %r17;
+; O0-NEXT:    cvt.u32.u16 %r13, %rs12;
+; O0-NEXT:    prmt.b32 %r14, %r10, %r13, 0x3340U;
+; O0-NEXT:    prmt.b32 %r15, %r14, %r9, 0x5410U;
+; O0-NEXT:    st.b32 [%rd3], %r15;
 ; O0-NEXT:    ret;
 ;
 ; O3-LABEL: test_srem_v4i8(
 ; O3:       {
 ; O3-NEXT:    .reg .b16 %rs<13>;
-; O3-NEXT:    .reg .b32 %r<18>;
+; O3-NEXT:    .reg .b32 %r<16>;
 ; O3-NEXT:    .reg .b64 %rd<4>;
 ; O3-EMPTY:
 ; O3-NEXT:  // %bb.0: // %entry
@@ -2104,21 +2102,19 @@ define void @test_srem_v4i8(ptr %a, ptr %b, ptr %c) {
 ; O3-NEXT:    rem.s16 %rs6, %rs5, %rs4;
 ; O3-NEXT:    cvt.u32.u16 %r8, %rs6;
 ; O3-NEXT:    prmt.b32 %r9, %r8, %r5, 0x3340U;
-; O3-NEXT:    prmt.b32 %r10, %r2, 0, 0x9991U;
-; O3-NEXT:    cvt.u16.u32 %rs7, %r10;
-; O3-NEXT:    prmt.b32 %r11, %r1, 0, 0x9991U;
-; O3-NEXT:    cvt.u16.u32 %rs8, %r11;
+; O3-NEXT:    cvt.s8.s32 %rs7, %r2;
+; O3-NEXT:    cvt.s8.s32 %rs8, %r1;
 ; O3-NEXT:    rem.s16 %rs9, %rs8, %rs7;
-; O3-NEXT:    cvt.u32.u16 %r12, %rs9;
-; O3-NEXT:    prmt.b32 %r13, %r2, 0, 0x8880U;
-; O3-NEXT:    cvt.u16.u32 %rs10, %r13;
-; O3-NEXT:    prmt.b32 %r14, %r1, 0, 0x8880U;
-; O3-NEXT:    cvt.u16.u32 %rs11, %r14;
+; O3-NEXT:    cvt.u32.u16 %r10, %rs9;
+; O3-NEXT:    prmt.b32 %r11, %r2, 0, 0x9991U;
+; O3-NEXT:    cvt.u16.u32 %rs10, %r11;
+; O3-NEXT:    prmt.b32 %r12, %r1, 0, 0x9991U;
+; O3-NEXT:    cvt.u16.u32 %rs11, %r12;
 ; O3-NEXT:    rem.s16 %rs12, %rs11, %rs10;
-; O3-NEXT:    cvt.u32.u16 %r15, %rs12;
-; O3-NEXT:    prmt.b32 %r16, %r15, %r12, 0x3340U;
-; O3-NEXT:    prmt.b32 %r17, %r16, %r9, 0x5410U;
-; O3-NEXT:    st.b32 [%rd3], %r17;
+; O3-NEXT:    cvt.u32.u16 %r13, %rs12;
+; O3-NEXT:    prmt.b32 %r14, %r10, %r13, 0x3340U;
+; O3-NEXT:    prmt.b32 %r15, %r14, %r9, 0x5410U;
+; O3-NEXT:    st.b32 [%rd3], %r15;
 ; O3-NEXT:    ret;
 entry:
   %t57 = load <4 x i8>, ptr %a, align 4
@@ -2138,7 +2134,7 @@ define void @test_srem_v3i8(ptr %a, ptr %b, ptr %c) {
 ; O0-LABEL: test_srem_v3i8(
 ; O0:       {
 ; O0-NEXT:    .reg .b16 %rs<20>;
-; O0-NEXT:    .reg .b32 %r<14>;
+; O0-NEXT:    .reg .b32 %r<8>;
 ; O0-NEXT:    .reg .b64 %rd<4>;
 ; O0-EMPTY:
 ; O0-NEXT:  // %bb.0: // %entry
@@ -2157,25 +2153,19 @@ define void @test_srem_v3i8(ptr %a, ptr %b, ptr %c) {
 ; O0-NEXT:    or.b16 %rs9, %rs8, %rs6;
 ; O0-NEXT:    cvt.u32.u16 %r2, %rs9;
 ; O0-NEXT:    ld.s8 %rs10, [%rd2+2];
-; O0-NEXT:    prmt.b32 %r3, %r2, 0, 0x9991U;
-; O0-NEXT:    cvt.u16.u32 %rs11, %r3;
-; O0-NEXT:    prmt.b32 %r4, %r1, 0, 0x9991U;
-; O0-NEXT:    cvt.u16.u32 %rs12, %r4;
+; O0-NEXT:    cvt.s16.s8 %rs11, %rs9;
+; O0-NEXT:    cvt.s16.s8 %rs12, %rs4;
 ; O0-NEXT:    rem.s16 %rs13, %rs12, %rs11;
-; O0-NEXT:    cvt.u32.u16 %r5, %rs13;
-; O0-NEXT:    prmt.b32 %r6, %r2, 0, 0x8880U;
-; O0-NEXT:    cvt.u16.u32 %rs14, %r6;
-; O0-NEXT:    prmt.b32 %r7, %r1, 0, 0x8880U;
-; O0-NEXT:    cvt.u16.u32 %rs15, %r7;
+; O0-NEXT:    cvt.u32.u16 %r3, %rs13;
+; O0-NEXT:    prmt.b32 %r4, %r2, 0, 0x9991U;
+; O0-NEXT:    cvt.u16.u32 %rs14, %r4;
+; O0-NEXT:    prmt.b32 %r5, %r1, 0, 0x9991U;
+; O0-NEXT:    cvt.u16.u32 %rs15, %r5;
 ; O0-NEXT:    rem.s16 %rs16, %rs15, %rs14;
-; O0-NEXT:    cvt.u32.u16 %r8, %rs16;
-; O0-NEXT:    prmt.b32 %r9, %r8, %r5, 0x3340U;
-; O0-NEXT:    // implicit-def: %r11
-; O0-NEXT:    // implicit-def: %r12
-; O0-NEXT:    prmt.b32 %r10, %r11, %r12, 0x3340U;
-; O0-NEXT:    prmt.b32 %r13, %r9, %r10, 0x5410U;
+; O0-NEXT:    cvt.u32.u16 %r6, %rs16;
+; O0-NEXT:    prmt.b32 %r7, %r3, %r6, 0x3340U;
 ; O0-NEXT:    rem.s16 %rs17, %rs5, %rs10;
-; O0-NEXT:    cvt.u16.u32 %rs18, %r13;
+; O0-NEXT:    cvt.u16.u32 %rs18, %r7;
 ; O0-NEXT:    st.b8 [%rd3], %rs18;
 ; O0-NEXT:    shr.u16 %rs19, %rs18, 8;
 ; O0-NEXT:    st.b8 [%rd3+1], %rs19;
@@ -2185,7 +2175,7 @@ define void @test_srem_v3i8(ptr %a, ptr %b, ptr %c) {
 ; O3-LABEL: test_srem_v3i8(
 ; O3:       {
 ; O3-NEXT:    .reg .b16 %rs<20>;
-; O3-NEXT:    .reg .b32 %r<14>;
+; O3-NEXT:    .reg .b32 %r<8>;
 ; O3-NEXT:    .reg .b64 %rd<4>;
 ; O3-EMPTY:
 ; O3-NEXT:  // %bb.0: // %entry
@@ -2204,24 +2194,20 @@ define void @test_srem_v3i8(ptr %a, ptr %b, ptr %c) {
 ; O3-NEXT:    cvt.u32.u16 %r2, %rs9;
 ; O3-NEXT:    ld.s8 %rs10, [%rd2+2];
 ; O3-NEXT:    ld.param.b64 %rd3, [test_srem_v3i8_param_2];
-; O3-NEXT:    prmt.b32 %r3, %r2, 0, 0x9991U;
-; O3-NEXT:    cvt.u16.u32 %rs11, %r3;
-; O3-NEXT:    prmt.b32 %r4, %r1, 0, 0x9991U;
-; O3-NEXT:    cvt.u16.u32 %rs12, %r4;
+; O3-NEXT:    cvt.s16.s8 %rs11, %rs9;
+; O3-NEXT:    cvt.s16.s8 %rs12, %rs4;
 ; O3-NEXT:    rem.s16 %rs13, %rs12, %rs11;
-; O3-NEXT:    cvt.u32.u16 %r5, %rs13;
-; O3-NEXT:    prmt.b32 %r6, %r2, 0, 0x8880U;
-; O3-NEXT:    cvt.u16.u32 %rs14, %r6;
-; O3-NEXT:    prmt.b32 %r7, %r1, 0, 0x8880U;
-; O3-NEXT:    cvt.u16.u32 %rs15, %r7;
+; O3-NEXT:    cvt.u32.u16 %r3, %rs13;
+; O3-NEXT:    prmt.b32 %r4, %r2, 0, 0x9991U;
+; O3-NEXT:    cvt.u16.u32 %rs14, %r4;
+; O3-NEXT:    prmt.b32 %r5, %r1, 0, 0x9991U;
+; O3-NEXT:    cvt.u16.u32 %rs15, %r5;
 ; O3-NEXT:    rem.s16 %rs16, %rs15, %rs14;
-; O3-NEXT:    cvt.u32.u16 %r8, %rs16;
-; O3-NEXT:    prmt.b32 %r9, %r8, %r5, 0x3340U;
-; O3-NEXT:    prmt.b32 %r10, %r11, %r12, 0x3340U;
-; O3-NEXT:    prmt.b32 %r13, %r9, %r10, 0x5410U;
+; O3-NEXT:    cvt.u32.u16 %r6, %rs16;
+; O3-NEXT:    prmt.b32 %r7, %r3, %r6, 0x3340U;
 ; O3-NEXT:    rem.s16 %rs17, %rs5, %rs10;
 ; O3-NEXT:    st.b8 [%rd3+2], %rs17;
-; O3-NEXT:    cvt.u16.u32 %rs18, %r13;
+; O3-NEXT:    cvt.u16.u32 %rs18, %r7;
 ; O3-NEXT:    st.b8 [%rd3], %rs18;
 ; O3-NEXT:    shr.u16 %rs19, %rs18, 8;
 ; O3-NEXT:    st.b8 [%rd3+1], %rs19;

>From 1a143a9473f77cf512e5b6c16f5562379ef083c9 Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Fri, 18 Jul 2025 20:50:16 +0000
Subject: [PATCH 2/2] address comments

---
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 32 +++++++++++----------
 1 file changed, 17 insertions(+), 15 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index e9190b6097709..4414fac3793df 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -6546,11 +6546,10 @@ void NVPTXTargetLowering::computeKnownBitsForTargetNode(
   }
 }
 
-static void getPRMTDemandedBits(const APInt &SelectorVal,
-                                const APInt &DemandedBits, APInt &DemandedLHS,
-                                APInt &DemandedRHS) {
-  DemandedLHS = APInt(32, 0);
-  DemandedRHS = APInt(32, 0);
+static std::pair<APInt, APInt> getPRMTDemandedBits(const APInt &SelectorVal,
+                                const APInt &DemandedBits) {
+  APInt DemandedLHS = APInt(32, 0);
+  APInt DemandedRHS = APInt(32, 0);
 
   for (unsigned I : llvm::seq(4)) {
     if (DemandedBits.extractBits(8, I * 8).isZero())
@@ -6567,6 +6566,8 @@ static void getPRMTDemandedBits(const APInt &SelectorVal,
     else
       Src.setBits(ByteStart, ByteStart + 8);
   }
+
+  return {DemandedLHS, DemandedRHS};
 }
 
 // Replace undef with 0 as this is easier for other optimizations such as
@@ -6584,26 +6585,26 @@ static SDValue simplifyDemandedBitsForPRMT(SDValue PRMT,
                                            SelectionDAG &DAG,
                                            const TargetLowering &TLI,
                                            unsigned Depth) {
+  assert(PRMT.getOpcode() == NVPTXISD::PRMT);
   SDValue Op0 = PRMT.getOperand(0);
   SDValue Op1 = PRMT.getOperand(1);
-  ConstantSDNode *Selector = dyn_cast<ConstantSDNode>(PRMT.getOperand(2));
-  unsigned Mode = PRMT.getConstantOperandVal(3);
-  if (!Selector)
+  auto *SelectorConst = dyn_cast<ConstantSDNode>(PRMT.getOperand(2));
+  if (!SelectorConst)
     return SDValue();
 
-  const APInt SelectorVal = getPRMTSelector(Selector->getAPIntValue(), Mode);
+  unsigned Mode = PRMT.getConstantOperandVal(3);
+  const APInt Selector = getPRMTSelector(SelectorConst->getAPIntValue(), Mode);
 
   // Try to simplify the PRMT to one of the inputs if the used bytes are all
   // from the same input in the correct order.
   const unsigned LeadingBytes = DemandedBits.countLeadingZeros() / 8;
   const unsigned SelBits = (4 - LeadingBytes) * 4;
-  if (SelectorVal.getLoBits(SelBits) == APInt(32, 0x3210).getLoBits(SelBits))
+  if (Selector.getLoBits(SelBits) == APInt(32, 0x3210).getLoBits(SelBits))
     return Op0;
-  if (SelectorVal.getLoBits(SelBits) == APInt(32, 0x7654).getLoBits(SelBits))
+  if (Selector.getLoBits(SelBits) == APInt(32, 0x7654).getLoBits(SelBits))
     return Op1;
 
-  APInt DemandedLHS, DemandedRHS;
-  getPRMTDemandedBits(SelectorVal, DemandedBits, DemandedLHS, DemandedRHS);
+  auto [DemandedLHS, DemandedRHS] = getPRMTDemandedBits(Selector, DemandedBits);
 
   // Attempt to avoid multi-use ops if we don't need anything from them.
   SDValue DemandedOp0 =
@@ -6613,10 +6614,11 @@ static SDValue simplifyDemandedBitsForPRMT(SDValue PRMT,
 
   DemandedOp0 = canonicalizePRMTInput(DemandedOp0, DAG);
   DemandedOp1 = canonicalizePRMTInput(DemandedOp1, DAG);
-  if (DemandedOp0 != Op0 || DemandedOp1 != Op1) {
+  if ((DemandedOp0 && DemandedOp0 != Op0) ||
+      (DemandedOp1 && DemandedOp1 != Op1)) {
     Op0 = DemandedOp0 ? DemandedOp0 : Op0;
     Op1 = DemandedOp1 ? DemandedOp1 : Op1;
-    return getPRMT(Op0, Op1, SelectorVal.getZExtValue(), SDLoc(PRMT), DAG);
+    return getPRMT(Op0, Op1, Selector.getZExtValue(), SDLoc(PRMT), DAG);
   }
 
   return SDValue();



More information about the llvm-commits mailing list