[llvm] [WebAssembly] Use partial_reduce_mla ISD nodes (PR #161184)

via llvm-commits llvm-commits at lists.llvm.org
Mon Sep 29 05:02:27 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-webassembly

Author: Sam Parker (sparker-arm)

<details>
<summary>Changes</summary>

Addresssing issue #<!-- -->160847.
 
Move away from combining the intrinsic call and instead lower the ISD nodes, using more tablegen for pattern matching.

---
Full diff: https://github.com/llvm/llvm-project/pull/161184.diff


4 Files Affected:

- (modified) llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp (+49-91) 
- (modified) llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h (+2-4) 
- (modified) llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td (+9) 
- (modified) llvm/test/CodeGen/WebAssembly/partial-reduce-accumulate.ll (+1-1) 


``````````diff
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index 64b9dc31f75b7..e830def066087 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -186,7 +186,6 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
   // SIMD-specific configuration
   if (Subtarget->hasSIMD128()) {
 
-    // Combine partial.reduce.add before legalization gets confused.
     setTargetDAGCombine(ISD::INTRINSIC_WO_CHAIN);
 
     // Combine wide-vector muls, with extend inputs, to extmul_half.
@@ -317,6 +316,18 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
       setOperationAction(ISD::SIGN_EXTEND_VECTOR_INREG, T, Custom);
       setOperationAction(ISD::ZERO_EXTEND_VECTOR_INREG, T, Custom);
     }
+
+    // Partial MLA reductions.
+    // We only have native support with i32x4.dot_i16x8_s, the rest are custom
+    // lowered.
+    setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SMLA, MVT::v4i32, MVT::v8i16,
+                              Legal);
+    setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_UMLA, MVT::v4i32, MVT::v8i16,
+                              Custom);
+    setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_SMLA, MVT::v4i32, MVT::v16i8,
+                              Custom);
+    setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_UMLA, MVT::v4i32, MVT::v16i8,
+                              Custom);
   }
 
   // As a special case, these operators use the type to mean the type to
@@ -416,41 +427,6 @@ MVT WebAssemblyTargetLowering::getPointerMemTy(const DataLayout &DL,
   return TargetLowering::getPointerMemTy(DL, AS);
 }
 
-bool WebAssemblyTargetLowering::shouldExpandPartialReductionIntrinsic(
-    const IntrinsicInst *I) const {
-  if (I->getIntrinsicID() != Intrinsic::vector_partial_reduce_add)
-    return true;
-
-  EVT VT = EVT::getEVT(I->getType());
-  if (VT.getSizeInBits() > 128)
-    return true;
-
-  auto Op1 = I->getOperand(1);
-
-  if (auto *InputInst = dyn_cast<Instruction>(Op1)) {
-    unsigned Opcode = InstructionOpcodeToISD(InputInst->getOpcode());
-    if (Opcode == ISD::MUL) {
-      if (isa<Instruction>(InputInst->getOperand(0)) &&
-          isa<Instruction>(InputInst->getOperand(1))) {
-        // dot only supports signed inputs but also support lowering unsigned.
-        if (cast<Instruction>(InputInst->getOperand(0))->getOpcode() !=
-            cast<Instruction>(InputInst->getOperand(1))->getOpcode())
-          return true;
-
-        EVT Op1VT = EVT::getEVT(Op1->getType());
-        if (Op1VT.getVectorElementType() == VT.getVectorElementType() &&
-            ((VT.getVectorElementCount() * 2 ==
-              Op1VT.getVectorElementCount()) ||
-             (VT.getVectorElementCount() * 4 == Op1VT.getVectorElementCount())))
-          return false;
-      }
-    } else if (ISD::isExtOpcode(Opcode)) {
-      return false;
-    }
-  }
-  return true;
-}
-
 TargetLowering::AtomicExpansionKind
 WebAssemblyTargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const {
   // We have wasm instructions for these
@@ -1706,6 +1682,9 @@ SDValue WebAssemblyTargetLowering::LowerOperation(SDValue Op,
     return LowerMUL_LOHI(Op, DAG);
   case ISD::UADDO:
     return LowerUADDO(Op, DAG);
+  case ISD::PARTIAL_REDUCE_SMLA:
+  case ISD::PARTIAL_REDUCE_UMLA:
+    return LowerPARTIAL_REDUCE_MLA(Op, DAG);
   }
 }
 
