[llvm] f678fc7 - [LegalizeVectorOps] Improve handling of multi-result operations.

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Fri Jan 10 10:15:39 PST 2020


Author: Craig Topper
Date: 2020-01-10T10:14:58-08:00
New Revision: f678fc7660b36ce0ad6ce4f05eaa28f3e9fdedb5

URL: https://github.com/llvm/llvm-project/commit/f678fc7660b36ce0ad6ce4f05eaa28f3e9fdedb5
DIFF: https://github.com/llvm/llvm-project/commit/f678fc7660b36ce0ad6ce4f05eaa28f3e9fdedb5.diff

LOG: [LegalizeVectorOps] Improve handling of multi-result operations.

This system wasn't very well designed for multi-result nodes. As
a consequence they weren't consistently registered in the
LegalizedNodes map leading to nodes being revisited for different
results.

I've removed the "Result" variable from the main LegalizeOp method
and used a SDNode* instead. The result number from the incoming
Op SDValue is only used for deciding which result to return to the
caller. When LegalizeOp is called it should always register a
legalized result for all of its results. Future calls for any other
result should be pulled for the LegalizedNodes map.

Legal nodes will now register all of their results in the map
instead of just the one we were called for.

The Expand and Promote handling to use a vector of results similar
to LegalizeDAG. Each of the new results is then re-legalized and
logged in the LegalizedNodes map for all of the Results for the
node being legalized. None of the handles register their own
results now. And none call ReplaceAllUsesOfValueWith now.

Custom handling now always passes result number 0 to LowerOperation.
This matches what LegalizeDAG does. Since the introduction of
STRICT nodes, I've encountered several issues with X86's custom
handling being called with an SDValue pointing at the chain and
our custom handlers using that to get a VT instead of result 0.
This should prevent us from having any more of those issues. On
return we will update the LegalizedNodes map for all results so
we shouldn't call the custom handler again for each result number.

I want to push SDNode* further into the Expand and Promote
handlers, but I've left that for a follow to keep this patch size
down. I've created a dummy SDValue(Node, 0) to keep the handlers
working.

Differential Revision: https://reviews.llvm.org/D72224

Added: 
    

