[llvm] 599d070 - [X86] Remove dyn_casts to ConstantSDNode for operand 1 of X86ISD::VSRLI/VSRAI/VSRLI. Use getConstantOperandVal and APInt operations.

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Sun Dec 29 18:02:50 PST 2019


Author: Craig Topper
Date: 2019-12-29T16:53:38-08:00
New Revision: 599d07091002b20be5e2b12b256782e0dd0df998

URL: https://github.com/llvm/llvm-project/commit/599d07091002b20be5e2b12b256782e0dd0df998
DIFF: https://github.com/llvm/llvm-project/commit/599d07091002b20be5e2b12b256782e0dd0df998.diff

LOG: [X86] Remove dyn_casts to ConstantSDNode for operand 1 of X86ISD::VSRLI/VSRAI/VSRLI. Use getConstantOperandVal and APInt operations.

These nodes should only ever be formed with an i8 TargetConstant
so we don't need to check for it to be a constant. It's also
always 8-bits so we don't need to use APInt compare functions.

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86ISelLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 95e12e55433c..65c8b02289e5 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -32369,28 +32369,26 @@ void X86TargetLowering::computeKnownBitsForTargetNode(const SDValue Op,
   case X86ISD::VSRAI:
   case X86ISD::VSHLI:
   case X86ISD::VSRLI: {
-    if (auto *ShiftImm = dyn_cast<ConstantSDNode>(Op.getOperand(1))) {
-      if (ShiftImm->getAPIntValue().uge(VT.getScalarSizeInBits())) {
-        Known.setAllZero();
-        break;
-      }
+    unsigned ShAmt = Op.getConstantOperandVal(1);
+    if (ShAmt >= VT.getScalarSizeInBits()) {
+      Known.setAllZero();
+      break;
+    }
 
-      Known = DAG.computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
-      unsigned ShAmt = ShiftImm->getZExtValue();
-      if (Opc == X86ISD::VSHLI) {
-        Known.Zero <<= ShAmt;
-        Known.One <<= ShAmt;
-        // Low bits are known zero.
-        Known.Zero.setLowBits(ShAmt);
-      } else if (Opc == X86ISD::VSRLI) {
-        Known.Zero.lshrInPlace(ShAmt);
-        Known.One.lshrInPlace(ShAmt);
-        // High bits are known zero.
-        Known.Zero.setHighBits(ShAmt);
-      } else {
-        Known.Zero.ashrInPlace(ShAmt);
-        Known.One.ashrInPlace(ShAmt);
-      }
+    Known = DAG.computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1);
+    if (Opc == X86ISD::VSHLI) {
+      Known.Zero <<= ShAmt;
+      Known.One <<= ShAmt;
+      // Low bits are known zero.
+      Known.Zero.setLowBits(ShAmt);
+    } else if (Opc == X86ISD::VSRLI) {
+      Known.Zero.lshrInPlace(ShAmt);
+      Known.One.lshrInPlace(ShAmt);
+      // High bits are known zero.
+      Known.Zero.setHighBits(ShAmt);
+    } else {
+      Known.Zero.ashrInPlace(ShAmt);
+      Known.One.ashrInPlace(ShAmt);
     }
     break;
   }
@@ -35656,117 +35654,110 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
   }
   case X86ISD::VSHLI: {
     SDValue Op0 = Op.getOperand(0);
-    SDValue Op1 = Op.getOperand(1);
 
-    if (auto *ShiftImm = dyn_cast<ConstantSDNode>(Op1)) {
-      if (ShiftImm->getAPIntValue().uge(BitWidth))
-        break;
+    unsigned ShAmt = Op.getConstantOperandVal(1);
+    if (ShAmt >= BitWidth)
+      break;
 
-      unsigned ShAmt = ShiftImm->getZExtValue();
-      APInt DemandedMask = OriginalDemandedBits.lshr(ShAmt);
-
-      // If this is ((X >>u C1) << ShAmt), see if we can simplify this into a
-      // single shift.  We can do this if the bottom bits (which are shifted
-      // out) are never demanded.
-      if (Op0.getOpcode() == X86ISD::VSRLI &&
-          OriginalDemandedBits.countTrailingZeros() >= ShAmt) {
-        if (auto *Shift2Imm = dyn_cast<ConstantSDNode>(Op0.getOperand(1))) {
-          if (Shift2Imm->getAPIntValue().ult(BitWidth)) {
-            int Diff = ShAmt - Shift2Imm->getZExtValue();
-            if (Diff == 0)
-              return TLO.CombineTo(Op, Op0.getOperand(0));
-
-            unsigned NewOpc = Diff < 0 ? X86ISD::VSRLI : X86ISD::VSHLI;
-            SDValue NewShift = TLO.DAG.getNode(
-                NewOpc, SDLoc(Op), VT, Op0.getOperand(0),
-                TLO.DAG.getTargetConstant(std::abs(Diff), SDLoc(Op), MVT::i8));
-            return TLO.CombineTo(Op, NewShift);
-          }
-        }
+    APInt DemandedMask = OriginalDemandedBits.lshr(ShAmt);
+
+    // If this is ((X >>u C1) << ShAmt), see if we can simplify this into a
+    // single shift.  We can do this if the bottom bits (which are shifted
+    // out) are never demanded.
+    if (Op0.getOpcode() == X86ISD::VSRLI &&
+        OriginalDemandedBits.countTrailingZeros() >= ShAmt) {
+      unsigned Shift2Amt = Op0.getConstantOperandVal(1);
+      if (Shift2Amt < BitWidth) {
+        int Diff = ShAmt - Shift2Amt;
+        if (Diff == 0)
+          return TLO.CombineTo(Op, Op0.getOperand(0));
+
+        unsigned NewOpc = Diff < 0 ? X86ISD::VSRLI : X86ISD::VSHLI;
+        SDValue NewShift = TLO.DAG.getNode(
+            NewOpc, SDLoc(Op), VT, Op0.getOperand(0),
+            TLO.DAG.getTargetConstant(std::abs(Diff), SDLoc(Op), MVT::i8));
+        return TLO.CombineTo(Op, NewShift);
       }
+    }
 
-      if (SimplifyDemandedBits(Op0, DemandedMask, OriginalDemandedElts, Known,
-                               TLO, Depth + 1))
-        return true;
+    if (SimplifyDemandedBits(Op0, DemandedMask, OriginalDemandedElts, Known,
+                             TLO, Depth + 1))
+      return true;
 
-      assert(!Known.hasConflict() && "Bits known to be one AND zero?");
-      Known.Zero <<= ShAmt;
-      Known.One <<= ShAmt;
+    assert(!Known.hasConflict() && "Bits known to be one AND zero?");
+    Known.Zero <<= ShAmt;
+    Known.One <<= ShAmt;
 
-      // Low bits known zero.
-      Known.Zero.setLowBits(ShAmt);
-    }
+    // Low bits known zero.
+    Known.Zero.setLowBits(ShAmt);
     break;
   }
   case X86ISD::VSRLI: {
-    if (auto *ShiftImm = dyn_cast<ConstantSDNode>(Op.getOperand(1))) {
-      if (ShiftImm->getAPIntValue().uge(BitWidth))
-        break;
+    unsigned ShAmt = Op.getConstantOperandVal(1);
+    if (ShAmt >= BitWidth)
+      break;
 
-      unsigned ShAmt = ShiftImm->getZExtValue();
-      APInt DemandedMask = OriginalDemandedBits << ShAmt;
+    APInt DemandedMask = OriginalDemandedBits << ShAmt;
 
-      if (SimplifyDemandedBits(Op.getOperand(0), DemandedMask,
-                               OriginalDemandedElts, Known, TLO, Depth + 1))
-        return true;
+    if (SimplifyDemandedBits(Op.getOperand(0), DemandedMask,
+                             OriginalDemandedElts, Known, TLO, Depth + 1))
+      return true;
 
-      assert(!Known.hasConflict() && "Bits known to be one AND zero?");
-      Known.Zero.lshrInPlace(ShAmt);
-      Known.One.lshrInPlace(ShAmt);
+    assert(!Known.hasConflict() && "Bits known to be one AND zero?");
+    Known.Zero.lshrInPlace(ShAmt);
+    Known.One.lshrInPlace(ShAmt);
 
-      // High bits known zero.
-      Known.Zero.setHighBits(ShAmt);
-    }
+    // High bits known zero.
+    Known.Zero.setHighBits(ShAmt);
     break;
   }
   case X86ISD::VSRAI: {
     SDValue Op0 = Op.getOperand(0);
     SDValue Op1 = Op.getOperand(1);
 
-    if (auto *ShiftImm = dyn_cast<ConstantSDNode>(Op1)) {
-      if (ShiftImm->getAPIntValue().uge(BitWidth))
-        break;
+    unsigned ShAmt = cast<ConstantSDNode>(Op1)->getZExtValue();
+    if (ShAmt >= BitWidth)
+      break;
 
-      unsigned ShAmt = ShiftImm->getZExtValue();
-      APInt DemandedMask = OriginalDemandedBits << ShAmt;
+    APInt DemandedMask = OriginalDemandedBits << ShAmt;
 
-      // If we just want the sign bit then we don't need to shift it.
-      if (OriginalDemandedBits.isSignMask())
-        return TLO.CombineTo(Op, Op0);
+    // If we just want the sign bit then we don't need to shift it.
+    if (OriginalDemandedBits.isSignMask())
+      return TLO.CombineTo(Op, Op0);
 
-      // fold (VSRAI (VSHLI X, C1), C1) --> X iff NumSignBits(X) > C1
-      if (Op0.getOpcode() == X86ISD::VSHLI && Op1 == Op0.getOperand(1)) {
-        SDValue Op00 = Op0.getOperand(0);
-        unsigned NumSignBits =
-            TLO.DAG.ComputeNumSignBits(Op00, OriginalDemandedElts);
-        if (ShAmt < NumSignBits)
-          return TLO.CombineTo(Op, Op00);
-      }
+    // fold (VSRAI (VSHLI X, C1), C1) --> X iff NumSignBits(X) > C1
+    if (Op0.getOpcode() == X86ISD::VSHLI &&
+        Op.getOperand(1) == Op0.getOperand(1)) {
+      SDValue Op00 = Op0.getOperand(0);
+      unsigned NumSignBits =
+          TLO.DAG.ComputeNumSignBits(Op00, OriginalDemandedElts);
+      if (ShAmt < NumSignBits)
+        return TLO.CombineTo(Op, Op00);
+    }
 
-      // If any of the demanded bits are produced by the sign extension, we also
-      // demand the input sign bit.
-      if (OriginalDemandedBits.countLeadingZeros() < ShAmt)
-        DemandedMask.setSignBit();
+    // If any of the demanded bits are produced by the sign extension, we also
+    // demand the input sign bit.
+    if (OriginalDemandedBits.countLeadingZeros() < ShAmt)
+      DemandedMask.setSignBit();
 
-      if (SimplifyDemandedBits(Op0, DemandedMask, OriginalDemandedElts, Known,
-                               TLO, Depth + 1))
-        return true;
+    if (SimplifyDemandedBits(Op0, DemandedMask, OriginalDemandedElts, Known,
+                             TLO, Depth + 1))
+      return true;
 
-      assert(!Known.hasConflict() && "Bits known to be one AND zero?");
-      Known.Zero.lshrInPlace(ShAmt);
-      Known.One.lshrInPlace(ShAmt);
+    assert(!Known.hasConflict() && "Bits known to be one AND zero?");
+    Known.Zero.lshrInPlace(ShAmt);
+    Known.One.lshrInPlace(ShAmt);
 
-      // If the input sign bit is known to be zero, or if none of the top bits
-      // are demanded, turn this into an unsigned shift right.
-      if (Known.Zero[BitWidth - ShAmt - 1] ||
-          OriginalDemandedBits.countLeadingZeros() >= ShAmt)
-        return TLO.CombineTo(
-            Op, TLO.DAG.getNode(X86ISD::VSRLI, SDLoc(Op), VT, Op0, Op1));
+    // If the input sign bit is known to be zero, or if none of the top bits
+    // are demanded, turn this into an unsigned shift right.
+    if (Known.Zero[BitWidth - ShAmt - 1] ||
+        OriginalDemandedBits.countLeadingZeros() >= ShAmt)
+      return TLO.CombineTo(
+          Op, TLO.DAG.getNode(X86ISD::VSRLI, SDLoc(Op), VT, Op0, Op1));
 
-      // High bits are known one.
-      if (Known.One[BitWidth - ShAmt - 1])
-        Known.One.setHighBits(ShAmt);
-    }
+    // High bits are known one.
+    if (Known.One[BitWidth - ShAmt - 1])
+      Known.One.setHighBits(ShAmt);
     break;
   }
   case X86ISD::PEXTRB:
@@ -39347,15 +39338,15 @@ static SDValue combineVectorShiftImm(SDNode *N, SelectionDAG &DAG,
   bool LogicalShift = X86ISD::VSHLI == Opcode || X86ISD::VSRLI == Opcode;
   EVT VT = N->getValueType(0);
   SDValue N0 = N->getOperand(0);
-  SDValue N1 = N->getOperand(1);
   unsigned NumBitsPerElt = VT.getScalarSizeInBits();
   assert(VT == N0.getValueType() && (NumBitsPerElt % 8) == 0 &&
          "Unexpected value type");
-  assert(N1.getValueType() == MVT::i8 && "Unexpected shift amount type");
+  assert(N->getOperand(1).getValueType() == MVT::i8 &&
+         "Unexpected shift amount type");
 
   // Out of range logical bit shifts are guaranteed to be zero.
   // Out of range arithmetic bit shifts splat the sign bit.
-  unsigned ShiftVal = cast<ConstantSDNode>(N1)->getZExtValue();
+  unsigned ShiftVal = N->getConstantOperandVal(1);
   if (ShiftVal >= NumBitsPerElt) {
     if (LogicalShift)
       return DAG.getConstant(0, SDLoc(N), VT);


        


More information about the llvm-commits mailing list