[llvm] 89fc016 - [CodeGen][SVE] Legalisation of extends with scalable types

Kerry McLaughlin via llvm-commits llvm-commits at lists.llvm.org
Fri Jun 5 04:23:40 PDT 2020


Author: Kerry McLaughlin
Date: 2020-06-05T12:08:42+01:00
New Revision: 89fc0166f53252956705935bfebbb70f06c47c8e

URL: https://github.com/llvm/llvm-project/commit/89fc0166f53252956705935bfebbb70f06c47c8e
DIFF: https://github.com/llvm/llvm-project/commit/89fc0166f53252956705935bfebbb70f06c47c8e.diff

LOG: [CodeGen][SVE] Legalisation of extends with scalable types

Summary:
This patch adds legalisation of extensions where the operand
of the extend is a legal scalable type but the result is not.

EXTRACT_SUBVECTOR is used to split the result, before
being replaced by target-specific [S|U]UNPK[HI|LO] operations.

For example:

```
zext <vscale x 16 x i8> %a to <vscale x 16 x i16>
```
should emit:

```
uunpklo z2.h, z0.b
uunpkhi z1.h, z0.b
```

Reviewers: sdesmalen, efriedma, david-arm

Reviewed By: efriedma

Subscribers: tschuett, hiraditya, rkruppe, psnobl, huihuiz, cfe-commits, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D79587

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/ValueTypes.h
    llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
    llvm/lib/CodeGen/ValueTypes.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.h
    llvm/test/CodeGen/AArch64/sve-sext-zext.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/ValueTypes.h b/llvm/include/llvm/CodeGen/ValueTypes.h
index e4d8a04a3340..8c2ef3f9a5e1 100644
--- a/llvm/include/llvm/CodeGen/ValueTypes.h
+++ b/llvm/include/llvm/CodeGen/ValueTypes.h
@@ -103,6 +103,17 @@ namespace llvm {
       return VecTy;
     }
 
+    /// Return a VT for a vector type whose attributes match ourselves
+    /// with the exception of the element type that is chosen by the caller.
+    EVT changeVectorElementType(EVT EltVT) const {
+      if (!isSimple())
+        return changeExtendedVectorElementType(EltVT);
+      MVT VecTy = MVT::getVectorVT(EltVT.V, getVectorElementCount());
+      assert(VecTy.SimpleTy != MVT::INVALID_SIMPLE_VALUE_TYPE &&
+             "Simple vector VT not representable by simple integer vector VT!");
+      return VecTy;
+    }
+
     /// Return the type converted to an equivalently sized integer or vector
     /// with integer element type. Similar to changeVectorElementTypeToInteger,
     /// but also handles scalars.
@@ -432,6 +443,7 @@ namespace llvm {
     // These are all out-of-line to prevent users of this header file
     // from having a dependency on Type.h.
     EVT changeExtendedTypeToInteger() const;
+    EVT changeExtendedVectorElementType(EVT EltVT) const;
     EVT changeExtendedVectorElementTypeToInteger() const;
     static EVT getExtendedIntegerVT(LLVMContext &C, unsigned BitWidth);
     static EVT getExtendedVectorVT(LLVMContext &C, EVT VT, unsigned NumElements,

diff  --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 453500aa9e51..70ef59338375 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -4324,6 +4324,31 @@ SDValue DAGTypeLegalizer::PromoteIntRes_EXTRACT_SUBVECTOR(SDNode *N) {
   SDLoc dl(N);
   SDValue BaseIdx = N->getOperand(1);
 
+  // TODO: We may be able to use this for types other than scalable
+  // vectors and fix those tests that expect BUILD_VECTOR to be used
+  if (OutVT.isScalableVector()) {
+    SDValue InOp0 = N->getOperand(0);
+    EVT InVT = InOp0.getValueType();
+
+    // Promote operands and see if this is handled by target lowering,
+    // Otherwise, use the BUILD_VECTOR approach below
+    if (getTypeAction(InVT) == TargetLowering::TypePromoteInteger) {
+      // Collect the (promoted) operands
+      SDValue Ops[] = { GetPromotedInteger(InOp0), BaseIdx };
+
+      EVT PromEltVT = Ops[0].getValueType().getVectorElementType();
+      assert(PromEltVT.bitsLE(NOutVTElem) &&
+             "Promoted operand has an element type greater than result");
+
+      EVT ExtVT = NOutVT.changeVectorElementType(PromEltVT);
+      SDValue Ext = DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), ExtVT, Ops);
+      return DAG.getNode(ISD::ANY_EXTEND, dl, NOutVT, Ext);
+    }
+  }
+
+  if (OutVT.isScalableVector())
+    report_fatal_error("Unable to promote scalable types using BUILD_VECTOR");
+
   SDValue InOp0 = N->getOperand(0);
   if (getTypeAction(InOp0.getValueType()) == TargetLowering::TypePromoteInteger)
     InOp0 = GetPromotedInteger(N->getOperand(0));

diff  --git a/llvm/lib/CodeGen/ValueTypes.cpp b/llvm/lib/CodeGen/ValueTypes.cpp
index 2b97e9d83dd0..538f97b76341 100644
--- a/llvm/lib/CodeGen/ValueTypes.cpp
+++ b/llvm/lib/CodeGen/ValueTypes.cpp
@@ -26,6 +26,11 @@ EVT EVT::changeExtendedVectorElementTypeToInteger() const {
                      isScalableVector());
 }
 
