[llvm] [AArch64][SVE] Add AArch64ISD nodes for wide add instructions (PR #115895)

James Chesterman via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 14 08:09:38 PST 2024


https://github.com/JamesChesterman updated https://github.com/llvm/llvm-project/pull/115895

>From 0cadef86e778bcba49952d74b0f3327bb53c0a94 Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Tue, 12 Nov 2024 16:13:50 +0000
Subject: [PATCH 1/3] [AArch64][SVE] Add AArch64ISD nodes for wide add
 instructions

When lowering from a partial reduction to a pair of wide adds,
previously the corresponding intrinsics were returned as nodes.
Now there are AArch64ISD nodes that are returned.
---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 31 ++++++++++++-------
 llvm/lib/Target/AArch64/AArch64ISelLowering.h |  6 ++++
 .../lib/Target/AArch64/AArch64SVEInstrInfo.td | 24 +++++++++++---
 3 files changed, 46 insertions(+), 15 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 9d1c3d4eddc880..f95745a1a84d8c 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2768,6 +2768,10 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(AArch64ISD::UADDV)
     MAKE_CASE(AArch64ISD::UADDLV)
     MAKE_CASE(AArch64ISD::SADDLV)
+    MAKE_CASE(AArch64ISD::SADDWT)
+    MAKE_CASE(AArch64ISD::SADDWB)
+    MAKE_CASE(AArch64ISD::UADDWT)
+    MAKE_CASE(AArch64ISD::UADDWB)
     MAKE_CASE(AArch64ISD::SDOT)
     MAKE_CASE(AArch64ISD::UDOT)
     MAKE_CASE(AArch64ISD::USDOT)
@@ -21907,17 +21911,10 @@ SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
     return SDValue();
 
   bool InputIsSigned = ExtInputOpcode == ISD::SIGN_EXTEND;
-  auto BottomIntrinsic = InputIsSigned ? Intrinsic::aarch64_sve_saddwb
-                                       : Intrinsic::aarch64_sve_uaddwb;
-  auto TopIntrinsic = InputIsSigned ? Intrinsic::aarch64_sve_saddwt
-                                    : Intrinsic::aarch64_sve_uaddwt;
-
-  auto BottomID = DAG.getTargetConstant(BottomIntrinsic, DL, AccElemVT);
-  auto BottomNode =
-      DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, AccVT, BottomID, Acc, Input);
-  auto TopID = DAG.getTargetConstant(TopIntrinsic, DL, AccElemVT);
-  return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, AccVT, TopID, BottomNode,
-                     Input);
+  auto BottomISD = InputIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
+  auto TopISD = InputIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT;
+  auto BottomNode = DAG.getNode(BottomISD, DL, AccVT, Acc, Input);
+  return DAG.getNode(TopISD, DL, AccVT, BottomNode, Input);
 }
 
 static SDValue performIntrinsicCombine(SDNode *N,
@@ -22097,6 +22094,18 @@ static SDValue performIntrinsicCombine(SDNode *N,
   case Intrinsic::aarch64_sve_bic_u:
     return DAG.getNode(AArch64ISD::BIC, SDLoc(N), N->getValueType(0),
                        N->getOperand(2), N->getOperand(3));
+  case Intrinsic::aarch64_sve_saddwb:
+    return DAG.getNode(AArch64ISD::SADDWB, SDLoc(N), N->getValueType(0),
+                       N->getOperand(1), N->getOperand(2));
+  case Intrinsic::aarch64_sve_saddwt:
+    return DAG.getNode(AArch64ISD::SADDWT, SDLoc(N), N->getValueType(0),
+                       N->getOperand(1), N->getOperand(2));
+  case Intrinsic::aarch64_sve_uaddwb:
+    return DAG.getNode(AArch64ISD::UADDWB, SDLoc(N), N->getValueType(0),
+                       N->getOperand(1), N->getOperand(2));
+  case Intrinsic::aarch64_sve_uaddwt:
+    return DAG.getNode(AArch64ISD::UADDWT, SDLoc(N), N->getValueType(0),
+                       N->getOperand(1), N->getOperand(2));
   case Intrinsic::aarch64_sve_eor_u:
     return DAG.getNode(ISD::XOR, SDLoc(N), N->getValueType(0), N->getOperand(2),
                        N->getOperand(3));
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index d11da64d3f84eb..176ad57a6ed728 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -273,6 +273,12 @@ enum NodeType : unsigned {
   UADDLV,
   SADDLV,
 
+  // Wide adds
+  SADDWT,
+  SADDWB,
+  UADDWT,
+  UADDWB,
+
   // Add Pairwise of two vectors
   ADDP,
   // Add Long Pairwise
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 4f146b3ee59e9a..1e0cc253c206ea 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -430,6 +430,22 @@ def SDT_AArch64Arith_Unpred : SDTypeProfile<1, 2, [
 
 def AArch64bic_node : SDNode<"AArch64ISD::BIC",  SDT_AArch64Arith_Unpred>;
 
+def SDT_AArch64addw : SDTypeProfile<1, 2, [SDTCisVec<0>, SDTCisVec<1>]>;
+
+def AArch64saddwt_node : SDNode<"AArch64ISD::SADDWT", SDT_AArch64addw>;
+def AArch64saddwb_node : SDNode<"AArch64ISD::SADDWB", SDT_AArch64addw>;
+def AArch64uaddwt_node : SDNode<"AArch64ISD::UADDWT", SDT_AArch64addw>;
+def AArch64uaddwb_node : SDNode<"AArch64ISD::UADDWB", SDT_AArch64addw>;
+
+def AArch64saddwt : PatFrag<(ops node:$op1, node:$op2),
+                              (AArch64saddwt_node node:$op1, node:$op2)>;
+def AArch64saddwb : PatFrag<(ops node:$op1, node:$op2),
+                              (AArch64saddwb_node node:$op1, node:$op2)>;
+def AArch64uaddwt : PatFrag<(ops node:$op1, node:$op2),
+                              (AArch64uaddwt_node node:$op1, node:$op2)>;
+def AArch64uaddwb : PatFrag<(ops node:$op1, node:$op2),
+                              (AArch64uaddwb_node node:$op1, node:$op2)>;
+
 def AArch64bic : PatFrags<(ops node:$op1, node:$op2),
                           [(and node:$op1, (xor node:$op2, (splat_vector (i32 -1)))),
                            (and node:$op1, (xor node:$op2, (splat_vector (i64 -1)))),
@@ -3674,10 +3690,10 @@ let Predicates = [HasSVE2orSME] in {
   defm UABDLT_ZZZ : sve2_wide_int_arith_long<0b01111, "uabdlt", int_aarch64_sve_uabdlt>;
 
   // SVE2 integer add/subtract wide
-  defm SADDWB_ZZZ : sve2_wide_int_arith_wide<0b000, "saddwb", int_aarch64_sve_saddwb>;
-  defm SADDWT_ZZZ : sve2_wide_int_arith_wide<0b001, "saddwt", int_aarch64_sve_saddwt>;
-  defm UADDWB_ZZZ : sve2_wide_int_arith_wide<0b010, "uaddwb", int_aarch64_sve_uaddwb>;
-  defm UADDWT_ZZZ : sve2_wide_int_arith_wide<0b011, "uaddwt", int_aarch64_sve_uaddwt>;
+  defm SADDWB_ZZZ : sve2_wide_int_arith_wide<0b000, "saddwb", AArch64saddwb>;
+  defm SADDWT_ZZZ : sve2_wide_int_arith_wide<0b001, "saddwt", AArch64saddwt>;
+  defm UADDWB_ZZZ : sve2_wide_int_arith_wide<0b010, "uaddwb", AArch64uaddwb>;
+  defm UADDWT_ZZZ : sve2_wide_int_arith_wide<0b011, "uaddwt", AArch64uaddwt>;
   defm SSUBWB_ZZZ : sve2_wide_int_arith_wide<0b100, "ssubwb", int_aarch64_sve_ssubwb>;
   defm SSUBWT_ZZZ : sve2_wide_int_arith_wide<0b101, "ssubwt", int_aarch64_sve_ssubwt>;
   defm USUBWB_ZZZ : sve2_wide_int_arith_wide<0b110, "usubwb", int_aarch64_sve_usubwb>;

>From b2cdf3aa38f2741aec3da4700d8be484fb3bd404 Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Thu, 14 Nov 2024 13:53:09 +0000
Subject: [PATCH 2/3] Minor changes

Remove unnecessary code and rename variables.
---
 llvm/lib/Target/AArch64/AArch64ISelLowering.cpp |  8 ++++----
 llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td  | 17 ++++-------------
 2 files changed, 8 insertions(+), 17 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index f95745a1a84d8c..4c4cd07c01bad7 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -21911,10 +21911,10 @@ SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
     return SDValue();
 
   bool InputIsSigned = ExtInputOpcode == ISD::SIGN_EXTEND;
-  auto BottomISD = InputIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
-  auto TopISD = InputIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT;
-  auto BottomNode = DAG.getNode(BottomISD, DL, AccVT, Acc, Input);
-  return DAG.getNode(TopISD, DL, AccVT, BottomNode, Input);
+  auto BottomOpcode = InputIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
+  auto TopOpcode = InputIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT;
+  auto BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, Input);
+  return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, Input);
 }
 
 static SDValue performIntrinsicCombine(SDNode *N,
diff --git a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
index 1e0cc253c206ea..659d5e8b414cee 100644
--- a/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
@@ -432,19 +432,10 @@ def AArch64bic_node : SDNode<"AArch64ISD::BIC",  SDT_AArch64Arith_Unpred>;
 
 def SDT_AArch64addw : SDTypeProfile<1, 2, [SDTCisVec<0>, SDTCisVec<1>]>;
 
-def AArch64saddwt_node : SDNode<"AArch64ISD::SADDWT", SDT_AArch64addw>;
-def AArch64saddwb_node : SDNode<"AArch64ISD::SADDWB", SDT_AArch64addw>;
-def AArch64uaddwt_node : SDNode<"AArch64ISD::UADDWT", SDT_AArch64addw>;
-def AArch64uaddwb_node : SDNode<"AArch64ISD::UADDWB", SDT_AArch64addw>;
-
-def AArch64saddwt : PatFrag<(ops node:$op1, node:$op2),
-                              (AArch64saddwt_node node:$op1, node:$op2)>;
-def AArch64saddwb : PatFrag<(ops node:$op1, node:$op2),
-                              (AArch64saddwb_node node:$op1, node:$op2)>;
-def AArch64uaddwt : PatFrag<(ops node:$op1, node:$op2),
-                              (AArch64uaddwt_node node:$op1, node:$op2)>;
-def AArch64uaddwb : PatFrag<(ops node:$op1, node:$op2),
-                              (AArch64uaddwb_node node:$op1, node:$op2)>;
+def AArch64saddwt : SDNode<"AArch64ISD::SADDWT", SDT_AArch64addw>;
+def AArch64saddwb : SDNode<"AArch64ISD::SADDWB", SDT_AArch64addw>;
+def AArch64uaddwt : SDNode<"AArch64ISD::UADDWT", SDT_AArch64addw>;
+def AArch64uaddwb : SDNode<"AArch64ISD::UADDWB", SDT_AArch64addw>;
 
 def AArch64bic : PatFrags<(ops node:$op1, node:$op2),
                           [(and node:$op1, (xor node:$op2, (splat_vector (i32 -1)))),

>From 0996ab06f2574bb152f9766f7c3ecfb5804c1dcc Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Thu, 14 Nov 2024 16:08:26 +0000
Subject: [PATCH 3/3] Small changes. Add cases to verifyTargetSDNode() to
 ensure new ISD nodes are correct when generating them.

---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 20 +++++++++++++++++++
 1 file changed, 20 insertions(+)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 4c4cd07c01bad7..42533a8d641048 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -29711,6 +29711,26 @@ void AArch64TargetLowering::verifyTargetSDNode(const SDNode *N) const {
   switch (N->getOpcode()) {
   default:
     break;
+  case AArch64ISD::SADDWT:
+  case AArch64ISD::SADDWB:
+  case AArch64ISD::UADDWT:
+  case AArch64ISD::UADDWB: {
+    assert(N->getNumValues() == 1 && "Expected one result!");
+    assert(N->getNumOperands() == 2 && "Expected two operands!");
+    EVT VT = N->getValueType(0);
+    EVT Op0VT = N->getOperand(0).getValueType();
+    EVT Op1VT = N->getOperand(1).getValueType();
+    assert(VT.isVector() && Op0VT.isVector() && Op1VT.isVector() &&
+           "Expected vectors!");
+    assert(VT.getSizeInBits() == Op0VT.getSizeInBits() &&
+           Op0VT.getSizeInBits() == Op1VT.getSizeInBits() &&
+           "Expected vectors of equal size!");
+    assert(Op0VT.getVectorElementCount() * 2 == Op1VT.getVectorElementCount() &&
+           Op0VT.getVectorElementCount() == VT.getVectorElementCount() &&
+           "Expected result vector and first input vector to have half the "
+           "lanes of the second input vector!");
+    break;
+  }
   case AArch64ISD::SUNPKLO:
   case AArch64ISD::SUNPKHI:
   case AArch64ISD::UUNPKLO:



More information about the llvm-commits mailing list