[llvm] [NFC][AArch64][SVE] Rename variables in partial reduction lowering functions (PR #120589)
James Chesterman via llvm-commits
llvm-commits at lists.llvm.org
Fri Dec 20 02:05:40 PST 2024
https://github.com/JamesChesterman updated https://github.com/llvm/llvm-project/pull/120589
>From 600930a3b48ecec1612d73ef233d085f0a8e5cd1 Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Thu, 19 Dec 2024 15:11:21 +0000
Subject: [PATCH 1/2] [NFC][AArch64][SVE] Rename variables in partial reduction
lowering functions
---
.../Target/AArch64/AArch64ISelLowering.cpp | 102 +++++++++---------
1 file changed, 50 insertions(+), 52 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index d1354ccf376609..290f349c77809f 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -21739,73 +21739,71 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
SDLoc DL(N);
// The narrower of the two operands. Used as the accumulator
- auto NarrowOp = N->getOperand(1);
- auto MulOp = N->getOperand(2);
- if (MulOp->getOpcode() != ISD::MUL)
+ auto A = N->getOperand(1);
+ auto B = N->getOperand(2);
+ if (B->getOpcode() != ISD::MUL)
return SDValue();
- auto ExtA = MulOp->getOperand(0);
- auto ExtB = MulOp->getOperand(1);
+ auto ExtMulOp1 = B->getOperand(0);
+ auto ExtMulOp2 = B->getOperand(1);
- if (!ISD::isExtOpcode(ExtA->getOpcode()) ||
- !ISD::isExtOpcode(ExtB->getOpcode()))
+ if (!ISD::isExtOpcode(ExtMulOp1->getOpcode()) ||
+ !ISD::isExtOpcode(ExtMulOp2->getOpcode()))
return SDValue();
- bool AIsSigned = ExtA->getOpcode() == ISD::SIGN_EXTEND;
- bool BIsSigned = ExtB->getOpcode() == ISD::SIGN_EXTEND;
+ bool MulOp1IsSigned = ExtMulOp1->getOpcode() == ISD::SIGN_EXTEND;
+ bool MulOp2IsSigned = ExtMulOp2->getOpcode() == ISD::SIGN_EXTEND;
- auto A = ExtA->getOperand(0);
- auto B = ExtB->getOperand(0);
- if (A.getValueType() != B.getValueType())
+ auto MulOp1 = ExtMulOp1->getOperand(0);
+ auto MulOp2 = ExtMulOp2->getOperand(0);
+ if (MulOp1.getValueType() != MulOp2.getValueType())
return SDValue();
- EVT ReducedType = N->getValueType(0);
- EVT MulSrcType = A.getValueType();
+ EVT AVT = N->getValueType(0);
+ EVT MulSrcVT = MulOp1.getValueType();
// Dot products operate on chunks of four elements so there must be four times
// as many elements in the wide type
- if (!(ReducedType == MVT::nxv4i64 && MulSrcType == MVT::nxv16i8) &&
- !(ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) &&
- !(ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16) &&
- !(ReducedType == MVT::v4i64 && MulSrcType == MVT::v16i8) &&
- !(ReducedType == MVT::v4i32 && MulSrcType == MVT::v16i8) &&
- !(ReducedType == MVT::v2i32 && MulSrcType == MVT::v8i8))
+ if (!(AVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) &&
+ !(AVT == MVT::nxv4i32 && MulSrcVT == MVT::nxv16i8) &&
+ !(AVT == MVT::nxv2i64 && MulSrcVT == MVT::nxv8i16) &&
+ !(AVT == MVT::v4i64 && MulSrcVT == MVT::v16i8) &&
+ !(AVT == MVT::v4i32 && MulSrcVT == MVT::v16i8) &&
+ !(AVT == MVT::v2i32 && MulSrcVT == MVT::v8i8))
return SDValue();
// If the extensions are mixed, we should lower it to a usdot instead
unsigned Opcode = 0;
- if (AIsSigned != BIsSigned) {
+ if (MulOp1IsSigned != MulOp2IsSigned) {
if (!Subtarget->hasMatMulInt8())
return SDValue();
bool Scalable = N->getValueType(0).isScalableVT();
// There's no nxv2i64 version of usdot
- if (Scalable && ReducedType != MVT::nxv4i32 && ReducedType != MVT::nxv4i64)
+ if (Scalable && AVT != MVT::nxv4i32 && AVT != MVT::nxv4i64)
return SDValue();
Opcode = AArch64ISD::USDOT;
// USDOT expects the signed operand to be last
- if (!BIsSigned)
- std::swap(A, B);
- } else if (AIsSigned)
+ if (!MulOp2IsSigned)
+ std::swap(MulOp1, MulOp2);
+ } else if (MulOp1IsSigned)
Opcode = AArch64ISD::SDOT;
else
Opcode = AArch64ISD::UDOT;
// Partial reduction lowering for (nx)v16i8 to (nx)v4i64 requires an i32 dot
// product followed by a zero / sign extension
- if ((ReducedType == MVT::nxv4i64 && MulSrcType == MVT::nxv16i8) ||
- (ReducedType == MVT::v4i64 && MulSrcType == MVT::v16i8)) {
- EVT ReducedTypeI32 =
- (ReducedType.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
+ if ((AVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) ||
+ (AVT == MVT::v4i64 && MulSrcVT == MVT::v16i8)) {
+ EVT AVTI32 = (AVT.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
- auto DotI32 = DAG.getNode(Opcode, DL, ReducedTypeI32,
- DAG.getConstant(0, DL, ReducedTypeI32), A, B);
- auto Extended = DAG.getSExtOrTrunc(DotI32, DL, ReducedType);
- return DAG.getNode(ISD::ADD, DL, NarrowOp.getValueType(), NarrowOp,
- Extended);
+ auto DotI32 = DAG.getNode(Opcode, DL, AVTI32,
+ DAG.getConstant(0, DL, AVTI32), MulOp1, MulOp2);
+ auto Extended = DAG.getSExtOrTrunc(DotI32, DL, AVT);
+ return DAG.getNode(ISD::ADD, DL, A.getValueType(), A, Extended);
}
- return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B);
+ return DAG.getNode(Opcode, DL, AVT, A, MulOp1, MulOp2);
}
SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
@@ -21822,32 +21820,32 @@ SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
SDLoc DL(N);
- auto Acc = N->getOperand(1);
- auto ExtInput = N->getOperand(2);
+ auto A = N->getOperand(1);
+ auto ExtB = N->getOperand(2);
- EVT AccVT = Acc.getValueType();
- EVT AccElemVT = AccVT.getVectorElementType();
+ EVT AVT = A.getValueType();
+ EVT AElemVT = AVT.getVectorElementType();
- if (ExtInput.getValueType().getVectorElementType() != AccElemVT)
+ if (ExtB.getValueType().getVectorElementType() != AElemVT)
return SDValue();
- unsigned ExtInputOpcode = ExtInput->getOpcode();
- if (!ISD::isExtOpcode(ExtInputOpcode))
+ unsigned ExtBOpcode = ExtB->getOpcode();
+ if (!ISD::isExtOpcode(ExtBOpcode))
return SDValue();
- auto Input = ExtInput->getOperand(0);
- EVT InputVT = Input.getValueType();
+ auto B = ExtB->getOperand(0);
+ EVT BVT = B.getValueType();
- if (!(InputVT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
- !(InputVT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
- !(InputVT == MVT::nxv16i8 && AccVT == MVT::nxv8i16))
+ if (!(BVT == MVT::nxv4i32 && AVT == MVT::nxv2i64) &&
+ !(BVT == MVT::nxv8i16 && AVT == MVT::nxv4i32) &&
+ !(BVT == MVT::nxv16i8 && AVT == MVT::nxv8i16))
return SDValue();
- bool InputIsSigned = ExtInputOpcode == ISD::SIGN_EXTEND;
- auto BottomOpcode = InputIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
- auto TopOpcode = InputIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT;
- auto BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, Input);
- return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, Input);
+ bool BIsSigned = ExtBOpcode == ISD::SIGN_EXTEND;
+ auto BottomOpcode = BIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
+ auto TopOpcode = BIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT;
+ auto BottomNode = DAG.getNode(BottomOpcode, DL, AVT, A, B);
+ return DAG.getNode(TopOpcode, DL, AVT, BottomNode, B);
}
static SDValue performIntrinsicCombine(SDNode *N,
>From 8389aa3a6f758ab8c5cb9b4da1d94b095d5eac5e Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Fri, 20 Dec 2024 10:00:10 +0000
Subject: [PATCH 2/2] Address feedback with how variables are renamed and
cleanup
---
.../Target/AArch64/AArch64ISelLowering.cpp | 112 +++++++++---------
1 file changed, 54 insertions(+), 58 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 290f349c77809f..b7c4845d39730b 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -21738,72 +21738,72 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
SDLoc DL(N);
- // The narrower of the two operands. Used as the accumulator
- auto A = N->getOperand(1);
- auto B = N->getOperand(2);
- if (B->getOpcode() != ISD::MUL)
+ SDValue Op2 = N->getOperand(2);
+ if (Op2->getOpcode() != ISD::MUL ||
+ !ISD::isExtOpcode(Op2->getOperand(0)->getOpcode()) ||
+ !ISD::isExtOpcode(Op2->getOperand(1)->getOpcode()))
return SDValue();
- auto ExtMulOp1 = B->getOperand(0);
- auto ExtMulOp2 = B->getOperand(1);
-
- if (!ISD::isExtOpcode(ExtMulOp1->getOpcode()) ||
- !ISD::isExtOpcode(ExtMulOp2->getOpcode()))
- return SDValue();
- bool MulOp1IsSigned = ExtMulOp1->getOpcode() == ISD::SIGN_EXTEND;
- bool MulOp2IsSigned = ExtMulOp2->getOpcode() == ISD::SIGN_EXTEND;
+ SDValue Acc = N->getOperand(1);
+ SDValue Mul = N->getOperand(2);
+ SDValue ExtMulOpLHS = Mul->getOperand(0);
+ SDValue ExtMulOpRHS = Mul->getOperand(1);
- auto MulOp1 = ExtMulOp1->getOperand(0);
- auto MulOp2 = ExtMulOp2->getOperand(0);
- if (MulOp1.getValueType() != MulOp2.getValueType())
+ SDValue MulOpLHS = ExtMulOpLHS->getOperand(0);
+ SDValue MulOpRHS = ExtMulOpRHS->getOperand(0);
+ if (MulOpLHS.getValueType() != MulOpRHS.getValueType())
return SDValue();
- EVT AVT = N->getValueType(0);
- EVT MulSrcVT = MulOp1.getValueType();
+ EVT ReducedVT = N->getValueType(0);
+ EVT MulSrcVT = MulOpLHS.getValueType();
// Dot products operate on chunks of four elements so there must be four times
// as many elements in the wide type
- if (!(AVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) &&
- !(AVT == MVT::nxv4i32 && MulSrcVT == MVT::nxv16i8) &&
- !(AVT == MVT::nxv2i64 && MulSrcVT == MVT::nxv8i16) &&
- !(AVT == MVT::v4i64 && MulSrcVT == MVT::v16i8) &&
- !(AVT == MVT::v4i32 && MulSrcVT == MVT::v16i8) &&
- !(AVT == MVT::v2i32 && MulSrcVT == MVT::v8i8))
+ if (!(ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) &&
+ !(ReducedVT == MVT::nxv4i32 && MulSrcVT == MVT::nxv16i8) &&
+ !(ReducedVT == MVT::nxv2i64 && MulSrcVT == MVT::nxv8i16) &&
+ !(ReducedVT == MVT::v4i64 && MulSrcVT == MVT::v16i8) &&
+ !(ReducedVT == MVT::v4i32 && MulSrcVT == MVT::v16i8) &&
+ !(ReducedVT == MVT::v2i32 && MulSrcVT == MVT::v8i8))
return SDValue();
+ bool MulOpLHSIsSigned = ExtMulOpLHS->getOpcode() == ISD::SIGN_EXTEND;
+ bool MulOpRHSIsSigned = ExtMulOpRHS->getOpcode() == ISD::SIGN_EXTEND;
// If the extensions are mixed, we should lower it to a usdot instead
unsigned Opcode = 0;
- if (MulOp1IsSigned != MulOp2IsSigned) {
+ if (MulOpLHSIsSigned != MulOpRHSIsSigned) {
if (!Subtarget->hasMatMulInt8())
return SDValue();
bool Scalable = N->getValueType(0).isScalableVT();
// There's no nxv2i64 version of usdot
- if (Scalable && AVT != MVT::nxv4i32 && AVT != MVT::nxv4i64)
+ if (Scalable && ReducedVT != MVT::nxv4i32 && ReducedVT != MVT::nxv4i64)
return SDValue();
Opcode = AArch64ISD::USDOT;
// USDOT expects the signed operand to be last
- if (!MulOp2IsSigned)
- std::swap(MulOp1, MulOp2);
- } else if (MulOp1IsSigned)
+ if (!MulOpRHSIsSigned)
+ std::swap(MulOpLHS, MulOpRHS);
+ } else if (MulOpLHSIsSigned)
Opcode = AArch64ISD::SDOT;
else
Opcode = AArch64ISD::UDOT;
// Partial reduction lowering for (nx)v16i8 to (nx)v4i64 requires an i32 dot
// product followed by a zero / sign extension
- if ((AVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) ||
- (AVT == MVT::v4i64 && MulSrcVT == MVT::v16i8)) {
- EVT AVTI32 = (AVT.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
+ if ((ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) ||
+ (ReducedVT == MVT::v4i64 && MulSrcVT == MVT::v16i8)) {
+ EVT ReducedVTI32 =
+ (ReducedVT.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
- auto DotI32 = DAG.getNode(Opcode, DL, AVTI32,
- DAG.getConstant(0, DL, AVTI32), MulOp1, MulOp2);
- auto Extended = DAG.getSExtOrTrunc(DotI32, DL, AVT);
- return DAG.getNode(ISD::ADD, DL, A.getValueType(), A, Extended);
+ SDValue DotI32 =
+ DAG.getNode(Opcode, DL, ReducedVTI32,
+ DAG.getConstant(0, DL, ReducedVTI32), MulOpLHS, MulOpRHS);
+ SDValue Extended = DAG.getSExtOrTrunc(DotI32, DL, ReducedVT);
+ return DAG.getNode(ISD::ADD, DL, ReducedVT, Acc, Extended);
}
- return DAG.getNode(Opcode, DL, AVT, A, MulOp1, MulOp2);
+ return DAG.getNode(Opcode, DL, ReducedVT, Acc, MulOpLHS, MulOpRHS);
}
SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
@@ -21820,32 +21820,28 @@ SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
SDLoc DL(N);
- auto A = N->getOperand(1);
- auto ExtB = N->getOperand(2);
-
- EVT AVT = A.getValueType();
- EVT AElemVT = AVT.getVectorElementType();
-
- if (ExtB.getValueType().getVectorElementType() != AElemVT)
+ if (!ISD::isExtOpcode(N->getOperand(2).getOpcode()))
return SDValue();
-
- unsigned ExtBOpcode = ExtB->getOpcode();
- if (!ISD::isExtOpcode(ExtBOpcode))
+ SDValue Acc = N->getOperand(1);
+ SDValue Ext = N->getOperand(2);
+ EVT AccVT = Acc.getValueType();
+ EVT ExtVT = Ext.getValueType();
+ if (ExtVT.getVectorElementType() != AccVT.getVectorElementType())
return SDValue();
- auto B = ExtB->getOperand(0);
- EVT BVT = B.getValueType();
+ SDValue ExtOp = Ext->getOperand(0);
+ EVT ExtOpVT = ExtOp.getValueType();
- if (!(BVT == MVT::nxv4i32 && AVT == MVT::nxv2i64) &&
- !(BVT == MVT::nxv8i16 && AVT == MVT::nxv4i32) &&
- !(BVT == MVT::nxv16i8 && AVT == MVT::nxv8i16))
+ if (!(ExtOpVT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
+ !(ExtOpVT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
+ !(ExtOpVT == MVT::nxv16i8 && AccVT == MVT::nxv8i16))
return SDValue();
- bool BIsSigned = ExtBOpcode == ISD::SIGN_EXTEND;
- auto BottomOpcode = BIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
- auto TopOpcode = BIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT;
- auto BottomNode = DAG.getNode(BottomOpcode, DL, AVT, A, B);
- return DAG.getNode(TopOpcode, DL, AVT, BottomNode, B);
+ bool ExtOpIsSigned = Ext.getOpcode() == ISD::SIGN_EXTEND;
+ unsigned BottomOpcode = ExtOpIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
+ unsigned TopOpcode = ExtOpIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT;
+ SDValue BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, ExtOp);
+ return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, ExtOp);
}
static SDValue performIntrinsicCombine(SDNode *N,
@@ -21857,9 +21853,9 @@ static SDValue performIntrinsicCombine(SDNode *N,
default:
break;
case Intrinsic::experimental_vector_partial_reduce_add: {
- if (auto Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG))
+ if (SDValue Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG))
return Dot;
- if (auto WideAdd = tryLowerPartialReductionToWideAdd(N, Subtarget, DAG))
+ if (SDValue WideAdd = tryLowerPartialReductionToWideAdd(N, Subtarget, DAG))
return WideAdd;
return DAG.getPartialReduceAdd(SDLoc(N), N->getValueType(0),
N->getOperand(1), N->getOperand(2));
More information about the llvm-commits
mailing list