[llvm] [DAGCombiner] Use getShiftAmountConstant where possible. (PR #97683)

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Wed Jul 3 23:55:07 PDT 2024


https://github.com/topperc created https://github.com/llvm/llvm-project/pull/97683

In #97645, I proposed removing the LegalTypes operand to TargetLowering::getShiftAmountTy. This means we don't need to use the DAGCombiner wrapper for getShiftAmountTy that manages this flag. Now we can use getShiftAmountOrConstant and let it call TargetLowering::getShiftAmountTy.

This contains the same X86 test change as #97645.

>From 6c5cc5f634e6618e5a0eba0aa238455f6272a44e Mon Sep 17 00:00:00 2001
From: Craig Topper <craig.topper at sifive.com>
Date: Wed, 3 Jul 2024 17:22:33 -0700
Subject: [PATCH] [DAGCombiner] Use getShiftAmountConstant where possible.

In #97645, I proposed removing the LegalTypes operand to
TargetLowering::getShiftAmountTy. This means we don't need to
use the DAGCombiner wrapper for getShiftAmountTy that manages this
flag. Now we can use getShiftAmountOrConstant and let it call
TargetLowering::getShiftAmountTy.

This contains the same X86 test change as #97645.
---
 llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp | 92 ++++++++-----------
 llvm/test/CodeGen/X86/shift-combine.ll        | 10 +-
 2 files changed, 41 insertions(+), 61 deletions(-)

diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index d81a54d2ecaaa..a7b469ea17f18 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -4395,14 +4395,13 @@ template <class MatchContextClass> SDValue DAGCombiner::visitMUL(SDNode *N) {
   // fold (mul x, -(1 << c)) -> -(x << c) or (-x) << c
   if (N1IsConst && !N1IsOpaqueConst && ConstValue1.isNegatedPowerOf2()) {
     unsigned Log2Val = (-ConstValue1).logBase2();
-    EVT ShiftVT = getShiftAmountTy(N0.getValueType());
 
     // FIXME: If the input is something that is easily negated (e.g. a
     // single-use add), we should put the negate there.
     return Matcher.getNode(
         ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT),
         Matcher.getNode(ISD::SHL, DL, VT, N0,
-                        DAG.getConstant(Log2Val, DL, ShiftVT)));
+                        DAG.getShiftAmountConstant(Log2Val, VT, DL)));
   }
 
   // Attempt to reuse an existing umul_lohi/smul_lohi node, but only if the
@@ -5101,9 +5100,9 @@ SDValue DAGCombiner::visitMULHS(SDNode *N) {
 
   // fold (mulhs x, 1) -> (sra x, size(x)-1)
   if (isOneConstant(N1))
-    return DAG.getNode(ISD::SRA, DL, VT, N0,
-                       DAG.getConstant(N0.getScalarValueSizeInBits() - 1, DL,
-                                       getShiftAmountTy(VT)));
+    return DAG.getNode(
+        ISD::SRA, DL, VT, N0,
+        DAG.getShiftAmountConstant(N0.getScalarValueSizeInBits() - 1, VT, DL));
 
   // fold (mulhs x, undef) -> 0
   if (N0.isUndef() || N1.isUndef())
@@ -5121,8 +5120,7 @@ SDValue DAGCombiner::visitMULHS(SDNode *N) {
       N1 = DAG.getNode(ISD::SIGN_EXTEND, DL, NewVT, N1);
       N1 = DAG.getNode(ISD::MUL, DL, NewVT, N0, N1);
       N1 = DAG.getNode(ISD::SRL, DL, NewVT, N1,
-            DAG.getConstant(SimpleSize, DL,
-                            getShiftAmountTy(N1.getValueType())));
+                       DAG.getShiftAmountConstant(SimpleSize, NewVT, DL));
       return DAG.getNode(ISD::TRUNCATE, DL, VT, N1);
     }
   }
@@ -5192,8 +5190,7 @@ SDValue DAGCombiner::visitMULHU(SDNode *N) {
       N1 = DAG.getNode(ISD::ZERO_EXTEND, DL, NewVT, N1);
       N1 = DAG.getNode(ISD::MUL, DL, NewVT, N0, N1);
       N1 = DAG.getNode(ISD::SRL, DL, NewVT, N1,
-            DAG.getConstant(SimpleSize, DL,
-                            getShiftAmountTy(N1.getValueType())));
+                       DAG.getShiftAmountConstant(SimpleSize, NewVT, DL));
       return DAG.getNode(ISD::TRUNCATE, DL, VT, N1);
     }
   }
