[llvm] [NVPTX] Implement computeKnownBitsForTargetNode for LoadV (PR #154165)

Kevin McAfee via llvm-commits llvm-commits at lists.llvm.org
Wed Aug 20 11:29:45 PDT 2025


https://github.com/kalxr updated https://github.com/llvm/llvm-project/pull/154165

>From d11030d2601bad619638a674915f0860ffd77459 Mon Sep 17 00:00:00 2001
From: Kevin McAfee <kmcafee at nvidia.com>
Date: Mon, 18 Aug 2025 17:36:45 +0000
Subject: [PATCH 1/6] pre-commit test

---
 llvm/test/CodeGen/NVPTX/i8x2-instructions.ll | 34 ++++++++++++++++++++
 1 file changed, 34 insertions(+)

diff --git a/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll b/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll
index 53150c1a01314..7463fb5b59e02 100644
--- a/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll
@@ -103,5 +103,39 @@ define <2 x i8> @test_call_2xi8(<2 x i8> %a) {
   %res = call <2 x i8> @test_call_2xi8(<2 x i8> %a)
   ret <2 x i8> %res
 }
+
+define <2 x float> @test_uitofp_2xi8(<2 x i8> %a) {
+; O0-LABEL: test_uitofp_2xi8(
+; O0:       {
+; O0-NEXT:    .reg .b16 %rs<5>;
+; O0-NEXT:    .reg .b32 %r<5>;
+; O0-EMPTY:
+; O0-NEXT:  // %bb.0:
+; O0-NEXT:    ld.param.v2.b8 {%rs1, %rs2}, [test_uitofp_2xi8_param_0];
+; O0-NEXT:    mov.b32 %r1, {%rs1, %rs2};
+; O0-NEXT:    and.b32 %r2, %r1, 16711935;
+; O0-NEXT:    mov.b32 {%rs3, %rs4}, %r2;
+; O0-NEXT:    cvt.rn.f32.u16 %r3, %rs4;
+; O0-NEXT:    cvt.rn.f32.u16 %r4, %rs3;
+; O0-NEXT:    st.param.v2.b32 [func_retval0], {%r4, %r3};
+; O0-NEXT:    ret;
+;
+; O3-LABEL: test_uitofp_2xi8(
+; O3:       {
+; O3-NEXT:    .reg .b16 %rs<5>;
+; O3-NEXT:    .reg .b32 %r<5>;
+; O3-EMPTY:
+; O3-NEXT:  // %bb.0:
+; O3-NEXT:    ld.param.v2.b8 {%rs1, %rs2}, [test_uitofp_2xi8_param_0];
+; O3-NEXT:    mov.b32 %r1, {%rs1, %rs2};
+; O3-NEXT:    and.b32 %r2, %r1, 16711935;
+; O3-NEXT:    mov.b32 {%rs3, %rs4}, %r2;
+; O3-NEXT:    cvt.rn.f32.u16 %r3, %rs4;
+; O3-NEXT:    cvt.rn.f32.u16 %r4, %rs3;
+; O3-NEXT:    st.param.v2.b32 [func_retval0], {%r4, %r3};
+; O3-NEXT:    ret;
+  %1 = uitofp <2 x i8> %a to <2 x float>
+  ret <2 x float> %1
+}
 ;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
 ; COMMON: {{.*}}

>From ea1f7d53c3635f1b8d3c15819a52b2cbf77e2a2e Mon Sep 17 00:00:00 2001
From: Kevin McAfee <kmcafee at nvidia.com>
Date: Mon, 18 Aug 2025 17:38:29 +0000
Subject: [PATCH 2/6] [NVPTX] Support vectors for AND combine

---
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp  | 114 ++++++++++++-------
 llvm/test/CodeGen/NVPTX/i8x2-instructions.ll |  25 ++--
 llvm/test/CodeGen/NVPTX/shift-opt.ll         |   9 +-
 3 files changed, 84 insertions(+), 64 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 74e6c139c610d..8190c23407250 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -5242,6 +5242,58 @@ static SDValue PerformFADDCombine(SDNode *N,
   return PerformFADDCombineWithOperands(N, N1, N0, DCI, OptLevel);
 }
 
+// Helper function to check if an AND operation on a load can be eliminated.
+// Returns a replacement value if the load can be eliminated, else nullopt.
+static std::optional<SDValue>
+canEliminateLoadAND(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
+                    SDValue Val, SDValue Mask, SDValue AExt) {
+  if (Val->getOpcode() != NVPTXISD::LoadV2 &&
+      Val->getOpcode() != NVPTXISD::LoadV4) {
+    return std::nullopt;
+  }
+
+  ConstantSDNode *MaskCnst = dyn_cast<ConstantSDNode>(Mask);
+  if (!MaskCnst) {
+    // Not an AND with a constant
+    return std::nullopt;
+  }
+
+  uint64_t MaskVal = MaskCnst->getZExtValue();
+  if (MaskVal != 0xff) {
+    // Not an AND that chops off top 8 bits
+    return std::nullopt;
+  }
+
+  MemSDNode *Mem = dyn_cast<MemSDNode>(Val);
+  if (!Mem) {
+    // Not a MemSDNode
+    return std::nullopt;
+  }
+
+  EVT MemVT = Mem->getMemoryVT();
+  if (MemVT != MVT::v2i8 && MemVT != MVT::v4i8) {
+    // We only handle the i8 case
+    return std::nullopt;
+  }
+
+  unsigned ExtType = Val->getConstantOperandVal(Val->getNumOperands() - 1);
+  if (ExtType == ISD::SEXTLOAD) {
+    // If the load is a sextload, the AND is needed to zero out the high 8 bits
+    return std::nullopt;
+  }
+
+  SDValue Result = Val;
+
+  if (AExt) {
+    // Re-insert the ext as a zext.
+    Result =
+        DCI.DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), AExt.getValueType(), Val);
+  }
+
+  // If we get here, the AND is unnecessary. Replace it with the load.
+  return Result;
+}
+
 static SDValue PerformANDCombine(SDNode *N,
                                  TargetLowering::DAGCombinerInfo &DCI) {
   // The type legalizer turns a vector load of i8 values into a zextload to i16
@@ -5252,9 +5304,8 @@ static SDValue PerformANDCombine(SDNode *N,
   SDValue Val = N->getOperand(0);
   SDValue Mask = N->getOperand(1);
 
-  if (isa<ConstantSDNode>(Val)) {
+  if (isa<ConstantSDNode>(Val))
     std::swap(Val, Mask);
-  }
 
   SDValue AExt;
 
@@ -5264,49 +5315,24 @@ static SDValue PerformANDCombine(SDNode *N,
     Val = Val->getOperand(0);
   }
 
-  if (Val->getOpcode() == NVPTXISD::LoadV2 ||
-      Val->getOpcode() == NVPTXISD::LoadV4) {
-    ConstantSDNode *MaskCnst = dyn_cast<ConstantSDNode>(Mask);
-    if (!MaskCnst) {
-      // Not an AND with a constant
-      return SDValue();
-    }
-
-    uint64_t MaskVal = MaskCnst->getZExtValue();
-    if (MaskVal != 0xff) {
-      // Not an AND that chops off top 8 bits
-      return SDValue();
-    }
-
-    MemSDNode *Mem = dyn_cast<MemSDNode>(Val);
-    if (!Mem) {
-      // Not a MemSDNode?!?
-      return SDValue();
-    }
-
-    EVT MemVT = Mem->getMemoryVT();
-    if (MemVT != MVT::v2i8 && MemVT != MVT::v4i8) {
-      // We only handle the i8 case
-      return SDValue();
-    }
-
-    unsigned ExtType = Val->getConstantOperandVal(Val->getNumOperands() - 1);
-    if (ExtType == ISD::SEXTLOAD) {
-      // If for some reason the load is a sextload, the and is needed to zero
-      // out the high 8 bits
-      return SDValue();
-    }
-
-    bool AddTo = false;
-    if (AExt.getNode() != nullptr) {
-      // Re-insert the ext as a zext.
-      Val = DCI.DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N),
-                            AExt.getValueType(), Val);
-      AddTo = true;
+  if (Val.getOpcode() == ISD::BUILD_VECTOR &&
+      Mask.getOpcode() == ISD::BUILD_VECTOR) {
+    assert(Val->getNumOperands() == Mask->getNumOperands() && !AExt);
+    for (unsigned I = 0; I < Val->getNumOperands(); ++I) {
+      // We know that the AExt is null and therefore the result of this call
+      // will be the BUILD_VECTOR operand or nullopt. Rather than create a new
+      // BUILD_VECTOR with the collection of operands, we can just use the
+      // original and ignore the result.
+      if (!canEliminateLoadAND(N, DCI, Val->getOperand(I), Mask->getOperand(I),
+                               AExt)
+               .has_value())
+        return SDValue();
     }
-
-    // If we get here, the AND is unnecessary.  Just replace it with the load
-    DCI.CombineTo(N, Val, AddTo);
+    DCI.CombineTo(N, Val, false);
+  } else {
+    auto Result = canEliminateLoadAND(N, DCI, Val, Mask, AExt);
+    if (Result.has_value())
+      DCI.CombineTo(N, Result.value(), AExt.getNode() != nullptr);
   }
 
   return SDValue();
diff --git a/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll b/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll
index 7463fb5b59e02..f4053d84593a5 100644
--- a/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll
@@ -107,32 +107,27 @@ define <2 x i8> @test_call_2xi8(<2 x i8> %a) {
 define <2 x float> @test_uitofp_2xi8(<2 x i8> %a) {
 ; O0-LABEL: test_uitofp_2xi8(
 ; O0:       {
-; O0-NEXT:    .reg .b16 %rs<5>;
-; O0-NEXT:    .reg .b32 %r<5>;
+; O0-NEXT:    .reg .b16 %rs<3>;
+; O0-NEXT:    .reg .b32 %r<4>;
 ; O0-EMPTY:
 ; O0-NEXT:  // %bb.0:
 ; O0-NEXT:    ld.param.v2.b8 {%rs1, %rs2}, [test_uitofp_2xi8_param_0];
 ; O0-NEXT:    mov.b32 %r1, {%rs1, %rs2};
-; O0-NEXT:    and.b32 %r2, %r1, 16711935;
-; O0-NEXT:    mov.b32 {%rs3, %rs4}, %r2;
-; O0-NEXT:    cvt.rn.f32.u16 %r3, %rs4;
-; O0-NEXT:    cvt.rn.f32.u16 %r4, %rs3;
-; O0-NEXT:    st.param.v2.b32 [func_retval0], {%r4, %r3};
+; O0-NEXT:    cvt.rn.f32.u16 %r2, %rs2;
+; O0-NEXT:    cvt.rn.f32.u16 %r3, %rs1;
+; O0-NEXT:    st.param.v2.b32 [func_retval0], {%r3, %r2};
 ; O0-NEXT:    ret;
 ;
 ; O3-LABEL: test_uitofp_2xi8(
 ; O3:       {
-; O3-NEXT:    .reg .b16 %rs<5>;
-; O3-NEXT:    .reg .b32 %r<5>;
+; O3-NEXT:    .reg .b16 %rs<3>;
+; O3-NEXT:    .reg .b32 %r<3>;
 ; O3-EMPTY:
 ; O3-NEXT:  // %bb.0:
 ; O3-NEXT:    ld.param.v2.b8 {%rs1, %rs2}, [test_uitofp_2xi8_param_0];
-; O3-NEXT:    mov.b32 %r1, {%rs1, %rs2};
-; O3-NEXT:    and.b32 %r2, %r1, 16711935;
-; O3-NEXT:    mov.b32 {%rs3, %rs4}, %r2;
-; O3-NEXT:    cvt.rn.f32.u16 %r3, %rs4;
-; O3-NEXT:    cvt.rn.f32.u16 %r4, %rs3;
-; O3-NEXT:    st.param.v2.b32 [func_retval0], {%r4, %r3};
+; O3-NEXT:    cvt.rn.f32.u16 %r1, %rs2;
+; O3-NEXT:    cvt.rn.f32.u16 %r2, %rs1;
+; O3-NEXT:    st.param.v2.b32 [func_retval0], {%r2, %r1};
 ; O3-NEXT:    ret;
   %1 = uitofp <2 x i8> %a to <2 x float>
   ret <2 x float> %1
diff --git a/llvm/test/CodeGen/NVPTX/shift-opt.ll b/llvm/test/CodeGen/NVPTX/shift-opt.ll
index e7866b01064c7..e0d22c62993ba 100644
--- a/llvm/test/CodeGen/NVPTX/shift-opt.ll
+++ b/llvm/test/CodeGen/NVPTX/shift-opt.ll
@@ -71,18 +71,17 @@ define <2 x i16> @test_vec(<2 x i16> %x, <2 x i8> %y) {
 ; CHECK-LABEL: test_vec(
 ; CHECK:       {
 ; CHECK-NEXT:    .reg .b16 %rs<7>;
-; CHECK-NEXT:    .reg .b32 %r<5>;
+; CHECK-NEXT:    .reg .b32 %r<4>;
 ; CHECK-EMPTY:
 ; CHECK-NEXT:  // %bb.0:
 ; CHECK-NEXT:    ld.param.v2.b16 {%rs1, %rs2}, [test_vec_param_0];
 ; CHECK-NEXT:    ld.param.v2.b8 {%rs3, %rs4}, [test_vec_param_1];
 ; CHECK-NEXT:    mov.b32 %r1, {%rs3, %rs4};
-; CHECK-NEXT:    and.b32 %r2, %r1, 16711935;
 ; CHECK-NEXT:    shr.u16 %rs5, %rs2, 5;
 ; CHECK-NEXT:    shr.u16 %rs6, %rs1, 5;
-; CHECK-NEXT:    mov.b32 %r3, {%rs6, %rs5};
-; CHECK-NEXT:    or.b32 %r4, %r3, %r2;
-; CHECK-NEXT:    st.param.b32 [func_retval0], %r4;
+; CHECK-NEXT:    mov.b32 %r2, {%rs6, %rs5};
+; CHECK-NEXT:    or.b32 %r3, %r2, %r1;
+; CHECK-NEXT:    st.param.b32 [func_retval0], %r3;
 ; CHECK-NEXT:    ret;
   %ext = zext <2 x i8> %y to <2 x i16>
   %shl = shl <2 x i16> %ext, splat(i16 5)

>From 885dcae41ee122734c5c77a0fcb4e5f99c903cb8 Mon Sep 17 00:00:00 2001
From: Kevin McAfee <kmcafee at nvidia.com>
Date: Mon, 18 Aug 2025 23:51:51 +0000
Subject: [PATCH 3/6] Remove AND combine and implement computeKnownBits for
 LoadV2/4

---
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 123 ++++----------------
 1 file changed, 25 insertions(+), 98 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 8190c23407250..67a5f2ae84c31 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -5242,102 +5242,6 @@ static SDValue PerformFADDCombine(SDNode *N,
   return PerformFADDCombineWithOperands(N, N1, N0, DCI, OptLevel);
 }
 
-// Helper function to check if an AND operation on a load can be eliminated.
-// Returns a replacement value if the load can be eliminated, else nullopt.
-static std::optional<SDValue>
-canEliminateLoadAND(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
-                    SDValue Val, SDValue Mask, SDValue AExt) {
-  if (Val->getOpcode() != NVPTXISD::LoadV2 &&
-      Val->getOpcode() != NVPTXISD::LoadV4) {
-    return std::nullopt;
-  }
-
-  ConstantSDNode *MaskCnst = dyn_cast<ConstantSDNode>(Mask);
-  if (!MaskCnst) {
-    // Not an AND with a constant
-    return std::nullopt;
-  }
-
-  uint64_t MaskVal = MaskCnst->getZExtValue();
-  if (MaskVal != 0xff) {
-    // Not an AND that chops off top 8 bits
-    return std::nullopt;
-  }
-
-  MemSDNode *Mem = dyn_cast<MemSDNode>(Val);
-  if (!Mem) {
-    // Not a MemSDNode
-    return std::nullopt;
-  }
-
-  EVT MemVT = Mem->getMemoryVT();
-  if (MemVT != MVT::v2i8 && MemVT != MVT::v4i8) {
-    // We only handle the i8 case
-    return std::nullopt;
-  }
-
-  unsigned ExtType = Val->getConstantOperandVal(Val->getNumOperands() - 1);
-  if (ExtType == ISD::SEXTLOAD) {
-    // If the load is a sextload, the AND is needed to zero out the high 8 bits
-    return std::nullopt;
-  }
-
-  SDValue Result = Val;
-
-  if (AExt) {
-    // Re-insert the ext as a zext.
-    Result =
-        DCI.DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), AExt.getValueType(), Val);
-  }
-
-  // If we get here, the AND is unnecessary. Replace it with the load.
-  return Result;
-}
-
-static SDValue PerformANDCombine(SDNode *N,
-                                 TargetLowering::DAGCombinerInfo &DCI) {
-  // The type legalizer turns a vector load of i8 values into a zextload to i16
-  // registers, optionally ANY_EXTENDs it (if target type is integer),
-  // and ANDs off the high 8 bits. Since we turn this load into a
-  // target-specific DAG node, the DAG combiner fails to eliminate these AND
-  // nodes. Do that here.
-  SDValue Val = N->getOperand(0);
-  SDValue Mask = N->getOperand(1);
-
-  if (isa<ConstantSDNode>(Val))
-    std::swap(Val, Mask);
-
-  SDValue AExt;
-
-  // Generally, we will see zextload -> IMOV16rr -> ANY_EXTEND -> and
-  if (Val.getOpcode() == ISD::ANY_EXTEND) {
-    AExt = Val;
-    Val = Val->getOperand(0);
-  }
-
-  if (Val.getOpcode() == ISD::BUILD_VECTOR &&
-      Mask.getOpcode() == ISD::BUILD_VECTOR) {
-    assert(Val->getNumOperands() == Mask->getNumOperands() && !AExt);
-    for (unsigned I = 0; I < Val->getNumOperands(); ++I) {
-      // We know that the AExt is null and therefore the result of this call
-      // will be the BUILD_VECTOR operand or nullopt. Rather than create a new
-      // BUILD_VECTOR with the collection of operands, we can just use the
-      // original and ignore the result.
-      if (!canEliminateLoadAND(N, DCI, Val->getOperand(I), Mask->getOperand(I),
-                               AExt)
-               .has_value())
-        return SDValue();
-    }
-    DCI.CombineTo(N, Val, false);
-  } else {
-    auto Result = canEliminateLoadAND(N, DCI, Val, Mask, AExt);
-    if (Result.has_value())
-      DCI.CombineTo(N, Result.value(), AExt.getNode() != nullptr);
-  }
-
-  return SDValue();
-}
-
 static SDValue PerformREMCombine(SDNode *N,
                                  TargetLowering::DAGCombinerInfo &DCI,
                                  CodeGenOptLevel OptLevel) {
@@ -6009,8 +5913,6 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
     return PerformADDCombine(N, DCI, OptLevel);
   case ISD::ADDRSPACECAST:
     return combineADDRSPACECAST(N, DCI);
-  case ISD::AND:
-    return PerformANDCombine(N, DCI);
   case ISD::SIGN_EXTEND:
   case ISD::ZERO_EXTEND:
     return combineMulWide(N, DCI, OptLevel);
@@ -6635,6 +6537,27 @@ static void computeKnownBitsForPRMT(const SDValue Op, KnownBits &Known,
   }
 }
 
+static void computeKnownBitsFori8VLoad(const SDValue Op, KnownBits &Known) {
+  MemSDNode *Mem = dyn_cast<MemSDNode>(Op);
+  if (!Mem) {
+    return;
+  }
+
+  EVT MemVT = Mem->getMemoryVT();
+  if (MemVT != MVT::v2i8 && MemVT != MVT::v4i8) {
+    return;
+  }
+
+  unsigned ExtType = Mem->getConstantOperandVal(Mem->getNumOperands() - 1);
+  if (ExtType == ISD::SEXTLOAD) {
+    Known = Known.sext(Known.getBitWidth());
+    return;
+  }
+  KnownBits HighZeros(Known.getBitWidth() - 8);
+  HighZeros.setAllZero();
+  Known.insertBits(HighZeros, 8);
+}
+
 void NVPTXTargetLowering::computeKnownBitsForTargetNode(
     const SDValue Op, KnownBits &Known, const APInt &DemandedElts,
     const SelectionDAG &DAG, unsigned Depth) const {
@@ -6644,6 +6567,10 @@ void NVPTXTargetLowering::computeKnownBitsForTargetNode(
   case NVPTXISD::PRMT:
     computeKnownBitsForPRMT(Op, Known, DAG, Depth);
     break;
+  case NVPTXISD::LoadV2:
+  case NVPTXISD::LoadV4:
+    computeKnownBitsFori8VLoad(Op, Known);
+    break;
   default:
     break;
   }

>From 2a2215ab6b8546472352ca8f04935b84b627f175 Mon Sep 17 00:00:00 2001
From: Kevin McAfee <kmcafee at nvidia.com>
Date: Wed, 20 Aug 2025 00:11:45 +0000
Subject: [PATCH 4/6] generalize for more types

---
 llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp | 23 +++++----------
 llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h   |  1 +
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 31 +++++++++++----------
 3 files changed, 24 insertions(+), 31 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 520ce4deb9a57..d86c0905cf943 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -1309,26 +1309,17 @@ bool NVPTXDAGToDAGISel::tryLDG(MemSDNode *LD) {
   return true;
 }
 
+unsigned NVPTXDAGToDAGISel::getFromTypeWidthForLoad(const MemSDNode *Mem) {
+  EVT MemVT = Mem->getMemoryVT();
+  auto ElementBitWidth = MemVT.getSizeInBits() / (Mem->getNumValues() - 1);
+  return ElementBitWidth;
+}
+
 bool NVPTXDAGToDAGISel::tryLDU(SDNode *N) {
   auto *LD = cast<MemSDNode>(N);
 
-  unsigned NumElts;
-  switch (N->getOpcode()) {
-  default:
-    llvm_unreachable("Unexpected opcode");
-  case ISD::INTRINSIC_W_CHAIN:
-    NumElts = 1;
-    break;
-  case NVPTXISD::LDUV2:
-    NumElts = 2;
-    break;
-  case NVPTXISD::LDUV4:
-    NumElts = 4;
-    break;
-  }
-
   SDLoc DL(N);
-  const unsigned FromTypeWidth = LD->getMemoryVT().getSizeInBits() / NumElts;
+  const unsigned FromTypeWidth = getFromTypeWidthForLoad(LD);
   const MVT::SimpleValueType TargetVT = LD->getSimpleValueType(0).SimpleTy;
 
   // If this is an LDU intrinsic, the address is the third operand. If its an
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
index 65731722f5343..e2ad55bc1796d 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
@@ -111,6 +111,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
 
 public:
   static NVPTX::AddressSpace getAddrSpace(const MemSDNode *N);
+  static unsigned getFromTypeWidthForLoad(const MemSDNode *Mem);
 };
 
 class NVPTXDAGToDAGISelLegacy : public SelectionDAGISelLegacy {
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 67a5f2ae84c31..0d96c9353141a 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -14,6 +14,7 @@
 #include "NVPTXISelLowering.h"
 #include "MCTargetDesc/NVPTXBaseInfo.h"
 #include "NVPTX.h"
+#include "NVPTXISelDAGToDAG.h"
 #include "NVPTXSubtarget.h"
 #include "NVPTXTargetMachine.h"
 #include "NVPTXTargetObjectFile.h"
@@ -6537,25 +6538,24 @@ static void computeKnownBitsForPRMT(const SDValue Op, KnownBits &Known,
   }
 }
 
-static void computeKnownBitsFori8VLoad(const SDValue Op, KnownBits &Known) {
-  MemSDNode *Mem = dyn_cast<MemSDNode>(Op);
-  if (!Mem) {
-    return;
-  }
+static void computeKnownBitsForVLoad(const SDValue Op, KnownBits &Known) {
+  MemSDNode *LD = cast<MemSDNode>(Op);
 
-  EVT MemVT = Mem->getMemoryVT();
-  if (MemVT != MVT::v2i8 && MemVT != MVT::v4i8) {
+  // We can't do anything without knowing the sign bit.
+  auto ExtType = LD->getConstantOperandVal(LD->getNumOperands() - 1);
+  if (ExtType == ISD::SEXTLOAD)
     return;
-  }
 
-  unsigned ExtType = Mem->getConstantOperandVal(Mem->getNumOperands() - 1);
-  if (ExtType == ISD::SEXTLOAD) {
-    Known = Known.sext(Known.getBitWidth());
+  // ExtLoading to vector types is weird and may not work well with known bits.
+  auto DestVT = LD->getValueType(0);
+  if (DestVT.isVector())
     return;
-  }
-  KnownBits HighZeros(Known.getBitWidth() - 8);
+
+  assert(Known.getBitWidth() == DestVT.getSizeInBits());
+  auto ElementBitWidth = NVPTXDAGToDAGISel::getFromTypeWidthForLoad(LD);
+  KnownBits HighZeros(Known.getBitWidth() - ElementBitWidth);
   HighZeros.setAllZero();
-  Known.insertBits(HighZeros, 8);
+  Known.insertBits(HighZeros, ElementBitWidth);
 }
 
 void NVPTXTargetLowering::computeKnownBitsForTargetNode(
@@ -6569,7 +6569,8 @@ void NVPTXTargetLowering::computeKnownBitsForTargetNode(
     break;
   case NVPTXISD::LoadV2:
   case NVPTXISD::LoadV4:
-    computeKnownBitsFori8VLoad(Op, Known);
+  case NVPTXISD::LoadV8:
+    computeKnownBitsForVLoad(Op, Known);
     break;
   default:
     break;

>From ef99d229ac828cae59088ff0d7b788b3c7e7987f Mon Sep 17 00:00:00 2001
From: Kevin McAfee <kmcafee at nvidia.com>
Date: Wed, 20 Aug 2025 18:24:27 +0000
Subject: [PATCH 5/6] Refactor/expand use of getFromTypeWidthForLoad

---
 llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp | 30 +++++++--------------
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp |  4 +--
 2 files changed, 11 insertions(+), 23 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index d86c0905cf943..3300ed9a5a81c 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -1150,15 +1150,12 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
   return true;
 }
 
-static unsigned getLoadStoreVectorNumElts(SDNode *N) {
+static unsigned getStoreVectorNumElts(SDNode *N) {
   switch (N->getOpcode()) {
-  case NVPTXISD::LoadV2:
   case NVPTXISD::StoreV2:
     return 2;
-  case NVPTXISD::LoadV4:
   case NVPTXISD::StoreV4:
     return 4;
-  case NVPTXISD::LoadV8:
   case NVPTXISD::StoreV8:
     return 8;
   default:
@@ -1171,7 +1168,6 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
   const EVT MemEVT = LD->getMemoryVT();
   if (!MemEVT.isSimple())
     return false;
-  const MVT MemVT = MemEVT.getSimpleVT();
 
   // Address Space Setting
   const auto CodeAddrSpace = getAddrSpace(LD);
@@ -1191,18 +1187,15 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
   // Float  : ISD::NON_EXTLOAD or ISD::EXTLOAD and the type is float
   // Read at least 8 bits (predicates are stored as 8-bit values)
   // The last operand holds the original LoadSDNode::getExtensionType() value
-  const unsigned TotalWidth = MemVT.getSizeInBits();
   const unsigned ExtensionType =
       N->getConstantOperandVal(N->getNumOperands() - 1);
   const unsigned FromType = (ExtensionType == ISD::SEXTLOAD)
                                 ? NVPTX::PTXLdStInstCode::Signed
                                 : NVPTX::PTXLdStInstCode::Untyped;
 
-  const unsigned FromTypeWidth = TotalWidth / getLoadStoreVectorNumElts(N);
+  const unsigned FromTypeWidth = getFromTypeWidthForLoad(LD);
 
   assert(!(EltVT.isVector() && ExtensionType != ISD::NON_EXTLOAD));
-  assert(isPowerOf2_32(FromTypeWidth) && FromTypeWidth >= 8 &&
-         FromTypeWidth <= 128 && TotalWidth <= 256 && "Invalid width for load");
 
   const auto [Base, Offset] = selectADDR(N->getOperand(1), CurDAG);
   SDValue Ops[] = {getI32Imm(Ordering, DL),
@@ -1247,30 +1240,23 @@ bool NVPTXDAGToDAGISel::tryLDG(MemSDNode *LD) {
   const EVT LoadedEVT = LD->getMemoryVT();
   if (!LoadedEVT.isSimple())
     return false;
-  const MVT LoadedVT = LoadedEVT.getSimpleVT();
 
   SDLoc DL(LD);
 
-  const unsigned TotalWidth = LoadedVT.getSizeInBits();
   unsigned ExtensionType;
-  unsigned NumElts;
   if (const auto *Load = dyn_cast<LoadSDNode>(LD)) {
     ExtensionType = Load->getExtensionType();
-    NumElts = 1;
   } else {
     ExtensionType = LD->getConstantOperandVal(LD->getNumOperands() - 1);
-    NumElts = getLoadStoreVectorNumElts(LD);
   }
   const unsigned FromType = (ExtensionType == ISD::SEXTLOAD)
                                 ? NVPTX::PTXLdStInstCode::Signed
                                 : NVPTX::PTXLdStInstCode::Untyped;
 
-  const unsigned FromTypeWidth = TotalWidth / NumElts;
+  const unsigned FromTypeWidth = getFromTypeWidthForLoad(LD);
 
   assert(!(LD->getSimpleValueType(0).isVector() &&
            ExtensionType != ISD::NON_EXTLOAD));
-  assert(isPowerOf2_32(FromTypeWidth) && FromTypeWidth >= 8 &&
-         FromTypeWidth <= 128 && TotalWidth <= 256 && "Invalid width for load");
 
   const auto [Base, Offset] = selectADDR(LD->getOperand(1), CurDAG);
   SDValue Ops[] = {getI32Imm(FromType, DL), getI32Imm(FromTypeWidth, DL), Base,
@@ -1310,8 +1296,12 @@ bool NVPTXDAGToDAGISel::tryLDG(MemSDNode *LD) {
 }
 
 unsigned NVPTXDAGToDAGISel::getFromTypeWidthForLoad(const MemSDNode *Mem) {
-  EVT MemVT = Mem->getMemoryVT();
-  auto ElementBitWidth = MemVT.getSizeInBits() / (Mem->getNumValues() - 1);
+  auto TotalWidth = Mem->getMemoryVT().getSizeInBits();
+  auto NumElts = Mem->getNumValues() - 1;
+  auto ElementBitWidth = TotalWidth / NumElts;
+  assert(isPowerOf2_32(ElementBitWidth) && ElementBitWidth >= 8 &&
+         ElementBitWidth <= 128 && TotalWidth <= 256 &&
+         "Invalid width for load");
   return ElementBitWidth;
 }
 
@@ -1434,7 +1424,7 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
   // - for integer type, always use 'u'
   const unsigned TotalWidth = StoreVT.getSimpleVT().getSizeInBits();
 
-  const unsigned NumElts = getLoadStoreVectorNumElts(ST);
+  const unsigned NumElts = getStoreVectorNumElts(ST);
 
   SmallVector<SDValue, 16> Ops;
   for (auto &V : ST->ops().slice(1, NumElts))
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 0d96c9353141a..c1d3e9567d85d 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -6553,9 +6553,7 @@ static void computeKnownBitsForVLoad(const SDValue Op, KnownBits &Known) {
 
   assert(Known.getBitWidth() == DestVT.getSizeInBits());
   auto ElementBitWidth = NVPTXDAGToDAGISel::getFromTypeWidthForLoad(LD);
-  KnownBits HighZeros(Known.getBitWidth() - ElementBitWidth);
-  HighZeros.setAllZero();
-  Known.insertBits(HighZeros, ElementBitWidth);
+  Known.Zero.setHighBits(Known.getBitWidth() - ElementBitWidth);
 }
 
 void NVPTXTargetLowering::computeKnownBitsForTargetNode(

>From 4098eee4606980e84a2fe6e6baae74ca2496898b Mon Sep 17 00:00:00 2001
From: Kevin McAfee <kmcafee at nvidia.com>
Date: Wed, 20 Aug 2025 18:27:47 +0000
Subject: [PATCH 6/6] rename

---
 llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index c1d3e9567d85d..ad56d2f12caf6 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -6538,7 +6538,7 @@ static void computeKnownBitsForPRMT(const SDValue Op, KnownBits &Known,
   }
 }
 
-static void computeKnownBitsForVLoad(const SDValue Op, KnownBits &Known) {
+static void computeKnownBitsForLoadV(const SDValue Op, KnownBits &Known) {
   MemSDNode *LD = cast<MemSDNode>(Op);
 
   // We can't do anything without knowing the sign bit.
@@ -6568,7 +6568,7 @@ void NVPTXTargetLowering::computeKnownBitsForTargetNode(
   case NVPTXISD::LoadV2:
   case NVPTXISD::LoadV4:
   case NVPTXISD::LoadV8:
-    computeKnownBitsForVLoad(Op, Known);
+    computeKnownBitsForLoadV(Op, Known);
     break;
   default:
     break;



More information about the llvm-commits mailing list