[llvm] 691ccf2 - [NVPTX] Implement computeKnownBitsForTargetNode for LoadV (#154165)
via llvm-commits
llvm-commits at lists.llvm.org
Wed Aug 20 11:57:18 PDT 2025
Author: Kevin McAfee
Date: 2025-08-20T18:57:15Z
New Revision: 691ccf263aede14209b10ab2d16f8002767c217b
URL: https://github.com/llvm/llvm-project/commit/691ccf263aede14209b10ab2d16f8002767c217b
DIFF: https://github.com/llvm/llvm-project/commit/691ccf263aede14209b10ab2d16f8002767c217b.diff
LOG: [NVPTX] Implement computeKnownBitsForTargetNode for LoadV (#154165)
Remove AND combines as they are no longer needed after this.
Added:
Modified:
llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
llvm/test/CodeGen/NVPTX/i8x2-instructions.ll
llvm/test/CodeGen/NVPTX/shift-opt.ll
Removed:
################################################################################
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 520ce4deb9a57..3300ed9a5a81c 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -1150,15 +1150,12 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
return true;
}
-static unsigned getLoadStoreVectorNumElts(SDNode *N) {
+static unsigned getStoreVectorNumElts(SDNode *N) {
switch (N->getOpcode()) {
- case NVPTXISD::LoadV2:
case NVPTXISD::StoreV2:
return 2;
- case NVPTXISD::LoadV4:
case NVPTXISD::StoreV4:
return 4;
- case NVPTXISD::LoadV8:
case NVPTXISD::StoreV8:
return 8;
default:
@@ -1171,7 +1168,6 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
const EVT MemEVT = LD->getMemoryVT();
if (!MemEVT.isSimple())
return false;
- const MVT MemVT = MemEVT.getSimpleVT();
// Address Space Setting
const auto CodeAddrSpace = getAddrSpace(LD);
@@ -1191,18 +1187,15 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
// Float : ISD::NON_EXTLOAD or ISD::EXTLOAD and the type is float
// Read at least 8 bits (predicates are stored as 8-bit values)
// The last operand holds the original LoadSDNode::getExtensionType() value
- const unsigned TotalWidth = MemVT.getSizeInBits();
const unsigned ExtensionType =
N->getConstantOperandVal(N->getNumOperands() - 1);
const unsigned FromType = (ExtensionType == ISD::SEXTLOAD)
? NVPTX::PTXLdStInstCode::Signed
: NVPTX::PTXLdStInstCode::Untyped;
- const unsigned FromTypeWidth = TotalWidth / getLoadStoreVectorNumElts(N);
+ const unsigned FromTypeWidth = getFromTypeWidthForLoad(LD);
assert(!(EltVT.isVector() && ExtensionType != ISD::NON_EXTLOAD));
- assert(isPowerOf2_32(FromTypeWidth) && FromTypeWidth >= 8 &&
- FromTypeWidth <= 128 && TotalWidth <= 256 && "Invalid width for load");
const auto [Base, Offset] = selectADDR(N->getOperand(1), CurDAG);
SDValue Ops[] = {getI32Imm(Ordering, DL),
@@ -1247,30 +1240,23 @@ bool NVPTXDAGToDAGISel::tryLDG(MemSDNode *LD) {
const EVT LoadedEVT = LD->getMemoryVT();
if (!LoadedEVT.isSimple())
return false;
- const MVT LoadedVT = LoadedEVT.getSimpleVT();
SDLoc DL(LD);
- const unsigned TotalWidth = LoadedVT.getSizeInBits();
unsigned ExtensionType;
- unsigned NumElts;
if (const auto *Load = dyn_cast<LoadSDNode>(LD)) {
ExtensionType = Load->getExtensionType();
- NumElts = 1;
} else {
ExtensionType = LD->getConstantOperandVal(LD->getNumOperands() - 1);
- NumElts = getLoadStoreVectorNumElts(LD);
}
const unsigned FromType = (ExtensionType == ISD::SEXTLOAD)
? NVPTX::PTXLdStInstCode::Signed
: NVPTX::PTXLdStInstCode::Untyped;
- const unsigned FromTypeWidth = TotalWidth / NumElts;
+ const unsigned FromTypeWidth = getFromTypeWidthForLoad(LD);
assert(!(LD->getSimpleValueType(0).isVector() &&
ExtensionType != ISD::NON_EXTLOAD));
- assert(isPowerOf2_32(FromTypeWidth) && FromTypeWidth >= 8 &&
- FromTypeWidth <= 128 && TotalWidth <= 256 && "Invalid width for load");
const auto [Base, Offset] = selectADDR(LD->getOperand(1), CurDAG);
SDValue Ops[] = {getI32Imm(FromType, DL), getI32Imm(FromTypeWidth, DL), Base,
@@ -1309,26 +1295,21 @@ bool NVPTXDAGToDAGISel::tryLDG(MemSDNode *LD) {
return true;
}
+unsigned NVPTXDAGToDAGISel::getFromTypeWidthForLoad(const MemSDNode *Mem) {
+ auto TotalWidth = Mem->getMemoryVT().getSizeInBits();
+ auto NumElts = Mem->getNumValues() - 1;
+ auto ElementBitWidth = TotalWidth / NumElts;
+ assert(isPowerOf2_32(ElementBitWidth) && ElementBitWidth >= 8 &&
+ ElementBitWidth <= 128 && TotalWidth <= 256 &&
+ "Invalid width for load");
+ return ElementBitWidth;
+}
+
bool NVPTXDAGToDAGISel::tryLDU(SDNode *N) {
auto *LD = cast<MemSDNode>(N);
- unsigned NumElts;
- switch (N->getOpcode()) {
- default:
- llvm_unreachable("Unexpected opcode");
- case ISD::INTRINSIC_W_CHAIN:
- NumElts = 1;
- break;
- case NVPTXISD::LDUV2:
- NumElts = 2;
- break;
- case NVPTXISD::LDUV4:
- NumElts = 4;
- break;
- }
-
SDLoc DL(N);
- const unsigned FromTypeWidth = LD->getMemoryVT().getSizeInBits() / NumElts;
+ const unsigned FromTypeWidth = getFromTypeWidthForLoad(LD);
const MVT::SimpleValueType TargetVT = LD->getSimpleValueType(0).SimpleTy;
// If this is an LDU intrinsic, the address is the third operand. If its an
@@ -1443,7 +1424,7 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
// - for integer type, always use 'u'
const unsigned TotalWidth = StoreVT.getSimpleVT().getSizeInBits();
- const unsigned NumElts = getLoadStoreVectorNumElts(ST);
+ const unsigned NumElts = getStoreVectorNumElts(ST);
SmallVector<SDValue, 16> Ops;
for (auto &V : ST->ops().slice(1, NumElts))
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
index 65731722f5343..e2ad55bc1796d 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
@@ -111,6 +111,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
public:
static NVPTX::AddressSpace getAddrSpace(const MemSDNode *N);
+ static unsigned getFromTypeWidthForLoad(const MemSDNode *Mem);
};
class NVPTXDAGToDAGISelLegacy : public SelectionDAGISelLegacy {
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 74e6c139c610d..ad56d2f12caf6 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -14,6 +14,7 @@
#include "NVPTXISelLowering.h"
#include "MCTargetDesc/NVPTXBaseInfo.h"
#include "NVPTX.h"
+#include "NVPTXISelDAGToDAG.h"
#include "NVPTXSubtarget.h"
#include "NVPTXTargetMachine.h"
#include "NVPTXTargetObjectFile.h"
@@ -5242,76 +5243,6 @@ static SDValue PerformFADDCombine(SDNode *N,
return PerformFADDCombineWithOperands(N, N1, N0, DCI, OptLevel);
}
-static SDValue PerformANDCombine(SDNode *N,
- TargetLowering::DAGCombinerInfo &DCI) {
- // The type legalizer turns a vector load of i8 values into a zextload to i16
- // registers, optionally ANY_EXTENDs it (if target type is integer),
- // and ANDs off the high 8 bits. Since we turn this load into a
- // target-specific DAG node, the DAG combiner fails to eliminate these AND
- // nodes. Do that here.
- SDValue Val = N->getOperand(0);
- SDValue Mask = N->getOperand(1);
-
- if (isa<ConstantSDNode>(Val)) {
- std::swap(Val, Mask);
- }
-
- SDValue AExt;
-
- // Generally, we will see zextload -> IMOV16rr -> ANY_EXTEND -> and
- if (Val.getOpcode() == ISD::ANY_EXTEND) {
- AExt = Val;
- Val = Val->getOperand(0);
- }
-
- if (Val->getOpcode() == NVPTXISD::LoadV2 ||
- Val->getOpcode() == NVPTXISD::LoadV4) {
- ConstantSDNode *MaskCnst = dyn_cast<ConstantSDNode>(Mask);
- if (!MaskCnst) {
- // Not an AND with a constant
- return SDValue();
- }
-
- uint64_t MaskVal = MaskCnst->getZExtValue();
- if (MaskVal != 0xff) {
- // Not an AND that chops off top 8 bits
- return SDValue();
- }
-
- MemSDNode *Mem = dyn_cast<MemSDNode>(Val);
- if (!Mem) {
- // Not a MemSDNode?!?
- return SDValue();
- }
-
- EVT MemVT = Mem->getMemoryVT();
- if (MemVT != MVT::v2i8 && MemVT != MVT::v4i8) {
- // We only handle the i8 case
- return SDValue();
- }
-
- unsigned ExtType = Val->getConstantOperandVal(Val->getNumOperands() - 1);
- if (ExtType == ISD::SEXTLOAD) {
- // If for some reason the load is a sextload, the and is needed to zero
- // out the high 8 bits
- return SDValue();
- }
-
- bool AddTo = false;
- if (AExt.getNode() != nullptr) {
- // Re-insert the ext as a zext.
- Val = DCI.DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N),
- AExt.getValueType(), Val);
- AddTo = true;
- }
-
- // If we get here, the AND is unnecessary. Just replace it with the load
- DCI.CombineTo(N, Val, AddTo);
- }
-
- return SDValue();
-}
-
static SDValue PerformREMCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
CodeGenOptLevel OptLevel) {
@@ -5983,8 +5914,6 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
return PerformADDCombine(N, DCI, OptLevel);
case ISD::ADDRSPACECAST:
return combineADDRSPACECAST(N, DCI);
- case ISD::AND:
- return PerformANDCombine(N, DCI);
case ISD::SIGN_EXTEND:
case ISD::ZERO_EXTEND:
return combineMulWide(N, DCI, OptLevel);
@@ -6609,6 +6538,24 @@ static void computeKnownBitsForPRMT(const SDValue Op, KnownBits &Known,
}
}
+static void computeKnownBitsForLoadV(const SDValue Op, KnownBits &Known) {
+ MemSDNode *LD = cast<MemSDNode>(Op);
+
+ // We can't do anything without knowing the sign bit.
+ auto ExtType = LD->getConstantOperandVal(LD->getNumOperands() - 1);
+ if (ExtType == ISD::SEXTLOAD)
+ return;
+
+ // ExtLoading to vector types is weird and may not work well with known bits.
+ auto DestVT = LD->getValueType(0);
+ if (DestVT.isVector())
+ return;
+
+ assert(Known.getBitWidth() == DestVT.getSizeInBits());
+ auto ElementBitWidth = NVPTXDAGToDAGISel::getFromTypeWidthForLoad(LD);
+ Known.Zero.setHighBits(Known.getBitWidth() - ElementBitWidth);
+}
+
void NVPTXTargetLowering::computeKnownBitsForTargetNode(
const SDValue Op, KnownBits &Known, const APInt &DemandedElts,
const SelectionDAG &DAG, unsigned Depth) const {
@@ -6618,6 +6565,11 @@ void NVPTXTargetLowering::computeKnownBitsForTargetNode(
case NVPTXISD::PRMT:
computeKnownBitsForPRMT(Op, Known, DAG, Depth);
break;
+ case NVPTXISD::LoadV2:
+ case NVPTXISD::LoadV4:
+ case NVPTXISD::LoadV8:
+ computeKnownBitsForLoadV(Op, Known);
+ break;
default:
break;
}
diff --git a/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll b/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll
index 53150c1a01314..f4053d84593a5 100644
--- a/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll
+++ b/llvm/test/CodeGen/NVPTX/i8x2-instructions.ll
@@ -103,5 +103,34 @@ define <2 x i8> @test_call_2xi8(<2 x i8> %a) {
%res = call <2 x i8> @test_call_2xi8(<2 x i8> %a)
ret <2 x i8> %res
}
+
+define <2 x float> @test_uitofp_2xi8(<2 x i8> %a) {
+; O0-LABEL: test_uitofp_2xi8(
+; O0: {
+; O0-NEXT: .reg .b16 %rs<3>;
+; O0-NEXT: .reg .b32 %r<4>;
+; O0-EMPTY:
+; O0-NEXT: // %bb.0:
+; O0-NEXT: ld.param.v2.b8 {%rs1, %rs2}, [test_uitofp_2xi8_param_0];
+; O0-NEXT: mov.b32 %r1, {%rs1, %rs2};
+; O0-NEXT: cvt.rn.f32.u16 %r2, %rs2;
+; O0-NEXT: cvt.rn.f32.u16 %r3, %rs1;
+; O0-NEXT: st.param.v2.b32 [func_retval0], {%r3, %r2};
+; O0-NEXT: ret;
+;
+; O3-LABEL: test_uitofp_2xi8(
+; O3: {
+; O3-NEXT: .reg .b16 %rs<3>;
+; O3-NEXT: .reg .b32 %r<3>;
+; O3-EMPTY:
+; O3-NEXT: // %bb.0:
+; O3-NEXT: ld.param.v2.b8 {%rs1, %rs2}, [test_uitofp_2xi8_param_0];
+; O3-NEXT: cvt.rn.f32.u16 %r1, %rs2;
+; O3-NEXT: cvt.rn.f32.u16 %r2, %rs1;
+; O3-NEXT: st.param.v2.b32 [func_retval0], {%r2, %r1};
+; O3-NEXT: ret;
+ %1 = uitofp <2 x i8> %a to <2 x float>
+ ret <2 x float> %1
+}
;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
; COMMON: {{.*}}
diff --git a/llvm/test/CodeGen/NVPTX/shift-opt.ll b/llvm/test/CodeGen/NVPTX/shift-opt.ll
index e7866b01064c7..e0d22c62993ba 100644
--- a/llvm/test/CodeGen/NVPTX/shift-opt.ll
+++ b/llvm/test/CodeGen/NVPTX/shift-opt.ll
@@ -71,18 +71,17 @@ define <2 x i16> @test_vec(<2 x i16> %x, <2 x i8> %y) {
; CHECK-LABEL: test_vec(
; CHECK: {
; CHECK-NEXT: .reg .b16 %rs<7>;
-; CHECK-NEXT: .reg .b32 %r<5>;
+; CHECK-NEXT: .reg .b32 %r<4>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.v2.b16 {%rs1, %rs2}, [test_vec_param_0];
; CHECK-NEXT: ld.param.v2.b8 {%rs3, %rs4}, [test_vec_param_1];
; CHECK-NEXT: mov.b32 %r1, {%rs3, %rs4};
-; CHECK-NEXT: and.b32 %r2, %r1, 16711935;
; CHECK-NEXT: shr.u16 %rs5, %rs2, 5;
; CHECK-NEXT: shr.u16 %rs6, %rs1, 5;
-; CHECK-NEXT: mov.b32 %r3, {%rs6, %rs5};
-; CHECK-NEXT: or.b32 %r4, %r3, %r2;
-; CHECK-NEXT: st.param.b32 [func_retval0], %r4;
+; CHECK-NEXT: mov.b32 %r2, {%rs6, %rs5};
+; CHECK-NEXT: or.b32 %r3, %r2, %r1;
+; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
; CHECK-NEXT: ret;
%ext = zext <2 x i8> %y to <2 x i16>
%shl = shl <2 x i16> %ext, splat(i16 5)
More information about the llvm-commits
mailing list