[llvm] a200b0f - [DAG] Introduce getSplat utility for common dispatch pattern [nfc]

Philip Reames via llvm-commits llvm-commits at lists.llvm.org
Mon Oct 3 12:50:18 PDT 2022


Author: Philip Reames
Date: 2022-10-03T12:49:39-07:00
New Revision: a200b0fc256a890b3f72014d20fce9e49d75763b

URL: https://github.com/llvm/llvm-project/commit/a200b0fc256a890b3f72014d20fce9e49d75763b
DIFF: https://github.com/llvm/llvm-project/commit/a200b0fc256a890b3f72014d20fce9e49d75763b.diff

LOG: [DAG] Introduce getSplat utility for common dispatch pattern [nfc]

We have a very common pattern of dispatching between BUILD_VECTOR and SPLAT_VECTOR creation repeated in many cases in code.  Common the pattern into a utility function.

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/SelectionDAG.h
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
    llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
    llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 969199b6f381c..b6f71ab575b95 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -862,6 +862,16 @@ class SelectionDAG {
     return getNode(ISD::SPLAT_VECTOR, DL, VT, Op);
   }
 
+  /// Returns a node representing a splat of one value into all lanes
+  /// of the provided vector type.  This is a utility which returns
+  /// either a BUILD_VECTOR or SPLAT_VECTOR depending on the
+  /// scalability of the desired vector type.
+  SDValue getSplat(EVT VT, const SDLoc &DL, SDValue Op) {
+    assert(VT.isVector() && "Can't splat to non-vector type");
+    return VT.isScalableVector() ?
+      getSplatVector(VT, DL, Op) : getSplatBuildVector(VT, DL, Op);
+  }
+
   /// Returns a vector of type ResVT whose elements contain the linear sequence
   ///   <0, Step, Step * 2, Step * 3, ...>
   SDValue getStepVector(const SDLoc &DL, EVT ResVT, APInt StepVal);

diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index a6c9c46201568..9281b4eb92b05 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -3469,11 +3469,8 @@ SDValue DAGCombiner::visitSUB(SDNode *N) {
     if (VT.isVector()) {
       SDValue N1S = DAG.getSplatValue(N1, true);
       if (N1S && N1S.getOpcode() == ISD::SUB &&
-          isNullConstant(N1S.getOperand(0))) {
-        if (VT.isScalableVector())
-          return DAG.getSplatVector(VT, DL, N1S.getOperand(1));
-        return DAG.getSplatBuildVector(VT, DL, N1S.getOperand(1));
-      }
+          isNullConstant(N1S.getOperand(0)))
+        return DAG.getSplat(VT, DL, N1S.getOperand(1));
     }
   }
 
@@ -19778,11 +19775,8 @@ SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) {
   if (!IndexC) {
     // If this is variable insert to undef vector, it might be better to splat:
     // inselt undef, InVal, EltNo --> build_vector < InVal, InVal, ... >
-    if (InVec.isUndef() && TLI.shouldSplatInsEltVarIndex(VT)) {
-      if (VT.isScalableVector())
-        return DAG.getSplatVector(VT, DL, InVal);
-      return DAG.getSplatBuildVector(VT, DL, InVal);
-    }
+    if (InVec.isUndef() && TLI.shouldSplatInsEltVarIndex(VT))
+      return DAG.getSplat(VT, DL, InVal);
     return SDValue();
   }
 
@@ -23817,9 +23811,7 @@ static SDValue scalarizeBinOpOfSplats(SDNode *N, SelectionDAG &DAG,
   }
 
   // bo (splat X, Index), (splat Y, Index) --> splat (bo X, Y), Index
-  if (VT.isScalableVector())
-    return DAG.getSplatVector(VT, DL, ScalarBO);
-  return DAG.getSplatBuildVector(VT, DL, ScalarBO);
+  return DAG.getSplat(VT, DL, ScalarBO);
 }
 
 /// Visit a binary vector operation, like ADD.

diff  --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
index 0132bf4a8affb..6f0cde6fbddb3 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
@@ -963,10 +963,7 @@ SDValue VectorLegalizer::ExpandSELECT(SDNode *Node) {
                        DAG.getConstant(0, DL, BitTy));
 
   // Broadcast the mask so that the entire vector is all one or all zero.
-  if (VT.isFixedLengthVector())
-    Mask = DAG.getSplatBuildVector(MaskTy, DL, Mask);
-  else
-    Mask = DAG.getSplatVector(MaskTy, DL, Mask);
+  Mask = DAG.getSplat(MaskTy, DL, Mask);
 
   // Bitcast the operands to be the same type as the mask.
   // This is needed when we select between FP types because
