[llvm] [WebAssembly] Support partial-reduce accumulator (PR #158060)
Sam Parker via llvm-commits
llvm-commits at lists.llvm.org
Thu Sep 11 05:49:21 PDT 2025
https://github.com/sparker-arm created https://github.com/llvm/llvm-project/pull/158060
We currently only support partial.reduce.add in the case where we are performing a multiply-accumulate. Now add support for any partial reduction where the input is being extended, where we can take advantage of extadd_pairwise.
>From fc74f9bc20080f4b8d463382d27325ab79f95748 Mon Sep 17 00:00:00 2001
From: Sam Parker <sam.parker at arm.com>
Date: Thu, 11 Sep 2025 13:44:01 +0100
Subject: [PATCH] [WebAssembly] Support partial-reduce accumulator
We currently only support partial.reduce.add in the case where we are
performing a multiply-accumulate. Now add support for any partial
reduction where the input is being extended, where we can take
advantage of extadd_pairwise.
---
.../lib/Target/WebAssembly/WebAssemblyISD.def | 1 +
.../WebAssembly/WebAssemblyISelLowering.cpp | 180 ++++++++++--------
.../WebAssembly/WebAssemblyInstrSIMD.td | 9 +-
.../WebAssemblyTargetTransformInfo.cpp | 27 ++-
4 files changed, 127 insertions(+), 90 deletions(-)
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISD.def b/llvm/lib/Target/WebAssembly/WebAssemblyISD.def
index 1eae3586d16b8..23108e429eda8 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISD.def
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISD.def
@@ -28,6 +28,7 @@ HANDLE_NODETYPE(BR_IF)
HANDLE_NODETYPE(BR_TABLE)
HANDLE_NODETYPE(DOT)
HANDLE_NODETYPE(EXT_ADD_PAIRWISE_U)
+HANDLE_NODETYPE(EXT_ADD_PAIRWISE_S)
HANDLE_NODETYPE(SHUFFLE)
HANDLE_NODETYPE(SWIZZLE)
HANDLE_NODETYPE(VEC_SHL)
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index fe100dab427ef..aea27ba32d37e 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -422,24 +422,30 @@ bool WebAssemblyTargetLowering::shouldExpandPartialReductionIntrinsic(
return true;
EVT VT = EVT::getEVT(I->getType());
+ if (VT.getSizeInBits() > 128)
+ return true;
+
auto Op1 = I->getOperand(1);
if (auto *InputInst = dyn_cast<Instruction>(Op1)) {
- if (InstructionOpcodeToISD(InputInst->getOpcode()) != ISD::MUL)
- return true;
-
- if (isa<Instruction>(InputInst->getOperand(0)) &&
- isa<Instruction>(InputInst->getOperand(1))) {
- // dot only supports signed inputs but also support lowering unsigned.
- if (cast<Instruction>(InputInst->getOperand(0))->getOpcode() !=
- cast<Instruction>(InputInst->getOperand(1))->getOpcode())
- return true;
-
- EVT Op1VT = EVT::getEVT(Op1->getType());
- if (Op1VT.getVectorElementType() == VT.getVectorElementType() &&
- ((VT.getVectorElementCount() * 2 == Op1VT.getVectorElementCount()) ||
- (VT.getVectorElementCount() * 4 == Op1VT.getVectorElementCount())))
- return false;
+ unsigned Opcode = InstructionOpcodeToISD(InputInst->getOpcode());
+ if (Opcode == ISD::MUL) {
+ if (isa<Instruction>(InputInst->getOperand(0)) &&
+ isa<Instruction>(InputInst->getOperand(1))) {
+ // dot only supports signed inputs but also support lowering unsigned.
+ if (cast<Instruction>(InputInst->getOperand(0))->getOpcode() !=
+ cast<Instruction>(InputInst->getOperand(1))->getOpcode())
+ return true;
+
+ EVT Op1VT = EVT::getEVT(Op1->getType());
+ if (Op1VT.getVectorElementType() == VT.getVectorElementType() &&
+ ((VT.getVectorElementCount() * 2 ==
+ Op1VT.getVectorElementCount()) ||
+ (VT.getVectorElementCount() * 4 == Op1VT.getVectorElementCount())))
+ return false;
+ }
+ } else if (ISD::isExtOpcode(Opcode)) {
+ return false;
}
}
return true;
@@ -2117,77 +2123,93 @@ SDValue performLowerPartialReduction(SDNode *N, SelectionDAG &DAG) {
assert(N->getValueType(0) == MVT::v4i32 && "can only support v4i32");
SDLoc DL(N);
- SDValue Mul = N->getOperand(2);
- assert(Mul->getOpcode() == ISD::MUL && "expected mul input");
-
- SDValue ExtendLHS = Mul->getOperand(0);
- SDValue ExtendRHS = Mul->getOperand(1);
- assert((ISD::isExtOpcode(ExtendLHS.getOpcode()) &&
- ISD::isExtOpcode(ExtendRHS.getOpcode())) &&
- "expected widening mul");
- assert(ExtendLHS.getOpcode() == ExtendRHS.getOpcode() &&
- "expected mul to use the same extend for both operands");
-
- SDValue ExtendInLHS = ExtendLHS->getOperand(0);
- SDValue ExtendInRHS = ExtendRHS->getOperand(0);
- bool IsSigned = ExtendLHS->getOpcode() == ISD::SIGN_EXTEND;
-
- if (ExtendInLHS->getValueType(0) == MVT::v8i16) {
- if (IsSigned) {
- // i32x4.dot_i16x8_s
- SDValue Dot = DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32,
- ExtendInLHS, ExtendInRHS);
- return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Dot);
- }
- unsigned LowOpc = WebAssemblyISD::EXTEND_LOW_U;
- unsigned HighOpc = WebAssemblyISD::EXTEND_HIGH_U;
+ SDValue Input = N->getOperand(2);
+ if (Input->getOpcode() == ISD::MUL) {
+ SDValue ExtendLHS = Input->getOperand(0);
+ SDValue ExtendRHS = Input->getOperand(1);
+ assert((ISD::isExtOpcode(ExtendLHS.getOpcode()) &&
+ ISD::isExtOpcode(ExtendRHS.getOpcode())) &&
+ "expected widening mul or add");
+ assert(ExtendLHS.getOpcode() == ExtendRHS.getOpcode() &&
+ "expected binop to use the same extend for both operands");
+
+ SDValue ExtendInLHS = ExtendLHS->getOperand(0);
+ SDValue ExtendInRHS = ExtendRHS->getOperand(0);
+ bool IsSigned = ExtendLHS->getOpcode() == ISD::SIGN_EXTEND;
+ unsigned LowOpc =
+ IsSigned ? WebAssemblyISD::EXTEND_LOW_S : WebAssemblyISD::EXTEND_LOW_U;
+ unsigned HighOpc = IsSigned ? WebAssemblyISD::EXTEND_HIGH_S
+ : WebAssemblyISD::EXTEND_HIGH_U;
+ SDValue LowLHS;
+ SDValue LowRHS;
+ SDValue HighLHS;
+ SDValue HighRHS;
+
+ auto AssignInputs = [&](MVT VT) {
+ LowLHS = DAG.getNode(LowOpc, DL, VT, ExtendInLHS);
+ LowRHS = DAG.getNode(LowOpc, DL, VT, ExtendInRHS);
+ HighLHS = DAG.getNode(HighOpc, DL, VT, ExtendInLHS);
+ HighRHS = DAG.getNode(HighOpc, DL, VT, ExtendInRHS);
+ };
- // (add (add (extmul_low_sx lhs, rhs), (extmul_high_sx lhs, rhs)))
- SDValue LowLHS = DAG.getNode(LowOpc, DL, MVT::v4i32, ExtendInLHS);
- SDValue LowRHS = DAG.getNode(LowOpc, DL, MVT::v4i32, ExtendInRHS);
- SDValue HighLHS = DAG.getNode(HighOpc, DL, MVT::v4i32, ExtendInLHS);
- SDValue HighRHS = DAG.getNode(HighOpc, DL, MVT::v4i32, ExtendInRHS);
+ if (ExtendInLHS->getValueType(0) == MVT::v8i16) {
+ if (IsSigned) {
+ // i32x4.dot_i16x8_s
+ SDValue Dot = DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32,
+ ExtendInLHS, ExtendInRHS);
+ return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Dot);
+ }
- SDValue MulLow = DAG.getNode(ISD::MUL, DL, MVT::v4i32, LowLHS, LowRHS);
- SDValue MulHigh = DAG.getNode(ISD::MUL, DL, MVT::v4i32, HighLHS, HighRHS);
- SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, MulLow, MulHigh);
- return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
+ // (add (add (extmul_low_sx lhs, rhs), (extmul_high_sx lhs, rhs)))
+ MVT VT = MVT::v4i32;
+ AssignInputs(VT);
+ SDValue MulLow = DAG.getNode(ISD::MUL, DL, VT, LowLHS, LowRHS);
+ SDValue MulHigh = DAG.getNode(ISD::MUL, DL, VT, HighLHS, HighRHS);
+ SDValue Add = DAG.getNode(ISD::ADD, DL, VT, MulLow, MulHigh);
+ return DAG.getNode(ISD::ADD, DL, VT, N->getOperand(1), Add);
+ } else {
+ assert(ExtendInLHS->getValueType(0) == MVT::v16i8 &&
+ "expected v16i8 input types");
+ AssignInputs(MVT::v8i16);
+ // Lower to a wider tree, using twice the operations compared to above.
+ if (IsSigned) {
+ // Use two dots
+ SDValue DotLHS =
+ DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, LowLHS, LowRHS);
+ SDValue DotRHS =
+ DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, HighLHS, HighRHS);
+ SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, DotLHS, DotRHS);
+ return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
+ }
+
+ SDValue MulLow = DAG.getNode(ISD::MUL, DL, MVT::v8i16, LowLHS, LowRHS);
+ SDValue MulHigh = DAG.getNode(ISD::MUL, DL, MVT::v8i16, HighLHS, HighRHS);
+
+ SDValue AddLow = DAG.getNode(WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL,
+ MVT::v4i32, MulLow);
+ SDValue AddHigh = DAG.getNode(WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL,
+ MVT::v4i32, MulHigh);
+ SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, AddLow, AddHigh);
+ return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
+ }
} else {
- assert(ExtendInLHS->getValueType(0) == MVT::v16i8 &&
- "expected v16i8 input types");
- // Lower to a wider tree, using twice the operations compared to above.
- if (IsSigned) {
- // Use two dots
- unsigned LowOpc = WebAssemblyISD::EXTEND_LOW_S;
- unsigned HighOpc = WebAssemblyISD::EXTEND_HIGH_S;
- SDValue LowLHS = DAG.getNode(LowOpc, DL, MVT::v8i16, ExtendInLHS);
- SDValue LowRHS = DAG.getNode(LowOpc, DL, MVT::v8i16, ExtendInRHS);
- SDValue HighLHS = DAG.getNode(HighOpc, DL, MVT::v8i16, ExtendInLHS);
- SDValue HighRHS = DAG.getNode(HighOpc, DL, MVT::v8i16, ExtendInRHS);
- SDValue DotLHS =
- DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, LowLHS, LowRHS);
- SDValue DotRHS =
- DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, HighLHS, HighRHS);
- SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, DotLHS, DotRHS);
+ // Accumulate the input using extadd_pairwise.
+ assert(ISD::isExtOpcode(Input.getOpcode()) && "expected extend");
+ bool IsSigned = Input->getOpcode() == ISD::SIGN_EXTEND;
+ unsigned PairwiseOpc = IsSigned ? WebAssemblyISD::EXT_ADD_PAIRWISE_S
+ : WebAssemblyISD::EXT_ADD_PAIRWISE_U;
+ SDValue ExtendIn = Input->getOperand(0);
+ if (ExtendIn->getValueType(0) == MVT::v8i16) {
+ SDValue Add = DAG.getNode(PairwiseOpc, DL, MVT::v4i32, ExtendIn);
return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
}
- unsigned LowOpc = WebAssemblyISD::EXTEND_LOW_U;
- unsigned HighOpc = WebAssemblyISD::EXTEND_HIGH_U;
- SDValue LowLHS = DAG.getNode(LowOpc, DL, MVT::v8i16, ExtendInLHS);
- SDValue LowRHS = DAG.getNode(LowOpc, DL, MVT::v8i16, ExtendInRHS);
- SDValue HighLHS = DAG.getNode(HighOpc, DL, MVT::v8i16, ExtendInLHS);
- SDValue HighRHS = DAG.getNode(HighOpc, DL, MVT::v8i16, ExtendInRHS);
-
- SDValue MulLow = DAG.getNode(ISD::MUL, DL, MVT::v8i16, LowLHS, LowRHS);
- SDValue MulHigh = DAG.getNode(ISD::MUL, DL, MVT::v8i16, HighLHS, HighRHS);
-
- SDValue AddLow =
- DAG.getNode(WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL, MVT::v4i32, MulLow);
- SDValue AddHigh = DAG.getNode(WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL,
- MVT::v4i32, MulHigh);
- SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, AddLow, AddHigh);
+ assert(ExtendIn->getValueType(0) == MVT::v16i8 &&
+ "expected v16i8 input types");
+ SDValue Add =
+ DAG.getNode(PairwiseOpc, DL, MVT::v4i32,
+ DAG.getNode(PairwiseOpc, DL, MVT::v8i16, ExtendIn));
return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
}
}
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
index 3c26b453c4482..d8948ad2df037 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
@@ -1454,12 +1454,13 @@ def : Pat<(t1.vt (bitconvert (t2.vt V128:$v))), (t1.vt V128:$v)>;
// Extended pairwise addition
def extadd_pairwise_u : SDNode<"WebAssemblyISD::EXT_ADD_PAIRWISE_U", extend_t>;
+def extadd_pairwise_s : SDNode<"WebAssemblyISD::EXT_ADD_PAIRWISE_S", extend_t>;
-defm "" : SIMDConvert<I16x8, I8x16, int_wasm_extadd_pairwise_signed,
+defm "" : SIMDConvert<I16x8, I8x16, extadd_pairwise_s,
"extadd_pairwise_i8x16_s", 0x7c>;
defm "" : SIMDConvert<I16x8, I8x16, extadd_pairwise_u,
"extadd_pairwise_i8x16_u", 0x7d>;
-defm "" : SIMDConvert<I32x4, I16x8, int_wasm_extadd_pairwise_signed,
+defm "" : SIMDConvert<I32x4, I16x8, extadd_pairwise_s,
"extadd_pairwise_i16x8_s", 0x7e>;
defm "" : SIMDConvert<I32x4, I16x8, extadd_pairwise_u,
"extadd_pairwise_i16x8_u", 0x7f>;
@@ -1468,6 +1469,10 @@ def : Pat<(v4i32 (int_wasm_extadd_pairwise_unsigned (v8i16 V128:$in))),
(extadd_pairwise_u_I32x4 V128:$in)>;
def : Pat<(v8i16 (int_wasm_extadd_pairwise_unsigned (v16i8 V128:$in))),
(extadd_pairwise_u_I16x8 V128:$in)>;
+def : Pat<(v4i32 (int_wasm_extadd_pairwise_signed (v8i16 V128:$in))),
+ (extadd_pairwise_s_I32x4 V128:$in)>;
+def : Pat<(v8i16 (int_wasm_extadd_pairwise_signed (v16i8 V128:$in))),
+ (extadd_pairwise_s_I16x8 V128:$in)>;
// f64x2 <-> f32x4 conversions
def demote_t : SDTypeProfile<1, 1, [SDTCisVec<0>, SDTCisVec<1>]>;
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
index 0eefd3e2b3500..92a9812df2127 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
@@ -316,7 +316,13 @@ InstructionCost WebAssemblyTTIImpl::getPartialReductionCost(
if (CostKind != TTI::TCK_RecipThroughput)
return Invalid;
- InstructionCost Cost(TTI::TCC_Basic);
+ if (Opcode != Instruction::Add)
+ return Invalid;
+
+ EVT AccumEVT = EVT::getEVT(AccumType);
+ // TODO: Add i64 accumulator.
+ if (AccumEVT != MVT::i32)
+ return Invalid;
// Possible options:
// - i16x8.extadd_pairwise_i8x16_sx
@@ -324,23 +330,26 @@ InstructionCost WebAssemblyTTIImpl::getPartialReductionCost(
// - i32x4.dot_i16x8_s
// Only try to support dot, for now.
- if (Opcode != Instruction::Add)
+ EVT InputEVT = EVT::getEVT(InputTypeA);
+ if (!((InputEVT == MVT::i16 && VF.getFixedValue() == 8) ||
+ (InputEVT == MVT::i8 && VF.getFixedValue() == 16))) {
return Invalid;
+ }
- if (!BinOp || *BinOp != Instruction::Mul)
+ if (OpAExtend == TTI::PR_None)
return Invalid;
- if (InputTypeA != InputTypeB)
- return Invalid;
+ InstructionCost Cost(TTI::TCC_Basic);
+ if (!BinOp)
+ return Cost;
if (OpAExtend != OpBExtend)
return Invalid;
- EVT InputEVT = EVT::getEVT(InputTypeA);
- EVT AccumEVT = EVT::getEVT(AccumType);
+ if (*BinOp != Instruction::Mul)
+ return Invalid;
- // TODO: Add i64 accumulator.
- if (AccumEVT != MVT::i32)
+ if (InputTypeA != InputTypeB)
return Invalid;
// Signed inputs can lower to dot
More information about the llvm-commits
mailing list