Modified: 
    llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
    llvm/test/CodeGen/X86/avx512-cmp.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index 13813008eff0..557bf495c85d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -75,7 +75,17 @@ class VectorLegalizer {
   SDValue LegalizeOp(SDValue Op);
 
   /// Assuming the node is legal, "legalize" the results.
-  SDValue TranslateLegalizeResults(SDValue Op, SDValue Result);
+  SDValue TranslateLegalizeResults(SDValue Op, SDNode *Result);
+
+  /// Make sure Results are legal and update the translation cache.
+  SDValue RecursivelyLegalizeResults(SDValue Op,
+                                     MutableArrayRef<SDValue> Results);
+
+  /// Wrapper to interface LowerOperation with a vector of Results.
+  /// Returns false if the target wants to use default expansion. Otherwise
+  /// returns true. If return is true and the Results are empty, then the
+  /// target wants to keep the input node as is.
+  bool LowerOperationWrapper(SDNode *N, SmallVectorImpl<SDValue> &Results);
 
   /// Implements unrolling a VSETCC.
   SDValue UnrollVSETCC(SDValue Op);
@@ -84,15 +94,15 @@ class VectorLegalizer {
   ///
   /// This is just a high-level routine to dispatch to specific code paths for
   /// operations to legalize them.
-  SDValue Expand(SDValue Op);
+  void Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results);
 
   /// Implements expansion for FP_TO_UINT; falls back to UnrollVectorOp if
   /// FP_TO_SINT isn't legal.
-  SDValue ExpandFP_TO_UINT(SDValue Op);
+  void ExpandFP_TO_UINT(SDValue Op, SmallVectorImpl<SDValue> &Results);
 
   /// Implements expansion for UINT_TO_FLOAT; falls back to UnrollVectorOp if
   /// SINT_TO_FLOAT and SHR on vectors isn't legal.
-  SDValue ExpandUINT_TO_FLOAT(SDValue Op);
+  void ExpandUINT_TO_FLOAT(SDValue Op, SmallVectorImpl<SDValue> &Results);
 
   /// Implement expansion for SIGN_EXTEND_INREG using SRL and SRA.
   SDValue ExpandSEXTINREG(SDValue Op);
@@ -130,8 +140,8 @@ class VectorLegalizer {
   /// supported by the target.
   SDValue ExpandVSELECT(SDValue Op);
   SDValue ExpandSELECT(SDValue Op);
-  std::pair<SDValue, SDValue> ExpandLoad(SDValue Op);
-  SDValue ExpandStore(SDValue Op);
+  std::pair<SDValue, SDValue> ExpandLoad(SDNode *N);
+  SDValue ExpandStore(SDNode *N);
   SDValue ExpandFNEG(SDValue Op);
   SDValue ExpandFSUB(SDValue Op);
   SDValue ExpandBITREVERSE(SDValue Op);
@@ -141,32 +151,33 @@ class VectorLegalizer {
   SDValue ExpandFunnelShift(SDValue Op);
   SDValue ExpandROT(SDValue Op);
   SDValue ExpandFMINNUM_FMAXNUM(SDValue Op);
-  SDValue ExpandUADDSUBO(SDValue Op);
-  SDValue ExpandSADDSUBO(SDValue Op);
-  SDValue ExpandMULO(SDValue Op);
+  void ExpandUADDSUBO(SDValue Op, SmallVectorImpl<SDValue> &Results);
+  void ExpandSADDSUBO(SDValue Op, SmallVectorImpl<SDValue> &Results);
+  void ExpandMULO(SDValue Op, SmallVectorImpl<SDValue> &Results);
   SDValue ExpandAddSubSat(SDValue Op);
   SDValue ExpandFixedPointMul(SDValue Op);
   SDValue ExpandFixedPointDiv(SDValue Op);
   SDValue ExpandStrictFPOp(SDValue Op);
+  void ExpandStrictFPOp(SDValue Op, SmallVectorImpl<SDValue> &Results);
 
-  SDValue UnrollStrictFPOp(SDValue Op);
+  void UnrollStrictFPOp(SDValue Op, SmallVectorImpl<SDValue> &Results);
 
   /// Implements vector promotion.
   ///
   /// This is essentially just bitcasting the operands to a 
diff erent type and
   /// bitcasting the result back to the original type.
-  SDValue Promote(SDValue Op);
+  void Promote(SDNode *Node, SmallVectorImpl<SDValue> &Results);
 
   /// Implements [SU]INT_TO_FP vector promotion.
   ///
   /// This is a [zs]ext of the input operand to a larger integer type.
-  SDValue PromoteINT_TO_FP(SDValue Op);
+  void PromoteINT_TO_FP(SDValue Op, SmallVectorImpl<SDValue> &Results);
 
   /// Implements FP_TO_[SU]INT vector promotion of the result type.
   ///
   /// It is promoted to a larger integer type.  The result is then
   /// truncated back to the original type.
-  SDValue PromoteFP_TO_INT(SDValue Op);
+  void PromoteFP_TO_INT(SDValue Op, SmallVectorImpl<SDValue> &Results);
 
 public:
   VectorLegalizer(SelectionDAG& dag) :
@@ -222,11 +233,27 @@ bool VectorLegalizer::Run() {
   return Changed;
 }
 
-SDValue VectorLegalizer::TranslateLegalizeResults(SDValue Op, SDValue Result) {
+SDValue VectorLegalizer::TranslateLegalizeResults(SDValue Op, SDNode *Result) {
+  assert(Op->getNumValues() == Result->getNumValues() &&
+         "Unexpected number of results");
   // Generic legalization: just pass the operand through.
-  for (unsigned i = 0, e = Op.getNode()->getNumValues(); i != e; ++i)
-    AddLegalizedOperand(Op.getValue(i), Result.getValue(i));
-  return Result.getValue(Op.getResNo());
+  for (unsigned i = 0, e = Op->getNumValues(); i != e; ++i)
+    AddLegalizedOperand(Op.getValue(i), SDValue(Result, i));
+  return SDValue(Result, Op.getResNo());
+}
+
+SDValue
+VectorLegalizer::RecursivelyLegalizeResults(SDValue Op,
+                                            MutableArrayRef<SDValue> Results) {
+  assert(Results.size() == Op->getNumValues() &&
+         "Unexpected number of results");
+  // Make sure that the generated code is itself legal.
+  for (unsigned i = 0, e = Results.size(); i != e; ++i) {
+    Results[i] = LegalizeOp(Results[i]);
+    AddLegalizedOperand(Op.getValue(i), Results[i]);
+  }
+
+  return Results[Op.getResNo()];
 }
 
 SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
@@ -235,18 +262,15 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
   DenseMap<SDValue, SDValue>::iterator I = LegalizedNodes.find(Op);
   if (I != LegalizedNodes.end()) return I->second;
 
-  SDNode* Node = Op.getNode();
-
   // Legalize the operands
   SmallVector<SDValue, 8> Ops;
-  for (const SDValue &Op : Node->op_values())
-    Ops.push_back(LegalizeOp(Op));
+  for (const SDValue &Oper : Op->op_values())
+    Ops.push_back(LegalizeOp(Oper));
 
-  SDValue Result = SDValue(DAG.UpdateNodeOperands(Op.getNode(), Ops),
-                           Op.getResNo());
+  SDNode *Node = DAG.UpdateNodeOperands(Op.getNode(), Ops);
 
   if (Op.getOpcode() == ISD::LOAD) {
-    LoadSDNode *LD = cast<LoadSDNode>(Op.getNode());
+    LoadSDNode *LD = cast<LoadSDNode>(Node);
     ISD::LoadExtType ExtType = LD->getExtensionType();
     if (LD->getMemoryVT().isVector() && ExtType != ISD::NON_EXTLOAD) {
       LLVM_DEBUG(dbgs() << "\nLegalizing extending vector load: ";
@@ -255,22 +279,21 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
                                    LD->getMemoryVT())) {
       default: llvm_unreachable("This action is not supported yet!");
       case TargetLowering::Legal:
-        return TranslateLegalizeResults(Op, Result);
-      case TargetLowering::Custom:
-        if (SDValue Lowered = TLI.LowerOperation(Result, DAG)) {
-          assert(Lowered->getNumValues() == Op->getNumValues() &&
-                 "Unexpected number of results");
-          if (Lowered != Result) {
-            // Make sure the new code is also legal.
-            Lowered = LegalizeOp(Lowered);
-            Changed = true;
-          }
-          return TranslateLegalizeResults(Op, Lowered);
+        return TranslateLegalizeResults(Op, Node);
+      case TargetLowering::Custom: {
+        SmallVector<SDValue, 2> ResultVals;
+        if (LowerOperationWrapper(Node, ResultVals)) {
+          if (ResultVals.empty())
+            return TranslateLegalizeResults(Op, Node);
+
+          Changed = true;
+          return RecursivelyLegalizeResults(Op, ResultVals);
         }
         LLVM_FALLTHROUGH;
+      }
       case TargetLowering::Expand: {
         Changed = true;
-        std::pair<SDValue, SDValue> Tmp = ExpandLoad(Result);
+        std::pair<SDValue, SDValue> Tmp = ExpandLoad(Node);
         AddLegalizedOperand(Op.getValue(0), Tmp.first);
         AddLegalizedOperand(Op.getValue(1), Tmp.second);
         return Op.getResNo() ? Tmp.first : Tmp.second;
@@ -278,7 +301,7 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
       }
     }
   } else if (Op.getOpcode() == ISD::STORE) {
-    StoreSDNode *ST = cast<StoreSDNode>(Op.getNode());
+    StoreSDNode *ST = cast<StoreSDNode>(Node);
     EVT StVT = ST->getMemoryVT();
     MVT ValVT = ST->getValue().getSimpleValueType();
     if (StVT.isVector() && ST->isTruncatingStore()) {
@@ -287,19 +310,21 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
       switch (TLI.getTruncStoreAction(ValVT, StVT)) {
       default: llvm_unreachable("This action is not supported yet!");
       case TargetLowering::Legal:
-        return TranslateLegalizeResults(Op, Result);
+        return TranslateLegalizeResults(Op, Node);
       case TargetLowering::Custom: {
-        SDValue Lowered = TLI.LowerOperation(Result, DAG);
-        if (Lowered != Result) {
-          // Make sure the new code is also legal.
-          Lowered = LegalizeOp(Lowered);
+        SmallVector<SDValue, 1> ResultVals;
+        if (LowerOperationWrapper(Node, ResultVals)) {
+          if (ResultVals.empty())
+            return TranslateLegalizeResults(Op, Node);
+
           Changed = true;
+          return RecursivelyLegalizeResults(Op, ResultVals);
         }
-        return TranslateLegalizeResults(Op, Lowered);
+        LLVM_FALLTHROUGH;
       }
       case TargetLowering::Expand: {
         Changed = true;
-        SDValue Chain = ExpandStore(Result);
+        SDValue Chain = ExpandStore(Node);
         AddLegalizedOperand(Op, Chain);
         return Chain;
       }
@@ -310,17 +335,17 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
   bool HasVectorValueOrOp = false;
   for (auto J = Node->value_begin(), E = Node->value_end(); J != E; ++J)
     HasVectorValueOrOp |= J->isVector();
-  for (const SDValue &Op : Node->op_values())
-    HasVectorValueOrOp |= Op.getValueType().isVector();
+  for (const SDValue &Oper : Node->op_values())
+    HasVectorValueOrOp |= Oper.getValueType().isVector();
 
   if (!HasVectorValueOrOp)
-    return TranslateLegalizeResults(Op, Result);
+    return TranslateLegalizeResults(Op, Node);
 
   TargetLowering::LegalizeAction Action = TargetLowering::Legal;
   EVT ValVT;
   switch (Op.getOpcode()) {
   default:
-    return TranslateLegalizeResults(Op, Result);
+    return TranslateLegalizeResults(Op, Node);
 #define INSTRUCTION(NAME, NARG, ROUND_MODE, INTRINSIC, DAGN)                   \
   case ISD::STRICT_##DAGN:
 #include "llvm/IR/ConstrainedOps.def"
@@ -473,42 +498,70 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
 
   LLVM_DEBUG(dbgs() << "\nLegalizing vector op: "; Node->dump(&DAG));
 
+  SmallVector<SDValue, 8> ResultVals;
   switch (Action) {
   default: llvm_unreachable("This action is not supported yet!");
   case TargetLowering::Promote:
-    Result = Promote(Op);
-    Changed = true;
+    LLVM_DEBUG(dbgs() << "Promoting\n");
+    Promote(Node, ResultVals);
+    assert(!ResultVals.empty() && "No results for promotion?");
     break;
   case TargetLowering::Legal:
     LLVM_DEBUG(dbgs() << "Legal node: nothing to do\n");
     break;
-  case TargetLowering::Custom: {
+  case TargetLowering::Custom:
     LLVM_DEBUG(dbgs() << "Trying custom legalization\n");
-    if (SDValue Tmp1 = TLI.LowerOperation(Op, DAG)) {
-      LLVM_DEBUG(dbgs() << "Successfully custom legalized node\n");
-      Result = Tmp1;
+    if (LowerOperationWrapper(Node, ResultVals))
       break;
-    }
     LLVM_DEBUG(dbgs() << "Could not custom legalize node\n");
     LLVM_FALLTHROUGH;
-  }
   case TargetLowering::Expand:
-    Result = Expand(Op);
+    LLVM_DEBUG(dbgs() << "Expanding\n");
+    Expand(Node, ResultVals);
+    break;
   }
 
-  // Make sure that the generated code is itself legal.
-  if (Result != Op) {
-    Result = LegalizeOp(Result);
-    Changed = true;
+  if (ResultVals.empty())
+    return TranslateLegalizeResults(Op, Node);
+
+  Changed = true;
+  return RecursivelyLegalizeResults(Op, ResultVals);
+}
+
+// FIME: This is very similar to the X86 override of
+// TargetLowering::LowerOperationWrapper. Can we merge them somehow?
+bool VectorLegalizer::LowerOperationWrapper(SDNode *Node,
+                                            SmallVectorImpl<SDValue> &Results) {
+  SDValue Res = TLI.LowerOperation(SDValue(Node, 0), DAG);
+
+  if (!Res.getNode())
+    return false;
+
+  if (Res == SDValue(Node, 0))
+    return true;
+
+  // If the original node has one result, take the return value from
+  // LowerOperation as is. It might not be result number 0.
+  if (Node->getNumValues() == 1) {
+    Results.push_back(Res);
+    return true;
   }
 
-  // Note that LegalizeOp may be reentered even from single-use nodes, which
-  // means that we always must cache transformed nodes.
-  AddLegalizedOperand(Op, Result);
-  return Result;
+  // If the original node has multiple results, then the return node should
+  // have the same number of results.
+  assert((Node->getNumValues() == Res->getNumValues()) &&
+         "Lowering returned the wrong number of results!");
+
+  // Places new result values base on N result number.
+  for (unsigned I = 0, E = Node->getNumValues(); I != E; ++I)
+    Results.push_back(Res.getValue(I));
+
+  return true;
 }
 
-SDValue VectorLegalizer::Promote(SDValue Op) {
+void VectorLegalizer::Promote(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
+  SDValue Op(Node, 0); // FIXME: Use Node throughout.
+
   // For a few operations there is a specific concept for promotion based on
   // the operand's type.
   switch (Op.getOpcode()) {
@@ -517,13 +570,15 @@ SDValue VectorLegalizer::Promote(SDValue Op) {
   case ISD::STRICT_SINT_TO_FP:
   case ISD::STRICT_UINT_TO_FP:
     // "Promote" the operation by extending the operand.
-    return PromoteINT_TO_FP(Op);
+    PromoteINT_TO_FP(Op, Results);
+    return;
   case ISD::FP_TO_UINT:
   case ISD::FP_TO_SINT:
   case ISD::STRICT_FP_TO_UINT:
   case ISD::STRICT_FP_TO_SINT:
     // Promote the operation by extending the operand.
-    return PromoteFP_TO_INT(Op);
+    PromoteFP_TO_INT(Op, Results);
+    return;
   case ISD::FP_ROUND:
   case ISD::FP_EXTEND:
     // These operations are used to do promotion so they can't be promoted
@@ -558,15 +613,20 @@ SDValue VectorLegalizer::Promote(SDValue Op) {
   }
 
   Op = DAG.getNode(Op.getOpcode(), dl, NVT, Operands, Op.getNode()->getFlags());
+
+  SDValue Res;
   if ((VT.isFloatingPoint() && NVT.isFloatingPoint()) ||
       (VT.isVector() && VT.getVectorElementType().isFloatingPoint() &&
        NVT.isVector() && NVT.getVectorElementType().isFloatingPoint()))
-    return DAG.getNode(ISD::FP_ROUND, dl, VT, Op, DAG.getIntPtrConstant(0, dl));
+    Res = DAG.getNode(ISD::FP_ROUND, dl, VT, Op, DAG.getIntPtrConstant(0, dl));
   else
-    return DAG.getNode(ISD::BITCAST, dl, VT, Op);
+    Res = DAG.getNode(ISD::BITCAST, dl, VT, Op);
+
+  Results.push_back(Res);
 }
 
-SDValue VectorLegalizer::PromoteINT_TO_FP(SDValue Op) {
+void VectorLegalizer::PromoteINT_TO_FP(SDValue Op,
+                                       SmallVectorImpl<SDValue> &Results) {
   // INT_TO_FP operations may require the input operand be promoted even
   // when the type is otherwise legal.
   bool IsStrict = Op->isStrictFPOpcode();
@@ -589,18 +649,24 @@ SDValue VectorLegalizer::PromoteINT_TO_FP(SDValue Op) {
       Operands[j] = Op.getOperand(j);
   }
 
-  if (IsStrict)
-    return DAG.getNode(Op.getOpcode(), dl, {Op.getValueType(), MVT::Other},
-                       Operands);
+  if (IsStrict) {
+    SDValue Res = DAG.getNode(Op.getOpcode(), dl,
+                              {Op.getValueType(), MVT::Other}, Operands);
+    Results.push_back(Res);
+    Results.push_back(Res.getValue(1));
+    return;
+  }
 
-  return DAG.getNode(Op.getOpcode(), dl, Op.getValueType(), Operands);
+  SDValue Res = DAG.getNode(Op.getOpcode(), dl, Op.getValueType(), Operands);
+  Results.push_back(Res);
 }
 
 // For FP_TO_INT we promote the result type to a vector type with wider
 // elements and then truncate the result.  This is 
diff erent from the default
 // PromoteVector which uses bitcast to promote thus assumning that the
 // promoted vector type has the same overall size.
-SDValue VectorLegalizer::PromoteFP_TO_INT(SDValue Op) {
+void VectorLegalizer::PromoteFP_TO_INT(SDValue Op,
+                                       SmallVectorImpl<SDValue> &Results) {
   MVT VT = Op.getSimpleValueType();
   MVT NVT = TLI.getTypeToPromoteTo(Op.getOpcode(), VT);
   bool IsStrict = Op->isStrictFPOpcode();
@@ -639,14 +705,13 @@ SDValue VectorLegalizer::PromoteFP_TO_INT(SDValue Op) {
   Promoted = DAG.getNode(NewOpc, dl, NVT, Promoted,
                          DAG.getValueType(VT.getScalarType()));
   Promoted = DAG.getNode(ISD::TRUNCATE, dl, VT, Promoted);
+  Results.push_back(Promoted);
   if (IsStrict)
-    return DAG.getMergeValues({Promoted, Chain}, dl);
-
-  return Promoted;
+    Results.push_back(Chain);
 }
 
-std::pair<SDValue, SDValue> VectorLegalizer::ExpandLoad(SDValue Op) {
-  LoadSDNode *LD = cast<LoadSDNode>(Op.getNode());
+std::pair<SDValue, SDValue> VectorLegalizer::ExpandLoad(SDNode *N) {
+  LoadSDNode *LD = cast<LoadSDNode>(N);
 
   EVT SrcVT = LD->getMemoryVT();
   EVT SrcEltVT = SrcVT.getScalarType();
@@ -655,7 +720,7 @@ std::pair<SDValue, SDValue> VectorLegalizer::ExpandLoad(SDValue Op) {
   SDValue NewChain;
   SDValue Value;
   if (SrcVT.getVectorNumElements() > 1 && !SrcEltVT.isByteSized()) {
-    SDLoc dl(Op);
+    SDLoc dl(N);
 
     SmallVector<SDValue, 8> Vals;
     SmallVector<SDValue, 8> LoadChains;
@@ -767,7 +832,7 @@ std::pair<SDValue, SDValue> VectorLegalizer::ExpandLoad(SDValue Op) {
     }
 
     NewChain = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, LoadChains);
-    Value = DAG.getBuildVector(Op.getNode()->getValueType(0), dl, Vals);
+    Value = DAG.getBuildVector(N->getValueType(0), dl, Vals);
   } else {
     std::tie(Value, NewChain) = TLI.scalarizeVectorLoad(LD, DAG);
   }
@@ -775,90 +840,122 @@ std::pair<SDValue, SDValue> VectorLegalizer::ExpandLoad(SDValue Op) {
   return std::make_pair(Value, NewChain);
 }
 
-SDValue VectorLegalizer::ExpandStore(SDValue Op) {
-  StoreSDNode *ST = cast<StoreSDNode>(Op.getNode());
+SDValue VectorLegalizer::ExpandStore(SDNode *N) {
+  StoreSDNode *ST = cast<StoreSDNode>(N);
   SDValue TF = TLI.scalarizeVectorStore(ST, DAG);
   return TF;
 }
 
-SDValue VectorLegalizer::Expand(SDValue Op) {
+void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
+  SDValue Op(Node, 0); // FIXME: Just pass Node to all the expanders.
+
   switch (Op->getOpcode()) {
   case ISD::SIGN_EXTEND_INREG:
-    return ExpandSEXTINREG(Op);
+    Results.push_back(ExpandSEXTINREG(Op));
+    return;
   case ISD::ANY_EXTEND_VECTOR_INREG:
-    return ExpandANY_EXTEND_VECTOR_INREG(Op);
+    Results.push_back(ExpandANY_EXTEND_VECTOR_INREG(Op));
+    return;
   case ISD::SIGN_EXTEND_VECTOR_INREG:
-    return ExpandSIGN_EXTEND_VECTOR_INREG(Op);
+    Results.push_back(ExpandSIGN_EXTEND_VECTOR_INREG(Op));
+    return;
   case ISD::ZERO_EXTEND_VECTOR_INREG:
-    return ExpandZERO_EXTEND_VECTOR_INREG(Op);
+    Results.push_back(ExpandZERO_EXTEND_VECTOR_INREG(Op));
+    return;
   case ISD::BSWAP:
-    return ExpandBSWAP(Op);
+    Results.push_back(ExpandBSWAP(Op));
+    return;
   case ISD::VSELECT:
-    return ExpandVSELECT(Op);
+    Results.push_back(ExpandVSELECT(Op));
+    return;
   case ISD::SELECT:
-    return ExpandSELECT(Op);
+    Results.push_back(ExpandSELECT(Op));
+    return;
   case ISD::FP_TO_UINT:
-    return ExpandFP_TO_UINT(Op);
+    ExpandFP_TO_UINT(Op, Results);
+    return;
   case ISD::UINT_TO_FP:
-    return ExpandUINT_TO_FLOAT(Op);
+    ExpandUINT_TO_FLOAT(Op, Results);
+    return;
   case ISD::FNEG:
-    return ExpandFNEG(Op);
+    Results.push_back(ExpandFNEG(Op));
+    return;
   case ISD::FSUB:
-    return ExpandFSUB(Op);
+    if (SDValue Tmp = ExpandFSUB(Op))
+      Results.push_back(Tmp);
+    return;
   case ISD::SETCC:
-    return UnrollVSETCC(Op);
+    Results.push_back(UnrollVSETCC(Op));
+    return;
   case ISD::ABS:
-    return ExpandABS(Op);
+    Results.push_back(ExpandABS(Op));
+    return;
   case ISD::BITREVERSE:
-    return ExpandBITREVERSE(Op);
+    if (SDValue Tmp = ExpandBITREVERSE(Op))
+      Results.push_back(Tmp);
+    return;
   case ISD::CTPOP:
-    return ExpandCTPOP(Op);
+    Results.push_back(ExpandCTPOP(Op));
+    return;
   case ISD::CTLZ:
   case ISD::CTLZ_ZERO_UNDEF:
-    return ExpandCTLZ(Op);
+    Results.push_back(ExpandCTLZ(Op));
+    return;
   case ISD::CTTZ:
   case ISD::CTTZ_ZERO_UNDEF:
-    return ExpandCTTZ(Op);
+    Results.push_back(ExpandCTTZ(Op));
+    return;
   case ISD::FSHL:
   case ISD::FSHR:
-    return ExpandFunnelShift(Op);
+    Results.push_back(ExpandFunnelShift(Op));
+    return;
   case ISD::ROTL:
   case ISD::ROTR:
-    return ExpandROT(Op);
+    Results.push_back(ExpandROT(Op));
+    return;
   case ISD::FMINNUM:
   case ISD::FMAXNUM:
-    return ExpandFMINNUM_FMAXNUM(Op);
+    Results.push_back(ExpandFMINNUM_FMAXNUM(Op));
+    return;
   case ISD::UADDO:
   case ISD::USUBO:
-    return ExpandUADDSUBO(Op);
+    ExpandUADDSUBO(Op, Results);
+    return;
   case ISD::SADDO:
   case ISD::SSUBO:
-    return ExpandSADDSUBO(Op);
+    ExpandSADDSUBO(Op, Results);
+    return;
   case ISD::UMULO:
   case ISD::SMULO:
-    return ExpandMULO(Op);
+    ExpandMULO(Op, Results);
+    return;
   case ISD::USUBSAT:
   case ISD::SSUBSAT:
   case ISD::UADDSAT:
   case ISD::SADDSAT:
-    return ExpandAddSubSat(Op);
+    Results.push_back(ExpandAddSubSat(Op));
+    return;
   case ISD::SMULFIX:
   case ISD::UMULFIX:
-    return ExpandFixedPointMul(Op);
+    Results.push_back(ExpandFixedPointMul(Op));
+    return;
   case ISD::SMULFIXSAT:
   case ISD::UMULFIXSAT:
     // FIXME: We do not expand SMULFIXSAT/UMULFIXSAT here yet, not sure exactly
     // why. Maybe it results in worse codegen compared to the unroll for some
     // targets? This should probably be investigated. And if we still prefer to
     // unroll an explanation could be helpful.
-    return DAG.UnrollVectorOp(Op.getNode());
+    Results.push_back(DAG.UnrollVectorOp(Op.getNode()));
+    return;
   case ISD::SDIVFIX:
   case ISD::UDIVFIX:
-    return ExpandFixedPointDiv(Op);
+    Results.push_back(ExpandFixedPointDiv(Op));
+    return;
 #define INSTRUCTION(NAME, NARG, ROUND_MODE, INTRINSIC, DAGN)                   \
   case ISD::STRICT_##DAGN:
 #include "llvm/IR/ConstrainedOps.def"
-    return ExpandStrictFPOp(Op);
+    ExpandStrictFPOp(Op, Results);
+    return;
   case ISD::VECREDUCE_ADD:
   case ISD::VECREDUCE_MUL:
   case ISD::VECREDUCE_AND:
@@ -872,9 +969,11 @@ SDValue VectorLegalizer::Expand(SDValue Op) {
   case ISD::VECREDUCE_FMUL:
   case ISD::VECREDUCE_FMAX:
   case ISD::VECREDUCE_FMIN:
-    return TLI.expandVecReduce(Op.getNode(), DAG);
+    Results.push_back(TLI.expandVecReduce(Op.getNode(), DAG));
+    return;
   default:
-    return DAG.UnrollVectorOp(Op.getNode());
+    Results.push_back(DAG.UnrollVectorOp(Op.getNode()));
+    return;
   }
 }
 
@@ -1120,7 +1219,7 @@ SDValue VectorLegalizer::ExpandBITREVERSE(SDValue Op) {
     return DAG.UnrollVectorOp(Op.getNode());
 
   // Let LegalizeDAG handle this later.
-  return Op;
+  return SDValue();
 }
 
 SDValue VectorLegalizer::ExpandVSELECT(SDValue Op) {
@@ -1180,23 +1279,28 @@ SDValue VectorLegalizer::ExpandABS(SDValue Op) {
   return DAG.UnrollVectorOp(Op.getNode());
 }
 
-SDValue VectorLegalizer::ExpandFP_TO_UINT(SDValue Op) {
+void VectorLegalizer::ExpandFP_TO_UINT(SDValue Op,
+                                       SmallVectorImpl<SDValue> &Results) {
   // Attempt to expand using TargetLowering.
   SDValue Result, Chain;
   if (TLI.expandFP_TO_UINT(Op.getNode(), Result, Chain, DAG)) {
+    Results.push_back(Result);
     if (Op->isStrictFPOpcode())
-      // Relink the chain
-      DAG.ReplaceAllUsesOfValueWith(Op.getValue(1), Chain);
-    return Result;
+      Results.push_back(Chain);
+    return;
   }
 
   // Otherwise go ahead and unroll.
-  if (Op->isStrictFPOpcode())
-    return UnrollStrictFPOp(Op);
-  return DAG.UnrollVectorOp(Op.getNode());
+  if (Op->isStrictFPOpcode()) {
+    UnrollStrictFPOp(Op, Results);
+    return;
+  }
+
+  Results.push_back(DAG.UnrollVectorOp(Op.getNode()));
 }
 
-SDValue VectorLegalizer::ExpandUINT_TO_FLOAT(SDValue Op) {
+void VectorLegalizer::ExpandUINT_TO_FLOAT(SDValue Op,
+                                          SmallVectorImpl<SDValue> &Results) {
   bool IsStrict = Op.getNode()->isStrictFPOpcode();
   unsigned OpNo = IsStrict ? 1 : 0;
   SDValue Src = Op.getOperand(OpNo);
@@ -1207,10 +1311,10 @@ SDValue VectorLegalizer::ExpandUINT_TO_FLOAT(SDValue Op) {
   SDValue Result;
   SDValue Chain;
   if (TLI.expandUINT_TO_FP(Op.getNode(), Result, Chain, DAG)) {
+    Results.push_back(Result);
     if (IsStrict)
-      // Relink the chain
-      DAG.ReplaceAllUsesOfValueWith(Op.getValue(1), Chain);
-    return Result;
+      Results.push_back(Chain);
+    return;
   }
 
   // Make sure that the SINT_TO_FP and SRL instructions are available.
@@ -1219,9 +1323,13 @@ SDValue VectorLegalizer::ExpandUINT_TO_FLOAT(SDValue Op) {
        (IsStrict && TLI.getOperationAction(ISD::STRICT_SINT_TO_FP, VT) ==
                         TargetLowering::Expand)) ||
       TLI.getOperationAction(ISD::SRL, VT) == TargetLowering::Expand) {
-    if (IsStrict)
-      return UnrollStrictFPOp(Op);
-    return DAG.UnrollVectorOp(Op.getNode());
+    if (IsStrict) {
+      UnrollStrictFPOp(Op, Results);
+      return;
+    }
+
+    Results.push_back(DAG.UnrollVectorOp(Op.getNode()));
+    return;
   }
 
   unsigned BW = VT.getScalarSizeInBits();
@@ -1261,9 +1369,9 @@ SDValue VectorLegalizer::ExpandUINT_TO_FLOAT(SDValue Op) {
         DAG.getNode(ISD::STRICT_FADD, DL, {Op.getValueType(), MVT::Other},
                     {SDValue(fLO.getNode(), 1), fHI, fLO});
 
-    // Relink the chain
-    DAG.ReplaceAllUsesOfValueWith(Op.getValue(1), SDValue(Result.getNode(), 1));
-    return Result;
+    Results.push_back(Result);
+    Results.push_back(Result.getValue(1));
+    return;
   }
 
   // Convert hi and lo to floats
@@ -1274,7 +1382,7 @@ SDValue VectorLegalizer::ExpandUINT_TO_FLOAT(SDValue Op) {
   SDValue fLO = DAG.getNode(ISD::SINT_TO_FP, DL, Op.getValueType(), LO);
 
   // Add the two halves
-  return DAG.getNode(ISD::FADD, DL, Op.getValueType(), fHI, fLO);
+  Results.push_back(DAG.getNode(ISD::FADD, DL, Op.getValueType(), fHI, fLO));
 }
 
 SDValue VectorLegalizer::ExpandFNEG(SDValue Op) {
@@ -1295,7 +1403,7 @@ SDValue VectorLegalizer::ExpandFSUB(SDValue Op) {
   EVT VT = Op.getValueType();
   if (TLI.isOperationLegalOrCustom(ISD::FNEG, VT) &&
       TLI.isOperationLegalOrCustom(ISD::FADD, VT))
-    return Op; // Defer to LegalizeDAG
+    return SDValue(); // Defer to LegalizeDAG
 
   return DAG.UnrollVectorOp(Op.getNode());
 }
@@ -1346,44 +1454,30 @@ SDValue VectorLegalizer::ExpandFMINNUM_FMAXNUM(SDValue Op) {
   return DAG.UnrollVectorOp(Op.getNode());
 }
 
-SDValue VectorLegalizer::ExpandUADDSUBO(SDValue Op) {
+void VectorLegalizer::ExpandUADDSUBO(SDValue Op,
+                                     SmallVectorImpl<SDValue> &Results) {
   SDValue Result, Overflow;
   TLI.expandUADDSUBO(Op.getNode(), Result, Overflow, DAG);
-
-  if (Op.getResNo() == 0) {
-    AddLegalizedOperand(Op.getValue(1), LegalizeOp(Overflow));
-    return Result;
-  } else {
-    AddLegalizedOperand(Op.getValue(0), LegalizeOp(Result));
-    return Overflow;
-  }
+  Results.push_back(Result);
+  Results.push_back(Overflow);
 }
 
-SDValue VectorLegalizer::ExpandSADDSUBO(SDValue Op) {
+void VectorLegalizer::ExpandSADDSUBO(SDValue Op,
+                                     SmallVectorImpl<SDValue> &Results) {
   SDValue Result, Overflow;
   TLI.expandSADDSUBO(Op.getNode(), Result, Overflow, DAG);
-
-  if (Op.getResNo() == 0) {
-    AddLegalizedOperand(Op.getValue(1), LegalizeOp(Overflow));
-    return Result;
-  } else {
-    AddLegalizedOperand(Op.getValue(0), LegalizeOp(Result));
-    return Overflow;
-  }
+  Results.push_back(Result);
+  Results.push_back(Overflow);
 }
 
-SDValue VectorLegalizer::ExpandMULO(SDValue Op) {
+void VectorLegalizer::ExpandMULO(SDValue Op,
+                                 SmallVectorImpl<SDValue> &Results) {
   SDValue Result, Overflow;
   if (!TLI.expandMULO(Op.getNode(), Result, Overflow, DAG))
     std::tie(Result, Overflow) = DAG.UnrollVectorOverflowOp(Op.getNode());
 
-  if (Op.getResNo() == 0) {
-    AddLegalizedOperand(Op.getValue(1), LegalizeOp(Overflow));
-    return Result;
-  } else {
-    AddLegalizedOperand(Op.getValue(0), LegalizeOp(Result));
-    return Overflow;
-  }
+  Results.push_back(Result);
+  Results.push_back(Overflow);
 }
 
 SDValue VectorLegalizer::ExpandAddSubSat(SDValue Op) {
@@ -1406,16 +1500,22 @@ SDValue VectorLegalizer::ExpandFixedPointDiv(SDValue Op) {
   return DAG.UnrollVectorOp(N);
 }
 
-SDValue VectorLegalizer::ExpandStrictFPOp(SDValue Op) {
-  if (Op.getOpcode() == ISD::STRICT_UINT_TO_FP)
-    return ExpandUINT_TO_FLOAT(Op);
-  if (Op.getOpcode() == ISD::STRICT_FP_TO_UINT)
-    return ExpandFP_TO_UINT(Op);
+void VectorLegalizer::ExpandStrictFPOp(SDValue Op,
+                                       SmallVectorImpl<SDValue> &Results) {
+  if (Op.getOpcode() == ISD::STRICT_UINT_TO_FP) {
+    ExpandUINT_TO_FLOAT(Op, Results);
+    return;
+  }
+  if (Op.getOpcode() == ISD::STRICT_FP_TO_UINT) {
+    ExpandFP_TO_UINT(Op, Results);
+    return;
+  }
 
-  return UnrollStrictFPOp(Op);
+  UnrollStrictFPOp(Op, Results);
 }
 
-SDValue VectorLegalizer::UnrollStrictFPOp(SDValue Op) {
+void VectorLegalizer::UnrollStrictFPOp(SDValue Op,
+                                       SmallVectorImpl<SDValue> &Results) {
   EVT VT = Op.getValue(0).getValueType();
   EVT EltVT = VT.getVectorElementType();
   unsigned NumElems = VT.getVectorNumElements();
@@ -1472,10 +1572,8 @@ SDValue VectorLegalizer::UnrollStrictFPOp(SDValue Op) {
   SDValue Result = DAG.getBuildVector(VT, dl, OpValues);
   SDValue NewChain = DAG.getNode(ISD::TokenFactor, dl, MVT::Other, OpChains);
 
-  AddLegalizedOperand(Op.getValue(0), Result);
-  AddLegalizedOperand(Op.getValue(1), NewChain);
-
-  return Op.getResNo() ? NewChain : Result;
+  Results.push_back(Result);
+  Results.push_back(NewChain);
 }
 
 SDValue VectorLegalizer::UnrollVSETCC(SDValue Op) {

diff  --git a/llvm/test/CodeGen/X86/avx512-cmp.ll b/llvm/test/CodeGen/X86/avx512-cmp.ll
index 3f3141e8876c..bd902dde2a26 100644
--- a/llvm/test/CodeGen/X86/avx512-cmp.ll
+++ b/llvm/test/CodeGen/X86/avx512-cmp.ll
@@ -181,3 +181,39 @@ if.then.i:
 if.end.i:
   ret i32 6
 }
+
+; This test previously caused an infinite loop in legalize vector ops. Due to
+; CSE triggering on the call to UpdateNodeOperands and the resulting node not
+; being passed to LowerOperation. The add is needed to force the zext into a
+; sext on that path. The shuffle keeps the zext alive. The xor somehow
+; influences the zext to be visited before the sext exposing the CSE opportunity
+; for the sext since zext of setcc is custom legalized to a sext and shift.
+define <8 x i32> @legalize_loop(<8 x double> %arg) {
+; KNL-LABEL: legalize_loop:
+; KNL:       ## %bb.0:
+; KNL-NEXT:    vxorpd %xmm1, %xmm1, %xmm1
+; KNL-NEXT:    vcmpnltpd %zmm0, %zmm1, %k1
+; KNL-NEXT:    vpternlogd $255, %zmm0, %zmm0, %zmm0 {%k1} {z}
+; KNL-NEXT:    vpsrld $31, %ymm0, %ymm1
+; KNL-NEXT:    vpshufd {{.*#+}} ymm1 = ymm1[3,2,1,0,7,6,5,4]
+; KNL-NEXT:    vpermq {{.*#+}} ymm1 = ymm1[2,3,0,1]
+; KNL-NEXT:    vpsubd %ymm0, %ymm1, %ymm0
+; KNL-NEXT:    retq
+;
+; SKX-LABEL: legalize_loop:
+; SKX:       ## %bb.0:
+; SKX-NEXT:    vxorpd %xmm1, %xmm1, %xmm1
+; SKX-NEXT:    vcmpnltpd %zmm0, %zmm1, %k0
+; SKX-NEXT:    vpmovm2d %k0, %ymm0
+; SKX-NEXT:    vpsrld $31, %ymm0, %ymm1
+; SKX-NEXT:    vpshufd {{.*#+}} ymm1 = ymm1[3,2,1,0,7,6,5,4]
+; SKX-NEXT:    vpermq {{.*#+}} ymm1 = ymm1[2,3,0,1]
+; SKX-NEXT:    vpsubd %ymm0, %ymm1, %ymm0
+; SKX-NEXT:    retq
+  %tmp = fcmp ogt <8 x double> %arg, zeroinitializer
+  %tmp1 = xor <8 x i1> %tmp, <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>
+  %tmp2 = zext <8 x i1> %tmp1 to <8 x i32>
+  %tmp3 = shufflevector <8 x i32> %tmp2, <8 x i32> undef, <8 x i32> <i32 7, i32 6, i32 5, i32 4, i32 3, i32 2, i32 1, i32 0>
+  %tmp4 = add <8 x i32> %tmp2, %tmp3
+  ret <8 x i32> %tmp4
+}


        


More information about the llvm-commits mailing list