[llvm] [NFC][AArch64][SVE] Rename variables in partial reduction lowering functions (PR #120589)

James Chesterman via llvm-commits llvm-commits at lists.llvm.org
Fri Dec 20 02:17:37 PST 2024


https://github.com/JamesChesterman updated https://github.com/llvm/llvm-project/pull/120589

>From 600930a3b48ecec1612d73ef233d085f0a8e5cd1 Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Thu, 19 Dec 2024 15:11:21 +0000
Subject: [PATCH 1/3] [NFC][AArch64][SVE] Rename variables in partial reduction
 lowering functions

---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 102 +++++++++---------
 1 file changed, 50 insertions(+), 52 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index d1354ccf376609..290f349c77809f 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -21739,73 +21739,71 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
   SDLoc DL(N);
 
   // The narrower of the two operands. Used as the accumulator
-  auto NarrowOp = N->getOperand(1);
-  auto MulOp = N->getOperand(2);
-  if (MulOp->getOpcode() != ISD::MUL)
+  auto A = N->getOperand(1);
+  auto B = N->getOperand(2);
+  if (B->getOpcode() != ISD::MUL)
     return SDValue();
 
-  auto ExtA = MulOp->getOperand(0);
-  auto ExtB = MulOp->getOperand(1);
+  auto ExtMulOp1 = B->getOperand(0);
+  auto ExtMulOp2 = B->getOperand(1);
 
-  if (!ISD::isExtOpcode(ExtA->getOpcode()) ||
-      !ISD::isExtOpcode(ExtB->getOpcode()))
+  if (!ISD::isExtOpcode(ExtMulOp1->getOpcode()) ||
+      !ISD::isExtOpcode(ExtMulOp2->getOpcode()))
     return SDValue();
-  bool AIsSigned = ExtA->getOpcode() == ISD::SIGN_EXTEND;
-  bool BIsSigned = ExtB->getOpcode() == ISD::SIGN_EXTEND;
+  bool MulOp1IsSigned = ExtMulOp1->getOpcode() == ISD::SIGN_EXTEND;
+  bool MulOp2IsSigned = ExtMulOp2->getOpcode() == ISD::SIGN_EXTEND;
 
-  auto A = ExtA->getOperand(0);
-  auto B = ExtB->getOperand(0);
-  if (A.getValueType() != B.getValueType())
+  auto MulOp1 = ExtMulOp1->getOperand(0);
+  auto MulOp2 = ExtMulOp2->getOperand(0);
+  if (MulOp1.getValueType() != MulOp2.getValueType())
     return SDValue();
 
-  EVT ReducedType = N->getValueType(0);
-  EVT MulSrcType = A.getValueType();
+  EVT AVT = N->getValueType(0);
+  EVT MulSrcVT = MulOp1.getValueType();
 
   // Dot products operate on chunks of four elements so there must be four times
   // as many elements in the wide type
-  if (!(ReducedType == MVT::nxv4i64 && MulSrcType == MVT::nxv16i8) &&
-      !(ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) &&
-      !(ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16) &&
-      !(ReducedType == MVT::v4i64 && MulSrcType == MVT::v16i8) &&
-      !(ReducedType == MVT::v4i32 && MulSrcType == MVT::v16i8) &&
-      !(ReducedType == MVT::v2i32 && MulSrcType == MVT::v8i8))
+  if (!(AVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) &&
+      !(AVT == MVT::nxv4i32 && MulSrcVT == MVT::nxv16i8) &&
+      !(AVT == MVT::nxv2i64 && MulSrcVT == MVT::nxv8i16) &&
+      !(AVT == MVT::v4i64 && MulSrcVT == MVT::v16i8) &&
+      !(AVT == MVT::v4i32 && MulSrcVT == MVT::v16i8) &&
+      !(AVT == MVT::v2i32 && MulSrcVT == MVT::v8i8))
     return SDValue();
 
   // If the extensions are mixed, we should lower it to a usdot instead
   unsigned Opcode = 0;
-  if (AIsSigned != BIsSigned) {
+  if (MulOp1IsSigned != MulOp2IsSigned) {
     if (!Subtarget->hasMatMulInt8())
       return SDValue();
 
     bool Scalable = N->getValueType(0).isScalableVT();
     // There's no nxv2i64 version of usdot
-    if (Scalable && ReducedType != MVT::nxv4i32 && ReducedType != MVT::nxv4i64)
+    if (Scalable && AVT != MVT::nxv4i32 && AVT != MVT::nxv4i64)
       return SDValue();
 
     Opcode = AArch64ISD::USDOT;
     // USDOT expects the signed operand to be last
-    if (!BIsSigned)
-      std::swap(A, B);
-  } else if (AIsSigned)
+    if (!MulOp2IsSigned)
+      std::swap(MulOp1, MulOp2);
+  } else if (MulOp1IsSigned)
     Opcode = AArch64ISD::SDOT;
   else
     Opcode = AArch64ISD::UDOT;
 
   // Partial reduction lowering for (nx)v16i8 to (nx)v4i64 requires an i32 dot
   // product followed by a zero / sign extension
-  if ((ReducedType == MVT::nxv4i64 && MulSrcType == MVT::nxv16i8) ||
-      (ReducedType == MVT::v4i64 && MulSrcType == MVT::v16i8)) {
-    EVT ReducedTypeI32 =
-        (ReducedType.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
+  if ((AVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) ||
+      (AVT == MVT::v4i64 && MulSrcVT == MVT::v16i8)) {
+    EVT AVTI32 = (AVT.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
 
-    auto DotI32 = DAG.getNode(Opcode, DL, ReducedTypeI32,
-                              DAG.getConstant(0, DL, ReducedTypeI32), A, B);
-    auto Extended = DAG.getSExtOrTrunc(DotI32, DL, ReducedType);
-    return DAG.getNode(ISD::ADD, DL, NarrowOp.getValueType(), NarrowOp,
-                       Extended);
+    auto DotI32 = DAG.getNode(Opcode, DL, AVTI32,
+                              DAG.getConstant(0, DL, AVTI32), MulOp1, MulOp2);
+    auto Extended = DAG.getSExtOrTrunc(DotI32, DL, AVT);
+    return DAG.getNode(ISD::ADD, DL, A.getValueType(), A, Extended);
   }
 
-  return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B);
+  return DAG.getNode(Opcode, DL, AVT, A, MulOp1, MulOp2);
 }
 
 SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
@@ -21822,32 +21820,32 @@ SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
 
   SDLoc DL(N);
 
-  auto Acc = N->getOperand(1);
-  auto ExtInput = N->getOperand(2);
+  auto A = N->getOperand(1);
+  auto ExtB = N->getOperand(2);
 
-  EVT AccVT = Acc.getValueType();
-  EVT AccElemVT = AccVT.getVectorElementType();
+  EVT AVT = A.getValueType();
+  EVT AElemVT = AVT.getVectorElementType();
 
-  if (ExtInput.getValueType().getVectorElementType() != AccElemVT)
+  if (ExtB.getValueType().getVectorElementType() != AElemVT)
     return SDValue();
 
-  unsigned ExtInputOpcode = ExtInput->getOpcode();
-  if (!ISD::isExtOpcode(ExtInputOpcode))
+  unsigned ExtBOpcode = ExtB->getOpcode();
+  if (!ISD::isExtOpcode(ExtBOpcode))
     return SDValue();
 
-  auto Input = ExtInput->getOperand(0);
-  EVT InputVT = Input.getValueType();
+  auto B = ExtB->getOperand(0);
+  EVT BVT = B.getValueType();
 
-  if (!(InputVT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
-      !(InputVT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
-      !(InputVT == MVT::nxv16i8 && AccVT == MVT::nxv8i16))
+  if (!(BVT == MVT::nxv4i32 && AVT == MVT::nxv2i64) &&
+      !(BVT == MVT::nxv8i16 && AVT == MVT::nxv4i32) &&
+      !(BVT == MVT::nxv16i8 && AVT == MVT::nxv8i16))
     return SDValue();
 
-  bool InputIsSigned = ExtInputOpcode == ISD::SIGN_EXTEND;
-  auto BottomOpcode = InputIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
-  auto TopOpcode = InputIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT;
-  auto BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, Input);
-  return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, Input);
+  bool BIsSigned = ExtBOpcode == ISD::SIGN_EXTEND;
+  auto BottomOpcode = BIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
+  auto TopOpcode = BIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT;
+  auto BottomNode = DAG.getNode(BottomOpcode, DL, AVT, A, B);
+  return DAG.getNode(TopOpcode, DL, AVT, BottomNode, B);
 }
 
 static SDValue performIntrinsicCombine(SDNode *N,

>From 8389aa3a6f758ab8c5cb9b4da1d94b095d5eac5e Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Fri, 20 Dec 2024 10:00:10 +0000
Subject: [PATCH 2/3] Address feedback with how variables are renamed and
 cleanup

---
 .../Target/AArch64/AArch64ISelLowering.cpp    | 112 +++++++++---------
 1 file changed, 54 insertions(+), 58 deletions(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 290f349c77809f..b7c4845d39730b 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -21738,72 +21738,72 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
 
   SDLoc DL(N);
 
-  // The narrower of the two operands. Used as the accumulator
-  auto A = N->getOperand(1);
-  auto B = N->getOperand(2);
-  if (B->getOpcode() != ISD::MUL)
+  SDValue Op2 = N->getOperand(2);
+  if (Op2->getOpcode() != ISD::MUL ||
+      !ISD::isExtOpcode(Op2->getOperand(0)->getOpcode()) ||
+      !ISD::isExtOpcode(Op2->getOperand(1)->getOpcode()))
     return SDValue();
 
-  auto ExtMulOp1 = B->getOperand(0);
-  auto ExtMulOp2 = B->getOperand(1);
-
-  if (!ISD::isExtOpcode(ExtMulOp1->getOpcode()) ||
-      !ISD::isExtOpcode(ExtMulOp2->getOpcode()))
-    return SDValue();
-  bool MulOp1IsSigned = ExtMulOp1->getOpcode() == ISD::SIGN_EXTEND;
-  bool MulOp2IsSigned = ExtMulOp2->getOpcode() == ISD::SIGN_EXTEND;
+  SDValue Acc = N->getOperand(1);
+  SDValue Mul = N->getOperand(2);
+  SDValue ExtMulOpLHS = Mul->getOperand(0);
+  SDValue ExtMulOpRHS = Mul->getOperand(1);
 
-  auto MulOp1 = ExtMulOp1->getOperand(0);
-  auto MulOp2 = ExtMulOp2->getOperand(0);
-  if (MulOp1.getValueType() != MulOp2.getValueType())
+  SDValue MulOpLHS = ExtMulOpLHS->getOperand(0);
+  SDValue MulOpRHS = ExtMulOpRHS->getOperand(0);
+  if (MulOpLHS.getValueType() != MulOpRHS.getValueType())
     return SDValue();
 
-  EVT AVT = N->getValueType(0);
-  EVT MulSrcVT = MulOp1.getValueType();
+  EVT ReducedVT = N->getValueType(0);
+  EVT MulSrcVT = MulOpLHS.getValueType();
 
   // Dot products operate on chunks of four elements so there must be four times
   // as many elements in the wide type
-  if (!(AVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) &&
-      !(AVT == MVT::nxv4i32 && MulSrcVT == MVT::nxv16i8) &&
-      !(AVT == MVT::nxv2i64 && MulSrcVT == MVT::nxv8i16) &&
-      !(AVT == MVT::v4i64 && MulSrcVT == MVT::v16i8) &&
-      !(AVT == MVT::v4i32 && MulSrcVT == MVT::v16i8) &&
-      !(AVT == MVT::v2i32 && MulSrcVT == MVT::v8i8))
+  if (!(ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) &&
+      !(ReducedVT == MVT::nxv4i32 && MulSrcVT == MVT::nxv16i8) &&
+      !(ReducedVT == MVT::nxv2i64 && MulSrcVT == MVT::nxv8i16) &&
+      !(ReducedVT == MVT::v4i64 && MulSrcVT == MVT::v16i8) &&
+      !(ReducedVT == MVT::v4i32 && MulSrcVT == MVT::v16i8) &&
+      !(ReducedVT == MVT::v2i32 && MulSrcVT == MVT::v8i8))
     return SDValue();
 
+  bool MulOpLHSIsSigned = ExtMulOpLHS->getOpcode() == ISD::SIGN_EXTEND;
+  bool MulOpRHSIsSigned = ExtMulOpRHS->getOpcode() == ISD::SIGN_EXTEND;
   // If the extensions are mixed, we should lower it to a usdot instead
   unsigned Opcode = 0;
-  if (MulOp1IsSigned != MulOp2IsSigned) {
+  if (MulOpLHSIsSigned != MulOpRHSIsSigned) {
     if (!Subtarget->hasMatMulInt8())
       return SDValue();
 
     bool Scalable = N->getValueType(0).isScalableVT();
     // There's no nxv2i64 version of usdot
-    if (Scalable && AVT != MVT::nxv4i32 && AVT != MVT::nxv4i64)
+    if (Scalable && ReducedVT != MVT::nxv4i32 && ReducedVT != MVT::nxv4i64)
       return SDValue();
 
     Opcode = AArch64ISD::USDOT;
     // USDOT expects the signed operand to be last
-    if (!MulOp2IsSigned)
-      std::swap(MulOp1, MulOp2);
-  } else if (MulOp1IsSigned)
+    if (!MulOpRHSIsSigned)
+      std::swap(MulOpLHS, MulOpRHS);
+  } else if (MulOpLHSIsSigned)
     Opcode = AArch64ISD::SDOT;
   else
     Opcode = AArch64ISD::UDOT;
 
   // Partial reduction lowering for (nx)v16i8 to (nx)v4i64 requires an i32 dot
   // product followed by a zero / sign extension
-  if ((AVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) ||
-      (AVT == MVT::v4i64 && MulSrcVT == MVT::v16i8)) {
-    EVT AVTI32 = (AVT.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
+  if ((ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) ||
+      (ReducedVT == MVT::v4i64 && MulSrcVT == MVT::v16i8)) {
+    EVT ReducedVTI32 =
+        (ReducedVT.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
 
-    auto DotI32 = DAG.getNode(Opcode, DL, AVTI32,
-                              DAG.getConstant(0, DL, AVTI32), MulOp1, MulOp2);
-    auto Extended = DAG.getSExtOrTrunc(DotI32, DL, AVT);
-    return DAG.getNode(ISD::ADD, DL, A.getValueType(), A, Extended);
+    SDValue DotI32 =
+        DAG.getNode(Opcode, DL, ReducedVTI32,
+                    DAG.getConstant(0, DL, ReducedVTI32), MulOpLHS, MulOpRHS);
+    SDValue Extended = DAG.getSExtOrTrunc(DotI32, DL, ReducedVT);
+    return DAG.getNode(ISD::ADD, DL, ReducedVT, Acc, Extended);
   }
 
-  return DAG.getNode(Opcode, DL, AVT, A, MulOp1, MulOp2);
+  return DAG.getNode(Opcode, DL, ReducedVT, Acc, MulOpLHS, MulOpRHS);
 }
 
 SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
@@ -21820,32 +21820,28 @@ SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
 
   SDLoc DL(N);
 
-  auto A = N->getOperand(1);
-  auto ExtB = N->getOperand(2);
-
-  EVT AVT = A.getValueType();
-  EVT AElemVT = AVT.getVectorElementType();
-
-  if (ExtB.getValueType().getVectorElementType() != AElemVT)
+  if (!ISD::isExtOpcode(N->getOperand(2).getOpcode()))
     return SDValue();
-
-  unsigned ExtBOpcode = ExtB->getOpcode();
-  if (!ISD::isExtOpcode(ExtBOpcode))
+  SDValue Acc = N->getOperand(1);
+  SDValue Ext = N->getOperand(2);
+  EVT AccVT = Acc.getValueType();
+  EVT ExtVT = Ext.getValueType();
+  if (ExtVT.getVectorElementType() != AccVT.getVectorElementType())
     return SDValue();
 
-  auto B = ExtB->getOperand(0);
-  EVT BVT = B.getValueType();
+  SDValue ExtOp = Ext->getOperand(0);
+  EVT ExtOpVT = ExtOp.getValueType();
 
-  if (!(BVT == MVT::nxv4i32 && AVT == MVT::nxv2i64) &&
-      !(BVT == MVT::nxv8i16 && AVT == MVT::nxv4i32) &&
-      !(BVT == MVT::nxv16i8 && AVT == MVT::nxv8i16))
+  if (!(ExtOpVT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
+      !(ExtOpVT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
+      !(ExtOpVT == MVT::nxv16i8 && AccVT == MVT::nxv8i16))
     return SDValue();
 
-  bool BIsSigned = ExtBOpcode == ISD::SIGN_EXTEND;
-  auto BottomOpcode = BIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
-  auto TopOpcode = BIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT;
-  auto BottomNode = DAG.getNode(BottomOpcode, DL, AVT, A, B);
-  return DAG.getNode(TopOpcode, DL, AVT, BottomNode, B);
+  bool ExtOpIsSigned = Ext.getOpcode() == ISD::SIGN_EXTEND;
+  unsigned BottomOpcode = ExtOpIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
+  unsigned TopOpcode = ExtOpIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT;
+  SDValue BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, ExtOp);
+  return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, ExtOp);
 }
 
 static SDValue performIntrinsicCombine(SDNode *N,
@@ -21857,9 +21853,9 @@ static SDValue performIntrinsicCombine(SDNode *N,
   default:
     break;
   case Intrinsic::experimental_vector_partial_reduce_add: {
-    if (auto Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG))
+    if (SDValue Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG))
       return Dot;
-    if (auto WideAdd = tryLowerPartialReductionToWideAdd(N, Subtarget, DAG))
+    if (SDValue WideAdd = tryLowerPartialReductionToWideAdd(N, Subtarget, DAG))
       return WideAdd;
     return DAG.getPartialReduceAdd(SDLoc(N), N->getValueType(0),
                                    N->getOperand(1), N->getOperand(2));

>From 0c38702fafa295686ac9168a58cdf3b0c9574217 Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Fri, 20 Dec 2024 10:16:48 +0000
Subject: [PATCH 3/3] Correct code formatting

---
 llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index b7c4845d39730b..f5a316d1d8be1c 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -21838,7 +21838,8 @@ SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
     return SDValue();
 
   bool ExtOpIsSigned = Ext.getOpcode() == ISD::SIGN_EXTEND;
-  unsigned BottomOpcode = ExtOpIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
+  unsigned BottomOpcode =
+      ExtOpIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
   unsigned TopOpcode = ExtOpIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT;
   SDValue BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, ExtOp);
   return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, ExtOp);



More information about the llvm-commits mailing list