[llvm] [SelectionDAG] Use getShiftAmountConstant to simplify code. NFC (PR #80561)

via llvm-commits llvm-commits at lists.llvm.org
Sat Feb 3 12:13:27 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-selectiondag

Author: Craig Topper (topperc)

<details>
<summary>Changes</summary>

Replace calls to getShiftAmountTy+getConstant with getShiftAmountContant.

---
Full diff: https://github.com/llvm/llvm-project/pull/80561.diff


1 Files Affected:

- (modified) llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp (+44-51) 


``````````diff
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 03b2a66989bd4..b15f62bc3aae7 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -1096,7 +1096,6 @@ bool TargetLowering::SimplifyDemandedBits(
   APInt DemandedBits = OriginalDemandedBits;
   APInt DemandedElts = OriginalDemandedElts;
   SDLoc dl(Op);
-  auto &DL = TLO.DAG.getDataLayout();
 
   // Undef operand.
   if (Op.isUndef())
@@ -2288,9 +2287,8 @@ bool TargetLowering::SimplifyDemandedBits(
       // the right place.
       unsigned ShiftOpcode = NLZ > NTZ ? ISD::SRL : ISD::SHL;
       if (!TLO.LegalOperations() || isOperationLegal(ShiftOpcode, VT)) {
-        EVT ShiftAmtTy = getShiftAmountTy(VT, DL);
         unsigned ShiftAmount = NLZ > NTZ ? NLZ - NTZ : NTZ - NLZ;
-        SDValue ShAmt = TLO.DAG.getConstant(ShiftAmount, dl, ShiftAmtTy);
+        SDValue ShAmt = TLO.DAG.getShiftAmountConstant(ShiftAmount, VT, dl);
         SDValue NewOp = TLO.DAG.getNode(ShiftOpcode, dl, VT, Src, ShAmt);
         return TLO.CombineTo(Op, NewOp);
       }
@@ -2330,8 +2328,8 @@ bool TargetLowering::SimplifyDemandedBits(
       if (!AlreadySignExtended) {
         // Compute the correct shift amount type, which must be getShiftAmountTy
         // for scalar types after legalization.
-        SDValue ShiftAmt = TLO.DAG.getConstant(BitWidth - ExVTBits, dl,
-                                               getShiftAmountTy(VT, DL));
+        SDValue ShiftAmt =
+            TLO.DAG.getShiftAmountConstant(BitWidth - ExVTBits, VT, dl);
         return TLO.CombineTo(Op,
                              TLO.DAG.getNode(ISD::SHL, dl, VT, Op0, ShiftAmt));
       }
@@ -2574,8 +2572,8 @@ bool TargetLowering::SimplifyDemandedBits(
         if (!(HighBits & DemandedBits)) {
           // None of the shifted in bits are needed.  Add a truncate of the
           // shift input, then shift it.
-          SDValue NewShAmt = TLO.DAG.getConstant(
-              ShVal, dl, getShiftAmountTy(VT, DL, TLO.LegalTypes()));
+          SDValue NewShAmt =
+              TLO.DAG.getShiftAmountConstant(ShVal, VT, dl, TLO.LegalTypes());
           SDValue NewTrunc =
               TLO.DAG.getNode(ISD::TRUNCATE, dl, VT, Src.getOperand(0));
           return TLO.CombineTo(
@@ -2753,8 +2751,7 @@ bool TargetLowering::SimplifyDemandedBits(
       unsigned CTZ = DemandedBits.countr_zero();
       ConstantSDNode *C = isConstOrConstSplat(Op.getOperand(1), DemandedElts);
       if (C && C->getAPIntValue().countr_zero() == CTZ) {
-        EVT ShiftAmtTy = getShiftAmountTy(VT, TLO.DAG.getDataLayout());
-        SDValue AmtC = TLO.DAG.getConstant(CTZ, dl, ShiftAmtTy);
+        SDValue AmtC = TLO.DAG.getShiftAmountConstant(CTZ, VT, dl);
         SDValue Shl = TLO.DAG.getNode(ISD::SHL, dl, VT, Op.getOperand(0), AmtC);
         return TLO.CombineTo(Op, Shl);
       }
@@ -2852,9 +2849,9 @@ bool TargetLowering::SimplifyDemandedBits(
       return 0;
     };
 
-    auto foldMul = [&](ISD::NodeType NT, SDValue X, SDValue Y, unsigned ShlAmt) {
-      EVT ShiftAmtTy = getShiftAmountTy(VT, TLO.DAG.getDataLayout());
-      SDValue ShlAmtC = TLO.DAG.getConstant(ShlAmt, dl, ShiftAmtTy);
+    auto foldMul = [&](ISD::NodeType NT, SDValue X, SDValue Y,
+                       unsigned ShlAmt) {
+      SDValue ShlAmtC = TLO.DAG.getShiftAmountConstant(ShlAmt, VT, dl);
       SDValue Shl = TLO.DAG.getNode(ISD::SHL, dl, VT, X, ShlAmtC);
       SDValue Res = TLO.DAG.getNode(NT, dl, VT, Y, Shl);
       return TLO.CombineTo(Op, Res);
@@ -4204,9 +4201,8 @@ SDValue TargetLowering::foldSetCCWithBinOp(EVT VT, SDValue N0, SDValue N1,
     return SDValue();
 
   // (X - Y) == Y --> X == Y << 1
-  EVT ShiftVT = getShiftAmountTy(OpVT, DAG.getDataLayout(),
-                                 !DCI.isBeforeLegalize());
-  SDValue One = DAG.getConstant(1, DL, ShiftVT);
+  SDValue One =
+      DAG.getShiftAmountConstant(1, OpVT, DL, !DCI.isBeforeLegalize());
   SDValue YShl1 = DAG.getNode(ISD::SHL, DL, N1.getValueType(), Y, One);
   if (!DCI.isCalledByLegalizer())
     DCI.AddToWorklist(YShl1.getNode());
@@ -5038,16 +5034,16 @@ SDValue TargetLowering::SimplifySetCC(EVT VT, SDValue N0, SDValue N1,
         (VT == ShValTy || (isTypeLegal(VT) && VT.bitsLE(ShValTy))) &&
         N0.getOpcode() == ISD::AND) {
       if (auto *AndRHS = dyn_cast<ConstantSDNode>(N0.getOperand(1))) {
-        EVT ShiftTy =
-            getShiftAmountTy(ShValTy, Layout, !DCI.isBeforeLegalize());
         if (Cond == ISD::SETNE && C1 == 0) {// (X & 8) != 0  -->  (X & 8) >> 3
           // Perform the xform if the AND RHS is a single bit.
           unsigned ShCt = AndRHS->getAPIntValue().logBase2();
           if (AndRHS->getAPIntValue().isPowerOf2() &&
               !TLI.shouldAvoidTransformToShift(ShValTy, ShCt)) {
-            return DAG.getNode(ISD::TRUNCATE, dl, VT,
-                               DAG.getNode(ISD::SRL, dl, ShValTy, N0,
-                                           DAG.getConstant(ShCt, dl, ShiftTy)));
+            return DAG.getNode(
+                ISD::TRUNCATE, dl, VT,
+                DAG.getNode(ISD::SRL, dl, ShValTy, N0,
+                            DAG.getShiftAmountConstant(
+                                ShCt, ShValTy, dl, !DCI.isBeforeLegalize())));
           }
         } else if (Cond == ISD::SETEQ && C1 == AndRHS->getAPIntValue()) {
           // (X & 8) == 8  -->  (X & 8) >> 3
@@ -5055,9 +5051,11 @@ SDValue TargetLowering::SimplifySetCC(EVT VT, SDValue N0, SDValue N1,
           unsigned ShCt = C1.logBase2();
           if (C1.isPowerOf2() &&
               !TLI.shouldAvoidTransformToShift(ShValTy, ShCt)) {
-            return DAG.getNode(ISD::TRUNCATE, dl, VT,
-                               DAG.getNode(ISD::SRL, dl, ShValTy, N0,
-                                           DAG.getConstant(ShCt, dl, ShiftTy)));
+            return DAG.getNode(
+                ISD::TRUNCATE, dl, VT,
+                DAG.getNode(ISD::SRL, dl, ShValTy, N0,
+                            DAG.getShiftAmountConstant(
+                                ShCt, ShValTy, dl, !DCI.isBeforeLegalize())));
           }
         }
       }
@@ -5065,7 +5063,6 @@ SDValue TargetLowering::SimplifySetCC(EVT VT, SDValue N0, SDValue N1,
 
     if (C1.getSignificantBits() <= 64 &&
         !isLegalICmpImmediate(C1.getSExtValue())) {
-      EVT ShiftTy = getShiftAmountTy(ShValTy, Layout, !DCI.isBeforeLegalize());
       // (X & -256) == 256 -> (X >> 8) == 1
       if ((Cond == ISD::SETEQ || Cond == ISD::SETNE) &&
           N0.getOpcode() == ISD::AND && N0.hasOneUse()) {
@@ -5074,9 +5071,10 @@ SDValue TargetLowering::SimplifySetCC(EVT VT, SDValue N0, SDValue N1,
           if (AndRHSC.isNegatedPowerOf2() && (AndRHSC & C1) == C1) {
             unsigned ShiftBits = AndRHSC.countr_zero();
             if (!TLI.shouldAvoidTransformToShift(ShValTy, ShiftBits)) {
-              SDValue Shift =
-                DAG.getNode(ISD::SRL, dl, ShValTy, N0.getOperand(0),
-                            DAG.getConstant(ShiftBits, dl, ShiftTy));
+              SDValue Shift = DAG.getNode(
+                  ISD::SRL, dl, ShValTy, N0.getOperand(0),
+                  DAG.getShiftAmountConstant(ShiftBits, ShValTy, dl,
+                                             !DCI.isBeforeLegalize()));
               SDValue CmpRHS = DAG.getConstant(C1.lshr(ShiftBits), dl, ShValTy);
               return DAG.getSetCC(dl, VT, Shift, CmpRHS, Cond);
             }
@@ -5103,8 +5101,10 @@ SDValue TargetLowering::SimplifySetCC(EVT VT, SDValue N0, SDValue N1,
         if (ShiftBits && NewC.getSignificantBits() <= 64 &&
             isLegalICmpImmediate(NewC.getSExtValue()) &&
             !TLI.shouldAvoidTransformToShift(ShValTy, ShiftBits)) {
-          SDValue Shift = DAG.getNode(ISD::SRL, dl, ShValTy, N0,
-                                      DAG.getConstant(ShiftBits, dl, ShiftTy));
+          SDValue Shift =
+              DAG.getNode(ISD::SRL, dl, ShValTy, N0,
+                          DAG.getShiftAmountConstant(ShiftBits, ShValTy, dl,
+                                                     !DCI.isBeforeLegalize()));
           SDValue CmpRHS = DAG.getConstant(NewC, dl, ShValTy);
           return DAG.getSetCC(dl, VT, Shift, CmpRHS, NewCond);
         }
@@ -8944,7 +8944,6 @@ SDValue TargetLowering::expandABS(SDNode *N, SelectionDAG &DAG,
                                   bool IsNegative) const {
   SDLoc dl(N);
   EVT VT = N->getValueType(0);
-  EVT ShVT = getShiftAmountTy(VT, DAG.getDataLayout());
   SDValue Op = N->getOperand(0);
 
   // abs(x) -> smax(x,sub(0,x))
@@ -8982,9 +8981,9 @@ SDValue TargetLowering::expandABS(SDNode *N, SelectionDAG &DAG,
     return SDValue();
 
   Op = DAG.getFreeze(Op);
-  SDValue Shift =
-      DAG.getNode(ISD::SRA, dl, VT, Op,
-                  DAG.getConstant(VT.getScalarSizeInBits() - 1, dl, ShVT));
+  SDValue Shift = DAG.getNode(
+      ISD::SRA, dl, VT, Op,
+      DAG.getShiftAmountConstant(VT.getScalarSizeInBits() - 1, VT, dl));
   SDValue Xor = DAG.getNode(ISD::XOR, dl, VT, Op, Shift);
 
   // abs(x) -> Y = sra (X, size(X)-1); sub (xor (X, Y), Y)
@@ -9592,9 +9591,7 @@ TargetLowering::expandUnalignedLoad(LoadSDNode *LD, SelectionDAG &DAG) const {
   }
 
   // aggregate the two parts
-  SDValue ShiftAmount =
-      DAG.getConstant(NumBits, dl, getShiftAmountTy(Hi.getValueType(),
-                                                    DAG.getDataLayout()));
+  SDValue ShiftAmount = DAG.getShiftAmountConstant(NumBits, VT, dl);
   SDValue Result = DAG.getNode(ISD::SHL, dl, VT, Hi, ShiftAmount);
   Result = DAG.getNode(ISD::OR, dl, VT, Result, Lo);
 
@@ -9706,8 +9703,8 @@ SDValue TargetLowering::expandUnalignedStore(StoreSDNode *ST,
   unsigned IncrementSize = NumBits / 8;
 
   // Divide the stored value in two parts.
-  SDValue ShiftAmount = DAG.getConstant(
-      NumBits, dl, getShiftAmountTy(Val.getValueType(), DAG.getDataLayout()));
+  SDValue ShiftAmount =
+      DAG.getShiftAmountConstant(NumBits, Val.getValueType(), dl);
   SDValue Lo = Val;
   // If Val is a constant, replace the upper bits with 0. The SRL will constant
   // fold and not use the upper bits. A smaller constant may be easier to
@@ -10351,9 +10348,8 @@ TargetLowering::expandFixedPointMul(SDNode *Node, SelectionDAG &DAG) const {
   // The result will need to be shifted right by the scale since both operands
   // are scaled. The result is given to us in 2 halves, so we only want part of
   // both in the result.
-  EVT ShiftTy = getShiftAmountTy(VT, DAG.getDataLayout());
   SDValue Result = DAG.getNode(ISD::FSHR, dl, VT, Hi, Lo,
-                               DAG.getConstant(Scale, dl, ShiftTy));
+                               DAG.getShiftAmountConstant(Scale, VT, dl));
   if (!Saturating)
     return Result;
 
@@ -10381,7 +10377,7 @@ TargetLowering::expandFixedPointMul(SDNode *Node, SelectionDAG &DAG) const {
 
   if (Scale == 0) {
     SDValue Sign = DAG.getNode(ISD::SRA, dl, VT, Lo,
-                               DAG.getConstant(VTSize - 1, dl, ShiftTy));
+                               DAG.getShiftAmountConstant(VTSize - 1, VT, dl));
     SDValue Overflow = DAG.getSetCC(dl, BoolVT, Hi, Sign, ISD::SETNE);
     // Saturated to SatMin if wide product is negative, and SatMax if wide
     // product is positive ...
@@ -10448,13 +10444,12 @@ TargetLowering::expandFixedPointDiv(unsigned Opcode, const SDLoc &dl,
   // RHS down by RHSShift, we can emit a regular division with a final scaling
   // factor of Scale.
 
-  EVT ShiftTy = getShiftAmountTy(VT, DAG.getDataLayout());
   if (LHSShift)
     LHS = DAG.getNode(ISD::SHL, dl, VT, LHS,
-                      DAG.getConstant(LHSShift, dl, ShiftTy));
+                      DAG.getShiftAmountConstant(LHSShift, VT, dl));
   if (RHSShift)
     RHS = DAG.getNode(Signed ? ISD::SRA : ISD::SRL, dl, VT, RHS,
-                      DAG.getConstant(RHSShift, dl, ShiftTy));
+                      DAG.getShiftAmountConstant(RHSShift, VT, dl));
 
   SDValue Quot;
   if (Signed) {
@@ -10597,8 +10592,7 @@ bool TargetLowering::expandMULO(SDNode *Node, SDValue &Result,
     if (C.isPowerOf2()) {
       // smulo(x, signed_min) is same as umulo(x, signed_min).
       bool UseArithShift = isSigned && !C.isMinSignedValue();
-      EVT ShiftAmtTy = getShiftAmountTy(VT, DAG.getDataLayout());
-      SDValue ShiftAmt = DAG.getConstant(C.logBase2(), dl, ShiftAmtTy);
+      SDValue ShiftAmt = DAG.getShiftAmountConstant(C.logBase2(), VT, dl);
       Result = DAG.getNode(ISD::SHL, dl, VT, LHS, ShiftAmt);
       Overflow = DAG.getSetCC(dl, SetCCVT,
           DAG.getNode(UseArithShift ? ISD::SRA : ISD::SRL,
@@ -10630,8 +10624,8 @@ bool TargetLowering::expandMULO(SDNode *Node, SDValue &Result,
     RHS = DAG.getNode(Ops[isSigned][2], dl, WideVT, RHS);
     SDValue Mul = DAG.getNode(ISD::MUL, dl, WideVT, LHS, RHS);
     BottomHalf = DAG.getNode(ISD::TRUNCATE, dl, VT, Mul);
-    SDValue ShiftAmt = DAG.getConstant(VT.getScalarSizeInBits(), dl,
-        getShiftAmountTy(WideVT, DAG.getDataLayout()));
+    SDValue ShiftAmt =
+        DAG.getShiftAmountConstant(VT.getScalarSizeInBits(), WideVT, dl);
     TopHalf = DAG.getNode(ISD::TRUNCATE, dl, VT,
                           DAG.getNode(ISD::SRL, dl, WideVT, Mul, ShiftAmt));
   } else {
@@ -10643,9 +10637,8 @@ bool TargetLowering::expandMULO(SDNode *Node, SDValue &Result,
 
   Result = BottomHalf;
   if (isSigned) {
-    SDValue ShiftAmt = DAG.getConstant(
-        VT.getScalarSizeInBits() - 1, dl,
-        getShiftAmountTy(BottomHalf.getValueType(), DAG.getDataLayout()));
+    SDValue ShiftAmt = DAG.getShiftAmountConstant(
+        VT.getScalarSizeInBits() - 1, BottomHalf.getValueType(), dl);
     SDValue Sign = DAG.getNode(ISD::SRA, dl, VT, BottomHalf, ShiftAmt);
     Overflow = DAG.getSetCC(dl, SetCCVT, TopHalf, Sign, ISD::SETNE);
   } else {

``````````

</details>


https://github.com/llvm/llvm-project/pull/80561


More information about the llvm-commits mailing list