@@ -2113,29 +2092,36 @@ SDValue WebAssemblyTargetLowering::LowerVASTART(SDValue Op,
                       MachinePointerInfo(SV));
 }
 
-// Try to lower partial.reduce.add to a dot or fallback to a sequence with
-// extmul and adds.
-SDValue performLowerPartialReduction(SDNode *N, SelectionDAG &DAG) {
-  assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN);
-  if (N->getConstantOperandVal(0) != Intrinsic::vector_partial_reduce_add)
-    return SDValue();
+// We only have native support with i32x4.dot_i16x8_s, so for the unsigned
+// case we can expand to extmul and add. For v16i8 inputs, we can use two dots,
+// for signed, for an expanded tree of extmul adds for unsigned.
+SDValue
+WebAssemblyTargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
+                                                   SelectionDAG &DAG) const {
+  assert(Op->getValueType(0) == MVT::v4i32 && "can only support v4i32");
+  SDLoc DL(Op);
 
-  assert(N->getValueType(0) == MVT::v4i32 && "can only support v4i32");
-  SDLoc DL(N);
+  SDValue Acc = Op.getOperand(0);
+  SDValue ExtendInLHS = Op.getOperand(1);
+  SDValue ExtendInRHS = Op.getOperand(2);
+  bool IsSigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
 
-  SDValue Input = N->getOperand(2);
-  if (Input->getOpcode() == ISD::MUL) {
-    SDValue ExtendLHS = Input->getOperand(0);
-    SDValue ExtendRHS = Input->getOperand(1);
-    assert((ISD::isExtOpcode(ExtendLHS.getOpcode()) &&
-            ISD::isExtOpcode(ExtendRHS.getOpcode())) &&
-           "expected widening mul or add");
-    assert(ExtendLHS.getOpcode() == ExtendRHS.getOpcode() &&
-           "expected binop to use the same extend for both operands");
-
-    SDValue ExtendInLHS = ExtendLHS->getOperand(0);
-    SDValue ExtendInRHS = ExtendRHS->getOperand(0);
-    bool IsSigned = ExtendLHS->getOpcode() == ISD::SIGN_EXTEND;
+  APInt Imm;
+  if (ISD::isConstantSplatVector(ExtendInRHS.getNode(), Imm) && Imm == 1) {
+    // Accumulate the input using extadd_pairwise.
+    unsigned PairwiseOpc = IsSigned ? WebAssemblyISD::EXT_ADD_PAIRWISE_S
+                                    : WebAssemblyISD::EXT_ADD_PAIRWISE_U;
+    if (ExtendInLHS->getValueType(0) == MVT::v8i16) {
+      SDValue Add = DAG.getNode(PairwiseOpc, DL, MVT::v4i32, ExtendInLHS);
+      return DAG.getNode(ISD::ADD, DL, MVT::v4i32, Acc, Add);
+    }
+    assert(ExtendInLHS->getValueType(0) == MVT::v16i8 &&
+           "expected v16i8 input types");
+    SDValue Add =
+        DAG.getNode(PairwiseOpc, DL, MVT::v4i32,
+                    DAG.getNode(PairwiseOpc, DL, MVT::v8i16, ExtendInLHS));
+    return DAG.getNode(ISD::ADD, DL, MVT::v4i32, Acc, Add);
+  } else {
     unsigned LowOpc =
         IsSigned ? WebAssemblyISD::EXTEND_LOW_S : WebAssemblyISD::EXTEND_LOW_U;
     unsigned HighOpc = IsSigned ? WebAssemblyISD::EXTEND_HIGH_S
@@ -2151,22 +2137,15 @@ SDValue performLowerPartialReduction(SDNode *N, SelectionDAG &DAG) {
       HighLHS = DAG.getNode(HighOpc, DL, VT, ExtendInLHS);
       HighRHS = DAG.getNode(HighOpc, DL, VT, ExtendInRHS);
     };
-
     if (ExtendInLHS->getValueType(0) == MVT::v8i16) {
-      if (IsSigned) {
-        // i32x4.dot_i16x8_s
-        SDValue Dot = DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32,
-                                  ExtendInLHS, ExtendInRHS);
-        return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Dot);
-      }
-
+      assert(!IsSigned && "expected unsigned");
       // (add (add (extmul_low_sx lhs, rhs), (extmul_high_sx lhs, rhs)))
       MVT VT = MVT::v4i32;
       AssignInputs(VT);
       SDValue MulLow = DAG.getNode(ISD::MUL, DL, VT, LowLHS, LowRHS);
       SDValue MulHigh = DAG.getNode(ISD::MUL, DL, VT, HighLHS, HighRHS);
       SDValue Add = DAG.getNode(ISD::ADD, DL, VT, MulLow, MulHigh);
-      return DAG.getNode(ISD::ADD, DL, VT, N->getOperand(1), Add);
+      return DAG.getNode(ISD::ADD, DL, VT, Acc, Add);
     } else {
       assert(ExtendInLHS->getValueType(0) == MVT::v16i8 &&
              "expected v16i8 input types");
@@ -2179,7 +2158,7 @@ SDValue performLowerPartialReduction(SDNode *N, SelectionDAG &DAG) {
         SDValue DotRHS =
             DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, HighLHS, HighRHS);
         SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, DotLHS, DotRHS);
-        return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
+        return DAG.getNode(ISD::ADD, DL, MVT::v4i32, Acc, Add);
       }
 
       SDValue MulLow = DAG.getNode(ISD::MUL, DL, MVT::v8i16, LowLHS, LowRHS);
