[llvm] af1c8f0 - [AArch64][SVE] Folds VSELECT if the predicate is all active.

Sander de Smalen via llvm-commits llvm-commits at lists.llvm.org
Thu Jan 27 07:59:04 PST 2022


Author: Sander de Smalen
Date: 2022-01-27T15:58:56Z
New Revision: af1c8f0d142179826197f22c3880c980e6e47b3d

URL: https://github.com/llvm/llvm-project/commit/af1c8f0d142179826197f22c3880c980e6e47b3d
DIFF: https://github.com/llvm/llvm-project/commit/af1c8f0d142179826197f22c3880c980e6e47b3d.diff

LOG: [AArch64][SVE] Folds VSELECT if the predicate is all active.

This adds the following changes:

* Fold: vselect(<all active predicate>, x, y) => x
* Extend isAllActivePredicate to take vscale_range into account, e.g.
  isAllActivePredicate(vl16) for nxv16i1 and vscale == 1 => true.
  isAllActivePredicate(vl32) for nxv16i1 and vscale == 2 => true.

Differential Revision: https://reviews.llvm.org/D118147

Added: 
    llvm/test/CodeGen/AArch64/sve-vselect-fold.ll

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.h

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
index fe9b2f8883b9d..899f069abdd4b 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp
@@ -5147,5 +5147,5 @@ bool AArch64DAGToDAGISel::SelectAllActivePredicate(SDValue N) {
   const AArch64TargetLowering *TLI =
       static_cast<const AArch64TargetLowering *>(getTargetLowering());
 
-  return TLI->isAllActivePredicate(N);
+  return TLI->isAllActivePredicate(*CurDAG, N);
 }

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index b395231c69a75..a26bbc77f2482 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -15087,7 +15087,15 @@ static SDValue combineSVEReductionOrderedFP(SDNode *N, unsigned Opc,
                      Zero);
 }
 
