[llvm] [NFC][AArch64][SVE] Rename variables in partial reduction lowering functions (PR #120589)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Dec 19 07:14:27 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-aarch64
Author: James Chesterman (JamesChesterman)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/120589.diff
1 Files Affected:
- (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+50-52)
``````````diff
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index d1354ccf376609..290f349c77809f 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -21739,73 +21739,71 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
SDLoc DL(N);
// The narrower of the two operands. Used as the accumulator
- auto NarrowOp = N->getOperand(1);
- auto MulOp = N->getOperand(2);
- if (MulOp->getOpcode() != ISD::MUL)
+ auto A = N->getOperand(1);
+ auto B = N->getOperand(2);
+ if (B->getOpcode() != ISD::MUL)
return SDValue();
- auto ExtA = MulOp->getOperand(0);
- auto ExtB = MulOp->getOperand(1);
+ auto ExtMulOp1 = B->getOperand(0);
+ auto ExtMulOp2 = B->getOperand(1);
- if (!ISD::isExtOpcode(ExtA->getOpcode()) ||
- !ISD::isExtOpcode(ExtB->getOpcode()))
+ if (!ISD::isExtOpcode(ExtMulOp1->getOpcode()) ||
+ !ISD::isExtOpcode(ExtMulOp2->getOpcode()))
return SDValue();
- bool AIsSigned = ExtA->getOpcode() == ISD::SIGN_EXTEND;
- bool BIsSigned = ExtB->getOpcode() == ISD::SIGN_EXTEND;
+ bool MulOp1IsSigned = ExtMulOp1->getOpcode() == ISD::SIGN_EXTEND;
+ bool MulOp2IsSigned = ExtMulOp2->getOpcode() == ISD::SIGN_EXTEND;
- auto A = ExtA->getOperand(0);
- auto B = ExtB->getOperand(0);
- if (A.getValueType() != B.getValueType())
+ auto MulOp1 = ExtMulOp1->getOperand(0);
+ auto MulOp2 = ExtMulOp2->getOperand(0);
+ if (MulOp1.getValueType() != MulOp2.getValueType())
return SDValue();
- EVT ReducedType = N->getValueType(0);
- EVT MulSrcType = A.getValueType();
+ EVT AVT = N->getValueType(0);
+ EVT MulSrcVT = MulOp1.getValueType();
// Dot products operate on chunks of four elements so there must be four times
// as many elements in the wide type
- if (!(ReducedType == MVT::nxv4i64 && MulSrcType == MVT::nxv16i8) &&
- !(ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) &&
- !(ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16) &&
- !(ReducedType == MVT::v4i64 && MulSrcType == MVT::v16i8) &&
- !(ReducedType == MVT::v4i32 && MulSrcType == MVT::v16i8) &&
- !(ReducedType == MVT::v2i32 && MulSrcType == MVT::v8i8))
+ if (!(AVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) &&
+ !(AVT == MVT::nxv4i32 && MulSrcVT == MVT::nxv16i8) &&
+ !(AVT == MVT::nxv2i64 && MulSrcVT == MVT::nxv8i16) &&
+ !(AVT == MVT::v4i64 && MulSrcVT == MVT::v16i8) &&
+ !(AVT == MVT::v4i32 && MulSrcVT == MVT::v16i8) &&
+ !(AVT == MVT::v2i32 && MulSrcVT == MVT::v8i8))
return SDValue();
// If the extensions are mixed, we should lower it to a usdot instead
unsigned Opcode = 0;
- if (AIsSigned != BIsSigned) {
+ if (MulOp1IsSigned != MulOp2IsSigned) {
if (!Subtarget->hasMatMulInt8())
return SDValue();
bool Scalable = N->getValueType(0).isScalableVT();
// There's no nxv2i64 version of usdot
- if (Scalable && ReducedType != MVT::nxv4i32 && ReducedType != MVT::nxv4i64)
+ if (Scalable && AVT != MVT::nxv4i32 && AVT != MVT::nxv4i64)
return SDValue();
Opcode = AArch64ISD::USDOT;
// USDOT expects the signed operand to be last
- if (!BIsSigned)
- std::swap(A, B);
- } else if (AIsSigned)
+ if (!MulOp2IsSigned)
+ std::swap(MulOp1, MulOp2);
+ } else if (MulOp1IsSigned)
Opcode = AArch64ISD::SDOT;
else
Opcode = AArch64ISD::UDOT;
// Partial reduction lowering for (nx)v16i8 to (nx)v4i64 requires an i32 dot
// product followed by a zero / sign extension
- if ((ReducedType == MVT::nxv4i64 && MulSrcType == MVT::nxv16i8) ||
- (ReducedType == MVT::v4i64 && MulSrcType == MVT::v16i8)) {
- EVT ReducedTypeI32 =
- (ReducedType.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
+ if ((AVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) ||
+ (AVT == MVT::v4i64 && MulSrcVT == MVT::v16i8)) {
+ EVT AVTI32 = (AVT.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
- auto DotI32 = DAG.getNode(Opcode, DL, ReducedTypeI32,
- DAG.getConstant(0, DL, ReducedTypeI32), A, B);
- auto Extended = DAG.getSExtOrTrunc(DotI32, DL, ReducedType);
- return DAG.getNode(ISD::ADD, DL, NarrowOp.getValueType(), NarrowOp,
- Extended);
+ auto DotI32 = DAG.getNode(Opcode, DL, AVTI32,
+ DAG.getConstant(0, DL, AVTI32), MulOp1, MulOp2);
+ auto Extended = DAG.getSExtOrTrunc(DotI32, DL, AVT);
+ return DAG.getNode(ISD::ADD, DL, A.getValueType(), A, Extended);
}
- return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B);
+ return DAG.getNode(Opcode, DL, AVT, A, MulOp1, MulOp2);
}
SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
@@ -21822,32 +21820,32 @@ SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
SDLoc DL(N);
- auto Acc = N->getOperand(1);
- auto ExtInput = N->getOperand(2);
+ auto A = N->getOperand(1);
+ auto ExtB = N->getOperand(2);
- EVT AccVT = Acc.getValueType();
- EVT AccElemVT = AccVT.getVectorElementType();
+ EVT AVT = A.getValueType();
+ EVT AElemVT = AVT.getVectorElementType();
- if (ExtInput.getValueType().getVectorElementType() != AccElemVT)
+ if (ExtB.getValueType().getVectorElementType() != AElemVT)
return SDValue();
- unsigned ExtInputOpcode = ExtInput->getOpcode();
- if (!ISD::isExtOpcode(ExtInputOpcode))
+ unsigned ExtBOpcode = ExtB->getOpcode();
+ if (!ISD::isExtOpcode(ExtBOpcode))
return SDValue();
- auto Input = ExtInput->getOperand(0);
- EVT InputVT = Input.getValueType();
+ auto B = ExtB->getOperand(0);
+ EVT BVT = B.getValueType();
- if (!(InputVT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
- !(InputVT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
- !(InputVT == MVT::nxv16i8 && AccVT == MVT::nxv8i16))
+ if (!(BVT == MVT::nxv4i32 && AVT == MVT::nxv2i64) &&
+ !(BVT == MVT::nxv8i16 && AVT == MVT::nxv4i32) &&
+ !(BVT == MVT::nxv16i8 && AVT == MVT::nxv8i16))
return SDValue();
- bool InputIsSigned = ExtInputOpcode == ISD::SIGN_EXTEND;
- auto BottomOpcode = InputIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
- auto TopOpcode = InputIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT;
- auto BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, Input);
- return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, Input);
+ bool BIsSigned = ExtBOpcode == ISD::SIGN_EXTEND;
+ auto BottomOpcode = BIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
+ auto TopOpcode = BIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT;
+ auto BottomNode = DAG.getNode(BottomOpcode, DL, AVT, A, B);
+ return DAG.getNode(TopOpcode, DL, AVT, BottomNode, B);
}
static SDValue performIntrinsicCombine(SDNode *N,
``````````
</details>
https://github.com/llvm/llvm-project/pull/120589
More information about the llvm-commits
mailing list