@@ -5404,8 +5401,7 @@ SDValue DAGCombiner::visitSMUL_LOHI(SDNode *N) {
       Lo = DAG.getNode(ISD::MUL, DL, NewVT, Lo, Hi);
       // Compute the high part as N1.
       Hi = DAG.getNode(ISD::SRL, DL, NewVT, Lo,
-            DAG.getConstant(SimpleSize, DL,
-                            getShiftAmountTy(Lo.getValueType())));
+                       DAG.getShiftAmountConstant(SimpleSize, NewVT, DL));
       Hi = DAG.getNode(ISD::TRUNCATE, DL, VT, Hi);
       // Compute the low part as N0.
       Lo = DAG.getNode(ISD::TRUNCATE, DL, VT, Lo);
@@ -5458,8 +5454,7 @@ SDValue DAGCombiner::visitUMUL_LOHI(SDNode *N) {
       Lo = DAG.getNode(ISD::MUL, DL, NewVT, Lo, Hi);
       // Compute the high part as N1.
       Hi = DAG.getNode(ISD::SRL, DL, NewVT, Lo,
-            DAG.getConstant(SimpleSize, DL,
-                            getShiftAmountTy(Lo.getValueType())));
+                       DAG.getShiftAmountConstant(SimpleSize, NewVT, DL));
       Hi = DAG.getNode(ISD::TRUNCATE, DL, VT, Hi);
       // Compute the low part as N0.
       Lo = DAG.getNode(ISD::TRUNCATE, DL, VT, Lo);
@@ -7484,8 +7479,7 @@ SDValue DAGCombiner::MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
   if (OpSizeInBits > 16) {
     SDLoc DL(N);
     Res = DAG.getNode(ISD::SRL, DL, VT, Res,
-                      DAG.getConstant(OpSizeInBits - 16, DL,
-                                      getShiftAmountTy(VT)));
+                      DAG.getShiftAmountConstant(OpSizeInBits - 16, VT, DL));
   }
   return Res;
 }