@@ -1309,8 +1306,7 @@ SDValue VectorLegalizer::ExpandVP_MERGE(SDNode *Node) {
     return DAG.UnrollVectorOp(Node);
 
   SDValue StepVec = DAG.getStepVector(DL, EVLVecVT);
-  SDValue SplatEVL = IsFixedLen ? DAG.getSplatBuildVector(EVLVecVT, DL, EVL)
-                                : DAG.getSplatVector(EVLVecVT, DL, EVL);
+  SDValue SplatEVL = DAG.getSplat(EVLVecVT, DL, EVL);
   SDValue EVLMask =
       DAG.getSetCC(DL, MaskVT, StepVec, SplatEVL, ISD::CondCode::SETULT);
 

diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 3c2a1166bb63c..80703066d5f02 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -1607,11 +1607,8 @@ SDValue SelectionDAG::getConstant(const ConstantInt &Val, const SDLoc &DL,
   }
 
   SDValue Result(N, 0);
-  if (VT.isScalableVector())
-    Result = getSplatVector(VT, DL, Result);
-  else if (VT.isVector())
-    Result = getSplatBuildVector(VT, DL, Result);
-
+  if (VT.isVector())
+    Result = getSplat(VT, DL, Result);
   return Result;
 }
 
@@ -1663,10 +1660,8 @@ SDValue SelectionDAG::getConstantFP(const ConstantFP &V, const SDLoc &DL,
   }
 
   SDValue Result(N, 0);
-  if (VT.isScalableVector())
-    Result = getSplatVector(VT, DL, Result);
-  else if (VT.isVector())
-    Result = getSplatBuildVector(VT, DL, Result);
+  if (VT.isVector())
+    Result = getSplat(VT, DL, Result);
   NewSDValueDbgMsg(Result, "Creating fp constant: ", this);
   return Result;
 }

diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 60e63035a8987..3308134044f27 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -1695,9 +1695,7 @@ SDValue SelectionDAGBuilder::getValueImpl(const Value *V) {
       else
         Op = DAG.getConstant(0, getCurSDLoc(), EltVT);
 
-      if (isa<ScalableVectorType>(VecTy))
-        return NodeMap[V] = DAG.getSplatVector(VT, getCurSDLoc(), Op);
-      return NodeMap[V] = DAG.getSplatBuildVector(VT, getCurSDLoc(), Op);
+      return NodeMap[V] = DAG.getSplat(VT, getCurSDLoc(), Op);
     }
 
     llvm_unreachable("Unknown vector constant");
@@ -3904,10 +3902,7 @@ void SelectionDAGBuilder::visitGetElementPtr(const User &I) {
   if (IsVectorGEP && !N.getValueType().isVector()) {
     LLVMContext &Context = *DAG.getContext();
     EVT VT = EVT::getVectorVT(Context, N.getValueType(), VectorElementCount);
-    if (VectorElementCount.isScalable())
-      N = DAG.getSplatVector(VT, dl, N);
-    else
-      N = DAG.getSplatBuildVector(VT, dl, N);
+    N = DAG.getSplat(VT, dl, N);
   }
 
   for (gep_type_iterator GTI = gep_type_begin(&I), E = gep_type_end(&I);
@@ -3979,10 +3974,7 @@ void SelectionDAGBuilder::visitGetElementPtr(const User &I) {
       if (!IdxN.getValueType().isVector() && IsVectorGEP) {
         EVT VT = EVT::getVectorVT(*Context, IdxN.getValueType(),
                                   VectorElementCount);
-        if (VectorElementCount.isScalable())
-          IdxN = DAG.getSplatVector(VT, dl, IdxN);
-        else
-          IdxN = DAG.getSplatBuildVector(VT, dl, IdxN);
+        IdxN = DAG.getSplat(VT, dl, IdxN);
       }
 
       // If the index is smaller or larger than intptr_t, truncate or extend
@@ -7247,14 +7239,8 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
     SDValue TripCount = getValue(I.getOperand(1));
     auto VecTy = CCVT.changeVectorElementType(ElementVT);
 
-    SDValue VectorIndex, VectorTripCount;
-    if (VecTy.isScalableVector()) {
-      VectorIndex = DAG.getSplatVector(VecTy, sdl, Index);
-      VectorTripCount = DAG.getSplatVector(VecTy, sdl, TripCount);
-    } else {
-      VectorIndex = DAG.getSplatBuildVector(VecTy, sdl, Index);
-      VectorTripCount = DAG.getSplatBuildVector(VecTy, sdl, TripCount);
-    }
+    SDValue VectorIndex = DAG.getSplat(VecTy, sdl, Index);
+    SDValue VectorTripCount = DAG.getSplat(VecTy, sdl, TripCount);
     SDValue VectorStep = DAG.getStepVector(sdl, VecTy);
     SDValue VectorInduction = DAG.getNode(
         ISD::UADDSAT, sdl, VecTy, VectorIndex, VectorStep);

diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index af8fdb77c9dd1..3b36521463d02 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -4192,9 +4192,7 @@ SDValue RISCVTargetLowering::lowerSELECT(SDValue Op, SelectionDAG &DAG) const {
   // Lower vector SELECTs to VSELECTs by splatting the condition.
   if (VT.isVector()) {
     MVT SplatCondVT = VT.changeVectorElementType(MVT::i1);
-    SDValue CondSplat = VT.isScalableVector()
-                            ? DAG.getSplatVector(SplatCondVT, DL, CondV)
-                            : DAG.getSplatBuildVector(SplatCondVT, DL, CondV);
+    SDValue CondSplat = DAG.getSplat(SplatCondVT, DL, CondV);
     return DAG.getNode(ISD::VSELECT, DL, VT, CondSplat, TrueV, FalseV);
   }
 


        


More information about the llvm-commits mailing list