[llvm] [NVPTX] pull in v2i32 build_vector through v2f32 bitcast (PR #153478)

Princeton Ferro via llvm-commits llvm-commits at lists.llvm.org
Thu Aug 21 11:02:12 PDT 2025


https://github.com/Prince781 updated https://github.com/llvm/llvm-project/pull/153478

>From 4bab2557ab3372a1231f6977064e42bacc6637cd Mon Sep 17 00:00:00 2001
From: Princeton Ferro <pferro at nvidia.com>
Date: Tue, 12 Aug 2025 18:01:46 -0700
Subject: [PATCH 1/4] [NVPTX] pull in v2i32 build_vector through v2f32 bitcast

Transform:
 v2f32 (bitcast (v2i32 build_vector i32:A, i32:B))
 --->
 v2f32 (build_vector (f32 (bitcast i32:A)), (f32 (bitcast i32:B)))

Since v2f32 is legal but v2i32 is not, v2i32 build_vector would be
legalized as bitwise ops on i64, when we want each 32-bit element to be
in its own register.
---
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp   | 101 +++++++++++++++++-
 .../test/CodeGen/NVPTX/f32x2-convert-i32x2.ll |  95 ++++++++++++++++
 2 files changed, 192 insertions(+), 4 deletions(-)
 create mode 100644 llvm/test/CodeGen/NVPTX/f32x2-convert-i32x2.ll

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index ad56d2f12caf6..566ba840b2eee 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -883,10 +883,11 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
   setOperationAction(ISD::UMUL_LOHI, MVT::i64, Expand);
 
   // We have some custom DAG combine patterns for these nodes
-  setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
-                       ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
-                       ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::LOAD,
-                       ISD::STORE, ISD::ZERO_EXTEND, ISD::SIGN_EXTEND});
+  setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::BITCAST,
+                       ISD::EXTRACT_VECTOR_ELT, ISD::FADD, ISD::MUL, ISD::SHL,
+                       ISD::SREM, ISD::UREM, ISD::VSELECT, ISD::BUILD_VECTOR,
+                       ISD::ADDRSPACECAST, ISD::LOAD, ISD::STORE,
+                       ISD::ZERO_EXTEND, ISD::SIGN_EXTEND});
 
   // setcc for f16x2 and bf16x2 needs special handling to prevent
   // legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -5243,6 +5244,94 @@ static SDValue PerformFADDCombine(SDNode *N,
   return PerformFADDCombineWithOperands(N, N1, N0, DCI, OptLevel);
 }
 
