[llvm] Calculate KnownBits from Metadata correctly for vector loads (PR #128908)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Mar 7 07:51:53 PST 2025
https://github.com/LU-JOHN updated https://github.com/llvm/llvm-project/pull/128908
>From f6706d46a6e298347e9e28f1991fd965dc8127e4 Mon Sep 17 00:00:00 2001
From: John Lu <John.Lu at amd.com>
Date: Wed, 26 Feb 2025 10:35:34 -0600
Subject: [PATCH 1/2] Calculate KnownBits from Metadata correctly for vector
loads
Signed-off-by: John Lu <John.Lu at amd.com>
---
llvm/lib/Analysis/ValueTracking.cpp | 7 ++++++-
llvm/test/CodeGen/AMDGPU/shl64_reduce.ll | 3 ++-
2 files changed, 8 insertions(+), 2 deletions(-)
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index e3e026f7979da..1d6368b408811 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -429,11 +429,16 @@ void llvm::computeKnownBitsFromRangeMetadata(const MDNode &Ranges,
ConstantInt *Upper =
mdconst::extract<ConstantInt>(Ranges.getOperand(2 * i + 1));
ConstantRange Range(Lower->getValue(), Upper->getValue());
+ unsigned RangeBitWidth = Lower->getBitWidth();
+ // BitWidth > RangeBitWidth can happen if Known is set to the width of a
+ // vector load but Ranges describes a vector element.
+ assert(BitWidth >= RangeBitWidth);
// The first CommonPrefixBits of all values in Range are equal.
unsigned CommonPrefixBits =
(Range.getUnsignedMax() ^ Range.getUnsignedMin()).countl_zero();
- APInt Mask = APInt::getHighBitsSet(BitWidth, CommonPrefixBits);
+ APInt Mask = APInt::getBitsSet(BitWidth, RangeBitWidth - CommonPrefixBits,
+ RangeBitWidth);
APInt UnsignedMax = Range.getUnsignedMax().zextOrTrunc(BitWidth);
Known.One &= UnsignedMax & Mask;
Known.Zero &= ~UnsignedMax & Mask;
diff --git a/llvm/test/CodeGen/AMDGPU/shl64_reduce.ll b/llvm/test/CodeGen/AMDGPU/shl64_reduce.ll
index 69242f4e44840..3f8553180df30 100644
--- a/llvm/test/CodeGen/AMDGPU/shl64_reduce.ll
+++ b/llvm/test/CodeGen/AMDGPU/shl64_reduce.ll
@@ -79,8 +79,9 @@ define <2 x i64> @shl_v2_metadata(<2 x i64> %arg0, ptr %arg1.ptr) {
; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
; CHECK-NEXT: flat_load_dwordx4 v[4:7], v[4:5]
; CHECK-NEXT: s_waitcnt vmcnt(0) lgkmcnt(0)
-; CHECK-NEXT: v_lshlrev_b64 v[0:1], v4, v[0:1]
; CHECK-NEXT: v_lshlrev_b64 v[2:3], v6, v[2:3]
+; CHECK-NEXT: v_lshlrev_b32_e32 v1, v4, v0
+; CHECK-NEXT: v_mov_b32_e32 v0, 0
; CHECK-NEXT: s_setpc_b64 s[30:31]
%shift.amt = load <2 x i64>, ptr %arg1.ptr, !range !0, !noundef !{}
%shl = shl <2 x i64> %arg0, %shift.amt
>From 1fdb2bbe04aaa41c15c0fb27b2a0a2d7596f6886 Mon Sep 17 00:00:00 2001
From: John Lu <John.Lu at amd.com>
Date: Fri, 7 Mar 2025 09:50:54 -0600
Subject: [PATCH 2/2] Set KnownBits to correct width. Reduce 64-bit shl for
all vector elts
Signed-off-by: John Lu <John.Lu at amd.com>
---
llvm/lib/Analysis/ValueTracking.cpp | 9 ++--
.../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 12 +++--
llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp | 46 ++++++++++++++-----
llvm/test/CodeGen/AMDGPU/shl64_reduce.ll | 30 +++++++-----
4 files changed, 66 insertions(+), 31 deletions(-)
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp
index 1d6368b408811..15ec8da35dcc6 100644
--- a/llvm/lib/Analysis/ValueTracking.cpp
+++ b/llvm/lib/Analysis/ValueTracking.cpp
@@ -430,15 +430,14 @@ void llvm::computeKnownBitsFromRangeMetadata(const MDNode &Ranges,
mdconst::extract<ConstantInt>(Ranges.getOperand(2 * i + 1));
ConstantRange Range(Lower->getValue(), Upper->getValue());
unsigned RangeBitWidth = Lower->getBitWidth();
- // BitWidth > RangeBitWidth can happen if Known is set to the width of a
- // vector load but Ranges describes a vector element.
- assert(BitWidth >= RangeBitWidth);
// The first CommonPrefixBits of all values in Range are equal.
unsigned CommonPrefixBits =
(Range.getUnsignedMax() ^ Range.getUnsignedMin()).countl_zero();
- APInt Mask = APInt::getBitsSet(BitWidth, RangeBitWidth - CommonPrefixBits,
- RangeBitWidth);
+ // BitWidth must equal RangeBitWidth. Otherwise Mask will be set
+ // incorrectly.
+ assert(BitWidth == RangeBitWidth && "BitWidth must equal RangeBitWidth");
+ APInt Mask = APInt::getHighBitsSet(BitWidth, CommonPrefixBits);
APInt UnsignedMax = Range.getUnsignedMax().zextOrTrunc(BitWidth);
Known.One &= UnsignedMax & Mask;
Known.Zero &= ~UnsignedMax & Mask;
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 9e61df7047d4a..b7e3ee89a9525 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -4027,15 +4027,19 @@ KnownBits SelectionDAG::computeKnownBits(SDValue Op, const APInt &DemandedElts,
// Fill in any known bits from range information. There are 3 types being
// used. The results VT (same vector elt size as BitWidth), the loaded
// MemoryVT (which may or may not be vector) and the range VTs original
- // type. The range matadata needs the full range (i.e
+ // type. The range metadata needs the full range (i.e
// MemoryVT().getSizeInBits()), which is truncated to the correct elt size
// if it is know. These are then extended to the original VT sizes below.
if (const MDNode *MD = LD->getRanges()) {
+ ConstantInt *Lower = mdconst::extract<ConstantInt>(MD->getOperand(0));
+
+ // FIXME: If loads are modified (e.g. type legalization)
+ // so that the load type no longer matches the range metadata type, the
+ // range metadata should be updated to match the new load width.
+ Known0 = Known0.trunc(Lower->getBitWidth());
computeKnownBitsFromRangeMetadata(*MD, Known0);
if (VT.isVector()) {
- // Handle truncation to the first demanded element.
- // TODO: Figure out which demanded elements are covered
- if (DemandedElts != 1 || !getDataLayout().isLittleEndian())
+ if (!getDataLayout().isLittleEndian())
break;
Known0 = Known0.trunc(BitWidth);
}
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
index 2d46cf3b70a34..74c0a38cb16e3 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
@@ -4084,7 +4084,7 @@ SDValue AMDGPUTargetLowering::performShlCombine(SDNode *N,
}
}
- if (VT != MVT::i64)
+ if (VT.getScalarType() != MVT::i64)
return SDValue();
// i64 (shl x, C) -> (build_pair 0, (shl x, C -32))
@@ -4092,21 +4092,24 @@ SDValue AMDGPUTargetLowering::performShlCombine(SDNode *N,
// On some subtargets, 64-bit shift is a quarter rate instruction. In the
// common case, splitting this into a move and a 32-bit shift is faster and
// the same code size.
- EVT TargetType = VT.getHalfSizedIntegerVT(*DAG.getContext());
- EVT TargetVecPairType = EVT::getVectorVT(*DAG.getContext(), TargetType, 2);
KnownBits Known = DAG.computeKnownBits(RHS);
- if (Known.getMinValue().getZExtValue() < TargetType.getSizeInBits())
+ EVT ElementType = VT.getScalarType();
+ EVT TargetScalarType = ElementType.getHalfSizedIntegerVT(*DAG.getContext());
+ EVT TargetType = (VT.isVector() ? VT.changeVectorElementType(TargetScalarType)
+ : TargetScalarType);
+
+ if (Known.getMinValue().getZExtValue() < TargetScalarType.getSizeInBits())
return SDValue();
SDValue ShiftAmt;
if (CRHS) {
- ShiftAmt =
- DAG.getConstant(RHSVal - TargetType.getSizeInBits(), SL, TargetType);
+ ShiftAmt = DAG.getConstant(RHSVal - TargetScalarType.getSizeInBits(), SL,
+ TargetType);
} else {
SDValue truncShiftAmt = DAG.getNode(ISD::TRUNCATE, SL, TargetType, RHS);
const SDValue ShiftMask =
- DAG.getConstant(TargetType.getSizeInBits() - 1, SL, TargetType);
+ DAG.getConstant(TargetScalarType.getSizeInBits() - 1, SL, TargetType);
// This AND instruction will clamp out of bounds shift values.
// It will also be removed during later instruction selection.
ShiftAmt = DAG.getNode(ISD::AND, SL, TargetType, truncShiftAmt, ShiftMask);
@@ -4116,9 +4119,24 @@ SDValue AMDGPUTargetLowering::performShlCombine(SDNode *N,
SDValue NewShift =
DAG.getNode(ISD::SHL, SL, TargetType, Lo, ShiftAmt, N->getFlags());
- const SDValue Zero = DAG.getConstant(0, SL, TargetType);
-
- SDValue Vec = DAG.getBuildVector(TargetVecPairType, SL, {Zero, NewShift});
+ const SDValue Zero = DAG.getConstant(0, SL, TargetScalarType);
+ SDValue Vec;
+
+ if (VT.isVector()) {
+ EVT ConcatType = TargetType.getDoubleNumVectorElementsVT(*DAG.getContext());
+ SmallVector<SDValue, 8> Ops;
+ for (unsigned I = 0, E = TargetType.getVectorNumElements(); I != E; ++I) {
+ SDValue Index = DAG.getConstant(I, SL, MVT::i32);
+ Ops.push_back(Zero);
+ SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, TargetScalarType,
+ NewShift, Index);
+ Ops.push_back(Elt);
+ }
+ Vec = DAG.getNode(ISD::BUILD_VECTOR, SL, ConcatType, Ops);
+ } else {
+ EVT ConcatType = EVT::getVectorVT(*DAG.getContext(), TargetType, 2);
+ Vec = DAG.getBuildVector(ConcatType, SL, {Zero, NewShift});
+ }
return DAG.getNode(ISD::BITCAST, SL, VT, Vec);
}
@@ -5182,7 +5200,13 @@ SDValue AMDGPUTargetLowering::PerformDAGCombine(SDNode *N,
break;
}
case ISD::SHL: {
- if (DCI.getDAGCombineLevel() < AfterLegalizeDAG)
+ // Range metadata can be invalidated when loads are converted to legal types
+ // (e.g. v2i64 -> v4i32).
+ // Try to convert vector shl before type legalization so that range metadata
+ // can be utilized.
+ if (!(N->getValueType(0).isVector() &&
+ DCI.getDAGCombineLevel() == BeforeLegalizeTypes) &&
+ DCI.getDAGCombineLevel() < AfterLegalizeDAG)
break;
return performShlCombine(N, DCI);
diff --git a/llvm/test/CodeGen/AMDGPU/shl64_reduce.ll b/llvm/test/CodeGen/AMDGPU/shl64_reduce.ll
index 3f8553180df30..2e2baa8175159 100644
--- a/llvm/test/CodeGen/AMDGPU/shl64_reduce.ll
+++ b/llvm/test/CodeGen/AMDGPU/shl64_reduce.ll
@@ -77,11 +77,12 @@ define <2 x i64> @shl_v2_metadata(<2 x i64> %arg0, ptr %arg1.ptr) {
; CHECK-LABEL: shl_v2_metadata:
; CHECK: ; %bb.0:
; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
-; CHECK-NEXT: flat_load_dwordx4 v[4:7], v[4:5]
+; CHECK-NEXT: flat_load_dwordx4 v[3:6], v[4:5]
; CHECK-NEXT: s_waitcnt vmcnt(0) lgkmcnt(0)
-; CHECK-NEXT: v_lshlrev_b64 v[2:3], v6, v[2:3]
-; CHECK-NEXT: v_lshlrev_b32_e32 v1, v4, v0
+; CHECK-NEXT: v_lshlrev_b32_e32 v1, v3, v0
+; CHECK-NEXT: v_lshlrev_b32_e32 v3, v5, v2
; CHECK-NEXT: v_mov_b32_e32 v0, 0
+; CHECK-NEXT: v_mov_b32_e32 v2, 0
; CHECK-NEXT: s_setpc_b64 s[30:31]
%shift.amt = load <2 x i64>, ptr %arg1.ptr, !range !0, !noundef !{}
%shl = shl <2 x i64> %arg0, %shift.amt
@@ -93,12 +94,15 @@ define <3 x i64> @shl_v3_metadata(<3 x i64> %arg0, ptr %arg1.ptr) {
; CHECK-LABEL: shl_v3_metadata:
; CHECK: ; %bb.0:
; CHECK-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
-; CHECK-NEXT: flat_load_dword v12, v[6:7] offset:16
+; CHECK-NEXT: flat_load_dword v1, v[6:7] offset:16
; CHECK-NEXT: flat_load_dwordx4 v[8:11], v[6:7]
; CHECK-NEXT: s_waitcnt vmcnt(0) lgkmcnt(0)
-; CHECK-NEXT: v_lshlrev_b64 v[4:5], v12, v[4:5]
-; CHECK-NEXT: v_lshlrev_b64 v[0:1], v8, v[0:1]
-; CHECK-NEXT: v_lshlrev_b64 v[2:3], v10, v[2:3]
+; CHECK-NEXT: v_lshlrev_b32_e32 v5, v1, v4
+; CHECK-NEXT: v_lshlrev_b32_e32 v1, v8, v0
+; CHECK-NEXT: v_lshlrev_b32_e32 v3, v10, v2
+; CHECK-NEXT: v_mov_b32_e32 v0, 0
+; CHECK-NEXT: v_mov_b32_e32 v2, 0
+; CHECK-NEXT: v_mov_b32_e32 v4, 0
; CHECK-NEXT: s_setpc_b64 s[30:31]
%shift.amt = load <3 x i64>, ptr %arg1.ptr, !range !0, !noundef !{}
%shl = shl <3 x i64> %arg0, %shift.amt
@@ -114,11 +118,15 @@ define <4 x i64> @shl_v4_metadata(<4 x i64> %arg0, ptr %arg1.ptr) {
; CHECK-NEXT: s_waitcnt vmcnt(0) lgkmcnt(0)
; CHECK-NEXT: flat_load_dwordx4 v[13:16], v[8:9] offset:16
; CHECK-NEXT: ; kill: killed $vgpr8 killed $vgpr9
-; CHECK-NEXT: v_lshlrev_b64 v[0:1], v10, v[0:1]
-; CHECK-NEXT: v_lshlrev_b64 v[2:3], v12, v[2:3]
+; CHECK-NEXT: v_lshlrev_b32_e32 v1, v10, v0
+; CHECK-NEXT: v_lshlrev_b32_e32 v3, v12, v2
; CHECK-NEXT: s_waitcnt vmcnt(0) lgkmcnt(0)
-; CHECK-NEXT: v_lshlrev_b64 v[4:5], v13, v[4:5]
-; CHECK-NEXT: v_lshlrev_b64 v[6:7], v15, v[6:7]
+; CHECK-NEXT: v_lshlrev_b32_e32 v5, v13, v4
+; CHECK-NEXT: v_lshlrev_b32_e32 v7, v15, v6
+; CHECK-NEXT: v_mov_b32_e32 v0, 0
+; CHECK-NEXT: v_mov_b32_e32 v2, 0
+; CHECK-NEXT: v_mov_b32_e32 v4, 0
+; CHECK-NEXT: v_mov_b32_e32 v6, 0
; CHECK-NEXT: s_setpc_b64 s[30:31]
%shift.amt = load <4 x i64>, ptr %arg1.ptr, !range !0, !noundef !{}
%shl = shl <4 x i64> %arg0, %shift.amt
More information about the llvm-commits
mailing list