[llvm] [SelectionDAG] Let ComputeKnownSignBits handle (shl (ext X), C) (PR #97695)
Björn Pettersson via llvm-commits
llvm-commits at lists.llvm.org
Thu Jul 4 02:47:59 PDT 2024
https://github.com/bjope updated https://github.com/llvm/llvm-project/pull/97695
>From 2460c6fdd156210021d8aae73a477588fbe4d7f2 Mon Sep 17 00:00:00 2001
From: Bjorn Pettersson <bjorn.a.pettersson at ericsson.com>
Date: Thu, 4 Jul 2024 10:34:04 +0200
Subject: [PATCH 1/2] [SelectionDAG] Let ComputeKnownSignBits handle (shl (ext
X), C)
Add simple support for looking through ZEXT/ANYEXT/SEXT when doing
ComputeKnownSignBits for SHL. This is valid for the case when all
extended bits are shifted out, because then the number of sign bits
can be found by analysing the EXT operand.
A future improvement could be to pass along the "shifted left by"
information in the recursive calls to ComputeKnownSignBits. Allowing
us to handle this more generically.
---
llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 15 +++++++++++++++
1 file changed, 15 insertions(+)
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 96242305e9eab..991df60a5e650 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -4617,6 +4617,21 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
case ISD::SHL:
if (std::optional<uint64_t> ShAmt =
getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1)) {
+ if (Op.getOperand(0).getOpcode() == ISD::ANY_EXTEND ||
+ Op.getOperand(0).getOpcode() == ISD::ZERO_EXTEND ||
+ Op.getOperand(0).getOpcode() == ISD::SIGN_EXTEND) {
+ SDValue Src = Op.getOperand(0);
+ EVT SrcVT = Src.getValueType();
+ SDValue ExtendedOp = Op.getOperand(0).getOperand(0);
+ EVT ExtendedOpVT = ExtendedOp.getValueType();
+ uint64_t ExtendedWidth =
+ SrcVT.getScalarSizeInBits() - ExtendedOpVT.getScalarSizeInBits();
+ if (ExtendedWidth <= *ShAmt) {
+ Tmp = ComputeNumSignBits(ExtendedOp, DemandedElts, Depth + 1);
+ if (*ShAmt - ExtendedWidth < Tmp)
+ return Tmp - (*ShAmt - ExtendedWidth);
+ }
+ }
// shl destroys sign bits, ensure it doesn't shift out all sign bits.
Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
if (*ShAmt < Tmp)
>From 85ad35da9dfaa2b1a25cb67721180d6af945a307 Mon Sep 17 00:00:00 2001
From: Bjorn Pettersson <bjorn.a.pettersson at ericsson.com>
Date: Thu, 4 Jul 2024 11:47:33 +0200
Subject: [PATCH 2/2] fixup: Also check MinShAmt
---
.../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 33 ++++++++++---------
1 file changed, 18 insertions(+), 15 deletions(-)
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 991df60a5e650..152dee9c2f78f 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -4615,27 +4615,30 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
Tmp = std::min<uint64_t>(Tmp + *ShAmt, VTBits);
return Tmp;
case ISD::SHL:
- if (std::optional<uint64_t> ShAmt =
+ if (std::optional<uint64_t> MaxShAmt =
getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1)) {
if (Op.getOperand(0).getOpcode() == ISD::ANY_EXTEND ||
Op.getOperand(0).getOpcode() == ISD::ZERO_EXTEND ||
- Op.getOperand(0).getOpcode() == ISD::SIGN_EXTEND) {
- SDValue Src = Op.getOperand(0);
- EVT SrcVT = Src.getValueType();
- SDValue ExtendedOp = Op.getOperand(0).getOperand(0);
- EVT ExtendedOpVT = ExtendedOp.getValueType();
- uint64_t ExtendedWidth =
- SrcVT.getScalarSizeInBits() - ExtendedOpVT.getScalarSizeInBits();
- if (ExtendedWidth <= *ShAmt) {
- Tmp = ComputeNumSignBits(ExtendedOp, DemandedElts, Depth + 1);
- if (*ShAmt - ExtendedWidth < Tmp)
- return Tmp - (*ShAmt - ExtendedWidth);
+ Op.getOperand(0).getOpcode() == ISD::SIGN_EXTEND)
+ if (std::optional<uint64_t> MinShAmt =
+ getValidMinimumShiftAmount(Op, DemandedElts, Depth + 1)) {
+ SDValue Src = Op.getOperand(0);
+ EVT SrcVT = Src.getValueType();
+ SDValue ExtendedOp = Op.getOperand(0).getOperand(0);
+ EVT ExtendedOpVT = ExtendedOp.getValueType();
+ uint64_t ExtendedWidth =
+ SrcVT.getScalarSizeInBits() - ExtendedOpVT.getScalarSizeInBits();
+ if (ExtendedWidth <= *MinShAmt) {
+ Tmp = ComputeNumSignBits(ExtendedOp, DemandedElts, Depth + 1);
+ Tmp += ExtendedWidth;
+ if (*MaxShAmt < Tmp)
+ return Tmp - *MaxShAmt;
+ }
}
- }
// shl destroys sign bits, ensure it doesn't shift out all sign bits.
Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
- if (*ShAmt < Tmp)
- return Tmp - *ShAmt;
+ if (*MaxShAmt < Tmp)
+ return Tmp - *MaxShAmt;
}
break;
case ISD::AND:
More information about the llvm-commits
mailing list