[llvm] 9106d04 - [RISCV][SelectionDAG] Introduce an ISD::SPLAT_VECTOR_PARTS node that can represent a splat of 2 i32 values into a nxvXi64 vector for riscv32.

Craig Topper via llvm-commits llvm-commits at lists.llvm.org
Wed Mar 10 09:54:24 PST 2021


Author: Craig Topper
Date: 2021-03-10T09:46:18-08:00
New Revision: 9106d04554026c94070528d164965e62ff54ceed

URL: https://github.com/llvm/llvm-project/commit/9106d04554026c94070528d164965e62ff54ceed
DIFF: https://github.com/llvm/llvm-project/commit/9106d04554026c94070528d164965e62ff54ceed.diff

LOG: [RISCV][SelectionDAG] Introduce an ISD::SPLAT_VECTOR_PARTS node that can represent a splat of 2 i32 values into a nxvXi64 vector for riscv32.

On riscv32, i64 isn't a legal scalar type but we would like to
support scalable vectors of i64.

This patch introduces a new node that can represent a splat made
of multiple scalar values. I've used this new node to solve the current
crashes we experience when getConstant is used after type legalization.

For RISCV, we are now default expanding SPLAT_VECTOR to SPLAT_VECTOR_PARTS
when needed and then handling the SPLAT_VECTOR_PARTS later during
LegalizeOps. I've remove the special case I previously put in for
ABS for D97991 as the default expansion is now able to succesfully
use getConstant.

Reviewed By: frasercrmck