@@ -2190,26 +2169,8 @@ SDValue performLowerPartialReduction(SDNode *N, SelectionDAG &DAG) {
       SDValue AddHigh = DAG.getNode(WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL,
                                     MVT::v4i32, MulHigh);
       SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, AddLow, AddHigh);
-      return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
+      return DAG.getNode(ISD::ADD, DL, MVT::v4i32, Acc, Add);
     }
-  } else {
-    // Accumulate the input using extadd_pairwise.
-    assert(ISD::isExtOpcode(Input.getOpcode()) && "expected extend");
-    bool IsSigned = Input->getOpcode() == ISD::SIGN_EXTEND;
-    unsigned PairwiseOpc = IsSigned ? WebAssemblyISD::EXT_ADD_PAIRWISE_S
-                                    : WebAssemblyISD::EXT_ADD_PAIRWISE_U;
-    SDValue ExtendIn = Input->getOperand(0);
-    if (ExtendIn->getValueType(0) == MVT::v8i16) {
-      SDValue Add = DAG.getNode(PairwiseOpc, DL, MVT::v4i32, ExtendIn);
-      return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
-    }
-
-    assert(ExtendIn->getValueType(0) == MVT::v16i8 &&
-           "expected v16i8 input types");
-    SDValue Add =
-        DAG.getNode(PairwiseOpc, DL, MVT::v4i32,
-                    DAG.getNode(PairwiseOpc, DL, MVT::v8i16, ExtendIn));
-    return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
   }
 }
 
@@ -3683,11 +3644,8 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
     return performVectorTruncZeroCombine(N, DCI);
   case ISD::TRUNCATE:
     return performTruncateCombine(N, DCI);