+static SDValue PerformANDCombine(SDNode *N,
+                                 TargetLowering::DAGCombinerInfo &DCI) {
+  // The type legalizer turns a vector load of i8 values into a zextload to i16
+  // registers, optionally ANY_EXTENDs it (if target type is integer),
+  // and ANDs off the high 8 bits. Since we turn this load into a
+  // target-specific DAG node, the DAG combiner fails to eliminate these AND
+  // nodes. Do that here.
+  SDValue Val = N->getOperand(0);
+  SDValue Mask = N->getOperand(1);
+
+  if (isa<ConstantSDNode>(Val)) {
+    std::swap(Val, Mask);
+  }
+
+  SDValue AExt;
+
+  // Generally, we will see zextload -> IMOV16rr -> ANY_EXTEND -> and
+  if (Val.getOpcode() == ISD::ANY_EXTEND) {
+    AExt = Val;
+    Val = Val->getOperand(0);
+  }
+
+  if (Val->getOpcode() == NVPTXISD::LoadV2 ||
+      Val->getOpcode() == NVPTXISD::LoadV4) {
+    ConstantSDNode *MaskCnst = dyn_cast<ConstantSDNode>(Mask);
+    if (!MaskCnst) {
+      // Not an AND with a constant
+      return SDValue();
+    }
+
+    uint64_t MaskVal = MaskCnst->getZExtValue();
+    if (MaskVal != 0xff) {
+      // Not an AND that chops off top 8 bits
+      return SDValue();
+    }
+
+    MemSDNode *Mem = dyn_cast<MemSDNode>(Val);
+    if (!Mem) {
+      // Not a MemSDNode?!?
+      return SDValue();
+    }
+
+    EVT MemVT = Mem->getMemoryVT();
+    if (MemVT != MVT::v2i8 && MemVT != MVT::v4i8) {
+      // We only handle the i8 case
+      return SDValue();
+    }
+
+    unsigned ExtType = Val->getConstantOperandVal(Val->getNumOperands() - 1);
+    if (ExtType == ISD::SEXTLOAD) {
+      // If for some reason the load is a sextload, the and is needed to zero
+      // out the high 8 bits
+      return SDValue();
+    }
+
+    bool AddTo = false;
+    if (AExt.getNode() != nullptr) {
+      // Re-insert the ext as a zext.
+      Val = DCI.DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N),
+                            AExt.getValueType(), Val);
+      AddTo = true;
+    }
+
+    // If we get here, the AND is unnecessary.  Just replace it with the load
+    DCI.CombineTo(N, Val, AddTo);
+  }
+
+  return SDValue();
+}
+
+static SDValue combineBitcast(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
+  const SDValue &Input = N->getOperand(0);
+  const EVT FromVT = Input.getValueType();
+  const EVT ToVT = N->getValueType(0);
+
+  if (Input.getOpcode() == ISD::BUILD_VECTOR && ToVT == MVT::v2f32 &&
+      FromVT == MVT::v2i32) {
+    // Pull in v2i32 build_vector through v2f32 bitcast to avoid legalizing the
+    // build_vector as bitwise ops.
+    return DCI.DAG.getBuildVector(
+        MVT::v2f32, SDLoc(N),
+        {DCI.DAG.getBitcast(MVT::f32, Input.getOperand(0)),
+         DCI.DAG.getBitcast(MVT::f32, Input.getOperand(1))});
+  }
+
+  return SDValue();
+}
+
 static SDValue PerformREMCombine(SDNode *N,
                                  TargetLowering::DAGCombinerInfo &DCI,
                                  CodeGenOptLevel OptLevel) {
@@ -5914,6 +6003,10 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
     return PerformADDCombine(N, DCI, OptLevel);
   case ISD::ADDRSPACECAST:
     return combineADDRSPACECAST(N, DCI);
+  case ISD::AND:
+    return PerformANDCombine(N, DCI);
+  case ISD::BITCAST:
+    return combineBitcast(N, DCI);
   case ISD::SIGN_EXTEND:
   case ISD::ZERO_EXTEND:
     return combineMulWide(N, DCI, OptLevel);
diff --git a/llvm/test/CodeGen/NVPTX/f32x2-convert-i32x2.ll b/llvm/test/CodeGen/NVPTX/f32x2-convert-i32x2.ll
new file mode 100644
index 0000000000000..13f8dd34b5ce3
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/f32x2-convert-i32x2.ll
@@ -0,0 +1,95 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -mcpu=sm_90a -O0 -disable-post-ra -frame-pointer=all          \
+; RUN: -verify-machineinstrs | FileCheck --check-prefixes=CHECK,CHECK-SM90A %s
+; RUN: %if ptxas-12.7 %{                                                      \
+; RUN:  llc < %s -mcpu=sm_90a -O0 -disable-post-ra -frame-pointer=all         \
+; RUN:  -verify-machineinstrs | %ptxas-verify -arch=sm_90a                    \
+; RUN: %}
+; RUN: llc < %s -mcpu=sm_100 -O0 -disable-post-ra -frame-pointer=all          \
+; RUN: -verify-machineinstrs | FileCheck --check-prefixes=CHECK,CHECK-SM100 %s
+; RUN: %if ptxas-12.7 %{                                                      \
+; RUN:  llc < %s -mcpu=sm_100 -O0 -disable-post-ra -frame-pointer=all         \
+; RUN:  -verify-machineinstrs | %ptxas-verify -arch=sm_100                    \
+; RUN: %}
+
+; Test that v2i32 -> v2f32 conversions don't emit bitwise operations on i64.
+
+target triple = "nvptx64-nvidia-cuda"
+
+declare <2 x i32> @return_i32x2(i32 %0)
+
+; Test with v2i32.
+define ptx_kernel void @store_i32x2(i32 %0, ptr %p) {
+; CHECK-LABEL: store_i32x2(
+; CHECK:       {
+; CHECK-NEXT:    .reg .b32 %r<6>;
+; CHECK-NEXT:    .reg .b64 %rd<2>;
+; CHECK-EMPTY:
+; CHECK-NEXT:  // %bb.0:
+; CHECK-NEXT:    ld.param.b64 %rd1, [store_i32x2_param_1];
+; CHECK-NEXT:    ld.param.b32 %r1, [store_i32x2_param_0];
+; CHECK-NEXT:    { // callseq 0, 0
+; CHECK-NEXT:    .param .b32 param0;
+; CHECK-NEXT:    .param .align 8 .b8 retval0[8];
+; CHECK-NEXT:    st.param.b32 [param0], %r1;
+; CHECK-NEXT:    call.uni (retval0), return_i32x2, (param0);
+; CHECK-NEXT:    ld.param.v2.b32 {%r2, %r3}, [retval0];
+; CHECK-NEXT:    } // callseq 0
+; CHECK-NEXT:    st.v2.b32 [%rd1], {%r2, %r3};
+; CHECK-NEXT:    ret;
+  %v = call <2 x i32> @return_i32x2(i32 %0)
+  %v.f32x2 = bitcast <2 x i32> %v to <2 x float>
+  store <2 x float> %v.f32x2, ptr %p, align 8
+  ret void
+}
+
+; Test with inline ASM returning { <1 x float>, <1 x float> }, which decays to
+; v2i32.
+define ptx_kernel void @inlineasm(ptr %p) {
+; CHECK-SM90A-LABEL: inlineasm(
+; CHECK-SM90A:       {
+; CHECK-SM90A-NEXT:    .reg .b32 %r<7>;
+; CHECK-SM90A-NEXT:    .reg .b64 %rd<2>;
+; CHECK-SM90A-EMPTY:
+; CHECK-SM90A-NEXT:  // %bb.0:
+; CHECK-SM90A-NEXT:    ld.param.b64 %rd1, [inlineasm_param_0];
+; CHECK-SM90A-NEXT:    mov.b32 %r3, 0;
+; CHECK-SM90A-NEXT:    mov.b32 %r4, %r3;
+; CHECK-SM90A-NEXT:    mov.b32 %r2, %r4;
+; CHECK-SM90A-NEXT:    mov.b32 %r1, %r3;
+; CHECK-SM90A-NEXT:    // begin inline asm
+; CHECK-SM90A-NEXT:    // nop
+; CHECK-SM90A-NEXT:    // end inline asm
+; CHECK-SM90A-NEXT:    mul.rn.f32 %r5, %r2, 0f00000000;
+; CHECK-SM90A-NEXT:    mul.rn.f32 %r6, %r1, 0f00000000;
+; CHECK-SM90A-NEXT:    st.v2.b32 [%rd1], {%r6, %r5};
+; CHECK-SM90A-NEXT:    ret;
+;
+; CHECK-SM100-LABEL: inlineasm(
+; CHECK-SM100:       {
+; CHECK-SM100-NEXT:    .reg .b32 %r<6>;
+; CHECK-SM100-NEXT:    .reg .b64 %rd<5>;
+; CHECK-SM100-EMPTY:
+; CHECK-SM100-NEXT:  // %bb.0:
+; CHECK-SM100-NEXT:    ld.param.b64 %rd1, [inlineasm_param_0];
+; CHECK-SM100-NEXT:    mov.b32 %r3, 0;
+; CHECK-SM100-NEXT:    mov.b32 %r4, %r3;
+; CHECK-SM100-NEXT:    mov.b32 %r2, %r4;
+; CHECK-SM100-NEXT:    mov.b32 %r1, %r3;
+; CHECK-SM100-NEXT:    // begin inline asm
+; CHECK-SM100-NEXT:    // nop
+; CHECK-SM100-NEXT:    // end inline asm
+; CHECK-SM100-NEXT:    mov.b64 %rd2, {%r1, %r2};
+; CHECK-SM100-NEXT:    mov.b32 %r5, 0f00000000;
+; CHECK-SM100-NEXT:    mov.b64 %rd3, {%r5, %r5};
+; CHECK-SM100-NEXT:    mul.rn.f32x2 %rd4, %rd2, %rd3;
+; CHECK-SM100-NEXT:    st.b64 [%rd1], %rd4;
+; CHECK-SM100-NEXT:    ret;
+  %r = call { <1 x float>, <1 x float> } asm sideeffect "// nop", "=f,=f,0,1"(<1 x float> zeroinitializer, <1 x float> zeroinitializer)
+  %i0 = extractvalue { <1 x float>, <1 x float> } %r, 0
+  %i1 = extractvalue { <1 x float>, <1 x float> } %r, 1
+  %i4 = shufflevector <1 x float> %i0, <1 x float> %i1, <2 x i32> <i32 0, i32 1>
+  %mul = fmul < 2 x float> %i4, zeroinitializer
+  store <2 x float> %mul, ptr %p, align 8
+  ret void
+}

>From f5a7bc72bdc84ef7624267871b7165d27949afc0 Mon Sep 17 00:00:00 2001
From: Princeton Ferro <pferro at nvidia.com>
Date: Wed, 13 Aug 2025 15:13:28 -0700
Subject: [PATCH 2/4] use fadd v2f32 to keep bitcast pattern in isel

---
 .../test/CodeGen/NVPTX/f32x2-convert-i32x2.ll | 61 +++++++++++++------
 1 file changed, 43 insertions(+), 18 deletions(-)

diff --git a/llvm/test/CodeGen/NVPTX/f32x2-convert-i32x2.ll b/llvm/test/CodeGen/NVPTX/f32x2-convert-i32x2.ll
index 13f8dd34b5ce3..38ca8b2bc98da 100644
--- a/llvm/test/CodeGen/NVPTX/f32x2-convert-i32x2.ll
+++ b/llvm/test/CodeGen/NVPTX/f32x2-convert-i32x2.ll
@@ -20,26 +20,49 @@ declare <2 x i32> @return_i32x2(i32 %0)
 
 ; Test with v2i32.
 define ptx_kernel void @store_i32x2(i32 %0, ptr %p) {
-; CHECK-LABEL: store_i32x2(
-; CHECK:       {
-; CHECK-NEXT:    .reg .b32 %r<6>;
-; CHECK-NEXT:    .reg .b64 %rd<2>;
-; CHECK-EMPTY:
-; CHECK-NEXT:  // %bb.0:
-; CHECK-NEXT:    ld.param.b64 %rd1, [store_i32x2_param_1];
-; CHECK-NEXT:    ld.param.b32 %r1, [store_i32x2_param_0];
-; CHECK-NEXT:    { // callseq 0, 0
-; CHECK-NEXT:    .param .b32 param0;
-; CHECK-NEXT:    .param .align 8 .b8 retval0[8];
-; CHECK-NEXT:    st.param.b32 [param0], %r1;
-; CHECK-NEXT:    call.uni (retval0), return_i32x2, (param0);
-; CHECK-NEXT:    ld.param.v2.b32 {%r2, %r3}, [retval0];
-; CHECK-NEXT:    } // callseq 0
-; CHECK-NEXT:    st.v2.b32 [%rd1], {%r2, %r3};
-; CHECK-NEXT:    ret;
+; CHECK-SM90A-LABEL: store_i32x2(
+; CHECK-SM90A:       {
+; CHECK-SM90A-NEXT:    .reg .b32 %r<8>;
+; CHECK-SM90A-NEXT:    .reg .b64 %rd<2>;
+; CHECK-SM90A-EMPTY:
+; CHECK-SM90A-NEXT:  // %bb.0:
+; CHECK-SM90A-NEXT:    ld.param.b64 %rd1, [store_i32x2_param_1];
+; CHECK-SM90A-NEXT:    ld.param.b32 %r1, [store_i32x2_param_0];
+; CHECK-SM90A-NEXT:    { // callseq 0, 0
+; CHECK-SM90A-NEXT:    .param .b32 param0;
+; CHECK-SM90A-NEXT:    .param .align 8 .b8 retval0[8];
+; CHECK-SM90A-NEXT:    st.param.b32 [param0], %r1;
+; CHECK-SM90A-NEXT:    call.uni (retval0), return_i32x2, (param0);
+; CHECK-SM90A-NEXT:    ld.param.v2.b32 {%r2, %r3}, [retval0];
+; CHECK-SM90A-NEXT:    } // callseq 0
+; CHECK-SM90A-NEXT:    add.rn.f32 %r6, %r3, %r3;
+; CHECK-SM90A-NEXT:    add.rn.f32 %r7, %r2, %r2;
+; CHECK-SM90A-NEXT:    st.v2.b32 [%rd1], {%r7, %r6};
+; CHECK-SM90A-NEXT:    ret;
+;
+; CHECK-SM100-LABEL: store_i32x2(
+; CHECK-SM100:       {
+; CHECK-SM100-NEXT:    .reg .b32 %r<6>;
+; CHECK-SM100-NEXT:    .reg .b64 %rd<4>;
+; CHECK-SM100-EMPTY:
+; CHECK-SM100-NEXT:  // %bb.0:
+; CHECK-SM100-NEXT:    ld.param.b64 %rd1, [store_i32x2_param_1];
+; CHECK-SM100-NEXT:    ld.param.b32 %r1, [store_i32x2_param_0];
+; CHECK-SM100-NEXT:    { // callseq 0, 0
+; CHECK-SM100-NEXT:    .param .b32 param0;
+; CHECK-SM100-NEXT:    .param .align 8 .b8 retval0[8];
+; CHECK-SM100-NEXT:    st.param.b32 [param0], %r1;
+; CHECK-SM100-NEXT:    call.uni (retval0), return_i32x2, (param0);
+; CHECK-SM100-NEXT:    ld.param.v2.b32 {%r2, %r3}, [retval0];
+; CHECK-SM100-NEXT:    } // callseq 0
+; CHECK-SM100-NEXT:    mov.b64 %rd2, {%r2, %r3};
+; CHECK-SM100-NEXT:    add.rn.f32x2 %rd3, %rd2, %rd2;
+; CHECK-SM100-NEXT:    st.b64 [%rd1], %rd3;
+; CHECK-SM100-NEXT:    ret;
   %v = call <2 x i32> @return_i32x2(i32 %0)
   %v.f32x2 = bitcast <2 x i32> %v to <2 x float>
-  store <2 x float> %v.f32x2, ptr %p, align 8
+  %res = fadd <2 x float> %v.f32x2, %v.f32x2
+  store <2 x float> %res, ptr %p, align 8
   ret void
 }
 
@@ -93,3 +116,5 @@ define ptx_kernel void @inlineasm(ptr %p) {
   store <2 x float> %mul, ptr %p, align 8
   ret void
 }
+;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
+; CHECK: {{.*}}

>From 3155ed29d8403f9141dffca6394d8056cdcfefd6 Mon Sep 17 00:00:00 2001
From: Princeton Ferro <pferro at nvidia.com>
Date: Thu, 21 Aug 2025 10:07:13 -0700
Subject: [PATCH 3/4] update test

---
 llvm/test/CodeGen/NVPTX/f32x2-convert-i32x2.ll | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/llvm/test/CodeGen/NVPTX/f32x2-convert-i32x2.ll b/llvm/test/CodeGen/NVPTX/f32x2-convert-i32x2.ll
index 38ca8b2bc98da..2bb1cade466bd 100644
--- a/llvm/test/CodeGen/NVPTX/f32x2-convert-i32x2.ll
+++ b/llvm/test/CodeGen/NVPTX/f32x2-convert-i32x2.ll
@@ -22,7 +22,7 @@ declare <2 x i32> @return_i32x2(i32 %0)
 define ptx_kernel void @store_i32x2(i32 %0, ptr %p) {
 ; CHECK-SM90A-LABEL: store_i32x2(
 ; CHECK-SM90A:       {
-; CHECK-SM90A-NEXT:    .reg .b32 %r<8>;
+; CHECK-SM90A-NEXT:    .reg .b32 %r<6>;
 ; CHECK-SM90A-NEXT:    .reg .b64 %rd<2>;
 ; CHECK-SM90A-EMPTY:
 ; CHECK-SM90A-NEXT:  // %bb.0:
@@ -35,14 +35,14 @@ define ptx_kernel void @store_i32x2(i32 %0, ptr %p) {
 ; CHECK-SM90A-NEXT:    call.uni (retval0), return_i32x2, (param0);
 ; CHECK-SM90A-NEXT:    ld.param.v2.b32 {%r2, %r3}, [retval0];
 ; CHECK-SM90A-NEXT:    } // callseq 0
-; CHECK-SM90A-NEXT:    add.rn.f32 %r6, %r3, %r3;
-; CHECK-SM90A-NEXT:    add.rn.f32 %r7, %r2, %r2;
-; CHECK-SM90A-NEXT:    st.v2.b32 [%rd1], {%r7, %r6};
+; CHECK-SM90A-NEXT:    add.rn.f32 %r4, %r3, %r3;
+; CHECK-SM90A-NEXT:    add.rn.f32 %r5, %r2, %r2;
+; CHECK-SM90A-NEXT:    st.v2.b32 [%rd1], {%r5, %r4};
 ; CHECK-SM90A-NEXT:    ret;
 ;
 ; CHECK-SM100-LABEL: store_i32x2(
 ; CHECK-SM100:       {
-; CHECK-SM100-NEXT:    .reg .b32 %r<6>;
+; CHECK-SM100-NEXT:    .reg .b32 %r<4>;
 ; CHECK-SM100-NEXT:    .reg .b64 %rd<4>;
 ; CHECK-SM100-EMPTY:
 ; CHECK-SM100-NEXT:  // %bb.0:

>From b554bea523eb9f35837edf0315bcba232da30fdf Mon Sep 17 00:00:00 2001
From: Princeton Ferro <pferro at nvidia.com>
Date: Thu, 21 Aug 2025 11:01:56 -0700
Subject: [PATCH 4/4] formatting

---
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 566ba840b2eee..e746cffc188b2 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -5302,8 +5302,8 @@ static SDValue PerformANDCombine(SDNode *N,
     bool AddTo = false;
     if (AExt.getNode() != nullptr) {
       // Re-insert the ext as a zext.
-      Val = DCI.DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N),
-                            AExt.getValueType(), Val);
+      Val =
+          DCI.DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), AExt.getValueType(), Val);
       AddTo = true;
     }
 



More information about the llvm-commits mailing list