[llvm] 89fc016 - [CodeGen][SVE] Legalisation of extends with scalable types
Kerry McLaughlin via llvm-commits
llvm-commits at lists.llvm.org
Fri Jun 5 04:23:40 PDT 2020
Author: Kerry McLaughlin
Date: 2020-06-05T12:08:42+01:00
New Revision: 89fc0166f53252956705935bfebbb70f06c47c8e
URL: https://github.com/llvm/llvm-project/commit/89fc0166f53252956705935bfebbb70f06c47c8e
DIFF: https://github.com/llvm/llvm-project/commit/89fc0166f53252956705935bfebbb70f06c47c8e.diff
LOG: [CodeGen][SVE] Legalisation of extends with scalable types
Summary:
This patch adds legalisation of extensions where the operand
of the extend is a legal scalable type but the result is not.
EXTRACT_SUBVECTOR is used to split the result, before
being replaced by target-specific [S|U]UNPK[HI|LO] operations.
For example:
```
zext <vscale x 16 x i8> %a to <vscale x 16 x i16>
```
should emit:
```
uunpklo z2.h, z0.b
uunpkhi z1.h, z0.b
```
Reviewers: sdesmalen, efriedma, david-arm
Reviewed By: efriedma
Subscribers: tschuett, hiraditya, rkruppe, psnobl, huihuiz, cfe-commits, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D79587
Added:
Modified:
llvm/include/llvm/CodeGen/ValueTypes.h
llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
llvm/lib/CodeGen/ValueTypes.cpp
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
llvm/lib/Target/AArch64/AArch64ISelLowering.h
llvm/test/CodeGen/AArch64/sve-sext-zext.ll
Removed:
################################################################################
diff --git a/llvm/include/llvm/CodeGen/ValueTypes.h b/llvm/include/llvm/CodeGen/ValueTypes.h
index e4d8a04a3340..8c2ef3f9a5e1 100644
--- a/llvm/include/llvm/CodeGen/ValueTypes.h
+++ b/llvm/include/llvm/CodeGen/ValueTypes.h
@@ -103,6 +103,17 @@ namespace llvm {
return VecTy;
}
+ /// Return a VT for a vector type whose attributes match ourselves
+ /// with the exception of the element type that is chosen by the caller.
+ EVT changeVectorElementType(EVT EltVT) const {
+ if (!isSimple())
+ return changeExtendedVectorElementType(EltVT);
+ MVT VecTy = MVT::getVectorVT(EltVT.V, getVectorElementCount());
+ assert(VecTy.SimpleTy != MVT::INVALID_SIMPLE_VALUE_TYPE &&
+ "Simple vector VT not representable by simple integer vector VT!");
+ return VecTy;
+ }
+
/// Return the type converted to an equivalently sized integer or vector
/// with integer element type. Similar to changeVectorElementTypeToInteger,
/// but also handles scalars.
@@ -432,6 +443,7 @@ namespace llvm {
// These are all out-of-line to prevent users of this header file
// from having a dependency on Type.h.
EVT changeExtendedTypeToInteger() const;
+ EVT changeExtendedVectorElementType(EVT EltVT) const;
EVT changeExtendedVectorElementTypeToInteger() const;
static EVT getExtendedIntegerVT(LLVMContext &C, unsigned BitWidth);
static EVT getExtendedVectorVT(LLVMContext &C, EVT VT, unsigned NumElements,
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 453500aa9e51..70ef59338375 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -4324,6 +4324,31 @@ SDValue DAGTypeLegalizer::PromoteIntRes_EXTRACT_SUBVECTOR(SDNode *N) {
SDLoc dl(N);
SDValue BaseIdx = N->getOperand(1);
+ // TODO: We may be able to use this for types other than scalable
+ // vectors and fix those tests that expect BUILD_VECTOR to be used
+ if (OutVT.isScalableVector()) {
+ SDValue InOp0 = N->getOperand(0);
+ EVT InVT = InOp0.getValueType();
+
+ // Promote operands and see if this is handled by target lowering,
+ // Otherwise, use the BUILD_VECTOR approach below
+ if (getTypeAction(InVT) == TargetLowering::TypePromoteInteger) {
+ // Collect the (promoted) operands
+ SDValue Ops[] = { GetPromotedInteger(InOp0), BaseIdx };
+
+ EVT PromEltVT = Ops[0].getValueType().getVectorElementType();
+ assert(PromEltVT.bitsLE(NOutVTElem) &&
+ "Promoted operand has an element type greater than result");
+
+ EVT ExtVT = NOutVT.changeVectorElementType(PromEltVT);
+ SDValue Ext = DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), ExtVT, Ops);
+ return DAG.getNode(ISD::ANY_EXTEND, dl, NOutVT, Ext);
+ }
+ }
+
+ if (OutVT.isScalableVector())
+ report_fatal_error("Unable to promote scalable types using BUILD_VECTOR");
+
SDValue InOp0 = N->getOperand(0);
if (getTypeAction(InOp0.getValueType()) == TargetLowering::TypePromoteInteger)
InOp0 = GetPromotedInteger(N->getOperand(0));
diff --git a/llvm/lib/CodeGen/ValueTypes.cpp b/llvm/lib/CodeGen/ValueTypes.cpp
index 2b97e9d83dd0..538f97b76341 100644
--- a/llvm/lib/CodeGen/ValueTypes.cpp
+++ b/llvm/lib/CodeGen/ValueTypes.cpp
@@ -26,6 +26,11 @@ EVT EVT::changeExtendedVectorElementTypeToInteger() const {
isScalableVector());
}
+EVT EVT::changeExtendedVectorElementType(EVT EltVT) const {
+ LLVMContext &Context = LLVMTy->getContext();
+ return getVectorVT(Context, EltVT, getVectorElementCount());
+}
+
EVT EVT::getExtendedIntegerVT(LLVMContext &Context, unsigned BitWidth) {
EVT VT;
VT.LLVMTy = IntegerType::get(Context, BitWidth);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 802df5734faf..7091a2f79db3 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -901,6 +901,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::SRA, VT, Custom);
if (VT.getScalarType() == MVT::i1)
setOperationAction(ISD::SETCC, VT, Custom);
+ } else {
+ for (auto VT : { MVT::nxv8i8, MVT::nxv4i16, MVT::nxv2i32 })
+ setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom);
}
}
setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i8, Custom);
@@ -8560,6 +8563,9 @@ AArch64TargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
SDValue AArch64TargetLowering::LowerEXTRACT_SUBVECTOR(SDValue Op,
SelectionDAG &DAG) const {
+ assert(!Op.getValueType().isScalableVector() &&
+ "Unexpected scalable type for custom lowering EXTRACT_SUBVECTOR");
+
EVT VT = Op.getOperand(0).getValueType();
SDLoc dl(Op);
// Just in case...
@@ -10662,7 +10668,45 @@ static SDValue performSVEAndCombine(SDNode *N,
if (DCI.isBeforeLegalizeOps())
return SDValue();
+ SelectionDAG &DAG = DCI.DAG;
SDValue Src = N->getOperand(0);
+ unsigned Opc = Src->getOpcode();
+
+ // Zero/any extend of an unsigned unpack
+ if (Opc == AArch64ISD::UUNPKHI || Opc == AArch64ISD::UUNPKLO) {
+ SDValue UnpkOp = Src->getOperand(0);
+ SDValue Dup = N->getOperand(1);
+
+ if (Dup.getOpcode() != AArch64ISD::DUP)
+ return SDValue();
+
+ SDLoc DL(N);
+ ConstantSDNode *C = dyn_cast<ConstantSDNode>(Dup->getOperand(0));
+ uint64_t ExtVal = C->getZExtValue();
+
+ // If the mask is fully covered by the unpack, we don't need to push
+ // a new AND onto the operand
+ EVT EltTy = UnpkOp->getValueType(0).getVectorElementType();
+ if ((ExtVal == 0xFF && EltTy == MVT::i8) ||
+ (ExtVal == 0xFFFF && EltTy == MVT::i16) ||
+ (ExtVal == 0xFFFFFFFF && EltTy == MVT::i32))
+ return Src;
+
+ // Truncate to prevent a DUP with an over wide constant
+ APInt Mask = C->getAPIntValue().trunc(EltTy.getSizeInBits());
+
+ // Otherwise, make sure we propagate the AND to the operand
+ // of the unpack
+ Dup = DAG.getNode(AArch64ISD::DUP, DL,
+ UnpkOp->getValueType(0),
+ DAG.getConstant(Mask.zextOrTrunc(32), DL, MVT::i32));
+
+ SDValue And = DAG.getNode(ISD::AND, DL,
+ UnpkOp->getValueType(0), UnpkOp, Dup);
+
+ return DAG.getNode(Opc, DL, N->getValueType(0), And);
+ }
+
SDValue Mask = N->getOperand(1);
if (!Src.hasOneUse())
@@ -10672,7 +10716,7 @@ static SDValue performSVEAndCombine(SDNode *N,
// SVE load instructions perform an implicit zero-extend, which makes them
// perfect candidates for combining.
- switch (Src->getOpcode()) {
+ switch (Opc) {
case AArch64ISD::LD1:
case AArch64ISD::LDNF1:
case AArch64ISD::LDFF1:
@@ -13256,9 +13300,41 @@ performSignExtendInRegCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
if (DCI.isBeforeLegalizeOps())
return SDValue();
+ SDLoc DL(N);
SDValue Src = N->getOperand(0);
unsigned Opc = Src->getOpcode();
+ // Sign extend of an unsigned unpack -> signed unpack
+ if (Opc == AArch64ISD::UUNPKHI || Opc == AArch64ISD::UUNPKLO) {
+
+ unsigned SOpc = Opc == AArch64ISD::UUNPKHI ? AArch64ISD::SUNPKHI
+ : AArch64ISD::SUNPKLO;
+
+ // Push the sign extend to the operand of the unpack
+ // This is necessary where, for example, the operand of the unpack
+ // is another unpack:
+ // 4i32 sign_extend_inreg (4i32 uunpklo(8i16 uunpklo (16i8 opnd)), from 4i8)
+ // ->
+ // 4i32 sunpklo (8i16 sign_extend_inreg(8i16 uunpklo (16i8 opnd), from 8i8)
+ // ->
+ // 4i32 sunpklo(8i16 sunpklo(16i8 opnd))
+ SDValue ExtOp = Src->getOperand(0);
+ auto VT = cast<VTSDNode>(N->getOperand(1))->getVT();
+ EVT EltTy = VT.getVectorElementType();
+
+ assert((EltTy == MVT::i8 || EltTy == MVT::i16 || EltTy == MVT::i32) &&
+ "Sign extending from an invalid type");
+
+ EVT ExtVT = EVT::getVectorVT(*DAG.getContext(),
+ VT.getVectorElementType(),
+ VT.getVectorElementCount() * 2);
+
+ SDValue Ext = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, ExtOp.getValueType(),
+ ExtOp, DAG.getValueType(ExtVT));
+
+ return DAG.getNode(SOpc, DL, N->getValueType(0), Ext);
+ }
+
// SVE load nodes (e.g. AArch64ISD::GLD1) are straightforward candidates
// for DAG Combine with SIGN_EXTEND_INREG. Bail out for all other nodes.
unsigned NewOpc;
@@ -13747,6 +13823,40 @@ static std::pair<SDValue, SDValue> splitInt128(SDValue N, SelectionDAG &DAG) {
return std::make_pair(Lo, Hi);
}
+void AArch64TargetLowering::ReplaceExtractSubVectorResults(
+ SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
+ SDValue In = N->getOperand(0);
+ EVT InVT = In.getValueType();
+
+ // Common code will handle these just fine.
+ if (!InVT.isScalableVector() || !InVT.isInteger())
+ return;
+
+ SDLoc DL(N);
+ EVT VT = N->getValueType(0);
+
+ // The following checks bail if this is not a halving operation.
+
+ ElementCount ResEC = VT.getVectorElementCount();
+
+ if (InVT.getVectorElementCount().Min != (ResEC.Min * 2))
+ return;
+
+ auto *CIndex = dyn_cast<ConstantSDNode>(N->getOperand(1));
+ if (!CIndex)
+ return;
+
+ unsigned Index = CIndex->getZExtValue();
+ if ((Index != 0) && (Index != ResEC.Min))
+ return;
+
+ unsigned Opcode = (Index == 0) ? AArch64ISD::UUNPKLO : AArch64ISD::UUNPKHI;
+ EVT ExtendedHalfVT = VT.widenIntegerVectorElementType(*DAG.getContext());
+
+ SDValue Half = DAG.getNode(Opcode, DL, ExtendedHalfVT, N->getOperand(0));
+ Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, Half));
+}
+
// Create an even/odd pair of X registers holding integer value V.
static SDValue createGPRPairNode(SelectionDAG &DAG, SDValue V) {
SDLoc dl(V.getNode());
@@ -13899,6 +14009,9 @@ void AArch64TargetLowering::ReplaceNodeResults(
Results.append({Pair, Result.getValue(2) /* Chain */});
return;
}
+ case ISD::EXTRACT_SUBVECTOR:
+ ReplaceExtractSubVectorResults(N, Results, DAG);
+ return;
case ISD::INTRINSIC_WO_CHAIN: {
EVT VT = N->getValueType(0);
assert((VT == MVT::i8 || VT == MVT::i16) &&
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 81af7e2235a2..e42c0b6e05b7 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -889,6 +889,9 @@ class AArch64TargetLowering : public TargetLowering {
void ReplaceNodeResults(SDNode *N, SmallVectorImpl<SDValue> &Results,
SelectionDAG &DAG) const override;
+ void ReplaceExtractSubVectorResults(SDNode *N,
+ SmallVectorImpl<SDValue> &Results,
+ SelectionDAG &DAG) const;
bool shouldNormalizeToSelectSequence(LLVMContext &, EVT) const override;
diff --git a/llvm/test/CodeGen/AArch64/sve-sext-zext.ll b/llvm/test/CodeGen/AArch64/sve-sext-zext.ll
index f9a527c1fc8c..24cf433306bb 100644
--- a/llvm/test/CodeGen/AArch64/sve-sext-zext.ll
+++ b/llvm/test/CodeGen/AArch64/sve-sext-zext.ll
@@ -186,3 +186,143 @@ define <vscale x 2 x i64> @zext_i32_i64(<vscale x 2 x i32> %a) {
%r = zext <vscale x 2 x i32> %a to <vscale x 2 x i64>
ret <vscale x 2 x i64> %r
}
+
+; Extending to illegal types
+
+define <vscale x 16 x i16> @sext_b_to_h(<vscale x 16 x i8> %a) {
+; CHECK-LABEL: sext_b_to_h:
+; CHECK: // %bb.0:
+; CHECK-NEXT: sunpklo z2.h, z0.b
+; CHECK-NEXT: sunpkhi z1.h, z0.b
+; CHECK-NEXT: mov z0.d, z2.d
+; CHECK-NEXT: ret
+ %ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i16>
+ ret <vscale x 16 x i16> %ext
+}
+
+define <vscale x 8 x i32> @sext_h_to_s(<vscale x 8 x i16> %a) {
+; CHECK-LABEL: sext_h_to_s:
+; CHECK: // %bb.0:
+; CHECK-NEXT: sunpklo z2.s, z0.h
+; CHECK-NEXT: sunpkhi z1.s, z0.h
+; CHECK-NEXT: mov z0.d, z2.d
+; CHECK-NEXT: ret
+ %ext = sext <vscale x 8 x i16> %a to <vscale x 8 x i32>
+ ret <vscale x 8 x i32> %ext
+}
+
+define <vscale x 4 x i64> @sext_s_to_d(<vscale x 4 x i32> %a) {
+; CHECK-LABEL: sext_s_to_d:
+; CHECK: // %bb.0:
+; CHECK-NEXT: sunpklo z2.d, z0.s
+; CHECK-NEXT: sunpkhi z1.d, z0.s
+; CHECK-NEXT: mov z0.d, z2.d
+; CHECK-NEXT: ret
+ %ext = sext <vscale x 4 x i32> %a to <vscale x 4 x i64>
+ ret <vscale x 4 x i64> %ext
+}
+
+define <vscale x 16 x i32> @sext_b_to_s(<vscale x 16 x i8> %a) {
+; CHECK-LABEL: sext_b_to_s:
+; CHECK: // %bb.0:
+; CHECK-NEXT: sunpklo z1.h, z0.b
+; CHECK-NEXT: sunpkhi z3.h, z0.b
+; CHECK-NEXT: sunpklo z0.s, z1.h
+; CHECK-NEXT: sunpkhi z1.s, z1.h
+; CHECK-NEXT: sunpklo z2.s, z3.h
+; CHECK-NEXT: sunpkhi z3.s, z3.h
+; CHECK-NEXT: ret
+ %ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ ret <vscale x 16 x i32> %ext
+}
+
+define <vscale x 16 x i64> @sext_b_to_d(<vscale x 16 x i8> %a) {
+; CHECK-LABEL: sext_b_to_d:
+; CHECK: // %bb.0:
+; CHECK-NEXT: sunpklo z1.h, z0.b
+; CHECK-NEXT: sunpkhi z0.h, z0.b
+; CHECK-NEXT: sunpklo z2.s, z1.h
+; CHECK-NEXT: sunpkhi z3.s, z1.h
+; CHECK-NEXT: sunpklo z5.s, z0.h
+; CHECK-NEXT: sunpkhi z7.s, z0.h
+; CHECK-NEXT: sunpklo z0.d, z2.s
+; CHECK-NEXT: sunpkhi z1.d, z2.s
+; CHECK-NEXT: sunpklo z2.d, z3.s
+; CHECK-NEXT: sunpkhi z3.d, z3.s
+; CHECK-NEXT: sunpklo z4.d, z5.s
+; CHECK-NEXT: sunpkhi z5.d, z5.s
+; CHECK-NEXT: sunpklo z6.d, z7.s
+; CHECK-NEXT: sunpkhi z7.d, z7.s
+; CHECK-NEXT: ret
+ %ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
+ ret <vscale x 16 x i64> %ext
+}
+
+define <vscale x 16 x i16> @zext_b_to_h(<vscale x 16 x i8> %a) {
+; CHECK-LABEL: zext_b_to_h:
+; CHECK: // %bb.0:
+; CHECK-NEXT: uunpklo z2.h, z0.b
+; CHECK-NEXT: uunpkhi z1.h, z0.b
+; CHECK-NEXT: mov z0.d, z2.d
+; CHECK-NEXT: ret
+ %ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i16>
+ ret <vscale x 16 x i16> %ext
+}
+
+define <vscale x 8 x i32> @zext_h_to_s(<vscale x 8 x i16> %a) {
+; CHECK-LABEL: zext_h_to_s:
+; CHECK: // %bb.0:
+; CHECK-NEXT: uunpklo z2.s, z0.h
+; CHECK-NEXT: uunpkhi z1.s, z0.h
+; CHECK-NEXT: mov z0.d, z2.d
+; CHECK-NEXT: ret
+ %ext = zext <vscale x 8 x i16> %a to <vscale x 8 x i32>
+ ret <vscale x 8 x i32> %ext
+}
+
+define <vscale x 4 x i64> @zext_s_to_d(<vscale x 4 x i32> %a) {
+; CHECK-LABEL: zext_s_to_d:
+; CHECK: // %bb.0:
+; CHECK-NEXT: uunpklo z2.d, z0.s
+; CHECK-NEXT: uunpkhi z1.d, z0.s
+; CHECK-NEXT: mov z0.d, z2.d
+; CHECK-NEXT: ret
+ %ext = zext <vscale x 4 x i32> %a to <vscale x 4 x i64>
+ ret <vscale x 4 x i64> %ext
+}
+
+define <vscale x 16 x i32> @zext_b_to_s(<vscale x 16 x i8> %a) {
+; CHECK-LABEL: zext_b_to_s:
+; CHECK: // %bb.0:
+; CHECK-NEXT: uunpklo z1.h, z0.b
+; CHECK-NEXT: uunpkhi z3.h, z0.b
+; CHECK-NEXT: uunpklo z0.s, z1.h
+; CHECK-NEXT: uunpkhi z1.s, z1.h
+; CHECK-NEXT: uunpklo z2.s, z3.h
+; CHECK-NEXT: uunpkhi z3.s, z3.h
+; CHECK-NEXT: ret
+ %ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ ret <vscale x 16 x i32> %ext
+}
+
+define <vscale x 16 x i64> @zext_b_to_d(<vscale x 16 x i8> %a) {
+; CHECK-LABEL: zext_b_to_d:
+; CHECK: // %bb.0:
+; CHECK-NEXT: uunpklo z1.h, z0.b
+; CHECK-NEXT: uunpkhi z0.h, z0.b
+; CHECK-NEXT: uunpklo z2.s, z1.h
+; CHECK-NEXT: uunpkhi z3.s, z1.h
+; CHECK-NEXT: uunpklo z5.s, z0.h
+; CHECK-NEXT: uunpkhi z7.s, z0.h
+; CHECK-NEXT: uunpklo z0.d, z2.s
+; CHECK-NEXT: uunpkhi z1.d, z2.s
+; CHECK-NEXT: uunpklo z2.d, z3.s
+; CHECK-NEXT: uunpkhi z3.d, z3.s
+; CHECK-NEXT: uunpklo z4.d, z5.s
+; CHECK-NEXT: uunpkhi z5.d, z5.s
+; CHECK-NEXT: uunpklo z6.d, z7.s
+; CHECK-NEXT: uunpkhi z7.d, z7.s
+; CHECK-NEXT: ret
+ %ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
+ ret <vscale x 16 x i64> %ext
+}
More information about the llvm-commits
mailing list