Differential Revision: https://reviews.llvm.org/D98004

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/ISDOpcodes.h
    llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
    llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
    llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
    llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
    llvm/lib/Target/RISCV/RISCVISelLowering.cpp
    llvm/lib/Target/RISCV/RISCVISelLowering.h

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index 8cd73951de8f..9d6c45f4e5c3 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -583,6 +583,15 @@ enum NodeType {
   /// implicitly truncated to it.
   SPLAT_VECTOR,
 
+  /// SPLAT_VECTOR_PARTS(SCALAR1, SCALAR2, ...) - Returns a vector with the
+  /// scalar values joined together and then duplicated in all lanes. This
+  /// represents a SPLAT_VECTOR that has had its scalar operand expanded. This
+  /// allows representing a 64-bit splat on a target with 32-bit integers. The
+  /// total width of the scalars must cover the element width. SCALAR1 contains
+  /// the least significant bits of the value regardless of endianness and all
+  /// scalars should have the same type.
+  SPLAT_VECTOR_PARTS,
+
   /// MULHU/MULHS - Multiply high - Multiply two integers of type iN,
   /// producing an unsigned/signed value of type i[2*N], then return the top
   /// part.

diff  --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 3a33d1e7d0c8..cf970566a9d2 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -4194,6 +4194,7 @@ bool DAGTypeLegalizer::ExpandIntegerOperand(SDNode *N, unsigned OpNo) {
   case ISD::EXTRACT_ELEMENT:   Res = ExpandOp_EXTRACT_ELEMENT(N); break;
   case ISD::INSERT_VECTOR_ELT: Res = ExpandOp_INSERT_VECTOR_ELT(N); break;
   case ISD::SCALAR_TO_VECTOR:  Res = ExpandOp_SCALAR_TO_VECTOR(N); break;
+  case ISD::SPLAT_VECTOR:      Res = ExpandIntOp_SPLAT_VECTOR(N); break;
   case ISD::SELECT_CC:         Res = ExpandIntOp_SELECT_CC(N); break;
   case ISD::SETCC:             Res = ExpandIntOp_SETCC(N); break;
   case ISD::SETCCCARRY:        Res = ExpandIntOp_SETCCCARRY(N); break;
@@ -4449,6 +4450,14 @@ SDValue DAGTypeLegalizer::ExpandIntOp_SETCCCARRY(SDNode *N) {
                      LowCmp.getValue(1), Cond);
 }
 
+SDValue DAGTypeLegalizer::ExpandIntOp_SPLAT_VECTOR(SDNode *N) {
+  // Split the operand and replace with SPLAT_VECTOR_PARTS.
+  SDValue Lo, Hi;
+  GetExpandedInteger(N->getOperand(0), Lo, Hi);
+  return DAG.getNode(ISD::SPLAT_VECTOR_PARTS, SDLoc(N), N->getValueType(0), Lo,
+                     Hi);
+}
+
 SDValue DAGTypeLegalizer::ExpandIntOp_Shift(SDNode *N) {
   // The value being shifted is legal, but the shift amount is too big.
   // It follows that either the result of the shift is undefined, or the

diff  --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index b38dc14763f5..de916823463f 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -481,6 +481,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   SDValue ExpandIntOp_UINT_TO_FP(SDNode *N);
   SDValue ExpandIntOp_RETURNADDR(SDNode *N);
   SDValue ExpandIntOp_ATOMIC_STORE(SDNode *N);
+  SDValue ExpandIntOp_SPLAT_VECTOR(SDNode *N);
 
   void IntegerExpandSetCCOperands(SDValue &NewLHS, SDValue &NewRHS,
                                   ISD::CondCode &CCCode, const SDLoc &dl);

diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index cd21d6eed994..2ad3ada7fc04 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -1383,6 +1383,22 @@ SDValue SelectionDAG::getConstant(const ConstantInt &Val, const SDLoc &DL,
     const APInt &NewVal = Elt->getValue();
     EVT ViaEltVT = TLI->getTypeToTransformTo(*getContext(), EltVT);
     unsigned ViaEltSizeInBits = ViaEltVT.getSizeInBits();
+
+    // For scalable vectors, try to use a SPLAT_VECTOR_PARTS node.
+    if (VT.isScalableVector()) {
+      assert(EltVT.getSizeInBits() % ViaEltSizeInBits == 0 &&
+             "Can only handle an even split!");
+      unsigned Parts = EltVT.getSizeInBits() / ViaEltSizeInBits;
+
+      SmallVector<SDValue, 2> ScalarParts;
+      for (unsigned i = 0; i != Parts; ++i)
+        ScalarParts.push_back(getConstant(
+            NewVal.lshr(i * ViaEltSizeInBits).trunc(ViaEltSizeInBits), DL,
+            ViaEltVT, isT, isO));
+
+      return getNode(ISD::SPLAT_VECTOR_PARTS, DL, VT, ScalarParts);
+    }
+
     unsigned ViaVecNumElts = VT.getSizeInBits() / ViaEltSizeInBits;
     EVT ViaVecVT = EVT::getVectorVT(*getContext(), ViaEltVT, ViaVecNumElts);
 

diff  --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
index 079227412215..8402163fd92b 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
@@ -290,6 +290,7 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
   case ISD::VECTOR_SHUFFLE:             return "vector_shuffle";
   case ISD::VECTOR_SPLICE:              return "vector_splice";
   case ISD::SPLAT_VECTOR:               return "splat_vector";
+  case ISD::SPLAT_VECTOR_PARTS:         return "splat_vector_parts";
   case ISD::VECTOR_REVERSE:             return "vector_reverse";
   case ISD::CARRY_FALSE:                return "carry_false";
   case ISD::ADDC:                       return "addc";

diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 7f94190bcab0..39041dfcda0e 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -399,7 +399,6 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
     if (!Subtarget.is64Bit()) {
       // We must custom-lower certain vXi64 operations on RV32 due to the vector
       // element type being illegal.
-      setOperationAction(ISD::SPLAT_VECTOR, MVT::i64, Custom);
       setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::i64, Custom);
       setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::i64, Custom);
 
@@ -424,15 +423,13 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
 
     for (MVT VT : IntVecVTs) {
       setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
+      setOperationAction(ISD::SPLAT_VECTOR_PARTS, VT, Custom);
 
       setOperationAction(ISD::SMIN, VT, Legal);
       setOperationAction(ISD::SMAX, VT, Legal);
       setOperationAction(ISD::UMIN, VT, Legal);
       setOperationAction(ISD::UMAX, VT, Legal);
 
-      if (!Subtarget.is64Bit() && VT.getVectorElementType() == MVT::i64)
-        setOperationAction(ISD::ABS, VT, Custom);
-
       setOperationAction(ISD::ROTL, VT, Expand);
       setOperationAction(ISD::ROTR, VT, Expand);
 
@@ -1313,8 +1310,8 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
         Op.getOperand(0).getValueType().getVectorElementType() == MVT::i1)
       return lowerVectorMaskExt(Op, DAG, /*ExtVal*/ -1);
     return lowerFixedLengthVectorExtendToRVV(Op, DAG, RISCVISD::VSEXT_VL);
-  case ISD::SPLAT_VECTOR:
-    return lowerSPLATVECTOR(Op, DAG);
+  case ISD::SPLAT_VECTOR_PARTS:
+    return lowerSPLAT_VECTOR_PARTS(Op, DAG);
   case ISD::INSERT_VECTOR_ELT:
     return lowerINSERT_VECTOR_ELT(Op, DAG);
   case ISD::EXTRACT_VECTOR_ELT:
@@ -2035,30 +2032,28 @@ SDValue RISCVTargetLowering::lowerShiftRightParts(SDValue Op, SelectionDAG &DAG,
   return DAG.getMergeValues(Parts, DL);
 }
 
-// Custom-lower a SPLAT_VECTOR where XLEN<SEW, as the SEW element type is
+// Custom-lower a SPLAT_VECTOR_PARTS where XLEN<SEW, as the SEW element type is
 // illegal (currently only vXi64 RV32).
 // FIXME: We could also catch non-constant sign-extended i32 values and lower
 // them to SPLAT_VECTOR_I64
-SDValue RISCVTargetLowering::lowerSPLATVECTOR(SDValue Op,
-                                              SelectionDAG &DAG) const {
+SDValue RISCVTargetLowering::lowerSPLAT_VECTOR_PARTS(SDValue Op,
+                                                     SelectionDAG &DAG) const {
   SDLoc DL(Op);
   EVT VecVT = Op.getValueType();
   assert(!Subtarget.is64Bit() && VecVT.getVectorElementType() == MVT::i64 &&
-         "Unexpected SPLAT_VECTOR lowering");
-  SDValue SplatVal = Op.getOperand(0);
+         "Unexpected SPLAT_VECTOR_PARTS lowering");
 
-  // If we can prove that the value is a sign-extended 32-bit value, lower this
-  // as a custom node in order to try and match RVV vector/scalar instructions.
-  if (auto *CVal = dyn_cast<ConstantSDNode>(SplatVal)) {
-    if (isInt<32>(CVal->getSExtValue()))
-      return DAG.getNode(RISCVISD::SPLAT_VECTOR_I64, DL, VecVT,
-                         DAG.getConstant(CVal->getSExtValue(), DL, MVT::i32));
-  }
+  assert(Op.getNumOperands() == 2 && "Unexpected number of operands!");
+  SDValue Lo = Op.getOperand(0);
+  SDValue Hi = Op.getOperand(1);
 
-  if (SplatVal.getOpcode() == ISD::SIGN_EXTEND &&
-      SplatVal.getOperand(0).getValueType() == MVT::i32) {
-    return DAG.getNode(RISCVISD::SPLAT_VECTOR_I64, DL, VecVT,
-                       SplatVal.getOperand(0));
+  if (isa<ConstantSDNode>(Lo) && isa<ConstantSDNode>(Hi)) {
+    int32_t LoC = cast<ConstantSDNode>(Lo)->getSExtValue();
+    int32_t HiC = cast<ConstantSDNode>(Hi)->getSExtValue();
+    // If Hi constant is all the same sign bit as Lo, lower this as a custom
+    // node in order to try and match RVV vector/scalar instructions.
+    if ((LoC >> 31) == HiC)
+      return DAG.getNode(RISCVISD::SPLAT_VECTOR_I64, DL, VecVT, Lo);
   }
 
   // Else, on RV32 we lower an i64-element SPLAT_VECTOR thus, being careful not
@@ -2069,11 +2064,7 @@ SDValue RISCVTargetLowering::lowerSPLATVECTOR(SDValue Op,
   // vsll.vx vY, vY, /*32*/
   // vsrl.vx vY, vY, /*32*/
   // vor.vv vX, vX, vY
-  SDValue One = DAG.getConstant(1, DL, MVT::i32);
-  SDValue Zero = DAG.getConstant(0, DL, MVT::i32);
   SDValue ThirtyTwoV = DAG.getConstant(32, DL, VecVT);
-  SDValue Lo = DAG.getNode(ISD::EXTRACT_ELEMENT, DL, MVT::i32, SplatVal, Zero);
-  SDValue Hi = DAG.getNode(ISD::EXTRACT_ELEMENT, DL, MVT::i32, SplatVal, One);
 
   Lo = DAG.getNode(RISCVISD::SPLAT_VECTOR_I64, DL, VecVT, Lo);
   Lo = DAG.getNode(ISD::SHL, DL, VecVT, Lo, ThirtyTwoV);
@@ -3162,17 +3153,6 @@ SDValue RISCVTargetLowering::lowerABS(SDValue Op, SelectionDAG &DAG) const {
   MVT VT = Op.getSimpleValueType();
   SDValue X = Op.getOperand(0);
 
-  // For scalable vectors we just need to deal with i64 on RV32 since the
-  // default expansion crashes in getConstant.
-  if (VT.isScalableVector()) {
-    assert(!Subtarget.is64Bit() && VT.getVectorElementType() == MVT::i64 &&
-           "Unexpected custom lowering!");
-    SDValue SplatZero = DAG.getNode(RISCVISD::SPLAT_VECTOR_I64, DL, VT,
-                                    DAG.getConstant(0, DL, MVT::i32));
-    SDValue NegX = DAG.getNode(ISD::SUB, DL, VT, SplatZero, X);
-    return DAG.getNode(ISD::SMAX, DL, VT, X, NegX);
-  }
-
   assert(VT.isFixedLengthVector() && "Unexpected type");
 
   MVT ContainerVT =

diff  --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 3f49c2390d46..87a4df55ae2f 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -437,7 +437,7 @@ class RISCVTargetLowering : public TargetLowering {
   SDValue lowerRETURNADDR(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerShiftLeftParts(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerShiftRightParts(SDValue Op, SelectionDAG &DAG, bool IsSRA) const;
-  SDValue lowerSPLATVECTOR(SDValue Op, SelectionDAG &DAG) const;
+  SDValue lowerSPLAT_VECTOR_PARTS(SDValue Op, SelectionDAG &DAG) const;
   SDValue lowerVectorMaskExt(SDValue Op, SelectionDAG &DAG,
                              int64_t ExtTrueVal) const;
   SDValue lowerVectorMaskTrunc(SDValue Op, SelectionDAG &DAG) const;


        


More information about the llvm-commits mailing list