[llvm] r289050 - [SelectionDAG] Add expansion and promotion of [US]MUL_LOHI

Nicolai Haehnle via llvm-commits llvm-commits at lists.llvm.org
Thu Dec 8 06:08:15 PST 2016


Author: nha
Date: Thu Dec  8 08:08:14 2016
New Revision: 289050

URL: http://llvm.org/viewvc/llvm-project?rev=289050&view=rev
Log:
[SelectionDAG] Add expansion and promotion of [US]MUL_LOHI

Summary:
Most targets set the action for these nodes to Expand even though there
isn't actually any code for them in ExpandNode. Instead, targets simply
relied on the fact that no code generates these nodes as long as the
nodes aren't legal or custom.

However, generating these nodes can be useful e.g. for divide-by-constant
in wider integer types.

Expand of [US]MUL_LOHI will use MULH[US] when legal or custom, and
a sequence of half-width multiplications otherwise. Promote uses a wider
multiply.

This patch intends to not change the generated code, but indirect effects
are possible since expansions/promotions that were previously done in
DAGCombine may now be done in LegalizeDAG.

See D24822 for a change that actually uses the new expansion.

Reviewers: spatel, bkramer, venkatra, efriedma, hfinkel, ast, nadav, tstellarAMD

Subscribers: arsenm, jyknight, nemanjai, wdng, nhaehnle, llvm-commits

Differential Revision: https://reviews.llvm.org/D24956

Modified:
    llvm/trunk/include/llvm/Target/TargetLowering.h
    llvm/trunk/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
    llvm/trunk/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
    llvm/trunk/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
    llvm/trunk/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Modified: llvm/trunk/include/llvm/Target/TargetLowering.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/Target/TargetLowering.h?rev=289050&r1=289049&r2=289050&view=diff
==============================================================================
--- llvm/trunk/include/llvm/Target/TargetLowering.h (original)
+++ llvm/trunk/include/llvm/Target/TargetLowering.h Thu Dec  8 08:08:14 2016
@@ -142,6 +142,13 @@ public:
     CmpXChg, // Expand the instruction into cmpxchg; used by at least X86.
   };
 
