[llvm] [NVPTX] Add NVPTXIncreaseAligmentPass to improve vectorization (PR #144958)

Alex MacLean via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 22 22:58:33 PDT 2025


https://github.com/AlexMaclean updated https://github.com/llvm/llvm-project/pull/144958

>From c4748eaab3baef63979f20158b947a05bb9ec968 Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Wed, 23 Jul 2025 02:15:03 +0000
Subject: [PATCH 1/4] [DAGCombiner] Fold setcc of trunc, generalizing some
 NVPTX isel logic

---
 .../CodeGen/SelectionDAG/TargetLowering.cpp   |  26 ++
 llvm/lib/Target/NVPTX/NVPTXInstrInfo.td       |  28 +-
 llvm/test/CodeGen/NVPTX/i8x4-instructions.ll  | 168 ++++++-----
 llvm/test/CodeGen/NVPTX/sext-setcc.ll         |  13 +-
 llvm/test/CodeGen/NVPTX/trunc-setcc.ll        | 269 ++++++++++++++++++
 5 files changed, 401 insertions(+), 103 deletions(-)
 create mode 100644 llvm/test/CodeGen/NVPTX/trunc-setcc.ll

diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 1764910861df4..a453c17877430 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -17,6 +17,7 @@
 #include "llvm/CodeGen/Analysis.h"
 #include "llvm/CodeGen/CallingConvLower.h"
 #include "llvm/CodeGen/CodeGenCommonISel.h"
+#include "llvm/CodeGen/ISDOpcodes.h"
 #include "llvm/CodeGen/MachineFrameInfo.h"
 #include "llvm/CodeGen/MachineFunction.h"
 #include "llvm/CodeGen/MachineJumpTableInfo.h"
@@ -5125,6 +5126,20 @@ SDValue TargetLowering::SimplifySetCC(EVT VT, SDValue N0, SDValue N1,
                           Cond == ISD::SETEQ ? ISD::SETLT : ISD::SETGE);
     }
 
+    // fold (setcc (trunc x) c) -> (setcc x c)
+    if (N0.getOpcode() == ISD::TRUNCATE &&
+        ((N0->getFlags().hasNoUnsignedWrap() && !ISD::isSignedIntSetCC(Cond)) ||
+         (N0->getFlags().hasNoSignedWrap() &&
+          !ISD::isUnsignedIntSetCC(Cond))) &&
+        isTypeDesirableForOp(ISD::SETCC, N0.getOperand(0).getValueType())) {
+      EVT NewVT = N0.getOperand(0).getValueType();
+      SDValue NewConst = DAG.getConstant(ISD::isSignedIntSetCC(Cond)
+                                             ? C1.sext(NewVT.getSizeInBits())
+                                             : C1.zext(NewVT.getSizeInBits()),
+                                         dl, NewVT);
+      return DAG.getSetCC(dl, VT, N0.getOperand(0), NewConst, Cond);
+    }
+
     if (SDValue V =
             optimizeSetCCOfSignedTruncationCheck(VT, N0, N1, Cond, DCI, dl))
       return V;
@@ -5646,6 +5661,17 @@ SDValue TargetLowering::SimplifySetCC(EVT VT, SDValue N0, SDValue N1,
     return N0;
   }
 
+  // Fold (setcc (trunc x) (trunc y)) -> (setcc x y)
+  if (N0.getOpcode() == ISD::TRUNCATE && N1.getOpcode() == ISD::TRUNCATE &&
+      N0.getOperand(0).getValueType() == N1.getOperand(0).getValueType() &&
+      ((!ISD::isSignedIntSetCC(Cond) && N0->getFlags().hasNoUnsignedWrap() &&
+        N1->getFlags().hasNoUnsignedWrap()) ||
+       (!ISD::isUnsignedIntSetCC(Cond) && N0->getFlags().hasNoSignedWrap() &&
+        N1->getFlags().hasNoSignedWrap())) &&
+      isTypeDesirableForOp(ISD::SETCC, N0.getOperand(0).getValueType())) {
+    return DAG.getSetCC(dl, VT, N0.getOperand(0), N1.getOperand(0), Cond);
+  }
+
   // Could not fold it.
   return SDValue();
 }
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index b5df4c6de7fd8..8043f678e0fcc 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -1714,40 +1714,16 @@ def cond_signed : PatLeaf<(cond), [{
   return isSignedIntSetCC(N->get());
 }]>;
 
-def cond_not_signed : PatLeaf<(cond), [{
-  return !isSignedIntSetCC(N->get());
-}]>;
-
 // comparisons of i8 extracted with PRMT as i32
 // It's faster to do comparison directly on i32 extracted by PRMT,
 // instead of the long conversion and sign extending.
-def: Pat<(setcc (i16 (sext_inreg (i16 (trunc (prmt i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE))), i8)),
-                (i16 (sext_inreg (i16 (trunc (prmt i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE))), i8)),
-                cond_signed:$cc),
-         (SETP_i32rr (PRMT_B32rii i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE),
-                     (PRMT_B32rii i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE), 
-                     (cond2cc $cc))>;
-
 def: Pat<(setcc (i16 (sext_inreg (trunc (prmt i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE)), i8)),
                 (i16 (sext_inreg (trunc (prmt i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE)), i8)),
                 cond_signed:$cc),
