[llvm] 156e9b4 - [WebAssembly] Use partial_reduce_mla ISD nodes (#161184)

via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 30 00:29:01 PDT 2025


Author: Sam Parker
Date: 2025-09-30T08:28:56+01:00
New Revision: 156e9b4b6989043024ea8c5d15360a716af51631

URL: https://github.com/llvm/llvm-project/commit/156e9b4b6989043024ea8c5d15360a716af51631
DIFF: https://github.com/llvm/llvm-project/commit/156e9b4b6989043024ea8c5d15360a716af51631.diff

LOG: [WebAssembly] Use partial_reduce_mla ISD nodes (#161184)

Addresssing issue #160847.
 
Move away from combining the intrinsic call and instead lower the ISD
nodes, using tablegen for pattern matching.

Added: 
    

Modified: 
    llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
    llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h
    llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
    llvm/test/CodeGen/WebAssembly/partial-reduce-accumulate.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index 64b9dc31f75b7..163bf9ba5b089 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -186,7 +186,6 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
   // SIMD-specific configuration
   if (Subtarget->hasSIMD128()) {
 
-    // Combine partial.reduce.add before legalization gets confused.
     setTargetDAGCombine(ISD::INTRINSIC_WO_CHAIN);
 
     // Combine wide-vector muls, with extend inputs, to extmul_half.
@@ -317,6 +316,12 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
       setOperationAction(ISD::SIGN_EXTEND_VECTOR_INREG, T, Custom);
       setOperationAction(ISD::ZERO_EXTEND_VECTOR_INREG, T, Custom);
     }
+
+    // Partial MLA reductions.
+    for (auto Op : {ISD::PARTIAL_REDUCE_SMLA, ISD::PARTIAL_REDUCE_UMLA}) {
+      setPartialReduceMLAAction(Op, MVT::v4i32, MVT::v16i8, Legal);
+      setPartialReduceMLAAction(Op, MVT::v4i32, MVT::v8i16, Legal);
+    }
   }
 
   // As a special case, these operators use the type to mean the type to
@@ -416,41 +421,6 @@ MVT WebAssemblyTargetLowering::getPointerMemTy(const DataLayout &DL,
   return TargetLowering::getPointerMemTy(DL, AS);
 }
 
-bool WebAssemblyTargetLowering::shouldExpandPartialReductionIntrinsic(
-    const IntrinsicInst *I) const {
-  if (I->getIntrinsicID() != Intrinsic::vector_partial_reduce_add)
-    return true;
-
-  EVT VT = EVT::getEVT(I->getType());
-  if (VT.getSizeInBits() > 128)
-    return true;
-
-  auto Op1 = I->getOperand(1);
-
-  if (auto *InputInst = dyn_cast<Instruction>(Op1)) {
-    unsigned Opcode = InstructionOpcodeToISD(InputInst->getOpcode());
-    if (Opcode == ISD::MUL) {
-      if (isa<Instruction>(InputInst->getOperand(0)) &&
-          isa<Instruction>(InputInst->getOperand(1))) {
-        // dot only supports signed inputs but also support lowering unsigned.
-        if (cast<Instruction>(InputInst->getOperand(0))->getOpcode() !=
-            cast<Instruction>(InputInst->getOperand(1))->getOpcode())
-          return true;
-
-        EVT Op1VT = EVT::getEVT(Op1->getType());
-        if (Op1VT.getVectorElementType() == VT.getVectorElementType() &&
-            ((VT.getVectorElementCount() * 2 ==
-              Op1VT.getVectorElementCount()) ||
-             (VT.getVectorElementCount() * 4 == Op1VT.getVectorElementCount())))
-          return false;
-      }
-    } else if (ISD::isExtOpcode(Opcode)) {
-      return false;
-    }
-  }
-  return true;
-}
-
 TargetLowering::AtomicExpansionKind
 WebAssemblyTargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const {
   // We have wasm instructions for these
@@ -2113,106 +2083,6 @@ SDValue WebAssemblyTargetLowering::LowerVASTART(SDValue Op,
                       MachinePointerInfo(SV));
 }
 
-// Try to lower partial.reduce.add to a dot or fallback to a sequence with
-// extmul and adds.
-SDValue performLowerPartialReduction(SDNode *N, SelectionDAG &DAG) {
-  assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN);
-  if (N->getConstantOperandVal(0) != Intrinsic::vector_partial_reduce_add)
-    return SDValue();
-
-  assert(N->getValueType(0) == MVT::v4i32 && "can only support v4i32");
-  SDLoc DL(N);
-
-  SDValue Input = N->getOperand(2);
-  if (Input->getOpcode() == ISD::MUL) {
-    SDValue ExtendLHS = Input->getOperand(0);
-    SDValue ExtendRHS = Input->getOperand(1);
-    assert((ISD::isExtOpcode(ExtendLHS.getOpcode()) &&
-            ISD::isExtOpcode(ExtendRHS.getOpcode())) &&
-           "expected widening mul or add");
-    assert(ExtendLHS.getOpcode() == ExtendRHS.getOpcode() &&
-           "expected binop to use the same extend for both operands");
-
-    SDValue ExtendInLHS = ExtendLHS->getOperand(0);
-    SDValue ExtendInRHS = ExtendRHS->getOperand(0);
-    bool IsSigned = ExtendLHS->getOpcode() == ISD::SIGN_EXTEND;
-    unsigned LowOpc =
-        IsSigned ? WebAssemblyISD::EXTEND_LOW_S : WebAssemblyISD::EXTEND_LOW_U;
-    unsigned HighOpc = IsSigned ? WebAssemblyISD::EXTEND_HIGH_S
-                                : WebAssemblyISD::EXTEND_HIGH_U;
-    SDValue LowLHS;
-    SDValue LowRHS;
-    SDValue HighLHS;
-    SDValue HighRHS;
-
-    auto AssignInputs = [&](MVT VT) {
-      LowLHS = DAG.getNode(LowOpc, DL, VT, ExtendInLHS);
-      LowRHS = DAG.getNode(LowOpc, DL, VT, ExtendInRHS);
-      HighLHS = DAG.getNode(HighOpc, DL, VT, ExtendInLHS);
-      HighRHS = DAG.getNode(HighOpc, DL, VT, ExtendInRHS);
-    };
-
-    if (ExtendInLHS->getValueType(0) == MVT::v8i16) {
-      if (IsSigned) {
-        // i32x4.dot_i16x8_s
-        SDValue Dot = DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32,
-                                  ExtendInLHS, ExtendInRHS);
-        return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Dot);
-      }
-
-      // (add (add (extmul_low_sx lhs, rhs), (extmul_high_sx lhs, rhs)))
-      MVT VT = MVT::v4i32;
-      AssignInputs(VT);
-      SDValue MulLow = DAG.getNode(ISD::MUL, DL, VT, LowLHS, LowRHS);
-      SDValue MulHigh = DAG.getNode(ISD::MUL, DL, VT, HighLHS, HighRHS);
-      SDValue Add = DAG.getNode(ISD::ADD, DL, VT, MulLow, MulHigh);
-      return DAG.getNode(ISD::ADD, DL, VT, N->getOperand(1), Add);
-    } else {
-      assert(ExtendInLHS->getValueType(0) == MVT::v16i8 &&
-             "expected v16i8 input types");
-      AssignInputs(MVT::v8i16);
-      // Lower to a wider tree, using twice the operations compared to above.
-      if (IsSigned) {
-        // Use two dots
-        SDValue DotLHS =
-            DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, LowLHS, LowRHS);
-        SDValue DotRHS =
-            DAG.getNode(WebAssemblyISD::DOT, DL, MVT::v4i32, HighLHS, HighRHS);
-        SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, DotLHS, DotRHS);
-        return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
-      }
-
-      SDValue MulLow = DAG.getNode(ISD::MUL, DL, MVT::v8i16, LowLHS, LowRHS);
-      SDValue MulHigh = DAG.getNode(ISD::MUL, DL, MVT::v8i16, HighLHS, HighRHS);
-
-      SDValue AddLow = DAG.getNode(WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL,
-                                   MVT::v4i32, MulLow);
-      SDValue AddHigh = DAG.getNode(WebAssemblyISD::EXT_ADD_PAIRWISE_U, DL,
-                                    MVT::v4i32, MulHigh);
-      SDValue Add = DAG.getNode(ISD::ADD, DL, MVT::v4i32, AddLow, AddHigh);
-      return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
-    }
-  } else {
-    // Accumulate the input using extadd_pairwise.
-    assert(ISD::isExtOpcode(Input.getOpcode()) && "expected extend");
-    bool IsSigned = Input->getOpcode() == ISD::SIGN_EXTEND;
-    unsigned PairwiseOpc = IsSigned ? WebAssemblyISD::EXT_ADD_PAIRWISE_S
-                                    : WebAssemblyISD::EXT_ADD_PAIRWISE_U;
-    SDValue ExtendIn = Input->getOperand(0);
-    if (ExtendIn->getValueType(0) == MVT::v8i16) {
-      SDValue Add = DAG.getNode(PairwiseOpc, DL, MVT::v4i32, ExtendIn);
-      return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
-    }
-
-    assert(ExtendIn->getValueType(0) == MVT::v16i8 &&
-           "expected v16i8 input types");
-    SDValue Add =
-        DAG.getNode(PairwiseOpc, DL, MVT::v4i32,
-                    DAG.getNode(PairwiseOpc, DL, MVT::v8i16, ExtendIn));
-    return DAG.getNode(ISD::ADD, DL, MVT::v4i32, N->getOperand(1), Add);
-  }
-}
-
 SDValue WebAssemblyTargetLowering::LowerIntrinsic(SDValue Op,
                                                   SelectionDAG &DAG) const {
   MachineFunction &MF = DAG.getMachineFunction();
@@ -3683,11 +3553,8 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
     return performVectorTruncZeroCombine(N, DCI);
   case ISD::TRUNCATE:
     return performTruncateCombine(N, DCI);
-  case ISD::INTRINSIC_WO_CHAIN: {
-    if (auto AnyAllCombine = performAnyAllCombine(N, DCI.DAG))
-      return AnyAllCombine;
-    return performLowerPartialReduction(N, DCI.DAG);
-  }
+  case ISD::INTRINSIC_WO_CHAIN:
+    return performAnyAllCombine(N, DCI.DAG);
   case ISD::MUL:
     return performMulCombine(N, DCI);
   }

