[llvm] [DAGCombiner] Use getShiftAmountConstant where possible. (PR #97683)
Craig Topper via llvm-commits
llvm-commits at lists.llvm.org
Wed Jul 3 23:55:07 PDT 2024
https://github.com/topperc created https://github.com/llvm/llvm-project/pull/97683
In #97645, I proposed removing the LegalTypes operand to TargetLowering::getShiftAmountTy. This means we don't need to use the DAGCombiner wrapper for getShiftAmountTy that manages this flag. Now we can use getShiftAmountOrConstant and let it call TargetLowering::getShiftAmountTy.
This contains the same X86 test change as #97645.
>From 6c5cc5f634e6618e5a0eba0aa238455f6272a44e Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Wed, 3 Jul 2024 17:22:33 -0700
Subject: [PATCH] [DAGCombiner] Use getShiftAmountConstant where possible.
In #97645, I proposed removing the LegalTypes operand to
TargetLowering::getShiftAmountTy. This means we don't need to
use the DAGCombiner wrapper for getShiftAmountTy that manages this
flag. Now we can use getShiftAmountOrConstant and let it call
TargetLowering::getShiftAmountTy.
This contains the same X86 test change as #97645.
---
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 92 ++++++++-----------
llvm/test/CodeGen/X86/shift-combine.ll | 10 +-
2 files changed, 41 insertions(+), 61 deletions(-)
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index d81a54d2ecaaa..a7b469ea17f18 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -4395,14 +4395,13 @@ template <class MatchContextClass> SDValue DAGCombiner::visitMUL(SDNode *N) {
// fold (mul x, -(1 << c)) -> -(x << c) or (-x) << c
if (N1IsConst && !N1IsOpaqueConst && ConstValue1.isNegatedPowerOf2()) {
unsigned Log2Val = (-ConstValue1).logBase2();
- EVT ShiftVT = getShiftAmountTy(N0.getValueType());
// FIXME: If the input is something that is easily negated (e.g. a
// single-use add), we should put the negate there.
return Matcher.getNode(
ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT),
Matcher.getNode(ISD::SHL, DL, VT, N0,
- DAG.getConstant(Log2Val, DL, ShiftVT)));
+ DAG.getShiftAmountConstant(Log2Val, VT, DL)));
}
// Attempt to reuse an existing umul_lohi/smul_lohi node, but only if the
@@ -5101,9 +5100,9 @@ SDValue DAGCombiner::visitMULHS(SDNode *N) {
// fold (mulhs x, 1) -> (sra x, size(x)-1)
if (isOneConstant(N1))
- return DAG.getNode(ISD::SRA, DL, VT, N0,
- DAG.getConstant(N0.getScalarValueSizeInBits() - 1, DL,
- getShiftAmountTy(VT)));
+ return DAG.getNode(
+ ISD::SRA, DL, VT, N0,
+ DAG.getShiftAmountConstant(N0.getScalarValueSizeInBits() - 1, VT, DL));
// fold (mulhs x, undef) -> 0
if (N0.isUndef() || N1.isUndef())
@@ -5121,8 +5120,7 @@ SDValue DAGCombiner::visitMULHS(SDNode *N) {
N1 = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N1);
N1 = DAG.getNode(ISD::MUL, DL, NewVT, N0, N1);
N1 = DAG.getNode(ISD::SRL, DL, NewVT, N1,
- DAG.getConstant(SimpleSize, DL,
- getShiftAmountTy(N1.getValueType())));
+ DAG.getShiftAmountConstant(SimpleSize, NewVT, DL));
return DAG.getNode(ISD::TRUNCATE, DL, VT, N1);
}
}
@@ -5192,8 +5190,7 @@ SDValue DAGCombiner::visitMULHU(SDNode *N) {
N1 = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N1);
N1 = DAG.getNode(ISD::MUL, DL, NewVT, N0, N1);
N1 = DAG.getNode(ISD::SRL, DL, NewVT, N1,
- DAG.getConstant(SimpleSize, DL,
- getShiftAmountTy(N1.getValueType())));
+ DAG.getShiftAmountConstant(SimpleSize, NewVT, DL));
return DAG.getNode(ISD::TRUNCATE, DL, VT, N1);
}
}
@@ -5404,8 +5401,7 @@ SDValue DAGCombiner::visitSMUL_LOHI(SDNode *N) {
Lo = DAG.getNode(ISD::MUL, DL, NewVT, Lo, Hi);
// Compute the high part as N1.
Hi = DAG.getNode(ISD::SRL, DL, NewVT, Lo,
- DAG.getConstant(SimpleSize, DL,
- getShiftAmountTy(Lo.getValueType())));
+ DAG.getShiftAmountConstant(SimpleSize, NewVT, DL));
Hi = DAG.getNode(ISD::TRUNCATE, DL, VT, Hi);
// Compute the low part as N0.
Lo = DAG.getNode(ISD::TRUNCATE, DL, VT, Lo);
@@ -5458,8 +5454,7 @@ SDValue DAGCombiner::visitUMUL_LOHI(SDNode *N) {
Lo = DAG.getNode(ISD::MUL, DL, NewVT, Lo, Hi);
// Compute the high part as N1.
Hi = DAG.getNode(ISD::SRL, DL, NewVT, Lo,
- DAG.getConstant(SimpleSize, DL,
- getShiftAmountTy(Lo.getValueType())));
+ DAG.getShiftAmountConstant(SimpleSize, NewVT, DL));
Hi = DAG.getNode(ISD::TRUNCATE, DL, VT, Hi);
// Compute the low part as N0.
Lo = DAG.getNode(ISD::TRUNCATE, DL, VT, Lo);
@@ -7484,8 +7479,7 @@ SDValue DAGCombiner::MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
if (OpSizeInBits > 16) {
SDLoc DL(N);
Res = DAG.getNode(ISD::SRL, DL, VT, Res,
- DAG.getConstant(OpSizeInBits - 16, DL,
- getShiftAmountTy(VT)));
+ DAG.getShiftAmountConstant(OpSizeInBits - 16, VT, DL));
}
return Res;
}
@@ -7603,7 +7597,7 @@ static bool isBSwapHWordPair(SDValue N, MutableArrayRef<SDNode *> Parts) {
// (rotr (bswap A), 16)
static SDValue matchBSwapHWordOrAndAnd(const TargetLowering &TLI,
SelectionDAG &DAG, SDNode *N, SDValue N0,
- SDValue N1, EVT VT, EVT ShiftAmountTy) {
+ SDValue N1, EVT VT) {
assert(N->getOpcode() == ISD::OR && VT == MVT::i32 &&
"MatchBSwapHWordOrAndAnd: expecting i32");
if (!TLI.isOperationLegalOrCustom(ISD::ROTR, VT))
@@ -7635,7 +7629,7 @@ static SDValue matchBSwapHWordOrAndAnd(const TargetLowering &TLI,
SDLoc DL(N);
SDValue BSwap = DAG.getNode(ISD::BSWAP, DL, VT, Shift0.getOperand(0));
- SDValue ShAmt = DAG.getConstant(16, DL, ShiftAmountTy);
+ SDValue ShAmt = DAG.getShiftAmountConstant(16, VT, DL);
return DAG.getNode(ISD::ROTR, DL, VT, BSwap, ShAmt);
}
@@ -7655,13 +7649,11 @@ SDValue DAGCombiner::MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1) {
if (!TLI.isOperationLegalOrCustom(ISD::BSWAP, VT))
return SDValue();
- if (SDValue BSwap = matchBSwapHWordOrAndAnd(TLI, DAG, N, N0, N1, VT,
- getShiftAmountTy(VT)))
+ if (SDValue BSwap = matchBSwapHWordOrAndAnd(TLI, DAG, N, N0, N1, VT))
return BSwap;
// Try again with commuted operands.
- if (SDValue BSwap = matchBSwapHWordOrAndAnd(TLI, DAG, N, N1, N0, VT,
- getShiftAmountTy(VT)))
+ if (SDValue BSwap = matchBSwapHWordOrAndAnd(TLI, DAG, N, N1, N0, VT))
return BSwap;
@@ -7698,7 +7690,7 @@ SDValue DAGCombiner::MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1) {
// Result of the bswap should be rotated by 16. If it's not legal, then
// do (x << 16) | (x >> 16).
- SDValue ShAmt = DAG.getConstant(16, DL, getShiftAmountTy(VT));
+ SDValue ShAmt = DAG.getShiftAmountConstant(16, VT, DL);
if (TLI.isOperationLegalOrCustom(ISD::ROTL, VT))
return DAG.getNode(ISD::ROTL, DL, VT, BSwap, ShAmt);
if (TLI.isOperationLegalOrCustom(ISD::ROTR, VT))
@@ -10430,8 +10422,7 @@ SDValue DAGCombiner::visitSRA(SDNode *N) {
TLI.isOperationLegalOrCustom(ISD::SIGN_EXTEND, TruncVT) &&
TLI.isOperationLegalOrCustom(ISD::TRUNCATE, VT) &&
TLI.isTruncateFree(VT, TruncVT)) {
- SDValue Amt = DAG.getConstant(ShiftAmt, DL,
- getShiftAmountTy(N0.getOperand(0).getValueType()));
+ SDValue Amt = DAG.getShiftAmountConstant(ShiftAmt, VT, DL);
SDValue Shift = DAG.getNode(ISD::SRL, DL, VT,
N0.getOperand(0), Amt);
SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, TruncVT,
@@ -10679,10 +10670,9 @@ SDValue DAGCombiner::visitSRL(SDNode *N) {
if (!LegalTypes || TLI.isTypeDesirableForOp(ISD::SRL, SmallVT)) {
uint64_t ShiftAmt = N1C->getZExtValue();
SDLoc DL0(N0);
- SDValue SmallShift = DAG.getNode(ISD::SRL, DL0, SmallVT,
- N0.getOperand(0),
- DAG.getConstant(ShiftAmt, DL0,
- getShiftAmountTy(SmallVT)));
+ SDValue SmallShift =
+ DAG.getNode(ISD::SRL, DL0, SmallVT, N0.getOperand(0),
+ DAG.getShiftAmountConstant(ShiftAmt, SmallVT, DL0));
AddToWorklist(SmallShift.getNode());
APInt Mask = APInt::getLowBitsSet(OpSizeInBits, OpSizeInBits - ShiftAmt);
return DAG.getNode(ISD::AND, DL, VT,
@@ -10726,8 +10716,7 @@ SDValue DAGCombiner::visitSRL(SDNode *N) {
if (ShAmt) {
SDLoc DL(N0);
Op = DAG.getNode(ISD::SRL, DL, VT, Op,
- DAG.getConstant(ShAmt, DL,
- getShiftAmountTy(Op.getValueType())));
+ DAG.getShiftAmountConstant(ShAmt, VT, DL));
AddToWorklist(Op.getNode());
}
return DAG.getNode(ISD::XOR, DL, VT, Op, DAG.getConstant(1, DL, VT));
@@ -11086,7 +11075,7 @@ SDValue DAGCombiner::visitBSWAP(SDNode *N) {
SDValue Res = N0.getOperand(0);
if (uint64_t NewShAmt = (ShAmt->getZExtValue() - (BW / 2)))
Res = DAG.getNode(ISD::SHL, DL, VT, Res,
- DAG.getConstant(NewShAmt, DL, getShiftAmountTy(VT)));
+ DAG.getShiftAmountConstant(NewShAmt, VT, DL));
Res = DAG.getZExtOrTrunc(Res, DL, HalfVT);
Res = DAG.getNode(ISD::BSWAP, DL, HalfVT, Res);
return DAG.getZExtOrTrunc(Res, DL, VT);
@@ -12316,9 +12305,9 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) {
if (TLI.isOperationLegalOrCustom(ISD::ABS, VT))
return DAG.getNode(ISD::ABS, DL, VT, LHS);
- SDValue Shift = DAG.getNode(ISD::SRA, DL, VT, LHS,
- DAG.getConstant(VT.getScalarSizeInBits() - 1,
- DL, getShiftAmountTy(VT)));
+ SDValue Shift = DAG.getNode(
+ ISD::SRA, DL, VT, LHS,
+ DAG.getShiftAmountConstant(VT.getScalarSizeInBits() - 1, VT, DL));
SDValue Add = DAG.getNode(ISD::ADD, DL, VT, LHS, Shift);
AddToWorklist(Shift.getNode());
AddToWorklist(Add.getNode());
@@ -14625,9 +14614,6 @@ SDValue DAGCombiner::reduceLoadWidth(SDNode *N) {
// Shift the result left, if we've swallowed a left shift.
SDValue Result = Load;
if (ShLeftAmt != 0) {
- EVT ShImmTy = getShiftAmountTy(Result.getValueType());
- if (!isUIntN(ShImmTy.getScalarSizeInBits(), ShLeftAmt))
- ShImmTy = VT;
// If the shift amount is as large as the result size (but, presumably,
// no larger than the source) then the useful bits of the result are
// zero; we can't simply return the shortened shift, because the result
@@ -14635,8 +14621,8 @@ SDValue DAGCombiner::reduceLoadWidth(SDNode *N) {
if (ShLeftAmt >= VT.getScalarSizeInBits())
Result = DAG.getConstant(0, DL, VT);
else
- Result = DAG.getNode(ISD::SHL, DL, VT,
- Result, DAG.getConstant(ShLeftAmt, DL, ShImmTy));
+ Result = DAG.getNode(ISD::SHL, DL, VT, Result,
+ DAG.getShiftAmountConstant(ShLeftAmt, VT, DL));
}
if (ShiftedOffset != 0) {
@@ -16898,7 +16884,7 @@ SDValue DAGCombiner::combineFMulOrFDivWithIntPow2(SDNode *N) {
// Perform actual transform.
SDValue MantissaShiftCnt =
- DAG.getConstant(*Mantissa, DL, getShiftAmountTy(NewIntVT));
+ DAG.getShiftAmountConstant(*Mantissa, NewIntVT, DL);
// TODO: Sometimes Log2 is of form `(X + C)`. `(X + C) << C1` should fold to
// `(X << C1) + (C << C1)`, but that isn't always the case because of the
// cast. We could implement that by handle here to handle the casts.
@@ -19811,9 +19797,9 @@ ShrinkLoadReplaceStoreWithStore(const std::pair<unsigned, unsigned> &MaskInfo,
// shifted by ByteShift and truncated down to NumBytes.
if (ByteShift) {
SDLoc DL(IVal);
- IVal = DAG.getNode(ISD::SRL, DL, IVal.getValueType(), IVal,
- DAG.getConstant(ByteShift*8, DL,
- DC->getShiftAmountTy(IVal.getValueType())));
+ IVal = DAG.getNode(
+ ISD::SRL, DL, IVal.getValueType(), IVal,
+ DAG.getShiftAmountConstant(ByteShift * 8, IVal.getValueType(), DL));
}
// Figure out the offset for the store and the alignment of the access.
@@ -27422,12 +27408,11 @@ SDValue DAGCombiner::foldSelectCCToShiftAnd(const SDLoc &DL, SDValue N0,
// and (sra X, size(X)-1), A -> "and (srl X, C2), A" iff A is a single-bit
// constant.
- EVT ShiftAmtTy = getShiftAmountTy(N0.getValueType());
auto *N2C = dyn_cast<ConstantSDNode>(N2.getNode());
if (N2C && ((N2C->getAPIntValue() & (N2C->getAPIntValue() - 1)) == 0)) {
unsigned ShCt = XType.getSizeInBits() - N2C->getAPIntValue().logBase2() - 1;
if (!TLI.shouldAvoidTransformToShift(XType, ShCt)) {
- SDValue ShiftAmt = DAG.getConstant(ShCt, DL, ShiftAmtTy);
+ SDValue ShiftAmt = DAG.getShiftAmountConstant(ShCt, XType, DL);
SDValue Shift = DAG.getNode(ISD::SRL, DL, XType, N0, ShiftAmt);
AddToWorklist(Shift.getNode());
@@ -27447,7 +27432,7 @@ SDValue DAGCombiner::foldSelectCCToShiftAnd(const SDLoc &DL, SDValue N0,
if (TLI.shouldAvoidTransformToShift(XType, ShCt))
return SDValue();
- SDValue ShiftAmt = DAG.getConstant(ShCt, DL, ShiftAmtTy);
+ SDValue ShiftAmt = DAG.getShiftAmountConstant(ShCt, XType, DL);
SDValue Shift = DAG.getNode(ISD::SRA, DL, XType, N0, ShiftAmt);
AddToWorklist(Shift.getNode());
@@ -27661,16 +27646,13 @@ SDValue DAGCombiner::SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1,
const APInt &AndMask = ConstAndRHS->getAPIntValue();
if (TLI.shouldFoldSelectWithSingleBitTest(VT, AndMask)) {
unsigned ShCt = AndMask.getBitWidth() - 1;
- SDValue ShlAmt =
- DAG.getConstant(AndMask.countl_zero(), SDLoc(AndLHS),
- getShiftAmountTy(AndLHS.getValueType()));
+ SDValue ShlAmt = DAG.getShiftAmountConstant(AndMask.countl_zero(), VT,
+ SDLoc(AndLHS));
SDValue Shl = DAG.getNode(ISD::SHL, SDLoc(N0), VT, AndLHS, ShlAmt);
// Now arithmetic right shift it all the way over, so the result is
// either all-ones, or zero.
- SDValue ShrAmt =
- DAG.getConstant(ShCt, SDLoc(Shl),
- getShiftAmountTy(Shl.getValueType()));
+ SDValue ShrAmt = DAG.getShiftAmountConstant(ShCt, VT, SDLoc(Shl));
SDValue Shr = DAG.getNode(ISD::SRA, SDLoc(N0), VT, Shl, ShrAmt);
return DAG.getNode(ISD::AND, DL, VT, Shr, N3);
@@ -27718,9 +27700,9 @@ SDValue DAGCombiner::SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1,
return SDValue();
// shl setcc result by log2 n2c
- return DAG.getNode(ISD::SHL, DL, N2.getValueType(), Temp,
- DAG.getConstant(ShCt, SDLoc(Temp),
- getShiftAmountTy(Temp.getValueType())));
+ return DAG.getNode(
+ ISD::SHL, DL, N2.getValueType(), Temp,
+ DAG.getShiftAmountConstant(ShCt, N2.getValueType(), SDLoc(Temp)));
}
// select_cc seteq X, 0, sizeof(X), ctlz(X) -> ctlz(X)
diff --git a/llvm/test/CodeGen/X86/shift-combine.ll b/llvm/test/CodeGen/X86/shift-combine.ll
index 30c3d53dd37c9..c9edd3f3e9048 100644
--- a/llvm/test/CodeGen/X86/shift-combine.ll
+++ b/llvm/test/CodeGen/X86/shift-combine.ll
@@ -444,12 +444,10 @@ define i64 @ashr_add_neg_shl_i32(i64 %r) nounwind {
define i64 @ashr_add_neg_shl_i8(i64 %r) nounwind {
; X86-LABEL: ashr_add_neg_shl_i8:
; X86: # %bb.0:
-; X86-NEXT: movl {{[0-9]+}}(%esp), %eax
-; X86-NEXT: shll $24, %eax
-; X86-NEXT: movl $33554432, %edx # imm = 0x2000000
-; X86-NEXT: subl %eax, %edx
-; X86-NEXT: movl %edx, %eax
-; X86-NEXT: sarl $24, %eax
+; X86-NEXT: movb $2, %al
+; X86-NEXT: subb {{[0-9]+}}(%esp), %al
+; X86-NEXT: movsbl %al, %eax
+; X86-NEXT: movl %eax, %edx
; X86-NEXT: sarl $31, %edx
; X86-NEXT: retl
;
More information about the llvm-commits
mailing list