[llvm] 690db16 - [AArch64] Make nxv1i1 types a legal type for SVE.
Sander de Smalen via llvm-commits
llvm-commits at lists.llvm.org
Fri Jul 1 08:12:56 PDT 2022
Author: Sander de Smalen
Date: 2022-07-01T15:11:13Z
New Revision: 690db164226fb1d454c5e592726a8bc0de16c6b5
URL: https://github.com/llvm/llvm-project/commit/690db164226fb1d454c5e592726a8bc0de16c6b5
DIFF: https://github.com/llvm/llvm-project/commit/690db164226fb1d454c5e592726a8bc0de16c6b5.diff
LOG: [AArch64] Make nxv1i1 types a legal type for SVE.
One motivation to add support for these types are the LD1Q/ST1Q
instructions in SME, for which we have defined a number of load/store
intrinsics which at the moment still take a `<vscale x 16 x i1>` predicate
regardless of their element type.
This patch adds basic support for the nxv1i1 type such that it can be passed/returned
from functions, as well as some basic support to support some existing tests that
result in a nxv1i1 type. It also adds support for splats.
Other operations (e.g. insert/extract subvector, logical ops, etc) will be
supported in follow-up patches.
Reviewed By: paulwalker-arm, efriedma
Differential Revision: https://reviews.llvm.org/D128665
Added:
Modified:
llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
llvm/lib/Target/AArch64/AArch64CallingConvention.td
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
llvm/lib/Target/AArch64/AArch64RegisterInfo.td
llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
llvm/lib/Target/AArch64/SVEInstrFormats.td
llvm/test/CodeGen/AArch64/sve-extract-scalable-vector.ll
llvm/test/CodeGen/AArch64/sve-select.ll
llvm/test/CodeGen/AArch64/sve-zeroinit.ll
Removed:
################################################################################
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
index 44521bb537917..fa555be00ded0 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
@@ -6653,7 +6653,7 @@ SDValue DAGTypeLegalizer::ModifyToType(SDValue InOp, EVT NVT,
EVT InVT = InOp.getValueType();
assert(InVT.getVectorElementType() == NVT.getVectorElementType() &&
"input and widen element type must match");
- assert(!InVT.isScalableVector() && !NVT.isScalableVector() &&
+ assert(InVT.isScalableVector() == NVT.isScalableVector() &&
"cannot modify scalable vectors in this way");
SDLoc dl(InOp);
@@ -6661,10 +6661,10 @@ SDValue DAGTypeLegalizer::ModifyToType(SDValue InOp, EVT NVT,
if (InVT == NVT)
return InOp;
- unsigned InNumElts = InVT.getVectorNumElements();
- unsigned WidenNumElts = NVT.getVectorNumElements();
- if (WidenNumElts > InNumElts && WidenNumElts % InNumElts == 0) {
- unsigned NumConcat = WidenNumElts / InNumElts;
+ ElementCount InEC = InVT.getVectorElementCount();
+ ElementCount WidenEC = NVT.getVectorElementCount();
+ if (WidenEC.hasKnownScalarFactor(InEC)) {
+ unsigned NumConcat = WidenEC.getKnownScalarFactor(InEC);
SmallVector<SDValue, 16> Ops(NumConcat);
SDValue FillVal = FillWithZeroes ? DAG.getConstant(0, dl, InVT) :
DAG.getUNDEF(InVT);
@@ -6675,10 +6675,16 @@ SDValue DAGTypeLegalizer::ModifyToType(SDValue InOp, EVT NVT,
return DAG.getNode(ISD::CONCAT_VECTORS, dl, NVT, Ops);
}
- if (WidenNumElts < InNumElts && InNumElts % WidenNumElts)
+ if (InEC.hasKnownScalarFactor(WidenEC))
return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, NVT, InOp,
DAG.getVectorIdxConstant(0, dl));
+ assert(!InVT.isScalableVector() && !NVT.isScalableVector() &&
+ "Scalable vectors should have been handled already.");
+
+ unsigned InNumElts = InEC.getFixedValue();
+ unsigned WidenNumElts = WidenEC.getFixedValue();
+
// Fall back to extract and build.
SmallVector<SDValue, 16> Ops(WidenNumElts);
EVT EltVT = NVT.getVectorElementType();
diff --git a/llvm/lib/Target/AArch64/AArch64CallingConvention.td b/llvm/lib/Target/AArch64/AArch64CallingConvention.td
index f26151536a586..c0da242a26de8 100644
--- a/llvm/lib/Target/AArch64/AArch64CallingConvention.td
+++ b/llvm/lib/Target/AArch64/AArch64CallingConvention.td
@@ -82,9 +82,9 @@ def CC_AArch64_AAPCS : CallingConv<[
nxv2bf16, nxv4bf16, nxv8bf16, nxv2f32, nxv4f32, nxv2f64],
CCPassIndirect<i64>>,
- CCIfType<[nxv2i1, nxv4i1, nxv8i1, nxv16i1],
+ CCIfType<[nxv1i1, nxv2i1, nxv4i1, nxv8i1, nxv16i1],
CCAssignToReg<[P0, P1, P2, P3]>>,
- CCIfType<[nxv2i1, nxv4i1, nxv8i1, nxv16i1],
+ CCIfType<[nxv1i1, nxv2i1, nxv4i1, nxv8i1, nxv16i1],
CCPassIndirect<i64>>,
// Handle i1, i8, i16, i32, i64, f32, f64 and v2f64 by passing in registers,
@@ -149,7 +149,7 @@ def RetCC_AArch64_AAPCS : CallingConv<[
nxv2bf16, nxv4bf16, nxv8bf16, nxv2f32, nxv4f32, nxv2f64],
CCAssignToReg<[Z0, Z1, Z2, Z3, Z4, Z5, Z6, Z7]>>,
- CCIfType<[nxv2i1, nxv4i1, nxv8i1, nxv16i1],
+ CCIfType<[nxv1i1, nxv2i1, nxv4i1, nxv8i1, nxv16i1],
CCAssignToReg<[P0, P1, P2, P3]>>
]>;
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 63cd8f93083cf..abfe2d5071119 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -292,6 +292,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
if (Subtarget->hasSVE() || Subtarget->hasSME()) {
// Add legal sve predicate types
+ addRegisterClass(MVT::nxv1i1, &AArch64::PPRRegClass);
addRegisterClass(MVT::nxv2i1, &AArch64::PPRRegClass);
addRegisterClass(MVT::nxv4i1, &AArch64::PPRRegClass);
addRegisterClass(MVT::nxv8i1, &AArch64::PPRRegClass);
@@ -1156,7 +1157,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
MVT::nxv4i16, MVT::nxv4i32, MVT::nxv8i8, MVT::nxv8i16 })
setOperationAction(ISD::SIGN_EXTEND_INREG, VT, Legal);
- for (auto VT : {MVT::nxv16i1, MVT::nxv8i1, MVT::nxv4i1, MVT::nxv2i1}) {
+ for (auto VT :
+ {MVT::nxv16i1, MVT::nxv8i1, MVT::nxv4i1, MVT::nxv2i1, MVT::nxv1i1}) {
setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
setOperationAction(ISD::SELECT, VT, Custom);
setOperationAction(ISD::SETCC, VT, Custom);
@@ -4676,7 +4678,6 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
Op.getOperand(2), Op.getOperand(3),
DAG.getValueType(Op.getValueType().changeVectorElementType(MVT::i32)),
Op.getOperand(1));
-
case Intrinsic::localaddress: {
const auto &MF = DAG.getMachineFunction();
const auto *RegInfo = Subtarget->getRegisterInfo();
@@ -10551,8 +10552,13 @@ SDValue AArch64TargetLowering::LowerSPLAT_VECTOR(SDValue Op,
DAG.getValueType(MVT::i1));
SDValue ID =
DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo, DL, MVT::i64);
- return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT, ID,
- DAG.getConstant(0, DL, MVT::i64), SplatVal);
+ SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
+ if (VT == MVT::nxv1i1)
+ return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::nxv1i1,
+ DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, MVT::nxv2i1, ID,
+ Zero, SplatVal),
+ Zero);
+ return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT, ID, Zero, SplatVal);
}
SDValue AArch64TargetLowering::LowerDUPQLane(SDValue Op,
diff --git a/llvm/lib/Target/AArch64/AArch64RegisterInfo.td b/llvm/lib/Target/AArch64/AArch64RegisterInfo.td
index e42feea959667..7a2b165570cbc 100644
--- a/llvm/lib/Target/AArch64/AArch64RegisterInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64RegisterInfo.td
@@ -871,7 +871,7 @@ class ZPRRegOp <string Suffix, AsmOperandClass C, ElementSizeEnum Size,
// SVE predicate register classes.
class PPRClass<int lastreg> : RegisterClass<
"AArch64",
- [ nxv16i1, nxv8i1, nxv4i1, nxv2i1 ], 16,
+ [ nxv16i1, nxv8i1, nxv4i1, nxv2i1, nxv1i1 ], 16,
(sequence "P%u", 0, lastreg)> {
let Size = 16;
}
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 6bd3a96e5fcb9..68ff1b78e84bd 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -748,6 +748,11 @@ let Predicates = [HasSVEorSME] in {
defm PUNPKLO_PP : sve_int_perm_punpk<0b0, "punpklo", int_aarch64_sve_punpklo>;
defm PUNPKHI_PP : sve_int_perm_punpk<0b1, "punpkhi", int_aarch64_sve_punpkhi>;
+ // Define pattern for `nxv1i1 splat_vector(1)`.
+ // We do this here instead of in ISelLowering such that PatFrag's can still
+ // recognize a splat.
+ def : Pat<(nxv1i1 immAllOnesV), (PUNPKLO_PP (PTRUE_D 31))>;
+
defm MOVPRFX_ZPzZ : sve_int_movprfx_pred_zero<0b000, "movprfx">;
defm MOVPRFX_ZPmZ : sve_int_movprfx_pred_merge<0b001, "movprfx">;
def MOVPRFX_ZZ : sve_int_bin_cons_misc_0_c<0b00000001, "movprfx", ZPRAny>;
@@ -1509,6 +1514,10 @@ let Predicates = [HasSVEorSME] in {
defm TRN2_PPP : sve_int_perm_bin_perm_pp<0b101, "trn2", AArch64trn2>;
// Extract lo/hi halves of legal predicate types.
+ def : Pat<(nxv1i1 (extract_subvector (nxv2i1 PPR:$Ps), (i64 0))),
+ (PUNPKLO_PP PPR:$Ps)>;
+ def : Pat<(nxv1i1 (extract_subvector (nxv2i1 PPR:$Ps), (i64 1))),
+ (PUNPKHI_PP PPR:$Ps)>;
def : Pat<(nxv2i1 (extract_subvector (nxv4i1 PPR:$Ps), (i64 0))),
(PUNPKLO_PP PPR:$Ps)>;
def : Pat<(nxv2i1 (extract_subvector (nxv4i1 PPR:$Ps), (i64 2))),
@@ -1599,6 +1608,8 @@ let Predicates = [HasSVEorSME] in {
(UUNPKHI_ZZ_D (UUNPKHI_ZZ_S ZPR:$Zs))>;
// Concatenate two predicates.
+ def : Pat<(nxv2i1 (concat_vectors nxv1i1:$p1, nxv1i1:$p2)),
+ (UZP1_PPP_D $p1, $p2)>;
def : Pat<(nxv4i1 (concat_vectors nxv2i1:$p1, nxv2i1:$p2)),
(UZP1_PPP_S $p1, $p2)>;
def : Pat<(nxv8i1 (concat_vectors nxv4i1:$p1, nxv4i1:$p2)),
@@ -2298,15 +2309,23 @@ let Predicates = [HasSVEorSME] in {
def : Pat<(nxv16i1 (reinterpret_cast (nxv8i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
def : Pat<(nxv16i1 (reinterpret_cast (nxv4i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
def : Pat<(nxv16i1 (reinterpret_cast (nxv2i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
+ def : Pat<(nxv16i1 (reinterpret_cast (nxv1i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
def : Pat<(nxv8i1 (reinterpret_cast (nxv16i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
def : Pat<(nxv8i1 (reinterpret_cast (nxv4i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
def : Pat<(nxv8i1 (reinterpret_cast (nxv2i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
+ def : Pat<(nxv8i1 (reinterpret_cast (nxv1i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
def : Pat<(nxv4i1 (reinterpret_cast (nxv16i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
def : Pat<(nxv4i1 (reinterpret_cast (nxv8i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
def : Pat<(nxv4i1 (reinterpret_cast (nxv2i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
+ def : Pat<(nxv4i1 (reinterpret_cast (nxv1i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
def : Pat<(nxv2i1 (reinterpret_cast (nxv16i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
def : Pat<(nxv2i1 (reinterpret_cast (nxv8i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
def : Pat<(nxv2i1 (reinterpret_cast (nxv4i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
+ def : Pat<(nxv2i1 (reinterpret_cast (nxv1i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
+ def : Pat<(nxv1i1 (reinterpret_cast (nxv16i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
+ def : Pat<(nxv1i1 (reinterpret_cast (nxv8i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
+ def : Pat<(nxv1i1 (reinterpret_cast (nxv4i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
+ def : Pat<(nxv1i1 (reinterpret_cast (nxv2i1 PPR:$src))), (COPY_TO_REGCLASS PPR:$src, PPR)>;
// These allow casting from/to unpacked floating-point types.
def : Pat<(nxv2f16 (reinterpret_cast (nxv8f16 ZPR:$src))), (COPY_TO_REGCLASS ZPR:$src, ZPR)>;
diff --git a/llvm/lib/Target/AArch64/SVEInstrFormats.td b/llvm/lib/Target/AArch64/SVEInstrFormats.td
index 13c04eedca2da..3631536a32b93 100644
--- a/llvm/lib/Target/AArch64/SVEInstrFormats.td
+++ b/llvm/lib/Target/AArch64/SVEInstrFormats.td
@@ -647,6 +647,7 @@ multiclass sve_int_pfalse<bits<6> opc, string asm> {
def : Pat<(nxv8i1 immAllZerosV), (!cast<Instruction>(NAME))>;
def : Pat<(nxv4i1 immAllZerosV), (!cast<Instruction>(NAME))>;
def : Pat<(nxv2i1 immAllZerosV), (!cast<Instruction>(NAME))>;
+ def : Pat<(nxv1i1 immAllZerosV), (!cast<Instruction>(NAME))>;
}
class sve_int_ptest<bits<6> opc, string asm>
@@ -1681,6 +1682,7 @@ multiclass sve_int_pred_log<bits<4> opc, string asm, SDPatternOperator op,
def : SVE_3_Op_Pat<nxv8i1, op, nxv8i1, nxv8i1, nxv8i1, !cast<Instruction>(NAME)>;
def : SVE_3_Op_Pat<nxv4i1, op, nxv4i1, nxv4i1, nxv4i1, !cast<Instruction>(NAME)>;
def : SVE_3_Op_Pat<nxv2i1, op, nxv2i1, nxv2i1, nxv2i1, !cast<Instruction>(NAME)>;
+ def : SVE_3_Op_Pat<nxv1i1, op, nxv1i1, nxv1i1, nxv1i1, !cast<Instruction>(NAME)>;
def : SVE_2_Op_AllActive_Pat<nxv16i1, op_nopred, nxv16i1, nxv16i1,
!cast<Instruction>(NAME), PTRUE_B>;
def : SVE_2_Op_AllActive_Pat<nxv8i1, op_nopred, nxv8i1, nxv8i1,
diff --git a/llvm/test/CodeGen/AArch64/sve-extract-scalable-vector.ll b/llvm/test/CodeGen/AArch64/sve-extract-scalable-vector.ll
index 0c27b4e767d11..18b18f58d7247 100644
--- a/llvm/test/CodeGen/AArch64/sve-extract-scalable-vector.ll
+++ b/llvm/test/CodeGen/AArch64/sve-extract-scalable-vector.ll
@@ -1078,3 +1078,27 @@ define <vscale x 2 x i1> @extract_nxv2i1_nxv16i1_all_zero() {
declare <vscale x 2 x float> @llvm.vector.extract.nxv2f32.nxv4f32(<vscale x 4 x float>, i64)
declare <vscale x 4 x i32> @llvm.vector.extract.nxv4i32.nxv8i32(<vscale x 8 x i32>, i64)
+
+;
+; Extract nxv1i1 type from: nxv2i1
+;
+
+define <vscale x 1 x i1> @extract_nxv1i1_nxv2i1_0(<vscale x 2 x i1> %in) {
+; CHECK-LABEL: extract_nxv1i1_nxv2i1_0:
+; CHECK: // %bb.0:
+; CHECK-NEXT: punpklo p0.h, p0.b
+; CHECK-NEXT: ret
+ %res = call <vscale x 1 x i1> @llvm.vector.extract.nxv1i1.nxv2i1(<vscale x 2 x i1> %in, i64 0)
+ ret <vscale x 1 x i1> %res
+}
+
+define <vscale x 1 x i1> @extract_nxv1i1_nxv2i1_1(<vscale x 2 x i1> %in) {
+; CHECK-LABEL: extract_nxv1i1_nxv2i1_1:
+; CHECK: // %bb.0:
+; CHECK-NEXT: punpkhi p0.h, p0.b
+; CHECK-NEXT: ret
+ %res = call <vscale x 1 x i1> @llvm.vector.extract.nxv1i1.nxv2i1(<vscale x 2 x i1> %in, i64 1)
+ ret <vscale x 1 x i1> %res
+}
+
+declare <vscale x 1 x i1> @llvm.vector.extract.nxv1i1.nxv2i1(<vscale x 2 x i1>, i64)
diff --git a/llvm/test/CodeGen/AArch64/sve-select.ll b/llvm/test/CodeGen/AArch64/sve-select.ll
index 857e057eb088f..5c1cfe639bc7e 100644
--- a/llvm/test/CodeGen/AArch64/sve-select.ll
+++ b/llvm/test/CodeGen/AArch64/sve-select.ll
@@ -187,6 +187,7 @@ define <vscale x 1 x i1> @select_nxv1i1(i1 %cond, <vscale x 1 x i1> %a, <vscal
; CHECK-NEXT: // kill: def $w0 killed $w0 def $x0
; CHECK-NEXT: sbfx x8, x0, #0, #1
; CHECK-NEXT: whilelo p2.d, xzr, x8
+; CHECK-NEXT: punpklo p2.h, p2.b
; CHECK-NEXT: sel p0.b, p2, p0.b, p1.b
; CHECK-NEXT: ret
%res = select i1 %cond, <vscale x 1 x i1> %a, <vscale x 1 x i1> %b
@@ -225,6 +226,7 @@ define <vscale x 4 x i32> @sel_nxv4i32(<vscale x 4 x i1> %p, <vscale x 4 x i32>
define <vscale x 1 x i64> @sel_nxv1i64(<vscale x 1 x i1> %p, <vscale x 1 x i64> %dst, <vscale x 1 x i64> %a) {
; CHECK-LABEL: sel_nxv1i64:
; CHECK: // %bb.0:
+; CHECK-NEXT: uzp1 p0.d, p0.d, p0.d
; CHECK-NEXT: mov z0.d, p0/m, z1.d
; CHECK-NEXT: ret
%sel = select <vscale x 1 x i1> %p, <vscale x 1 x i64> %a, <vscale x 1 x i64> %dst
@@ -483,6 +485,7 @@ define <vscale x 1 x i1> @icmp_select_nxv1i1(<vscale x 1 x i1> %a, <vscale x 1 x
; CHECK-NEXT: cset w8, eq
; CHECK-NEXT: sbfx x8, x8, #0, #1
; CHECK-NEXT: whilelo p2.d, xzr, x8
+; CHECK-NEXT: punpklo p2.h, p2.b
; CHECK-NEXT: sel p0.b, p2, p0.b, p1.b
; CHECK-NEXT: ret
%mask = icmp eq i64 %x0, 0
diff --git a/llvm/test/CodeGen/AArch64/sve-zeroinit.ll b/llvm/test/CodeGen/AArch64/sve-zeroinit.ll
index de7695a86ff90..c436bb7f822b7 100644
--- a/llvm/test/CodeGen/AArch64/sve-zeroinit.ll
+++ b/llvm/test/CodeGen/AArch64/sve-zeroinit.ll
@@ -52,6 +52,13 @@ define <vscale x 8 x half> @test_zeroinit_8xf16() {
ret <vscale x 8 x half> zeroinitializer
}
+define <vscale x 1 x i1> @test_zeroinit_1xi1() {
+; CHECK-LABEL: test_zeroinit_1xi1
+; CHECK: pfalse p0.b
+; CHECK-NEXT: ret
+ ret <vscale x 1 x i1> zeroinitializer
+}
+
define <vscale x 2 x i1> @test_zeroinit_2xi1() {
; CHECK-LABEL: test_zeroinit_2xi1
; CHECK: pfalse p0.b
More information about the llvm-commits
mailing list