[llvm] [WebAssembly] Support partial-reduce accumulator (PR #158060)

Sam Parker via llvm-commits llvm-commits at lists.llvm.org
Thu Sep 11 05:49:21 PDT 2025


https://github.com/sparker-arm created https://github.com/llvm/llvm-project/pull/158060

We currently only support partial.reduce.add in the case where we are performing a multiply-accumulate. Now add support for any partial reduction where the input is being extended, where we can take advantage of extadd_pairwise.

>From fc74f9bc20080f4b8d463382d27325ab79f95748 Mon Sep 17 00:00:00 2001
From: Sam Parker <sam.parker at arm.com>
Date: Thu, 11 Sep 2025 13:44:01 +0100
Subject: [PATCH] [WebAssembly] Support partial-reduce accumulator

We currently only support partial.reduce.add in the case where we are
performing a multiply-accumulate. Now add support for any partial
reduction where the input is being extended, where we can take
advantage of extadd_pairwise.
---
 .../lib/Target/WebAssembly/WebAssemblyISD.def |   1 +
 .../WebAssembly/WebAssemblyISelLowering.cpp   | 180 ++++++++++--------
 .../WebAssembly/WebAssemblyInstrSIMD.td       |   9 +-
 .../WebAssemblyTargetTransformInfo.cpp        |  27 ++-
 4 files changed, 127 insertions(+), 90 deletions(-)

diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISD.def b/llvm/lib/Target/WebAssembly/WebAssemblyISD.def
index 1eae3586d16b8..23108e429eda8 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISD.def
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISD.def
@@ -28,6 +28,7 @@ HANDLE_NODETYPE(BR_IF)
 HANDLE_NODETYPE(BR_TABLE)
 HANDLE_NODETYPE(DOT)
 HANDLE_NODETYPE(EXT_ADD_PAIRWISE_U)
+HANDLE_NODETYPE(EXT_ADD_PAIRWISE_S)
 HANDLE_NODETYPE(SHUFFLE)
 HANDLE_NODETYPE(SWIZZLE)
 HANDLE_NODETYPE(VEC_SHL)
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index fe100dab427ef..aea27ba32d37e 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -422,24 +422,30 @@ bool WebAssemblyTargetLowering::shouldExpandPartialReductionIntrinsic(
     return true;
 
   EVT VT = EVT::getEVT(I->getType());
+  if (VT.getSizeInBits() > 128)
+    return true;
+
   auto Op1 = I->getOperand(1);
 
   if (auto *InputInst = dyn_cast<Instruction>(Op1)) {
-    if (InstructionOpcodeToISD(InputInst->getOpcode()) != ISD::MUL)
-      return true;
-
-    if (isa<Instruction>(InputInst->getOperand(0)) &&
-        isa<Instruction>(InputInst->getOperand(1))) {
-      // dot only supports signed inputs but also support lowering unsigned.
-      if (cast<Instruction>(InputInst->getOperand(0))->getOpcode() !=
-          cast<Instruction>(InputInst->getOperand(1))->getOpcode())
-        return true;
-
-      EVT Op1VT = EVT::getEVT(Op1->getType());
-      if (Op1VT.getVectorElementType() == VT.getVectorElementType() &&
-          ((VT.getVectorElementCount() * 2 == Op1VT.getVectorElementCount()) ||
-           (VT.getVectorElementCount() * 4 == Op1VT.getVectorElementCount())))
-        return false;
+    unsigned Opcode = InstructionOpcodeToISD(InputInst->getOpcode());
+    if (Opcode == ISD::MUL) {
+      if (isa<Instruction>(InputInst->getOperand(0)) &&
+          isa<Instruction>(InputInst->getOperand(1))) {
+        // dot only supports signed inputs but also support lowering unsigned.
+        if (cast<Instruction>(InputInst->getOperand(0))->getOpcode() !=
+            cast<Instruction>(InputInst->getOperand(1))->getOpcode())
+          return true;
+
+        EVT Op1VT = EVT::getEVT(Op1->getType());
+        if (Op1VT.getVectorElementType() == VT.getVectorElementType() &&
+            ((VT.getVectorElementCount() * 2 ==
+              Op1VT.getVectorElementCount()) ||
+             (VT.getVectorElementCount() * 4 == Op1VT.getVectorElementCount())))
+          return false;
+      }
+    } else if (ISD::isExtOpcode(Opcode)) {
+      return false;
     }
   }
   return true;
@@ -2117,77 +2123,93 @@ SDValue performLowerPartialReduction(SDNode *N, SelectionDAG &DAG) {
 
   assert(N->getValueType(0) == MVT::v4i32 && "can only support v4i32");
   SDLoc DL(N);
-  SDValue Mul = N->getOperand(2);
-  assert(Mul->getOpcode() == ISD::MUL && "expected mul input");
-
-  SDValue ExtendLHS = Mul->getOperand(0);
-  SDValue ExtendRHS = Mul->getOperand(1);
-  assert((ISD::isExtOpcode(ExtendLHS.getOpcode()) &&
-          ISD::isExtOpcode(ExtendRHS.getOpcode())) &&
-         "expected widening mul");
-  assert(ExtendLHS.getOpcode() == ExtendRHS.getOpcode() &&
-         "expected mul to use the same extend for both operands");
-
-  SDValue ExtendInLHS = ExtendLHS->getOperand(0);
-  SDValue ExtendInRHS = ExtendRHS->getOperand(0);
-  bool IsSigned = ExtendLHS->getOpcode() == ISD::SIGN_EXTEND;
-
-  if (ExtendInLHS->getValueType(0) == MVT::v8i16) {
-    if (IsSigned) {
-      // i32x4.dot_i16x8_s
-      SDValue Dot = DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32,
-                                ExtendInLHS, ExtendInRHS);
-      return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Dot);
-    }
 
-    unsigned LowOpc = WebAssemblyISD::EXTEND_LOW_U;
-    unsigned HighOpc = WebAssemblyISD::EXTEND_HIGH_U;
+  SDValue Input = N->getOperand(2);
+  if (Input->getOpcode() == ISD::MUL) {
+    SDValue ExtendLHS = Input->getOperand(0);
+    SDValue ExtendRHS = Input->getOperand(1);
+    assert((ISD::isExtOpcode(ExtendLHS.getOpcode()) &&
+            ISD::isExtOpcode(ExtendRHS.getOpcode())) &&
+           "expected widening mul or add");
+    assert(ExtendLHS.getOpcode() == ExtendRHS.getOpcode() &&
+           "expected binop to use the same extend for both operands");
+
+    SDValue ExtendInLHS = ExtendLHS->getOperand(0);
+    SDValue ExtendInRHS = ExtendRHS->getOperand(0);
+    bool IsSigned = ExtendLHS->getOpcode() == ISD::SIGN_EXTEND;
+    unsigned LowOpc =
+        IsSigned ? WebAssemblyISD::EXTEND_LOW_S : WebAssemblyISD::EXTEND_LOW_U;
+    unsigned HighOpc = IsSigned ? WebAssemblyISD::EXTEND_HIGH_S
+                                : WebAssemblyISD::EXTEND_HIGH_U;
+    SDValue LowLHS;
+    SDValue LowRHS;
+    SDValue HighLHS;
+    SDValue HighRHS;
+
+    auto AssignInputs = [&](MVT VT) {
+      LowLHS = DAG.getNode(LowOpc, DL, VT, ExtendInLHS);
+      LowRHS = DAG.getNode(LowOpc, DL, VT, ExtendInRHS);
+      HighLHS = DAG.getNode(HighOpc, DL, VT, ExtendInLHS);
+      HighRHS = DAG.getNode(HighOpc, DL, VT, ExtendInRHS);
+    };
 
-    // (add (add (extmul_low_sx lhs, rhs), (extmul_high_sx lhs, rhs)))
-    SDValue LowLHS = DAG.getNode(LowOpc, DL, MVT::v4i32, ExtendInLHS);
-    SDValue LowRHS = DAG.getNode(LowOpc, DL, MVT::v4i32, ExtendInRHS);
-    SDValue HighLHS = DAG.getNode(HighOpc, DL, MVT::v4i32, ExtendInLHS);
-    SDValue HighRHS = DAG.getNode(HighOpc, DL, MVT::v4i32, ExtendInRHS);
+    if (ExtendInLHS->getValueType(0) == MVT::v8i16) {
+      if (IsSigned) {
+        // i32x4.dot_i16x8_s
+        SDValue Dot = DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32,
+                                  ExtendInLHS, ExtendInRHS);
+        return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Dot);
+      }
 
-    SDValue MulLow = DAG.getNode(ISD::MUL, DL, MVT::v4i32, LowLHS, LowRHS);
-    SDValue MulHigh = DAG.getNode(ISD::MUL, DL, MVT::v4i32, HighLHS, HighRHS);
-    SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, MulLow, MulHigh);
-    return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
+      // (add (add (extmul_low_sx lhs, rhs), (extmul_high_sx lhs, rhs)))
+      MVT VT = MVT::v4i32;
+      AssignInputs(VT);
+      SDValue MulLow = DAG.getNode(ISD::MUL, DL, VT, LowLHS, LowRHS);
+      SDValue MulHigh = DAG.getNode(ISD::MUL, DL, VT, HighLHS, HighRHS);
+      SDValue Add = DAG.getNode(ISD::ADD, DL, VT, MulLow, MulHigh);
+      return DAG.getNode(ISD::ADD, DL, VT, N->getOperand(1), Add);
+    } else {
+      assert(ExtendInLHS->getValueType(0) == MVT::v16i8 &&
+             "expected v16i8 input types");
+      AssignInputs(MVT::v8i16);
+      // Lower to a wider tree, using twice the operations compared to above.
+      if (IsSigned) {
+        // Use two dots
+        SDValue DotLHS =
+            DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, LowLHS, LowRHS);
+        SDValue DotRHS =
+            DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, HighLHS, HighRHS);
+        SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, DotLHS, DotRHS);
+        return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
+      }
+
+      SDValue MulLow = DAG.getNode(ISD::MUL, DL, MVT::v8i16, LowLHS, LowRHS);
+      SDValue MulHigh = DAG.getNode(ISD::MUL, DL, MVT::v8i16, HighLHS, HighRHS);
+
+      SDValue AddLow = DAG.getNode(WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL,
+                                   MVT::v4i32, MulLow);
+      SDValue AddHigh = DAG.getNode(WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL,
+                                    MVT::v4i32, MulHigh);
+      SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, AddLow, AddHigh);
+      return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
+    }
   } else {
-    assert(ExtendInLHS->getValueType(0) == MVT::v16i8 &&
-           "expected v16i8 input types");
-    // Lower to a wider tree, using twice the operations compared to above.
-    if (IsSigned) {
-      // Use two dots
-      unsigned LowOpc = WebAssemblyISD::EXTEND_LOW_S;
-      unsigned HighOpc = WebAssemblyISD::EXTEND_HIGH_S;
-      SDValue LowLHS = DAG.getNode(LowOpc, DL, MVT::v8i16, ExtendInLHS);
-      SDValue LowRHS = DAG.getNode(LowOpc, DL, MVT::v8i16, ExtendInRHS);
-      SDValue HighLHS = DAG.getNode(HighOpc, DL, MVT::v8i16, ExtendInLHS);
-      SDValue HighRHS = DAG.getNode(HighOpc, DL, MVT::v8i16, ExtendInRHS);
-      SDValue DotLHS =
-          DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, LowLHS, LowRHS);
-      SDValue DotRHS =
-          DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, HighLHS, HighRHS);
-      SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, DotLHS, DotRHS);
+    // Accumulate the input using extadd_pairwise.
+    assert(ISD::isExtOpcode(Input.getOpcode()) && "expected extend");
+    bool IsSigned = Input->getOpcode() == ISD::SIGN_EXTEND;
+    unsigned PairwiseOpc = IsSigned ? WebAssemblyISD::EXT_ADD_PAIRWISE_S
+                                    : WebAssemblyISD::EXT_ADD_PAIRWISE_U;
+    SDValue ExtendIn = Input->getOperand(0);
+    if (ExtendIn->getValueType(0) == MVT::v8i16) {
+      SDValue Add = DAG.getNode(PairwiseOpc, DL, MVT::v4i32, ExtendIn);
       return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
     }
 
