[llvm] 42ec6fd - [TargetLowering] Apply basic shift combines before recursive SimplifyDemandedBits calls.

Simon Pilgrim via llvm-commits llvm-commits at lists.llvm.org
Fri Feb 21 08:34:46 PST 2020


Author: Simon Pilgrim
Date: 2020-02-21T16:31:20Z
New Revision: 42ec6fdce92090c02a10506fbdb2257fdbc2d1fd

URL: https://github.com/llvm/llvm-project/commit/42ec6fdce92090c02a10506fbdb2257fdbc2d1fd
DIFF: https://github.com/llvm/llvm-project/commit/42ec6fdce92090c02a10506fbdb2257fdbc2d1fd.diff

LOG: [TargetLowering] Apply basic shift combines before recursive SimplifyDemandedBits calls.

Minor refactor/cleanup before we begin adding non-uniform support.

Added: 
    

Modified: 
    llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index f93225d44fe1..1e345caa1e17 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -1364,6 +1364,7 @@ bool TargetLowering::SimplifyDemandedBits(
   case ISD::SHL: {
     SDValue Op0 = Op.getOperand(0);
     SDValue Op1 = Op.getOperand(1);
+    EVT ShiftVT = Op1.getValueType();
 
     if (const APInt *SA =
             TLO.DAG.getValidShiftAmountConstant(Op, DemandedElts)) {
@@ -1388,7 +1389,7 @@ bool TargetLowering::SimplifyDemandedBits(
                 Opc = ISD::SRL;
               }
 
-              SDValue NewSA = TLO.DAG.getConstant(Diff, dl, Op1.getValueType());
+              SDValue NewSA = TLO.DAG.getConstant(Diff, dl, ShiftVT);
               return TLO.CombineTo(
                   Op, TLO.DAG.getNode(Opc, dl, VT, Op0.getOperand(0), NewSA));
             }
@@ -1396,19 +1397,9 @@ bool TargetLowering::SimplifyDemandedBits(
         }
       }
 
-      APInt InDemandedMask = DemandedBits.lshr(ShAmt);
-      if (SimplifyDemandedBits(Op0, InDemandedMask, DemandedElts, Known, TLO,
-                               Depth + 1))
-        return true;
-
-      // Try shrinking the operation as long as the shift amount will still be
-      // in range.
-      if ((ShAmt < DemandedBits.getActiveBits()) &&
-          ShrinkDemandedOp(Op, BitWidth, DemandedBits, TLO))
-        return true;
-
       // Convert (shl (anyext x, c)) to (anyext (shl x, c)) if the high bits
       // are not demanded. This will likely allow the anyext to be folded away.
+      // TODO - support non-uniform vector amounts.
       if (Op0.getOpcode() == ISD::ANY_EXTEND) {
         SDValue InnerOp = Op0.getOperand(0);
         EVT InnerVT = InnerOp.getValueType();
@@ -1424,11 +1415,13 @@ bool TargetLowering::SimplifyDemandedBits(
           return TLO.CombineTo(
               Op, TLO.DAG.getNode(ISD::ANY_EXTEND, dl, VT, NarrowShl));
         }
+
         // Repeat the SHL optimization above in cases where an extension
         // intervenes: (shl (anyext (shr x, c1)), c2) to
         // (shl (anyext x), c2-c1).  This requires that the bottom c1 bits
         // aren't demanded (as above) and that the shifted upper c1 bits of
         // x aren't demanded.
+        // TODO - support non-uniform vector amounts.
         if (Op0.hasOneUse() && InnerOp.getOpcode() == ISD::SRL &&
             InnerOp.hasOneUse()) {
           if (const APInt *SA2 =
@@ -1438,8 +1431,8 @@ bool TargetLowering::SimplifyDemandedBits(
                 DemandedBits.getActiveBits() <=
                     (InnerBits - InnerShAmt + ShAmt) &&
                 DemandedBits.countTrailingZeros() >= ShAmt) {
-              SDValue NewSA = TLO.DAG.getConstant(ShAmt - InnerShAmt, dl,
-                                                  Op1.getValueType());
+              SDValue NewSA =
+                  TLO.DAG.getConstant(ShAmt - InnerShAmt, dl, ShiftVT);
               SDValue NewExt = TLO.DAG.getNode(ISD::ANY_EXTEND, dl, VT,
                                                InnerOp.getOperand(0));
               return TLO.CombineTo(
@@ -1449,16 +1442,28 @@ bool TargetLowering::SimplifyDemandedBits(
         }
       }
 
+      APInt InDemandedMask = DemandedBits.lshr(ShAmt);
+      if (SimplifyDemandedBits(Op0, InDemandedMask, DemandedElts, Known, TLO,
+                               Depth + 1))
+        return true;
+      assert(!Known.hasConflict() && "Bits known to be one AND zero?");
       Known.Zero <<= ShAmt;
       Known.One <<= ShAmt;
       // low bits known zero.
       Known.Zero.setLowBits(ShAmt);
+
+      // Try shrinking the operation as long as the shift amount will still be
+      // in range.
+      if ((ShAmt < DemandedBits.getActiveBits()) &&
+          ShrinkDemandedOp(Op, BitWidth, DemandedBits, TLO))
+        return true;
     }
     break;
   }
   case ISD::SRL: {
     SDValue Op0 = Op.getOperand(0);
     SDValue Op1 = Op.getOperand(1);
+    EVT ShiftVT = Op1.getValueType();
 
     if (const APInt *SA =
             TLO.DAG.getValidShiftAmountConstant(Op, DemandedElts)) {
@@ -1466,14 +1471,6 @@ bool TargetLowering::SimplifyDemandedBits(
       if (ShAmt == 0)
         return TLO.CombineTo(Op, Op0);
 
-      EVT ShiftVT = Op1.getValueType();
-      APInt InDemandedMask = (DemandedBits << ShAmt);
-
-      // If the shift is exact, then it does demand the low bits (and knows that
-      // they are zero).
-      if (Op->getFlags().hasExact())
-        InDemandedMask.setLowBits(ShAmt);
-
       // If this is ((X << C1) >>u ShAmt), see if we can simplify this into a
       // single shift.  We can do this if the top bits (which are shifted out)
       // are never demanded.
@@ -1500,6 +1497,13 @@ bool TargetLowering::SimplifyDemandedBits(
         }
       }
 
+      APInt InDemandedMask = (DemandedBits << ShAmt);
+
+      // If the shift is exact, then it does demand the low bits (and knows that
+      // they are zero).
+      if (Op->getFlags().hasExact())
+        InDemandedMask.setLowBits(ShAmt);
+
       // Compute the new bits that are at the top now.
       if (SimplifyDemandedBits(Op0, InDemandedMask, DemandedElts, Known, TLO,
                                Depth + 1))
@@ -1515,6 +1519,7 @@ bool TargetLowering::SimplifyDemandedBits(
   case ISD::SRA: {
     SDValue Op0 = Op.getOperand(0);
     SDValue Op1 = Op.getOperand(1);
+    EVT ShiftVT = Op1.getValueType();
 
     // If we only want bits that already match the signbit then we don't need
     // to shift.
@@ -1568,8 +1573,7 @@ bool TargetLowering::SimplifyDemandedBits(
       int Log2 = DemandedBits.exactLogBase2();
       if (Log2 >= 0) {
         // The bit must come from the sign.
-        SDValue NewSA =
-            TLO.DAG.getConstant(BitWidth - 1 - Log2, dl, Op1.getValueType());
+        SDValue NewSA = TLO.DAG.getConstant(BitWidth - 1 - Log2, dl, ShiftVT);
         return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SRL, dl, VT, Op0, NewSA));
       }
 


        


More information about the llvm-commits mailing list