[llvm] r370569 - [CodeGen] Refactor DAGTypeLegalizer::ExpandIntRes_MULFIX. NFC

Bjorn Pettersson via llvm-commits llvm-commits at lists.llvm.org
Sat Aug 31 02:28:50 PDT 2019


Author: bjope
Date: Sat Aug 31 02:28:50 2019
New Revision: 370569

URL: http://llvm.org/viewvc/llvm-project?rev=370569&view=rev
Log:
[CodeGen] Refactor DAGTypeLegalizer::ExpandIntRes_MULFIX. NFC

Restructured the code a little bit in preparation for adding
UMULFIXSAT. I think it will be easier to understand the code
if not interleaving the codegen for signed/unsigned/saturated
cases that much.

Modified:
    llvm/trunk/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp

Modified: llvm/trunk/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp?rev=370569&r1=370568&r2=370569&view=diff
==============================================================================
--- llvm/trunk/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp (original)
+++ llvm/trunk/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp Sat Aug 31 02:28:50 2019
@@ -2811,21 +2811,25 @@ void DAGTypeLegalizer::ExpandIntRes_MULF
   SDValue RHS = N->getOperand(1);
   uint64_t Scale = N->getConstantOperandVal(2);
   bool Saturating = N->getOpcode() == ISD::SMULFIXSAT;