-    unsigned LowOpc = WebAssemblyISD::EXTEND_LOW_U;
-    unsigned HighOpc = WebAssemblyISD::EXTEND_HIGH_U;
-    SDValue LowLHS = DAG.getNode(LowOpc, DL, MVT::v8i16, ExtendInLHS);
-    SDValue LowRHS = DAG.getNode(LowOpc, DL, MVT::v8i16, ExtendInRHS);
-    SDValue HighLHS = DAG.getNode(HighOpc, DL, MVT::v8i16, ExtendInLHS);
-    SDValue HighRHS = DAG.getNode(HighOpc, DL, MVT::v8i16, ExtendInRHS);
-
-    SDValue MulLow = DAG.getNode(ISD::MUL, DL, MVT::v8i16, LowLHS, LowRHS);
-    SDValue MulHigh = DAG.getNode(ISD::MUL, DL, MVT::v8i16, HighLHS, HighRHS);
-
-    SDValue AddLow =
-        DAG.getNode(WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL, MVT::v4i32, MulLow);
-    SDValue AddHigh = DAG.getNode(WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL,
-                                  MVT::v4i32, MulHigh);
-    SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, AddLow, AddHigh);
+    assert(ExtendIn->getValueType(0) == MVT::v16i8 &&
+           "expected v16i8 input types");
+    SDValue Add =
+        DAG.getNode(PairwiseOpc, DL, MVT::v4i32,
+                    DAG.getNode(PairwiseOpc, DL, MVT::v8i16, ExtendIn));
     return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
   }
 }
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
index 3c26b453c4482..d8948ad2df037 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
@@ -1454,12 +1454,13 @@ def : Pat<(t1.vt (bitconvert (t2.vt V128:$v))), (t1.vt V128:$v)>;
 
 // Extended pairwise addition
 def extadd_pairwise_u : SDNode<"WebAssemblyISD::EXT_ADD_PAIRWISE_U", extend_t>;
