[llvm] 4e1db6a - [AArch64][SVE] Add AArch64ISD nodes for wide add instructions (#115895)
via llvm-commits
llvm-commits at lists.llvm.org
Fri Nov 15 03:01:14 PST 2024
Author: James Chesterman
Date: 2024-11-15T11:01:10Z
New Revision: 4e1db6a318775d9d0c49357baea6ca02fe5b5389
URL: https://github.com/llvm/llvm-project/commit/4e1db6a318775d9d0c49357baea6ca02fe5b5389
DIFF: https://github.com/llvm/llvm-project/commit/4e1db6a318775d9d0c49357baea6ca02fe5b5389.diff
LOG: [AArch64][SVE] Add AArch64ISD nodes for wide add instructions (#115895)
When lowering from a partial reduction to a pair of wide adds,
previously the corresponding intrinsics were returned as nodes. Now
there are AArch64ISD nodes that are returned.
Added:
Modified:
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
llvm/lib/Target/AArch64/AArch64ISelLowering.h
llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
Removed:
################################################################################
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 9d1c3d4eddc880..ff577f238d1839 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2768,6 +2768,10 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(AArch64ISD::UADDV)
MAKE_CASE(AArch64ISD::UADDLV)
MAKE_CASE(AArch64ISD::SADDLV)
+ MAKE_CASE(AArch64ISD::SADDWT)
+ MAKE_CASE(AArch64ISD::SADDWB)
+ MAKE_CASE(AArch64ISD::UADDWT)
+ MAKE_CASE(AArch64ISD::UADDWB)
MAKE_CASE(AArch64ISD::SDOT)
MAKE_CASE(AArch64ISD::UDOT)
MAKE_CASE(AArch64ISD::USDOT)
@@ -21907,17 +21911,10 @@ SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
return SDValue();
bool InputIsSigned = ExtInputOpcode == ISD::SIGN_EXTEND;
- auto BottomIntrinsic = InputIsSigned ? Intrinsic::aarch64_sve_saddwb
- : Intrinsic::aarch64_sve_uaddwb;
- auto TopIntrinsic = InputIsSigned ? Intrinsic::aarch64_sve_saddwt
- : Intrinsic::aarch64_sve_uaddwt;
-
- auto BottomID = DAG.getTargetConstant(BottomIntrinsic, DL, AccElemVT);
- auto BottomNode =
- DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, AccVT, BottomID, Acc, Input);
- auto TopID = DAG.getTargetConstant(TopIntrinsic, DL, AccElemVT);
- return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, AccVT, TopID, BottomNode,
- Input);
+ auto BottomOpcode = InputIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
+ auto TopOpcode = InputIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT;
+ auto BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, Input);
+ return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, Input);
}
static SDValue performIntrinsicCombine(SDNode *N,
@@ -22097,6 +22094,18 @@ static SDValue performIntrinsicCombine(SDNode *N,
case Intrinsic::aarch64_sve_bic_u:
return DAG.getNode(AArch64ISD::BIC, SDLoc(N), N->getValueType(0),
N->getOperand(2), N->getOperand(3));
+ case Intrinsic::aarch64_sve_saddwb:
+ return DAG.getNode(AArch64ISD::SADDWB, SDLoc(N), N->getValueType(0),
+ N->getOperand(1), N->getOperand(2));
+ case Intrinsic::aarch64_sve_saddwt:
+ return DAG.getNode(AArch64ISD::SADDWT, SDLoc(N), N->getValueType(0),
+ N->getOperand(1), N->getOperand(2));
+ case Intrinsic::aarch64_sve_uaddwb:
+ return DAG.getNode(AArch64ISD::UADDWB, SDLoc(N), N->getValueType(0),
+ N->getOperand(1), N->getOperand(2));
+ case Intrinsic::aarch64_sve_uaddwt:
+ return DAG.getNode(AArch64ISD::UADDWT, SDLoc(N), N->getValueType(0),
+ N->getOperand(1), N->getOperand(2));
case Intrinsic::aarch64_sve_eor_u:
return DAG.getNode(ISD::XOR, SDLoc(N), N->getValueType(0), N->getOperand(2),
N->getOperand(3));
@@ -29702,6 +29711,27 @@ void AArch64TargetLowering::verifyTargetSDNode(const SDNode *N) const {
switch (N->getOpcode()) {
default:
break;
+ case AArch64ISD::SADDWT:
+ case AArch64ISD::SADDWB:
+ case AArch64ISD::UADDWT:
+ case AArch64ISD::UADDWB: {
+ assert(N->getNumValues() == 1 && "Expected one result!");
+ assert(N->getNumOperands() == 2 && "Expected two operands!");
+ EVT VT = N->getValueType(0);
+ EVT Op0VT = N->getOperand(0).getValueType();
+ EVT Op1VT = N->getOperand(1).getValueType();
+ assert(VT.isVector() && Op0VT.isVector() && Op1VT.isVector() &&
+ VT.isInteger() && Op0VT.isInteger() && Op1VT.isInteger() &&
+ "Expected integer vectors!");
+ assert(VT == Op0VT &&
+ "Expected result and first input to have the same type!");
+ assert(Op0VT.getSizeInBits() == Op1VT.getSizeInBits() &&
+ "Expected vectors of equal size!");
+ assert(Op0VT.getVectorElementCount() * 2 == Op1VT.getVectorElementCount() &&
+ "Expected result vector and first input vector to have half the "
+ "lanes of the second input vector!");
+ break;
+ }
case AArch64ISD::SUNPKLO:
case AArch64ISD::SUNPKHI:
case AArch64ISD::UUNPKLO:
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index d11da64d3f84eb..176ad57a6ed728 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -273,6 +273,12 @@ enum NodeType : unsigned {
UADDLV,
SADDLV,
+ // Wide adds
+ SADDWT,
+ SADDWB,
+ UADDWT,
+ UADDWB,
+
// Add Pairwise of two vectors
ADDP,
// Add Long Pairwise
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 69fd7547ce85e8..8791ce6266c86c 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -430,6 +430,13 @@ def SDT_AArch64Arith_Unpred : SDTypeProfile<1, 2, [
def AArch64bic_node : SDNode<"AArch64ISD::BIC", SDT_AArch64Arith_Unpred>;
+def SDT_AArch64addw : SDTypeProfile<1, 2, [SDTCisVec<0>, SDTCisVec<1>]>;
+
+def AArch64saddwt : SDNode<"AArch64ISD::SADDWT", SDT_AArch64addw>;
+def AArch64saddwb : SDNode<"AArch64ISD::SADDWB", SDT_AArch64addw>;
+def AArch64uaddwt : SDNode<"AArch64ISD::UADDWT", SDT_AArch64addw>;
+def AArch64uaddwb : SDNode<"AArch64ISD::UADDWB", SDT_AArch64addw>;
+
def AArch64bic : PatFrags<(ops node:$op1, node:$op2),
[(and node:$op1, (xor node:$op2, (splat_vector (i32 -1)))),
(and node:$op1, (xor node:$op2, (splat_vector (i64 -1)))),
@@ -3674,10 +3681,10 @@ let Predicates = [HasSVE2orSME] in {
defm UABDLT_ZZZ : sve2_wide_int_arith_long<0b01111, "uabdlt", int_aarch64_sve_uabdlt>;
// SVE2 integer add/subtract wide
- defm SADDWB_ZZZ : sve2_wide_int_arith_wide<0b000, "saddwb", int_aarch64_sve_saddwb>;
- defm SADDWT_ZZZ : sve2_wide_int_arith_wide<0b001, "saddwt", int_aarch64_sve_saddwt>;
- defm UADDWB_ZZZ : sve2_wide_int_arith_wide<0b010, "uaddwb", int_aarch64_sve_uaddwb>;
- defm UADDWT_ZZZ : sve2_wide_int_arith_wide<0b011, "uaddwt", int_aarch64_sve_uaddwt>;
+ defm SADDWB_ZZZ : sve2_wide_int_arith_wide<0b000, "saddwb", AArch64saddwb>;
+ defm SADDWT_ZZZ : sve2_wide_int_arith_wide<0b001, "saddwt", AArch64saddwt>;
+ defm UADDWB_ZZZ : sve2_wide_int_arith_wide<0b010, "uaddwb", AArch64uaddwb>;
+ defm UADDWT_ZZZ : sve2_wide_int_arith_wide<0b011, "uaddwt", AArch64uaddwt>;
defm SSUBWB_ZZZ : sve2_wide_int_arith_wide<0b100, "ssubwb", int_aarch64_sve_ssubwb>;
defm SSUBWT_ZZZ : sve2_wide_int_arith_wide<0b101, "ssubwt", int_aarch64_sve_ssubwt>;
defm USUBWB_ZZZ : sve2_wide_int_arith_wide<0b110, "usubwb", int_aarch64_sve_usubwb>;
More information about the llvm-commits
mailing list