-  EVT BoolVT = getSetCCResultType(VT);
-  SDValue Zero = DAG.getConstant(0, dl, VT);
+  bool Signed = (N->getOpcode() == ISD::SMULFIX ||
+                 N->getOpcode() == ISD::SMULFIXSAT);
+
+  // Handle special case when scale is equal to zero.
   if (!Scale) {
     SDValue Result;
     if (!Saturating) {
       Result = DAG.getNode(ISD::MUL, dl, VT, LHS, RHS);
     } else {
+      EVT BoolVT = getSetCCResultType(VT);
       Result = DAG.getNode(ISD::SMULO, dl, DAG.getVTList(VT, BoolVT), LHS, RHS);
       SDValue Product = Result.getValue(0);
       SDValue Overflow = Result.getValue(1);
-
+      assert(Signed && "Unsigned saturation not supported (yet).");
       APInt MinVal = APInt::getSignedMinValue(VTSize);
       APInt MaxVal = APInt::getSignedMaxValue(VTSize);
       SDValue SatMin = DAG.getConstant(MinVal, dl, VT);
       SDValue SatMax = DAG.getConstant(MaxVal, dl, VT);
+      SDValue Zero = DAG.getConstant(0, dl, VT);
       SDValue ProdNeg = DAG.getSetCC(dl, BoolVT, Product, Zero, ISD::SETLT);
       Result = DAG.getSelect(dl, VT, ProdNeg, SatMax, SatMin);
       Result = DAG.getSelect(dl, VT, Overflow, Result, Product);
@@ -2840,8 +2844,6 @@ void DAGTypeLegalizer::ExpandIntRes_MULF
   GetExpandedInteger(RHS, RL, RH);
   SmallVector<SDValue, 4> Result;
 
-  bool Signed = (N->getOpcode() == ISD::SMULFIX ||
-                 N->getOpcode() == ISD::SMULFIXSAT);
   unsigned LoHiOp = Signed ? ISD::SMUL_LOHI : ISD::UMUL_LOHI;
   if (!TLI.expandMUL_LOHI(LoHiOp, VT, dl, LHS, RHS, Result, NVT, DAG,
                           TargetLowering::MulExpansionKind::OnlyLegalOrCustom,
@@ -2855,19 +2857,14 @@ void DAGTypeLegalizer::ExpandIntRes_MULF
                                     "the size of the current value type");
   EVT ShiftTy = TLI.getShiftAmountTy(NVT, DAG.getDataLayout());
 
-  // Shift whole amount by scale.
   SDValue ResultLL = Result[0];
   SDValue ResultLH = Result[1];
   SDValue ResultHL = Result[2];
   SDValue ResultHH = Result[3];
 
-  SDValue SatMax, SatMin;
-  SDValue NVTZero = DAG.getConstant(0, dl, NVT);
-  SDValue NVTNeg1 = DAG.getConstant(-1, dl, NVT);
-  EVT BoolNVT = getSetCCResultType(NVT);
-
-  // After getting the multplication result in 4 parts, we need to perform a
+  // After getting the multiplication result in 4 parts, we need to perform a
   // shift right by the amount of the scale to get the result in that scale.
+  //
   // Let's say we multiply 2 64 bit numbers. The resulting value can be held in
   // 128 bits that are cut into 4 32-bit parts:
   //
@@ -2894,66 +2891,17 @@ void DAGTypeLegalizer::ExpandIntRes_MULF
     Hi = DAG.getNode(ISD::SRL, dl, NVT, ResultLH, SRLAmnt);
     Hi = DAG.getNode(ISD::OR, dl, NVT, Hi,
                      DAG.getNode(ISD::SHL, dl, NVT, ResultHL, SHLAmnt));
-
-    // We cannot overflow past HH when multiplying 2 ints of size VTSize, so the
-    // highest bit of HH determines saturation direction in the event of
-    // saturation.
-    // The number of overflow bits we can check are VTSize - Scale + 1 (we
-    // include the sign bit). If these top bits are > 0, then we overflowed past
-    // the max value. If these top bits are < -1, then we overflowed past the
-    // min value. Otherwise, we did not overflow.
-    if (Saturating) {
-      unsigned OverflowBits = VTSize - Scale + 1;
-      assert(OverflowBits <= VTSize && OverflowBits > NVTSize &&
-             "Extent of overflow bits must start within HL");
-      SDValue HLHiMask = DAG.getConstant(
-          APInt::getHighBitsSet(NVTSize, OverflowBits - NVTSize), dl, NVT);
-      SDValue HLLoMask = DAG.getConstant(
-          APInt::getLowBitsSet(NVTSize, VTSize - OverflowBits), dl, NVT);
-
-      // HH > 0 or HH == 0 && HL > HLLoMask
-      SDValue HHPos = DAG.getSetCC(dl, BoolNVT, ResultHH, NVTZero, ISD::SETGT);
-      SDValue HHZero = DAG.getSetCC(dl, BoolNVT, ResultHH, NVTZero, ISD::SETEQ);
-      SDValue HLPos =
-          DAG.getSetCC(dl, BoolNVT, ResultHL, HLLoMask, ISD::SETUGT);
-      SatMax = DAG.getNode(ISD::OR, dl, BoolNVT, HHPos,
-                           DAG.getNode(ISD::AND, dl, BoolNVT, HHZero, HLPos));
-
-      // HH < -1 or HH == -1 && HL < HLHiMask
-      SDValue HHNeg = DAG.getSetCC(dl, BoolNVT, ResultHH, NVTNeg1, ISD::SETLT);
-      SDValue HHNeg1 = DAG.getSetCC(dl, BoolNVT, ResultHH, NVTNeg1, ISD::SETEQ);
-      SDValue HLNeg =
-          DAG.getSetCC(dl, BoolNVT, ResultHL, HLHiMask, ISD::SETULT);
-      SatMin = DAG.getNode(ISD::OR, dl, BoolNVT, HHNeg,
-                           DAG.getNode(ISD::AND, dl, BoolNVT, HHNeg1, HLNeg));
-    }
   } else if (Scale == NVTSize) {
-    // If the scales are equal, Lo and Hi are ResultLH and Result HL,
+    // If the scales are equal, Lo and Hi are ResultLH and ResultHL,
     // respectively. Avoid shifting to prevent undefined behavior.
     Lo = ResultLH;
     Hi = ResultHL;
-
-    // We overflow max if HH > 0 or HH == 0 && HL sign bit is 1.
-    // We overflow min if HH < -1 or HH == -1 && HL sign bit is 0.
-    if (Saturating) {
-      SDValue HHPos = DAG.getSetCC(dl, BoolNVT, ResultHH, NVTZero, ISD::SETGT);
-      SDValue HHZero = DAG.getSetCC(dl, BoolNVT, ResultHH, NVTZero, ISD::SETEQ);
-      SDValue HLNeg = DAG.getSetCC(dl, BoolNVT, ResultHL, NVTZero, ISD::SETLT);
-      SatMax = DAG.getNode(ISD::OR, dl, BoolNVT, HHPos,
-                           DAG.getNode(ISD::AND, dl, BoolNVT, HHZero, HLNeg));
-
-      SDValue HHNeg = DAG.getSetCC(dl, BoolNVT, ResultHH, NVTNeg1, ISD::SETLT);
-      SDValue HHNeg1 = DAG.getSetCC(dl, BoolNVT, ResultHH, NVTNeg1, ISD::SETEQ);
-      SDValue HLPos = DAG.getSetCC(dl, BoolNVT, ResultHL, NVTZero, ISD::SETGE);
-      SatMin = DAG.getNode(ISD::OR, dl, BoolNVT, HHNeg,
-                           DAG.getNode(ISD::AND, dl, BoolNVT, HHNeg1, HLPos));
-    }
   } else if (Scale < VTSize) {
     // If the scale is instead less than the old VT size, but greater than or
     // equal to the expanded VT size, the first part of the result (ResultLL) is
     // no longer a part of Lo because it would be scaled out anyway. Instead we
     // can start shifting right from the fourth part (ResultHH) to the second
-    // part (ResultLH), and Result LH will be the new Lo.
+    // part (ResultLH), and ResultLH will be the new Lo.
     SDValue SRLAmnt = DAG.getConstant(Scale - NVTSize, dl, ShiftTy);
     SDValue SHLAmnt = DAG.getConstant(VTSize - Scale, dl, ShiftTy);
     Lo = DAG.getNode(ISD::SRL, dl, NVT, ResultLH, SRLAmnt);
@@ -2962,19 +2910,6 @@ void DAGTypeLegalizer::ExpandIntRes_MULF
     Hi = DAG.getNode(ISD::SRL, dl, NVT, ResultHL, SRLAmnt);
     Hi = DAG.getNode(ISD::OR, dl, NVT, Hi,
                      DAG.getNode(ISD::SHL, dl, NVT, ResultHH, SHLAmnt));
-
-    // This is similar to the case when we saturate if Scale < NVTSize, but we
-    // only need to chech HH.
-    if (Saturating) {
-      unsigned OverflowBits = VTSize - Scale + 1;
-      SDValue HHHiMask = DAG.getConstant(
-          APInt::getHighBitsSet(NVTSize, OverflowBits), dl, NVT);
-      SDValue HHLoMask = DAG.getConstant(
-          APInt::getLowBitsSet(NVTSize, NVTSize - OverflowBits), dl, NVT);
-
-      SatMax = DAG.getSetCC(dl, BoolNVT, ResultHH, HHLoMask, ISD::SETGT);
-      SatMin = DAG.getSetCC(dl, BoolNVT, ResultHH, HHHiMask, ISD::SETLT);
-    }
   } else if (Scale == VTSize) {
     assert(
         !Signed &&
@@ -2982,20 +2917,90 @@ void DAGTypeLegalizer::ExpandIntRes_MULF
 
     Lo = ResultHL;
     Hi = ResultHH;
-  } else {
+  } else
     llvm_unreachable("Expected the scale to be less than or equal to the width "
                      "of the operands");
-  }
 
-  if (Saturating) {
-    APInt LHMax = APInt::getSignedMaxValue(NVTSize);
-    APInt LLMax = APInt::getAllOnesValue(NVTSize);
-    APInt LHMin = APInt::getSignedMinValue(NVTSize);
-    Hi = DAG.getSelect(dl, NVT, SatMax, DAG.getConstant(LHMax, dl, NVT), Hi);
-    Hi = DAG.getSelect(dl, NVT, SatMin, DAG.getConstant(LHMin, dl, NVT), Hi);
-    Lo = DAG.getSelect(dl, NVT, SatMax, DAG.getConstant(LLMax, dl, NVT), Lo);
-    Lo = DAG.getSelect(dl, NVT, SatMin, NVTZero, Lo);
-  }
+  // Unless saturation is requested we are done. The result is in <Hi,Lo>.
+  if (!Saturating)
+    return;
+
+  // To handle saturation we must check for overflow in the multiplication.
+  //
+  // Signed overflow happened if the upper (VTSize - Scale + 1) bits (of Result)
+  // aren't all ones or all zeroes.
+  //
+  // We cannot overflow past HH when multiplying 2 ints of size VTSize, so the
+  // highest bit of HH determines saturation direction in the event of
+  // saturation.
+
+  SDValue SatMax, SatMin;
+  SDValue NVTZero = DAG.getConstant(0, dl, NVT);
+  SDValue NVTNeg1 = DAG.getConstant(-1, dl, NVT);
+  EVT BoolNVT = getSetCCResultType(NVT);
+
+  if (!Signed)
+    llvm_unreachable("Unsigned saturation not supported (yet).");
+
+  if (Scale < NVTSize) {
+    // The number of overflow bits we can check are VTSize - Scale + 1 (we
+    // include the sign bit). If these top bits are > 0, then we overflowed past
+    // the max value. If these top bits are < -1, then we overflowed past the
+    // min value. Otherwise, we did not overflow.
+    unsigned OverflowBits = VTSize - Scale + 1;
+    assert(OverflowBits <= VTSize && OverflowBits > NVTSize &&
+           "Extent of overflow bits must start within HL");
+    SDValue HLHiMask = DAG.getConstant(
+        APInt::getHighBitsSet(NVTSize, OverflowBits - NVTSize), dl, NVT);
+    SDValue HLLoMask = DAG.getConstant(
+        APInt::getLowBitsSet(NVTSize, VTSize - OverflowBits), dl, NVT);
+    // We overflow max if HH > 0 or (HH == 0 && HL > HLLoMask).
+    SDValue HHGT0 = DAG.getSetCC(dl, BoolNVT, ResultHH, NVTZero, ISD::SETGT);
+    SDValue HHEQ0 = DAG.getSetCC(dl, BoolNVT, ResultHH, NVTZero, ISD::SETEQ);
+    SDValue HLUGT = DAG.getSetCC(dl, BoolNVT, ResultHL, HLLoMask, ISD::SETUGT);
+    SatMax = DAG.getNode(ISD::OR, dl, BoolNVT, HHGT0,
+                         DAG.getNode(ISD::AND, dl, BoolNVT, HHEQ0, HLUGT));
+    // We overflow min if HH < -1 or (HH == -1 && HL < HLHiMask).
+    SDValue HHLT = DAG.getSetCC(dl, BoolNVT, ResultHH, NVTNeg1, ISD::SETLT);
+    SDValue HHEQ = DAG.getSetCC(dl, BoolNVT, ResultHH, NVTNeg1, ISD::SETEQ);
+    SDValue HLULT = DAG.getSetCC(dl, BoolNVT, ResultHL, HLHiMask, ISD::SETULT);
+    SatMin = DAG.getNode(ISD::OR, dl, BoolNVT, HHLT,
+                         DAG.getNode(ISD::AND, dl, BoolNVT, HHEQ, HLULT));
+  } else if (Scale == NVTSize) {
+    // We overflow max if HH > 0 or (HH == 0 && HL sign bit is 1).
+    SDValue HHGT0 = DAG.getSetCC(dl, BoolNVT, ResultHH, NVTZero, ISD::SETGT);
+    SDValue HHEQ0 = DAG.getSetCC(dl, BoolNVT, ResultHH, NVTZero, ISD::SETEQ);
+    SDValue HLNeg = DAG.getSetCC(dl, BoolNVT, ResultHL, NVTZero, ISD::SETLT);
+    SatMax = DAG.getNode(ISD::OR, dl, BoolNVT, HHGT0,
+                         DAG.getNode(ISD::AND, dl, BoolNVT, HHEQ0, HLNeg));
+    // We overflow min if HH < -1 or (HH == -1 && HL sign bit is 0).
+    SDValue HHLT = DAG.getSetCC(dl, BoolNVT, ResultHH, NVTNeg1, ISD::SETLT);
+    SDValue HHEQ = DAG.getSetCC(dl, BoolNVT, ResultHH, NVTNeg1, ISD::SETEQ);
+    SDValue HLPos = DAG.getSetCC(dl, BoolNVT, ResultHL, NVTZero, ISD::SETGE);
+    SatMin = DAG.getNode(ISD::OR, dl, BoolNVT, HHLT,
+                         DAG.getNode(ISD::AND, dl, BoolNVT, HHEQ, HLPos));
+  } else if (Scale < VTSize) {
+    // This is similar to the case when we saturate if Scale < NVTSize, but we
+    // only need to check HH.
+    unsigned OverflowBits = VTSize - Scale + 1;
+    SDValue HHHiMask = DAG.getConstant(
+        APInt::getHighBitsSet(NVTSize, OverflowBits), dl, NVT);
+    SDValue HHLoMask = DAG.getConstant(
+        APInt::getLowBitsSet(NVTSize, NVTSize - OverflowBits), dl, NVT);
+    SatMax = DAG.getSetCC(dl, BoolNVT, ResultHH, HHLoMask, ISD::SETGT);
+    SatMin = DAG.getSetCC(dl, BoolNVT, ResultHH, HHHiMask, ISD::SETLT);
+  } else
+    llvm_unreachable("Illegal scale for signed fixed point mul.");
+
+  // Saturate to signed maximum.
+  APInt MaxHi = APInt::getSignedMaxValue(NVTSize);
+  APInt MaxLo = APInt::getAllOnesValue(NVTSize);
+  Hi = DAG.getSelect(dl, NVT, SatMax, DAG.getConstant(MaxHi, dl, NVT), Hi);
+  Lo = DAG.getSelect(dl, NVT, SatMax, DAG.getConstant(MaxLo, dl, NVT), Lo);
+  // Saturate to signed minimum.
+  APInt MinHi = APInt::getSignedMinValue(NVTSize);
+  Hi = DAG.getSelect(dl, NVT, SatMin, DAG.getConstant(MinHi, dl, NVT), Hi);
+  Lo = DAG.getSelect(dl, NVT, SatMin, NVTZero, Lo);
 }
 
 void DAGTypeLegalizer::ExpandIntRes_SADDSUBO(SDNode *Node,




More information about the llvm-commits mailing list