-  case ISD::INTRINSIC_WO_CHAIN: {
-    if (auto AnyAllCombine = performAnyAllCombine(N, DCI.DAG))
-      return AnyAllCombine;
-    return performLowerPartialReduction(N, DCI.DAG);
-  }
+  case ISD::INTRINSIC_WO_CHAIN:
+    return performAnyAllCombine(N, DCI.DAG);
   case ISD::MUL:
     return performMulCombine(N, DCI);
   }
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h
index 72401a7a259c0..3ff8346e12a6f 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h
@@ -45,8 +45,6 @@ class WebAssemblyTargetLowering final : public TargetLowering {
   /// right decision when generating code for different targets.
   const WebAssemblySubtarget *Subtarget;
 
-  bool
-  shouldExpandPartialReductionIntrinsic(const IntrinsicInst *I) const override;
   AtomicExpansionKind shouldExpandAtomicRMWInIR(AtomicRMWInst *) const override;
   bool shouldScalarizeBinop(SDValue VecOp) const override;
   FastISel *createFastISel(FunctionLoweringInfo &FuncInfo,
@@ -89,8 +87,7 @@ class WebAssemblyTargetLowering final : public TargetLowering {
   bool CanLowerReturn(CallingConv::ID CallConv, MachineFunction &MF,
                       bool isVarArg,
                       const SmallVectorImpl<ISD::OutputArg> &Outs,
-                      LLVMContext &Context,
-                      const Type *RetTy) const override;
+                      LLVMContext &Context, const Type *RetTy) const override;
   SDValue LowerReturn(SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
                       const SmallVectorImpl<ISD::OutputArg> &Outs,
                       const SmallVectorImpl<SDValue> &OutVals, const SDLoc &dl,
@@ -134,6 +131,7 @@ class WebAssemblyTargetLowering final : public TargetLowering {
   SDValue LowerMUL_LOHI(SDValue Op, SelectionDAG &DAG) const;
   SDValue Replace128Op(SDNode *N, SelectionDAG &DAG) const;
   SDValue LowerUADDO(SDValue Op, SelectionDAG &DAG) const;
+  SDValue LowerPARTIAL_REDUCE_MLA(SDValue Op, SelectionDAG &DAG) const;
 
   // Custom DAG combine hooks
   SDValue
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
index d8948ad2df037..b5724ecd90155 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
@@ -1159,6 +1159,9 @@ defm DOT : SIMD_I<(outs V128:$dst), (ins V128:$lhs, V128:$rhs), (outs), (ins),
                   186>;
 def : Pat<(wasm_dot V128:$lhs, V128:$rhs),
           (DOT $lhs, $rhs)>;
+def : Pat<(v4i32 (partial_reduce_smla (v4i32 V128:$acc), (v8i16 V128:$lhs),
+                                                         (v8i16 V128:$rhs))),
+          (ADD_I32x4 (DOT $lhs, $rhs), $acc)>;
 
 // Extending multiplication: extmul_{low,high}_P, extmul_high
 def extend_t : SDTypeProfile<1, 1, [SDTCisVec<0>, SDTCisVec<1>]>;
@@ -1473,6 +1476,12 @@ def : Pat<(v4i32 (int_wasm_extadd_pairwise_signed (v8i16 V128:$in))),
           (extadd_pairwise_s_I32x4 V128:$in)>;
 def : Pat<(v8i16 (int_wasm_extadd_pairwise_signed (v16i8 V128:$in))),
           (extadd_pairwise_s_I16x8 V128:$in)>;
+def : Pat<(v4i32 (partial_reduce_smla (v4i32 V128:$acc), (v8i16 V128:$in),
+                                                         (I16x8.splat (i32 1)))),
+          (ADD_I32x4 (extadd_pairwise_s_I32x4 V128:$in), V128:$acc)>;
+def : Pat<(v4i32 (partial_reduce_umla (v4i32 V128:$acc), (v8i16 V128:$in),
+                                                         (I16x8.splat (i32 1)))),
+          (ADD_I32x4 (extadd_pairwise_u_I32x4 V128:$in), V128:$acc)>;
 
 // f64x2 <-> f32x4 conversions
 def demote_t : SDTypeProfile<1, 1, [SDTCisVec<0>, SDTCisVec<1>]>;
diff --git a/llvm/test/CodeGen/WebAssembly/partial-reduce-accumulate.ll b/llvm/test/CodeGen/WebAssembly/partial-reduce-accumulate.ll
index 47ea762864cc2..c9e486a3f29b4 100644
--- a/llvm/test/CodeGen/WebAssembly/partial-reduce-accumulate.ll
+++ b/llvm/test/CodeGen/WebAssembly/partial-reduce-accumulate.ll
@@ -402,10 +402,10 @@ define hidden i32 @accumulate_add_s16_s16(ptr noundef readonly  %a, ptr noundef
 ; MAX-BANDWIDTH: loop
 ; MAX-BANDWIDTH: v128.load
 ; MAX-BANDWIDTH: i32x4.extadd_pairwise_i16x8_s
-; MAX-BANDWIDTH: i32x4.add
 ; MAX-BANDWIDTH: v128.load
 ; MAX-BANDWIDTH: i32x4.extadd_pairwise_i16x8_s
 ; MAX-BANDWIDTH: i32x4.add
+; MAX-BANDWIDTH: i32x4.add
 entry:
   %cmp8.not = icmp eq i32 %N, 0
   br i1 %cmp8.not, label %for.cond.cleanup, label %for.body

``````````

</details>


https://github.com/llvm/llvm-project/pull/161184


More information about the llvm-commits mailing list