-         (SETP_i32rr (PRMT_B32rii i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE),
-                     (PRMT_B32rii i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE), 
+         (SETP_i32rr (PRMT_B32rii i32:$a, 0, (to_sign_extend_selector $sel_a), PrmtNONE),
+                     (PRMT_B32rii i32:$b, 0, (to_sign_extend_selector $sel_b), PrmtNONE), 
                      (cond2cc $cc))>;
 
-def: Pat<(setcc (i16 (trunc (prmt i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE))),
-                (i16 (trunc (prmt i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE))),
-                cond_signed:$cc),
-         (SETP_i32rr (PRMT_B32rii i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE),
-                     (PRMT_B32rii i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE),
-                     (cond2cc $cc))>;
-
-def: Pat<(setcc (i16 (trunc (prmt i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE))),
-                (i16 (trunc (prmt i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE))),
-                cond_not_signed:$cc),
-         (SETP_i32rr (PRMT_B32rii i32:$a, 0, byte_extract_prmt:$sel_a, PrmtNONE),
-                     (PRMT_B32rii i32:$b, 0, byte_extract_prmt:$sel_b, PrmtNONE), 
-                     (cond2cc $cc))>;
 
 def SDTDeclareArrayParam :
   SDTypeProfile<0, 3, [SDTCisVT<0, i32>, SDTCisVT<1, i32>, SDTCisVT<2, i32>]>;
diff --git a/llvm/test/CodeGen/NVPTX/i8x4-instructions.ll b/llvm/test/CodeGen/NVPTX/i8x4-instructions.ll
index da99cec0669ed..f2a2b171d5ca2 100644
--- a/llvm/test/CodeGen/NVPTX/i8x4-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/i8x4-instructions.ll
@@ -343,61 +343,77 @@ define <4 x i8> @test_smax(<4 x i8> %a, <4 x i8> %b) #0 {
 ; O0-LABEL: test_smax(
 ; O0:       {
 ; O0-NEXT:    .reg .pred %p<5>;
-; O0-NEXT:    .reg .b32 %r<18>;
+; O0-NEXT:    .reg .b32 %r<26>;
 ; O0-EMPTY:
 ; O0-NEXT:  // %bb.0:
 ; O0-NEXT:    ld.param.b32 %r2, [test_smax_param_1];
 ; O0-NEXT:    ld.param.b32 %r1, [test_smax_param_0];
-; O0-NEXT:    prmt.b32 %r3, %r2, 0, 0x7770U;
-; O0-NEXT:    prmt.b32 %r4, %r1, 0, 0x7770U;
+; O0-NEXT:    prmt.b32 %r3, %r2, 0, 0x8880U;
+; O0-NEXT:    prmt.b32 %r4, %r1, 0, 0x8880U;
 ; O0-NEXT:    setp.gt.s32 %p1, %r4, %r3;
-; O0-NEXT:    prmt.b32 %r5, %r2, 0, 0x7771U;
-; O0-NEXT:    prmt.b32 %r6, %r1, 0, 0x7771U;
+; O0-NEXT:    prmt.b32 %r5, %r2, 0, 0x9991U;
+; O0-NEXT:    prmt.b32 %r6, %r1, 0, 0x9991U;
 ; O0-NEXT:    setp.gt.s32 %p2, %r6, %r5;
-; O0-NEXT:    prmt.b32 %r7, %r2, 0, 0x7772U;
-; O0-NEXT:    prmt.b32 %r8, %r1, 0, 0x7772U;
+; O0-NEXT:    prmt.b32 %r7, %r2, 0, 0xaaa2U;
+; O0-NEXT:    prmt.b32 %r8, %r1, 0, 0xaaa2U;
 ; O0-NEXT:    setp.gt.s32 %p3, %r8, %r7;
-; O0-NEXT:    prmt.b32 %r9, %r2, 0, 0x7773U;
-; O0-NEXT:    prmt.b32 %r10, %r1, 0, 0x7773U;
+; O0-NEXT:    prmt.b32 %r9, %r2, 0, 0xbbb3U;
+; O0-NEXT:    prmt.b32 %r10, %r1, 0, 0xbbb3U;
 ; O0-NEXT:    setp.gt.s32 %p4, %r10, %r9;
-; O0-NEXT:    selp.b32 %r11, %r10, %r9, %p4;
-; O0-NEXT:    selp.b32 %r12, %r8, %r7, %p3;
-; O0-NEXT:    prmt.b32 %r13, %r12, %r11, 0x3340U;
-; O0-NEXT:    selp.b32 %r14, %r6, %r5, %p2;
-; O0-NEXT:    selp.b32 %r15, %r4, %r3, %p1;
-; O0-NEXT:    prmt.b32 %r16, %r15, %r14, 0x3340U;
-; O0-NEXT:    prmt.b32 %r17, %r16, %r13, 0x5410U;
-; O0-NEXT:    st.param.b32 [func_retval0], %r17;
+; O0-NEXT:    prmt.b32 %r11, %r2, 0, 0x7770U;
+; O0-NEXT:    prmt.b32 %r12, %r2, 0, 0x7771U;
+; O0-NEXT:    prmt.b32 %r13, %r2, 0, 0x7772U;
+; O0-NEXT:    prmt.b32 %r14, %r2, 0, 0x7773U;
+; O0-NEXT:    prmt.b32 %r15, %r1, 0, 0x7773U;
+; O0-NEXT:    selp.b32 %r16, %r15, %r14, %p4;
+; O0-NEXT:    prmt.b32 %r17, %r1, 0, 0x7772U;
+; O0-NEXT:    selp.b32 %r18, %r17, %r13, %p3;
+; O0-NEXT:    prmt.b32 %r19, %r18, %r16, 0x3340U;
+; O0-NEXT:    prmt.b32 %r20, %r1, 0, 0x7771U;
+; O0-NEXT:    selp.b32 %r21, %r20, %r12, %p2;
+; O0-NEXT:    prmt.b32 %r22, %r1, 0, 0x7770U;
+; O0-NEXT:    selp.b32 %r23, %r22, %r11, %p1;
+; O0-NEXT:    prmt.b32 %r24, %r23, %r21, 0x3340U;
+; O0-NEXT:    prmt.b32 %r25, %r24, %r19, 0x5410U;
+; O0-NEXT:    st.param.b32 [func_retval0], %r25;
 ; O0-NEXT:    ret;
 ;
 ; O3-LABEL: test_smax(
 ; O3:       {
 ; O3-NEXT:    .reg .pred %p<5>;
-; O3-NEXT:    .reg .b32 %r<18>;
+; O3-NEXT:    .reg .b32 %r<26>;
 ; O3-EMPTY:
 ; O3-NEXT:  // %bb.0:
 ; O3-NEXT:    ld.param.b32 %r1, [test_smax_param_0];
 ; O3-NEXT:    ld.param.b32 %r2, [test_smax_param_1];
-; O3-NEXT:    prmt.b32 %r3, %r2, 0, 0x7770U;
-; O3-NEXT:    prmt.b32 %r4, %r1, 0, 0x7770U;
+; O3-NEXT:    prmt.b32 %r3, %r2, 0, 0x8880U;
+; O3-NEXT:    prmt.b32 %r4, %r1, 0, 0x8880U;
 ; O3-NEXT:    setp.gt.s32 %p1, %r4, %r3;
-; O3-NEXT:    prmt.b32 %r5, %r2, 0, 0x7771U;
-; O3-NEXT:    prmt.b32 %r6, %r1, 0, 0x7771U;
+; O3-NEXT:    prmt.b32 %r5, %r2, 0, 0x9991U;
+; O3-NEXT:    prmt.b32 %r6, %r1, 0, 0x9991U;
 ; O3-NEXT:    setp.gt.s32 %p2, %r6, %r5;
-; O3-NEXT:    prmt.b32 %r7, %r2, 0, 0x7772U;
-; O3-NEXT:    prmt.b32 %r8, %r1, 0, 0x7772U;
+; O3-NEXT:    prmt.b32 %r7, %r2, 0, 0xaaa2U;
+; O3-NEXT:    prmt.b32 %r8, %r1, 0, 0xaaa2U;
 ; O3-NEXT:    setp.gt.s32 %p3, %r8, %r7;
-; O3-NEXT:    prmt.b32 %r9, %r2, 0, 0x7773U;
-; O3-NEXT:    prmt.b32 %r10, %r1, 0, 0x7773U;
+; O3-NEXT:    prmt.b32 %r9, %r2, 0, 0xbbb3U;
+; O3-NEXT:    prmt.b32 %r10, %r1, 0, 0xbbb3U;
 ; O3-NEXT:    setp.gt.s32 %p4, %r10, %r9;
-; O3-NEXT:    selp.b32 %r11, %r10, %r9, %p4;
-; O3-NEXT:    selp.b32 %r12, %r8, %r7, %p3;
-; O3-NEXT:    prmt.b32 %r13, %r12, %r11, 0x3340U;
-; O3-NEXT:    selp.b32 %r14, %r6, %r5, %p2;
-; O3-NEXT:    selp.b32 %r15, %r4, %r3, %p1;
-; O3-NEXT:    prmt.b32 %r16, %r15, %r14, 0x3340U;
-; O3-NEXT:    prmt.b32 %r17, %r16, %r13, 0x5410U;
-; O3-NEXT:    st.param.b32 [func_retval0], %r17;
+; O3-NEXT:    prmt.b32 %r11, %r2, 0, 0x7770U;
+; O3-NEXT:    prmt.b32 %r12, %r2, 0, 0x7771U;
+; O3-NEXT:    prmt.b32 %r13, %r2, 0, 0x7772U;
+; O3-NEXT:    prmt.b32 %r14, %r2, 0, 0x7773U;
+; O3-NEXT:    prmt.b32 %r15, %r1, 0, 0x7773U;
+; O3-NEXT:    selp.b32 %r16, %r15, %r14, %p4;
+; O3-NEXT:    prmt.b32 %r17, %r1, 0, 0x7772U;
+; O3-NEXT:    selp.b32 %r18, %r17, %r13, %p3;
+; O3-NEXT:    prmt.b32 %r19, %r18, %r16, 0x3340U;
+; O3-NEXT:    prmt.b32 %r20, %r1, 0, 0x7771U;
+; O3-NEXT:    selp.b32 %r21, %r20, %r12, %p2;
+; O3-NEXT:    prmt.b32 %r22, %r1, 0, 0x7770U;
+; O3-NEXT:    selp.b32 %r23, %r22, %r11, %p1;
+; O3-NEXT:    prmt.b32 %r24, %r23, %r21, 0x3340U;
+; O3-NEXT:    prmt.b32 %r25, %r24, %r19, 0x5410U;
+; O3-NEXT:    st.param.b32 [func_retval0], %r25;
 ; O3-NEXT:    ret;
   %cmp = icmp sgt <4 x i8> %a, %b
   %r = select <4 x i1> %cmp, <4 x i8> %a, <4 x i8> %b
@@ -473,61 +489,77 @@ define <4 x i8> @test_smin(<4 x i8> %a, <4 x i8> %b) #0 {
 ; O0-LABEL: test_smin(
 ; O0:       {
 ; O0-NEXT:    .reg .pred %p<5>;
-; O0-NEXT:    .reg .b32 %r<18>;
+; O0-NEXT:    .reg .b32 %r<26>;
 ; O0-EMPTY:
 ; O0-NEXT:  // %bb.0:
 ; O0-NEXT:    ld.param.b32 %r2, [test_smin_param_1];
 ; O0-NEXT:    ld.param.b32 %r1, [test_smin_param_0];
-; O0-NEXT:    prmt.b32 %r3, %r2, 0, 0x7770U;
-; O0-NEXT:    prmt.b32 %r4, %r1, 0, 0x7770U;
+; O0-NEXT:    prmt.b32 %r3, %r2, 0, 0x8880U;
+; O0-NEXT:    prmt.b32 %r4, %r1, 0, 0x8880U;
 ; O0-NEXT:    setp.le.s32 %p1, %r4, %r3;
-; O0-NEXT:    prmt.b32 %r5, %r2, 0, 0x7771U;
-; O0-NEXT:    prmt.b32 %r6, %r1, 0, 0x7771U;
+; O0-NEXT:    prmt.b32 %r5, %r2, 0, 0x9991U;
+; O0-NEXT:    prmt.b32 %r6, %r1, 0, 0x9991U;
 ; O0-NEXT:    setp.le.s32 %p2, %r6, %r5;
-; O0-NEXT:    prmt.b32 %r7, %r2, 0, 0x7772U;
-; O0-NEXT:    prmt.b32 %r8, %r1, 0, 0x7772U;
+; O0-NEXT:    prmt.b32 %r7, %r2, 0, 0xaaa2U;
+; O0-NEXT:    prmt.b32 %r8, %r1, 0, 0xaaa2U;
 ; O0-NEXT:    setp.le.s32 %p3, %r8, %r7;
-; O0-NEXT:    prmt.b32 %r9, %r2, 0, 0x7773U;
-; O0-NEXT:    prmt.b32 %r10, %r1, 0, 0x7773U;
+; O0-NEXT:    prmt.b32 %r9, %r2, 0, 0xbbb3U;
+; O0-NEXT:    prmt.b32 %r10, %r1, 0, 0xbbb3U;
 ; O0-NEXT:    setp.le.s32 %p4, %r10, %r9;
-; O0-NEXT:    selp.b32 %r11, %r10, %r9, %p4;
-; O0-NEXT:    selp.b32 %r12, %r8, %r7, %p3;
-; O0-NEXT:    prmt.b32 %r13, %r12, %r11, 0x3340U;
-; O0-NEXT:    selp.b32 %r14, %r6, %r5, %p2;
-; O0-NEXT:    selp.b32 %r15, %r4, %r3, %p1;
-; O0-NEXT:    prmt.b32 %r16, %r15, %r14, 0x3340U;
-; O0-NEXT:    prmt.b32 %r17, %r16, %r13, 0x5410U;
-; O0-NEXT:    st.param.b32 [func_retval0], %r17;
+; O0-NEXT:    prmt.b32 %r11, %r2, 0, 0x7770U;
+; O0-NEXT:    prmt.b32 %r12, %r2, 0, 0x7771U;
+; O0-NEXT:    prmt.b32 %r13, %r2, 0, 0x7772U;
+; O0-NEXT:    prmt.b32 %r14, %r2, 0, 0x7773U;
+; O0-NEXT:    prmt.b32 %r15, %r1, 0, 0x7773U;
+; O0-NEXT:    selp.b32 %r16, %r15, %r14, %p4;
+; O0-NEXT:    prmt.b32 %r17, %r1, 0, 0x7772U;
+; O0-NEXT:    selp.b32 %r18, %r17, %r13, %p3;
+; O0-NEXT:    prmt.b32 %r19, %r18, %r16, 0x3340U;
+; O0-NEXT:    prmt.b32 %r20, %r1, 0, 0x7771U;
+; O0-NEXT:    selp.b32 %r21, %r20, %r12, %p2;
+; O0-NEXT:    prmt.b32 %r22, %r1, 0, 0x7770U;
+; O0-NEXT:    selp.b32 %r23, %r22, %r11, %p1;
+; O0-NEXT:    prmt.b32 %r24, %r23, %r21, 0x3340U;
+; O0-NEXT:    prmt.b32 %r25, %r24, %r19, 0x5410U;
+; O0-NEXT:    st.param.b32 [func_retval0], %r25;
 ; O0-NEXT:    ret;
 ;
 ; O3-LABEL: test_smin(
 ; O3:       {
 ; O3-NEXT:    .reg .pred %p<5>;
-; O3-NEXT:    .reg .b32 %r<18>;
+; O3-NEXT:    .reg .b32 %r<26>;
 ; O3-EMPTY:
 ; O3-NEXT:  // %bb.0:
 ; O3-NEXT:    ld.param.b32 %r1, [test_smin_param_0];
 ; O3-NEXT:    ld.param.b32 %r2, [test_smin_param_1];
-; O3-NEXT:    prmt.b32 %r3, %r2, 0, 0x7770U;
-; O3-NEXT:    prmt.b32 %r4, %r1, 0, 0x7770U;
+; O3-NEXT:    prmt.b32 %r3, %r2, 0, 0x8880U;
+; O3-NEXT:    prmt.b32 %r4, %r1, 0, 0x8880U;
 ; O3-NEXT:    setp.le.s32 %p1, %r4, %r3;
-; O3-NEXT:    prmt.b32 %r5, %r2, 0, 0x7771U;
-; O3-NEXT:    prmt.b32 %r6, %r1, 0, 0x7771U;
+; O3-NEXT:    prmt.b32 %r5, %r2, 0, 0x9991U;
+; O3-NEXT:    prmt.b32 %r6, %r1, 0, 0x9991U;
 ; O3-NEXT:    setp.le.s32 %p2, %r6, %r5;
-; O3-NEXT:    prmt.b32 %r7, %r2, 0, 0x7772U;
-; O3-NEXT:    prmt.b32 %r8, %r1, 0, 0x7772U;
+; O3-NEXT:    prmt.b32 %r7, %r2, 0, 0xaaa2U;
+; O3-NEXT:    prmt.b32 %r8, %r1, 0, 0xaaa2U;
 ; O3-NEXT:    setp.le.s32 %p3, %r8, %r7;
-; O3-NEXT:    prmt.b32 %r9, %r2, 0, 0x7773U;
-; O3-NEXT:    prmt.b32 %r10, %r1, 0, 0x7773U;
+; O3-NEXT:    prmt.b32 %r9, %r2, 0, 0xbbb3U;
+; O3-NEXT:    prmt.b32 %r10, %r1, 0, 0xbbb3U;
 ; O3-NEXT:    setp.le.s32 %p4, %r10, %r9;
-; O3-NEXT:    selp.b32 %r11, %r10, %r9, %p4;
-; O3-NEXT:    selp.b32 %r12, %r8, %r7, %p3;
-; O3-NEXT:    prmt.b32 %r13, %r12, %r11, 0x3340U;
-; O3-NEXT:    selp.b32 %r14, %r6, %r5, %p2;
-; O3-NEXT:    selp.b32 %r15, %r4, %r3, %p1;
-; O3-NEXT:    prmt.b32 %r16, %r15, %r14, 0x3340U;
-; O3-NEXT:    prmt.b32 %r17, %r16, %r13, 0x5410U;
-; O3-NEXT:    st.param.b32 [func_retval0], %r17;
+; O3-NEXT:    prmt.b32 %r11, %r2, 0, 0x7770U;
+; O3-NEXT:    prmt.b32 %r12, %r2, 0, 0x7771U;
+; O3-NEXT:    prmt.b32 %r13, %r2, 0, 0x7772U;
+; O3-NEXT:    prmt.b32 %r14, %r2, 0, 0x7773U;
+; O3-NEXT:    prmt.b32 %r15, %r1, 0, 0x7773U;
+; O3-NEXT:    selp.b32 %r16, %r15, %r14, %p4;
+; O3-NEXT:    prmt.b32 %r17, %r1, 0, 0x7772U;
+; O3-NEXT:    selp.b32 %r18, %r17, %r13, %p3;
+; O3-NEXT:    prmt.b32 %r19, %r18, %r16, 0x3340U;
+; O3-NEXT:    prmt.b32 %r20, %r1, 0, 0x7771U;
+; O3-NEXT:    selp.b32 %r21, %r20, %r12, %p2;
+; O3-NEXT:    prmt.b32 %r22, %r1, 0, 0x7770U;
+; O3-NEXT:    selp.b32 %r23, %r22, %r11, %p1;
+; O3-NEXT:    prmt.b32 %r24, %r23, %r21, 0x3340U;
+; O3-NEXT:    prmt.b32 %r25, %r24, %r19, 0x5410U;
+; O3-NEXT:    st.param.b32 [func_retval0], %r25;
 ; O3-NEXT:    ret;
   %cmp = icmp sle <4 x i8> %a, %b
   %r = select <4 x i1> %cmp, <4 x i8> %a, <4 x i8> %b
diff --git a/llvm/test/CodeGen/NVPTX/sext-setcc.ll b/llvm/test/CodeGen/NVPTX/sext-setcc.ll
index 9a67bdfeb067b..97918a6f26cdf 100644
--- a/llvm/test/CodeGen/NVPTX/sext-setcc.ll
+++ b/llvm/test/CodeGen/NVPTX/sext-setcc.ll
@@ -29,7 +29,6 @@ define <4 x i8> @sext_setcc_v4i1_to_v4i8(ptr %p) {
 ; CHECK-LABEL: sext_setcc_v4i1_to_v4i8(
 ; CHECK:       {
 ; CHECK-NEXT:    .reg .pred %p<5>;
-; CHECK-NEXT:    .reg .b16 %rs<5>;
 ; CHECK-NEXT:    .reg .b32 %r<13>;
 ; CHECK-NEXT:    .reg .b64 %rd<2>;
 ; CHECK-EMPTY:
@@ -37,17 +36,13 @@ define <4 x i8> @sext_setcc_v4i1_to_v4i8(ptr %p) {
 ; CHECK-NEXT:    ld.param.b64 %rd1, [sext_setcc_v4i1_to_v4i8_param_0];
 ; CHECK-NEXT:    ld.b32 %r1, [%rd1];
 ; CHECK-NEXT:    prmt.b32 %r2, %r1, 0, 0x7770U;
-; CHECK-NEXT:    cvt.u16.u32 %rs1, %r2;
-; CHECK-NEXT:    setp.eq.b16 %p1, %rs1, 0;
+; CHECK-NEXT:    setp.eq.b32 %p1, %r2, 0;
 ; CHECK-NEXT:    prmt.b32 %r3, %r1, 0, 0x7771U;
-; CHECK-NEXT:    cvt.u16.u32 %rs2, %r3;
-; CHECK-NEXT:    setp.eq.b16 %p2, %rs2, 0;
+; CHECK-NEXT:    setp.eq.b32 %p2, %r3, 0;
 ; CHECK-NEXT:    prmt.b32 %r4, %r1, 0, 0x7772U;
-; CHECK-NEXT:    cvt.u16.u32 %rs3, %r4;
-; CHECK-NEXT:    setp.eq.b16 %p3, %rs3, 0;
+; CHECK-NEXT:    setp.eq.b32 %p3, %r4, 0;
 ; CHECK-NEXT:    prmt.b32 %r5, %r1, 0, 0x7773U;
-; CHECK-NEXT:    cvt.u16.u32 %rs4, %r5;
-; CHECK-NEXT:    setp.eq.b16 %p4, %rs4, 0;
+; CHECK-NEXT:    setp.eq.b32 %p4, %r5, 0;
 ; CHECK-NEXT:    selp.b32 %r6, -1, 0, %p4;
 ; CHECK-NEXT:    selp.b32 %r7, -1, 0, %p3;
 ; CHECK-NEXT:    prmt.b32 %r8, %r7, %r6, 0x3340U;
diff --git a/llvm/test/CodeGen/NVPTX/trunc-setcc.ll b/llvm/test/CodeGen/NVPTX/trunc-setcc.ll
new file mode 100644
index 0000000000000..f22e37e203966
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/trunc-setcc.ll
@@ -0,0 +1,269 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -mcpu=sm_50 | FileCheck %s
+; RUN: %if ptxas %{ llc < %s -mcpu=sm_50 | %ptxas-verify -arch=sm_50 %}
+
+target triple = "nvptx64-nvidia-cuda"
+
+define i1 @trunc_nsw_singed_const(i32 %a) {
+; CHECK-LABEL: trunc_nsw_singed_const(
+; CHECK:       {
+; CHECK-NEXT:    .reg .pred %p<2>;
+; CHECK-NEXT:    .reg .b32 %r<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b32 %r1, [trunc_nsw_singed_const_param_0];
+; CHECK-NEXT:    add.s32 %r2, %r1, 1;
+; CHECK-NEXT:    setp.gt.s32 %p1, %r2, -1;
+; CHECK-NEXT:    selp.b32 %r3, -1, 0, %p1;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r3;
+; CHECK-NEXT:    ret;
+  %a2 = add i32 %a, 1
+  %b = trunc nsw i32 %a2 to i8
+  %c = icmp sgt i8 %b, -1
+  ret i1 %c
+}
+
+define i1 @trunc_nuw_singed_const(i32 %a) {
+; CHECK-LABEL: trunc_nuw_singed_const(
+; CHECK:       {
+; CHECK-NEXT:    .reg .pred %p<2>;
+; CHECK-NEXT:    .reg .b16 %rs<4>;
+; CHECK-NEXT:    .reg .b32 %r<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b8 %rs1, [trunc_nuw_singed_const_param_0];
+; CHECK-NEXT:    add.s16 %rs2, %rs1, 1;
+; CHECK-NEXT:    cvt.s16.s8 %rs3, %rs2;
+; CHECK-NEXT:    setp.lt.s16 %p1, %rs3, 100;
+; CHECK-NEXT:    selp.b32 %r1, -1, 0, %p1;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT:    ret;
+  %a2 = add i32 %a, 1
+  %b = trunc nuw i32 %a2 to i8
+  %c = icmp slt i8 %b, 100
+  ret i1 %c
+}
+
+define i1 @trunc_nsw_unsinged_const(i32 %a) {
+; CHECK-LABEL: trunc_nsw_unsinged_const(
+; CHECK:       {
+; CHECK-NEXT:    .reg .pred %p<2>;
+; CHECK-NEXT:    .reg .b16 %rs<4>;
+; CHECK-NEXT:    .reg .b32 %r<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b8 %rs1, [trunc_nsw_unsinged_const_param_0];
+; CHECK-NEXT:    add.s16 %rs2, %rs1, 1;
+; CHECK-NEXT:    and.b16 %rs3, %rs2, 255;
+; CHECK-NEXT:    setp.lt.u16 %p1, %rs3, 236;
+; CHECK-NEXT:    selp.b32 %r1, -1, 0, %p1;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT:    ret;
+  %a2 = add i32 %a, 1
+  %b = trunc nsw i32 %a2 to i8
+  %c = icmp ult i8 %b, -20
+  ret i1 %c
+}
+
+define i1 @trunc_nuw_unsinged_const(i32 %a) {
+; CHECK-LABEL: trunc_nuw_unsinged_const(
+; CHECK:       {
+; CHECK-NEXT:    .reg .pred %p<2>;
+; CHECK-NEXT:    .reg .b32 %r<4>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b32 %r1, [trunc_nuw_unsinged_const_param_0];
+; CHECK-NEXT:    add.s32 %r2, %r1, 1;
+; CHECK-NEXT:    setp.gt.u32 %p1, %r2, 100;
+; CHECK-NEXT:    selp.b32 %r3, -1, 0, %p1;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r3;
+; CHECK-NEXT:    ret;
+  %a2 = add i32 %a, 1
+  %b = trunc nuw i32 %a2 to i8
+  %c = icmp ugt i8 %b, 100
+  ret i1 %c
+}
+
+
+define i1 @trunc_nsw_eq_const(i32 %a) {
+; CHECK-LABEL: trunc_nsw_eq_const(
+; CHECK:       {
+; CHECK-NEXT:    .reg .pred %p<2>;
+; CHECK-NEXT:    .reg .b32 %r<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b32 %r1, [trunc_nsw_eq_const_param_0];
+; CHECK-NEXT:    setp.eq.b32 %p1, %r1, 99;
+; CHECK-NEXT:    selp.b32 %r2, -1, 0, %p1;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r2;
+; CHECK-NEXT:    ret;
+  %a2 = add i32 %a, 1
+  %b = trunc nsw i32 %a2 to i8
+  %c = icmp eq i8 %b, 100
+  ret i1 %c
+}
+
+define i1 @trunc_nuw_eq_const(i32 %a) {
+; CHECK-LABEL: trunc_nuw_eq_const(
+; CHECK:       {
+; CHECK-NEXT:    .reg .pred %p<2>;
+; CHECK-NEXT:    .reg .b32 %r<3>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b32 %r1, [trunc_nuw_eq_const_param_0];
+; CHECK-NEXT:    setp.eq.b32 %p1, %r1, 99;
+; CHECK-NEXT:    selp.b32 %r2, -1, 0, %p1;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r2;
+; CHECK-NEXT:    ret;
+  %a2 = add i32 %a, 1
+  %b = trunc nuw i32 %a2 to i8
+  %c = icmp eq i8 %b, 100
+  ret i1 %c
+}
+
+;;;
+
+define i1 @trunc_nsw_singed(i32 %a1, i32 %a2) {
+; CHECK-LABEL: trunc_nsw_singed(
+; CHECK:       {
+; CHECK-NEXT:    .reg .pred %p<2>;
+; CHECK-NEXT:    .reg .b32 %r<6>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b32 %r1, [trunc_nsw_singed_param_0];
+; CHECK-NEXT:    add.s32 %r2, %r1, 1;
+; CHECK-NEXT:    ld.param.b32 %r3, [trunc_nsw_singed_param_1];
+; CHECK-NEXT:    add.s32 %r4, %r3, 7;
+; CHECK-NEXT:    setp.gt.s32 %p1, %r2, %r4;
+; CHECK-NEXT:    selp.b32 %r5, -1, 0, %p1;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r5;
+; CHECK-NEXT:    ret;
+  %b1 = add i32 %a1, 1
+  %b2 = add i32 %a2, 7
+  %c1 = trunc nsw i32 %b1 to i8
+  %c2 = trunc nsw i32 %b2 to i8
+  %c = icmp sgt i8 %c1, %c2
+  ret i1 %c
+}
+
+define i1 @trunc_nuw_singed(i32 %a1, i32 %a2) {
+; CHECK-LABEL: trunc_nuw_singed(
+; CHECK:       {
+; CHECK-NEXT:    .reg .pred %p<2>;
+; CHECK-NEXT:    .reg .b16 %rs<7>;
+; CHECK-NEXT:    .reg .b32 %r<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b8 %rs1, [trunc_nuw_singed_param_0];
+; CHECK-NEXT:    ld.param.b8 %rs2, [trunc_nuw_singed_param_1];
+; CHECK-NEXT:    add.s16 %rs3, %rs1, 1;
+; CHECK-NEXT:    cvt.s16.s8 %rs4, %rs3;
+; CHECK-NEXT:    add.s16 %rs5, %rs2, 6;
+; CHECK-NEXT:    cvt.s16.s8 %rs6, %rs5;
+; CHECK-NEXT:    setp.lt.s16 %p1, %rs4, %rs6;
+; CHECK-NEXT:    selp.b32 %r1, -1, 0, %p1;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT:    ret;
+  %b1 = add i32 %a1, 1
+  %b2 = add i32 %a2, 6
+  %c1 = trunc nuw i32 %b1 to i8
+  %c2 = trunc nuw i32 %b2 to i8
+  %c = icmp slt i8 %c1, %c2
+  ret i1 %c
+}
+
+define i1 @trunc_nsw_unsinged(i32 %a1, i32 %a2) {
+; CHECK-LABEL: trunc_nsw_unsinged(
+; CHECK:       {
+; CHECK-NEXT:    .reg .pred %p<2>;
+; CHECK-NEXT:    .reg .b16 %rs<7>;
+; CHECK-NEXT:    .reg .b32 %r<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b8 %rs1, [trunc_nsw_unsinged_param_0];
+; CHECK-NEXT:    ld.param.b8 %rs2, [trunc_nsw_unsinged_param_1];
+; CHECK-NEXT:    add.s16 %rs3, %rs1, 1;
+; CHECK-NEXT:    and.b16 %rs4, %rs3, 255;
+; CHECK-NEXT:    add.s16 %rs5, %rs2, 4;
+; CHECK-NEXT:    and.b16 %rs6, %rs5, 255;
+; CHECK-NEXT:    setp.lt.u16 %p1, %rs4, %rs6;
+; CHECK-NEXT:    selp.b32 %r1, -1, 0, %p1;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r1;
+; CHECK-NEXT:    ret;
+  %b1 = add i32 %a1, 1
+  %b2 = add i32 %a2, 4
+  %c1 = trunc nsw i32 %b1 to i8
+  %c2 = trunc nsw i32 %b2 to i8
+  %c = icmp ult i8 %c1, %c2
+  ret i1 %c
+}
+
+define i1 @trunc_nuw_unsinged(i32 %a1, i32 %a2) {
+; CHECK-LABEL: trunc_nuw_unsinged(
+; CHECK:       {
+; CHECK-NEXT:    .reg .pred %p<2>;
+; CHECK-NEXT:    .reg .b32 %r<6>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b32 %r1, [trunc_nuw_unsinged_param_0];
+; CHECK-NEXT:    add.s32 %r2, %r1, 1;
+; CHECK-NEXT:    ld.param.b32 %r3, [trunc_nuw_unsinged_param_1];
+; CHECK-NEXT:    add.s32 %r4, %r3, 5;
+; CHECK-NEXT:    setp.gt.u32 %p1, %r2, %r4;
+; CHECK-NEXT:    selp.b32 %r5, -1, 0, %p1;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r5;
+; CHECK-NEXT:    ret;
+  %b1 = add i32 %a1, 1
+  %b2 = add i32 %a2, 5
+  %c1 = trunc nuw i32 %b1 to i8
+  %c2 = trunc nuw i32 %b2 to i8
+  %c = icmp ugt i8 %c1, %c2
+  ret i1 %c
+}
+
+
+define i1 @trunc_nsw_eq(i32 %a1, i32 %a2) {
+; CHECK-LABEL: trunc_nsw_eq(
+; CHECK:       {
+; CHECK-NEXT:    .reg .pred %p<2>;
+; CHECK-NEXT:    .reg .b32 %r<6>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b32 %r1, [trunc_nsw_eq_param_0];
+; CHECK-NEXT:    add.s32 %r2, %r1, 1;
+; CHECK-NEXT:    ld.param.b32 %r3, [trunc_nsw_eq_param_1];
+; CHECK-NEXT:    add.s32 %r4, %r3, 3;
+; CHECK-NEXT:    setp.eq.b32 %p1, %r2, %r4;
+; CHECK-NEXT:    selp.b32 %r5, -1, 0, %p1;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r5;
+; CHECK-NEXT:    ret;
+  %b1 = add i32 %a1, 1
+  %b2 = add i32 %a2, 3
+  %c1 = trunc nsw i32 %b1 to i8
+  %c2 = trunc nsw i32 %b2 to i8
+  %c = icmp eq i8 %c1, %c2
+  ret i1 %c
+}
+
+define i1 @trunc_nuw_eq(i32 %a1, i32 %a2) {
+; CHECK-LABEL: trunc_nuw_eq(
+; CHECK:       {
+; CHECK-NEXT:    .reg .pred %p<2>;
+; CHECK-NEXT:    .reg .b32 %r<6>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b32 %r1, [trunc_nuw_eq_param_0];
+; CHECK-NEXT:    add.s32 %r2, %r1, 2;
+; CHECK-NEXT:    ld.param.b32 %r3, [trunc_nuw_eq_param_1];
+; CHECK-NEXT:    add.s32 %r4, %r3, 1;
+; CHECK-NEXT:    setp.eq.b32 %p1, %r2, %r4;
+; CHECK-NEXT:    selp.b32 %r5, -1, 0, %p1;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r5;
+; CHECK-NEXT:    ret;
+  %b1 = add i32 %a1, 2
+  %b2 = add i32 %a2, 1
+  %c1 = trunc nuw i32 %b1 to i8
+  %c2 = trunc nuw i32 %b2 to i8
+  %c = icmp eq i8 %c1, %c2
+  ret i1 %c
+}

>From 2798859d1153a2cde4c1e19bd172872c7bba94b8 Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Thu, 19 Jun 2025 15:29:12 +0000
Subject: [PATCH 2/4] [NVPTX] Add NVPTXIncreaseAligmentPass to improve
 vectorization

---
 llvm/lib/Target/NVPTX/CMakeLists.txt          |   1 +
 llvm/lib/Target/NVPTX/NVPTX.h                 |   7 +
 .../Target/NVPTX/NVPTXIncreaseAlignment.cpp   | 131 ++++++++++++++++++
 llvm/lib/Target/NVPTX/NVPTXPassRegistry.def   |   1 +
 llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp  |   2 +
 .../CodeGen/NVPTX/call-with-alloca-buffer.ll  |   2 +-
 .../CodeGen/NVPTX/increase-local-align.ll     |  85 ++++++++++++
 7 files changed, 228 insertions(+), 1 deletion(-)
 create mode 100644 llvm/lib/Target/NVPTX/NVPTXIncreaseAlignment.cpp
 create mode 100644 llvm/test/CodeGen/NVPTX/increase-local-align.ll

diff --git a/llvm/lib/Target/NVPTX/CMakeLists.txt b/llvm/lib/Target/NVPTX/CMakeLists.txt
index 693f0d0b35edc..9d91100d35b3a 100644
--- a/llvm/lib/Target/NVPTX/CMakeLists.txt
+++ b/llvm/lib/Target/NVPTX/CMakeLists.txt
@@ -26,6 +26,7 @@ set(NVPTXCodeGen_sources
   NVPTXISelLowering.cpp
   NVPTXLowerAggrCopies.cpp
   NVPTXLowerAlloca.cpp
+  NVPTXIncreaseAlignment.cpp
   NVPTXLowerArgs.cpp
   NVPTXLowerUnreachable.cpp
   NVPTXMCExpr.cpp
diff --git a/llvm/lib/Target/NVPTX/NVPTX.h b/llvm/lib/Target/NVPTX/NVPTX.h
index 77a0e03d4075a..5b2f10be072bb 100644
--- a/llvm/lib/Target/NVPTX/NVPTX.h
+++ b/llvm/lib/Target/NVPTX/NVPTX.h
@@ -55,6 +55,7 @@ FunctionPass *createNVPTXTagInvariantLoadsPass();
 MachineFunctionPass *createNVPTXPeephole();
 MachineFunctionPass *createNVPTXProxyRegErasurePass();
 MachineFunctionPass *createNVPTXForwardParamsPass();
+FunctionPass *createNVPTXIncreaseLocalAlignmentPass();
 
 void initializeNVVMReflectLegacyPassPass(PassRegistry &);
 void initializeGenericToNVVMLegacyPassPass(PassRegistry &);
@@ -77,6 +78,7 @@ void initializeNVPTXExternalAAWrapperPass(PassRegistry &);
 void initializeNVPTXPeepholePass(PassRegistry &);
 void initializeNVPTXTagInvariantLoadLegacyPassPass(PassRegistry &);
 void initializeNVPTXPrologEpilogPassPass(PassRegistry &);
+void initializeNVPTXIncreaseLocalAlignmentLegacyPassPass(PassRegistry &);
 
 struct NVVMIntrRangePass : PassInfoMixin<NVVMIntrRangePass> {
   PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
@@ -112,6 +114,11 @@ struct NVPTXTagInvariantLoadsPass : PassInfoMixin<NVPTXTagInvariantLoadsPass> {
   PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
 };
 
+struct NVPTXIncreaseLocalAlignmentPass
+    : PassInfoMixin<NVPTXIncreaseLocalAlignmentPass> {
+  PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
+};
+
 namespace NVPTX {
 enum DrvInterface {
   NVCL,
diff --git a/llvm/lib/Target/NVPTX/NVPTXIncreaseAlignment.cpp b/llvm/lib/Target/NVPTX/NVPTXIncreaseAlignment.cpp
new file mode 100644
index 0000000000000..4078ef340970f
--- /dev/null
+++ b/llvm/lib/Target/NVPTX/NVPTXIncreaseAlignment.cpp
@@ -0,0 +1,131 @@
+//===-- NVPTXIncreaseAlignment.cpp - Increase alignment for local arrays --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// A simple pass that looks at local memory arrays that are statically
+// sized and sets an appropriate alignment for them. This enables vectorization
+// of loads/stores to these arrays if not explicitly specified by the client.
+//
+// TODO: Ideally we should do a bin-packing of local arrays to maximize
+// alignments while minimizing holes.
+//
+//===----------------------------------------------------------------------===//
+
+#include "NVPTX.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/MathExtras.h"
+
+using namespace llvm;
+
+static cl::opt<bool>
+    MaxLocalArrayAlignment("nvptx-use-max-local-array-alignment",
+                           cl::init(false), cl::Hidden,
+                           cl::desc("Use maximum alignment for local memory"));
+
+static constexpr Align MaxPTXArrayAlignment = Align::Constant<16>();
+
+/// Get the maximum useful alignment for an array. This is more likely to
+/// produce holes in the local memory.
+///
+/// Choose an alignment large enough that the entire array could be loaded with
+/// a single vector load (if possible). Cap the alignment at
+/// MaxPTXArrayAlignment.
+static Align getAggressiveArrayAlignment(const unsigned ArraySize) {
+  return std::min(MaxPTXArrayAlignment, Align(PowerOf2Ceil(ArraySize)));
+}
+
+/// Get the alignment of arrays that reduces the chances of leaving holes when
+/// arrays are allocated within a contiguous memory buffer (like shared memory
+/// and stack). Holes are still possible before and after the array allocation.
+///
+/// Choose the largest alignment such that the array size is a multiple of the
+/// alignment. If all elements of the buffer are allocated in order of
+/// alignment (higher to lower) no holes will be left.
+static Align getConservativeArrayAlignment(const unsigned ArraySize) {
+  return commonAlignment(MaxPTXArrayAlignment, ArraySize);
+}
+
+/// Find a better alignment for local arrays
+static bool updateAllocaAlignment(const DataLayout &DL, AllocaInst *Alloca) {
+  // Looking for statically sized local arrays
+  if (!Alloca->isStaticAlloca())
+    return false;
+
+  // For now, we only support array allocas
+  if (!(Alloca->isArrayAllocation() || Alloca->getAllocatedType()->isArrayTy()))
+    return false;
+
+  const auto ArraySize = Alloca->getAllocationSize(DL);
+  if (!(ArraySize && ArraySize->isFixed()))
+    return false;
+
+  const auto ArraySizeValue = ArraySize->getFixedValue();
+  const Align PreferredAlignment =
+      MaxLocalArrayAlignment ? getAggressiveArrayAlignment(ArraySizeValue)
+                             : getConservativeArrayAlignment(ArraySizeValue);
+
+  if (PreferredAlignment > Alloca->getAlign()) {
+    Alloca->setAlignment(PreferredAlignment);
+    return true;
+  }
+
+  return false;
+}
+
+static bool runSetLocalArrayAlignment(Function &F) {
+  bool Changed = false;
+  const DataLayout &DL = F.getParent()->getDataLayout();
+
+  BasicBlock &EntryBB = F.getEntryBlock();
+  for (Instruction &I : EntryBB)
+    if (AllocaInst *Alloca = dyn_cast<AllocaInst>(&I))
+      Changed |= updateAllocaAlignment(DL, Alloca);
+
+  return Changed;
+}
+
+namespace {
+struct NVPTXIncreaseLocalAlignmentLegacyPass : public FunctionPass {
+  static char ID;
+  NVPTXIncreaseLocalAlignmentLegacyPass() : FunctionPass(ID) {}
+
+  bool runOnFunction(Function &F) override;
+  StringRef getPassName() const override {
+    return "NVPTX Increase Local Alignment";
+  }
+};
+} // namespace
+
+char NVPTXIncreaseLocalAlignmentLegacyPass::ID = 0;
+INITIALIZE_PASS(NVPTXIncreaseLocalAlignmentLegacyPass,
+                "nvptx-increase-local-alignment",
+                "Increase alignment for statically sized alloca arrays", false,
+                false)
+
+FunctionPass *llvm::createNVPTXIncreaseLocalAlignmentPass() {
+  return new NVPTXIncreaseLocalAlignmentLegacyPass();
+}
+
+bool NVPTXIncreaseLocalAlignmentLegacyPass::runOnFunction(Function &F) {
+  return runSetLocalArrayAlignment(F);
+}
+
+PreservedAnalyses
+NVPTXIncreaseLocalAlignmentPass::run(Function &F, FunctionAnalysisManager &AM) {
+  bool Changed = runSetLocalArrayAlignment(F);
+
+  if (!Changed)
+    return PreservedAnalyses::all();
+
+  PreservedAnalyses PA;
+  PA.preserveSet<CFGAnalyses>();
+  return PA;
+}
diff --git a/llvm/lib/Target/NVPTX/NVPTXPassRegistry.def b/llvm/lib/Target/NVPTX/NVPTXPassRegistry.def
index ee37c9826012c..827cb7bba7018 100644
--- a/llvm/lib/Target/NVPTX/NVPTXPassRegistry.def
+++ b/llvm/lib/Target/NVPTX/NVPTXPassRegistry.def
@@ -40,4 +40,5 @@ FUNCTION_PASS("nvvm-intr-range", NVVMIntrRangePass())
 FUNCTION_PASS("nvptx-copy-byval-args", NVPTXCopyByValArgsPass())
 FUNCTION_PASS("nvptx-lower-args", NVPTXLowerArgsPass(*this))
 FUNCTION_PASS("nvptx-tag-invariant-loads", NVPTXTagInvariantLoadsPass())
+FUNCTION_PASS("nvptx-increase-local-alignment", NVPTXIncreaseLocalAlignmentPass())
 #undef FUNCTION_PASS
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
index 0603994606d71..7426114dd0f89 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp
@@ -393,6 +393,8 @@ void NVPTXPassConfig::addIRPasses() {
   // but EarlyCSE can do neither of them.
   if (getOptLevel() != CodeGenOptLevel::None) {
     addEarlyCSEOrGVNPass();
+    // Increase alignment for local arrays to improve vectorization.
+    addPass(createNVPTXIncreaseLocalAlignmentPass());
     if (!DisableLoadStoreVectorizer)
       addPass(createLoadStoreVectorizerPass());
     addPass(createSROAPass());
diff --git a/llvm/test/CodeGen/NVPTX/call-with-alloca-buffer.ll b/llvm/test/CodeGen/NVPTX/call-with-alloca-buffer.ll
index 0cd7058174d67..d2504ddd8e76c 100644
--- a/llvm/test/CodeGen/NVPTX/call-with-alloca-buffer.ll
+++ b/llvm/test/CodeGen/NVPTX/call-with-alloca-buffer.ll
@@ -20,7 +20,7 @@ define ptx_kernel void @kernel_func(ptr %a) {
 entry:
   %buf = alloca [16 x i8], align 4
 
-; CHECK: .local .align 4 .b8 	__local_depot0[16]
+; CHECK: .local .align 16 .b8 	__local_depot0[16]
 ; CHECK: mov.b64 %SPL
 
 ; CHECK: ld.param.b64 %rd[[A_REG:[0-9]+]], [kernel_func_param_0]
diff --git a/llvm/test/CodeGen/NVPTX/increase-local-align.ll b/llvm/test/CodeGen/NVPTX/increase-local-align.ll
new file mode 100644
index 0000000000000..605c4b5b2b77d
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/increase-local-align.ll
@@ -0,0 +1,85 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -S -passes=nvptx-increase-local-alignment < %s | FileCheck %s --check-prefixes=COMMON,DEFAULT
+; RUN: opt -S -passes=nvptx-increase-local-alignment -nvptx-use-max-local-array-alignment < %s | FileCheck %s --check-prefixes=COMMON,MAX
+target triple = "nvptx64-nvidia-cuda"
+
+define void @test1() {
+; COMMON-LABEL: define void @test1() {
+; COMMON-NEXT:    [[A:%.*]] = alloca i8, align 1
+; COMMON-NEXT:    ret void
+;
+  %a = alloca i8, align 1
+  ret void
+}
+
+define void @test2() {
+; DEFAULT-LABEL: define void @test2() {
+; DEFAULT-NEXT:    [[A:%.*]] = alloca [63 x i8], align 1
+; DEFAULT-NEXT:    ret void
+;
+; MAX-LABEL: define void @test2() {
+; MAX-NEXT:    [[A:%.*]] = alloca [63 x i8], align 16
+; MAX-NEXT:    ret void
+;
+  %a = alloca [63 x i8], align 1
+  ret void
+}
+
+define void @test3() {
+; COMMON-LABEL: define void @test3() {
+; COMMON-NEXT:    [[A:%.*]] = alloca [64 x i8], align 16
+; COMMON-NEXT:    ret void
+;
+  %a = alloca [64 x i8], align 1
+  ret void
+}
+
+define void @test4() {
+; DEFAULT-LABEL: define void @test4() {
+; DEFAULT-NEXT:    [[A:%.*]] = alloca i8, i32 63, align 1
+; DEFAULT-NEXT:    ret void
+;
+; MAX-LABEL: define void @test4() {
+; MAX-NEXT:    [[A:%.*]] = alloca i8, i32 63, align 16
+; MAX-NEXT:    ret void
+;
+  %a = alloca i8, i32 63, align 1
+  ret void
+}
+
+define void @test5() {
+; COMMON-LABEL: define void @test5() {
+; COMMON-NEXT:    [[A:%.*]] = alloca i8, i32 64, align 16
+; COMMON-NEXT:    ret void
+;
+  %a = alloca i8, i32 64, align 1
+  ret void
+}
+
+define void @test6() {
+; COMMON-LABEL: define void @test6() {
+; COMMON-NEXT:    [[A:%.*]] = alloca i8, align 32
+; COMMON-NEXT:    ret void
+;
+  %a = alloca i8, align 32
+  ret void
+}
+
+define void @test7() {
+; COMMON-LABEL: define void @test7() {
+; COMMON-NEXT:    [[A:%.*]] = alloca i32, align 2
+; COMMON-NEXT:    ret void
+;
+  %a = alloca i32, align 2
+  ret void
+}
+
+define void @test8() {
+; COMMON-LABEL: define void @test8() {
+; COMMON-NEXT:    [[A:%.*]] = alloca [2 x i32], align 8
+; COMMON-NEXT:    ret void
+;
+  %a = alloca [2 x i32], align 2
+  ret void
+}
+

>From 43368da1bc42ed54aaf975167a6dd6fd97ed7aa1 Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Wed, 25 Jun 2025 15:00:36 +0000
Subject: [PATCH 3/4] address comments

---
 .../Target/NVPTX/NVPTXIncreaseAlignment.cpp   | 60 +++++++++++++------
 .../CodeGen/NVPTX/increase-local-align.ll     |  2 +-
 llvm/test/CodeGen/NVPTX/lower-byval-args.ll   | 14 ++---
 llvm/test/CodeGen/NVPTX/variadics-backend.ll  | 11 +++-
 4 files changed, 60 insertions(+), 27 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXIncreaseAlignment.cpp b/llvm/lib/Target/NVPTX/NVPTXIncreaseAlignment.cpp
index 4078ef340970f..1fb1e578994e9 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIncreaseAlignment.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXIncreaseAlignment.cpp
@@ -7,7 +7,7 @@
 //===----------------------------------------------------------------------===//
 //
 // A simple pass that looks at local memory arrays that are statically
-// sized and sets an appropriate alignment for them. This enables vectorization
+// sized and potentially increases their alignment. This enables vectorization
 // of loads/stores to these arrays if not explicitly specified by the client.
 //
 // TODO: Ideally we should do a bin-packing of local arrays to maximize
@@ -16,12 +16,15 @@
 //===----------------------------------------------------------------------===//
 
 #include "NVPTX.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/Module.h"
+#include "llvm/IR/PassManager.h"
 #include "llvm/Pass.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/MathExtras.h"
+#include "llvm/Support/NVPTXAddrSpace.h"
 
 using namespace llvm;
 
@@ -30,7 +33,25 @@ static cl::opt<bool>
                            cl::init(false), cl::Hidden,
                            cl::desc("Use maximum alignment for local memory"));
 
-static constexpr Align MaxPTXArrayAlignment = Align::Constant<16>();
+static Align getMaxLocalArrayAlignment(const TargetTransformInfo &TTI) {
+  const unsigned MaxBitWidth =
+      TTI.getLoadStoreVecRegBitWidth(NVPTXAS::ADDRESS_SPACE_LOCAL);
+  return Align(MaxBitWidth / 8);
+}
+
+namespace {
+struct NVPTXIncreaseLocalAlignment {
+  const Align MaxAlign;
+
+  NVPTXIncreaseLocalAlignment(const TargetTransformInfo &TTI)
+      : MaxAlign(getMaxLocalArrayAlignment(TTI)) {}
+
+  bool run(Function &F);
+  bool updateAllocaAlignment(AllocaInst *Alloca, const DataLayout &DL);
+  Align getAggressiveArrayAlignment(unsigned ArraySize);
+  Align getConservativeArrayAlignment(unsigned ArraySize);
+};
+} // namespace
 
 /// Get the maximum useful alignment for an array. This is more likely to
 /// produce holes in the local memory.
@@ -38,8 +59,9 @@ static constexpr Align MaxPTXArrayAlignment = Align::Constant<16>();
 /// Choose an alignment large enough that the entire array could be loaded with
 /// a single vector load (if possible). Cap the alignment at
 /// MaxPTXArrayAlignment.
-static Align getAggressiveArrayAlignment(const unsigned ArraySize) {
-  return std::min(MaxPTXArrayAlignment, Align(PowerOf2Ceil(ArraySize)));
+Align NVPTXIncreaseLocalAlignment::getAggressiveArrayAlignment(
+    const unsigned ArraySize) {
+  return std::min(MaxAlign, Align(PowerOf2Ceil(ArraySize)));
 }
 
 /// Get the alignment of arrays that reduces the chances of leaving holes when
@@ -49,20 +71,18 @@ static Align getAggressiveArrayAlignment(const unsigned ArraySize) {
 /// Choose the largest alignment such that the array size is a multiple of the
 /// alignment. If all elements of the buffer are allocated in order of
 /// alignment (higher to lower) no holes will be left.
-static Align getConservativeArrayAlignment(const unsigned ArraySize) {
-  return commonAlignment(MaxPTXArrayAlignment, ArraySize);
+Align NVPTXIncreaseLocalAlignment::getConservativeArrayAlignment(
+    const unsigned ArraySize) {
+  return commonAlignment(MaxAlign, ArraySize);
 }
 
 /// Find a better alignment for local arrays
-static bool updateAllocaAlignment(const DataLayout &DL, AllocaInst *Alloca) {
+bool NVPTXIncreaseLocalAlignment::updateAllocaAlignment(AllocaInst *Alloca,
+                                                        const DataLayout &DL) {
   // Looking for statically sized local arrays
   if (!Alloca->isStaticAlloca())
     return false;
 
-  // For now, we only support array allocas
-  if (!(Alloca->isArrayAllocation() || Alloca->getAllocatedType()->isArrayTy()))
-    return false;
-
   const auto ArraySize = Alloca->getAllocationSize(DL);
   if (!(ArraySize && ArraySize->isFixed()))
     return false;
@@ -80,14 +100,14 @@ static bool updateAllocaAlignment(const DataLayout &DL, AllocaInst *Alloca) {
   return false;
 }
 
-static bool runSetLocalArrayAlignment(Function &F) {
+bool NVPTXIncreaseLocalAlignment::run(Function &F) {
   bool Changed = false;
-  const DataLayout &DL = F.getParent()->getDataLayout();
+  const auto &DL = F.getParent()->getDataLayout();
 
   BasicBlock &EntryBB = F.getEntryBlock();
   for (Instruction &I : EntryBB)
     if (AllocaInst *Alloca = dyn_cast<AllocaInst>(&I))
-      Changed |= updateAllocaAlignment(DL, Alloca);
+      Changed |= updateAllocaAlignment(Alloca, DL);
 
   return Changed;
 }
@@ -98,6 +118,9 @@ struct NVPTXIncreaseLocalAlignmentLegacyPass : public FunctionPass {
   NVPTXIncreaseLocalAlignmentLegacyPass() : FunctionPass(ID) {}
 
   bool runOnFunction(Function &F) override;
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    AU.addRequired<TargetTransformInfoWrapperPass>();
+  }
   StringRef getPassName() const override {
     return "NVPTX Increase Local Alignment";
   }
@@ -115,12 +138,15 @@ FunctionPass *llvm::createNVPTXIncreaseLocalAlignmentPass() {
 }
 
 bool NVPTXIncreaseLocalAlignmentLegacyPass::runOnFunction(Function &F) {
-  return runSetLocalArrayAlignment(F);
+  const auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
+  return NVPTXIncreaseLocalAlignment(TTI).run(F);
 }
 
 PreservedAnalyses
-NVPTXIncreaseLocalAlignmentPass::run(Function &F, FunctionAnalysisManager &AM) {
-  bool Changed = runSetLocalArrayAlignment(F);
+NVPTXIncreaseLocalAlignmentPass::run(Function &F,
+                                     FunctionAnalysisManager &FAM) {
+  const auto &TTI = FAM.getResult<TargetIRAnalysis>(F);
+  bool Changed = NVPTXIncreaseLocalAlignment(TTI).run(F);
 
   if (!Changed)
     return PreservedAnalyses::all();
diff --git a/llvm/test/CodeGen/NVPTX/increase-local-align.ll b/llvm/test/CodeGen/NVPTX/increase-local-align.ll
index 605c4b5b2b77d..3dddcf384b81c 100644
--- a/llvm/test/CodeGen/NVPTX/increase-local-align.ll
+++ b/llvm/test/CodeGen/NVPTX/increase-local-align.ll
@@ -67,7 +67,7 @@ define void @test6() {
 
 define void @test7() {
 ; COMMON-LABEL: define void @test7() {
-; COMMON-NEXT:    [[A:%.*]] = alloca i32, align 2
+; COMMON-NEXT:    [[A:%.*]] = alloca i32, align 4
 ; COMMON-NEXT:    ret void
 ;
   %a = alloca i32, align 2
diff --git a/llvm/test/CodeGen/NVPTX/lower-byval-args.ll b/llvm/test/CodeGen/NVPTX/lower-byval-args.ll
index 4784d7093a796..4047579eb4ea3 100644
--- a/llvm/test/CodeGen/NVPTX/lower-byval-args.ll
+++ b/llvm/test/CodeGen/NVPTX/lower-byval-args.ll
@@ -135,7 +135,7 @@ define dso_local ptx_kernel void @escape_ptr(ptr nocapture noundef readnone %out
 ;
 ; PTX-LABEL: escape_ptr(
 ; PTX:       {
-; PTX-NEXT:    .local .align 4 .b8 __local_depot2[8];
+; PTX-NEXT:    .local .align 8 .b8 __local_depot2[8];
 ; PTX-NEXT:    .reg .b64 %SP;
 ; PTX-NEXT:    .reg .b64 %SPL;
 ; PTX-NEXT:    .reg .b32 %r<3>;
@@ -175,7 +175,7 @@ define dso_local ptx_kernel void @escape_ptr_gep(ptr nocapture noundef readnone
 ;
 ; PTX-LABEL: escape_ptr_gep(
 ; PTX:       {
-; PTX-NEXT:    .local .align 4 .b8 __local_depot3[8];
+; PTX-NEXT:    .local .align 8 .b8 __local_depot3[8];
 ; PTX-NEXT:    .reg .b64 %SP;
 ; PTX-NEXT:    .reg .b64 %SPL;
 ; PTX-NEXT:    .reg .b32 %r<3>;
@@ -190,7 +190,7 @@ define dso_local ptx_kernel void @escape_ptr_gep(ptr nocapture noundef readnone
 ; PTX-NEXT:    st.local.b32 [%rd2+4], %r1;
 ; PTX-NEXT:    ld.param.b32 %r2, [escape_ptr_gep_param_1];
 ; PTX-NEXT:    st.local.b32 [%rd2], %r2;
-; PTX-NEXT:    add.s64 %rd3, %rd1, 4;
+; PTX-NEXT:    or.b64 %rd3, %rd1, 4;
 ; PTX-NEXT:    { // callseq 1, 0
 ; PTX-NEXT:    .param .b64 param0;
 ; PTX-NEXT:    st.param.b64 [param0], %rd3;
@@ -216,7 +216,7 @@ define dso_local ptx_kernel void @escape_ptr_store(ptr nocapture noundef writeon
 ;
 ; PTX-LABEL: escape_ptr_store(
 ; PTX:       {
-; PTX-NEXT:    .local .align 4 .b8 __local_depot4[8];
+; PTX-NEXT:    .local .align 8 .b8 __local_depot4[8];
 ; PTX-NEXT:    .reg .b64 %SP;
 ; PTX-NEXT:    .reg .b64 %SPL;
 ; PTX-NEXT:    .reg .b32 %r<3>;
@@ -254,7 +254,7 @@ define dso_local ptx_kernel void @escape_ptr_gep_store(ptr nocapture noundef wri
 ;
 ; PTX-LABEL: escape_ptr_gep_store(
 ; PTX:       {
-; PTX-NEXT:    .local .align 4 .b8 __local_depot5[8];
+; PTX-NEXT:    .local .align 8 .b8 __local_depot5[8];
 ; PTX-NEXT:    .reg .b64 %SP;
 ; PTX-NEXT:    .reg .b64 %SPL;
 ; PTX-NEXT:    .reg .b32 %r<3>;
@@ -271,7 +271,7 @@ define dso_local ptx_kernel void @escape_ptr_gep_store(ptr nocapture noundef wri
 ; PTX-NEXT:    st.local.b32 [%rd4+4], %r1;
 ; PTX-NEXT:    ld.param.b32 %r2, [escape_ptr_gep_store_param_1];
 ; PTX-NEXT:    st.local.b32 [%rd4], %r2;
-; PTX-NEXT:    add.s64 %rd5, %rd3, 4;
+; PTX-NEXT:    or.b64 %rd5, %rd3, 4;
 ; PTX-NEXT:    st.global.b64 [%rd2], %rd5;
 ; PTX-NEXT:    ret;
 entry:
@@ -294,7 +294,7 @@ define dso_local ptx_kernel void @escape_ptrtoint(ptr nocapture noundef writeonl
 ;
 ; PTX-LABEL: escape_ptrtoint(
 ; PTX:       {
-; PTX-NEXT:    .local .align 4 .b8 __local_depot6[8];
+; PTX-NEXT:    .local .align 8 .b8 __local_depot6[8];
 ; PTX-NEXT:    .reg .b64 %SP;
 ; PTX-NEXT:    .reg .b64 %SPL;
 ; PTX-NEXT:    .reg .b32 %r<3>;
diff --git a/llvm/test/CodeGen/NVPTX/variadics-backend.ll b/llvm/test/CodeGen/NVPTX/variadics-backend.ll
index ad2e7044e93bc..7c028284e9db0 100644
--- a/llvm/test/CodeGen/NVPTX/variadics-backend.ll
+++ b/llvm/test/CodeGen/NVPTX/variadics-backend.ll
@@ -198,7 +198,7 @@ declare void @llvm.memcpy.p0.p0.i64(ptr noalias nocapture writeonly, ptr noalias
 define dso_local i32 @bar() {
 ; CHECK-PTX-LABEL: bar(
 ; CHECK-PTX:       {
-; CHECK-PTX-NEXT:    .local .align 8 .b8 __local_depot3[24];
+; CHECK-PTX-NEXT:    .local .align 16 .b8 __local_depot3[32];
 ; CHECK-PTX-NEXT:    .reg .b64 %SP;
 ; CHECK-PTX-NEXT:    .reg .b64 %SPL;
 ; CHECK-PTX-NEXT:    .reg .b16 %rs<4>;
@@ -219,6 +219,13 @@ define dso_local i32 @bar() {
 ; CHECK-PTX-NEXT:    st.b8 [%SP+12], 1;
 ; CHECK-PTX-NEXT:    st.b64 [%SP+16], 1;
 ; CHECK-PTX-NEXT:    add.u64 %rd3, %SP, 8;
+; CHECK-PTX-NEXT:    mov.b32 %r1, 1;
+; CHECK-PTX-NEXT:    st.b32 [%SP+16], %r1;
+; CHECK-PTX-NEXT:    mov.b16 %rs4, 1;
+; CHECK-PTX-NEXT:    st.b8 [%SP+20], %rs4;
+; CHECK-PTX-NEXT:    mov.b64 %rd3, 1;
+; CHECK-PTX-NEXT:    st.b64 [%SP+24], %rd3;
+; CHECK-PTX-NEXT:    add.u64 %rd4, %SP, 16;
 ; CHECK-PTX-NEXT:    { // callseq 1, 0
 ; CHECK-PTX-NEXT:    .param .b32 param0;
 ; CHECK-PTX-NEXT:    st.param.b32 [param0], 1;
@@ -345,7 +352,7 @@ entry:
 define dso_local void @qux() {
 ; CHECK-PTX-LABEL: qux(
 ; CHECK-PTX:       {
-; CHECK-PTX-NEXT:    .local .align 8 .b8 __local_depot7[24];
+; CHECK-PTX-NEXT:    .local .align 16 .b8 __local_depot7[32];
 ; CHECK-PTX-NEXT:    .reg .b64 %SP;
 ; CHECK-PTX-NEXT:    .reg .b64 %SPL;
 ; CHECK-PTX-NEXT:    .reg .b32 %r<2>;

>From b12f4e99f255a20932596da7fd2f1c4881ea3681 Mon Sep 17 00:00:00 2001
From: Alex Maclean <amaclean at nvidia.com>
Date: Thu, 17 Jul 2025 21:16:27 +0000
Subject: [PATCH 4/4] address comments

---
 .../Target/NVPTX/NVPTXIncreaseAlignment.cpp   | 69 ++++++++++---------
 .../CodeGen/NVPTX/increase-local-align.ll     | 45 ++++++++----
 llvm/test/CodeGen/NVPTX/local-stack-frame.ll  |  4 +-
 llvm/test/CodeGen/NVPTX/variadics-backend.ll  | 19 ++---
 4 files changed, 75 insertions(+), 62 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXIncreaseAlignment.cpp b/llvm/lib/Target/NVPTX/NVPTXIncreaseAlignment.cpp
index 1fb1e578994e9..cff2ea25c1e6c 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIncreaseAlignment.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXIncreaseAlignment.cpp
@@ -6,11 +6,11 @@
 //
 //===----------------------------------------------------------------------===//
 //
-// A simple pass that looks at local memory arrays that are statically
+// A simple pass that looks at local memory allocas that are statically
 // sized and potentially increases their alignment. This enables vectorization
-// of loads/stores to these arrays if not explicitly specified by the client.
+// of loads/stores to these allocas if not explicitly specified by the client.
 //
-// TODO: Ideally we should do a bin-packing of local arrays to maximize
+// TODO: Ideally we should do a bin-packing of local allocas to maximize
 // alignments while minimizing holes.
 //
 //===----------------------------------------------------------------------===//
@@ -28,10 +28,10 @@
 
 using namespace llvm;
 
-static cl::opt<bool>
-    MaxLocalArrayAlignment("nvptx-use-max-local-array-alignment",
-                           cl::init(false), cl::Hidden,
-                           cl::desc("Use maximum alignment for local memory"));
+static cl::opt<unsigned> MinLocalArrayAlignment(
+    "nvptx-ensure-minimum-local-alignment", cl::init(16), cl::Hidden,
+    cl::desc(
+        "Ensure local memory objects are at least this aligned (default 16)"));
 
 static Align getMaxLocalArrayAlignment(const TargetTransformInfo &TTI) {
   const unsigned MaxBitWidth =
@@ -41,45 +41,46 @@ static Align getMaxLocalArrayAlignment(const TargetTransformInfo &TTI) {
 
 namespace {
 struct NVPTXIncreaseLocalAlignment {
-  const Align MaxAlign;
+  const Align MaxUsableAlign;
 
   NVPTXIncreaseLocalAlignment(const TargetTransformInfo &TTI)
-      : MaxAlign(getMaxLocalArrayAlignment(TTI)) {}
+      : MaxUsableAlign(getMaxLocalArrayAlignment(TTI)) {}
 
   bool run(Function &F);
   bool updateAllocaAlignment(AllocaInst *Alloca, const DataLayout &DL);
-  Align getAggressiveArrayAlignment(unsigned ArraySize);
-  Align getConservativeArrayAlignment(unsigned ArraySize);
+  Align getMaxUsefulArrayAlignment(unsigned ArraySize);
+  Align getMaxSafeLocalAlignment(unsigned ArraySize);
 };
 } // namespace
 
-/// Get the maximum useful alignment for an array. This is more likely to
+/// Get the maximum useful alignment for an allocation. This is more likely to
 /// produce holes in the local memory.
 ///
-/// Choose an alignment large enough that the entire array could be loaded with
-/// a single vector load (if possible). Cap the alignment at
-/// MaxPTXArrayAlignment.
-Align NVPTXIncreaseLocalAlignment::getAggressiveArrayAlignment(
+/// Choose an alignment large enough that the entire alloca could be loaded
+/// with a single vector load (if possible). Cap the alignment at
+/// MinLocalArrayAlignment and MaxUsableAlign.
+Align NVPTXIncreaseLocalAlignment::getMaxUsefulArrayAlignment(
     const unsigned ArraySize) {
-  return std::min(MaxAlign, Align(PowerOf2Ceil(ArraySize)));
+  const Align UpperLimit =
+      std::min(MaxUsableAlign, Align(MinLocalArrayAlignment));
+  return std::min(UpperLimit, Align(PowerOf2Ceil(ArraySize)));
 }
 
-/// Get the alignment of arrays that reduces the chances of leaving holes when
-/// arrays are allocated within a contiguous memory buffer (like shared memory
-/// and stack). Holes are still possible before and after the array allocation.
+/// Get the alignment of allocas that reduces the chances of leaving holes when
+/// they are allocated within a contiguous memory buffer (like the stack).
+/// Holes are still possible before and after the allocation.
 ///
-/// Choose the largest alignment such that the array size is a multiple of the
-/// alignment. If all elements of the buffer are allocated in order of
+/// Choose the largest alignment such that the allocation size is a multiple of
+/// the alignment. If all elements of the buffer are allocated in order of
 /// alignment (higher to lower) no holes will be left.
-Align NVPTXIncreaseLocalAlignment::getConservativeArrayAlignment(
+Align NVPTXIncreaseLocalAlignment::getMaxSafeLocalAlignment(
     const unsigned ArraySize) {
-  return commonAlignment(MaxAlign, ArraySize);
+  return commonAlignment(MaxUsableAlign, ArraySize);
 }
 
-/// Find a better alignment for local arrays
+/// Find a better alignment for local allocas.
 bool NVPTXIncreaseLocalAlignment::updateAllocaAlignment(AllocaInst *Alloca,
                                                         const DataLayout &DL) {
-  // Looking for statically sized local arrays
   if (!Alloca->isStaticAlloca())
     return false;
 
@@ -88,12 +89,15 @@ bool NVPTXIncreaseLocalAlignment::updateAllocaAlignment(AllocaInst *Alloca,
     return false;
 
   const auto ArraySizeValue = ArraySize->getFixedValue();
-  const Align PreferredAlignment =
-      MaxLocalArrayAlignment ? getAggressiveArrayAlignment(ArraySizeValue)
-                             : getConservativeArrayAlignment(ArraySizeValue);
+  if (ArraySizeValue == 0)
+    return false;
+
+  const Align NewAlignment =
+      std::max(getMaxSafeLocalAlignment(ArraySizeValue),
+               getMaxUsefulArrayAlignment(ArraySizeValue));
 
-  if (PreferredAlignment > Alloca->getAlign()) {
-    Alloca->setAlignment(PreferredAlignment);
+  if (NewAlignment > Alloca->getAlign()) {
+    Alloca->setAlignment(NewAlignment);
     return true;
   }
 
@@ -130,8 +134,7 @@ struct NVPTXIncreaseLocalAlignmentLegacyPass : public FunctionPass {
 char NVPTXIncreaseLocalAlignmentLegacyPass::ID = 0;
 INITIALIZE_PASS(NVPTXIncreaseLocalAlignmentLegacyPass,
                 "nvptx-increase-local-alignment",
-                "Increase alignment for statically sized alloca arrays", false,
-                false)
+                "Increase alignment for statically sized allocas", false, false)
 
 FunctionPass *llvm::createNVPTXIncreaseLocalAlignmentPass() {
   return new NVPTXIncreaseLocalAlignmentLegacyPass();
diff --git a/llvm/test/CodeGen/NVPTX/increase-local-align.ll b/llvm/test/CodeGen/NVPTX/increase-local-align.ll
index 3dddcf384b81c..6215850a2a22b 100644
--- a/llvm/test/CodeGen/NVPTX/increase-local-align.ll
+++ b/llvm/test/CodeGen/NVPTX/increase-local-align.ll
@@ -1,6 +1,7 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
-; RUN: opt -S -passes=nvptx-increase-local-alignment < %s | FileCheck %s --check-prefixes=COMMON,DEFAULT
-; RUN: opt -S -passes=nvptx-increase-local-alignment -nvptx-use-max-local-array-alignment < %s | FileCheck %s --check-prefixes=COMMON,MAX
+; RUN: opt -S -passes=nvptx-increase-local-alignment -nvptx-ensure-minimum-local-alignment=1 < %s | FileCheck %s --check-prefixes=COMMON,MIN-1
+; RUN: opt -S -passes=nvptx-increase-local-alignment -nvptx-ensure-minimum-local-alignment=8 < %s | FileCheck %s --check-prefixes=COMMON,MIN-8
+; RUN: opt -S -passes=nvptx-increase-local-alignment -nvptx-ensure-minimum-local-alignment=16 < %s | FileCheck %s --check-prefixes=COMMON,MIN-16
 target triple = "nvptx64-nvidia-cuda"
 
 define void @test1() {
@@ -13,13 +14,17 @@ define void @test1() {
 }
 
 define void @test2() {
-; DEFAULT-LABEL: define void @test2() {
-; DEFAULT-NEXT:    [[A:%.*]] = alloca [63 x i8], align 1
-; DEFAULT-NEXT:    ret void
+; MIN-1-LABEL: define void @test2() {
+; MIN-1-NEXT:    [[A:%.*]] = alloca [63 x i8], align 1
+; MIN-1-NEXT:    ret void
 ;
-; MAX-LABEL: define void @test2() {
-; MAX-NEXT:    [[A:%.*]] = alloca [63 x i8], align 16
-; MAX-NEXT:    ret void
+; MIN-8-LABEL: define void @test2() {
+; MIN-8-NEXT:    [[A:%.*]] = alloca [63 x i8], align 8
+; MIN-8-NEXT:    ret void
+;
+; MIN-16-LABEL: define void @test2() {
+; MIN-16-NEXT:    [[A:%.*]] = alloca [63 x i8], align 16
+; MIN-16-NEXT:    ret void
 ;
   %a = alloca [63 x i8], align 1
   ret void
@@ -35,13 +40,17 @@ define void @test3() {
 }
 
 define void @test4() {
-; DEFAULT-LABEL: define void @test4() {
-; DEFAULT-NEXT:    [[A:%.*]] = alloca i8, i32 63, align 1
-; DEFAULT-NEXT:    ret void
+; MIN-1-LABEL: define void @test4() {
+; MIN-1-NEXT:    [[A:%.*]] = alloca i8, i32 63, align 1
+; MIN-1-NEXT:    ret void
+;
+; MIN-8-LABEL: define void @test4() {
+; MIN-8-NEXT:    [[A:%.*]] = alloca i8, i32 63, align 8
+; MIN-8-NEXT:    ret void
 ;
-; MAX-LABEL: define void @test4() {
-; MAX-NEXT:    [[A:%.*]] = alloca i8, i32 63, align 16
-; MAX-NEXT:    ret void
+; MIN-16-LABEL: define void @test4() {
+; MIN-16-NEXT:    [[A:%.*]] = alloca i8, i32 63, align 16
+; MIN-16-NEXT:    ret void
 ;
   %a = alloca i8, i32 63, align 1
   ret void
@@ -83,3 +92,11 @@ define void @test8() {
   ret void
 }
 
+define void @test9() {
+; COMMON-LABEL: define void @test9() {
+; COMMON-NEXT:    [[A:%.*]] = alloca [0 x i32], align 1
+; COMMON-NEXT:    ret void
+;
+  %a = alloca [0 x i32], align 1
+  ret void
+}
diff --git a/llvm/test/CodeGen/NVPTX/local-stack-frame.ll b/llvm/test/CodeGen/NVPTX/local-stack-frame.ll
index 5c3017310d0a3..3899b37e140eb 100644
--- a/llvm/test/CodeGen/NVPTX/local-stack-frame.ll
+++ b/llvm/test/CodeGen/NVPTX/local-stack-frame.ll
@@ -94,7 +94,7 @@ declare void @bar(ptr %a)
 define void @foo3(i32 %a) {
 ; PTX32-LABEL: foo3(
 ; PTX32:       {
-; PTX32-NEXT:    .local .align 4 .b8 __local_depot2[12];
+; PTX32-NEXT:    .local .align 16 .b8 __local_depot2[16];
 ; PTX32-NEXT:    .reg .b32 %SP;
 ; PTX32-NEXT:    .reg .b32 %SPL;
 ; PTX32-NEXT:    .reg .b32 %r<6>;
@@ -110,7 +110,7 @@ define void @foo3(i32 %a) {
 ;
 ; PTX64-LABEL: foo3(
 ; PTX64:       {
-; PTX64-NEXT:    .local .align 4 .b8 __local_depot2[12];
+; PTX64-NEXT:    .local .align 16 .b8 __local_depot2[16];
 ; PTX64-NEXT:    .reg .b64 %SP;
 ; PTX64-NEXT:    .reg .b64 %SPL;
 ; PTX64-NEXT:    .reg .b32 %r<2>;
diff --git a/llvm/test/CodeGen/NVPTX/variadics-backend.ll b/llvm/test/CodeGen/NVPTX/variadics-backend.ll
index 7c028284e9db0..17c74227cfbe6 100644
--- a/llvm/test/CodeGen/NVPTX/variadics-backend.ll
+++ b/llvm/test/CodeGen/NVPTX/variadics-backend.ll
@@ -101,7 +101,7 @@ declare void @llvm.va_end.p0(ptr)
 define dso_local i32 @foo() {
 ; CHECK-PTX-LABEL: foo(
 ; CHECK-PTX:       {
-; CHECK-PTX-NEXT:    .local .align 8 .b8 __local_depot1[40];
+; CHECK-PTX-NEXT:    .local .align 16 .b8 __local_depot1[48];
 ; CHECK-PTX-NEXT:    .reg .b64 %SP;
 ; CHECK-PTX-NEXT:    .reg .b64 %SPL;
 ; CHECK-PTX-NEXT:    .reg .b32 %r<3>;
@@ -138,7 +138,7 @@ entry:
 define dso_local i32 @variadics2(i32 noundef %first, ...) {
 ; CHECK-PTX-LABEL: variadics2(
 ; CHECK-PTX:       {
-; CHECK-PTX-NEXT:    .local .align 1 .b8 __local_depot2[3];
+; CHECK-PTX-NEXT:    .local .align 4 .b8 __local_depot2[4];
 ; CHECK-PTX-NEXT:    .reg .b64 %SP;
 ; CHECK-PTX-NEXT:    .reg .b64 %SPL;
 ; CHECK-PTX-NEXT:    .reg .b16 %rs<4>;
@@ -215,17 +215,10 @@ define dso_local i32 @bar() {
 ; CHECK-PTX-NEXT:    st.local.b8 [%rd2+1], %rs2;
 ; CHECK-PTX-NEXT:    ld.global.nc.b8 %rs3, [__const_$_bar_$_s1+5];
 ; CHECK-PTX-NEXT:    st.local.b8 [%rd2], %rs3;
-; CHECK-PTX-NEXT:    st.b32 [%SP+8], 1;
-; CHECK-PTX-NEXT:    st.b8 [%SP+12], 1;
-; CHECK-PTX-NEXT:    st.b64 [%SP+16], 1;
-; CHECK-PTX-NEXT:    add.u64 %rd3, %SP, 8;
-; CHECK-PTX-NEXT:    mov.b32 %r1, 1;
-; CHECK-PTX-NEXT:    st.b32 [%SP+16], %r1;
-; CHECK-PTX-NEXT:    mov.b16 %rs4, 1;
-; CHECK-PTX-NEXT:    st.b8 [%SP+20], %rs4;
-; CHECK-PTX-NEXT:    mov.b64 %rd3, 1;
-; CHECK-PTX-NEXT:    st.b64 [%SP+24], %rd3;
-; CHECK-PTX-NEXT:    add.u64 %rd4, %SP, 16;
+; CHECK-PTX-NEXT:    st.b32 [%SP+16], 1;
+; CHECK-PTX-NEXT:    st.b8 [%SP+20], 1;
+; CHECK-PTX-NEXT:    st.b64 [%SP+24], 1;
+; CHECK-PTX-NEXT:    add.u64 %rd3, %SP, 16;
 ; CHECK-PTX-NEXT:    { // callseq 1, 0
 ; CHECK-PTX-NEXT:    .param .b32 param0;
 ; CHECK-PTX-NEXT:    st.param.b32 [param0], 1;



More information about the llvm-commits mailing list