[llvm] 2887f14 - [ISel] Port AArch64 SABD and UABD to DAGCombine
David Green via llvm-commits
llvm-commits at lists.llvm.org
Sat Jun 26 11:34:41 PDT 2021
Author: David Green
Date: 2021-06-26T19:34:16+01:00
New Revision: 2887f1463930044a6093f111dc8eba5594144c33
URL: https://github.com/llvm/llvm-project/commit/2887f1463930044a6093f111dc8eba5594144c33
DIFF: https://github.com/llvm/llvm-project/commit/2887f1463930044a6093f111dc8eba5594144c33.diff
LOG: [ISel] Port AArch64 SABD and UABD to DAGCombine
This ports the AArch64 SABD and USBD over to DAG Combine, where they can be
used by more backends (notably MVE in a follow-up patch). The matching code
has changed very little, just to handle legal operations and types
differently. It selects from (ABS (SUB (EXTEND a), (EXTEND b))), producing
a ubds/abdu which is zexted to the original type.
Differential Revision: https://reviews.llvm.org/D91937
Added:
Modified:
llvm/include/llvm/CodeGen/ISDOpcodes.h
llvm/include/llvm/Target/TargetSelectionDAG.td
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
llvm/lib/CodeGen/TargetLoweringBase.cpp
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
llvm/lib/Target/AArch64/AArch64ISelLowering.h
llvm/lib/Target/AArch64/AArch64InstrInfo.td
Removed:
################################################################################
diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index adad8c18e5583..6eb70ab477089 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -611,6 +611,13 @@ enum NodeType {
MULHU,
MULHS,
+ // ABDS/ABDU - Absolute
diff erence - Return the absolute
diff erence between
+ // two numbers interpreted as signed/unsigned.
+ // i.e trunc(abs(sext(Op0) - sext(Op1))) becomes abds(Op0, Op1)
+ // or trunc(abs(zext(Op0) - zext(Op1))) becomes abdu(Op0, Op1)
+ ABDS,
+ ABDU,
+
/// [US]{MIN/MAX} - Binary minimum or maximum of signed or unsigned
/// integers.
SMIN,
diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td
index 1913396609fd0..c7f22bf173910 100644
--- a/llvm/include/llvm/Target/TargetSelectionDAG.td
+++ b/llvm/include/llvm/Target/TargetSelectionDAG.td
@@ -369,6 +369,8 @@ def mul : SDNode<"ISD::MUL" , SDTIntBinOp,
[SDNPCommutative, SDNPAssociative]>;
def mulhs : SDNode<"ISD::MULHS" , SDTIntBinOp, [SDNPCommutative]>;
def mulhu : SDNode<"ISD::MULHU" , SDTIntBinOp, [SDNPCommutative]>;
+def abds : SDNode<"ISD::ABDS" , SDTIntBinOp, [SDNPCommutative]>;
+def abdu : SDNode<"ISD::ABDU" , SDTIntBinOp, [SDNPCommutative]>;
def smullohi : SDNode<"ISD::SMUL_LOHI" , SDTIntBinHiLoOp, [SDNPCommutative]>;
def umullohi : SDNode<"ISD::UMUL_LOHI" , SDTIntBinHiLoOp, [SDNPCommutative]>;
def sdiv : SDNode<"ISD::SDIV" , SDTIntBinOp>;
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 63c979c4dbd73..5ea3de9d0db66 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -9071,6 +9071,40 @@ SDValue DAGCombiner::visitFunnelShift(SDNode *N) {
return SDValue();
}
+// Given a ABS node, detect the following pattern:
+// (ABS (SUB (EXTEND a), (EXTEND b))).
+// Generates UABD/SABD instruction.
+static SDValue combineABSToABD(SDNode *N, SelectionDAG &DAG,
+ const TargetLowering &TLI) {
+ SDValue AbsOp1 = N->getOperand(0);
+ SDValue Op0, Op1;
+
+ if (AbsOp1.getOpcode() != ISD::SUB)
+ return SDValue();
+
+ Op0 = AbsOp1.getOperand(0);
+ Op1 = AbsOp1.getOperand(1);
+
+ unsigned Opc0 = Op0.getOpcode();
+ // Check if the operands of the sub are (zero|sign)-extended.
+ if (Opc0 != Op1.getOpcode() ||
+ (Opc0 != ISD::ZERO_EXTEND && Opc0 != ISD::SIGN_EXTEND))
+ return SDValue();
+
+ EVT VT1 = Op0.getOperand(0).getValueType();
+ EVT VT2 = Op1.getOperand(0).getValueType();
+ // Check if the operands are of same type and valid size.
+ unsigned ABDOpcode = (Opc0 == ISD::SIGN_EXTEND) ? ISD::ABDS : ISD::ABDU;
+ if (VT1 != VT2 || !TLI.isOperationLegalOrCustom(ABDOpcode, VT1))
+ return SDValue();
+
+ Op0 = Op0.getOperand(0);
+ Op1 = Op1.getOperand(0);
+ SDValue ABD =
+ DAG.getNode(ABDOpcode, SDLoc(N), Op0->getValueType(0), Op0, Op1);
+ return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), N->getValueType(0), ABD);
+}
+
SDValue DAGCombiner::visitABS(SDNode *N) {
SDValue N0 = N->getOperand(0);
EVT VT = N->getValueType(0);
@@ -9084,6 +9118,10 @@ SDValue DAGCombiner::visitABS(SDNode *N) {
// fold (abs x) -> x iff not-negative
if (DAG.SignBitIsZero(N0))
return N0;
+
+ if (SDValue ABD = combineABSToABD(N, DAG, TLI))
+ return ABD;
+
return SDValue();
}
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
index 73c207e589fba..40083c614a6c7 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
@@ -231,6 +231,8 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
case ISD::MUL: return "mul";
case ISD::MULHU: return "mulhu";
case ISD::MULHS: return "mulhs";
+ case ISD::ABDS: return "abds";
+ case ISD::ABDU: return "abdu";
case ISD::SDIV: return "sdiv";
case ISD::UDIV: return "udiv";
case ISD::SREM: return "srem";
diff --git a/llvm/lib/CodeGen/TargetLoweringBase.cpp b/llvm/lib/CodeGen/TargetLoweringBase.cpp
index d2c291f2ae72b..ebac779984ec5 100644
--- a/llvm/lib/CodeGen/TargetLoweringBase.cpp
+++ b/llvm/lib/CodeGen/TargetLoweringBase.cpp
@@ -813,6 +813,10 @@ void TargetLoweringBase::initActions() {
setOperationAction(ISD::SUBC, VT, Expand);
setOperationAction(ISD::SUBE, VT, Expand);
+ // Absolute
diff erence
+ setOperationAction(ISD::ABDS, VT, Expand);
+ setOperationAction(ISD::ABDU, VT, Expand);
+
// These default to Expand so they will be expanded to CTLZ/CTTZ by default.
setOperationAction(ISD::CTLZ_ZERO_UNDEF, VT, Expand);
setOperationAction(ISD::CTTZ_ZERO_UNDEF, VT, Expand);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index fd5c9e05a8a25..9886d6374665b 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1050,6 +1050,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::USUBSAT, VT, Legal);
}
+ for (MVT VT : {MVT::v8i8, MVT::v4i16, MVT::v2i32, MVT::v16i8, MVT::v8i16,
+ MVT::v4i32}) {
+ setOperationAction(ISD::ABDS, VT, Legal);
+ setOperationAction(ISD::ABDU, VT, Legal);
+ }
+
// Vector reductions
for (MVT VT : { MVT::v4f16, MVT::v2f32,
MVT::v8f16, MVT::v4f32, MVT::v2f64 }) {
@@ -2116,8 +2122,6 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(AArch64ISD::CTPOP_MERGE_PASSTHRU)
MAKE_CASE(AArch64ISD::DUP_MERGE_PASSTHRU)
MAKE_CASE(AArch64ISD::INDEX_VECTOR)
- MAKE_CASE(AArch64ISD::UABD)
- MAKE_CASE(AArch64ISD::SABD)
MAKE_CASE(AArch64ISD::UADDLP)
MAKE_CASE(AArch64ISD::CALL_RVMARKER)
}
@@ -4082,8 +4086,8 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
}
case Intrinsic::aarch64_neon_sabd:
case Intrinsic::aarch64_neon_uabd: {
- unsigned Opcode = IntNo == Intrinsic::aarch64_neon_uabd ? AArch64ISD::UABD
- : AArch64ISD::SABD;
+ unsigned Opcode = IntNo == Intrinsic::aarch64_neon_uabd ? ISD::ABDU
+ : ISD::ABDS;
return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1),
Op.getOperand(2));
}
@@ -12099,8 +12103,8 @@ static SDValue performVecReduceAddCombineWithUADDLP(SDNode *N,
SDValue UABDHigh8Op1 =
DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i8, EXT1->getOperand(0),
DAG.getConstant(8, DL, MVT::i64));
- SDValue UABDHigh8 = DAG.getNode(IsZExt ? AArch64ISD::UABD : AArch64ISD::SABD,
- DL, MVT::v8i8, UABDHigh8Op0, UABDHigh8Op1);
+ SDValue UABDHigh8 = DAG.getNode(IsZExt ? ISD::ABDU : ISD::ABDS, DL, MVT::v8i8,
+ UABDHigh8Op0, UABDHigh8Op1);
SDValue UABDL = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::v8i16, UABDHigh8);
// Second, create the node pattern of UABAL.
@@ -12110,8 +12114,8 @@ static SDValue performVecReduceAddCombineWithUADDLP(SDNode *N,
SDValue UABDLo8Op1 =
DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i8, EXT1->getOperand(0),
DAG.getConstant(0, DL, MVT::i64));
- SDValue UABDLo8 = DAG.getNode(IsZExt ? AArch64ISD::UABD : AArch64ISD::SABD,
- DL, MVT::v8i8, UABDLo8Op0, UABDLo8Op1);
+ SDValue UABDLo8 = DAG.getNode(IsZExt ? ISD::ABDU : ISD::ABDS, DL, MVT::v8i8,
+ UABDLo8Op0, UABDLo8Op1);
SDValue ZExtUABD = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::v8i16, UABDLo8);
SDValue UABAL = DAG.getNode(ISD::ADD, DL, MVT::v8i16, UABDL, ZExtUABD);
@@ -12170,48 +12174,6 @@ static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG,
return DAG.getNode(ISD::VECREDUCE_ADD, DL, N->getValueType(0), Dot);
}
-// Given a ABS node, detect the following pattern:
-// (ABS (SUB (EXTEND a), (EXTEND b))).
-// Generates UABD/SABD instruction.
-static SDValue performABSCombine(SDNode *N, SelectionDAG &DAG,
- TargetLowering::DAGCombinerInfo &DCI,
- const AArch64Subtarget *Subtarget) {
- SDValue AbsOp1 = N->getOperand(0);
- SDValue Op0, Op1;
-
- if (AbsOp1.getOpcode() != ISD::SUB)
- return SDValue();
-
- Op0 = AbsOp1.getOperand(0);
- Op1 = AbsOp1.getOperand(1);
-
- unsigned Opc0 = Op0.getOpcode();
- // Check if the operands of the sub are (zero|sign)-extended.
- if (Opc0 != Op1.getOpcode() ||
- (Opc0 != ISD::ZERO_EXTEND && Opc0 != ISD::SIGN_EXTEND))
- return SDValue();
-
- EVT VectorT1 = Op0.getOperand(0).getValueType();
- EVT VectorT2 = Op1.getOperand(0).getValueType();
- // Check if vectors are of same type and valid size.
- uint64_t Size = VectorT1.getFixedSizeInBits();
- if (VectorT1 != VectorT2 || (Size != 64 && Size != 128))
- return SDValue();
-
- // Check if vector element types are valid.
- EVT VT1 = VectorT1.getVectorElementType();
- if (VT1 != MVT::i8 && VT1 != MVT::i16 && VT1 != MVT::i32)
- return SDValue();
-
- Op0 = Op0.getOperand(0);
- Op1 = Op1.getOperand(0);
- unsigned ABDOpcode =
- (Opc0 == ISD::SIGN_EXTEND) ? AArch64ISD::SABD : AArch64ISD::UABD;
- SDValue ABD =
- DAG.getNode(ABDOpcode, SDLoc(N), Op0->getValueType(0), Op0, Op1);
- return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), N->getValueType(0), ABD);
-}
-
static SDValue performXorCombine(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const AArch64Subtarget *Subtarget) {
@@ -14377,8 +14339,8 @@ static SDValue performExtendCombine(SDNode *N,
// helps the backend to decide that an sabdl2 would be useful, saving a real
// extract_high operation.
if (!DCI.isBeforeLegalizeOps() && N->getOpcode() == ISD::ZERO_EXTEND &&
- (N->getOperand(0).getOpcode() == AArch64ISD::UABD ||
- N->getOperand(0).getOpcode() == AArch64ISD::SABD)) {
+ (N->getOperand(0).getOpcode() == ISD::ABDU ||
+ N->getOperand(0).getOpcode() == ISD::ABDS)) {
SDNode *ABDNode = N->getOperand(0).getNode();
SDValue NewABD =
tryCombineLongOpWithDup(Intrinsic::not_intrinsic, ABDNode, DCI, DAG);
@@ -16344,8 +16306,6 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
default:
LLVM_DEBUG(dbgs() << "Custom combining: skipping\n");
break;
- case ISD::ABS:
- return performABSCombine(N, DAG, DCI, Subtarget);
case ISD::ADD:
case ISD::SUB:
return performAddSubCombine(N, DCI, DAG);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index ced2607862123..7daa61996739f 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -236,10 +236,6 @@ enum NodeType : unsigned {
SRHADD,
URHADD,
- // Absolute
diff erence
- UABD,
- SABD,
-
// Unsigned Add Long Pairwise
UADDLP,
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 301f1ed69638d..7802144fb2c98 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -579,14 +579,11 @@ def AArch64urhadd : SDNode<"AArch64ISD::URHADD", SDT_AArch64binvec>;
def AArch64shadd : SDNode<"AArch64ISD::SHADD", SDT_AArch64binvec>;
def AArch64uhadd : SDNode<"AArch64ISD::UHADD", SDT_AArch64binvec>;
-def AArch64uabd_n : SDNode<"AArch64ISD::UABD", SDT_AArch64binvec>;
-def AArch64sabd_n : SDNode<"AArch64ISD::SABD", SDT_AArch64binvec>;
-
def AArch64uabd : PatFrags<(ops node:$lhs, node:$rhs),
- [(AArch64uabd_n node:$lhs, node:$rhs),
+ [(abdu node:$lhs, node:$rhs),
(int_aarch64_neon_uabd node:$lhs, node:$rhs)]>;
def AArch64sabd : PatFrags<(ops node:$lhs, node:$rhs),
- [(AArch64sabd_n node:$lhs, node:$rhs),
+ [(abds node:$lhs, node:$rhs),
(int_aarch64_neon_sabd node:$lhs, node:$rhs)]>;
def AArch64uaddlp_n : SDNode<"AArch64ISD::UADDLP", SDT_AArch64uaddlp>;
More information about the llvm-commits
mailing list