@@ -7603,7 +7597,7 @@ static bool isBSwapHWordPair(SDValue N, MutableArrayRef<SDNode *> Parts) {
 //   (rotr (bswap A), 16)
 static SDValue matchBSwapHWordOrAndAnd(const TargetLowering &TLI,
                                        SelectionDAG &DAG, SDNode *N, SDValue N0,
-                                       SDValue N1, EVT VT, EVT ShiftAmountTy) {
+                                       SDValue N1, EVT VT) {
   assert(N->getOpcode() == ISD::OR && VT == MVT::i32 &&
          "MatchBSwapHWordOrAndAnd: expecting i32");
   if (!TLI.isOperationLegalOrCustom(ISD::ROTR, VT))
@@ -7635,7 +7629,7 @@ static SDValue matchBSwapHWordOrAndAnd(const TargetLowering &TLI,
 
   SDLoc DL(N);
   SDValue BSwap = DAG.getNode(ISD::BSWAP, DL, VT, Shift0.getOperand(0));
-  SDValue ShAmt = DAG.getConstant(16, DL, ShiftAmountTy);
+  SDValue ShAmt = DAG.getShiftAmountConstant(16, VT, DL);
   return DAG.getNode(ISD::ROTR, DL, VT, BSwap, ShAmt);
 }
 
@@ -7655,13 +7649,11 @@ SDValue DAGCombiner::MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1) {
   if (!TLI.isOperationLegalOrCustom(ISD::BSWAP, VT))
     return SDValue();
 
-  if (SDValue BSwap = matchBSwapHWordOrAndAnd(TLI, DAG, N, N0, N1, VT,
-                                              getShiftAmountTy(VT)))
+  if (SDValue BSwap = matchBSwapHWordOrAndAnd(TLI, DAG, N, N0, N1, VT))
     return BSwap;
 
   // Try again with commuted operands.
-  if (SDValue BSwap = matchBSwapHWordOrAndAnd(TLI, DAG, N, N1, N0, VT,
-                                              getShiftAmountTy(VT)))
+  if (SDValue BSwap = matchBSwapHWordOrAndAnd(TLI, DAG, N, N1, N0, VT))
     return BSwap;
 
 
@@ -7698,7 +7690,7 @@ SDValue DAGCombiner::MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1) {
 
   // Result of the bswap should be rotated by 16. If it's not legal, then
   // do  (x << 16) | (x >> 16).
-  SDValue ShAmt = DAG.getConstant(16, DL, getShiftAmountTy(VT));
+  SDValue ShAmt = DAG.getShiftAmountConstant(16, VT, DL);
   if (TLI.isOperationLegalOrCustom(ISD::ROTL, VT))
     return DAG.getNode(ISD::ROTL, DL, VT, BSwap, ShAmt);
   if (TLI.isOperationLegalOrCustom(ISD::ROTR, VT))
@@ -10430,8 +10422,7 @@ SDValue DAGCombiner::visitSRA(SDNode *N) {
           TLI.isOperationLegalOrCustom(ISD::SIGN_EXTEND, TruncVT) &&
           TLI.isOperationLegalOrCustom(ISD::TRUNCATE, VT) &&
           TLI.isTruncateFree(VT, TruncVT)) {
-        SDValue Amt = DAG.getConstant(ShiftAmt, DL,
-            getShiftAmountTy(N0.getOperand(0).getValueType()));
+        SDValue Amt = DAG.getShiftAmountConstant(ShiftAmt, VT, DL);
         SDValue Shift = DAG.getNode(ISD::SRL, DL, VT,
                                     N0.getOperand(0), Amt);
         SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, TruncVT,
@@ -10679,10 +10670,9 @@ SDValue DAGCombiner::visitSRL(SDNode *N) {
     if (!LegalTypes || TLI.isTypeDesirableForOp(ISD::SRL, SmallVT)) {
       uint64_t ShiftAmt = N1C->getZExtValue();
       SDLoc DL0(N0);
-      SDValue SmallShift = DAG.getNode(ISD::SRL, DL0, SmallVT,
-                                       N0.getOperand(0),
-                          DAG.getConstant(ShiftAmt, DL0,
-                                          getShiftAmountTy(SmallVT)));
+      SDValue SmallShift =
+          DAG.getNode(ISD::SRL, DL0, SmallVT, N0.getOperand(0),
+                      DAG.getShiftAmountConstant(ShiftAmt, SmallVT, DL0));
       AddToWorklist(SmallShift.getNode());
       APInt Mask = APInt::getLowBitsSet(OpSizeInBits, OpSizeInBits - ShiftAmt);
       return DAG.getNode(ISD::AND, DL, VT,
@@ -10726,8 +10716,7 @@ SDValue DAGCombiner::visitSRL(SDNode *N) {
       if (ShAmt) {
         SDLoc DL(N0);
         Op = DAG.getNode(ISD::SRL, DL, VT, Op,
-                  DAG.getConstant(ShAmt, DL,
-                                  getShiftAmountTy(Op.getValueType())));
+                         DAG.getShiftAmountConstant(ShAmt, VT, DL));
         AddToWorklist(Op.getNode());
       }
       return DAG.getNode(ISD::XOR, DL, VT, Op, DAG.getConstant(1, DL, VT));
@@ -11086,7 +11075,7 @@ SDValue DAGCombiner::visitBSWAP(SDNode *N) {
       SDValue Res = N0.getOperand(0);
       if (uint64_t NewShAmt = (ShAmt->getZExtValue() - (BW / 2)))
         Res = DAG.getNode(ISD::SHL, DL, VT, Res,
-                          DAG.getConstant(NewShAmt, DL, getShiftAmountTy(VT)));
+                          DAG.getShiftAmountConstant(NewShAmt, VT, DL));
       Res = DAG.getZExtOrTrunc(Res, DL, HalfVT);
       Res = DAG.getNode(ISD::BSWAP, DL, HalfVT, Res);
       return DAG.getZExtOrTrunc(Res, DL, VT);
@@ -12316,9 +12305,9 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) {
       if (TLI.isOperationLegalOrCustom(ISD::ABS, VT))
         return DAG.getNode(ISD::ABS, DL, VT, LHS);
 
-      SDValue Shift = DAG.getNode(ISD::SRA, DL, VT, LHS,
-                                  DAG.getConstant(VT.getScalarSizeInBits() - 1,
-                                                  DL, getShiftAmountTy(VT)));
+      SDValue Shift = DAG.getNode(
+          ISD::SRA, DL, VT, LHS,
+          DAG.getShiftAmountConstant(VT.getScalarSizeInBits() - 1, VT, DL));
       SDValue Add = DAG.getNode(ISD::ADD, DL, VT, LHS, Shift);
       AddToWorklist(Shift.getNode());
       AddToWorklist(Add.getNode());
@@ -14625,9 +14614,6 @@ SDValue DAGCombiner::reduceLoadWidth(SDNode *N) {
   // Shift the result left, if we've swallowed a left shift.
   SDValue Result = Load;
   if (ShLeftAmt != 0) {
-    EVT ShImmTy = getShiftAmountTy(Result.getValueType());
-    if (!isUIntN(ShImmTy.getScalarSizeInBits(), ShLeftAmt))
-      ShImmTy = VT;
     // If the shift amount is as large as the result size (but, presumably,
     // no larger than the source) then the useful bits of the result are
     // zero; we can't simply return the shortened shift, because the result
@@ -14635,8 +14621,8 @@ SDValue DAGCombiner::reduceLoadWidth(SDNode *N) {
     if (ShLeftAmt >= VT.getScalarSizeInBits())
       Result = DAG.getConstant(0, DL, VT);
     else
-      Result = DAG.getNode(ISD::SHL, DL, VT,
-                          Result, DAG.getConstant(ShLeftAmt, DL, ShImmTy));
+      Result = DAG.getNode(ISD::SHL, DL, VT, Result,
+                           DAG.getShiftAmountConstant(ShLeftAmt, VT, DL));
   }
 
   if (ShiftedOffset != 0) {
@@ -16898,7 +16884,7 @@ SDValue DAGCombiner::combineFMulOrFDivWithIntPow2(SDNode *N) {
 
   // Perform actual transform.
   SDValue MantissaShiftCnt =
-      DAG.getConstant(*Mantissa, DL, getShiftAmountTy(NewIntVT));
+      DAG.getShiftAmountConstant(*Mantissa, NewIntVT, DL);
   // TODO: Sometimes Log2 is of form `(X + C)`. `(X + C) << C1` should fold to
   // `(X << C1) + (C << C1)`, but that isn't always the case because of the
   // cast. We could implement that by handle here to handle the casts.
@@ -19811,9 +19797,9 @@ ShrinkLoadReplaceStoreWithStore(const std::pair<unsigned, unsigned> &MaskInfo,
   // shifted by ByteShift and truncated down to NumBytes.
   if (ByteShift) {
     SDLoc DL(IVal);
-    IVal = DAG.getNode(ISD::SRL, DL, IVal.getValueType(), IVal,
-                       DAG.getConstant(ByteShift*8, DL,
-                                    DC->getShiftAmountTy(IVal.getValueType())));
+    IVal = DAG.getNode(
+        ISD::SRL, DL, IVal.getValueType(), IVal,
+        DAG.getShiftAmountConstant(ByteShift * 8, IVal.getValueType(), DL));
   }
 
   // Figure out the offset for the store and the alignment of the access.
@@ -27422,12 +27408,11 @@ SDValue DAGCombiner::foldSelectCCToShiftAnd(const SDLoc &DL, SDValue N0,
 
   // and (sra X, size(X)-1), A -> "and (srl X, C2), A" iff A is a single-bit
   // constant.
-  EVT ShiftAmtTy = getShiftAmountTy(N0.getValueType());
   auto *N2C = dyn_cast<ConstantSDNode>(N2.getNode());
   if (N2C && ((N2C->getAPIntValue() & (N2C->getAPIntValue() - 1)) == 0)) {
     unsigned ShCt = XType.getSizeInBits() - N2C->getAPIntValue().logBase2() - 1;
     if (!TLI.shouldAvoidTransformToShift(XType, ShCt)) {
-      SDValue ShiftAmt = DAG.getConstant(ShCt, DL, ShiftAmtTy);
+      SDValue ShiftAmt = DAG.getShiftAmountConstant(ShCt, XType, DL);
       SDValue Shift = DAG.getNode(ISD::SRL, DL, XType, N0, ShiftAmt);
       AddToWorklist(Shift.getNode());
 
@@ -27447,7 +27432,7 @@ SDValue DAGCombiner::foldSelectCCToShiftAnd(const SDLoc &DL, SDValue N0,
   if (TLI.shouldAvoidTransformToShift(XType, ShCt))
     return SDValue();
 
-  SDValue ShiftAmt = DAG.getConstant(ShCt, DL, ShiftAmtTy);
+  SDValue ShiftAmt = DAG.getShiftAmountConstant(ShCt, XType, DL);
   SDValue Shift = DAG.getNode(ISD::SRA, DL, XType, N0, ShiftAmt);
   AddToWorklist(Shift.getNode());
 
@@ -27661,16 +27646,13 @@ SDValue DAGCombiner::SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1,
       const APInt &AndMask = ConstAndRHS->getAPIntValue();
       if (TLI.shouldFoldSelectWithSingleBitTest(VT, AndMask)) {
         unsigned ShCt = AndMask.getBitWidth() - 1;
-        SDValue ShlAmt =
-            DAG.getConstant(AndMask.countl_zero(), SDLoc(AndLHS),
-                            getShiftAmountTy(AndLHS.getValueType()));
+        SDValue ShlAmt = DAG.getShiftAmountConstant(AndMask.countl_zero(), VT,
+                                                    SDLoc(AndLHS));
         SDValue Shl = DAG.getNode(ISD::SHL, SDLoc(N0), VT, AndLHS, ShlAmt);
 
         // Now arithmetic right shift it all the way over, so the result is
         // either all-ones, or zero.
-        SDValue ShrAmt =
-          DAG.getConstant(ShCt, SDLoc(Shl),
-                          getShiftAmountTy(Shl.getValueType()));
+        SDValue ShrAmt = DAG.getShiftAmountConstant(ShCt, VT, SDLoc(Shl));
         SDValue Shr = DAG.getNode(ISD::SRA, SDLoc(N0), VT, Shl, ShrAmt);
 
         return DAG.getNode(ISD::AND, DL, VT, Shr, N3);
@@ -27718,9 +27700,9 @@ SDValue DAGCombiner::SimplifySelectCC(const SDLoc &DL, SDValue N0, SDValue N1,
       return SDValue();
 
     // shl setcc result by log2 n2c
-    return DAG.getNode(ISD::SHL, DL, N2.getValueType(), Temp,
-                       DAG.getConstant(ShCt, SDLoc(Temp),
-                                       getShiftAmountTy(Temp.getValueType())));
+    return DAG.getNode(
+        ISD::SHL, DL, N2.getValueType(), Temp,
+        DAG.getShiftAmountConstant(ShCt, N2.getValueType(), SDLoc(Temp)));
   }
 
   // select_cc seteq X, 0, sizeof(X), ctlz(X) -> ctlz(X)
diff --git a/llvm/test/CodeGen/X86/shift-combine.ll b/llvm/test/CodeGen/X86/shift-combine.ll
index 30c3d53dd37c9..c9edd3f3e9048 100644
--- a/llvm/test/CodeGen/X86/shift-combine.ll
+++ b/llvm/test/CodeGen/X86/shift-combine.ll
@@ -444,12 +444,10 @@ define i64 @ashr_add_neg_shl_i32(i64 %r) nounwind {
 define i64 @ashr_add_neg_shl_i8(i64 %r) nounwind {
 ; X86-LABEL: ashr_add_neg_shl_i8:
 ; X86:       # %bb.0:
-; X86-NEXT:    movl {{[0-9]+}}(%esp), %eax
-; X86-NEXT:    shll $24, %eax
-; X86-NEXT:    movl $33554432, %edx # imm = 0x2000000
-; X86-NEXT:    subl %eax, %edx
-; X86-NEXT:    movl %edx, %eax
-; X86-NEXT:    sarl $24, %eax
+; X86-NEXT:    movb $2, %al
+; X86-NEXT:    subb {{[0-9]+}}(%esp), %al
+; X86-NEXT:    movsbl %al, %eax
+; X86-NEXT:    movl %eax, %edx
 ; X86-NEXT:    sarl $31, %edx
 ; X86-NEXT:    retl
 ;



More information about the llvm-commits mailing list