+def extadd_pairwise_s : SDNode<"WebAssemblyISD::EXT_ADD_PAIRWISE_S", extend_t>;
 
-defm "" : SIMDConvert<I16x8, I8x16, int_wasm_extadd_pairwise_signed,
+defm "" : SIMDConvert<I16x8, I8x16, extadd_pairwise_s,
                       "extadd_pairwise_i8x16_s", 0x7c>;
 defm "" : SIMDConvert<I16x8, I8x16, extadd_pairwise_u,
                       "extadd_pairwise_i8x16_u", 0x7d>;
-defm "" : SIMDConvert<I32x4, I16x8, int_wasm_extadd_pairwise_signed,
+defm "" : SIMDConvert<I32x4, I16x8, extadd_pairwise_s,
                       "extadd_pairwise_i16x8_s", 0x7e>;
 defm "" : SIMDConvert<I32x4, I16x8, extadd_pairwise_u,
                       "extadd_pairwise_i16x8_u", 0x7f>;
@@ -1468,6 +1469,10 @@ def : Pat<(v4i32 (int_wasm_extadd_pairwise_unsigned (v8i16 V128:$in))),
           (extadd_pairwise_u_I32x4 V128:$in)>;
 def : Pat<(v8i16 (int_wasm_extadd_pairwise_unsigned (v16i8 V128:$in))),
           (extadd_pairwise_u_I16x8 V128:$in)>;
