[llvm] 86c52af - [TargetLowering] SimplifyDemandedBits - use getValidShiftAmountConstant helper.

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Fri Feb 21 06:24:11 PST 2020


Author: Simon Pilgrim
Date: 2020-02-21T14:23:53Z
New Revision: 86c52af05a64c4aa9d61984eeda8fb7849a4b0fa

URL: https://github.com/llvm/llvm-project/commit/86c52af05a64c4aa9d61984eeda8fb7849a4b0fa
DIFF: https://github.com/llvm/llvm-project/commit/86c52af05a64c4aa9d61984eeda8fb7849a4b0fa.diff

LOG: [TargetLowering] SimplifyDemandedBits - use getValidShiftAmountConstant helper.

Use the SelectionDAG::getValidShiftAmountConstant helper to get const/constsplat shift amounts, which allows us to drop the out of range shift amount early-out.

First step towards better non-uniform shift amount support in SimplifyDemandedBits.

Added: 
    

Modified: 
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 34f60e38b620..4a684efcb395 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -7571,8 +7571,7 @@ SDValue DAGCombiner::visitSHL(SDNode *N) {
       return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0, NewOp1);
   }
 
-  // TODO - support non-uniform vector shift amounts.
-  if (N1C && SimplifyDemandedBits(SDValue(N, 0)))
+  if (SimplifyDemandedBits(SDValue(N, 0)))
     return SDValue(N, 0);
 
   // fold (shl (shl x, c1), c2) -> 0 or (shl x, (add c1, c2))
@@ -7938,8 +7937,7 @@ SDValue DAGCombiner::visitSRA(SDNode *N) {
   }
 
   // Simplify, based on bits shifted out of the LHS.
-  // TODO - support non-uniform vector shift amounts.
-  if (N1C && SimplifyDemandedBits(SDValue(N, 0)))
+  if (SimplifyDemandedBits(SDValue(N, 0)))
     return SDValue(N, 0);
 
   // If the sign bit is known to be zero, switch this to a SRL.
@@ -8135,8 +8133,7 @@ SDValue DAGCombiner::visitSRL(SDNode *N) {
 
   // fold operands of srl based on knowledge that the low bits are not
   // demanded.
-  // TODO - support non-uniform vector shift amounts.
-  if (N1C && SimplifyDemandedBits(SDValue(N, 0)))
+  if (SimplifyDemandedBits(SDValue(N, 0)))
     return SDValue(N, 0);
 
   if (N1C && !N1C->isOpaque())

diff  --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 5a38207e41cf..f93225d44fe1 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -1365,11 +1365,8 @@ bool TargetLowering::SimplifyDemandedBits(
     SDValue Op0 = Op.getOperand(0);
     SDValue Op1 = Op.getOperand(1);
 
-    if (ConstantSDNode *SA = isConstOrConstSplat(Op1, DemandedElts)) {
-      // If the shift count is an invalid immediate, don't do anything.
-      if (SA->getAPIntValue().uge(BitWidth))
-        break;
-
+    if (const APInt *SA =
+            TLO.DAG.getValidShiftAmountConstant(Op, DemandedElts)) {
       unsigned ShAmt = SA->getZExtValue();
       if (ShAmt == 0)
         return TLO.CombineTo(Op, Op0);
@@ -1380,9 +1377,9 @@ bool TargetLowering::SimplifyDemandedBits(
       // TODO - support non-uniform vector amounts.
       if (Op0.getOpcode() == ISD::SRL) {
         if (!DemandedBits.intersects(APInt::getLowBitsSet(BitWidth, ShAmt))) {
-          if (ConstantSDNode *SA2 =
-                  isConstOrConstSplat(Op0.getOperand(1), DemandedElts)) {
-            if (SA2->getAPIntValue().ult(BitWidth)) {
+          if (const APInt *SA2 =
+                  TLO.DAG.getValidShiftAmountConstant(Op0, DemandedElts)) {
+            if (SA2->ult(BitWidth)) {
               unsigned C1 = SA2->getZExtValue();
               unsigned Opc = ISD::SHL;
               int Diff = ShAmt - C1;
@@ -1434,8 +1431,8 @@ bool TargetLowering::SimplifyDemandedBits(
         // x aren't demanded.
         if (Op0.hasOneUse() && InnerOp.getOpcode() == ISD::SRL &&
             InnerOp.hasOneUse()) {
-          if (ConstantSDNode *SA2 =
-                  isConstOrConstSplat(InnerOp.getOperand(1))) {
+          if (const APInt *SA2 =
+                  TLO.DAG.getValidShiftAmountConstant(InnerOp, DemandedElts)) {
             unsigned InnerShAmt = SA2->getLimitedValue(InnerBits);
             if (InnerShAmt < ShAmt && InnerShAmt < InnerBits &&
                 DemandedBits.getActiveBits() <=
@@ -1463,11 +1460,8 @@ bool TargetLowering::SimplifyDemandedBits(
     SDValue Op0 = Op.getOperand(0);
     SDValue Op1 = Op.getOperand(1);
 
-    if (ConstantSDNode *SA = isConstOrConstSplat(Op1, DemandedElts)) {
-      // If the shift count is an invalid immediate, don't do anything.
-      if (SA->getAPIntValue().uge(BitWidth))
-        break;
-
+    if (const APInt *SA =
+            TLO.DAG.getValidShiftAmountConstant(Op, DemandedElts)) {
       unsigned ShAmt = SA->getZExtValue();
       if (ShAmt == 0)
         return TLO.CombineTo(Op, Op0);
@@ -1485,11 +1479,11 @@ bool TargetLowering::SimplifyDemandedBits(
       // are never demanded.
       // TODO - support non-uniform vector amounts.
       if (Op0.getOpcode() == ISD::SHL) {
-        if (ConstantSDNode *SA2 =
-                isConstOrConstSplat(Op0.getOperand(1), DemandedElts)) {
+        if (const APInt *SA2 =
+                TLO.DAG.getValidShiftAmountConstant(Op0, DemandedElts)) {
           if (!DemandedBits.intersects(
                   APInt::getHighBitsSet(BitWidth, ShAmt))) {
-            if (SA2->getAPIntValue().ult(BitWidth)) {
+            if (SA2->ult(BitWidth)) {
               unsigned C1 = SA2->getZExtValue();
               unsigned Opc = ISD::SRL;
               int Diff = ShAmt - C1;
@@ -1513,8 +1507,8 @@ bool TargetLowering::SimplifyDemandedBits(
       assert(!Known.hasConflict() && "Bits known to be one AND zero?");
       Known.Zero.lshrInPlace(ShAmt);
       Known.One.lshrInPlace(ShAmt);
-
-      Known.Zero.setHighBits(ShAmt); // High bits known zero.
+      // High bits known zero.
+      Known.Zero.setHighBits(ShAmt);
     }
     break;
   }
@@ -1536,11 +1530,8 @@ bool TargetLowering::SimplifyDemandedBits(
     if (DemandedBits.isOneValue())
       return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SRL, dl, VT, Op0, Op1));
 
-    if (ConstantSDNode *SA = isConstOrConstSplat(Op1, DemandedElts)) {
-      // If the shift count is an invalid immediate, don't do anything.
-      if (SA->getAPIntValue().uge(BitWidth))
-        break;
-
+    if (const APInt *SA =
+            TLO.DAG.getValidShiftAmountConstant(Op, DemandedElts)) {
       unsigned ShAmt = SA->getZExtValue();
       if (ShAmt == 0)
         return TLO.CombineTo(Op, Op0);


        


More information about the llvm-commits mailing list