[llvm] af1c8f0 - [AArch64][SVE] Folds VSELECT if the predicate is all active.
Sander de Smalen via llvm-commits
llvm-commits at lists.llvm.org
Thu Jan 27 07:59:04 PST 2022
Author: Sander de Smalen
Date: 2022-01-27T15:58:56Z
New Revision: af1c8f0d142179826197f22c3880c980e6e47b3d
URL: https://github.com/llvm/llvm-project/commit/af1c8f0d142179826197f22c3880c980e6e47b3d
DIFF: https://github.com/llvm/llvm-project/commit/af1c8f0d142179826197f22c3880c980e6e47b3d.diff
LOG: [AArch64][SVE] Folds VSELECT if the predicate is all active.
This adds the following changes:
* Fold: vselect(<all active predicate>, x, y) => x
* Extend isAllActivePredicate to take vscale_range into account, e.g.
isAllActivePredicate(vl16) for nxv16i1 and vscale == 1 => true.
isAllActivePredicate(vl32) for nxv16i1 and vscale == 2 => true.
Differential Revision: https://reviews.llvm.org/D118147
Added:
llvm/test/CodeGen/AArch64/sve-vselect-fold.ll
Modified:
llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
llvm/lib/Target/AArch64/AArch64ISelLowering.h
Removed:
################################################################################
diff --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
index fe9b2f8883b9d..899f069abdd4b 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
@@ -5147,5 +5147,5 @@ bool AArch64DAGToDAGISel::SelectAllActivePredicate(SDValue N) {
const AArch64TargetLowering *TLI =
static_cast<const AArch64TargetLowering *>(getTargetLowering());
- return TLI->isAllActivePredicate(N);
+ return TLI->isAllActivePredicate(*CurDAG, N);
}
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index b395231c69a75..a26bbc77f2482 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -15087,7 +15087,15 @@ static SDValue combineSVEReductionOrderedFP(SDNode *N, unsigned Opc,
Zero);
}
-static bool isAllActivePredicate(SDValue N) {
+static bool isAllInactivePredicate(SDValue N) {
+ // Look through cast.
+ while (N.getOpcode() == AArch64ISD::REINTERPRET_CAST)
+ N = N.getOperand(0);
+
+ return N.getOpcode() == AArch64ISD::PFALSE;
+}
+
+static bool isAllActivePredicate(SelectionDAG &DAG, SDValue N) {
unsigned NumElts = N.getValueType().getVectorMinNumElements();
// Look through cast.
@@ -15106,6 +15114,21 @@ static bool isAllActivePredicate(SDValue N) {
N.getConstantOperandVal(0) == AArch64SVEPredPattern::all)
return N.getValueType().getVectorMinNumElements() >= NumElts;
+ // If we're compiling for a specific vector-length, we can check if the
+ // pattern's VL equals that of the scalable vector at runtime.
+ if (N.getOpcode() == AArch64ISD::PTRUE) {
+ const auto &Subtarget =
+ static_cast<const AArch64Subtarget &>(DAG.getSubtarget());
+ unsigned MinSVESize = Subtarget.getMinSVEVectorSizeInBits();
+ unsigned MaxSVESize = Subtarget.getMaxSVEVectorSizeInBits();
+ if (MaxSVESize && MinSVESize == MaxSVESize) {
+ unsigned VScale = MaxSVESize / AArch64::SVEBitsPerBlock;
+ unsigned PatNumElts =
+ getNumElementsFromSVEPredPattern(N.getConstantOperandVal(0));
+ return PatNumElts == (NumElts * VScale);
+ }
+ }
+
return false;
}
@@ -15122,7 +15145,7 @@ static SDValue convertMergedOpToPredOp(SDNode *N, unsigned Opc,
SDValue Op2 = N->getOperand(SwapOperands ? 2 : 3);
// ISD way to specify an all active predicate.
- if (isAllActivePredicate(Pg)) {
+ if (isAllActivePredicate(DAG, Pg)) {
if (UnpredOp)
return DAG.getNode(Opc, SDLoc(N), N->getValueType(0), Op1, Op2);
@@ -16793,6 +16816,12 @@ static SDValue performVSelectCombine(SDNode *N, SelectionDAG &DAG) {
SDValue N0 = N->getOperand(0);
EVT CCVT = N0.getValueType();
+ if (isAllActivePredicate(DAG, N0))
+ return N->getOperand(1);
+
+ if (isAllInactivePredicate(N0))
+ return N->getOperand(2);
+
// Check for sign pattern (VSELECT setgt, iN lhs, -1, 1, -1) and transform
// into (OR (ASR lhs, N-1), 1), which requires less instructions for the
// supported types.
@@ -19364,7 +19393,7 @@ SDValue AArch64TargetLowering::LowerPredReductionToSVE(SDValue ReduceOp,
default:
return SDValue();
case ISD::VECREDUCE_OR:
- if (isAllActivePredicate(Pg))
+ if (isAllActivePredicate(DAG, Pg))
// The predicate can be 'Op' because
// vecreduce_or(Op & <all true>) <=> vecreduce_or(Op).
return getPTest(DAG, VT, Op, Op, AArch64CC::ANY_ACTIVE);
@@ -19813,8 +19842,9 @@ SDValue AArch64TargetLowering::getSVESafeBitCast(EVT VT, SDValue Op,
return Op;
}
-bool AArch64TargetLowering::isAllActivePredicate(SDValue N) const {
- return ::isAllActivePredicate(N);
+bool AArch64TargetLowering::isAllActivePredicate(SelectionDAG &DAG,
+ SDValue N) const {
+ return ::isAllActivePredicate(DAG, N);
}
EVT AArch64TargetLowering::getPromotedVTForPredicate(EVT VT) const {
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 90b947dd814b2..ca6c70297c0b4 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -844,7 +844,7 @@ class AArch64TargetLowering : public TargetLowering {
return 128;
}
- bool isAllActivePredicate(SDValue N) const;
+ bool isAllActivePredicate(SelectionDAG &DAG, SDValue N) const;
EVT getPromotedVTForPredicate(EVT VT) const;
EVT getAsmOperandValueType(const DataLayout &DL, Type *Ty,
diff --git a/llvm/test/CodeGen/AArch64/sve-vselect-fold.ll b/llvm/test/CodeGen/AArch64/sve-vselect-fold.ll
new file mode 100644
index 0000000000000..159b85d148b78
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-vselect-fold.ll
@@ -0,0 +1,60 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve -verify-machineinstrs < %s | FileCheck %s
+
+; Test that the select returns %true, because the predicate is all active.
+define <vscale x 4 x i32> @select_ptrue_fold_all_active(<vscale x 4 x i32> %false, <vscale x 4 x i32> %true) {
+; CHECK-LABEL: select_ptrue_fold_all_active:
+; CHECK: // %bb.0:
+; CHECK-NEXT: mov z0.d, z1.d
+; CHECK-NEXT: ret
+ %p = call <vscale x 4 x i1> @llvm.aarch64.sve.ptrue.nxv4i1(i32 31)
+ %res = select <vscale x 4 x i1> %p, <vscale x 4 x i32> %true, <vscale x 4 x i32> %false
+ ret <vscale x 4 x i32> %res
+}
+
+; Test that the select returns %true, because the predicate is all active for vscale_range(2, 2)
+define <vscale x 4 x i32> @select_ptrue_fold_vl8(<vscale x 4 x i32> %false, <vscale x 4 x i32> %true) vscale_range(2, 2) {
+; CHECK-LABEL: select_ptrue_fold_vl8:
+; CHECK: // %bb.0:
+; CHECK-NEXT: mov z0.d, z1.d
+; CHECK-NEXT: ret
+ %p = call <vscale x 4 x i1> @llvm.aarch64.sve.ptrue.nxv4i1(i32 8)
+ %res = select <vscale x 4 x i1> %p, <vscale x 4 x i32> %true, <vscale x 4 x i32> %false
+ ret <vscale x 4 x i32> %res
+}
+
+define <vscale x 16 x i8> @select_ptrue_fold_all_inactive(<vscale x 16 x i8> %true, <vscale x 16 x i8> %false) {
+; CHECK-LABEL: select_ptrue_fold_all_inactive:
+; CHECK: // %bb.0:
+; CHECK-NEXT: mov z0.d, z1.d
+; CHECK-NEXT: ret
+ %p = call <vscale x 16 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv16i1(<vscale x 16 x i1> zeroinitializer)
+ %res = select <vscale x 16 x i1> %p, <vscale x 16 x i8> %true, <vscale x 16 x i8> %false
+ ret <vscale x 16 x i8> %res
+}
+
+define <vscale x 4 x i32> @select_ptrue_fold_all_inactive_reinterpret(<vscale x 4 x i32> %true, <vscale x 4 x i32> %false) {
+; CHECK-LABEL: select_ptrue_fold_all_inactive_reinterpret:
+; CHECK: // %bb.0:
+; CHECK-NEXT: mov z0.d, z1.d
+; CHECK-NEXT: ret
+ %p = call <vscale x 4 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv4i1(<vscale x 16 x i1> zeroinitializer)
+ %res = select <vscale x 4 x i1> %p, <vscale x 4 x i32> %true, <vscale x 4 x i32> %false
+ ret <vscale x 4 x i32> %res
+}
+
+; Test that the select remains, because predicate is not all active (only half lanes are set for vscale_range(2, 2))
+define <vscale x 4 x i32> @select_ptrue_no_fold_vl4(<vscale x 4 x i32> %true, <vscale x 4 x i32> %false) vscale_range(2, 2) {
+; CHECK-LABEL: select_ptrue_no_fold_vl4:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ptrue p0.s, vl4
+; CHECK-NEXT: sel z0.s, p0, z0.s, z1.s
+; CHECK-NEXT: ret
+ %p = call <vscale x 4 x i1> @llvm.aarch64.sve.ptrue.nxv4i1(i32 4)
+ %res = select <vscale x 4 x i1> %p, <vscale x 4 x i32> %true, <vscale x 4 x i32> %false
+ ret <vscale x 4 x i32> %res
+}
+
+declare <vscale x 4 x i1> @llvm.aarch64.sve.ptrue.nxv4i1(i32)
+declare <vscale x 4 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv4i1(<vscale x 16 x i1>)
+declare <vscale x 16 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv16i1(<vscale x 16 x i1>)
More information about the llvm-commits
mailing list