+def : Pat<(v4i32 (int_wasm_extadd_pairwise_signed (v8i16 V128:$in))),
+          (extadd_pairwise_s_I32x4 V128:$in)>;
+def : Pat<(v8i16 (int_wasm_extadd_pairwise_signed (v16i8 V128:$in))),
+          (extadd_pairwise_s_I16x8 V128:$in)>;
 
 // f64x2 <-> f32x4 conversions
 def demote_t : SDTypeProfile<1, 1, [SDTCisVec<0>, SDTCisVec<1>]>;
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
index 0eefd3e2b3500..92a9812df2127 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp
@@ -316,7 +316,13 @@ InstructionCost WebAssemblyTTIImpl::getPartialReductionCost(
   if (CostKind != TTI::TCK_RecipThroughput)
     return Invalid;
 
-  InstructionCost Cost(TTI::TCC_Basic);
+  if (Opcode != Instruction::Add)
+    return Invalid;
+
+  EVT AccumEVT = EVT::getEVT(AccumType);
+  // TODO: Add i64 accumulator.
+  if (AccumEVT != MVT::i32)
+    return Invalid;
 
   // Possible options:
   // - i16x8.extadd_pairwise_i8x16_sx
@@ -324,23 +330,26 @@ InstructionCost WebAssemblyTTIImpl::getPartialReductionCost(
   // - i32x4.dot_i16x8_s
   // Only try to support dot, for now.
 
-  if (Opcode != Instruction::Add)
+  EVT InputEVT = EVT::getEVT(InputTypeA);
+  if (!((InputEVT == MVT::i16 && VF.getFixedValue() == 8) ||
+        (InputEVT == MVT::i8 && VF.getFixedValue() == 16))) {
     return Invalid;
+  }
 
-  if (!BinOp || *BinOp != Instruction::Mul)
+  if (OpAExtend == TTI::PR_None)
     return Invalid;
 
-  if (InputTypeA != InputTypeB)
-    return Invalid;
+  InstructionCost Cost(TTI::TCC_Basic);
+  if (!BinOp)
+    return Cost;
 
   if (OpAExtend != OpBExtend)
     return Invalid;
 
-  EVT InputEVT = EVT::getEVT(InputTypeA);
-  EVT AccumEVT = EVT::getEVT(AccumType);
+  if (*BinOp != Instruction::Mul)
+    return Invalid;
 
-  // TODO: Add i64 accumulator.
-  if (AccumEVT != MVT::i32)
+  if (InputTypeA != InputTypeB)
     return Invalid;
 
   // Signed inputs can lower to dot



More information about the llvm-commits mailing list