[llvm] [WebAssembly] Use partial_reduce_mla ISD nodes (PR #161184)
Sam Parker via llvm-commits
llvm-commits at lists.llvm.org
Mon Sep 29 05:01:36 PDT 2025
https://github.com/sparker-arm created https://github.com/llvm/llvm-project/pull/161184
Addresssing issue #160847.
Move away from combining the intrinsic call and instead lower the ISD nodes, using more tablegen for pattern matching.
>From bddd802db7ebdb4a2c4b98c19a50f7740d598d2b Mon Sep 17 00:00:00 2001
From: Sam Parker <sam.parker at arm.com>
Date: Mon, 29 Sep 2025 12:54:29 +0100
Subject: [PATCH] [WebAssembly] Use partial_reduce_mla ISD nodes
Move away from combining the intrinsic call and instead lower the ISD
nodes, using more tablegen for pattern matching.
---
.../WebAssembly/WebAssemblyISelLowering.cpp | 140 ++++++------------
.../WebAssembly/WebAssemblyISelLowering.h | 6 +-
.../WebAssembly/WebAssemblyInstrSIMD.td | 9 ++
.../WebAssembly/partial-reduce-accumulate.ll | 2 +-
4 files changed, 61 insertions(+), 96 deletions(-)
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index 64b9dc31f75b7..e830def066087 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -186,7 +186,6 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
// SIMD-specific configuration
if (Subtarget->hasSIMD128()) {
- // Combine partial.reduce.add before legalization gets confused.
setTargetDAGCombine(ISD::INTRINSIC_WO_CHAIN);
// Combine wide-vector muls, with extend inputs, to extmul_half.
@@ -317,6 +316,18 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
setOperationAction(ISD::SIGN_EXTEND_VECTOR_INREG, T, Custom);
setOperationAction(ISD::ZERO_EXTEND_VECTOR_INREG, T, Custom);
}
+
+ // Partial MLA reductions.
+ // We only have native support with i32x4.dot_i16x8_s, the rest are custom
+ // lowered.
+ setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SMLA, MVT::v4i32, MVT::v8i16,
+ Legal);
+ setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_UMLA, MVT::v4i32, MVT::v8i16,
+ Custom);
+ setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SMLA, MVT::v4i32, MVT::v16i8,
+ Custom);
+ setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_UMLA, MVT::v4i32, MVT::v16i8,
+ Custom);
}
// As a special case, these operators use the type to mean the type to
@@ -416,41 +427,6 @@ MVT WebAssemblyTargetLowering::getPointerMemTy(const DataLayout &DL,
return TargetLowering::getPointerMemTy(DL, AS);
}
-bool WebAssemblyTargetLowering::shouldExpandPartialReductionIntrinsic(
- const IntrinsicInst *I) const {
- if (I->getIntrinsicID() != Intrinsic::vector_partial_reduce_add)
- return true;
-
- EVT VT = EVT::getEVT(I->getType());
- if (VT.getSizeInBits() > 128)
- return true;
-
- auto Op1 = I->getOperand(1);
-
- if (auto *InputInst = dyn_cast<Instruction>(Op1)) {
- unsigned Opcode = InstructionOpcodeToISD(InputInst->getOpcode());
- if (Opcode == ISD::MUL) {
- if (isa<Instruction>(InputInst->getOperand(0)) &&
- isa<Instruction>(InputInst->getOperand(1))) {
- // dot only supports signed inputs but also support lowering unsigned.
- if (cast<Instruction>(InputInst->getOperand(0))->getOpcode() !=
- cast<Instruction>(InputInst->getOperand(1))->getOpcode())
- return true;
-
- EVT Op1VT = EVT::getEVT(Op1->getType());
- if (Op1VT.getVectorElementType() == VT.getVectorElementType() &&
- ((VT.getVectorElementCount() * 2 ==
- Op1VT.getVectorElementCount()) ||
- (VT.getVectorElementCount() * 4 == Op1VT.getVectorElementCount())))
- return false;
- }
- } else if (ISD::isExtOpcode(Opcode)) {
- return false;
- }
- }
- return true;
-}
-
TargetLowering::AtomicExpansionKind
WebAssemblyTargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const {
// We have wasm instructions for these
@@ -1706,6 +1682,9 @@ SDValue WebAssemblyTargetLowering::LowerOperation(SDValue Op,
return LowerMUL_LOHI(Op, DAG);
case ISD::UADDO:
return LowerUADDO(Op, DAG);
+ case ISD::PARTIAL_REDUCE_SMLA:
+ case ISD::PARTIAL_REDUCE_UMLA:
+ return LowerPARTIAL_REDUCE_MLA(Op, DAG);
}
}
@@ -2113,29 +2092,36 @@ SDValue WebAssemblyTargetLowering::LowerVASTART(SDValue Op,
MachinePointerInfo(SV));
}
-// Try to lower partial.reduce.add to a dot or fallback to a sequence with
-// extmul and adds.
-SDValue performLowerPartialReduction(SDNode *N, SelectionDAG &DAG) {
- assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN);
- if (N->getConstantOperandVal(0) != Intrinsic::vector_partial_reduce_add)
- return SDValue();
+// We only have native support with i32x4.dot_i16x8_s, so for the unsigned
+// case we can expand to extmul and add. For v16i8 inputs, we can use two dots,
+// for signed, for an expanded tree of extmul adds for unsigned.
+SDValue
+WebAssemblyTargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
+ SelectionDAG &DAG) const {
+ assert(Op->getValueType(0) == MVT::v4i32 && "can only support v4i32");
+ SDLoc DL(Op);
- assert(N->getValueType(0) == MVT::v4i32 && "can only support v4i32");
- SDLoc DL(N);
+ SDValue Acc = Op.getOperand(0);
+ SDValue ExtendInLHS = Op.getOperand(1);
+ SDValue ExtendInRHS = Op.getOperand(2);
+ bool IsSigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
- SDValue Input = N->getOperand(2);
- if (Input->getOpcode() == ISD::MUL) {
- SDValue ExtendLHS = Input->getOperand(0);
- SDValue ExtendRHS = Input->getOperand(1);
- assert((ISD::isExtOpcode(ExtendLHS.getOpcode()) &&
- ISD::isExtOpcode(ExtendRHS.getOpcode())) &&
- "expected widening mul or add");
- assert(ExtendLHS.getOpcode() == ExtendRHS.getOpcode() &&
- "expected binop to use the same extend for both operands");
-
- SDValue ExtendInLHS = ExtendLHS->getOperand(0);
- SDValue ExtendInRHS = ExtendRHS->getOperand(0);
- bool IsSigned = ExtendLHS->getOpcode() == ISD::SIGN_EXTEND;
+ APInt Imm;
+ if (ISD::isConstantSplatVector(ExtendInRHS.getNode(), Imm) && Imm == 1) {
+ // Accumulate the input using extadd_pairwise.
+ unsigned PairwiseOpc = IsSigned ? WebAssemblyISD::EXT_ADD_PAIRWISE_S
+ : WebAssemblyISD::EXT_ADD_PAIRWISE_U;
+ if (ExtendInLHS->getValueType(0) == MVT::v8i16) {
+ SDValue Add = DAG.getNode(PairwiseOpc, DL, MVT::v4i32, ExtendInLHS);
+ return DAG.getNode(ISD::ADD, DL, MVT::v4i32, Acc, Add);
+ }
+ assert(ExtendInLHS->getValueType(0) == MVT::v16i8 &&
+ "expected v16i8 input types");
+ SDValue Add =
+ DAG.getNode(PairwiseOpc, DL, MVT::v4i32,
+ DAG.getNode(PairwiseOpc, DL, MVT::v8i16, ExtendInLHS));
+ return DAG.getNode(ISD::ADD, DL, MVT::v4i32, Acc, Add);
+ } else {
unsigned LowOpc =
IsSigned ? WebAssemblyISD::EXTEND_LOW_S : WebAssemblyISD::EXTEND_LOW_U;
unsigned HighOpc = IsSigned ? WebAssemblyISD::EXTEND_HIGH_S
@@ -2151,22 +2137,15 @@ SDValue performLowerPartialReduction(SDNode *N, SelectionDAG &DAG) {
HighLHS = DAG.getNode(HighOpc, DL, VT, ExtendInLHS);
HighRHS = DAG.getNode(HighOpc, DL, VT, ExtendInRHS);
};
-
if (ExtendInLHS->getValueType(0) == MVT::v8i16) {
- if (IsSigned) {
- // i32x4.dot_i16x8_s
- SDValue Dot = DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32,
- ExtendInLHS, ExtendInRHS);
- return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Dot);
- }
-
+ assert(!IsSigned && "expected unsigned");
// (add (add (extmul_low_sx lhs, rhs), (extmul_high_sx lhs, rhs)))
MVT VT = MVT::v4i32;
AssignInputs(VT);
SDValue MulLow = DAG.getNode(ISD::MUL, DL, VT, LowLHS, LowRHS);
SDValue MulHigh = DAG.getNode(ISD::MUL, DL, VT, HighLHS, HighRHS);
SDValue Add = DAG.getNode(ISD::ADD, DL, VT, MulLow, MulHigh);
- return DAG.getNode(ISD::ADD, DL, VT, N->getOperand(1), Add);
+ return DAG.getNode(ISD::ADD, DL, VT, Acc, Add);
} else {
assert(ExtendInLHS->getValueType(0) == MVT::v16i8 &&
"expected v16i8 input types");
@@ -2179,7 +2158,7 @@ SDValue performLowerPartialReduction(SDNode *N, SelectionDAG &DAG) {
SDValue DotRHS =
DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, HighLHS, HighRHS);
SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, DotLHS, DotRHS);
- return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
+ return DAG.getNode(ISD::ADD, DL, MVT::v4i32, Acc, Add);
}
SDValue MulLow = DAG.getNode(ISD::MUL, DL, MVT::v8i16, LowLHS, LowRHS);
@@ -2190,26 +2169,8 @@ SDValue performLowerPartialReduction(SDNode *N, SelectionDAG &DAG) {
SDValue AddHigh = DAG.getNode(WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL,
MVT::v4i32, MulHigh);
SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, AddLow, AddHigh);
- return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
+ return DAG.getNode(ISD::ADD, DL, MVT::v4i32, Acc, Add);
}
- } else {
- // Accumulate the input using extadd_pairwise.
- assert(ISD::isExtOpcode(Input.getOpcode()) && "expected extend");
- bool IsSigned = Input->getOpcode() == ISD::SIGN_EXTEND;
- unsigned PairwiseOpc = IsSigned ? WebAssemblyISD::EXT_ADD_PAIRWISE_S
- : WebAssemblyISD::EXT_ADD_PAIRWISE_U;
- SDValue ExtendIn = Input->getOperand(0);
- if (ExtendIn->getValueType(0) == MVT::v8i16) {
- SDValue Add = DAG.getNode(PairwiseOpc, DL, MVT::v4i32, ExtendIn);
- return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
- }
-
- assert(ExtendIn->getValueType(0) == MVT::v16i8 &&
- "expected v16i8 input types");
- SDValue Add =
- DAG.getNode(PairwiseOpc, DL, MVT::v4i32,
- DAG.getNode(PairwiseOpc, DL, MVT::v8i16, ExtendIn));
- return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
}
}
@@ -3683,11 +3644,8 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
return performVectorTruncZeroCombine(N, DCI);
case ISD::TRUNCATE:
return performTruncateCombine(N, DCI);
- case ISD::INTRINSIC_WO_CHAIN: {
- if (auto AnyAllCombine = performAnyAllCombine(N, DCI.DAG))
- return AnyAllCombine;
- return performLowerPartialReduction(N, DCI.DAG);
- }
+ case ISD::INTRINSIC_WO_CHAIN:
+ return performAnyAllCombine(N, DCI.DAG);
case ISD::MUL:
return performMulCombine(N, DCI);
}
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h
index 72401a7a259c0..3ff8346e12a6f 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h
@@ -45,8 +45,6 @@ class WebAssemblyTargetLowering final : public TargetLowering {
/// right decision when generating code for different targets.
const WebAssemblySubtarget *Subtarget;
- bool
- shouldExpandPartialReductionIntrinsic(const IntrinsicInst *I) const override;
AtomicExpansionKind shouldExpandAtomicRMWInIR(AtomicRMWInst *) const override;
bool shouldScalarizeBinop(SDValue VecOp) const override;
FastISel *createFastISel(FunctionLoweringInfo &FuncInfo,
@@ -89,8 +87,7 @@ class WebAssemblyTargetLowering final : public TargetLowering {
bool CanLowerReturn(CallingConv::ID CallConv, MachineFunction &MF,
bool isVarArg,
const SmallVectorImpl<ISD::OutputArg> &Outs,
- LLVMContext &Context,
- const Type *RetTy) const override;
+ LLVMContext &Context, const Type *RetTy) const override;
SDValue LowerReturn(SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
const SmallVectorImpl<ISD::OutputArg> &Outs,
const SmallVectorImpl<SDValue> &OutVals, const SDLoc &dl,
@@ -134,6 +131,7 @@ class WebAssemblyTargetLowering final : public TargetLowering {
SDValue LowerMUL_LOHI(SDValue Op, SelectionDAG &DAG) const;
SDValue Replace128Op(SDNode *N, SelectionDAG &DAG) const;
SDValue LowerUADDO(SDValue Op, SelectionDAG &DAG) const;
+ SDValue LowerPARTIAL_REDUCE_MLA(SDValue Op, SelectionDAG &DAG) const;
// Custom DAG combine hooks
SDValue
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
index d8948ad2df037..b5724ecd90155 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
@@ -1159,6 +1159,9 @@ defm DOT : SIMD_I<(outs V128:$dst), (ins V128:$lhs, V128:$rhs), (outs), (ins),
186>;
def : Pat<(wasm_dot V128:$lhs, V128:$rhs),
(DOT $lhs, $rhs)>;
+def : Pat<(v4i32 (partial_reduce_smla (v4i32 V128:$acc), (v8i16 V128:$lhs),
+ (v8i16 V128:$rhs))),
+ (ADD_I32x4 (DOT $lhs, $rhs), $acc)>;
// Extending multiplication: extmul_{low,high}_P, extmul_high
def extend_t : SDTypeProfile<1, 1, [SDTCisVec<0>, SDTCisVec<1>]>;
@@ -1473,6 +1476,12 @@ def : Pat<(v4i32 (int_wasm_extadd_pairwise_signed (v8i16 V128:$in))),
(extadd_pairwise_s_I32x4 V128:$in)>;
def : Pat<(v8i16 (int_wasm_extadd_pairwise_signed (v16i8 V128:$in))),
(extadd_pairwise_s_I16x8 V128:$in)>;
+def : Pat<(v4i32 (partial_reduce_smla (v4i32 V128:$acc), (v8i16 V128:$in),
+ (I16x8.splat (i32 1)))),
+ (ADD_I32x4 (extadd_pairwise_s_I32x4 V128:$in), V128:$acc)>;
+def : Pat<(v4i32 (partial_reduce_umla (v4i32 V128:$acc), (v8i16 V128:$in),
+ (I16x8.splat (i32 1)))),
+ (ADD_I32x4 (extadd_pairwise_u_I32x4 V128:$in), V128:$acc)>;
// f64x2 <-> f32x4 conversions
def demote_t : SDTypeProfile<1, 1, [SDTCisVec<0>, SDTCisVec<1>]>;
diff --git a/llvm/test/CodeGen/WebAssembly/partial-reduce-accumulate.ll b/llvm/test/CodeGen/WebAssembly/partial-reduce-accumulate.ll
index 47ea762864cc2..c9e486a3f29b4 100644
--- a/llvm/test/CodeGen/WebAssembly/partial-reduce-accumulate.ll
+++ b/llvm/test/CodeGen/WebAssembly/partial-reduce-accumulate.ll
@@ -402,10 +402,10 @@ define hidden i32 @accumulate_add_s16_s16(ptr noundef readonly %a, ptr noundef
; MAX-BANDWIDTH: loop
; MAX-BANDWIDTH: v128.load
; MAX-BANDWIDTH: i32x4.extadd_pairwise_i16x8_s
-; MAX-BANDWIDTH: i32x4.add
; MAX-BANDWIDTH: v128.load
; MAX-BANDWIDTH: i32x4.extadd_pairwise_i16x8_s
; MAX-BANDWIDTH: i32x4.add
+; MAX-BANDWIDTH: i32x4.add
entry:
%cmp8.not = icmp eq i32 %N, 0
br i1 %cmp8.not, label %for.cond.cleanup, label %for.body
More information about the llvm-commits
mailing list