+  /// Enum that specifies when a multiplication should be expanded.
+  enum class MulExpansionKind {
+    Always,            // Always expand the instruction.
+    OnlyLegalOrCustom, // Only expand when the resulting instructions are legal
+                       // or custom.
+  };
+
   static ISD::NodeType getExtendForContent(BooleanContent Content) {
     switch (Content) {
     case UndefinedBooleanContent:
@@ -3036,9 +3043,28 @@ public:
   // Legalization utility functions
   //
 
+  /// Expand a MUL or [US]MUL_LOHI of n-bit values into two or four nodes,
+  /// respectively, each computing an n/2-bit part of the result.
+  /// \param Result A vector that will be filled with the parts of the result
+  ///        in little-endian order.
+  /// \param HalfVT The value type to use for the result nodes.
+  /// \param OnlyLegalOrCustom Only legal or custom instructions are used.
+  /// \param LL Low bits of the LHS of the MUL.  You can use this parameter
+  ///        if you want to control how low bits are extracted from the LHS.
+  /// \param LH High bits of the LHS of the MUL.  See LL for meaning.
+  /// \param RL Low bits of the RHS of the MUL.  See LL for meaning
+  /// \param RH High bits of the RHS of the MUL.  See LL for meaning.
+  /// \returns true if the node has been expanded, false if it has not
+  bool expandMUL_LOHI(unsigned Opcode, EVT VT, SDLoc dl, SDValue LHS,
+                      SDValue RHS, SmallVectorImpl<SDValue> &Result, EVT HiLoVT,
+                      SelectionDAG &DAG, MulExpansionKind Kind,
+                      SDValue LL = SDValue(), SDValue LH = SDValue(),
+                      SDValue RL = SDValue(), SDValue RH = SDValue()) const;
+
   /// Expand a MUL into two nodes.  One that computes the high bits of
   /// the result and one that computes the low bits.
   /// \param HiLoVT The value type to use for the Lo and Hi nodes.
+  /// \param OnlyLegalOrCustom Only legal or custom instructions are used.
   /// \param LL Low bits of the LHS of the MUL.  You can use this parameter
   ///        if you want to control how low bits are extracted from the LHS.
   /// \param LH High bits of the LHS of the MUL.  See LL for meaning.
@@ -3046,9 +3072,9 @@ public:
   /// \param RH High bits of the RHS of the MUL.  See LL for meaning.
   /// \returns true if the node has been expanded. false if it has not
   bool expandMUL(SDNode *N, SDValue &Lo, SDValue &Hi, EVT HiLoVT,
-                 SelectionDAG &DAG, SDValue LL = SDValue(),
-                 SDValue LH = SDValue(), SDValue RL = SDValue(),
-                 SDValue RH = SDValue()) const;
+                 SelectionDAG &DAG, MulExpansionKind Kind,
+                 SDValue LL = SDValue(), SDValue LH = SDValue(),
+                 SDValue RL = SDValue(), SDValue RH = SDValue()) const;
 
   /// Expand float(f32) to SINT(i64) conversion
   /// \param N Node to expand

Modified: llvm/trunk/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp?rev=289050&r1=289049&r2=289050&view=diff
==============================================================================
--- llvm/trunk/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp (original)
+++ llvm/trunk/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp Thu Dec  8 08:08:14 2016
@@ -3312,17 +3312,49 @@ bool SelectionDAGLegalize::ExpandNode(SD
   }
   case ISD::MULHU:
   case ISD::MULHS: {
-    unsigned ExpandOpcode = Node->getOpcode() == ISD::MULHU ? ISD::UMUL_LOHI :
-                                                              ISD::SMUL_LOHI;
+    unsigned ExpandOpcode =
+        Node->getOpcode() == ISD::MULHU ? ISD::UMUL_LOHI : ISD::SMUL_LOHI;
     EVT VT = Node->getValueType(0);
     SDVTList VTs = DAG.getVTList(VT, VT);
-    assert(TLI.isOperationLegalOrCustom(ExpandOpcode, VT) &&
-           "If this wasn't legal, it shouldn't have been created!");
+
     Tmp1 = DAG.getNode(ExpandOpcode, dl, VTs, Node->getOperand(0),
                        Node->getOperand(1));
     Results.push_back(Tmp1.getValue(1));
     break;
   }
+  case ISD::UMUL_LOHI:
+  case ISD::SMUL_LOHI: {
+    SDValue LHS = Node->getOperand(0);
+    SDValue RHS = Node->getOperand(1);
+    MVT VT = LHS.getSimpleValueType();
+    unsigned MULHOpcode =
+        Node->getOpcode() == ISD::UMUL_LOHI ? ISD::MULHU : ISD::MULHS;
+
+    if (TLI.isOperationLegalOrCustom(MULHOpcode, VT)) {
+      Results.push_back(DAG.getNode(ISD::MUL, dl, VT, LHS, RHS));
+      Results.push_back(DAG.getNode(MULHOpcode, dl, VT, LHS, RHS));
+      break;
+    }
+
+    SmallVector<SDValue, 4> Halves;
+    EVT HalfType = EVT(VT).getHalfSizedIntegerVT(*DAG.getContext());
+    assert(TLI.isTypeLegal(HalfType));
+    if (TLI.expandMUL_LOHI(Node->getOpcode(), VT, Node, LHS, RHS, Halves,
+                           HalfType, DAG,
+                           TargetLowering::MulExpansionKind::Always)) {
+      for (unsigned i = 0; i < 2; ++i) {
+        SDValue Lo = DAG.getNode(ISD::ZERO_EXTEND, dl, VT, Halves[2 * i]);
+        SDValue Hi = DAG.getNode(ISD::ANY_EXTEND, dl, VT, Halves[2 * i + 1]);
+        SDValue Shift = DAG.getConstant(
+            HalfType.getScalarSizeInBits(), dl,
+            TLI.getShiftAmountTy(HalfType, DAG.getDataLayout()));
+        Hi = DAG.getNode(ISD::SHL, dl, VT, Hi, Shift);
+        Results.push_back(DAG.getNode(ISD::OR, dl, VT, Lo, Hi));
+      }
+      break;
+    }
+    break;
+  }
   case ISD::MUL: {
     EVT VT = Node->getValueType(0);
     SDVTList VTs = DAG.getVTList(VT, VT);
@@ -3357,7 +3389,8 @@ bool SelectionDAGLegalize::ExpandNode(SD
         TLI.isOperationLegalOrCustom(ISD::ANY_EXTEND, VT) &&
         TLI.isOperationLegalOrCustom(ISD::SHL, VT) &&
         TLI.isOperationLegalOrCustom(ISD::OR, VT) &&
-        TLI.expandMUL(Node, Lo, Hi, HalfType, DAG)) {
+        TLI.expandMUL(Node, Lo, Hi, HalfType, DAG,
+                      TargetLowering::MulExpansionKind::OnlyLegalOrCustom)) {
       Lo = DAG.getNode(ISD::ZERO_EXTEND, dl, VT, Lo);
       Hi = DAG.getNode(ISD::ANY_EXTEND, dl, VT, Hi);
       SDValue Shift =
@@ -4194,6 +4227,24 @@ void SelectionDAGLegalize::PromoteNode(S
     Results.push_back(DAG.getNode(TruncOp, dl, OVT, Tmp1));
     break;
   }
+  case ISD::UMUL_LOHI:
+  case ISD::SMUL_LOHI: {
+    // Promote to a multiply in a wider integer type.
+    unsigned ExtOp = Node->getOpcode() == ISD::UMUL_LOHI ? ISD::ZERO_EXTEND
+                                                         : ISD::SIGN_EXTEND;
+    Tmp1 = DAG.getNode(ExtOp, dl, NVT, Node->getOperand(0));
+    Tmp2 = DAG.getNode(ExtOp, dl, NVT, Node->getOperand(1));
+    Tmp1 = DAG.getNode(ISD::MUL, dl, NVT, Tmp1, Tmp2);
+
+    auto &DL = DAG.getDataLayout();
+    unsigned OriginalSize = OVT.getScalarSizeInBits();
+    Tmp2 = DAG.getNode(
+        ISD::SRL, dl, NVT, Tmp1,
+        DAG.getConstant(OriginalSize, dl, TLI.getScalarShiftAmountTy(DL, NVT)));
+    Results.push_back(DAG.getNode(ISD::TRUNCATE, dl, OVT, Tmp1));
+    Results.push_back(DAG.getNode(ISD::TRUNCATE, dl, OVT, Tmp2));
+    break;
+  }
   case ISD::SELECT: {
     unsigned ExtOp, TruncOp;
     if (Node->getValueType(0).isVector() ||

Modified: llvm/trunk/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp?rev=289050&r1=289049&r2=289050&view=diff
==============================================================================
--- llvm/trunk/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp (original)
+++ llvm/trunk/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp Thu Dec  8 08:08:14 2016
@@ -2189,7 +2189,9 @@ void DAGTypeLegalizer::ExpandIntRes_MUL(
   GetExpandedInteger(N->getOperand(0), LL, LH);
   GetExpandedInteger(N->getOperand(1), RL, RH);
 
-  if (TLI.expandMUL(N, Lo, Hi, NVT, DAG, LL, LH, RL, RH))
+  if (TLI.expandMUL(N, Lo, Hi, NVT, DAG,
+                    TargetLowering::MulExpansionKind::OnlyLegalOrCustom,
+                    LL, LH, RL, RH))
     return;
 
   // If nothing else, we can make a libcall.

Modified: llvm/trunk/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp?rev=289050&r1=289049&r2=289050&view=diff
==============================================================================
--- llvm/trunk/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp (original)
+++ llvm/trunk/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp Thu Dec  8 08:08:14 2016
@@ -333,6 +333,8 @@ SDValue VectorLegalizer::LegalizeOp(SDVa
   case ISD::SMAX:
   case ISD::UMIN:
   case ISD::UMAX:
+  case ISD::SMUL_LOHI:
+  case ISD::UMUL_LOHI:
     QueryType = Node->getValueType(0);
     break;
   case ISD::FP_ROUND_INREG:

Modified: llvm/trunk/lib/CodeGen/SelectionDAG/TargetLowering.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/CodeGen/SelectionDAG/TargetLowering.cpp?rev=289050&r1=289049&r2=289050&view=diff
==============================================================================
--- llvm/trunk/lib/CodeGen/SelectionDAG/TargetLowering.cpp (original)
+++ llvm/trunk/lib/CodeGen/SelectionDAG/TargetLowering.cpp Thu Dec  8 08:08:14 2016
@@ -3093,24 +3093,29 @@ verifyReturnAddressArgumentIsConstant(SD
 // Legalization Utilities
 //===----------------------------------------------------------------------===//
 
-bool TargetLowering::expandMUL(SDNode *N, SDValue &Lo, SDValue &Hi, EVT HiLoVT,
-                               SelectionDAG &DAG, SDValue LL, SDValue LH,
-                               SDValue RL, SDValue RH) const {
-  EVT VT = N->getValueType(0);
-  SDLoc dl(N);
-
-  bool HasMULHS = isOperationLegalOrCustom(ISD::MULHS, HiLoVT);
-  bool HasMULHU = isOperationLegalOrCustom(ISD::MULHU, HiLoVT);
-  bool HasSMUL_LOHI = isOperationLegalOrCustom(ISD::SMUL_LOHI, HiLoVT);
-  bool HasUMUL_LOHI = isOperationLegalOrCustom(ISD::UMUL_LOHI, HiLoVT);
+bool TargetLowering::expandMUL_LOHI(unsigned Opcode, EVT VT, SDLoc dl,
+                                    SDValue LHS, SDValue RHS,
+                                    SmallVectorImpl<SDValue> &Result,
+                                    EVT HiLoVT, SelectionDAG &DAG,
+                                    MulExpansionKind Kind, SDValue LL,
+                                    SDValue LH, SDValue RL, SDValue RH) const {
+  assert(Opcode == ISD::MUL || Opcode == ISD::UMUL_LOHI ||
+         Opcode == ISD::SMUL_LOHI);
+
+  bool HasMULHS = (Kind == MulExpansionKind::Always) ||
+                  isOperationLegalOrCustom(ISD::MULHS, HiLoVT);
+  bool HasMULHU = (Kind == MulExpansionKind::Always) ||
+                  isOperationLegalOrCustom(ISD::MULHU, HiLoVT);
+  bool HasSMUL_LOHI = (Kind == MulExpansionKind::Always) ||
+                      isOperationLegalOrCustom(ISD::SMUL_LOHI, HiLoVT);
+  bool HasUMUL_LOHI = (Kind == MulExpansionKind::Always) ||
+                      isOperationLegalOrCustom(ISD::UMUL_LOHI, HiLoVT);
 
   if (!HasMULHU && !HasMULHS && !HasUMUL_LOHI && !HasSMUL_LOHI)
     return false;
 
   unsigned OuterBitSize = VT.getScalarSizeInBits();
   unsigned InnerBitSize = HiLoVT.getScalarSizeInBits();
-  SDValue LHS = N->getOperand(0);
-  SDValue RHS = N->getOperand(1);
   unsigned LHSSB = DAG.ComputeNumSignBits(LHS);
   unsigned RHSSB = DAG.ComputeNumSignBits(RHS);
 
@@ -3134,6 +3139,8 @@ bool TargetLowering::expandMUL(SDNode *N
     return false;
   };
 
+  SDValue Lo, Hi;
+
   if (!LL.getNode() && !RL.getNode() &&
       isOperationLegalOrCustom(ISD::TRUNCATE, HiLoVT)) {
     LL = DAG.getNode(ISD::TRUNCATE, dl, HiLoVT, LHS);
@@ -3144,20 +3151,41 @@ bool TargetLowering::expandMUL(SDNode *N
     return false;
 
   APInt HighMask = APInt::getHighBitsSet(OuterBitSize, InnerBitSize);
-  if (DAG.MaskedValueIsZero(N->getOperand(0), HighMask) &&
-      DAG.MaskedValueIsZero(N->getOperand(1), HighMask)) {
+  if (DAG.MaskedValueIsZero(LHS, HighMask) &&
+      DAG.MaskedValueIsZero(RHS, HighMask)) {
     // The inputs are both zero-extended.
-    if (MakeMUL_LOHI(LL, RL, Lo, Hi, false))
+    if (MakeMUL_LOHI(LL, RL, Lo, Hi, false)) {
+      Result.push_back(Lo);
+      Result.push_back(Hi);
+      if (Opcode != ISD::MUL) {
+        SDValue Zero = DAG.getConstant(0, dl, HiLoVT);
+        Result.push_back(Zero);
+        Result.push_back(Zero);
+      }
       return true;
+    }
   }
-  if (LHSSB > InnerBitSize && RHSSB > InnerBitSize) {
+
+  if (!VT.isVector() && Opcode == ISD::MUL && LHSSB > InnerBitSize &&
+      RHSSB > InnerBitSize) {
     // The input values are both sign-extended.
-    if (MakeMUL_LOHI(LL, RL, Lo, Hi, true))
+    // TODO non-MUL case?
+    if (MakeMUL_LOHI(LL, RL, Lo, Hi, true)) {
+      Result.push_back(Lo);
+      Result.push_back(Hi);
       return true;
+    }
   }
 
-  SDValue Shift = DAG.getConstant(OuterBitSize - InnerBitSize, dl,
-                                  getShiftAmountTy(VT, DAG.getDataLayout()));
+  unsigned ShiftAmount = OuterBitSize - InnerBitSize;
+  EVT ShiftAmountTy = getShiftAmountTy(VT, DAG.getDataLayout());
+  if (APInt::getMaxValue(ShiftAmountTy.getSizeInBits()).ult(ShiftAmount)) {
+    // FIXME getShiftAmountTy does not always return a sensible result when VT
+    // is an illegal type, and so the type may be too small to fit the shift
+    // amount. Override it with i32. The shift will have to be legalized.
+    ShiftAmountTy = MVT::i32;
+  }
+  SDValue Shift = DAG.getConstant(ShiftAmount, dl, ShiftAmountTy);
 
   if (!LH.getNode() && !RH.getNode() &&
       isOperationLegalOrCustom(ISD::SRL, VT) &&
@@ -3171,15 +3199,84 @@ bool TargetLowering::expandMUL(SDNode *N
   if (!LH.getNode())
     return false;
 
-  if (MakeMUL_LOHI(LL, RL, Lo, Hi, false)) {
+  if (!MakeMUL_LOHI(LL, RL, Lo, Hi, false))
+    return false;
+
+  Result.push_back(Lo);
+
+  if (Opcode == ISD::MUL) {
     RH = DAG.getNode(ISD::MUL, dl, HiLoVT, LL, RH);
     LH = DAG.getNode(ISD::MUL, dl, HiLoVT, LH, RL);
     Hi = DAG.getNode(ISD::ADD, dl, HiLoVT, Hi, RH);
     Hi = DAG.getNode(ISD::ADD, dl, HiLoVT, Hi, LH);
+    Result.push_back(Hi);
     return true;
   }
 
-  return false;
+  // Compute the full width result.
+  auto Merge = [&](SDValue Lo, SDValue Hi) -> SDValue {
+    Lo = DAG.getNode(ISD::ZERO_EXTEND, dl, VT, Lo);
+    Hi = DAG.getNode(ISD::ZERO_EXTEND, dl, VT, Hi);
+    Hi = DAG.getNode(ISD::SHL, dl, VT, Hi, Shift);
+    return DAG.getNode(ISD::OR, dl, VT, Lo, Hi);
+  };
+
+  SDValue Next = DAG.getNode(ISD::ZERO_EXTEND, dl, VT, Hi);
+  if (!MakeMUL_LOHI(LL, RH, Lo, Hi, false))
+    return false;
+
+  // This is effectively the add part of a multiply-add of half-sized operands,
+  // so it cannot overflow.
+  Next = DAG.getNode(ISD::ADD, dl, VT, Next, Merge(Lo, Hi));
+
+  if (!MakeMUL_LOHI(LH, RL, Lo, Hi, false))
+    return false;
+
+  Next = DAG.getNode(ISD::ADDC, dl, DAG.getVTList(VT, MVT::Glue), Next,
+                     Merge(Lo, Hi));
+
+  SDValue Carry = Next.getValue(1);
+  Result.push_back(DAG.getNode(ISD::TRUNCATE, dl, HiLoVT, Next));
+  Next = DAG.getNode(ISD::SRL, dl, VT, Next, Shift);
+
+  if (!MakeMUL_LOHI(LH, RH, Lo, Hi, Opcode == ISD::SMUL_LOHI))
+    return false;
+
+  SDValue Zero = DAG.getConstant(0, dl, HiLoVT);
+  Hi = DAG.getNode(ISD::ADDE, dl, DAG.getVTList(HiLoVT, MVT::Glue), Hi, Zero,
+                   Carry);
+  Next = DAG.getNode(ISD::ADD, dl, VT, Next, Merge(Lo, Hi));
+
+  if (Opcode == ISD::SMUL_LOHI) {
+    SDValue NextSub = DAG.getNode(ISD::SUB, dl, VT, Next,
+                                  DAG.getNode(ISD::ZERO_EXTEND, dl, VT, RL));
+    Next = DAG.getSelectCC(dl, LH, Zero, NextSub, Next, ISD::SETLT);
+
+    NextSub = DAG.getNode(ISD::SUB, dl, VT, Next,
+                          DAG.getNode(ISD::ZERO_EXTEND, dl, VT, LL));
+    Next = DAG.getSelectCC(dl, RH, Zero, NextSub, Next, ISD::SETLT);
+  }
+
+  Result.push_back(DAG.getNode(ISD::TRUNCATE, dl, HiLoVT, Next));
+  Next = DAG.getNode(ISD::SRL, dl, VT, Next, Shift);
+  Result.push_back(DAG.getNode(ISD::TRUNCATE, dl, HiLoVT, Next));
+  return true;
+}
+
+bool TargetLowering::expandMUL(SDNode *N, SDValue &Lo, SDValue &Hi, EVT HiLoVT,
+                               SelectionDAG &DAG, MulExpansionKind Kind,
+                               SDValue LL, SDValue LH, SDValue RL,
+                               SDValue RH) const {
+  SmallVector<SDValue, 2> Result;
+  bool Ok = expandMUL_LOHI(N->getOpcode(), N->getValueType(0), N,
+                           N->getOperand(0), N->getOperand(1), Result, HiLoVT,
+                           DAG, Kind, LL, LH, RL, RH);
+  if (Ok) {
+    assert(Result.size() == 2);
+    Lo = Result[0];
+    Hi = Result[1];
+  }
+  return Ok;
 }
 
 bool TargetLowering::expandFP_TO_SINT(SDNode *Node, SDValue &Result,




More information about the llvm-commits mailing list