[llvm] c41685b - [SelectionDAG] Make getZeroExtendInReg take a vector VT if the operand VT is a vector.

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Tue Apr 7 11:36:01 PDT 2020


Author: Craig Topper
Date: 2020-04-07T11:34:08-07:00
New Revision: c41685b16fcceaa2078eb14eb27f6696f851eb49

URL: https://github.com/llvm/llvm-project/commit/c41685b16fcceaa2078eb14eb27f6696f851eb49
DIFF: https://github.com/llvm/llvm-project/commit/c41685b16fcceaa2078eb14eb27f6696f851eb49.diff

LOG: [SelectionDAG] Make getZeroExtendInReg take a vector VT if the operand VT is a vector.

This removes a call to getScalarType from a bunch of call sites.
It also makes the behavior consistent with SIGN_EXTEND_INREG.

Differential Revision: https://reviews.llvm.org/D77631

Added: 
    

Modified: 
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
    llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
    llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
    llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
    llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
    llvm/lib/Target/X86/X86ISelLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 4d23dc9c0573..10312a5059d8 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -10142,7 +10142,7 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) {
       if (!LegalOperations || (TLI.isOperationLegal(ISD::AND, SrcVT) &&
                                TLI.isOperationLegal(ISD::ZERO_EXTEND, VT))) {
         SDValue Op = N0.getOperand(0);
-        Op = DAG.getZeroExtendInReg(Op, SDLoc(N), MinVT.getScalarType());
+        Op = DAG.getZeroExtendInReg(Op, SDLoc(N), MinVT);
         AddToWorklist(Op.getNode());
         SDValue ZExtOrTrunc = DAG.getZExtOrTrunc(Op, SDLoc(N), VT);
         // Transfer the debug info; the new node is equivalent to N0.
@@ -10154,7 +10154,7 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) {
     if (!LegalOperations || TLI.isOperationLegal(ISD::AND, VT)) {
       SDValue Op = DAG.getAnyExtOrTrunc(N0.getOperand(0), SDLoc(N), VT);
       AddToWorklist(Op.getNode());
-      SDValue And = DAG.getZeroExtendInReg(Op, SDLoc(N), MinVT.getScalarType());
+      SDValue And = DAG.getZeroExtendInReg(Op, SDLoc(N), MinVT);
       // We may safely transfer the debug info describing the truncate node over
       // to the equivalent and operation.
       DAG.transferDbgValues(N0, And);
@@ -10283,7 +10283,7 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) {
         // zext(setcc) -> zext_in_reg(vsetcc) for vectors.
         SDValue VSetCC = DAG.getNode(ISD::SETCC, DL, VT, N0.getOperand(0),
                                      N0.getOperand(1), N0.getOperand(2));
-        return DAG.getZeroExtendInReg(VSetCC, DL, MVT::i1);
+        return DAG.getZeroExtendInReg(VSetCC, DL, N0.getValueType());
       }
 
       // If the desired elements are smaller or larger than the source
@@ -10293,8 +10293,8 @@ SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) {
       SDValue VsetCC =
           DAG.getNode(ISD::SETCC, DL, MatchingVectorType, N0.getOperand(0),
                       N0.getOperand(1), N0.getOperand(2));
-      return DAG.getZeroExtendInReg(DAG.getAnyExtOrTrunc(VsetCC, DL, VT),
-                                    DL, MVT::i1);
+      return DAG.getZeroExtendInReg(DAG.getAnyExtOrTrunc(VsetCC, DL, VT), DL,
+                                    N0.getValueType());
     }
 
     // zext(setcc x,y,cc) -> select_cc x, y, 1, 0, cc
@@ -10812,7 +10812,7 @@ SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) {
 
   // fold (sext_in_reg x) -> (zext_in_reg x) if the sign bit is known zero.
   if (DAG.MaskedValueIsZero(N0, APInt::getOneBitSet(VTBits, EVTBits - 1)))
-    return DAG.getZeroExtendInReg(N0, SDLoc(N), EVT.getScalarType());
+    return DAG.getZeroExtendInReg(N0, SDLoc(N), EVT);
 
   // fold operands of sext_in_reg based on knowledge that the top bits are not
   // demanded.

diff  --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
index f64d7fd87086..112a06385931 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
@@ -933,7 +933,7 @@ void SelectionDAGLegalize::LegalizeLoadOps(SDNode *Node) {
                              Result.getValueType(),
                              Result, DAG.getValueType(SrcVT));
       else
-        ValRes = DAG.getZeroExtendInReg(Result, dl, SrcVT.getScalarType());
+        ValRes = DAG.getZeroExtendInReg(Result, dl, SrcVT);
       Value = ValRes;
       Chain = Result.getValue(1);
       break;
@@ -3531,8 +3531,9 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
     SDValue Overflow = DAG.getSetCC(dl, SetCCType, Sum, LHS, CC);
 
     // Add of the sum and the carry.
+    SDValue One = DAG.getConstant(1, dl, VT);
     SDValue CarryExt =
-        DAG.getZeroExtendInReg(DAG.getZExtOrTrunc(Carry, dl, VT), dl, MVT::i1);
+        DAG.getNode(ISD::AND, dl, VT, DAG.getZExtOrTrunc(Carry, dl, VT), One);
     SDValue Sum2 = DAG.getNode(Op, dl, VT, Sum, CarryExt);
 
     // Second check for overflow. If we are adding, we can only overflow if the

diff  --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index ed67f7dc8ea3..716fe9ddd60c 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -615,8 +615,7 @@ SDValue DAGTypeLegalizer::PromoteIntRes_INT_EXTEND(SDNode *N) {
         return DAG.getNode(ISD::SIGN_EXTEND_INREG, dl, NVT, Res,
                            DAG.getValueType(N->getOperand(0).getValueType()));
       if (N->getOpcode() == ISD::ZERO_EXTEND)
-        return DAG.getZeroExtendInReg(Res, dl,
-                      N->getOperand(0).getValueType().getScalarType());
+        return DAG.getZeroExtendInReg(Res, dl, N->getOperand(0).getValueType());
       assert(N->getOpcode() == ISD::ANY_EXTEND && "Unknown integer extension!");
       return Res;
     }
@@ -1169,7 +1168,7 @@ SDValue DAGTypeLegalizer::PromoteIntRes_UADDSUBO(SDNode *N, unsigned ResNo) {
 
   // Calculate the overflow flag: zero extend the arithmetic result from
   // the original type.
-  SDValue Ofl = DAG.getZeroExtendInReg(Res, dl, OVT.getScalarType());
+  SDValue Ofl = DAG.getZeroExtendInReg(Res, dl, OVT);
   // Overflowed if and only if this is not equal to Res.
   Ofl = DAG.getSetCC(dl, N->getValueType(1), Ofl, Res, ISD::SETNE);
 
@@ -1784,8 +1783,7 @@ SDValue DAGTypeLegalizer::PromoteIntOp_ZERO_EXTEND(SDNode *N) {
   SDLoc dl(N);
   SDValue Op = GetPromotedInteger(N->getOperand(0));
   Op = DAG.getNode(ISD::ANY_EXTEND, dl, N->getValueType(0), Op);
-  return DAG.getZeroExtendInReg(Op, dl,
-                                N->getOperand(0).getValueType().getScalarType());
+  return DAG.getZeroExtendInReg(Op, dl, N->getOperand(0).getValueType());
 }
 
 SDValue DAGTypeLegalizer::PromoteIntOp_ADDSUBCARRY(SDNode *N, unsigned OpNo) {

diff  --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index e27f1e5b3231..d858a10f4698 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -265,7 +265,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
     EVT OldVT = Op.getValueType();
     SDLoc dl(Op);
     Op = GetPromotedInteger(Op);
-    return DAG.getZeroExtendInReg(Op, dl, OldVT.getScalarType());
+    return DAG.getZeroExtendInReg(Op, dl, OldVT);
   }
 
   // Get a promoted operand and sign or zero extend it to the final size
@@ -279,7 +279,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
     if (TLI.isSExtCheaperThanZExt(OldVT, Op.getValueType()))
       return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, Op.getValueType(), Op,
                          DAG.getValueType(OldVT));
-    return DAG.getZeroExtendInReg(Op, DL, OldVT.getScalarType());
+    return DAG.getZeroExtendInReg(Op, DL, OldVT);
   }
 
   // Integer Result Promotion.

diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 5f6d8bf95517..b752b14abc37 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -1167,15 +1167,21 @@ SDValue SelectionDAG::getBoolExtOrTrunc(SDValue Op, const SDLoc &SL, EVT VT,
 }
 
 SDValue SelectionDAG::getZeroExtendInReg(SDValue Op, const SDLoc &DL, EVT VT) {
-  assert(!VT.isVector() &&
-         "getZeroExtendInReg should use the vector element type instead of "
-         "the vector type!");
-  if (Op.getValueType().getScalarType() == VT) return Op;
-  unsigned BitWidth = Op.getScalarValueSizeInBits();
-  APInt Imm = APInt::getLowBitsSet(BitWidth,
-                                   VT.getSizeInBits());
-  return getNode(ISD::AND, DL, Op.getValueType(), Op,
-                 getConstant(Imm, DL, Op.getValueType()));
+  EVT OpVT = Op.getValueType();
+  assert(VT.isInteger() && OpVT.isInteger() &&
+         "Cannot getZeroExtendInReg FP types");
+  assert(VT.isVector() == OpVT.isVector() &&
+         "getZeroExtendInReg type should be vector iff the operand "
+         "type is vector!");
+  assert((!VT.isVector() ||
+          VT.getVectorNumElements() == OpVT.getVectorNumElements()) &&
+         "Vector element counts must match in getZeroExtendInReg");
+  assert(VT.bitsLE(OpVT) && "Not extending!");
+  if (OpVT == VT)
+    return Op;
+  APInt Imm = APInt::getLowBitsSet(OpVT.getScalarSizeInBits(),
+                                   VT.getScalarSizeInBits());
+  return getNode(ISD::AND, DL, OpVT, Op, getConstant(Imm, DL, OpVT));
 }
 
 SDValue SelectionDAG::getPtrExtOrTrunc(SDValue Op, const SDLoc &DL, EVT VT) {

diff  --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 7e613570eaf9..f45ca102d620 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -1727,8 +1727,7 @@ bool TargetLowering::SimplifyDemandedBits(
 
     // If the input sign bit is known zero, convert this into a zero extension.
     if (Known.Zero[ExVTBits - 1])
-      return TLO.CombineTo(
-          Op, TLO.DAG.getZeroExtendInReg(Op0, dl, ExVT.getScalarType()));
+      return TLO.CombineTo(Op, TLO.DAG.getZeroExtendInReg(Op0, dl, ExVT));
 
     APInt Mask = APInt::getLowBitsSet(BitWidth, ExVTBits);
     if (Known.One[ExVTBits - 1]) { // Input sign bit known set

diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index a9db423caec2..c89e4ac3e899 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -41259,7 +41259,7 @@ static SDValue PromoteMaskArithmetic(SDNode *N, SelectionDAG &DAG,
   case ISD::ANY_EXTEND:
     return Op;
   case ISD::ZERO_EXTEND:
-    return DAG.getZeroExtendInReg(Op, DL, NarrowVT.getScalarType());
+    return DAG.getZeroExtendInReg(Op, DL, NarrowVT);
   case ISD::SIGN_EXTEND:
     return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT,
                        Op, DAG.getValueType(NarrowVT));
@@ -44870,7 +44870,7 @@ static SDValue combineExtSetcc(SDNode *N, SelectionDAG &DAG,
   SDValue Res = DAG.getSetCC(dl, VT, N0.getOperand(0), N0.getOperand(1), CC);
 
   if (N->getOpcode() == ISD::ZERO_EXTEND)
-    Res = DAG.getZeroExtendInReg(Res, dl, N0.getValueType().getScalarType());
+    Res = DAG.getZeroExtendInReg(Res, dl, N0.getValueType());
 
   return Res;
 }


        


More information about the llvm-commits mailing list