[llvm] bf89d24 - [AArch64] NFC: Move safe predicate casting to a separate function.
Sander de Smalen via llvm-commits
llvm-commits at lists.llvm.org
Mon Jul 4 03:58:56 PDT 2022
Author: Sander de Smalen
Date: 2022-07-04T10:32:54Z
New Revision: bf89d24f5319cb57e8458c1192480d17f00d4540
URL: https://github.com/llvm/llvm-project/commit/bf89d24f5319cb57e8458c1192480d17f00d4540
DIFF: https://github.com/llvm/llvm-project/commit/bf89d24f5319cb57e8458c1192480d17f00d4540.diff
LOG: [AArch64] NFC: Move safe predicate casting to a separate function.
This patch puts the code to safely bitcast a predicate, and possibly zero
any undefined lanes when doing a widening cast, into one place and merges
the functionality with lowerConvertToSVBool.
This is some cleanup inspired by D128665.
Reviewed By: paulwalker-arm
Differential Revision: https://reviews.llvm.org/D128926
Added:
Modified:
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
llvm/lib/Target/AArch64/AArch64ISelLowering.h
Removed:
################################################################################
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index abfe2d5071119..58a78c2e3c245 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1082,6 +1082,14 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::Other, Custom);
}
+ // FIXME: Move lowering for more nodes here if those are common between
+ // SVE and SME.
+ if (Subtarget->hasSVE() || Subtarget->hasSME()) {
+ for (auto VT :
+ {MVT::nxv16i1, MVT::nxv8i1, MVT::nxv4i1, MVT::nxv2i1, MVT::nxv1i1})
+ setOperationAction(ISD::SPLAT_VECTOR, VT, Custom);
+ }
+
if (Subtarget->hasSVE()) {
for (auto VT : {MVT::nxv16i8, MVT::nxv8i16, MVT::nxv4i32, MVT::nxv2i64}) {
setOperationAction(ISD::BITREVERSE, VT, Custom);
@@ -1162,7 +1170,6 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
setOperationAction(ISD::SELECT, VT, Custom);
setOperationAction(ISD::SETCC, VT, Custom);
- setOperationAction(ISD::SPLAT_VECTOR, VT, Custom);
setOperationAction(ISD::TRUNCATE, VT, Custom);
setOperationAction(ISD::VECREDUCE_AND, VT, Custom);
setOperationAction(ISD::VECREDUCE_OR, VT, Custom);
@@ -4333,27 +4340,47 @@ static inline SDValue getPTrue(SelectionDAG &DAG, SDLoc DL, EVT VT,
DAG.getTargetConstant(Pattern, DL, MVT::i32));
}
-static SDValue lowerConvertToSVBool(SDValue Op, SelectionDAG &DAG) {
+SDValue AArch64TargetLowering::getSVEPredicateBitCast(EVT VT, SDValue Op,
+ SelectionDAG &DAG) const {
SDLoc DL(Op);
- EVT OutVT = Op.getValueType();
- SDValue InOp = Op.getOperand(1);
- EVT InVT = InOp.getValueType();
+ EVT InVT = Op.getValueType();
+
+ assert(InVT.getVectorElementType() == MVT::i1 &&
+ VT.getVectorElementType() == MVT::i1 &&
+ "Expected a predicate-to-predicate bitcast");
+ assert(VT.isScalableVector() && isTypeLegal(VT) &&
+ InVT.isScalableVector() && isTypeLegal(InVT) &&
+ "Only expect to cast between legal scalable predicate types!");
// Return the operand if the cast isn't changing type,
- // i.e. <n x 16 x i1> -> <n x 16 x i1>
- if (InVT == OutVT)
- return InOp;
+ // e.g. <n x 16 x i1> -> <n x 16 x i1>
+ if (InVT == VT)
+ return Op;
+
+ SDValue Reinterpret = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Op);
- SDValue Reinterpret =
- DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, OutVT, InOp);
+ // We only have to zero the lanes if new lanes are being defined, e.g. when
+ // casting from <vscale x 2 x i1> to <vscale x 16 x i1>. If this is not the
+ // case (e.g. when casting from <vscale x 16 x i1> -> <vscale x 2 x i1>) then
+ // we can return here.
+ if (InVT.bitsGT(VT))
+ return Reinterpret;
- // If the argument converted to an svbool is a ptrue or a comparison, the
- // lanes introduced by the widening are zero by construction.
- switch (InOp.getOpcode()) {
+ // Check if the other lanes are already known to be zeroed by
+ // construction.
+ switch (Op.getOpcode()) {
+ default:
+ // We guarantee i1 splat_vectors to zero the other lanes by
+ // implementing it with ptrue and possibly a punpklo for nxv1i1.
+ if (ISD::isConstantSplatVectorAllOnes(Op.getNode()))
+ return Reinterpret;
+ break;
case AArch64ISD::SETCC_MERGE_ZERO:
return Reinterpret;
case ISD::INTRINSIC_WO_CHAIN:
- switch (InOp.getConstantOperandVal(0)) {
+ switch (Op.getConstantOperandVal(0)) {
+ default:
+ break;
case Intrinsic::aarch64_sve_ptrue:
case Intrinsic::aarch64_sve_cmpeq_wide:
case Intrinsic::aarch64_sve_cmpne_wide:
@@ -4369,15 +4396,10 @@ static SDValue lowerConvertToSVBool(SDValue Op, SelectionDAG &DAG) {
}
}
- // Splat vectors of one will generate ptrue instructions
- if (ISD::isConstantSplatVectorAllOnes(InOp.getNode()))
- return Reinterpret;
-
- // Otherwise, zero the newly introduced lanes.
- SDValue Mask = getPTrue(DAG, DL, InVT, AArch64SVEPredPattern::all);
- SDValue MaskReinterpret =
- DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, OutVT, Mask);
- return DAG.getNode(ISD::AND, DL, OutVT, Reinterpret, MaskReinterpret);
+ // Zero the newly introduced lanes.
+ SDValue Mask = DAG.getConstant(1, DL, InVT);
+ Mask = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Mask);
+ return DAG.getNode(ISD::AND, DL, VT, Reinterpret, Mask);
}
SDValue AArch64TargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op,
@@ -4546,10 +4568,9 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
case Intrinsic::aarch64_sve_dupq_lane:
return LowerDUPQLane(Op, DAG);
case Intrinsic::aarch64_sve_convert_from_svbool:
- return DAG.getNode(AArch64ISD::REINTERPRET_CAST, dl, Op.getValueType(),
- Op.getOperand(1));
+ return getSVEPredicateBitCast(Op.getValueType(), Op.getOperand(1), DAG);
case Intrinsic::aarch64_sve_convert_to_svbool:
- return lowerConvertToSVBool(Op, DAG);
+ return getSVEPredicateBitCast(MVT::nxv16i1, Op.getOperand(1), DAG);
case Intrinsic::aarch64_sve_fneg:
return DAG.getNode(AArch64ISD::FNEG_MERGE_PASSTHRU, dl, Op.getValueType(),
Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
@@ -21464,22 +21485,17 @@ SDValue AArch64TargetLowering::getSVESafeBitCast(EVT VT, SDValue Op,
SelectionDAG &DAG) const {
SDLoc DL(Op);
EVT InVT = Op.getValueType();
- const TargetLowering &TLI = DAG.getTargetLoweringInfo();
- (void)TLI;
- assert(VT.isScalableVector() && TLI.isTypeLegal(VT) &&
- InVT.isScalableVector() && TLI.isTypeLegal(InVT) &&
+ assert(VT.isScalableVector() && isTypeLegal(VT) &&
+ InVT.isScalableVector() && isTypeLegal(InVT) &&
"Only expect to cast between legal scalable vector types!");
- assert((VT.getVectorElementType() == MVT::i1) ==
- (InVT.getVectorElementType() == MVT::i1) &&
- "Cannot cast between data and predicate scalable vector types!");
+ assert(VT.getVectorElementType() != MVT::i1 &&
+ InVT.getVectorElementType() != MVT::i1 &&
+ "For predicate bitcasts, use getSVEPredicateBitCast");
if (InVT == VT)
return Op;
- if (VT.getVectorElementType() == MVT::i1)
- return DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Op);
-
EVT PackedVT = getPackedSVEVectorVT(VT.getVectorElementType());
EVT PackedInVT = getPackedSVEVectorVT(InVT.getVectorElementType());
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 06ea918ea32e8..c7a6acc394d76 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -1148,8 +1148,13 @@ class AArch64TargetLowering : public TargetLowering {
// These can make "bitcasting" a multiphase process. REINTERPRET_CAST is used
// to transition between unpacked and packed types of the same element type,
// with BITCAST used otherwise.
+ // This function does not handle predicate bitcasts.
SDValue getSVESafeBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) const;
+ // Returns a safe bitcast between two scalable vector predicates, where
+ // any newly created lanes from a widening bitcast are defined as zero.
+ SDValue getSVEPredicateBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) const;
+
bool isConstantUnsignedBitfieldExtractLegal(unsigned Opc, LLT Ty1,
LLT Ty2) const override;
};
More information about the llvm-commits
mailing list