+EVT EVT::changeExtendedVectorElementType(EVT EltVT) const {
+  LLVMContext &Context = LLVMTy->getContext();
+  return getVectorVT(Context, EltVT, getVectorElementCount());
+}
+
 EVT EVT::getExtendedIntegerVT(LLVMContext &Context, unsigned BitWidth) {
   EVT VT;
   VT.LLVMTy = IntegerType::get(Context, BitWidth);

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 802df5734faf..7091a2f79db3 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -901,6 +901,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
         setOperationAction(ISD::SRA, VT, Custom);
         if (VT.getScalarType() == MVT::i1)
           setOperationAction(ISD::SETCC, VT, Custom);
+      } else {
+        for (auto VT : { MVT::nxv8i8, MVT::nxv4i16, MVT::nxv2i32 })
+          setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom);
       }
     }
     setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i8, Custom);
@@ -8560,6 +8563,9 @@ AArch64TargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
 
 SDValue AArch64TargetLowering::LowerEXTRACT_SUBVECTOR(SDValue Op,
                                                       SelectionDAG &DAG) const {
+  assert(!Op.getValueType().isScalableVector() &&
+         "Unexpected scalable type for custom lowering EXTRACT_SUBVECTOR");
+
   EVT VT = Op.getOperand(0).getValueType();
   SDLoc dl(Op);
   // Just in case...
@@ -10662,7 +10668,45 @@ static SDValue performSVEAndCombine(SDNode *N,
   if (DCI.isBeforeLegalizeOps())
     return SDValue();
 
+  SelectionDAG &DAG = DCI.DAG;
   SDValue Src = N->getOperand(0);
+  unsigned Opc = Src->getOpcode();
+
+  // Zero/any extend of an unsigned unpack
+  if (Opc == AArch64ISD::UUNPKHI || Opc == AArch64ISD::UUNPKLO) {
+    SDValue UnpkOp = Src->getOperand(0);
+    SDValue Dup = N->getOperand(1);
+
+    if (Dup.getOpcode() != AArch64ISD::DUP)
+      return SDValue();
+
+    SDLoc DL(N);
+    ConstantSDNode *C = dyn_cast<ConstantSDNode>(Dup->getOperand(0));
+    uint64_t ExtVal = C->getZExtValue();
+
+    // If the mask is fully covered by the unpack, we don't need to push
+    // a new AND onto the operand
+    EVT EltTy = UnpkOp->getValueType(0).getVectorElementType();
+    if ((ExtVal == 0xFF && EltTy == MVT::i8) ||
+        (ExtVal == 0xFFFF && EltTy == MVT::i16) ||
+        (ExtVal == 0xFFFFFFFF && EltTy == MVT::i32))
+      return Src;
+
+    // Truncate to prevent a DUP with an over wide constant
+    APInt Mask = C->getAPIntValue().trunc(EltTy.getSizeInBits());
+
+    // Otherwise, make sure we propagate the AND to the operand
+    // of the unpack
+    Dup = DAG.getNode(AArch64ISD::DUP, DL,
+                      UnpkOp->getValueType(0),
+                      DAG.getConstant(Mask.zextOrTrunc(32), DL, MVT::i32));
+
+    SDValue And = DAG.getNode(ISD::AND, DL,
+                              UnpkOp->getValueType(0), UnpkOp, Dup);
+
+    return DAG.getNode(Opc, DL, N->getValueType(0), And);
+  }
+
   SDValue Mask = N->getOperand(1);
 
   if (!Src.hasOneUse())
@@ -10672,7 +10716,7 @@ static SDValue performSVEAndCombine(SDNode *N,
 
   // SVE load instructions perform an implicit zero-extend, which makes them
   // perfect candidates for combining.
-  switch (Src->getOpcode()) {
+  switch (Opc) {
   case AArch64ISD::LD1:
   case AArch64ISD::LDNF1:
   case AArch64ISD::LDFF1:
@@ -13256,9 +13300,41 @@ performSignExtendInRegCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
   if (DCI.isBeforeLegalizeOps())
     return SDValue();
 
+  SDLoc DL(N);
   SDValue Src = N->getOperand(0);
   unsigned Opc = Src->getOpcode();
 
+  // Sign extend of an unsigned unpack -> signed unpack
+  if (Opc == AArch64ISD::UUNPKHI || Opc == AArch64ISD::UUNPKLO) {
+
+    unsigned SOpc = Opc == AArch64ISD::UUNPKHI ? AArch64ISD::SUNPKHI
+                                               : AArch64ISD::SUNPKLO;
+
+    // Push the sign extend to the operand of the unpack
+    // This is necessary where, for example, the operand of the unpack
+    // is another unpack:
+    // 4i32 sign_extend_inreg (4i32 uunpklo(8i16 uunpklo (16i8 opnd)), from 4i8)
+    // ->
+    // 4i32 sunpklo (8i16 sign_extend_inreg(8i16 uunpklo (16i8 opnd), from 8i8)
+    // ->
+    // 4i32 sunpklo(8i16 sunpklo(16i8 opnd))
+    SDValue ExtOp = Src->getOperand(0);
+    auto VT = cast<VTSDNode>(N->getOperand(1))->getVT();
+    EVT EltTy = VT.getVectorElementType();
+
+    assert((EltTy == MVT::i8 || EltTy == MVT::i16 || EltTy == MVT::i32) &&
+           "Sign extending from an invalid type");
+
+    EVT ExtVT = EVT::getVectorVT(*DAG.getContext(),
+                                 VT.getVectorElementType(),
+                                 VT.getVectorElementCount() * 2);
+
+    SDValue Ext = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, ExtOp.getValueType(),
+                              ExtOp, DAG.getValueType(ExtVT));
+
+    return DAG.getNode(SOpc, DL, N->getValueType(0), Ext);
+  }
+
   // SVE load nodes (e.g. AArch64ISD::GLD1) are straightforward candidates
   // for DAG Combine with SIGN_EXTEND_INREG. Bail out for all other nodes.
   unsigned NewOpc;
@@ -13747,6 +13823,40 @@ static std::pair<SDValue, SDValue> splitInt128(SDValue N, SelectionDAG &DAG) {
   return std::make_pair(Lo, Hi);
 }
 
+void AArch64TargetLowering::ReplaceExtractSubVectorResults(
+    SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
+  SDValue In = N->getOperand(0);
+  EVT InVT = In.getValueType();
+
+  // Common code will handle these just fine.
+  if (!InVT.isScalableVector() || !InVT.isInteger())
+    return;
+
+  SDLoc DL(N);
+  EVT VT = N->getValueType(0);
+
+  // The following checks bail if this is not a halving operation.
+
+  ElementCount ResEC = VT.getVectorElementCount();
+
+  if (InVT.getVectorElementCount().Min != (ResEC.Min * 2))
+    return;
+
+  auto *CIndex = dyn_cast<ConstantSDNode>(N->getOperand(1));
+  if (!CIndex)
+    return;
+
+  unsigned Index = CIndex->getZExtValue();
+  if ((Index != 0) && (Index != ResEC.Min))
+    return;
+
+  unsigned Opcode = (Index == 0) ? AArch64ISD::UUNPKLO : AArch64ISD::UUNPKHI;
+  EVT ExtendedHalfVT = VT.widenIntegerVectorElementType(*DAG.getContext());
+
+  SDValue Half = DAG.getNode(Opcode, DL, ExtendedHalfVT, N->getOperand(0));
+  Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, Half));
+}
+
 // Create an even/odd pair of X registers holding integer value V.
 static SDValue createGPRPairNode(SelectionDAG &DAG, SDValue V) {
   SDLoc dl(V.getNode());
@@ -13899,6 +14009,9 @@ void AArch64TargetLowering::ReplaceNodeResults(
     Results.append({Pair, Result.getValue(2) /* Chain */});
     return;
   }
+  case ISD::EXTRACT_SUBVECTOR:
+    ReplaceExtractSubVectorResults(N, Results, DAG);
+    return;
   case ISD::INTRINSIC_WO_CHAIN: {
     EVT VT = N->getValueType(0);
     assert((VT == MVT::i8 || VT == MVT::i16) &&

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 81af7e2235a2..e42c0b6e05b7 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -889,6 +889,9 @@ class AArch64TargetLowering : public TargetLowering {
 
   void ReplaceNodeResults(SDNode *N, SmallVectorImpl<SDValue> &Results,
                           SelectionDAG &DAG) const override;
+  void ReplaceExtractSubVectorResults(SDNode *N,
+                                      SmallVectorImpl<SDValue> &Results,
+                                      SelectionDAG &DAG) const;
 
   bool shouldNormalizeToSelectSequence(LLVMContext &, EVT) const override;
 

diff  --git a/llvm/test/CodeGen/AArch64/sve-sext-zext.ll b/llvm/test/CodeGen/AArch64/sve-sext-zext.ll
index f9a527c1fc8c..24cf433306bb 100644
--- a/llvm/test/CodeGen/AArch64/sve-sext-zext.ll
+++ b/llvm/test/CodeGen/AArch64/sve-sext-zext.ll
@@ -186,3 +186,143 @@ define <vscale x 2 x i64> @zext_i32_i64(<vscale x 2 x i32> %a) {
   %r = zext <vscale x 2 x i32> %a to <vscale x 2 x i64>
   ret <vscale x 2 x i64> %r
 }
+
+; Extending to illegal types
+
+define <vscale x 16 x i16> @sext_b_to_h(<vscale x 16 x i8> %a) {
+; CHECK-LABEL: sext_b_to_h:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    sunpklo z2.h, z0.b
+; CHECK-NEXT:    sunpkhi z1.h, z0.b
+; CHECK-NEXT:    mov z0.d, z2.d
+; CHECK-NEXT:    ret
+  %ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i16>
+  ret <vscale x 16 x i16> %ext
+}
+
+define <vscale x 8 x i32> @sext_h_to_s(<vscale x 8 x i16> %a) {
+; CHECK-LABEL: sext_h_to_s:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    sunpklo z2.s, z0.h
+; CHECK-NEXT:    sunpkhi z1.s, z0.h
+; CHECK-NEXT:    mov z0.d, z2.d
+; CHECK-NEXT:    ret
+  %ext = sext <vscale x 8 x i16> %a to <vscale x 8 x i32>
+  ret <vscale x 8 x i32> %ext
+}
+
+define <vscale x 4 x i64> @sext_s_to_d(<vscale x 4 x i32> %a) {
+; CHECK-LABEL: sext_s_to_d:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    sunpklo z2.d, z0.s
+; CHECK-NEXT:    sunpkhi z1.d, z0.s
+; CHECK-NEXT:    mov z0.d, z2.d
+; CHECK-NEXT:    ret
+  %ext = sext <vscale x 4 x i32> %a to <vscale x 4 x i64>
+  ret <vscale x 4 x i64> %ext
+}
+
+define <vscale x 16 x i32> @sext_b_to_s(<vscale x 16 x i8> %a) {
+; CHECK-LABEL: sext_b_to_s:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    sunpklo z1.h, z0.b
+; CHECK-NEXT:    sunpkhi z3.h, z0.b
+; CHECK-NEXT:    sunpklo z0.s, z1.h
+; CHECK-NEXT:    sunpkhi z1.s, z1.h
+; CHECK-NEXT:    sunpklo z2.s, z3.h
+; CHECK-NEXT:    sunpkhi z3.s, z3.h
+; CHECK-NEXT:    ret
+  %ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+  ret <vscale x 16 x i32> %ext
+}
+
+define <vscale x 16 x i64> @sext_b_to_d(<vscale x 16 x i8> %a) {
+; CHECK-LABEL: sext_b_to_d:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    sunpklo z1.h, z0.b
+; CHECK-NEXT:    sunpkhi z0.h, z0.b
+; CHECK-NEXT:    sunpklo z2.s, z1.h
+; CHECK-NEXT:    sunpkhi z3.s, z1.h
+; CHECK-NEXT:    sunpklo z5.s, z0.h
+; CHECK-NEXT:    sunpkhi z7.s, z0.h
+; CHECK-NEXT:    sunpklo z0.d, z2.s
+; CHECK-NEXT:    sunpkhi z1.d, z2.s
+; CHECK-NEXT:    sunpklo z2.d, z3.s
+; CHECK-NEXT:    sunpkhi z3.d, z3.s
+; CHECK-NEXT:    sunpklo z4.d, z5.s
+; CHECK-NEXT:    sunpkhi z5.d, z5.s
+; CHECK-NEXT:    sunpklo z6.d, z7.s
+; CHECK-NEXT:    sunpkhi z7.d, z7.s
+; CHECK-NEXT:    ret
+  %ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
+  ret <vscale x 16 x i64> %ext
+}
+
+define <vscale x 16 x i16> @zext_b_to_h(<vscale x 16 x i8> %a) {
+; CHECK-LABEL: zext_b_to_h:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    uunpklo z2.h, z0.b
+; CHECK-NEXT:    uunpkhi z1.h, z0.b
+; CHECK-NEXT:    mov z0.d, z2.d
+; CHECK-NEXT:    ret
+  %ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i16>
+  ret <vscale x 16 x i16> %ext
+}
+
+define <vscale x 8 x i32> @zext_h_to_s(<vscale x 8 x i16> %a) {
+; CHECK-LABEL: zext_h_to_s:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    uunpklo z2.s, z0.h
+; CHECK-NEXT:    uunpkhi z1.s, z0.h
+; CHECK-NEXT:    mov z0.d, z2.d
+; CHECK-NEXT:    ret
+  %ext = zext <vscale x 8 x i16> %a to <vscale x 8 x i32>
+  ret <vscale x 8 x i32> %ext
+}
+
+define <vscale x 4 x i64> @zext_s_to_d(<vscale x 4 x i32> %a) {
+; CHECK-LABEL: zext_s_to_d:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    uunpklo z2.d, z0.s
+; CHECK-NEXT:    uunpkhi z1.d, z0.s
+; CHECK-NEXT:    mov z0.d, z2.d
+; CHECK-NEXT:    ret
+  %ext = zext <vscale x 4 x i32> %a to <vscale x 4 x i64>
+  ret <vscale x 4 x i64> %ext
+}
+
+define <vscale x 16 x i32> @zext_b_to_s(<vscale x 16 x i8> %a) {
+; CHECK-LABEL: zext_b_to_s:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    uunpklo z1.h, z0.b
+; CHECK-NEXT:    uunpkhi z3.h, z0.b
+; CHECK-NEXT:    uunpklo z0.s, z1.h
+; CHECK-NEXT:    uunpkhi z1.s, z1.h
+; CHECK-NEXT:    uunpklo z2.s, z3.h
+; CHECK-NEXT:    uunpkhi z3.s, z3.h
+; CHECK-NEXT:    ret
+  %ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+  ret <vscale x 16 x i32> %ext
+}
+
+define <vscale x 16 x i64> @zext_b_to_d(<vscale x 16 x i8> %a) {
+; CHECK-LABEL: zext_b_to_d:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    uunpklo z1.h, z0.b
+; CHECK-NEXT:    uunpkhi z0.h, z0.b
+; CHECK-NEXT:    uunpklo z2.s, z1.h
+; CHECK-NEXT:    uunpkhi z3.s, z1.h
+; CHECK-NEXT:    uunpklo z5.s, z0.h
+; CHECK-NEXT:    uunpkhi z7.s, z0.h
+; CHECK-NEXT:    uunpklo z0.d, z2.s
+; CHECK-NEXT:    uunpkhi z1.d, z2.s
+; CHECK-NEXT:    uunpklo z2.d, z3.s
+; CHECK-NEXT:    uunpkhi z3.d, z3.s
+; CHECK-NEXT:    uunpklo z4.d, z5.s
+; CHECK-NEXT:    uunpkhi z5.d, z5.s
+; CHECK-NEXT:    uunpklo z6.d, z7.s
+; CHECK-NEXT:    uunpkhi z7.d, z7.s
+; CHECK-NEXT:    ret
+  %ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
+  ret <vscale x 16 x i64> %ext
+}


        


More information about the llvm-commits mailing list