[llvm] bf89d24 - [AArch64] NFC: Move safe predicate casting to a separate function.

Sander de Smalen via llvm-commits llvm-commits at lists.llvm.org
Mon Jul 4 03:58:56 PDT 2022


Author: Sander de Smalen
Date: 2022-07-04T10:32:54Z
New Revision: bf89d24f5319cb57e8458c1192480d17f00d4540

URL: https://github.com/llvm/llvm-project/commit/bf89d24f5319cb57e8458c1192480d17f00d4540
DIFF: https://github.com/llvm/llvm-project/commit/bf89d24f5319cb57e8458c1192480d17f00d4540.diff

LOG: [AArch64] NFC: Move safe predicate casting to a separate function.

This patch puts the code to safely bitcast a predicate, and possibly zero
any undefined lanes when doing a widening cast, into one place and merges
the functionality with lowerConvertToSVBool.

This is some cleanup inspired by D128665.

Reviewed By: paulwalker-arm

Differential Revision: https://reviews.llvm.org/D128926

Added: 
    

Modified: 
    llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
    llvm/lib/Target/AArch64/AArch64ISelLowering.h

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index abfe2d5071119..58a78c2e3c245 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1082,6 +1082,14 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
     setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::Other, Custom);
   }
 
+  // FIXME: Move lowering for more nodes here if those are common between
+  // SVE and SME.
+  if (Subtarget->hasSVE() || Subtarget->hasSME()) {
+    for (auto VT :
+         {MVT::nxv16i1, MVT::nxv8i1, MVT::nxv4i1, MVT::nxv2i1, MVT::nxv1i1})
+      setOperationAction(ISD::SPLAT_VECTOR, VT, Custom);
+  }
+
   if (Subtarget->hasSVE()) {
     for (auto VT : {MVT::nxv16i8, MVT::nxv8i16, MVT::nxv4i32, MVT::nxv2i64}) {
       setOperationAction(ISD::BITREVERSE, VT, Custom);
@@ -1162,7 +1170,6 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
       setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
       setOperationAction(ISD::SELECT, VT, Custom);
       setOperationAction(ISD::SETCC, VT, Custom);
-      setOperationAction(ISD::SPLAT_VECTOR, VT, Custom);
       setOperationAction(ISD::TRUNCATE, VT, Custom);
       setOperationAction(ISD::VECREDUCE_AND, VT, Custom);
       setOperationAction(ISD::VECREDUCE_OR, VT, Custom);
@@ -4333,27 +4340,47 @@ static inline SDValue getPTrue(SelectionDAG &DAG, SDLoc DL, EVT VT,
                      DAG.getTargetConstant(Pattern, DL, MVT::i32));
 }
 
-static SDValue lowerConvertToSVBool(SDValue Op, SelectionDAG &DAG) {
+SDValue AArch64TargetLowering::getSVEPredicateBitCast(EVT VT, SDValue Op,
+                                                      SelectionDAG &DAG) const {
   SDLoc DL(Op);
-  EVT OutVT = Op.getValueType();
-  SDValue InOp = Op.getOperand(1);
-  EVT InVT = InOp.getValueType();
+  EVT InVT = Op.getValueType();
+
+  assert(InVT.getVectorElementType() == MVT::i1 &&
+         VT.getVectorElementType() == MVT::i1 &&
+         "Expected a predicate-to-predicate bitcast");
+  assert(VT.isScalableVector() && isTypeLegal(VT) &&
+         InVT.isScalableVector() && isTypeLegal(InVT) &&
+         "Only expect to cast between legal scalable predicate types!");
 
   // Return the operand if the cast isn't changing type,
-  // i.e. <n x 16 x i1> -> <n x 16 x i1>
-  if (InVT == OutVT)
-    return InOp;
+  // e.g. <n x 16 x i1> -> <n x 16 x i1>
+  if (InVT == VT)
+    return Op;
+
+  SDValue Reinterpret = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Op);
 
-  SDValue Reinterpret =
-      DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, OutVT, InOp);
+  // We only have to zero the lanes if new lanes are being defined, e.g. when
+  // casting from <vscale x 2 x i1> to <vscale x 16 x i1>. If this is not the
+  // case (e.g. when casting from <vscale x 16 x i1> -> <vscale x 2 x i1>) then
+  // we can return here.
+  if (InVT.bitsGT(VT))
+    return Reinterpret;
 
-  // If the argument converted to an svbool is a ptrue or a comparison, the
-  // lanes introduced by the widening are zero by construction.
-  switch (InOp.getOpcode()) {
+  // Check if the other lanes are already known to be zeroed by
+  // construction.
+  switch (Op.getOpcode()) {
+  default:
+    // We guarantee i1 splat_vectors to zero the other lanes by
+    // implementing it with ptrue and possibly a punpklo for nxv1i1.
+    if (ISD::isConstantSplatVectorAllOnes(Op.getNode()))
+      return Reinterpret;
+    break;
   case AArch64ISD::SETCC_MERGE_ZERO:
     return Reinterpret;
   case ISD::INTRINSIC_WO_CHAIN:
-    switch (InOp.getConstantOperandVal(0)) {
+    switch (Op.getConstantOperandVal(0)) {
+    default:
+      break;
     case Intrinsic::aarch64_sve_ptrue:
     case Intrinsic::aarch64_sve_cmpeq_wide:
     case Intrinsic::aarch64_sve_cmpne_wide:
@@ -4369,15 +4396,10 @@ static SDValue lowerConvertToSVBool(SDValue Op, SelectionDAG &DAG) {
     }
   }
 
-  // Splat vectors of one will generate ptrue instructions
-  if (ISD::isConstantSplatVectorAllOnes(InOp.getNode()))
-    return Reinterpret;
-
-  // Otherwise, zero the newly introduced lanes.
-  SDValue Mask = getPTrue(DAG, DL, InVT, AArch64SVEPredPattern::all);
-  SDValue MaskReinterpret =
-      DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, OutVT, Mask);
-  return DAG.getNode(ISD::AND, DL, OutVT, Reinterpret, MaskReinterpret);
+  // Zero the newly introduced lanes.
+  SDValue Mask = DAG.getConstant(1, DL, InVT);
+  Mask = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Mask);
+  return DAG.getNode(ISD::AND, DL, VT, Reinterpret, Mask);
 }
 
 SDValue AArch64TargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op,
@@ -4546,10 +4568,9 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
   case Intrinsic::aarch64_sve_dupq_lane:
     return LowerDUPQLane(Op, DAG);
   case Intrinsic::aarch64_sve_convert_from_svbool:
-    return DAG.getNode(AArch64ISD::REINTERPRET_CAST, dl, Op.getValueType(),
-                       Op.getOperand(1));
+    return getSVEPredicateBitCast(Op.getValueType(), Op.getOperand(1), DAG);
   case Intrinsic::aarch64_sve_convert_to_svbool:
-    return lowerConvertToSVBool(Op, DAG);
+    return getSVEPredicateBitCast(MVT::nxv16i1, Op.getOperand(1), DAG);
   case Intrinsic::aarch64_sve_fneg:
     return DAG.getNode(AArch64ISD::FNEG_MERGE_PASSTHRU, dl, Op.getValueType(),
                        Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
@@ -21464,22 +21485,17 @@ SDValue AArch64TargetLowering::getSVESafeBitCast(EVT VT, SDValue Op,
                                                  SelectionDAG &DAG) const {
   SDLoc DL(Op);
   EVT InVT = Op.getValueType();
-  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
-  (void)TLI;
 
-  assert(VT.isScalableVector() && TLI.isTypeLegal(VT) &&
-         InVT.isScalableVector() && TLI.isTypeLegal(InVT) &&
+  assert(VT.isScalableVector() && isTypeLegal(VT) &&
+         InVT.isScalableVector() && isTypeLegal(InVT) &&
          "Only expect to cast between legal scalable vector types!");
-  assert((VT.getVectorElementType() == MVT::i1) ==
-             (InVT.getVectorElementType() == MVT::i1) &&
-         "Cannot cast between data and predicate scalable vector types!");
+  assert(VT.getVectorElementType() != MVT::i1 &&
+         InVT.getVectorElementType() != MVT::i1 &&
+         "For predicate bitcasts, use getSVEPredicateBitCast");
 
   if (InVT == VT)
     return Op;
 
-  if (VT.getVectorElementType() == MVT::i1)
-    return DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Op);
-
   EVT PackedVT = getPackedSVEVectorVT(VT.getVectorElementType());
   EVT PackedInVT = getPackedSVEVectorVT(InVT.getVectorElementType());
 

diff  --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 06ea918ea32e8..c7a6acc394d76 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -1148,8 +1148,13 @@ class AArch64TargetLowering : public TargetLowering {
   // These can make "bitcasting" a multiphase process. REINTERPRET_CAST is used
   // to transition between unpacked and packed types of the same element type,
   // with BITCAST used otherwise.
+  // This function does not handle predicate bitcasts.
   SDValue getSVESafeBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) const;
 
+  // Returns a safe bitcast between two scalable vector predicates, where
+  // any newly created lanes from a widening bitcast are defined as zero.
+  SDValue getSVEPredicateBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) const;
+
   bool isConstantUnsignedBitfieldExtractLegal(unsigned Opc, LLT Ty1,
                                               LLT Ty2) const override;
 };


        


More information about the llvm-commits mailing list