diff  --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h
index 72401a7a259c0..b33a8530310be 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.h
@@ -45,8 +45,6 @@ class WebAssemblyTargetLowering final : public TargetLowering {
   /// right decision when generating code for 
diff erent targets.
   const WebAssemblySubtarget *Subtarget;
 
-  bool
-  shouldExpandPartialReductionIntrinsic(const IntrinsicInst *I) const override;
   AtomicExpansionKind shouldExpandAtomicRMWInIR(AtomicRMWInst *) const override;
   bool shouldScalarizeBinop(SDValue VecOp) const override;
   FastISel *createFastISel(FunctionLoweringInfo &FuncInfo,
@@ -89,8 +87,7 @@ class WebAssemblyTargetLowering final : public TargetLowering {
   bool CanLowerReturn(CallingConv::ID CallConv, MachineFunction &MF,
                       bool isVarArg,
                       const SmallVectorImpl<ISD::OutputArg> &Outs,
-                      LLVMContext &Context,
-                      const Type *RetTy) const override;
+                      LLVMContext &Context, const Type *RetTy) const override;
   SDValue LowerReturn(SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
                       const SmallVectorImpl<ISD::OutputArg> &Outs,
                       const SmallVectorImpl<SDValue> &OutVals, const SDLoc &dl,

diff  --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
index d8948ad2df037..130602650d34e 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
@@ -1504,6 +1504,51 @@ def : Pat<(v2f64 (extloadv2f32 (i64 I64:$addr))),
 defm Q15MULR_SAT_S :
   SIMDBinary<I16x8, int_wasm_q15mulr_sat_signed, "q15mulr_sat_s", 0x82>;
 
+//===----------------------------------------------------------------------===//
+// Partial reductions, using: dot, extmul and extadd_pairwise
+//===----------------------------------------------------------------------===//
+// MLA: v8i16 -> v4i32
+def : Pat<(v4i32 (partial_reduce_smla (v4i32 V128:$acc), (v8i16 V128:$lhs),
+                                                         (v8i16 V128:$rhs))),
+          (ADD_I32x4 (DOT $lhs, $rhs), $acc)>;
+def : Pat<(v4i32 (partial_reduce_umla (v4i32 V128:$acc), (v8i16 V128:$lhs),
+                                                         (v8i16 V128:$rhs))),
+          (ADD_I32x4 (ADD_I32x4 (EXTMUL_LOW_U_I32x4 $lhs, $rhs),
+                                (EXTMUL_HIGH_U_I32x4 $lhs, $rhs)),
+                     $acc)>;
+// MLA: v16i8 -> v4i32
+def : Pat<(v4i32 (partial_reduce_smla (v4i32 V128:$acc), (v16i8 V128:$lhs),
+                                                         (v16i8 V128:$rhs))),
+          (ADD_I32x4 (ADD_I32x4 (DOT (extend_low_s_I16x8 $lhs),
+                                     (extend_low_s_I16x8 $rhs)),
+                                (DOT (extend_high_s_I16x8 $lhs),
+                                     (extend_high_s_I16x8 $rhs))),
+                      $acc)>;
+def : Pat<(v4i32 (partial_reduce_umla (v4i32 V128:$acc), (v16i8 V128:$lhs),
+                                                         (v16i8 V128:$rhs))),
+          (ADD_I32x4 (ADD_I32x4 (extadd_pairwise_u_I32x4 (EXTMUL_LOW_U_I16x8 $lhs, $rhs)),
+                                (extadd_pairwise_u_I32x4 (EXTMUL_HIGH_U_I16x8 $lhs, $rhs))),
+                     $acc)>;
+
+// Accumulate: v8i16 -> v4i32
+def : Pat<(v4i32 (partial_reduce_smla (v4i32 V128:$acc), (v8i16 V128:$in),
+                                                         (I16x8.splat (i32 1)))),
+          (ADD_I32x4 (extadd_pairwise_s_I32x4 $in), $acc)>;
+
+def : Pat<(v4i32 (partial_reduce_umla (v4i32 V128:$acc), (v8i16 V128:$in),
+                                                         (I16x8.splat (i32 1)))),
+          (ADD_I32x4 (extadd_pairwise_u_I32x4 $in), $acc)>;
+
+// Accumulate: v16i8 -> v4i32
+def : Pat<(v4i32 (partial_reduce_smla (v4i32 V128:$acc), (v16i8 V128:$in),
+                                                         (I8x16.splat (i32 1)))),
+          (ADD_I32x4 (extadd_pairwise_s_I32x4 (extadd_pairwise_s_I16x8 $in)),
+                     $acc)>;
+def : Pat<(v4i32 (partial_reduce_umla (v4i32 V128:$acc), (v16i8 V128:$in),
+                                                         (I8x16.splat (i32 1)))),
+          (ADD_I32x4 (extadd_pairwise_u_I32x4 (extadd_pairwise_u_I16x8 $in)),
+                     $acc)>;
+
 //===----------------------------------------------------------------------===//
 // Relaxed swizzle
 //===----------------------------------------------------------------------===//

diff  --git a/llvm/test/CodeGen/WebAssembly/partial-reduce-accumulate.ll b/llvm/test/CodeGen/WebAssembly/partial-reduce-accumulate.ll
index 47ea762864cc2..a599f4653f323 100644
--- a/llvm/test/CodeGen/WebAssembly/partial-reduce-accumulate.ll
+++ b/llvm/test/CodeGen/WebAssembly/partial-reduce-accumulate.ll
@@ -19,11 +19,11 @@ define hidden i32 @accumulate_add_u8_u8(ptr noundef readonly  %a, ptr noundef re
 ; MAX-BANDWIDTH: v128.load
 ; MAX-BANDWIDTH: i16x8.extadd_pairwise_i8x16_u
 ; MAX-BANDWIDTH: i32x4.extadd_pairwise_i16x8_u
-; MAX-BANDWIDTH: i32x4.add
 ; MAX-BANDWIDTH: v128.load
 ; MAX-BANDWIDTH: i16x8.extadd_pairwise_i8x16_u
 ; MAX-BANDWIDTH: i32x4.extadd_pairwise_i16x8_u
 ; MAX-BANDWIDTH: i32x4.add
+; MAX-BANDWIDTH: i32x4.add
 
 entry:
   %cmp8.not = icmp eq i32 %N, 0
@@ -65,11 +65,11 @@ define hidden i32 @accumulate_add_s8_s8(ptr noundef readonly  %a, ptr noundef re
 ; MAX-BANDWIDTH: v128.load
 ; MAX-BANDWIDTH: i16x8.extadd_pairwise_i8x16_s
 ; MAX-BANDWIDTH: i32x4.extadd_pairwise_i16x8_s
-; MAX-BANDWIDTH: i32x4.add
 ; MAX-BANDWIDTH: v128.load
 ; MAX-BANDWIDTH: i16x8.extadd_pairwise_i8x16_s
 ; MAX-BANDWIDTH: i32x4.extadd_pairwise_i16x8_s
 ; MAX-BANDWIDTH: i32x4.add
+; MAX-BANDWIDTH: i32x4.add
 entry:
   %cmp8.not = icmp eq i32 %N, 0
   br i1 %cmp8.not, label %for.cond.cleanup, label %for.body
@@ -108,12 +108,11 @@ define hidden i32 @accumulate_add_s8_u8(ptr noundef readonly  %a, ptr noundef re
 
 ; MAX-BANDWIDTH: loop
 ; MAX-BANDWIDTH: v128.load
-; MAX-BANDWIDTH: i16x8.extadd_pairwise_i8x16_s
-; MAX-BANDWIDTH: i32x4.extadd_pairwise_i16x8_s
-; MAX-BANDWIDTH: i32x4.add
-; MAX-BANDWIDTH: v128.load
 ; MAX-BANDWIDTH: i16x8.extadd_pairwise_i8x16_u
 ; MAX-BANDWIDTH: i32x4.extadd_pairwise_i16x8_u
+; MAX-BANDWIDTH: v128.load
+; MAX-BANDWIDTH: i16x8.extadd_pairwise_i8x16_s
+; MAX-BANDWIDTH: i32x4.extadd_pairwise_i16x8_s
 ; MAX-BANDWIDTH: i32x4.add
 entry:
   %cmp8.not = icmp eq i32 %N, 0
@@ -363,10 +362,10 @@ define hidden i32 @accumulate_add_u16_u16(ptr noundef readonly  %a, ptr noundef
 ; MAX-BANDWIDTH: loop
 ; MAX-BANDWIDTH: v128.load
 ; MAX-BANDWIDTH: i32x4.extadd_pairwise_i16x8_u
-; MAX-BANDWIDTH: i32x4.add
 ; MAX-BANDWIDTH: v128.load
 ; MAX-BANDWIDTH: i32x4.extadd_pairwise_i16x8_u
 ; MAX-BANDWIDTH: i32x4.add
+; MAX-BANDWIDTH: i32x4.add
 entry:
   %cmp8.not = icmp eq i32 %N, 0
   br i1 %cmp8.not, label %for.cond.cleanup, label %for.body
@@ -402,10 +401,10 @@ define hidden i32 @accumulate_add_s16_s16(ptr noundef readonly  %a, ptr noundef
 ; MAX-BANDWIDTH: loop
 ; MAX-BANDWIDTH: v128.load
 ; MAX-BANDWIDTH: i32x4.extadd_pairwise_i16x8_s
-; MAX-BANDWIDTH: i32x4.add
 ; MAX-BANDWIDTH: v128.load
 ; MAX-BANDWIDTH: i32x4.extadd_pairwise_i16x8_s
 ; MAX-BANDWIDTH: i32x4.add
+; MAX-BANDWIDTH: i32x4.add
 entry:
   %cmp8.not = icmp eq i32 %N, 0
   br i1 %cmp8.not, label %for.cond.cleanup, label %for.body


        


More information about the llvm-commits mailing list