[llvm] [AArch64][SVE] Add partial reduction SDNodes (PR #117185)
James Chesterman via llvm-commits
llvm-commits at lists.llvm.org
Thu Jan 23 06:29:52 PST 2025
https://github.com/JamesChesterman updated https://github.com/llvm/llvm-project/pull/117185
>From ee6f89fc31ad0ba2f206e6171bc3ddd04ca4fdc1 Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Thu, 21 Nov 2024 16:19:19 +0000
Subject: [PATCH 01/14] [AArch64][SVE] Add partial reduction SDNodes
Add the opcode 'ISD::PARTIAL_REDUCE_ADD' and use it when making
SDNodes. When the inputs and outputs have types that can allow for
lowering to wide add or dot product instruction(s), then convert
the corresponding intrinsic to an SDNode. This will allow
legalisation, which will be added in a future patch, to be done
more easily.
---
llvm/include/llvm/CodeGen/ISDOpcodes.h | 5 +
llvm/include/llvm/CodeGen/SelectionDAG.h | 5 +
llvm/include/llvm/CodeGen/SelectionDAGNodes.h | 16 ++++
.../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 6 ++
.../SelectionDAG/SelectionDAGBuilder.cpp | 12 ++-
.../SelectionDAG/SelectionDAGBuilder.h | 1 +
.../SelectionDAG/SelectionDAGDumper.cpp | 2 +
.../Target/AArch64/AArch64ISelLowering.cpp | 94 +++++++++----------
8 files changed, 92 insertions(+), 49 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index fd8784a4c10034..6ba76a42d06f85 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -1451,6 +1451,11 @@ enum NodeType {
VECREDUCE_UMAX,
VECREDUCE_UMIN,
+ // The `llvm.experimental.vector.partial.reduce.add` intrinsic
+ // Operands: Accumulator, Input
+ // Outputs: Output
+ PARTIAL_REDUCE_ADD,
+
// The `llvm.experimental.stackmap` intrinsic.
// Operands: input chain, glue, <id>, <numShadowBytes>, [live0[, live1...]]
// Outputs: output chain, glue
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index ba0538f7084eec..803339480f67cc 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -1604,6 +1604,11 @@ class SelectionDAG {
/// the target's desired shift amount type.
SDValue getShiftAmountOperand(EVT LHSTy, SDValue Op);
+ /// Get a partial reduction SD node for the DAG. This is done when the input
+ /// and output types can be legalised for wide add(s) or dot product(s)
+ SDValue getPartialReduceAddSDNode(SDLoc DL, SDValue Chain, SDValue Acc,
+ SDValue Input);
+
/// Create the DAG equivalent of vector_partial_reduce where Op1 and Op2 are
/// its operands and ReducedTY is the intrinsic's return type.
SDValue getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index 49467ce0a54cd0..32b3bcfd3dcde9 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -3033,6 +3033,22 @@ class MaskedHistogramSDNode : public MaskedGatherScatterSDNode {
}
};
+class PartialReduceAddSDNode : public SDNode {
+public:
+ friend class SelectionDAG;
+
+ PartialReduceAddSDNode(const DebugLoc &dl, SDVTList VTs)
+ : SDNode(ISD::PARTIAL_REDUCE_ADD, 0, dl, VTs) {}
+
+ const SDValue &getChain() const { return getOperand(0); }
+ const SDValue &getAcc() const { return getOperand(1); }
+ const SDValue &getInput() const { return getOperand(2); }
+
+ static bool classof(const SDNode *N) {
+ return N->getOpcode() == ISD::PARTIAL_REDUCE_ADD;
+ }
+};
+
class FPStateAccessSDNode : public MemSDNode {
public:
friend class SelectionDAG;
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 0dfd0302ae5438..c2a52cf1e8314f 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -2467,6 +2467,12 @@ SDValue SelectionDAG::getShiftAmountOperand(EVT LHSTy, SDValue Op) {
return getZExtOrTrunc(Op, SDLoc(Op), ShTy);
}
+SDValue SelectionDAG::getPartialReduceAddSDNode(SDLoc DL, SDValue Chain,
+ SDValue Acc, SDValue Input) {
+ return getNode(ISD::PARTIAL_REDUCE_ADD, DL, Acc.getValueType(), Chain, Acc,
+ Input);
+}
+
SDValue SelectionDAG::getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
SDValue Op2) {
EVT FullTy = Op2.getValueType();
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 8a5d7c0b022d90..f055e57027f4e4 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -6422,6 +6422,16 @@ void SelectionDAGBuilder::visitVectorHistogram(const CallInst &I,
DAG.setRoot(Histogram);
}
+void SelectionDAGBuilder::visitPartialReduceAdd(const CallInst &I,
+ unsigned IntrinsicID) {
+ SDLoc dl = getCurSDLoc();
+ SDValue Acc = getValue(I.getOperand(0));
+ SDValue Input = getValue(I.getOperand(1));
+ SDValue Chain = getRoot();
+
+ setValue(&I, DAG.getPartialReduceAddSDNode(dl, Chain, Acc, Input));
+}
+
void SelectionDAGBuilder::visitVectorExtractLastActive(const CallInst &I,
unsigned Intrinsic) {
assert(Intrinsic == Intrinsic::experimental_vector_extract_last_active &&
@@ -8120,7 +8130,7 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
case Intrinsic::experimental_vector_partial_reduce_add: {
if (!TLI.shouldExpandPartialReductionIntrinsic(cast<IntrinsicInst>(&I))) {
- visitTargetIntrinsic(I, Intrinsic);
+ visitPartialReduceAdd(I, Intrinsic);
return;
}
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h
index 3a8dc25e98700e..a9e0c8f1ea10c1 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h
@@ -629,6 +629,7 @@ class SelectionDAGBuilder {
void visitConstrainedFPIntrinsic(const ConstrainedFPIntrinsic &FPI);
void visitConvergenceControl(const CallInst &I, unsigned Intrinsic);
void visitVectorHistogram(const CallInst &I, unsigned IntrinsicID);
+ void visitPartialReduceAdd(const CallInst &, unsigned IntrinsicID);
void visitVectorExtractLastActive(const CallInst &I, unsigned Intrinsic);
void visitVPLoad(const VPIntrinsic &VPIntrin, EVT VT,
const SmallVectorImpl<SDValue> &OpValues);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
index f63c8dd3df1c83..f0126cd41f671d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
@@ -570,6 +570,8 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
case ISD::VECTOR_FIND_LAST_ACTIVE:
return "find_last_active";
+ case ISD::PARTIAL_REDUCE_ADD:
+ return "partial_reduce_add";
// Vector Predication
#define BEGIN_REGISTER_VP_SDNODE(SDID, LEGALARG, NAME, ...) \
case ISD::SDID: \
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 8a3a9f75415fbc..9edc7573272905 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1128,6 +1128,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setTargetDAGCombine(
{ISD::MGATHER, ISD::MSCATTER, ISD::EXPERIMENTAL_VECTOR_HISTOGRAM});
+ setTargetDAGCombine(ISD::PARTIAL_REDUCE_ADD);
+
setTargetDAGCombine(ISD::FP_EXTEND);
setTargetDAGCombine(ISD::GlobalAddress);
@@ -22011,40 +22013,23 @@ static SDValue tryCombineWhileLo(SDNode *N,
return SDValue(N, 0);
}
-SDValue tryLowerPartialReductionToDot(SDNode *N,
+SDValue tryLowerPartialReductionToDot(PartialReduceAddSDNode *PR,
const AArch64Subtarget *Subtarget,
SelectionDAG &DAG) {
- assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
- getIntrinsicID(N) ==
- Intrinsic::experimental_vector_partial_reduce_add &&
- "Expected a partial reduction node");
-
- bool Scalable = N->getValueType(0).isScalableVector();
+ bool Scalable = PR->getValueType(0).isScalableVector();
if (Scalable && !Subtarget->isSVEorStreamingSVEAvailable())
return SDValue();
if (!Scalable && (!Subtarget->isNeonAvailable() || !Subtarget->hasDotProd()))
return SDValue();
- SDLoc DL(N);
+ SDLoc DL(PR);
- SDValue Op2 = N->getOperand(2);
- unsigned Op2Opcode = Op2->getOpcode();
- SDValue MulOpLHS, MulOpRHS;
- bool MulOpLHSIsSigned, MulOpRHSIsSigned;
- if (ISD::isExtOpcode(Op2Opcode)) {
- MulOpLHSIsSigned = MulOpRHSIsSigned = (Op2Opcode == ISD::SIGN_EXTEND);
- MulOpLHS = Op2->getOperand(0);
- MulOpRHS = DAG.getConstant(1, DL, MulOpLHS.getValueType());
- } else if (Op2Opcode == ISD::MUL) {
- SDValue ExtMulOpLHS = Op2->getOperand(0);
- SDValue ExtMulOpRHS = Op2->getOperand(1);
-
- unsigned ExtMulOpLHSOpcode = ExtMulOpLHS->getOpcode();
- unsigned ExtMulOpRHSOpcode = ExtMulOpRHS->getOpcode();
- if (!ISD::isExtOpcode(ExtMulOpLHSOpcode) ||
- !ISD::isExtOpcode(ExtMulOpRHSOpcode))
- return SDValue();
+ // The narrower of the two operands. Used as the accumulator
+ auto NarrowOp = PR->getAcc();
+ auto MulOp = PR->getInput();
+ if (MulOp->getOpcode() != ISD::MUL)
+ return SDValue();
MulOpLHSIsSigned = ExtMulOpLHSOpcode == ISD::SIGN_EXTEND;
MulOpRHSIsSigned = ExtMulOpRHSOpcode == ISD::SIGN_EXTEND;
@@ -22057,9 +22042,8 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
} else
return SDValue();
- SDValue Acc = N->getOperand(1);
- EVT ReducedVT = N->getValueType(0);
- EVT MulSrcVT = MulOpLHS.getValueType();
+ EVT ReducedType = PR->getValueType(0);
+ EVT MulSrcType = A.getValueType();
// Dot products operate on chunks of four elements so there must be four times
// as many elements in the wide type
@@ -22077,7 +22061,7 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
if (!Subtarget->hasMatMulInt8())
return SDValue();
- bool Scalable = N->getValueType(0).isScalableVT();
+ bool Scalable = PR->getValueType(0).isScalableVT();
// There's no nxv2i64 version of usdot
if (Scalable && ReducedVT != MVT::nxv4i32 && ReducedVT != MVT::nxv4i64)
return SDValue();
@@ -22106,24 +22090,18 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
return DAG.getNode(Opcode, DL, ReducedVT, Acc, MulOpLHS, MulOpRHS);
}
-SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
+SDValue tryLowerPartialReductionToWideAdd(PartialReduceAddSDNode *PR,
const AArch64Subtarget *Subtarget,
SelectionDAG &DAG) {
- assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
- getIntrinsicID(N) ==
- Intrinsic::experimental_vector_partial_reduce_add &&
- "Expected a partial reduction node");
-
if (!Subtarget->hasSVE2() && !Subtarget->isStreamingSVEAvailable())
return SDValue();
- SDLoc DL(N);
+ SDLoc DL(PR);
+
+ auto Acc = PR->getAcc();
+ auto ExtInput = PR->getInput();
- if (!ISD::isExtOpcode(N->getOperand(2).getOpcode()))
- return SDValue();
- SDValue Acc = N->getOperand(1);
- SDValue Ext = N->getOperand(2);
EVT AccVT = Acc.getValueType();
EVT ExtVT = Ext.getValueType();
if (ExtVT.getVectorElementType() != AccVT.getVectorElementType())
@@ -22145,6 +22123,32 @@ SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, ExtOp);
}
+static SDValue
+performPartialReduceAddCombine(SDNode *N, SelectionDAG &DAG,
+ const AArch64Subtarget *Subtarget) {
+ auto *PR = cast<PartialReduceAddSDNode>(N);
+ if (auto Dot = tryLowerPartialReductionToDot(PR, Subtarget, DAG))
+ return Dot;
+ if (auto WideAdd = tryLowerPartialReductionToWideAdd(PR, Subtarget, DAG))
+ return WideAdd;
+ return DAG.getPartialReduceAdd(SDLoc(PR), PR->getValueType(0), PR->getAcc(),
+ PR->getInput());
+}
+
+
+
+static SDValue
+performPartialReduceAddCombine(SDNode *N, SelectionDAG &DAG,
+ const AArch64Subtarget *Subtarget) {
+ auto *PR = cast<PartialReduceAddSDNode>(N);
+ if (auto Dot = tryLowerPartialReductionToDot(PR, Subtarget, DAG))
+ return Dot;
+ if (auto WideAdd = tryLowerPartialReductionToWideAdd(PR, Subtarget, DAG))
+ return WideAdd;
+ return DAG.getPartialReduceAdd(SDLoc(PR), PR->getValueType(0), PR->getAcc(),
+ PR->getInput());
+}
+
static SDValue performIntrinsicCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
const AArch64Subtarget *Subtarget) {
@@ -22153,14 +22157,6 @@ static SDValue performIntrinsicCombine(SDNode *N,
switch (IID) {
default:
break;
- case Intrinsic::experimental_vector_partial_reduce_add: {
- if (SDValue Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG))
- return Dot;
- if (SDValue WideAdd = tryLowerPartialReductionToWideAdd(N, Subtarget, DAG))
- return WideAdd;
- return DAG.getPartialReduceAdd(SDLoc(N), N->getValueType(0),
- N->getOperand(1), N->getOperand(2));
- }
case Intrinsic::aarch64_neon_vcvtfxs2fp:
case Intrinsic::aarch64_neon_vcvtfxu2fp:
return tryCombineFixedPointConvert(N, DCI, DAG);
@@ -26591,6 +26587,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
case ISD::MSCATTER:
case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
return performMaskedGatherScatterCombine(N, DCI, DAG);
+ case ISD::PARTIAL_REDUCE_ADD:
+ return performPartialReduceAddCombine(N, DAG, Subtarget);
case ISD::FP_EXTEND:
return performFPExtendCombine(N, DAG, DCI, Subtarget);
case AArch64ISD::BRCOND:
>From 8bc55db640580583d1da8ce1b7ef061efda43b1e Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Wed, 27 Nov 2024 09:48:06 +0000
Subject: [PATCH 02/14] Changes to previous patch. Involves removing
PartialReduceAddSDNode as well as changing how the intrinsic is transformed
into the SD node.
---
llvm/include/llvm/CodeGen/SelectionDAG.h | 13 +++-----
llvm/include/llvm/CodeGen/SelectionDAGNodes.h | 16 ----------
.../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 10 ++----
.../SelectionDAG/SelectionDAGBuilder.cpp | 21 ++++---------
.../SelectionDAG/SelectionDAGBuilder.h | 1 -
.../Target/AArch64/AArch64ISelLowering.cpp | 31 +++++++++----------
6 files changed, 27 insertions(+), 65 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 803339480f67cc..089d69d5256c55 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -1604,15 +1604,10 @@ class SelectionDAG {
/// the target's desired shift amount type.
SDValue getShiftAmountOperand(EVT LHSTy, SDValue Op);
- /// Get a partial reduction SD node for the DAG. This is done when the input
- /// and output types can be legalised for wide add(s) or dot product(s)
- SDValue getPartialReduceAddSDNode(SDLoc DL, SDValue Chain, SDValue Acc,
- SDValue Input);
-
- /// Create the DAG equivalent of vector_partial_reduce where Op1 and Op2 are
- /// its operands and ReducedTY is the intrinsic's return type.
- SDValue getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
- SDValue Op2);
+ /// Expands partial reduce node which can't be lowered to wide add or dot
+ /// product instruction(s)
+ SDValue expandPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
+ SDValue Op2);
/// Expands a node with multiple results to an FP or vector libcall. The
/// libcall is expected to take all the operands of the \p Node followed by
diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index 32b3bcfd3dcde9..49467ce0a54cd0 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -3033,22 +3033,6 @@ class MaskedHistogramSDNode : public MaskedGatherScatterSDNode {
}
};
-class PartialReduceAddSDNode : public SDNode {
-public:
- friend class SelectionDAG;
-
- PartialReduceAddSDNode(const DebugLoc &dl, SDVTList VTs)
- : SDNode(ISD::PARTIAL_REDUCE_ADD, 0, dl, VTs) {}
-
- const SDValue &getChain() const { return getOperand(0); }
- const SDValue &getAcc() const { return getOperand(1); }
- const SDValue &getInput() const { return getOperand(2); }
-
- static bool classof(const SDNode *N) {
- return N->getOpcode() == ISD::PARTIAL_REDUCE_ADD;
- }
-};
-
class FPStateAccessSDNode : public MemSDNode {
public:
friend class SelectionDAG;
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index c2a52cf1e8314f..7b38d7c424d0df 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -2467,14 +2467,8 @@ SDValue SelectionDAG::getShiftAmountOperand(EVT LHSTy, SDValue Op) {
return getZExtOrTrunc(Op, SDLoc(Op), ShTy);
}
-SDValue SelectionDAG::getPartialReduceAddSDNode(SDLoc DL, SDValue Chain,
- SDValue Acc, SDValue Input) {
- return getNode(ISD::PARTIAL_REDUCE_ADD, DL, Acc.getValueType(), Chain, Acc,
- Input);
-}
-
-SDValue SelectionDAG::getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
- SDValue Op2) {
+SDValue SelectionDAG::expandPartialReduceAdd(SDLoc DL, EVT ReducedTy,
+ SDValue Op1, SDValue Op2) {
EVT FullTy = Op2.getValueType();
unsigned Stride = ReducedTy.getVectorMinNumElements();
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index f055e57027f4e4..f2f9634e2782d3 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -6422,16 +6422,6 @@ void SelectionDAGBuilder::visitVectorHistogram(const CallInst &I,
DAG.setRoot(Histogram);
}
-void SelectionDAGBuilder::visitPartialReduceAdd(const CallInst &I,
- unsigned IntrinsicID) {
- SDLoc dl = getCurSDLoc();
- SDValue Acc = getValue(I.getOperand(0));
- SDValue Input = getValue(I.getOperand(1));
- SDValue Chain = getRoot();
-
- setValue(&I, DAG.getPartialReduceAddSDNode(dl, Chain, Acc, Input));
-}
-
void SelectionDAGBuilder::visitVectorExtractLastActive(const CallInst &I,
unsigned Intrinsic) {
assert(Intrinsic == Intrinsic::experimental_vector_extract_last_active &&
@@ -8128,15 +8118,16 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
return;
}
case Intrinsic::experimental_vector_partial_reduce_add: {
+ SDLoc dl = getCurSDLoc();
+ SDValue Acc = getValue(I.getOperand(0));
+ EVT AccVT = Acc.getValueType();
+ SDValue Input = getValue(I.getOperand(1));
if (!TLI.shouldExpandPartialReductionIntrinsic(cast<IntrinsicInst>(&I))) {
- visitPartialReduceAdd(I, Intrinsic);
+ setValue(&I, DAG.getNode(ISD::PARTIAL_REDUCE_ADD, dl, AccVT, Acc, Input));
return;
}
-
- setValue(&I, DAG.getPartialReduceAdd(sdl, EVT::getEVT(I.getType()),
- getValue(I.getOperand(0)),
- getValue(I.getOperand(1))));
+ setValue(&I, DAG.expandPartialReduceAdd(dl, AccVT, Acc, Input));
return;
}
case Intrinsic::experimental_cttz_elts: {
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h
index a9e0c8f1ea10c1..3a8dc25e98700e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h
@@ -629,7 +629,6 @@ class SelectionDAGBuilder {
void visitConstrainedFPIntrinsic(const ConstrainedFPIntrinsic &FPI);
void visitConvergenceControl(const CallInst &I, unsigned Intrinsic);
void visitVectorHistogram(const CallInst &I, unsigned IntrinsicID);
- void visitPartialReduceAdd(const CallInst &, unsigned IntrinsicID);
void visitVectorExtractLastActive(const CallInst &I, unsigned Intrinsic);
void visitVPLoad(const VPIntrinsic &VPIntrin, EVT VT,
const SmallVectorImpl<SDValue> &OpValues);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 9edc7573272905..a82a417f4f8fc1 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -22013,21 +22013,21 @@ static SDValue tryCombineWhileLo(SDNode *N,
return SDValue(N, 0);
}
-SDValue tryLowerPartialReductionToDot(PartialReduceAddSDNode *PR,
+SDValue tryLowerPartialReductionToDot(SDNode *N,
const AArch64Subtarget *Subtarget,
SelectionDAG &DAG) {
- bool Scalable = PR->getValueType(0).isScalableVector();
+ bool Scalable = N->getValueType(0).isScalableVector();
if (Scalable && !Subtarget->isSVEorStreamingSVEAvailable())
return SDValue();
if (!Scalable && (!Subtarget->isNeonAvailable() || !Subtarget->hasDotProd()))
return SDValue();
- SDLoc DL(PR);
+ SDLoc DL(N);
// The narrower of the two operands. Used as the accumulator
- auto NarrowOp = PR->getAcc();
- auto MulOp = PR->getInput();
+ auto NarrowOp = N->getOperand(0);
+ auto MulOp = N->getOperand(1);
if (MulOp->getOpcode() != ISD::MUL)
return SDValue();
@@ -22042,7 +22042,7 @@ SDValue tryLowerPartialReductionToDot(PartialReduceAddSDNode *PR,
} else
return SDValue();
- EVT ReducedType = PR->getValueType(0);
+ EVT ReducedType = N->getValueType(0);
EVT MulSrcType = A.getValueType();
// Dot products operate on chunks of four elements so there must be four times
@@ -22061,7 +22061,7 @@ SDValue tryLowerPartialReductionToDot(PartialReduceAddSDNode *PR,
if (!Subtarget->hasMatMulInt8())
return SDValue();
- bool Scalable = PR->getValueType(0).isScalableVT();
+ bool Scalable = N->getValueType(0).isScalableVT();
// There's no nxv2i64 version of usdot
if (Scalable && ReducedVT != MVT::nxv4i32 && ReducedVT != MVT::nxv4i64)
return SDValue();
@@ -22090,17 +22090,17 @@ SDValue tryLowerPartialReductionToDot(PartialReduceAddSDNode *PR,
return DAG.getNode(Opcode, DL, ReducedVT, Acc, MulOpLHS, MulOpRHS);
}
-SDValue tryLowerPartialReductionToWideAdd(PartialReduceAddSDNode *PR,
+SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
const AArch64Subtarget *Subtarget,
SelectionDAG &DAG) {
if (!Subtarget->hasSVE2() && !Subtarget->isStreamingSVEAvailable())
return SDValue();
- SDLoc DL(PR);
+ SDLoc DL(N);
- auto Acc = PR->getAcc();
- auto ExtInput = PR->getInput();
+ auto Acc = N->getOperand(0);
+ auto ExtInput = N->getOperand(1);
EVT AccVT = Acc.getValueType();
EVT ExtVT = Ext.getValueType();
@@ -22126,13 +22126,12 @@ SDValue tryLowerPartialReductionToWideAdd(PartialReduceAddSDNode *PR,
static SDValue
performPartialReduceAddCombine(SDNode *N, SelectionDAG &DAG,
const AArch64Subtarget *Subtarget) {
- auto *PR = cast<PartialReduceAddSDNode>(N);
- if (auto Dot = tryLowerPartialReductionToDot(PR, Subtarget, DAG))
+ if (auto Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG))
return Dot;
- if (auto WideAdd = tryLowerPartialReductionToWideAdd(PR, Subtarget, DAG))
+ if (auto WideAdd = tryLowerPartialReductionToWideAdd(N, Subtarget, DAG))
return WideAdd;
- return DAG.getPartialReduceAdd(SDLoc(PR), PR->getValueType(0), PR->getAcc(),
- PR->getInput());
+ return DAG.expandPartialReduceAdd(SDLoc(N), N->getValueType(0),
+ N->getOperand(0), N->getOperand(1));
}
>From 2a8f7ee7fbfe2cc036d60f50e03537c7e9fffbcf Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Thu, 28 Nov 2024 12:11:36 +0000
Subject: [PATCH 03/14] Remove unnecessary function parameter and update
comments
---
llvm/include/llvm/CodeGen/ISDOpcodes.h | 3 ++-
llvm/include/llvm/CodeGen/SelectionDAG.h | 9 +++++----
llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 5 +++--
llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp | 2 +-
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 3 +--
5 files changed, 12 insertions(+), 10 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index 6ba76a42d06f85..42505f5a6a6b85 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -1451,7 +1451,8 @@ enum NodeType {
VECREDUCE_UMAX,
VECREDUCE_UMIN,
- // The `llvm.experimental.vector.partial.reduce.add` intrinsic
+ // This corresponds to the `llvm.experimental.vector.partial.reduce.add`
+ // intrinsic
// Operands: Accumulator, Input
// Outputs: Output
PARTIAL_REDUCE_ADD,
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 089d69d5256c55..259b2f843019f0 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -1604,10 +1604,11 @@ class SelectionDAG {
/// the target's desired shift amount type.
SDValue getShiftAmountOperand(EVT LHSTy, SDValue Op);
- /// Expands partial reduce node which can't be lowered to wide add or dot
- /// product instruction(s)
- SDValue expandPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
- SDValue Op2);
+ /// Expands PARTIAL_REDUCE_ADD nodes which can't be lowered.
+ /// @param Op1 Accumulator for where the result is stored for the partial
+ /// reduction operation
+ /// @param Op2 Input for the partial reduction operation
+ SDValue expandPartialReduceAdd(SDLoc DL, SDValue Op1, SDValue Op2);
/// Expands a node with multiple results to an FP or vector libcall. The
/// libcall is expected to take all the operands of the \p Node followed by
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 7b38d7c424d0df..b720379497d911 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -2467,8 +2467,9 @@ SDValue SelectionDAG::getShiftAmountOperand(EVT LHSTy, SDValue Op) {
return getZExtOrTrunc(Op, SDLoc(Op), ShTy);
}
-SDValue SelectionDAG::expandPartialReduceAdd(SDLoc DL, EVT ReducedTy,
- SDValue Op1, SDValue Op2) {
+SDValue SelectionDAG::expandPartialReduceAdd(SDLoc DL, SDValue Op1,
+ SDValue Op2) {
+ EVT ReducedTy = Op1.getValueType();
EVT FullTy = Op2.getValueType();
unsigned Stride = ReducedTy.getVectorMinNumElements();
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index f2f9634e2782d3..7d2b0e4f094ff8 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -8127,7 +8127,7 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
setValue(&I, DAG.getNode(ISD::PARTIAL_REDUCE_ADD, dl, AccVT, Acc, Input));
return;
}
- setValue(&I, DAG.expandPartialReduceAdd(dl, AccVT, Acc, Input));
+ setValue(&I, DAG.expandPartialReduceAdd(dl, Acc, Input));
return;
}
case Intrinsic::experimental_cttz_elts: {
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index a82a417f4f8fc1..74a5234d1b9eca 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -22130,8 +22130,7 @@ performPartialReduceAddCombine(SDNode *N, SelectionDAG &DAG,
return Dot;
if (auto WideAdd = tryLowerPartialReductionToWideAdd(N, Subtarget, DAG))
return WideAdd;
- return DAG.expandPartialReduceAdd(SDLoc(N), N->getValueType(0),
- N->getOperand(0), N->getOperand(1));
+ return DAG.expandPartialReduceAdd(SDLoc(N), N->getOperand(0), N->getOperand(1));
}
>From e75ff4ae7a36af01be7a8b22fc9b83492c3353fb Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Thu, 28 Nov 2024 13:17:15 +0000
Subject: [PATCH 04/14] Code formatting changes
---
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 74a5234d1b9eca..b4e441e33af5ed 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -22130,7 +22130,8 @@ performPartialReduceAddCombine(SDNode *N, SelectionDAG &DAG,
return Dot;
if (auto WideAdd = tryLowerPartialReductionToWideAdd(N, Subtarget, DAG))
return WideAdd;
- return DAG.expandPartialReduceAdd(SDLoc(N), N->getOperand(0), N->getOperand(1));
+ return DAG.expandPartialReduceAdd(SDLoc(N), N->getOperand(0),
+ N->getOperand(1));
}
>From ad377d1bcf72737c72ea1c8fd8ed00b5e15286a2 Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Mon, 2 Dec 2024 09:41:36 +0000
Subject: [PATCH 05/14] Make two ISD nodes for partial reductions as opposed to
one
---
llvm/include/llvm/CodeGen/ISDOpcodes.h | 5 +-
llvm/include/llvm/CodeGen/SelectionDAG.h | 9 ++--
llvm/include/llvm/CodeGen/TargetLowering.h | 10 ++++
.../SelectionDAG/SelectionDAGBuilder.cpp | 7 ++-
.../SelectionDAG/SelectionDAGDumper.cpp | 7 ++-
.../Target/AArch64/AArch64ISelLowering.cpp | 48 ++++++++++++-------
llvm/lib/Target/AArch64/AArch64ISelLowering.h | 2 +
7 files changed, 61 insertions(+), 27 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index 42505f5a6a6b85..a2e856efe931bb 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -1451,11 +1451,12 @@ enum NodeType {
VECREDUCE_UMAX,
VECREDUCE_UMIN,
- // This corresponds to the `llvm.experimental.vector.partial.reduce.add`
+ // These correspond to the `llvm.experimental.vector.partial.reduce.add`
// intrinsic
// Operands: Accumulator, Input
// Outputs: Output
- PARTIAL_REDUCE_ADD,
+ PARTIAL_REDUCE_SADD,
+ PARTIAL_REDUCE_UADD,
// The `llvm.experimental.stackmap` intrinsic.
// Operands: input chain, glue, <id>, <numShadowBytes>, [live0[, live1...]]
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 259b2f843019f0..cdff508ef8d997 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -1604,10 +1604,11 @@ class SelectionDAG {
/// the target's desired shift amount type.
SDValue getShiftAmountOperand(EVT LHSTy, SDValue Op);
- /// Expands PARTIAL_REDUCE_ADD nodes which can't be lowered.
- /// @param Op1 Accumulator for where the result is stored for the partial
- /// reduction operation
- /// @param Op2 Input for the partial reduction operation
+ /// Expands PARTIAL_REDUCE_S/UADD nodes to a sequence of subvector extracts
+ /// followed by vector adds.
+ /// \p Op1 Accumulator for where the result is stored for the partial
+ /// reduction operation.
+ /// \p Op2 Input for the partial reduction operation.
SDValue expandPartialReduceAdd(SDLoc DL, SDValue Op1, SDValue Op2);
/// Expands a node with multiple results to an FP or vector libcall. The
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 38ac90f0c081b3..c0dd1a8e369e2b 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -462,6 +462,16 @@ class TargetLoweringBase {
return true;
}
+ /// Return true if there is a sign extend on the input to this function. Used
+ /// to determine whether to transform the
+ /// @llvm.experimental.vector.partial.reduce.* intrinsic to
+ /// PARTIAL_REDUCE_SADD or PARTIAL_REDUCE_UADD. It also removes the extend
+ /// from the input. \p Input The 'Input' operand to the
+ /// @llvm.experimental.vector.partial.reduce.* intrinsic.
+ virtual bool isPartialReductionInputSigned(SDValue &Input) const {
+ return false;
+ }
+
/// Return true if the @llvm.get.active.lane.mask intrinsic should be expanded
/// using generic code in SelectionDAGBuilder.
virtual bool shouldExpandGetActiveLaneMask(EVT VT, EVT OpVT) const {
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 7d2b0e4f094ff8..c63a81884b20e7 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -8124,7 +8124,12 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
SDValue Input = getValue(I.getOperand(1));
if (!TLI.shouldExpandPartialReductionIntrinsic(cast<IntrinsicInst>(&I))) {
- setValue(&I, DAG.getNode(ISD::PARTIAL_REDUCE_ADD, dl, AccVT, Acc, Input));
+ if (TLI.isPartialReductionInputSigned(Input))
+ setValue(&I,
+ DAG.getNode(ISD::PARTIAL_REDUCE_SADD, dl, AccVT, Acc, Input));
+ else
+ setValue(&I,
+ DAG.getNode(ISD::PARTIAL_REDUCE_UADD, dl, AccVT, Acc, Input));
return;
}
setValue(&I, DAG.expandPartialReduceAdd(dl, Acc, Input));
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
index f0126cd41f671d..5ca42904da852d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
@@ -570,8 +570,11 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
case ISD::VECTOR_FIND_LAST_ACTIVE:
return "find_last_active";
- case ISD::PARTIAL_REDUCE_ADD:
- return "partial_reduce_add";
+ case ISD::PARTIAL_REDUCE_UADD:
+ return "partial_reduce_uadd";
+ case ISD::PARTIAL_REDUCE_SADD:
+ return "partial_reduce_sadd";
+
// Vector Predication
#define BEGIN_REGISTER_VP_SDNODE(SDID, LEGALARG, NAME, ...) \
case ISD::SDID: \
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index b4e441e33af5ed..a2035777b8ca19 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1128,7 +1128,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setTargetDAGCombine(
{ISD::MGATHER, ISD::MSCATTER, ISD::EXPERIMENTAL_VECTOR_HISTOGRAM});
- setTargetDAGCombine(ISD::PARTIAL_REDUCE_ADD);
+ setTargetDAGCombine({ISD::PARTIAL_REDUCE_SADD, ISD::PARTIAL_REDUCE_UADD});
setTargetDAGCombine(ISD::FP_EXTEND);
@@ -2048,13 +2048,30 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
EVT VT = EVT::getEVT(I->getType());
auto Op1 = I->getOperand(1);
EVT Op1VT = EVT::getEVT(Op1->getType());
- if (Op1VT.getVectorElementType() == VT.getVectorElementType() &&
- (VT.getVectorElementCount() * 4 == Op1VT.getVectorElementCount() ||
- VT.getVectorElementCount() * 2 == Op1VT.getVectorElementCount()))
+ if ((Op1VT == MVT::nxv4i64 && VT == MVT::nxv2i64) ||
+ (Op1VT == MVT::nxv8i32 && VT == MVT::nxv4i32) ||
+ (Op1VT == MVT::nxv16i16 && VT == MVT::nxv8i16) ||
+ (Op1VT == MVT::nxv16i64 && VT == MVT::nxv4i64) ||
+ (Op1VT == MVT::nxv16i32 && VT == MVT::nxv4i32) ||
+ (Op1VT == MVT::nxv8i64 && VT == MVT::nxv2i64) ||
+ (Op1VT == MVT::v16i64 && VT == MVT::v4i64) ||
+ (Op1VT == MVT::v16i32 && VT == MVT::v4i32) ||
+ (Op1VT == MVT::v8i32 && VT == MVT::v2i32))
return false;
return true;
}
+bool AArch64TargetLowering::isPartialReductionInputSigned(
+ SDValue &Input) const {
+ unsigned InputOpcode = Input.getOpcode();
+ if (ISD::isExtOpcode(InputOpcode)) {
+ Input = Input.getOperand(0);
+ if (InputOpcode == ISD::SIGN_EXTEND)
+ return true;
+ }
+ return false;
+}
+
bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
if (!Subtarget->isSVEorStreamingSVEAvailable())
return true;
@@ -22100,27 +22117,21 @@ SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
SDLoc DL(N);
auto Acc = N->getOperand(0);
- auto ExtInput = N->getOperand(1);
+ auto Input = N->getOperand(1);
EVT AccVT = Acc.getValueType();
- EVT ExtVT = Ext.getValueType();
- if (ExtVT.getVectorElementType() != AccVT.getVectorElementType())
- return SDValue();
-
- SDValue ExtOp = Ext->getOperand(0);
- EVT ExtOpVT = ExtOp.getValueType();
+ EVT InputVT = Input.getValueType();
if (!(ExtOpVT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
!(ExtOpVT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
!(ExtOpVT == MVT::nxv16i8 && AccVT == MVT::nxv8i16))
return SDValue();
- bool ExtOpIsSigned = Ext.getOpcode() == ISD::SIGN_EXTEND;
- unsigned BottomOpcode =
- ExtOpIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
- unsigned TopOpcode = ExtOpIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT;
- SDValue BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, ExtOp);
- return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, ExtOp);
+ bool InputIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SADD;
+ auto BottomOpcode = InputIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
+ auto TopOpcode = InputIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT;
+ auto BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, Input);
+ return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, Input);
}
static SDValue
@@ -26586,7 +26597,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
case ISD::MSCATTER:
case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
return performMaskedGatherScatterCombine(N, DCI, DAG);
- case ISD::PARTIAL_REDUCE_ADD:
+ case ISD::PARTIAL_REDUCE_UADD:
+ case ISD::PARTIAL_REDUCE_SADD:
return performPartialReduceAddCombine(N, DAG, Subtarget);
case ISD::FP_EXTEND:
return performFPExtendCombine(N, DAG, DCI, Subtarget);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 61579de50db17e..bf86738b7bb8ff 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -996,6 +996,8 @@ class AArch64TargetLowering : public TargetLowering {
bool
shouldExpandPartialReductionIntrinsic(const IntrinsicInst *I) const override;
+ bool isPartialReductionInputSigned(SDValue &Input) const override;
+
bool shouldExpandCttzElements(EVT VT) const override;
bool shouldExpandVectorMatch(EVT VT, unsigned SearchSize) const override;
>From ba98a377b608f27e9fecd5cb881ef53c7f529991 Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Mon, 9 Dec 2024 13:41:41 +0000
Subject: [PATCH 06/14] Determine which ISD node to use in DAG combine rather
than in SelectionDAGBuilder.
---
llvm/include/llvm/CodeGen/ISDOpcodes.h | 4 +-
llvm/include/llvm/CodeGen/SelectionDAG.h | 3 +-
llvm/include/llvm/CodeGen/TargetLowering.h | 10 --
.../SelectionDAG/SelectionDAGBuilder.cpp | 8 +-
.../Target/AArch64/AArch64ISelLowering.cpp | 97 ++++++++++---------
llvm/lib/Target/AArch64/AArch64ISelLowering.h | 2 -
6 files changed, 55 insertions(+), 69 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index a2e856efe931bb..d58344a6cc9b1e 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -1451,8 +1451,8 @@ enum NodeType {
VECREDUCE_UMAX,
VECREDUCE_UMIN,
- // These correspond to the `llvm.experimental.vector.partial.reduce.add`
- // intrinsic
+ // Nodes used to represent a partial reduction addition operation (signed and
+ // unsigned).
// Operands: Accumulator, Input
// Outputs: Output
PARTIAL_REDUCE_SADD,
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index cdff508ef8d997..6daba913b37b63 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -1604,8 +1604,7 @@ class SelectionDAG {
/// the target's desired shift amount type.
SDValue getShiftAmountOperand(EVT LHSTy, SDValue Op);
- /// Expands PARTIAL_REDUCE_S/UADD nodes to a sequence of subvector extracts
- /// followed by vector adds.
+ /// Expands PARTIAL_REDUCE_S/UADD nodes.
/// \p Op1 Accumulator for where the result is stored for the partial
/// reduction operation.
/// \p Op2 Input for the partial reduction operation.
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index c0dd1a8e369e2b..38ac90f0c081b3 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -462,16 +462,6 @@ class TargetLoweringBase {
return true;
}
- /// Return true if there is a sign extend on the input to this function. Used
- /// to determine whether to transform the
- /// @llvm.experimental.vector.partial.reduce.* intrinsic to
- /// PARTIAL_REDUCE_SADD or PARTIAL_REDUCE_UADD. It also removes the extend
- /// from the input. \p Input The 'Input' operand to the
- /// @llvm.experimental.vector.partial.reduce.* intrinsic.
- virtual bool isPartialReductionInputSigned(SDValue &Input) const {
- return false;
- }
-
/// Return true if the @llvm.get.active.lane.mask intrinsic should be expanded
/// using generic code in SelectionDAGBuilder.
virtual bool shouldExpandGetActiveLaneMask(EVT VT, EVT OpVT) const {
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index c63a81884b20e7..139513bd9fc93f 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -8124,12 +8124,8 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
SDValue Input = getValue(I.getOperand(1));
if (!TLI.shouldExpandPartialReductionIntrinsic(cast<IntrinsicInst>(&I))) {
- if (TLI.isPartialReductionInputSigned(Input))
- setValue(&I,
- DAG.getNode(ISD::PARTIAL_REDUCE_SADD, dl, AccVT, Acc, Input));
- else
- setValue(&I,
- DAG.getNode(ISD::PARTIAL_REDUCE_UADD, dl, AccVT, Acc, Input));
+ setValue(&I,
+ DAG.getNode(ISD::PARTIAL_REDUCE_UADD, dl, AccVT, Acc, Input));
return;
}
setValue(&I, DAG.expandPartialReduceAdd(dl, Acc, Input));
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index a2035777b8ca19..f2c080430127f0 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2061,17 +2061,6 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
return true;
}
-bool AArch64TargetLowering::isPartialReductionInputSigned(
- SDValue &Input) const {
- unsigned InputOpcode = Input.getOpcode();
- if (ISD::isExtOpcode(InputOpcode)) {
- Input = Input.getOperand(0);
- if (InputOpcode == ISD::SIGN_EXTEND)
- return true;
- }
- return false;
-}
-
bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
if (!Subtarget->isSVEorStreamingSVEAvailable())
return true;
@@ -22048,48 +22037,54 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
if (MulOp->getOpcode() != ISD::MUL)
return SDValue();
- MulOpLHSIsSigned = ExtMulOpLHSOpcode == ISD::SIGN_EXTEND;
- MulOpRHSIsSigned = ExtMulOpRHSOpcode == ISD::SIGN_EXTEND;
-
- MulOpLHS = ExtMulOpLHS->getOperand(0);
- MulOpRHS = ExtMulOpRHS->getOperand(0);
+ auto A = MulOp->getOperand(0);
+ auto B = MulOp->getOperand(1);
- if (MulOpLHS.getValueType() != MulOpRHS.getValueType())
+ unsigned AOpcode = A->getOpcode();
+ unsigned BOpcode = B->getOpcode();
+ unsigned Opcode;
+ EVT ReducedType = N->getValueType(0);
+ EVT MulSrcType;
+ if (ISD::isExtOpcode(AOpcode) || ISD::isExtOpcode(BOpcode)) {
+ bool AIsSigned = AOpcode == ISD::SIGN_EXTEND;
+ bool BIsSigned = BOpcode == ISD::SIGN_EXTEND;
+
+ A = A->getOperand(0);
+ B = B->getOperand(0);
+ if (A.getValueType() != B.getValueType())
return SDValue();
- } else
- return SDValue();
- EVT ReducedType = N->getValueType(0);
- EVT MulSrcType = A.getValueType();
+ if (AIsSigned != BIsSigned) {
+ if (!Subtarget->hasMatMulInt8())
+ return SDValue();
+
+ bool Scalable = N->getValueType(0).isScalableVT();
+ // There's no nxv2i64 version of usdot
+ if (Scalable && ReducedType != MVT::nxv4i32 &&
+ ReducedType != MVT::nxv4i64)
+ return SDValue();
+
+ Opcode = AArch64ISD::USDOT;
+ // USDOT expects the signed operand to be last
+ if (!BIsSigned)
+ std::swap(A, B);
+ } else if (AIsSigned)
+ Opcode = AArch64ISD::SDOT;
+ else
+ Opcode = AArch64ISD::UDOT;
+ MulSrcType = A.getValueType();
+ }
// Dot products operate on chunks of four elements so there must be four times
// as many elements in the wide type
- if (!(ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) &&
- !(ReducedVT == MVT::nxv4i32 && MulSrcVT == MVT::nxv16i8) &&
- !(ReducedVT == MVT::nxv2i64 && MulSrcVT == MVT::nxv8i16) &&
- !(ReducedVT == MVT::v4i64 && MulSrcVT == MVT::v16i8) &&
- !(ReducedVT == MVT::v4i32 && MulSrcVT == MVT::v16i8) &&
- !(ReducedVT == MVT::v2i32 && MulSrcVT == MVT::v8i8))
+ if (!(ReducedType == MVT::nxv4i64 && MulSrcType == MVT::nxv16i8) &&
+ !(ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) &&
+ !(ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16) &&
+ !(ReducedType == MVT::v4i64 && MulSrcType == MVT::v16i8) &&
+ !(ReducedType == MVT::v4i32 && MulSrcType == MVT::v16i8) &&
+ !(ReducedType == MVT::v2i32 && MulSrcType == MVT::v8i8))
return SDValue();
- // If the extensions are mixed, we should lower it to a usdot instead
- unsigned Opcode = 0;
- if (MulOpLHSIsSigned != MulOpRHSIsSigned) {
- if (!Subtarget->hasMatMulInt8())
- return SDValue();
-
- bool Scalable = N->getValueType(0).isScalableVT();
- // There's no nxv2i64 version of usdot
- if (Scalable && ReducedVT != MVT::nxv4i32 && ReducedVT != MVT::nxv4i64)
- return SDValue();
-
- Opcode = AArch64ISD::USDOT;
- // USDOT expects the signed operand to be last
- if (!MulOpRHSIsSigned)
- std::swap(MulOpLHS, MulOpRHS);
- } else
- Opcode = MulOpLHSIsSigned ? AArch64ISD::SDOT : AArch64ISD::UDOT;
-
// Partial reduction lowering for (nx)v16i8 to (nx)v4i64 requires an i32 dot
// product followed by a zero / sign extension
if ((ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) ||
@@ -22119,15 +22114,23 @@ SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
auto Acc = N->getOperand(0);
auto Input = N->getOperand(1);
- EVT AccVT = Acc.getValueType();
+ unsigned Opcode = N->getOpcode();
+ unsigned InputOpcode = Input.getOpcode();
+ if (ISD::isExtOpcode(InputOpcode)) {
+ Input = Input.getOperand(0);
+ if (InputOpcode == ISD::SIGN_EXTEND)
+ Opcode = ISD::PARTIAL_REDUCE_SADD;
+ }
+
EVT InputVT = Input.getValueType();
+ EVT AccVT = Acc.getValueType();
if (!(ExtOpVT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
!(ExtOpVT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
!(ExtOpVT == MVT::nxv16i8 && AccVT == MVT::nxv8i16))
return SDValue();
- bool InputIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SADD;
+ bool InputIsSigned = Opcode == ISD::PARTIAL_REDUCE_SADD;
auto BottomOpcode = InputIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
auto TopOpcode = InputIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT;
auto BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, Input);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index bf86738b7bb8ff..61579de50db17e 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -996,8 +996,6 @@ class AArch64TargetLowering : public TargetLowering {
bool
shouldExpandPartialReductionIntrinsic(const IntrinsicInst *I) const override;
- bool isPartialReductionInputSigned(SDValue &Input) const override;
-
bool shouldExpandCttzElements(EVT VT) const override;
bool shouldExpandVectorMatch(EVT VT, unsigned SearchSize) const override;
>From 43f73c2acadfc0426043a4e98dd168e5ca04793c Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Wed, 11 Dec 2024 14:41:56 +0000
Subject: [PATCH 07/14] Separate lowering code for PARTIAL_REDUCE_U/SADD
Separate lowering code from all being in the DAG-combine function.
Now the DAG-combine decides whether the node should be the signed
or unsigned version of partial reduce add. Then there is a function
in LowerOperation that does the actual lowering to wide adds or dot
products if it is able to.
---
.../Target/AArch64/AArch64ISelLowering.cpp | 265 ++++++++++--------
llvm/lib/Target/AArch64/AArch64ISelLowering.h | 1 +
2 files changed, 146 insertions(+), 120 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index f2c080430127f0..483afe0a3750a9 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1846,8 +1846,17 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::nxv2i64,
Custom);
}
+
+ for (auto VT : {MVT::nxv2i64, MVT::nxv4i32, MVT::nxv8i16}) {
+ setOperationAction(ISD::PARTIAL_REDUCE_UADD, VT, Custom);
+ setOperationAction(ISD::PARTIAL_REDUCE_SADD, VT, Custom);
+ }
}
+ for (auto VT : {MVT::v4i64, MVT::v4i32, MVT::v2i32}) {
+ setOperationAction(ISD::PARTIAL_REDUCE_UADD, VT, Custom);
+ setOperationAction(ISD::PARTIAL_REDUCE_SADD, VT, Custom);
+ }
if (Subtarget->hasMOPS() && Subtarget->hasMTE()) {
// Only required for llvm.aarch64.mops.memset.tag
@@ -2046,17 +2055,18 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
return true;
EVT VT = EVT::getEVT(I->getType());
- auto Op1 = I->getOperand(1);
- EVT Op1VT = EVT::getEVT(Op1->getType());
- if ((Op1VT == MVT::nxv4i64 && VT == MVT::nxv2i64) ||
- (Op1VT == MVT::nxv8i32 && VT == MVT::nxv4i32) ||
- (Op1VT == MVT::nxv16i16 && VT == MVT::nxv8i16) ||
- (Op1VT == MVT::nxv16i64 && VT == MVT::nxv4i64) ||
- (Op1VT == MVT::nxv16i32 && VT == MVT::nxv4i32) ||
- (Op1VT == MVT::nxv8i64 && VT == MVT::nxv2i64) ||
- (Op1VT == MVT::v16i64 && VT == MVT::v4i64) ||
- (Op1VT == MVT::v16i32 && VT == MVT::v4i32) ||
- (Op1VT == MVT::v8i32 && VT == MVT::v2i32))
+ auto Input = I->getOperand(1);
+ EVT InputVT = EVT::getEVT(Input->getType());
+
+ if ((InputVT == MVT::nxv4i64 && VT == MVT::nxv2i64) ||
+ (InputVT == MVT::nxv8i32 && VT == MVT::nxv4i32) ||
+ (InputVT == MVT::nxv16i16 && VT == MVT::nxv8i16) ||
+ (InputVT == MVT::nxv16i64 && VT == MVT::nxv4i64) ||
+ (InputVT == MVT::nxv16i32 && VT == MVT::nxv4i32) ||
+ (InputVT == MVT::nxv8i64 && VT == MVT::nxv2i64) ||
+ (InputVT == MVT::v16i64 && VT == MVT::v4i64) ||
+ (InputVT == MVT::v16i32 && VT == MVT::v4i32) ||
+ (InputVT == MVT::v8i32 && VT == MVT::v2i32))
return false;
return true;
}
@@ -7659,6 +7669,9 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
return LowerFLDEXP(Op, DAG);
case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
return LowerVECTOR_HISTOGRAM(Op, DAG);
+ case ISD::PARTIAL_REDUCE_UADD:
+ case ISD::PARTIAL_REDUCE_SADD:
+ return LowerPARTIAL_REDUCE_ADD(Op, DAG);
}
}
@@ -22019,147 +22032,126 @@ static SDValue tryCombineWhileLo(SDNode *N,
return SDValue(N, 0);
}
-SDValue tryLowerPartialReductionToDot(SDNode *N,
- const AArch64Subtarget *Subtarget,
- SelectionDAG &DAG) {
-
- bool Scalable = N->getValueType(0).isScalableVector();
+SDValue tryCombineToDotProduct(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
+ const AArch64Subtarget *Subtarget, SDLoc &DL) {
+ bool Scalable = Acc.getValueType().isScalableVector();
if (Scalable && !Subtarget->isSVEorStreamingSVEAvailable())
return SDValue();
if (!Scalable && (!Subtarget->isNeonAvailable() || !Subtarget->hasDotProd()))
return SDValue();
- SDLoc DL(N);
-
- // The narrower of the two operands. Used as the accumulator
- auto NarrowOp = N->getOperand(0);
- auto MulOp = N->getOperand(1);
- if (MulOp->getOpcode() != ISD::MUL)
+ unsigned InputOpcode = Input->getOpcode();
+ if (InputOpcode != ISD::MUL)
return SDValue();
-
- auto A = MulOp->getOperand(0);
- auto B = MulOp->getOperand(1);
-
+ auto A = Input->getOperand(0);
+ auto B = Input->getOperand(1);
unsigned AOpcode = A->getOpcode();
unsigned BOpcode = B->getOpcode();
- unsigned Opcode;
- EVT ReducedType = N->getValueType(0);
- EVT MulSrcType;
- if (ISD::isExtOpcode(AOpcode) || ISD::isExtOpcode(BOpcode)) {
- bool AIsSigned = AOpcode == ISD::SIGN_EXTEND;
- bool BIsSigned = BOpcode == ISD::SIGN_EXTEND;
-
- A = A->getOperand(0);
- B = B->getOperand(0);
- if (A.getValueType() != B.getValueType())
- return SDValue();
+ EVT AccVT = Acc->getValueType(0);
- if (AIsSigned != BIsSigned) {
- if (!Subtarget->hasMatMulInt8())
- return SDValue();
+ if (!ISD::isExtOpcode(AOpcode) || !ISD::isExtOpcode(BOpcode))
+ return DAG.expandPartialReduceAdd(DL, Acc, Input);
- bool Scalable = N->getValueType(0).isScalableVT();
- // There's no nxv2i64 version of usdot
- if (Scalable && ReducedType != MVT::nxv4i32 &&
- ReducedType != MVT::nxv4i64)
- return SDValue();
+ bool AIsSigned = AOpcode == ISD::SIGN_EXTEND;
+ bool BIsSigned = BOpcode == ISD::SIGN_EXTEND;
- Opcode = AArch64ISD::USDOT;
- // USDOT expects the signed operand to be last
- if (!BIsSigned)
- std::swap(A, B);
- } else if (AIsSigned)
- Opcode = AArch64ISD::SDOT;
- else
- Opcode = AArch64ISD::UDOT;
- MulSrcType = A.getValueType();
- }
+ A = A->getOperand(0);
+ B = B->getOperand(0);
+ EVT MulSrcVT = A.getValueType();
// Dot products operate on chunks of four elements so there must be four times
// as many elements in the wide type
- if (!(ReducedType == MVT::nxv4i64 && MulSrcType == MVT::nxv16i8) &&
- !(ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) &&
- !(ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16) &&
- !(ReducedType == MVT::v4i64 && MulSrcType == MVT::v16i8) &&
- !(ReducedType == MVT::v4i32 && MulSrcType == MVT::v16i8) &&
- !(ReducedType == MVT::v2i32 && MulSrcType == MVT::v8i8))
- return SDValue();
+ if (!(AccVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) &&
+ !(AccVT == MVT::nxv4i32 && MulSrcVT == MVT::nxv16i8) &&
+ !(AccVT == MVT::nxv2i64 && MulSrcVT == MVT::nxv8i16) &&
+ !(AccVT == MVT::v4i64 && MulSrcVT == MVT::v16i8) &&
+ !(AccVT == MVT::v4i32 && MulSrcVT == MVT::v16i8) &&
+ !(AccVT == MVT::v2i32 && MulSrcVT == MVT::v8i8))
+ return DAG.expandPartialReduceAdd(DL, Acc, Input);
+
+ unsigned DotOpcode = AIsSigned ? AArch64ISD::SDOT : AArch64ISD::UDOT;
+ if (AIsSigned != BIsSigned) {
+ if (!Subtarget->hasMatMulInt8())
+ return DAG.expandPartialReduceAdd(DL, Acc, Input);
+
+ bool Scalable = AccVT.isScalableVT();
+ // There's no nxv2i64 version of usdot
+ if (Scalable && AccVT != MVT::nxv4i32 && AccVT != MVT::nxv4i64)
+ return DAG.expandPartialReduceAdd(DL, Acc, Input);
+
+ if (!BIsSigned)
+ std::swap(A, B);
+ DotOpcode = AArch64ISD::USDOT;
+ // Lower usdot patterns here because legalisation would attempt to split it
+ // unless exts are removed. But, removing the exts would lose the
+ // information about whether each operand is signed.
+ if ((AccVT != MVT::nxv4i64 || MulSrcVT != MVT::nxv16i8) &&
+ (AccVT != MVT::v4i64 || MulSrcVT != MVT::v16i8))
+ return DAG.getNode(DotOpcode, DL, AccVT, Acc, A, B);
+ }
// Partial reduction lowering for (nx)v16i8 to (nx)v4i64 requires an i32 dot
- // product followed by a zero / sign extension
- if ((ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) ||
- (ReducedVT == MVT::v4i64 && MulSrcVT == MVT::v16i8)) {
- EVT ReducedVTI32 =
- (ReducedVT.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
+ // product followed by a zero / sign extension. Need to lower this here
+ // because legalisation would attempt to split it.
+ if ((AccVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) ||
+ (AccVT == MVT::v4i64 && MulSrcVT == MVT::v16i8)) {
+ EVT AccVTI32 = (AccVT.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
- SDValue DotI32 =
- DAG.getNode(Opcode, DL, ReducedVTI32,
- DAG.getConstant(0, DL, ReducedVTI32), MulOpLHS, MulOpRHS);
- SDValue Extended = DAG.getSExtOrTrunc(DotI32, DL, ReducedVT);
- return DAG.getNode(ISD::ADD, DL, ReducedVT, Acc, Extended);
+ auto DotI32 = DAG.getNode(DotOpcode, DL, AccVTI32,
+ DAG.getConstant(0, DL, AccVTI32), A, B);
+ auto Extended = DAG.getSExtOrTrunc(DotI32, DL, AccVT);
+ return DAG.getNode(ISD::ADD, DL, AccVT, Acc, Extended);
}
- return DAG.getNode(Opcode, DL, ReducedVT, Acc, MulOpLHS, MulOpRHS);
-}
+ if (A.getValueType() != B.getValueType())
+ return DAG.expandPartialReduceAdd(DL, Acc, Input);
-SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
- const AArch64Subtarget *Subtarget,
- SelectionDAG &DAG) {
+ unsigned NewOpcode =
+ AIsSigned ? ISD::PARTIAL_REDUCE_SADD : ISD::PARTIAL_REDUCE_UADD;
+ auto NewMul = DAG.getNode(ISD::MUL, DL, A.getValueType(), A, B);
+ return DAG.getNode(NewOpcode, DL, AccVT, Acc, NewMul);
+}
+SDValue tryCombineToWideAdd(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
+ const AArch64Subtarget *Subtarget, SDLoc &DL) {
if (!Subtarget->hasSVE2() && !Subtarget->isStreamingSVEAvailable())
- return SDValue();
-
- SDLoc DL(N);
-
- auto Acc = N->getOperand(0);
- auto Input = N->getOperand(1);
-
- unsigned Opcode = N->getOpcode();
- unsigned InputOpcode = Input.getOpcode();
- if (ISD::isExtOpcode(InputOpcode)) {
- Input = Input.getOperand(0);
- if (InputOpcode == ISD::SIGN_EXTEND)
- Opcode = ISD::PARTIAL_REDUCE_SADD;
- }
-
+ return DAG.expandPartialReduceAdd(DL, Acc, Input);
+ unsigned InputOpcode = Input->getOpcode();
+ if (!ISD::isExtOpcode(InputOpcode))
+ return DAG.expandPartialReduceAdd(DL, Acc, Input);
+ Input = Input->getOperand(0);
EVT InputVT = Input.getValueType();
- EVT AccVT = Acc.getValueType();
+ EVT AccVT = Acc->getValueType(0);
- if (!(ExtOpVT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
- !(ExtOpVT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
- !(ExtOpVT == MVT::nxv16i8 && AccVT == MVT::nxv8i16))
+ if (!(InputVT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
+ !(InputVT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
+ !(InputVT == MVT::nxv16i8 && AccVT == MVT::nxv8i16))
return SDValue();
- bool InputIsSigned = Opcode == ISD::PARTIAL_REDUCE_SADD;
- auto BottomOpcode = InputIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
- auto TopOpcode = InputIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT;
- auto BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, Input);
- return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, Input);
-}
-
-static SDValue
-performPartialReduceAddCombine(SDNode *N, SelectionDAG &DAG,
- const AArch64Subtarget *Subtarget) {
- if (auto Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG))
- return Dot;
- if (auto WideAdd = tryLowerPartialReductionToWideAdd(N, Subtarget, DAG))
- return WideAdd;
- return DAG.expandPartialReduceAdd(SDLoc(N), N->getOperand(0),
- N->getOperand(1));
+ unsigned NewOpcode = InputOpcode == ISD::SIGN_EXTEND
+ ? ISD::PARTIAL_REDUCE_SADD
+ : ISD::PARTIAL_REDUCE_UADD;
+ return DAG.getNode(NewOpcode, DL, AccVT, Acc, Input);
}
+SDValue performPartialReduceAddCombine(SDNode *N, SelectionDAG &DAG,
+ const AArch64Subtarget *Subtarget) {
+ SDLoc DL(N);
+ auto Acc = N->getOperand(0);
+ auto Input = N->getOperand(1);
+ EVT AccElemVT = Acc.getValueType().getVectorElementType();
+ EVT InputElemVT = Input.getValueType().getVectorElementType();
+ // If the exts have already been removed or it has already been lowered to an
+ // usdot instruction, then the element types will not be equal
+ if (InputElemVT != AccElemVT || Input.getOpcode() == AArch64ISD::USDOT)
+ return SDValue(N, 0);
-static SDValue
-performPartialReduceAddCombine(SDNode *N, SelectionDAG &DAG,
- const AArch64Subtarget *Subtarget) {
- auto *PR = cast<PartialReduceAddSDNode>(N);
- if (auto Dot = tryLowerPartialReductionToDot(PR, Subtarget, DAG))
+ if (auto Dot = tryCombineToDotProduct(Acc, Input, DAG, Subtarget, DL))
return Dot;
- if (auto WideAdd = tryLowerPartialReductionToWideAdd(PR, Subtarget, DAG))
+ if (auto WideAdd = tryCombineToWideAdd(Acc, Input, DAG, Subtarget, DL))
return WideAdd;
- return DAG.getPartialReduceAdd(SDLoc(PR), PR->getValueType(0), PR->getAcc(),
- PR->getInput());
+ return SDValue();
}
static SDValue performIntrinsicCombine(SDNode *N,
@@ -29372,6 +29364,39 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
return Scatter;
}
+SDValue
+AArch64TargetLowering::LowerPARTIAL_REDUCE_ADD(SDValue Op,
+ SelectionDAG &DAG) const {
+ SDLoc DL(Op);
+ SDValue Acc = Op.getOperand(0);
+ SDValue Input = Op.getOperand(1);
+
+ EVT AccVT = Acc.getValueType();
+ EVT InputVT = Input.getValueType();
+
+ unsigned Opcode = Op.getOpcode();
+
+ if (AccVT.getVectorElementCount() * 4 == InputVT.getVectorElementCount()) {
+ unsigned IndexAdd = 0;
+ // ISD::MUL may have already been lowered, meaning the operands would be in
+ // different positions.
+ if (Input.getOpcode() != ISD::MUL)
+ IndexAdd = 1;
+ auto A = Input.getOperand(IndexAdd);
+ auto B = Input.getOperand(IndexAdd + 1);
+
+ unsigned DotOpcode = Opcode == ISD::PARTIAL_REDUCE_SADD ? AArch64ISD::SDOT
+ : AArch64ISD::UDOT;
+ return DAG.getNode(DotOpcode, DL, AccVT, Acc, A, B);
+ }
+ bool InputIsSigned = Opcode == ISD::PARTIAL_REDUCE_SADD;
+ unsigned BottomOpcode =
+ InputIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
+ unsigned TopOpcode = InputIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT;
+ auto BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, Input);
+ return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, Input);
+}
+
SDValue
AArch64TargetLowering::LowerFixedLengthFPToIntToSVE(SDValue Op,
SelectionDAG &DAG) const {
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 61579de50db17e..f80cd0e1c8c8d6 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -1184,6 +1184,7 @@ class AArch64TargetLowering : public TargetLowering {
SDValue LowerVECTOR_DEINTERLEAVE(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerVECTOR_INTERLEAVE(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerVECTOR_HISTOGRAM(SDValue Op, SelectionDAG &DAG) const;
+ SDValue LowerPARTIAL_REDUCE_ADD(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerDIV(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerMUL(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerVectorSRA_SRL_SHL(SDValue Op, SelectionDAG &DAG) const;
>From 5e31db5d82d06155f25b027190642c602d8db5c8 Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Thu, 12 Dec 2024 13:50:36 +0000
Subject: [PATCH 08/14] Change the way the dot product pattern is checked for
lowering. Add condition in wide add combine to not allow fixed length
vectors.
---
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 12 +++++++++++-
1 file changed, 11 insertions(+), 1 deletion(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 483afe0a3750a9..5ec409cb7e97ce 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -22041,13 +22041,18 @@ SDValue tryCombineToDotProduct(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
return SDValue();
unsigned InputOpcode = Input->getOpcode();
+ EVT AccVT = Acc->getValueType(0);
+ if (AccVT.getVectorElementCount() * 4 ==
+ Input->getValueType(0).getVectorElementCount() &&
+ InputOpcode != ISD::MUL)
+ return DAG.expandPartialReduceAdd(DL, Acc, Input);
if (InputOpcode != ISD::MUL)
return SDValue();
+
auto A = Input->getOperand(0);
auto B = Input->getOperand(1);
unsigned AOpcode = A->getOpcode();
unsigned BOpcode = B->getOpcode();
- EVT AccVT = Acc->getValueType(0);
if (!ISD::isExtOpcode(AOpcode) || !ISD::isExtOpcode(BOpcode))
return DAG.expandPartialReduceAdd(DL, Acc, Input);
@@ -22122,6 +22127,8 @@ SDValue tryCombineToWideAdd(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
Input = Input->getOperand(0);
EVT InputVT = Input.getValueType();
EVT AccVT = Acc->getValueType(0);
+ if (!AccVT.isScalableVector())
+ return DAG.expandPartialReduceAdd(DL, Acc, Input);
if (!(InputVT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
!(InputVT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
@@ -29376,6 +29383,9 @@ AArch64TargetLowering::LowerPARTIAL_REDUCE_ADD(SDValue Op,
unsigned Opcode = Op.getOpcode();
+ // If the following condition is true and the input opcode was not ISD::MUL
+ // during the DAG-combine, it is already expanded. So this condition means the
+ // input opcode must have been ISD::MUL.
if (AccVT.getVectorElementCount() * 4 == InputVT.getVectorElementCount()) {
unsigned IndexAdd = 0;
// ISD::MUL may have already been lowered, meaning the operands would be in
>From 0a06b2a01c5e70f32aa4cca5515deacbbe202c06 Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Tue, 17 Dec 2024 13:59:13 +0000
Subject: [PATCH 09/14] Change from adding ISD::PARTIAL_REDUCE_S/UADD to adding
ISD::PARTIAL_REDUCE_S/UMLA
This makes the lowering function easier as you do not need to worry
about whether the MUL is lowered or not. Instead its operands are
taken from it. If there is no MUL instruction and just one operand,
the other operand is a vector of ones (for value types eligible for
wide add lowering).
---
llvm/include/llvm/CodeGen/ISDOpcodes.h | 4 +-
llvm/include/llvm/CodeGen/SelectionDAG.h | 2 +-
.../SelectionDAG/SelectionDAGBuilder.cpp | 2 +-
.../SelectionDAG/SelectionDAGDumper.cpp | 8 +--
.../Target/AArch64/AArch64ISelLowering.cpp | 60 ++++++++-----------
llvm/lib/Target/AArch64/AArch64ISelLowering.h | 2 +-
6 files changed, 34 insertions(+), 44 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index d58344a6cc9b1e..12d20b90c0c241 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -1455,8 +1455,8 @@ enum NodeType {
// unsigned).
// Operands: Accumulator, Input
// Outputs: Output
- PARTIAL_REDUCE_SADD,
- PARTIAL_REDUCE_UADD,
+ PARTIAL_REDUCE_SMLA,
+ PARTIAL_REDUCE_UMLA,
// The `llvm.experimental.stackmap` intrinsic.
// Operands: input chain, glue, <id>, <numShadowBytes>, [live0[, live1...]]
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 6daba913b37b63..8729311770c3e8 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -1604,7 +1604,7 @@ class SelectionDAG {
/// the target's desired shift amount type.
SDValue getShiftAmountOperand(EVT LHSTy, SDValue Op);
- /// Expands PARTIAL_REDUCE_S/UADD nodes.
+ /// Expands PARTIAL_REDUCE_S/UMLA nodes.
/// \p Op1 Accumulator for where the result is stored for the partial
/// reduction operation.
/// \p Op2 Input for the partial reduction operation.
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 139513bd9fc93f..d8ed88a25ae468 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -8125,7 +8125,7 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
if (!TLI.shouldExpandPartialReductionIntrinsic(cast<IntrinsicInst>(&I))) {
setValue(&I,
- DAG.getNode(ISD::PARTIAL_REDUCE_UADD, dl, AccVT, Acc, Input));
+ DAG.getNode(ISD::PARTIAL_REDUCE_UMLA, dl, AccVT, Acc, Input));
return;
}
setValue(&I, DAG.expandPartialReduceAdd(dl, Acc, Input));
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
index 5ca42904da852d..a387c10679261b 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
@@ -570,10 +570,10 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
case ISD::VECTOR_FIND_LAST_ACTIVE:
return "find_last_active";
- case ISD::PARTIAL_REDUCE_UADD:
- return "partial_reduce_uadd";
- case ISD::PARTIAL_REDUCE_SADD:
- return "partial_reduce_sadd";
+ case ISD::PARTIAL_REDUCE_UMLA:
+ return "partial_reduce_umla";
+ case ISD::PARTIAL_REDUCE_SMLA:
+ return "partial_reduce_smla";
// Vector Predication
#define BEGIN_REGISTER_VP_SDNODE(SDID, LEGALARG, NAME, ...) \
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 5ec409cb7e97ce..41bc358e203598 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1128,7 +1128,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setTargetDAGCombine(
{ISD::MGATHER, ISD::MSCATTER, ISD::EXPERIMENTAL_VECTOR_HISTOGRAM});
- setTargetDAGCombine({ISD::PARTIAL_REDUCE_SADD, ISD::PARTIAL_REDUCE_UADD});
+ setTargetDAGCombine({ISD::PARTIAL_REDUCE_SMLA, ISD::PARTIAL_REDUCE_UMLA});
setTargetDAGCombine(ISD::FP_EXTEND);
@@ -1848,14 +1848,14 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
}
for (auto VT : {MVT::nxv2i64, MVT::nxv4i32, MVT::nxv8i16}) {
- setOperationAction(ISD::PARTIAL_REDUCE_UADD, VT, Custom);
- setOperationAction(ISD::PARTIAL_REDUCE_SADD, VT, Custom);
+ setOperationAction(ISD::PARTIAL_REDUCE_UMLA, VT, Custom);
+ setOperationAction(ISD::PARTIAL_REDUCE_SMLA, VT, Custom);
}
}
for (auto VT : {MVT::v4i64, MVT::v4i32, MVT::v2i32}) {
- setOperationAction(ISD::PARTIAL_REDUCE_UADD, VT, Custom);
- setOperationAction(ISD::PARTIAL_REDUCE_SADD, VT, Custom);
+ setOperationAction(ISD::PARTIAL_REDUCE_UMLA, VT, Custom);
+ setOperationAction(ISD::PARTIAL_REDUCE_SMLA, VT, Custom);
}
if (Subtarget->hasMOPS() && Subtarget->hasMTE()) {
@@ -7669,9 +7669,9 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
return LowerFLDEXP(Op, DAG);
case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
return LowerVECTOR_HISTOGRAM(Op, DAG);
- case ISD::PARTIAL_REDUCE_UADD:
- case ISD::PARTIAL_REDUCE_SADD:
- return LowerPARTIAL_REDUCE_ADD(Op, DAG);
+ case ISD::PARTIAL_REDUCE_UMLA:
+ case ISD::PARTIAL_REDUCE_SMLA:
+ return LowerPARTIAL_REDUCE_MLA(Op, DAG);
}
}
@@ -22112,9 +22112,8 @@ SDValue tryCombineToDotProduct(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
return DAG.expandPartialReduceAdd(DL, Acc, Input);
unsigned NewOpcode =
- AIsSigned ? ISD::PARTIAL_REDUCE_SADD : ISD::PARTIAL_REDUCE_UADD;
- auto NewMul = DAG.getNode(ISD::MUL, DL, A.getValueType(), A, B);
- return DAG.getNode(NewOpcode, DL, AccVT, Acc, NewMul);
+ AIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+ return DAG.getNode(NewOpcode, DL, AccVT, Acc, A, B);
}
SDValue tryCombineToWideAdd(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
@@ -22136,9 +22135,10 @@ SDValue tryCombineToWideAdd(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
return SDValue();
unsigned NewOpcode = InputOpcode == ISD::SIGN_EXTEND
- ? ISD::PARTIAL_REDUCE_SADD
- : ISD::PARTIAL_REDUCE_UADD;
- return DAG.getNode(NewOpcode, DL, AccVT, Acc, Input);
+ ? ISD::PARTIAL_REDUCE_SMLA
+ : ISD::PARTIAL_REDUCE_UMLA;
+ return DAG.getNode(NewOpcode, DL, AccVT, Acc, Input,
+ DAG.getConstant(1, DL, InputVT));
}
SDValue performPartialReduceAddCombine(SDNode *N, SelectionDAG &DAG,
@@ -26599,8 +26599,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
case ISD::MSCATTER:
case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
return performMaskedGatherScatterCombine(N, DCI, DAG);
- case ISD::PARTIAL_REDUCE_UADD:
- case ISD::PARTIAL_REDUCE_SADD:
+ case ISD::PARTIAL_REDUCE_UMLA:
+ case ISD::PARTIAL_REDUCE_SMLA:
return performPartialReduceAddCombine(N, DAG, Subtarget);
case ISD::FP_EXTEND:
return performFPExtendCombine(N, DAG, DCI, Subtarget);
@@ -29372,39 +29372,29 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
}
SDValue
-AArch64TargetLowering::LowerPARTIAL_REDUCE_ADD(SDValue Op,
+AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
SelectionDAG &DAG) const {
SDLoc DL(Op);
SDValue Acc = Op.getOperand(0);
- SDValue Input = Op.getOperand(1);
+ SDValue Input1 = Op.getOperand(1);
+ SDValue Input2 = Op.getOperand(2);
EVT AccVT = Acc.getValueType();
- EVT InputVT = Input.getValueType();
+ EVT InputVT = Input1.getValueType();
unsigned Opcode = Op.getOpcode();
- // If the following condition is true and the input opcode was not ISD::MUL
- // during the DAG-combine, it is already expanded. So this condition means the
- // input opcode must have been ISD::MUL.
if (AccVT.getVectorElementCount() * 4 == InputVT.getVectorElementCount()) {
- unsigned IndexAdd = 0;
- // ISD::MUL may have already been lowered, meaning the operands would be in
- // different positions.
- if (Input.getOpcode() != ISD::MUL)
- IndexAdd = 1;
- auto A = Input.getOperand(IndexAdd);
- auto B = Input.getOperand(IndexAdd + 1);
-
- unsigned DotOpcode = Opcode == ISD::PARTIAL_REDUCE_SADD ? AArch64ISD::SDOT
+ unsigned DotOpcode = Opcode == ISD::PARTIAL_REDUCE_SMLA ? AArch64ISD::SDOT
: AArch64ISD::UDOT;
- return DAG.getNode(DotOpcode, DL, AccVT, Acc, A, B);
+ return DAG.getNode(DotOpcode, DL, AccVT, Acc, Input1, Input2);
}
- bool InputIsSigned = Opcode == ISD::PARTIAL_REDUCE_SADD;
+ bool InputIsSigned = Opcode == ISD::PARTIAL_REDUCE_SMLA;
unsigned BottomOpcode =
InputIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
unsigned TopOpcode = InputIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT;
- auto BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, Input);
- return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, Input);
+ auto BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, Input1);
+ return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, Input1);
}
SDValue
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index f80cd0e1c8c8d6..3231a3fb0a67cf 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -1184,7 +1184,7 @@ class AArch64TargetLowering : public TargetLowering {
SDValue LowerVECTOR_DEINTERLEAVE(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerVECTOR_INTERLEAVE(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerVECTOR_HISTOGRAM(SDValue Op, SelectionDAG &DAG) const;
- SDValue LowerPARTIAL_REDUCE_ADD(SDValue Op, SelectionDAG &DAG) const;
+ SDValue LowerPARTIAL_REDUCE_MLA(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerDIV(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerMUL(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerVectorSRA_SRL_SHL(SDValue Op, SelectionDAG &DAG) const;
>From 60924265c42deb262af3e55449e10ab327e37471 Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Thu, 19 Dec 2024 09:35:48 +0000
Subject: [PATCH 10/14] MUL instructions now included in DAG combines.
---
llvm/include/llvm/CodeGen/SelectionDAG.h | 8 ++-
.../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 16 +++--
.../SelectionDAG/SelectionDAGBuilder.cpp | 8 ++-
.../Target/AArch64/AArch64ISelLowering.cpp | 67 +++++++++++--------
4 files changed, 62 insertions(+), 37 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 8729311770c3e8..37ffb229d24aeb 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -1608,7 +1608,13 @@ class SelectionDAG {
/// \p Op1 Accumulator for where the result is stored for the partial
/// reduction operation.
/// \p Op2 Input for the partial reduction operation.
- SDValue expandPartialReduceAdd(SDLoc DL, SDValue Op1, SDValue Op2);
+ /// Expands PARTIAL_REDUCE_S/UMLA nodes.
+ /// \p Acc Accumulator for where the result is stored for the partial
+ /// reduction operation.
+ /// \p Input1 First input for the partial reduction operation.
+ /// \p Input2 Second input for the partial reduction operation.
+ SDValue expandPartialReduceAdd(SDLoc DL, SDValue Acc, SDValue Input1,
+ SDValue Input2);
/// Expands a node with multiple results to an FP or vector libcall. The
/// libcall is expected to take all the operands of the \p Node followed by
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index b720379497d911..2e82385d5a8aff 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -2467,20 +2467,24 @@ SDValue SelectionDAG::getShiftAmountOperand(EVT LHSTy, SDValue Op) {
return getZExtOrTrunc(Op, SDLoc(Op), ShTy);
}
-SDValue SelectionDAG::expandPartialReduceAdd(SDLoc DL, SDValue Op1,
- SDValue Op2) {
- EVT ReducedTy = Op1.getValueType();
- EVT FullTy = Op2.getValueType();
+SDValue SelectionDAG::expandPartialReduceAdd(SDLoc DL, SDValue Acc,
+ SDValue Input1, SDValue Input2) {
+
+ EVT FullTy = Input1.getValueType();
+ Input2 = getAnyExtOrTrunc(Input2, DL, FullTy);
+ SDValue Input = getNode(ISD::MUL, DL, FullTy, Input1, Input2);
+
+ EVT ReducedTy = Acc.getValueType();
unsigned Stride = ReducedTy.getVectorMinNumElements();
unsigned ScaleFactor = FullTy.getVectorMinNumElements() / Stride;
// Collect all of the subvectors
- std::deque<SDValue> Subvectors = {Op1};
+ std::deque<SDValue> Subvectors = {Acc};
for (unsigned I = 0; I < ScaleFactor; I++) {
auto SourceIndex = getVectorIdxConstant(I * Stride, DL);
Subvectors.push_back(
- getNode(ISD::EXTRACT_SUBVECTOR, DL, ReducedTy, {Op2, SourceIndex}));
+ getNode(ISD::EXTRACT_SUBVECTOR, DL, ReducedTy, {Input, SourceIndex}));
}
// Flatten the subvector tree
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index d8ed88a25ae468..26e1693b49735f 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -8124,11 +8124,13 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
SDValue Input = getValue(I.getOperand(1));
if (!TLI.shouldExpandPartialReductionIntrinsic(cast<IntrinsicInst>(&I))) {
- setValue(&I,
- DAG.getNode(ISD::PARTIAL_REDUCE_UMLA, dl, AccVT, Acc, Input));
+ setValue(&I, DAG.getNode(ISD::PARTIAL_REDUCE_UMLA, dl, AccVT, Acc, Input,
+ DAG.getConstant(1, dl, Input.getValueType())));
return;
}
- setValue(&I, DAG.expandPartialReduceAdd(dl, Acc, Input));
+ setValue(&I,
+ DAG.expandPartialReduceAdd(
+ dl, Acc, Input, DAG.getConstant(1, dl, Input.getValueType())));
return;
}
case Intrinsic::experimental_cttz_elts: {
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 41bc358e203598..28b5d1102b5aaa 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -22032,7 +22032,8 @@ static SDValue tryCombineWhileLo(SDNode *N,
return SDValue(N, 0);
}
-SDValue tryCombineToDotProduct(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
+SDValue tryCombineToDotProduct(SDValue &Acc, SDValue &Input1, SDValue &Input2,
+ SelectionDAG &DAG,
const AArch64Subtarget *Subtarget, SDLoc &DL) {
bool Scalable = Acc.getValueType().isScalableVector();
if (Scalable && !Subtarget->isSVEorStreamingSVEAvailable())
@@ -22040,22 +22041,22 @@ SDValue tryCombineToDotProduct(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
if (!Scalable && (!Subtarget->isNeonAvailable() || !Subtarget->hasDotProd()))
return SDValue();
- unsigned InputOpcode = Input->getOpcode();
+ unsigned Input1Opcode = Input1->getOpcode();
EVT AccVT = Acc->getValueType(0);
if (AccVT.getVectorElementCount() * 4 ==
- Input->getValueType(0).getVectorElementCount() &&
- InputOpcode != ISD::MUL)
- return DAG.expandPartialReduceAdd(DL, Acc, Input);
- if (InputOpcode != ISD::MUL)
+ Input1->getValueType(0).getVectorElementCount() &&
+ Input1Opcode != ISD::MUL)
+ return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2);
+ if (Input1Opcode != ISD::MUL)
return SDValue();
- auto A = Input->getOperand(0);
- auto B = Input->getOperand(1);
+ auto A = Input1->getOperand(0);
+ auto B = Input1->getOperand(1);
unsigned AOpcode = A->getOpcode();
unsigned BOpcode = B->getOpcode();
if (!ISD::isExtOpcode(AOpcode) || !ISD::isExtOpcode(BOpcode))
- return DAG.expandPartialReduceAdd(DL, Acc, Input);
+ return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2);
bool AIsSigned = AOpcode == ISD::SIGN_EXTEND;
bool BIsSigned = BOpcode == ISD::SIGN_EXTEND;
@@ -22064,6 +22065,10 @@ SDValue tryCombineToDotProduct(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
B = B->getOperand(0);
EVT MulSrcVT = A.getValueType();
+ Input2 = DAG.getAnyExtOrTrunc(Input2, DL, MulSrcVT);
+ A = DAG.getNode(ISD::MUL, DL, MulSrcVT, A, Input2);
+ B = DAG.getNode(ISD::MUL, DL, MulSrcVT, B, Input2);
+
// Dot products operate on chunks of four elements so there must be four times
// as many elements in the wide type
if (!(AccVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) &&
@@ -22072,17 +22077,17 @@ SDValue tryCombineToDotProduct(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
!(AccVT == MVT::v4i64 && MulSrcVT == MVT::v16i8) &&
!(AccVT == MVT::v4i32 && MulSrcVT == MVT::v16i8) &&
!(AccVT == MVT::v2i32 && MulSrcVT == MVT::v8i8))
- return DAG.expandPartialReduceAdd(DL, Acc, Input);
+ return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2);
unsigned DotOpcode = AIsSigned ? AArch64ISD::SDOT : AArch64ISD::UDOT;
if (AIsSigned != BIsSigned) {
if (!Subtarget->hasMatMulInt8())
- return DAG.expandPartialReduceAdd(DL, Acc, Input);
+ return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2);
bool Scalable = AccVT.isScalableVT();
// There's no nxv2i64 version of usdot
if (Scalable && AccVT != MVT::nxv4i32 && AccVT != MVT::nxv4i64)
- return DAG.expandPartialReduceAdd(DL, Acc, Input);
+ return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2);
if (!BIsSigned)
std::swap(A, B);
@@ -22109,32 +22114,37 @@ SDValue tryCombineToDotProduct(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
}
if (A.getValueType() != B.getValueType())
- return DAG.expandPartialReduceAdd(DL, Acc, Input);
+ return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2);
unsigned NewOpcode =
AIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
return DAG.getNode(NewOpcode, DL, AccVT, Acc, A, B);
}
-SDValue tryCombineToWideAdd(SDValue &Acc, SDValue &Input, SelectionDAG &DAG,
+SDValue tryCombineToWideAdd(SDValue &Acc, SDValue &Input1, SDValue &Input2,
+ SelectionDAG &DAG,
const AArch64Subtarget *Subtarget, SDLoc &DL) {
if (!Subtarget->hasSVE2() && !Subtarget->isStreamingSVEAvailable())
- return DAG.expandPartialReduceAdd(DL, Acc, Input);
- unsigned InputOpcode = Input->getOpcode();
- if (!ISD::isExtOpcode(InputOpcode))
- return DAG.expandPartialReduceAdd(DL, Acc, Input);
- Input = Input->getOperand(0);
- EVT InputVT = Input.getValueType();
+ return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2);
+ unsigned Input1Opcode = Input1->getOpcode();
+ if (!ISD::isExtOpcode(Input1Opcode))
+ return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2);
+
EVT AccVT = Acc->getValueType(0);
+ Input1 = Input1->getOperand(0);
+ EVT InputVT = Input1.getValueType();
+ Input2 = DAG.getAnyExtOrTrunc(Input2, DL, InputVT);
+ SDValue Input = DAG.getNode(ISD::MUL, DL, InputVT, Input1, Input2);
+
if (!AccVT.isScalableVector())
- return DAG.expandPartialReduceAdd(DL, Acc, Input);
+ return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2);
if (!(InputVT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
!(InputVT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
!(InputVT == MVT::nxv16i8 && AccVT == MVT::nxv8i16))
return SDValue();
- unsigned NewOpcode = InputOpcode == ISD::SIGN_EXTEND
+ unsigned NewOpcode = Input1Opcode == ISD::SIGN_EXTEND
? ISD::PARTIAL_REDUCE_SMLA
: ISD::PARTIAL_REDUCE_UMLA;
return DAG.getNode(NewOpcode, DL, AccVT, Acc, Input,
@@ -22145,18 +22155,21 @@ SDValue performPartialReduceAddCombine(SDNode *N, SelectionDAG &DAG,
const AArch64Subtarget *Subtarget) {
SDLoc DL(N);
auto Acc = N->getOperand(0);
- auto Input = N->getOperand(1);
+ auto Input1 = N->getOperand(1);
+ auto Input2 = N->getOperand(2);
EVT AccElemVT = Acc.getValueType().getVectorElementType();
- EVT InputElemVT = Input.getValueType().getVectorElementType();
+ EVT InputElemVT = Input1.getValueType().getVectorElementType();
// If the exts have already been removed or it has already been lowered to an
// usdot instruction, then the element types will not be equal
- if (InputElemVT != AccElemVT || Input.getOpcode() == AArch64ISD::USDOT)
+ if (InputElemVT != AccElemVT || Input1.getOpcode() == AArch64ISD::USDOT)
return SDValue(N, 0);
- if (auto Dot = tryCombineToDotProduct(Acc, Input, DAG, Subtarget, DL))
+ if (auto Dot =
+ tryCombineToDotProduct(Acc, Input1, Input2, DAG, Subtarget, DL))
return Dot;
- if (auto WideAdd = tryCombineToWideAdd(Acc, Input, DAG, Subtarget, DL))
+ if (auto WideAdd =
+ tryCombineToWideAdd(Acc, Input1, Input2, DAG, Subtarget, DL))
return WideAdd;
return SDValue();
}
>From 9971a6ef494ef96a4a44a4ad1a178094723394f1 Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Fri, 10 Jan 2025 09:27:11 +0000
Subject: [PATCH 11/14] Make the no bin op changes work with adding Partial
Reduction SDNodes.
---
.../Target/AArch64/AArch64ISelLowering.cpp | 186 +++++++++---------
1 file changed, 96 insertions(+), 90 deletions(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 28b5d1102b5aaa..a355a6774cc2b7 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -22032,144 +22032,150 @@ static SDValue tryCombineWhileLo(SDNode *N,
return SDValue(N, 0);
}
-SDValue tryCombineToDotProduct(SDValue &Acc, SDValue &Input1, SDValue &Input2,
+SDValue tryCombineToDotProduct(SDValue &Op0, SDValue &Op1, SDValue &Op2,
SelectionDAG &DAG,
const AArch64Subtarget *Subtarget, SDLoc &DL) {
- bool Scalable = Acc.getValueType().isScalableVector();
+ bool Scalable = Op0->getValueType(0).isScalableVector();
if (Scalable && !Subtarget->isSVEorStreamingSVEAvailable())
- return SDValue();
+ return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2);
if (!Scalable && (!Subtarget->isNeonAvailable() || !Subtarget->hasDotProd()))
+ return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2);
+
+ unsigned Op1Opcode = Op1->getOpcode();
+ SDValue MulOpLHS, MulOpRHS;
+ bool MulOpLHSIsSigned, MulOpRHSIsSigned;
+ if (ISD::isExtOpcode(Op1Opcode)) {
+ MulOpLHSIsSigned = MulOpRHSIsSigned = (Op1Opcode == ISD::SIGN_EXTEND);
+ MulOpLHS = Op1->getOperand(0);
+ MulOpRHS = DAG.getAnyExtOrTrunc(Op2, DL, MulOpLHS.getValueType());
+ } else if (Op1Opcode == ISD::MUL) {
+ SDValue ExtMulOpLHS = Op1->getOperand(0);
+ SDValue ExtMulOpRHS = Op1->getOperand(1);
+
+ unsigned ExtMulOpLHSOpcode = ExtMulOpLHS->getOpcode();
+ unsigned ExtMulOpRHSOpcode = ExtMulOpRHS->getOpcode();
+ if (!ISD::isExtOpcode(ExtMulOpLHSOpcode) ||
+ !ISD::isExtOpcode(ExtMulOpRHSOpcode))
+ return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2);
+
+ MulOpLHSIsSigned = ExtMulOpLHSOpcode == ISD::SIGN_EXTEND;
+ MulOpRHSIsSigned = ExtMulOpRHSOpcode == ISD::SIGN_EXTEND;
+
+ MulOpLHS = ExtMulOpLHS->getOperand(0);
+ MulOpRHS = ExtMulOpRHS->getOperand(0);
+ EVT MulOpLHSVT = MulOpLHS.getValueType();
+
+ if (MulOpLHSVT != MulOpRHS.getValueType())
+ return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2);
+
+ Op2 = DAG.getAnyExtOrTrunc(Op2, DL, MulOpLHSVT);
+ MulOpLHS = DAG.getNode(ISD::MUL, DL, MulOpLHSVT, MulOpLHS, Op2);
+ MulOpRHS = DAG.getNode(ISD::MUL, DL, MulOpLHSVT, MulOpRHS, Op2);
+ } else
return SDValue();
- unsigned Input1Opcode = Input1->getOpcode();
- EVT AccVT = Acc->getValueType(0);
- if (AccVT.getVectorElementCount() * 4 ==
- Input1->getValueType(0).getVectorElementCount() &&
- Input1Opcode != ISD::MUL)
- return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2);
- if (Input1Opcode != ISD::MUL)
- return SDValue();
-
- auto A = Input1->getOperand(0);
- auto B = Input1->getOperand(1);
- unsigned AOpcode = A->getOpcode();
- unsigned BOpcode = B->getOpcode();
-
- if (!ISD::isExtOpcode(AOpcode) || !ISD::isExtOpcode(BOpcode))
- return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2);
-
- bool AIsSigned = AOpcode == ISD::SIGN_EXTEND;
- bool BIsSigned = BOpcode == ISD::SIGN_EXTEND;
-
- A = A->getOperand(0);
- B = B->getOperand(0);
- EVT MulSrcVT = A.getValueType();
-
- Input2 = DAG.getAnyExtOrTrunc(Input2, DL, MulSrcVT);
- A = DAG.getNode(ISD::MUL, DL, MulSrcVT, A, Input2);
- B = DAG.getNode(ISD::MUL, DL, MulSrcVT, B, Input2);
+ SDValue Acc = Op0;
+ EVT ReducedVT = Acc->getValueType(0);
+ EVT MulSrcVT = MulOpLHS.getValueType();
// Dot products operate on chunks of four elements so there must be four times
// as many elements in the wide type
- if (!(AccVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) &&
- !(AccVT == MVT::nxv4i32 && MulSrcVT == MVT::nxv16i8) &&
- !(AccVT == MVT::nxv2i64 && MulSrcVT == MVT::nxv8i16) &&
- !(AccVT == MVT::v4i64 && MulSrcVT == MVT::v16i8) &&
- !(AccVT == MVT::v4i32 && MulSrcVT == MVT::v16i8) &&
- !(AccVT == MVT::v2i32 && MulSrcVT == MVT::v8i8))
- return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2);
-
- unsigned DotOpcode = AIsSigned ? AArch64ISD::SDOT : AArch64ISD::UDOT;
- if (AIsSigned != BIsSigned) {
+ if (!(ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) &&
+ !(ReducedVT == MVT::nxv4i32 && MulSrcVT == MVT::nxv16i8) &&
+ !(ReducedVT == MVT::nxv2i64 && MulSrcVT == MVT::nxv8i16) &&
+ !(ReducedVT == MVT::v4i64 && MulSrcVT == MVT::v16i8) &&
+ !(ReducedVT == MVT::v4i32 && MulSrcVT == MVT::v16i8) &&
+ !(ReducedVT == MVT::v2i32 && MulSrcVT == MVT::v8i8))
+ return SDValue();
+
+ // If the extensions are mixed, we should lower it to a usdot instead
+ unsigned DotOpcode = MulOpLHSIsSigned ? AArch64ISD::SDOT : AArch64ISD::UDOT;
+ if (MulOpLHSIsSigned != MulOpRHSIsSigned) {
if (!Subtarget->hasMatMulInt8())
- return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2);
+ return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2);
- bool Scalable = AccVT.isScalableVT();
+ bool Scalable = ReducedVT.isScalableVT();
// There's no nxv2i64 version of usdot
- if (Scalable && AccVT != MVT::nxv4i32 && AccVT != MVT::nxv4i64)
- return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2);
+ if (Scalable && ReducedVT != MVT::nxv4i32 && ReducedVT != MVT::nxv4i64)
+ return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2);
- if (!BIsSigned)
- std::swap(A, B);
+ if (!MulOpRHSIsSigned)
+ std::swap(MulOpLHS, MulOpRHS);
DotOpcode = AArch64ISD::USDOT;
// Lower usdot patterns here because legalisation would attempt to split it
// unless exts are removed. But, removing the exts would lose the
// information about whether each operand is signed.
- if ((AccVT != MVT::nxv4i64 || MulSrcVT != MVT::nxv16i8) &&
- (AccVT != MVT::v4i64 || MulSrcVT != MVT::v16i8))
- return DAG.getNode(DotOpcode, DL, AccVT, Acc, A, B);
+ if ((ReducedVT != MVT::nxv4i64 || MulSrcVT != MVT::nxv16i8) &&
+ (ReducedVT != MVT::v4i64 || MulSrcVT != MVT::v16i8))
+ return DAG.getNode(DotOpcode, DL, ReducedVT, Acc, MulOpLHS, MulOpRHS);
}
// Partial reduction lowering for (nx)v16i8 to (nx)v4i64 requires an i32 dot
// product followed by a zero / sign extension. Need to lower this here
// because legalisation would attempt to split it.
- if ((AccVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) ||
- (AccVT == MVT::v4i64 && MulSrcVT == MVT::v16i8)) {
- EVT AccVTI32 = (AccVT.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
+ if ((ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) ||
+ (ReducedVT == MVT::v4i64 && MulSrcVT == MVT::v16i8)) {
+ EVT ReducedVTI32 =
+ (ReducedVT.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
- auto DotI32 = DAG.getNode(DotOpcode, DL, AccVTI32,
- DAG.getConstant(0, DL, AccVTI32), A, B);
- auto Extended = DAG.getSExtOrTrunc(DotI32, DL, AccVT);
- return DAG.getNode(ISD::ADD, DL, AccVT, Acc, Extended);
+ SDValue DotI32 =
+ DAG.getNode(DotOpcode, DL, ReducedVTI32,
+ DAG.getConstant(0, DL, ReducedVTI32), MulOpLHS, MulOpRHS);
+ SDValue Extended = DAG.getSExtOrTrunc(DotI32, DL, ReducedVT);
+ return DAG.getNode(ISD::ADD, DL, ReducedVT, Acc, Extended);
}
- if (A.getValueType() != B.getValueType())
- return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2);
-
unsigned NewOpcode =
- AIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
- return DAG.getNode(NewOpcode, DL, AccVT, Acc, A, B);
+ MulOpLHSIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+ return DAG.getNode(NewOpcode, DL, ReducedVT, Acc, MulOpLHS, MulOpRHS);
}
-SDValue tryCombineToWideAdd(SDValue &Acc, SDValue &Input1, SDValue &Input2,
+SDValue tryCombineToWideAdd(SDValue &Op0, SDValue &Op1, SDValue &Op2,
SelectionDAG &DAG,
const AArch64Subtarget *Subtarget, SDLoc &DL) {
if (!Subtarget->hasSVE2() && !Subtarget->isStreamingSVEAvailable())
- return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2);
- unsigned Input1Opcode = Input1->getOpcode();
- if (!ISD::isExtOpcode(Input1Opcode))
- return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2);
+ return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2);
+ unsigned Op1Opcode = Op1->getOpcode();
+ if (!ISD::isExtOpcode(Op1Opcode))
+ return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2);
- EVT AccVT = Acc->getValueType(0);
- Input1 = Input1->getOperand(0);
- EVT InputVT = Input1.getValueType();
- Input2 = DAG.getAnyExtOrTrunc(Input2, DL, InputVT);
- SDValue Input = DAG.getNode(ISD::MUL, DL, InputVT, Input1, Input2);
+ EVT AccVT = Op0->getValueType(0);
+ Op1 = Op1->getOperand(0);
+ EVT Op1VT = Op1.getValueType();
+ Op2 = DAG.getAnyExtOrTrunc(Op2, DL, Op1VT);
+ SDValue Input = DAG.getNode(ISD::MUL, DL, Op1VT, Op1, Op2);
if (!AccVT.isScalableVector())
- return DAG.expandPartialReduceAdd(DL, Acc, Input1, Input2);
+ return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2);
- if (!(InputVT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
- !(InputVT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
- !(InputVT == MVT::nxv16i8 && AccVT == MVT::nxv8i16))
+ if (!(Op1VT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
+ !(Op1VT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
+ !(Op1VT == MVT::nxv16i8 && AccVT == MVT::nxv8i16))
return SDValue();
- unsigned NewOpcode = Input1Opcode == ISD::SIGN_EXTEND
- ? ISD::PARTIAL_REDUCE_SMLA
- : ISD::PARTIAL_REDUCE_UMLA;
- return DAG.getNode(NewOpcode, DL, AccVT, Acc, Input,
- DAG.getConstant(1, DL, InputVT));
+ unsigned NewOpcode = Op1Opcode == ISD::SIGN_EXTEND ? ISD::PARTIAL_REDUCE_SMLA
+ : ISD::PARTIAL_REDUCE_UMLA;
+ return DAG.getNode(NewOpcode, DL, AccVT, Op0, Input,
+ DAG.getConstant(1, DL, Op1VT));
}
SDValue performPartialReduceAddCombine(SDNode *N, SelectionDAG &DAG,
const AArch64Subtarget *Subtarget) {
SDLoc DL(N);
- auto Acc = N->getOperand(0);
- auto Input1 = N->getOperand(1);
- auto Input2 = N->getOperand(2);
- EVT AccElemVT = Acc.getValueType().getVectorElementType();
- EVT InputElemVT = Input1.getValueType().getVectorElementType();
+ SDValue Op0 = N->getOperand(0);
+ SDValue Op1 = N->getOperand(1);
+ SDValue Op2 = N->getOperand(2);
+ EVT Op0ElemVT = Op0.getValueType().getVectorElementType();
+ EVT Op1ElemVT = Op1.getValueType().getVectorElementType();
// If the exts have already been removed or it has already been lowered to an
// usdot instruction, then the element types will not be equal
- if (InputElemVT != AccElemVT || Input1.getOpcode() == AArch64ISD::USDOT)
+ if (Op0ElemVT != Op1ElemVT || Op1.getOpcode() == AArch64ISD::USDOT)
return SDValue(N, 0);
- if (auto Dot =
- tryCombineToDotProduct(Acc, Input1, Input2, DAG, Subtarget, DL))
+ if (auto Dot = tryCombineToDotProduct(Op0, Op1, Op2, DAG, Subtarget, DL))
return Dot;
- if (auto WideAdd =
- tryCombineToWideAdd(Acc, Input1, Input2, DAG, Subtarget, DL))
+ if (auto WideAdd = tryCombineToWideAdd(Op0, Op1, Op2, DAG, Subtarget, DL))
return WideAdd;
return SDValue();
}
>From 3cdf92fc3d3b0bde96e4d18611ef837ee2a74fee Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Mon, 20 Jan 2025 09:59:44 +0000
Subject: [PATCH 12/14] Address comments on patch. Remove
shouldExpandPartialReductionIntrinsic().
---
llvm/include/llvm/CodeGen/ISDOpcodes.h | 15 +++++--
llvm/include/llvm/CodeGen/SelectionDAG.h | 4 --
llvm/include/llvm/CodeGen/TargetLowering.h | 7 ---
.../SelectionDAG/SelectionDAGBuilder.cpp | 25 +++++++----
.../Target/AArch64/AArch64ISelLowering.cpp | 45 ++++++-------------
llvm/lib/Target/AArch64/AArch64ISelLowering.h | 3 --
6 files changed, 41 insertions(+), 58 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index 12d20b90c0c241..c844d9505b3ef4 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -1451,9 +1451,18 @@ enum NodeType {
VECREDUCE_UMAX,
VECREDUCE_UMIN,
- // Nodes used to represent a partial reduction addition operation (signed and
- // unsigned).
- // Operands: Accumulator, Input
+ // Partial Reduction nodes. These represent multiply-add instructions because
+ // Input1 and Input2 are multiplied together first. This result is then
+ // reduced, by addition, to the number of elements that the Accumulator's type
+ // has.
+ // Input1 and Input2 must be the same type. Accumulator's element type must
+ // match that of Input1 and Input2. The number of elements in Input1 and
+ // Input2 must be a positive integer multiple of the number of elements in the
+ // Accumulator.
+ // The signedness of this node will dictate the signedness of nodes expanded
+ // from it. The signedness of the node is dictated by the signedness of
+ // Input1.
+ // Operands: Accumulator, Input1, Input2
// Outputs: Output
PARTIAL_REDUCE_SMLA,
PARTIAL_REDUCE_UMLA,
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 37ffb229d24aeb..9381b2465684b8 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -1604,10 +1604,6 @@ class SelectionDAG {
/// the target's desired shift amount type.
SDValue getShiftAmountOperand(EVT LHSTy, SDValue Op);
- /// Expands PARTIAL_REDUCE_S/UMLA nodes.
- /// \p Op1 Accumulator for where the result is stored for the partial
- /// reduction operation.
- /// \p Op2 Input for the partial reduction operation.
/// Expands PARTIAL_REDUCE_S/UMLA nodes.
/// \p Acc Accumulator for where the result is stored for the partial
/// reduction operation.
diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h
index 38ac90f0c081b3..2a77d4922242ec 100644
--- a/llvm/include/llvm/CodeGen/TargetLowering.h
+++ b/llvm/include/llvm/CodeGen/TargetLowering.h
@@ -455,13 +455,6 @@ class TargetLoweringBase {
return true;
}
- /// Return true if the @llvm.experimental.vector.partial.reduce.* intrinsic
- /// should be expanded using generic code in SelectionDAGBuilder.
- virtual bool
- shouldExpandPartialReductionIntrinsic(const IntrinsicInst *I) const {
- return true;
- }
-
/// Return true if the @llvm.get.active.lane.mask intrinsic should be expanded
/// using generic code in SelectionDAGBuilder.
virtual bool shouldExpandGetActiveLaneMask(EVT VT, EVT OpVT) const {
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 26e1693b49735f..71f9148ecd0702 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -8122,15 +8122,22 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
SDValue Acc = getValue(I.getOperand(0));
EVT AccVT = Acc.getValueType();
SDValue Input = getValue(I.getOperand(1));
-
- if (!TLI.shouldExpandPartialReductionIntrinsic(cast<IntrinsicInst>(&I))) {
- setValue(&I, DAG.getNode(ISD::PARTIAL_REDUCE_UMLA, dl, AccVT, Acc, Input,
- DAG.getConstant(1, dl, Input.getValueType())));
- return;
- }
- setValue(&I,
- DAG.expandPartialReduceAdd(
- dl, Acc, Input, DAG.getConstant(1, dl, Input.getValueType())));
+ EVT InputVT = Input.getValueType();
+
+ assert(AccVT.getVectorElementType() == InputVT.getVectorElementType() &&
+ "Expected operands to have the same vector element type!");
+ assert(InputVT.getVectorElementCount().getKnownMinValue() %
+ AccVT.getVectorElementCount().getKnownMinValue() ==
+ 0 &&
+ "Expected the element count of the Input operand to be a positive "
+ "integer multiple of the element count of the Accumulator operand!");
+
+ // ISD::PARTIAL_REDUCE_UMLA is chosen arbitrarily and would function the
+ // same if ISD::PARTIAL_REDUCE_SMLA was used instead. It should be changed
+ // to its correct signedness when combining or expanding, according to
+ // extends being performed on Input.
+ setValue(&I, DAG.getNode(ISD::PARTIAL_REDUCE_UMLA, dl, AccVT, Acc, Input,
+ DAG.getConstant(1, dl, InputVT)));
return;
}
case Intrinsic::experimental_cttz_elts: {
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index a355a6774cc2b7..35781592c3002f 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2049,28 +2049,6 @@ bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
return false;
}
-bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
- const IntrinsicInst *I) const {
- if (I->getIntrinsicID() != Intrinsic::experimental_vector_partial_reduce_add)
- return true;
-
- EVT VT = EVT::getEVT(I->getType());
- auto Input = I->getOperand(1);
- EVT InputVT = EVT::getEVT(Input->getType());
-
- if ((InputVT == MVT::nxv4i64 && VT == MVT::nxv2i64) ||
- (InputVT == MVT::nxv8i32 && VT == MVT::nxv4i32) ||
- (InputVT == MVT::nxv16i16 && VT == MVT::nxv8i16) ||
- (InputVT == MVT::nxv16i64 && VT == MVT::nxv4i64) ||
- (InputVT == MVT::nxv16i32 && VT == MVT::nxv4i32) ||
- (InputVT == MVT::nxv8i64 && VT == MVT::nxv2i64) ||
- (InputVT == MVT::v16i64 && VT == MVT::v4i64) ||
- (InputVT == MVT::v16i32 && VT == MVT::v4i32) ||
- (InputVT == MVT::v8i32 && VT == MVT::v2i32))
- return false;
- return true;
-}
-
bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
if (!Subtarget->isSVEorStreamingSVEAvailable())
return true;
@@ -22037,9 +22015,9 @@ SDValue tryCombineToDotProduct(SDValue &Op0, SDValue &Op1, SDValue &Op2,
const AArch64Subtarget *Subtarget, SDLoc &DL) {
bool Scalable = Op0->getValueType(0).isScalableVector();
if (Scalable && !Subtarget->isSVEorStreamingSVEAvailable())
- return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2);
+ return SDValue();
if (!Scalable && (!Subtarget->isNeonAvailable() || !Subtarget->hasDotProd()))
- return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2);
+ return SDValue();
unsigned Op1Opcode = Op1->getOpcode();
SDValue MulOpLHS, MulOpRHS;
@@ -22056,7 +22034,7 @@ SDValue tryCombineToDotProduct(SDValue &Op0, SDValue &Op1, SDValue &Op2,
unsigned ExtMulOpRHSOpcode = ExtMulOpRHS->getOpcode();
if (!ISD::isExtOpcode(ExtMulOpLHSOpcode) ||
!ISD::isExtOpcode(ExtMulOpRHSOpcode))
- return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2);
+ return SDValue();
MulOpLHSIsSigned = ExtMulOpLHSOpcode == ISD::SIGN_EXTEND;
MulOpRHSIsSigned = ExtMulOpRHSOpcode == ISD::SIGN_EXTEND;
@@ -22066,7 +22044,7 @@ SDValue tryCombineToDotProduct(SDValue &Op0, SDValue &Op1, SDValue &Op2,
EVT MulOpLHSVT = MulOpLHS.getValueType();
if (MulOpLHSVT != MulOpRHS.getValueType())
- return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2);
+ return SDValue();
Op2 = DAG.getAnyExtOrTrunc(Op2, DL, MulOpLHSVT);
MulOpLHS = DAG.getNode(ISD::MUL, DL, MulOpLHSVT, MulOpLHS, Op2);
@@ -22092,12 +22070,12 @@ SDValue tryCombineToDotProduct(SDValue &Op0, SDValue &Op1, SDValue &Op2,
unsigned DotOpcode = MulOpLHSIsSigned ? AArch64ISD::SDOT : AArch64ISD::UDOT;
if (MulOpLHSIsSigned != MulOpRHSIsSigned) {
if (!Subtarget->hasMatMulInt8())
- return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2);
+ return SDValue();
bool Scalable = ReducedVT.isScalableVT();
// There's no nxv2i64 version of usdot
if (Scalable && ReducedVT != MVT::nxv4i32 && ReducedVT != MVT::nxv4i64)
- return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2);
+ return SDValue();
if (!MulOpRHSIsSigned)
std::swap(MulOpLHS, MulOpRHS);
@@ -22134,10 +22112,10 @@ SDValue tryCombineToWideAdd(SDValue &Op0, SDValue &Op1, SDValue &Op2,
SelectionDAG &DAG,
const AArch64Subtarget *Subtarget, SDLoc &DL) {
if (!Subtarget->hasSVE2() && !Subtarget->isStreamingSVEAvailable())
- return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2);
+ return SDValue();
unsigned Op1Opcode = Op1->getOpcode();
if (!ISD::isExtOpcode(Op1Opcode))
- return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2);
+ return SDValue();
EVT AccVT = Op0->getValueType(0);
Op1 = Op1->getOperand(0);
@@ -22146,7 +22124,7 @@ SDValue tryCombineToWideAdd(SDValue &Op0, SDValue &Op1, SDValue &Op2,
SDValue Input = DAG.getNode(ISD::MUL, DL, Op1VT, Op1, Op2);
if (!AccVT.isScalableVector())
- return DAG.expandPartialReduceAdd(DL, Op0, Op1, Op2);
+ return SDValue();
if (!(Op1VT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
!(Op1VT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
@@ -22177,7 +22155,10 @@ SDValue performPartialReduceAddCombine(SDNode *N, SelectionDAG &DAG,
return Dot;
if (auto WideAdd = tryCombineToWideAdd(Op0, Op1, Op2, DAG, Subtarget, DL))
return WideAdd;
- return SDValue();
+ // N->getOperand needs calling again because the Op variables may have been
+ // changed by the functions above
+ return DAG.expandPartialReduceAdd(DL, N->getOperand(0), N->getOperand(1),
+ N->getOperand(2));
}
static SDValue performIntrinsicCombine(SDNode *N,
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 3231a3fb0a67cf..d00ccc5af5536d 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -993,9 +993,6 @@ class AArch64TargetLowering : public TargetLowering {
bool shouldExpandGetActiveLaneMask(EVT VT, EVT OpVT) const override;
- bool
- shouldExpandPartialReductionIntrinsic(const IntrinsicInst *I) const override;
-
bool shouldExpandCttzElements(EVT VT) const override;
bool shouldExpandVectorMatch(EVT VT, unsigned SearchSize) const override;
>From 982d6e017cc7d38e9b68c568d21f82cef3a4e067 Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Mon, 20 Jan 2025 15:10:52 +0000
Subject: [PATCH 13/14] Add the MUL in LowerPARTIAL_REDUCE_MLA()
Only do it if Input2 is a splat vector of constant 1s. Still create
the MUL in the DAG combine for the wide add pattern. This is
because it is pruned if an operand is constant 1s, or changed to
a shift instruction if an operand is a power of 2. This would not
happen if the MUL was made in LowerPARTIAL_REDUCE_MLA.
---
llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 14 +++++++++++++-
1 file changed, 13 insertions(+), 1 deletion(-)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 35781592c3002f..7c727871210055 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -22120,7 +22120,10 @@ SDValue tryCombineToWideAdd(SDValue &Op0, SDValue &Op1, SDValue &Op2,
EVT AccVT = Op0->getValueType(0);
Op1 = Op1->getOperand(0);
EVT Op1VT = Op1.getValueType();
+ // Makes Op2's value type match the value type of Op1 without its extend.
Op2 = DAG.getAnyExtOrTrunc(Op2, DL, Op1VT);
+ // Make a MUL between Op1 and Op2 here so the MUL can be changed if possible
+ // (can be pruned or changed to a shift instruction for example).
SDValue Input = DAG.getNode(ISD::MUL, DL, Op1VT, Op1, Op2);
if (!AccVT.isScalableVector())
@@ -22133,6 +22136,7 @@ SDValue tryCombineToWideAdd(SDValue &Op0, SDValue &Op1, SDValue &Op2,
unsigned NewOpcode = Op1Opcode == ISD::SIGN_EXTEND ? ISD::PARTIAL_REDUCE_SMLA
: ISD::PARTIAL_REDUCE_UMLA;
+ // Return a constant of 1s for Op2 so the MUL is not performed again.
return DAG.getNode(NewOpcode, DL, AccVT, Op0, Input,
DAG.getConstant(1, DL, Op1VT));
}
@@ -29389,11 +29393,19 @@ AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
: AArch64ISD::UDOT;
return DAG.getNode(DotOpcode, DL, AccVT, Acc, Input1, Input2);
}
+
+ SDValue MulInput = Input1;
+ // If Input2 is a splat vector of constant 1 then the MUL instruction is not
+ // needed. If it was created here it would not be automatically pruned.
+ if (Input2.getOpcode() != ISD::SPLAT_VECTOR || Input2.getNumOperands() == 0 ||
+ !isOneConstant(Input2.getOperand(0)))
+ MulInput = DAG.getNode(ISD::MUL, DL, InputVT, Input1, Input2);
+
bool InputIsSigned = Opcode == ISD::PARTIAL_REDUCE_SMLA;
unsigned BottomOpcode =
InputIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
unsigned TopOpcode = InputIsSigned ? AArch64ISD::SADDWT : AArch64ISD::UADDWT;
- auto BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, Input1);
+ SDValue BottomNode = DAG.getNode(BottomOpcode, DL, AccVT, Acc, Input1);
return DAG.getNode(TopOpcode, DL, AccVT, BottomNode, Input1);
}
>From 03adf25f7a45464305aa442d8c2bd5f6dfd12608 Mon Sep 17 00:00:00 2001
From: James Chesterman <james.chesterman at arm.com>
Date: Thu, 23 Jan 2025 14:27:53 +0000
Subject: [PATCH 14/14] Separate the DAG combine into two stages for when there
is a Mul operation as Input1.
Also change the LangRef in ISDOpcodes.h for PARTIAL_REDUCE_MLA
nodes to set restrictions on what can be used for its inputs.
Rename functions to accord to PARTIAL_REDUCE_MLA rather than
PARTIAL_REDUCE_ADD.
---
llvm/include/llvm/CodeGen/ISDOpcodes.h | 13 +-
llvm/include/llvm/CodeGen/SelectionDAG.h | 2 +-
.../lib/CodeGen/SelectionDAG/SelectionDAG.cpp | 11 +-
.../Target/AArch64/AArch64ISelLowering.cpp | 176 ++++++++++--------
4 files changed, 114 insertions(+), 88 deletions(-)
diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h
index c844d9505b3ef4..8c8314e2ab9c70 100644
--- a/llvm/include/llvm/CodeGen/ISDOpcodes.h
+++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h
@@ -1455,13 +1455,12 @@ enum NodeType {
// Input1 and Input2 are multiplied together first. This result is then
// reduced, by addition, to the number of elements that the Accumulator's type
// has.
- // Input1 and Input2 must be the same type. Accumulator's element type must
- // match that of Input1 and Input2. The number of elements in Input1 and
- // Input2 must be a positive integer multiple of the number of elements in the
- // Accumulator.
- // The signedness of this node will dictate the signedness of nodes expanded
- // from it. The signedness of the node is dictated by the signedness of
- // Input1.
+ // Input1 and Input2 must be the same type. The Accumulator and the Output
+ // must be the same type.
+ // The number of elements in Input1 and Input2 must be a positive integer
+ // multiple of the number of elements in the Accumulator / Output type.
+ // Input1 and Input2 may have a different element type from Accumulator and
+ // Output.
// Operands: Accumulator, Input1, Input2
// Outputs: Output
PARTIAL_REDUCE_SMLA,
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index 9381b2465684b8..8d953522997c4a 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -1609,7 +1609,7 @@ class SelectionDAG {
/// reduction operation.
/// \p Input1 First input for the partial reduction operation.
/// \p Input2 Second input for the partial reduction operation.
- SDValue expandPartialReduceAdd(SDLoc DL, SDValue Acc, SDValue Input1,
+ SDValue expandPartialReduceMLA(SDLoc DL, SDValue Acc, SDValue Input1,
SDValue Input2);
/// Expands a node with multiple results to an FP or vector libcall. The
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 2e82385d5a8aff..255eec2e07ad0d 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -2467,12 +2467,17 @@ SDValue SelectionDAG::getShiftAmountOperand(EVT LHSTy, SDValue Op) {
return getZExtOrTrunc(Op, SDLoc(Op), ShTy);
}
-SDValue SelectionDAG::expandPartialReduceAdd(SDLoc DL, SDValue Acc,
+SDValue SelectionDAG::expandPartialReduceMLA(SDLoc DL, SDValue Acc,
SDValue Input1, SDValue Input2) {
EVT FullTy = Input1.getValueType();
- Input2 = getAnyExtOrTrunc(Input2, DL, FullTy);
- SDValue Input = getNode(ISD::MUL, DL, FullTy, Input1, Input2);
+ unsigned Input2Opcode = Input2.getOpcode();
+
+ SDValue Input = Input1;
+ if ((Input2Opcode != ISD::SPLAT_VECTOR &&
+ Input2Opcode != ISD::BUILD_VECTOR) ||
+ !isOneConstant(Input2.getOperand(0)))
+ Input = getNode(ISD::MUL, DL, FullTy, Input1, Input2);
EVT ReducedTy = Acc.getValueType();
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 7c727871210055..815bbfd6bb81d9 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -22010,7 +22010,38 @@ static SDValue tryCombineWhileLo(SDNode *N,
return SDValue(N, 0);
}
-SDValue tryCombineToDotProduct(SDValue &Op0, SDValue &Op1, SDValue &Op2,
+SDValue tryCombinePartialReduceMLAMulOp(SDValue &Op0, SDValue &Op1,
+ SDValue &Op2, SelectionDAG &DAG,
+ SDLoc &DL) {
+ // Makes PARTIAL_REDUCE_MLA(Acc, MUL(EXT(MulOpLHS), EXT(MulOpRHS)), Splat (1))
+ // into PARTIAL_REDUCE_MLA(Acc, EXT(MulOpLHS), EXT(MulOpRHS))
+ if (Op1->getOpcode() != ISD::MUL)
+ return SDValue();
+
+ SDValue ExtMulOpLHS = Op1->getOperand(0);
+ SDValue ExtMulOpRHS = Op1->getOperand(1);
+ unsigned ExtMulOpLHSOpcode = ExtMulOpLHS->getOpcode();
+ unsigned ExtMulOpRHSOpcode = ExtMulOpRHS->getOpcode();
+ if (!ISD::isExtOpcode(ExtMulOpLHSOpcode) ||
+ !ISD::isExtOpcode(ExtMulOpRHSOpcode))
+ return SDValue();
+
+ SDValue MulOpLHS = ExtMulOpLHS->getOperand(0);
+ SDValue MulOpRHS = ExtMulOpRHS->getOperand(0);
+ EVT MulOpLHSVT = MulOpLHS.getValueType();
+ if (MulOpLHSVT != MulOpRHS.getValueType())
+ return SDValue();
+
+ unsigned Op2Opcode = Op2->getOpcode();
+ if ((Op2Opcode != ISD::SPLAT_VECTOR && Op2Opcode != ISD::BUILD_VECTOR) ||
+ !isOneConstant(Op2->getOperand(0)))
+ return SDValue();
+
+ return DAG.getNode(ISD::PARTIAL_REDUCE_UMLA, DL, Op0->getValueType(0), Op0,
+ ExtMulOpLHS, ExtMulOpRHS);
+}
+
+SDValue tryCombineToDotProduct(SDValue &Op0, SDValue &ExtOp1, SDValue &ExtOp2,
SelectionDAG &DAG,
const AArch64Subtarget *Subtarget, SDLoc &DL) {
bool Scalable = Op0->getValueType(0).isScalableVector();
@@ -22019,56 +22050,51 @@ SDValue tryCombineToDotProduct(SDValue &Op0, SDValue &Op1, SDValue &Op2,
if (!Scalable && (!Subtarget->isNeonAvailable() || !Subtarget->hasDotProd()))
return SDValue();
- unsigned Op1Opcode = Op1->getOpcode();
- SDValue MulOpLHS, MulOpRHS;
- bool MulOpLHSIsSigned, MulOpRHSIsSigned;
- if (ISD::isExtOpcode(Op1Opcode)) {
- MulOpLHSIsSigned = MulOpRHSIsSigned = (Op1Opcode == ISD::SIGN_EXTEND);
- MulOpLHS = Op1->getOperand(0);
- MulOpRHS = DAG.getAnyExtOrTrunc(Op2, DL, MulOpLHS.getValueType());
- } else if (Op1Opcode == ISD::MUL) {
- SDValue ExtMulOpLHS = Op1->getOperand(0);
- SDValue ExtMulOpRHS = Op1->getOperand(1);
-
- unsigned ExtMulOpLHSOpcode = ExtMulOpLHS->getOpcode();
- unsigned ExtMulOpRHSOpcode = ExtMulOpRHS->getOpcode();
- if (!ISD::isExtOpcode(ExtMulOpLHSOpcode) ||
- !ISD::isExtOpcode(ExtMulOpRHSOpcode))
- return SDValue();
-
- MulOpLHSIsSigned = ExtMulOpLHSOpcode == ISD::SIGN_EXTEND;
- MulOpRHSIsSigned = ExtMulOpRHSOpcode == ISD::SIGN_EXTEND;
-
- MulOpLHS = ExtMulOpLHS->getOperand(0);
- MulOpRHS = ExtMulOpRHS->getOperand(0);
- EVT MulOpLHSVT = MulOpLHS.getValueType();
-
- if (MulOpLHSVT != MulOpRHS.getValueType())
+ unsigned ExtOp1Opcode = ExtOp1->getOpcode();
+ unsigned ExtOp2Opcode = ExtOp2->getOpcode();
+ SDValue Op1, Op2;
+ bool Op1IsSigned, Op2IsSigned;
+ if (!ISD::isExtOpcode(ExtOp1Opcode))
+ return SDValue();
+ Op1 = ExtOp1->getOperand(0);
+ EVT SrcVT = Op1.getValueType();
+
+ if ((ExtOp2Opcode == ISD::SPLAT_VECTOR ||
+ ExtOp2Opcode == ISD::BUILD_VECTOR) &&
+ isOneConstant(ExtOp2.getOperand(0))) {
+ // Makes PARTIAL_REDUCE_MLA(Acc, Ext(Op1), Splat(1)) into
+ // PARTIAL_REDUCE_MLA(Acc, Op1, Splat(1))
+ Op1IsSigned = Op2IsSigned = (ExtOp1Opcode == ISD::SIGN_EXTEND);
+ // Can only do this because it's a splat vector of constant 1
+ Op2 = DAG.getAnyExtOrTrunc(ExtOp2, DL, SrcVT);
+ } else if (ISD::isExtOpcode(ExtOp2Opcode)) {
+ // Makes PARTIAL_REDUCE_MLA(Acc, Ext(Op1), Ext(Op2)) into
+ // PARTIAL_REDUCE_MLA(Acc, Op1, Op2)
+ Op2 = ExtOp2->getOperand(0);
+ Op1IsSigned = ExtOp1Opcode == ISD::SIGN_EXTEND;
+ Op2IsSigned = ExtOp2Opcode == ISD::SIGN_EXTEND;
+ if (SrcVT != Op2.getValueType())
return SDValue();
-
- Op2 = DAG.getAnyExtOrTrunc(Op2, DL, MulOpLHSVT);
- MulOpLHS = DAG.getNode(ISD::MUL, DL, MulOpLHSVT, MulOpLHS, Op2);
- MulOpRHS = DAG.getNode(ISD::MUL, DL, MulOpLHSVT, MulOpRHS, Op2);
- } else
+ } else {
return SDValue();
+ }
SDValue Acc = Op0;
EVT ReducedVT = Acc->getValueType(0);
- EVT MulSrcVT = MulOpLHS.getValueType();
// Dot products operate on chunks of four elements so there must be four times
// as many elements in the wide type
- if (!(ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) &&
- !(ReducedVT == MVT::nxv4i32 && MulSrcVT == MVT::nxv16i8) &&
- !(ReducedVT == MVT::nxv2i64 && MulSrcVT == MVT::nxv8i16) &&
- !(ReducedVT == MVT::v4i64 && MulSrcVT == MVT::v16i8) &&
- !(ReducedVT == MVT::v4i32 && MulSrcVT == MVT::v16i8) &&
- !(ReducedVT == MVT::v2i32 && MulSrcVT == MVT::v8i8))
+ if (!(ReducedVT == MVT::nxv4i64 && SrcVT == MVT::nxv16i8) &&
+ !(ReducedVT == MVT::nxv4i32 && SrcVT == MVT::nxv16i8) &&
+ !(ReducedVT == MVT::nxv2i64 && SrcVT == MVT::nxv8i16) &&
+ !(ReducedVT == MVT::v4i64 && SrcVT == MVT::v16i8) &&
+ !(ReducedVT == MVT::v4i32 && SrcVT == MVT::v16i8) &&
+ !(ReducedVT == MVT::v2i32 && SrcVT == MVT::v8i8))
return SDValue();
// If the extensions are mixed, we should lower it to a usdot instead
- unsigned DotOpcode = MulOpLHSIsSigned ? AArch64ISD::SDOT : AArch64ISD::UDOT;
- if (MulOpLHSIsSigned != MulOpRHSIsSigned) {
+ unsigned DotOpcode = Op1IsSigned ? AArch64ISD::SDOT : AArch64ISD::UDOT;
+ if (Op1IsSigned != Op2IsSigned) {
if (!Subtarget->hasMatMulInt8())
return SDValue();
@@ -22077,71 +22103,69 @@ SDValue tryCombineToDotProduct(SDValue &Op0, SDValue &Op1, SDValue &Op2,
if (Scalable && ReducedVT != MVT::nxv4i32 && ReducedVT != MVT::nxv4i64)
return SDValue();
- if (!MulOpRHSIsSigned)
- std::swap(MulOpLHS, MulOpRHS);
+ if (!Op2IsSigned)
+ std::swap(Op1, Op2);
DotOpcode = AArch64ISD::USDOT;
// Lower usdot patterns here because legalisation would attempt to split it
// unless exts are removed. But, removing the exts would lose the
// information about whether each operand is signed.
- if ((ReducedVT != MVT::nxv4i64 || MulSrcVT != MVT::nxv16i8) &&
- (ReducedVT != MVT::v4i64 || MulSrcVT != MVT::v16i8))
- return DAG.getNode(DotOpcode, DL, ReducedVT, Acc, MulOpLHS, MulOpRHS);
+ if ((ReducedVT != MVT::nxv4i64 || SrcVT != MVT::nxv16i8) &&
+ (ReducedVT != MVT::v4i64 || SrcVT != MVT::v16i8))
+ return DAG.getNode(DotOpcode, DL, ReducedVT, Acc, Op1, Op2);
}
// Partial reduction lowering for (nx)v16i8 to (nx)v4i64 requires an i32 dot
// product followed by a zero / sign extension. Need to lower this here
// because legalisation would attempt to split it.
- if ((ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) ||
- (ReducedVT == MVT::v4i64 && MulSrcVT == MVT::v16i8)) {
+ if ((ReducedVT == MVT::nxv4i64 && SrcVT == MVT::nxv16i8) ||
+ (ReducedVT == MVT::v4i64 && SrcVT == MVT::v16i8)) {
EVT ReducedVTI32 =
(ReducedVT.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
SDValue DotI32 =
DAG.getNode(DotOpcode, DL, ReducedVTI32,
- DAG.getConstant(0, DL, ReducedVTI32), MulOpLHS, MulOpRHS);
+ DAG.getConstant(0, DL, ReducedVTI32), Op1, Op2);
SDValue Extended = DAG.getSExtOrTrunc(DotI32, DL, ReducedVT);
return DAG.getNode(ISD::ADD, DL, ReducedVT, Acc, Extended);
}
unsigned NewOpcode =
- MulOpLHSIsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
- return DAG.getNode(NewOpcode, DL, ReducedVT, Acc, MulOpLHS, MulOpRHS);
+ Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
+ return DAG.getNode(NewOpcode, DL, ReducedVT, Acc, Op1, Op2);
}
-SDValue tryCombineToWideAdd(SDValue &Op0, SDValue &Op1, SDValue &Op2,
+SDValue tryCombineToWideAdd(SDValue &Op0, SDValue &ExtOp1, SDValue &Op2,
SelectionDAG &DAG,
const AArch64Subtarget *Subtarget, SDLoc &DL) {
+ // Makes PARTIAL_REDUCE_MLA(Acc, Ext(Op1), Splat(1)) into
+ // PARTIAL_REDUCE_MLA(Acc, Op1, Splat(1))
if (!Subtarget->hasSVE2() && !Subtarget->isStreamingSVEAvailable())
return SDValue();
- unsigned Op1Opcode = Op1->getOpcode();
- if (!ISD::isExtOpcode(Op1Opcode))
- return SDValue();
-
EVT AccVT = Op0->getValueType(0);
- Op1 = Op1->getOperand(0);
+ unsigned ExtOp1Opcode = ExtOp1->getOpcode();
+ if (!ISD::isExtOpcode(ExtOp1Opcode))
+ return SDValue();
+ SDValue Op1 = ExtOp1->getOperand(0);
EVT Op1VT = Op1.getValueType();
- // Makes Op2's value type match the value type of Op1 without its extend.
- Op2 = DAG.getAnyExtOrTrunc(Op2, DL, Op1VT);
- // Make a MUL between Op1 and Op2 here so the MUL can be changed if possible
- // (can be pruned or changed to a shift instruction for example).
- SDValue Input = DAG.getNode(ISD::MUL, DL, Op1VT, Op1, Op2);
- if (!AccVT.isScalableVector())
+ unsigned Op2Opcode = Op2->getOpcode();
+ if (Op2Opcode != ISD::SPLAT_VECTOR || !isOneConstant(Op2->getOperand(0)))
return SDValue();
+ Op2 = DAG.getAnyExtOrTrunc(Op2, DL, Op1VT);
if (!(Op1VT == MVT::nxv4i32 && AccVT == MVT::nxv2i64) &&
!(Op1VT == MVT::nxv8i16 && AccVT == MVT::nxv4i32) &&
!(Op1VT == MVT::nxv16i8 && AccVT == MVT::nxv8i16))
return SDValue();
- unsigned NewOpcode = Op1Opcode == ISD::SIGN_EXTEND ? ISD::PARTIAL_REDUCE_SMLA
- : ISD::PARTIAL_REDUCE_UMLA;
- // Return a constant of 1s for Op2 so the MUL is not performed again.
- return DAG.getNode(NewOpcode, DL, AccVT, Op0, Input,
- DAG.getConstant(1, DL, Op1VT));
+ unsigned NewOpcode = ExtOp1Opcode == ISD::SIGN_EXTEND
+ ? ISD::PARTIAL_REDUCE_SMLA
+ : ISD::PARTIAL_REDUCE_UMLA;
+
+ return DAG.getNode(NewOpcode, DL, AccVT, Op0, Op1, Op2);
}
-SDValue performPartialReduceAddCombine(SDNode *N, SelectionDAG &DAG,
+SDValue performPartialReduceMLACombine(SDNode *N, SelectionDAG &DAG,
const AArch64Subtarget *Subtarget) {
SDLoc DL(N);
SDValue Op0 = N->getOperand(0);
@@ -22155,13 +22179,18 @@ SDValue performPartialReduceAddCombine(SDNode *N, SelectionDAG &DAG,
if (Op0ElemVT != Op1ElemVT || Op1.getOpcode() == AArch64ISD::USDOT)
return SDValue(N, 0);
+ if (auto MLA = tryCombinePartialReduceMLAMulOp(Op0, Op1, Op2, DAG, DL)) {
+ Op0 = MLA->getOperand(0);
+ Op1 = MLA->getOperand(1);
+ Op2 = MLA->getOperand(2);
+ }
if (auto Dot = tryCombineToDotProduct(Op0, Op1, Op2, DAG, Subtarget, DL))
return Dot;
if (auto WideAdd = tryCombineToWideAdd(Op0, Op1, Op2, DAG, Subtarget, DL))
return WideAdd;
// N->getOperand needs calling again because the Op variables may have been
// changed by the functions above
- return DAG.expandPartialReduceAdd(DL, N->getOperand(0), N->getOperand(1),
+ return DAG.expandPartialReduceMLA(DL, N->getOperand(0), N->getOperand(1),
N->getOperand(2));
}
@@ -26605,7 +26634,7 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
return performMaskedGatherScatterCombine(N, DCI, DAG);
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA:
- return performPartialReduceAddCombine(N, DAG, Subtarget);
+ return performPartialReduceMLACombine(N, DAG, Subtarget);
case ISD::FP_EXTEND:
return performFPExtendCombine(N, DAG, DCI, Subtarget);
case AArch64ISD::BRCOND:
@@ -29394,13 +29423,6 @@ AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
return DAG.getNode(DotOpcode, DL, AccVT, Acc, Input1, Input2);
}
- SDValue MulInput = Input1;
- // If Input2 is a splat vector of constant 1 then the MUL instruction is not
- // needed. If it was created here it would not be automatically pruned.
- if (Input2.getOpcode() != ISD::SPLAT_VECTOR || Input2.getNumOperands() == 0 ||
- !isOneConstant(Input2.getOperand(0)))
- MulInput = DAG.getNode(ISD::MUL, DL, InputVT, Input1, Input2);
-
bool InputIsSigned = Opcode == ISD::PARTIAL_REDUCE_SMLA;
unsigned BottomOpcode =
InputIsSigned ? AArch64ISD::SADDWB : AArch64ISD::UADDWB;
More information about the llvm-commits
mailing list