-static bool isAllActivePredicate(SDValue N) {
+static bool isAllInactivePredicate(SDValue N) {
+  // Look through cast.
+  while (N.getOpcode() == AArch64ISD::REINTERPRET_CAST)
+    N = N.getOperand(0);
+
+  return N.getOpcode() == AArch64ISD::PFALSE;
+}
+
+static bool isAllActivePredicate(SelectionDAG &DAG, SDValue N) {
   unsigned NumElts = N.getValueType().getVectorMinNumElements();
 
   // Look through cast.
@@ -15106,6 +15114,21 @@ static bool isAllActivePredicate(SDValue N) {
       N.getConstantOperandVal(0) == AArch64SVEPredPattern::all)
     return N.getValueType().getVectorMinNumElements() >= NumElts;
 
+  // If we're compiling for a specific vector-length, we can check if the
+  // pattern's VL equals that of the scalable vector at runtime.
+  if (N.getOpcode() == AArch64ISD::PTRUE) {
+    const auto &Subtarget =
+        static_cast<const AArch64Subtarget &>(DAG.getSubtarget());
+    unsigned MinSVESize = Subtarget.getMinSVEVectorSizeInBits();
+    unsigned MaxSVESize = Subtarget.getMaxSVEVectorSizeInBits();
+    if (MaxSVESize && MinSVESize == MaxSVESize) {
+      unsigned VScale = MaxSVESize / AArch64::SVEBitsPerBlock;
+      unsigned PatNumElts =
+          getNumElementsFromSVEPredPattern(N.getConstantOperandVal(0));
+      return PatNumElts == (NumElts * VScale);
+    }
+  }
+
   return false;
 }
 
@@ -15122,7 +15145,7 @@ static SDValue convertMergedOpToPredOp(SDNode *N, unsigned Opc,
   SDValue Op2 = N->getOperand(SwapOperands ? 2 : 3);
 
   // ISD way to specify an all active predicate.
-  if (isAllActivePredicate(Pg)) {
+  if (isAllActivePredicate(DAG, Pg)) {
     if (UnpredOp)
       return DAG.getNode(Opc, SDLoc(N), N->getValueType(0), Op1, Op2);
 
@@ -16793,6 +16816,12 @@ static SDValue performVSelectCombine(SDNode *N, SelectionDAG &DAG) {
   SDValue N0 = N->getOperand(0);
   EVT CCVT = N0.getValueType();
 
+  if (isAllActivePredicate(DAG, N0))
+    return N->getOperand(1);
+
+  if (isAllInactivePredicate(N0))
+    return N->getOperand(2);
+
   // Check for sign pattern (VSELECT setgt, iN lhs, -1, 1, -1) and transform
   // into (OR (ASR lhs, N-1), 1), which requires less instructions for the
   // supported types.
@@ -19364,7 +19393,7 @@ SDValue AArch64TargetLowering::LowerPredReductionToSVE(SDValue ReduceOp,
   default:
     return SDValue();
   case ISD::VECREDUCE_OR:
-    if (isAllActivePredicate(Pg))
+    if (isAllActivePredicate(DAG, Pg))
       // The predicate can be 'Op' because
       // vecreduce_or(Op & <all true>) <=> vecreduce_or(Op).
       return getPTest(DAG, VT, Op, Op, AArch64CC::ANY_ACTIVE);
@@ -19813,8 +19842,9 @@ SDValue AArch64TargetLowering::getSVESafeBitCast(EVT VT, SDValue Op,
   return Op;
 }
 
-bool AArch64TargetLowering::isAllActivePredicate(SDValue N) const {
-  return ::isAllActivePredicate(N);
+bool AArch64TargetLowering::isAllActivePredicate(SelectionDAG &DAG,
+                                                 SDValue N) const {
+  return ::isAllActivePredicate(DAG, N);
 }
 
 EVT AArch64TargetLowering::getPromotedVTForPredicate(EVT VT) const {

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 90b947dd814b2..ca6c70297c0b4 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -844,7 +844,7 @@ class AArch64TargetLowering : public TargetLowering {
     return 128;
   }
 
-  bool isAllActivePredicate(SDValue N) const;
+  bool isAllActivePredicate(SelectionDAG &DAG, SDValue N) const;
   EVT getPromotedVTForPredicate(EVT VT) const;
 
   EVT getAsmOperandValueType(const DataLayout &DL, Type *Ty,

diff  --git a/llvm/test/CodeGen/AArch64/sve-vselect-fold.ll b/llvm/test/CodeGen/AArch64/sve-vselect-fold.ll
new file mode 100644
index 0000000000000..159b85d148b78
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-vselect-fold.ll
@@ -0,0 +1,60 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve -verify-machineinstrs < %s | FileCheck %s
+
+; Test that the select returns %true, because the predicate is all active.
+define <vscale x 4 x i32> @select_ptrue_fold_all_active(<vscale x 4 x i32> %false, <vscale x 4 x i32> %true) {
+; CHECK-LABEL: select_ptrue_fold_all_active:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov z0.d, z1.d
+; CHECK-NEXT:    ret
+  %p = call <vscale x 4 x i1> @llvm.aarch64.sve.ptrue.nxv4i1(i32 31)
+  %res = select <vscale x 4 x i1> %p, <vscale x 4 x i32> %true, <vscale x 4 x i32> %false
+  ret <vscale x 4 x i32> %res
+}
+
+; Test that the select returns %true, because the predicate is all active for vscale_range(2, 2)
+define <vscale x 4 x i32> @select_ptrue_fold_vl8(<vscale x 4 x i32> %false, <vscale x 4 x i32> %true) vscale_range(2, 2) {
+; CHECK-LABEL: select_ptrue_fold_vl8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov z0.d, z1.d
+; CHECK-NEXT:    ret
+  %p = call <vscale x 4 x i1> @llvm.aarch64.sve.ptrue.nxv4i1(i32 8)
+  %res = select <vscale x 4 x i1> %p, <vscale x 4 x i32> %true, <vscale x 4 x i32> %false
+  ret <vscale x 4 x i32> %res
+}
+
+define <vscale x 16 x i8> @select_ptrue_fold_all_inactive(<vscale x 16 x i8> %true, <vscale x 16 x i8> %false) {
+; CHECK-LABEL: select_ptrue_fold_all_inactive:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov z0.d, z1.d
+; CHECK-NEXT:    ret
+  %p = call <vscale x 16 x  i1> @llvm.aarch64.sve.convert.from.svbool.nxv16i1(<vscale x 16 x i1> zeroinitializer)
+  %res = select <vscale x 16 x  i1> %p, <vscale x 16 x i8> %true, <vscale x 16 x i8> %false
+  ret <vscale x 16 x i8> %res
+}
+
+define <vscale x 4 x i32> @select_ptrue_fold_all_inactive_reinterpret(<vscale x 4 x i32> %true, <vscale x 4 x i32> %false) {
+; CHECK-LABEL: select_ptrue_fold_all_inactive_reinterpret:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov z0.d, z1.d
+; CHECK-NEXT:    ret
+  %p = call <vscale x 4 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv4i1(<vscale x 16 x i1> zeroinitializer)
+  %res = select <vscale x 4 x i1> %p, <vscale x 4 x i32> %true, <vscale x 4 x i32> %false
+  ret <vscale x 4 x i32> %res
+}
+
+; Test that the select remains, because predicate is not all active (only half lanes are set for vscale_range(2, 2))
+define <vscale x 4 x i32> @select_ptrue_no_fold_vl4(<vscale x 4 x i32> %true, <vscale x 4 x i32> %false) vscale_range(2, 2) {
+; CHECK-LABEL: select_ptrue_no_fold_vl4:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ptrue p0.s, vl4
+; CHECK-NEXT:    sel z0.s, p0, z0.s, z1.s
+; CHECK-NEXT:    ret
+  %p = call <vscale x 4 x i1> @llvm.aarch64.sve.ptrue.nxv4i1(i32 4)
+  %res = select <vscale x 4 x i1> %p, <vscale x 4 x i32> %true, <vscale x 4 x i32> %false
+  ret <vscale x 4 x i32> %res
+}
+
+declare <vscale x 4 x i1> @llvm.aarch64.sve.ptrue.nxv4i1(i32)
+declare <vscale x 4 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv4i1(<vscale x 16 x i1>)
+declare <vscale x 16 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv16i1(<vscale x 16 x i1>)


        


More information about the llvm-commits mailing list