[llvm] [WebAssembly] Lower wide SIMD i8 muls (PR #130785)

via llvm-commits llvm-commits at lists.llvm.org
Tue Mar 11 08:51:59 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-backend-webassembly

Author: Sam Parker (sparker-arm)

<details>
<summary>Changes</summary>

Currently, 'wide' i32 simd multiplication, with extended i8 elements, will perform the multiplication with i32 So, for IR like the following:
```
  %wide.a = sext <8 x i8> %a to <8 x i32>
  %wide.b = sext <8 x i8> %a to <8 x i32>
  %mul = mul <8 x i32> %wide.a, %wide.b
  ret <8 x i32> %mul
```

We would generate the following sequence:
```
  i16x8.extend_low_i8x16_s $push6=, $1
  local.tee $push5=, $3=, $pop6
  i32x4.extmul_low_i16x8_s $push0=, $pop5, $3
  v128.store 0($0), $pop0
  i8x16.shuffle $push1=, $1, $1, 4, 5, 6, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
  i16x8.extend_low_i8x16_s $push4=, $pop1
  local.tee $push3=, $1=, $pop4
  i32x4.extmul_low_i16x8_s $push2=, $pop3, $1
  v128.store 16($0), $pop2
  return
```

But now we perform the multiplication with i16, resulting in:
```
  i16x8.extmul_low_i8x16_s $push3=, $1, $1
  local.tee $push2=, $1=, $pop3
  i32x4.extend_high_i16x8_s $push0=, $pop2
  v128.store 16($0), $pop0
  i32x4.extend_low_i16x8_s $push1=, $1
  v128.store 0($0), $pop1
  return
```

---
Full diff: https://github.com/llvm/llvm-project/pull/130785.diff


2 Files Affected:

- (modified) llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp (+93-2) 
- (added) llvm/test/CodeGen/WebAssembly/wide-simd-mul.ll (+197) 


``````````diff
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index b24a45c2d8898..9ae46e709d823 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -183,6 +183,9 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
     // Combine partial.reduce.add before legalization gets confused.
     setTargetDAGCombine(ISD::INTRINSIC_WO_CHAIN);
 
+    // Combine wide-vector muls, with extend inputs, to extmul_half.
+    setTargetDAGCombine(ISD::MUL);
+
     // Combine vector mask reductions into alltrue/anytrue
     setTargetDAGCombine(ISD::SETCC);
 
@@ -1461,8 +1464,7 @@ WebAssemblyTargetLowering::LowerCall(CallLoweringInfo &CLI,
 
 bool WebAssemblyTargetLowering::CanLowerReturn(
     CallingConv::ID /*CallConv*/, MachineFunction & /*MF*/, bool /*IsVarArg*/,
-    const SmallVectorImpl<ISD::OutputArg> &Outs,
-    LLVMContext & /*Context*/,
+    const SmallVectorImpl<ISD::OutputArg> &Outs, LLVMContext & /*Context*/,
     const Type *RetTy) const {
   // WebAssembly can only handle returning tuples with multivalue enabled
   return WebAssembly::canLowerReturn(Outs.size(), Subtarget);
@@ -3254,6 +3256,93 @@ static SDValue performSETCCCombine(SDNode *N,
   return SDValue();
 }
 
+static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG) {
+  assert(N->getOpcode() == ISD::MUL);
+  EVT VT = N->getValueType(0);
+  if (VT != MVT::v8i32 && VT != MVT::v16i32)
+    return SDValue();
+
+  // Mul with extending inputs.
+  SDValue LHS = N->getOperand(0);
+  SDValue RHS = N->getOperand(1);
+  if (LHS.getOpcode() != RHS.getOpcode())
+    return SDValue();
+
+  if (LHS.getOpcode() != ISD::SIGN_EXTEND &&
+      LHS.getOpcode() != ISD::ZERO_EXTEND)
+    return SDValue();
+
+  if (LHS->getOperand(0).getValueType() != RHS->getOperand(0).getValueType())
+    return SDValue();
+
+  EVT FromVT = LHS->getOperand(0).getValueType();
+  EVT EltTy = FromVT.getVectorElementType();
+  if (EltTy != MVT::i8)
+    return SDValue();
+
+  // For an input DAG that looks like this
+  // %a = input_type
+  // %b = input_type
+  // %lhs = extend %a to output_type
+  // %rhs = extend %b to output_type
+  // %mul = mul %lhs, %rhs
+
+  // input_type | output_type | instructions
+  // v16i8      | v16i32      | %low = i16x8.extmul_low_i8x16_ %a, %b
+  //            |             | %high = i16x8.extmul_high_i8x16_, %a, %b
+  //            |             | %low_low = i32x4.ext_low_i16x8_ %low
+  //            |             | %low_high = i32x4.ext_high_i16x8_ %low
+  //            |             | %high_low = i32x4.ext_low_i16x8_ %high
+  //            |             | %high_high = i32x4.ext_high_i16x8_ %high
+  //            |             | %res = concat_vector(...)
+  // v8i8       | v8i32       | %low = i16x8.extmul_low_i8x16_ %a, %b
+  //            |             | %low_low = i32x4.ext_low_i16x8_ %low
+  //            |             | %low_high = i32x4.ext_high_i16x8_ %low
+  //            |             | %res = concat_vector(%low_low, %low_high)
+
+  SDLoc DL(N);
+  unsigned NumElts = VT.getVectorNumElements();
+  SDValue ExtendInLHS = LHS->getOperand(0);
+  SDValue ExtendInRHS = RHS->getOperand(0);
+  bool IsSigned = LHS->getOpcode() == ISD::SIGN_EXTEND;
+  unsigned ExtendLowOpc =
+      IsSigned ? WebAssemblyISD::EXTEND_LOW_S : WebAssemblyISD::EXTEND_LOW_U;
+  unsigned ExtendHighOpc =
+      IsSigned ? WebAssemblyISD::EXTEND_HIGH_S : WebAssemblyISD::EXTEND_HIGH_U;
+
+  auto GetExtendLow = [&DAG, &DL, &ExtendLowOpc](EVT VT, SDValue Op) {
+    return DAG.getNode(ExtendLowOpc, DL, VT, Op);
+  };
+  auto GetExtendHigh = [&DAG, &DL, &ExtendHighOpc](EVT VT, SDValue Op) {
+    return DAG.getNode(ExtendHighOpc, DL, VT, Op);
+  };
+
+  if (NumElts == 16) {
+    SDValue LowLHS = GetExtendLow(MVT::v8i16, ExtendInLHS);
+    SDValue LowRHS = GetExtendLow(MVT::v8i16, ExtendInRHS);
+    SDValue MulLow = DAG.getNode(ISD::MUL, DL, MVT::v8i16, LowLHS, LowRHS);
+    SDValue HighLHS = GetExtendHigh(MVT::v8i16, ExtendInLHS);
+    SDValue HighRHS = GetExtendHigh(MVT::v8i16, ExtendInRHS);
+    SDValue MulHigh = DAG.getNode(ISD::MUL, DL, MVT::v8i16, HighLHS, HighRHS);
+    SDValue SubVectors[] = {
+        GetExtendLow(MVT::v4i32, MulLow),
+        GetExtendHigh(MVT::v4i32, MulLow),
+        GetExtendLow(MVT::v4i32, MulHigh),
+        GetExtendHigh(MVT::v4i32, MulHigh),
+    };
+    return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, SubVectors);
+  } else {
+    assert(NumElts == 8);
+    SDValue LowLHS = DAG.getNode(LHS->getOpcode(), DL, MVT::v8i16, ExtendInLHS);
+    SDValue LowRHS = DAG.getNode(RHS->getOpcode(), DL, MVT::v8i16, ExtendInRHS);
+    SDValue MulLow = DAG.getNode(ISD::MUL, DL, MVT::v8i16, LowLHS, LowRHS);
+    SDValue Lo = GetExtendLow(MVT::v4i32, MulLow);
+    SDValue Hi = GetExtendHigh(MVT::v4i32, MulLow);
+    return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Lo, Hi);
+  }
+  return SDValue();
+}
+
 SDValue
 WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
                                              DAGCombinerInfo &DCI) const {
@@ -3281,5 +3370,7 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
     return performTruncateCombine(N, DCI);
   case ISD::INTRINSIC_WO_CHAIN:
     return performLowerPartialReduction(N, DCI.DAG);
+  case ISD::MUL:
+    return performMulCombine(N, DCI.DAG);
   }
 }
diff --git a/llvm/test/CodeGen/WebAssembly/wide-simd-mul.ll b/llvm/test/CodeGen/WebAssembly/wide-simd-mul.ll
new file mode 100644
index 0000000000000..94aa197bfd564
--- /dev/null
+++ b/llvm/test/CodeGen/WebAssembly/wide-simd-mul.ll
@@ -0,0 +1,197 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -mtriple=wasm32 -verify-machineinstrs -disable-wasm-fallthrough-return-opt -wasm-disable-explicit-locals -wasm-keep-registers -mattr=+simd128 | FileCheck %s
+
+define <8 x i32> @sext_mul_v8i8(<8 x i8> %a, <8 x i8> %b) {
+; CHECK-LABEL: sext_mul_v8i8:
+; CHECK:         .functype sext_mul_v8i8 (i32, v128, v128) -> ()
+; CHECK-NEXT:  # %bb.0:
+; CHECK-NEXT:    i16x8.extmul_low_i8x16_s $push3=, $1, $1
+; CHECK-NEXT:    local.tee $push2=, $1=, $pop3
+; CHECK-NEXT:    i32x4.extend_high_i16x8_s $push0=, $pop2
+; CHECK-NEXT:    v128.store 16($0), $pop0
+; CHECK-NEXT:    i32x4.extend_low_i16x8_s $push1=, $1
+; CHECK-NEXT:    v128.store 0($0), $pop1
+; CHECK-NEXT:    return
+  %wide.a = sext <8 x i8> %a to <8 x i32>
+  %wide.b = sext <8 x i8> %a to <8 x i32>
+  %mul = mul <8 x i32> %wide.a, %wide.b
+  ret <8 x i32> %mul
+}
+
+define <16 x i32> @sext_mul_v16i8(<16 x i8> %a, <16 x i8> %b) {
+; CHECK-LABEL: sext_mul_v16i8:
+; CHECK:         .functype sext_mul_v16i8 (i32, v128, v128) -> ()
+; CHECK-NEXT:  # %bb.0:
+; CHECK-NEXT:    i16x8.extmul_high_i8x16_s $push7=, $1, $1
+; CHECK-NEXT:    local.tee $push6=, $3=, $pop7
+; CHECK-NEXT:    i32x4.extend_high_i16x8_s $push0=, $pop6
+; CHECK-NEXT:    v128.store 48($0), $pop0
+; CHECK-NEXT:    i32x4.extend_low_i16x8_s $push1=, $3
+; CHECK-NEXT:    v128.store 32($0), $pop1
+; CHECK-NEXT:    i16x8.extmul_low_i8x16_s $push5=, $1, $1
+; CHECK-NEXT:    local.tee $push4=, $1=, $pop5
+; CHECK-NEXT:    i32x4.extend_high_i16x8_s $push2=, $pop4
+; CHECK-NEXT:    v128.store 16($0), $pop2
+; CHECK-NEXT:    i32x4.extend_low_i16x8_s $push3=, $1
+; CHECK-NEXT:    v128.store 0($0), $pop3
+; CHECK-NEXT:    return
+  %wide.a = sext <16 x i8> %a to <16 x i32>
+  %wide.b = sext <16 x i8> %a to <16 x i32>
+  %mul = mul <16 x i32> %wide.a, %wide.b
+  ret <16 x i32> %mul
+}
+
+define <8 x i32> @sext_mul_v8i16(<8 x i16> %a, <8 x i16> %b) {
+; CHECK-LABEL: sext_mul_v8i16:
+; CHECK:         .functype sext_mul_v8i16 (i32, v128, v128) -> ()
+; CHECK-NEXT:  # %bb.0:
+; CHECK-NEXT:    i32x4.extmul_high_i16x8_s $push0=, $1, $1
+; CHECK-NEXT:    v128.store 16($0), $pop0
+; CHECK-NEXT:    i32x4.extmul_low_i16x8_s $push1=, $1, $1
+; CHECK-NEXT:    v128.store 0($0), $pop1
+; CHECK-NEXT:    return
+  %wide.a = sext <8 x i16> %a to <8 x i32>
+  %wide.b = sext <8 x i16> %a to <8 x i32>
+  %mul = mul <8 x i32> %wide.a, %wide.b
+  ret <8 x i32> %mul
+}
+
+define <8 x i32> @zext_mul_v8i8(<8 x i8> %a, <8 x i8> %b) {
+; CHECK-LABEL: zext_mul_v8i8:
+; CHECK:         .functype zext_mul_v8i8 (i32, v128, v128) -> ()
+; CHECK-NEXT:  # %bb.0:
+; CHECK-NEXT:    i16x8.extmul_low_i8x16_u $push3=, $1, $1
+; CHECK-NEXT:    local.tee $push2=, $1=, $pop3
+; CHECK-NEXT:    i32x4.extend_high_i16x8_u $push0=, $pop2
+; CHECK-NEXT:    v128.store 16($0), $pop0
+; CHECK-NEXT:    i32x4.extend_low_i16x8_u $push1=, $1
+; CHECK-NEXT:    v128.store 0($0), $pop1
+; CHECK-NEXT:    return
+  %wide.a = zext <8 x i8> %a to <8 x i32>
+  %wide.b = zext <8 x i8> %a to <8 x i32>
+  %mul = mul <8 x i32> %wide.a, %wide.b
+  ret <8 x i32> %mul
+}
+
+define <16 x i32> @zext_mul_v16i8(<16 x i8> %a, <16 x i8> %b) {
+; CHECK-LABEL: zext_mul_v16i8:
+; CHECK:         .functype zext_mul_v16i8 (i32, v128, v128) -> ()
+; CHECK-NEXT:  # %bb.0:
+; CHECK-NEXT:    i16x8.extmul_high_i8x16_u $push7=, $1, $1
+; CHECK-NEXT:    local.tee $push6=, $3=, $pop7
+; CHECK-NEXT:    i32x4.extend_high_i16x8_u $push0=, $pop6
+; CHECK-NEXT:    v128.store 48($0), $pop0
+; CHECK-NEXT:    i32x4.extend_low_i16x8_u $push1=, $3
+; CHECK-NEXT:    v128.store 32($0), $pop1
+; CHECK-NEXT:    i16x8.extmul_low_i8x16_u $push5=, $1, $1
+; CHECK-NEXT:    local.tee $push4=, $1=, $pop5
+; CHECK-NEXT:    i32x4.extend_high_i16x8_u $push2=, $pop4
+; CHECK-NEXT:    v128.store 16($0), $pop2
+; CHECK-NEXT:    i32x4.extend_low_i16x8_u $push3=, $1
+; CHECK-NEXT:    v128.store 0($0), $pop3
+; CHECK-NEXT:    return
+  %wide.a = zext <16 x i8> %a to <16 x i32>
+  %wide.b = zext <16 x i8> %a to <16 x i32>
+  %mul = mul <16 x i32> %wide.a, %wide.b
+  ret <16 x i32> %mul
+}
+
+define <8 x i32> @zext_mul_v8i16(<8 x i16> %a, <8 x i16> %b) {
+; CHECK-LABEL: zext_mul_v8i16:
+; CHECK:         .functype zext_mul_v8i16 (i32, v128, v128) -> ()
+; CHECK-NEXT:  # %bb.0:
+; CHECK-NEXT:    i32x4.extmul_high_i16x8_u $push0=, $1, $1
+; CHECK-NEXT:    v128.store 16($0), $pop0
+; CHECK-NEXT:    i32x4.extmul_low_i16x8_u $push1=, $1, $1
+; CHECK-NEXT:    v128.store 0($0), $pop1
+; CHECK-NEXT:    return
+  %wide.a = zext <8 x i16> %a to <8 x i32>
+  %wide.b = zext <8 x i16> %a to <8 x i32>
+  %mul = mul <8 x i32> %wide.a, %wide.b
+  ret <8 x i32> %mul
+}
+
+define <8 x i32> @sext_zext_mul_v8i8(<8 x i8> %a, <8 x i8> %b) {
+; CHECK-LABEL: sext_zext_mul_v8i8:
+; CHECK:         .functype sext_zext_mul_v8i8 (i32, v128, v128) -> ()
+; CHECK-NEXT:  # %bb.0:
+; CHECK-NEXT:    i16x8.extend_low_i8x16_s $push2=, $1
+; CHECK-NEXT:    i32x4.extend_low_i16x8_s $push3=, $pop2
+; CHECK-NEXT:    i16x8.extend_low_i8x16_u $push0=, $1
+; CHECK-NEXT:    i32x4.extend_low_i16x8_u $push1=, $pop0
+; CHECK-NEXT:    i32x4.mul $push4=, $pop3, $pop1
+; CHECK-NEXT:    v128.store 0($0), $pop4
+; CHECK-NEXT:    i8x16.shuffle $push11=, $1, $1, 4, 5, 6, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
+; CHECK-NEXT:    local.tee $push10=, $1=, $pop11
+; CHECK-NEXT:    i16x8.extend_low_i8x16_s $push7=, $pop10
+; CHECK-NEXT:    i32x4.extend_low_i16x8_s $push8=, $pop7
+; CHECK-NEXT:    i16x8.extend_low_i8x16_u $push5=, $1
+; CHECK-NEXT:    i32x4.extend_low_i16x8_u $push6=, $pop5
+; CHECK-NEXT:    i32x4.mul $push9=, $pop8, $pop6
+; CHECK-NEXT:    v128.store 16($0), $pop9
+; CHECK-NEXT:    return
+  %wide.a = sext <8 x i8> %a to <8 x i32>
+  %wide.b = zext <8 x i8> %a to <8 x i32>
+  %mul = mul <8 x i32> %wide.a, %wide.b
+  ret <8 x i32> %mul
+}
+
+define <16 x i32> @sext_zext_mul_v16i8(<16 x i8> %a, <16 x i8> %b) {
+; CHECK-LABEL: sext_zext_mul_v16i8:
+; CHECK:         .functype sext_zext_mul_v16i8 (i32, v128, v128) -> ()
+; CHECK-NEXT:  # %bb.0:
+; CHECK-NEXT:    i16x8.extend_low_i8x16_s $push2=, $1
+; CHECK-NEXT:    i32x4.extend_low_i16x8_s $push3=, $pop2
+; CHECK-NEXT:    i16x8.extend_low_i8x16_u $push0=, $1
+; CHECK-NEXT:    i32x4.extend_low_i16x8_u $push1=, $pop0
+; CHECK-NEXT:    i32x4.mul $push4=, $pop3, $pop1
+; CHECK-NEXT:    v128.store 0($0), $pop4
+; CHECK-NEXT:    i8x16.shuffle $push25=, $1, $1, 12, 13, 14, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
+; CHECK-NEXT:    local.tee $push24=, $3=, $pop25
+; CHECK-NEXT:    i16x8.extend_low_i8x16_s $push7=, $pop24
+; CHECK-NEXT:    i32x4.extend_low_i16x8_s $push8=, $pop7
+; CHECK-NEXT:    i16x8.extend_low_i8x16_u $push5=, $3
+; CHECK-NEXT:    i32x4.extend_low_i16x8_u $push6=, $pop5
+; CHECK-NEXT:    i32x4.mul $push9=, $pop8, $pop6
+; CHECK-NEXT:    v128.store 48($0), $pop9
+; CHECK-NEXT:    i8x16.shuffle $push23=, $1, $1, 8, 9, 10, 11, 12, 13, 14, 15, 0, 0, 0, 0, 0, 0, 0, 0
+; CHECK-NEXT:    local.tee $push22=, $3=, $pop23
+; CHECK-NEXT:    i16x8.extend_low_i8x16_s $push12=, $pop22
+; CHECK-NEXT:    i32x4.extend_low_i16x8_s $push13=, $pop12
+; CHECK-NEXT:    i16x8.extend_low_i8x16_u $push10=, $3
+; CHECK-NEXT:    i32x4.extend_low_i16x8_u $push11=, $pop10
+; CHECK-NEXT:    i32x4.mul $push14=, $pop13, $pop11
+; CHECK-NEXT:    v128.store 32($0), $pop14
+; CHECK-NEXT:    i8x16.shuffle $push21=, $1, $1, 4, 5, 6, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
+; CHECK-NEXT:    local.tee $push20=, $1=, $pop21
+; CHECK-NEXT:    i16x8.extend_low_i8x16_s $push17=, $pop20
+; CHECK-NEXT:    i32x4.extend_low_i16x8_s $push18=, $pop17
+; CHECK-NEXT:    i16x8.extend_low_i8x16_u $push15=, $1
+; CHECK-NEXT:    i32x4.extend_low_i16x8_u $push16=, $pop15
+; CHECK-NEXT:    i32x4.mul $push19=, $pop18, $pop16
+; CHECK-NEXT:    v128.store 16($0), $pop19
+; CHECK-NEXT:    return
+  %wide.a = sext <16 x i8> %a to <16 x i32>
+  %wide.b = zext <16 x i8> %a to <16 x i32>
+  %mul = mul <16 x i32> %wide.a, %wide.b
+  ret <16 x i32> %mul
+}
+
+define <8 x i32> @zext_sext_mul_v8i16(<8 x i16> %a, <8 x i16> %b) {
+; CHECK-LABEL: zext_sext_mul_v8i16:
+; CHECK:         .functype zext_sext_mul_v8i16 (i32, v128, v128) -> ()
+; CHECK-NEXT:  # %bb.0:
+; CHECK-NEXT:    i32x4.extend_high_i16x8_u $push1=, $1
+; CHECK-NEXT:    i32x4.extend_high_i16x8_s $push0=, $1
+; CHECK-NEXT:    i32x4.mul $push2=, $pop1, $pop0
+; CHECK-NEXT:    v128.store 16($0), $pop2
+; CHECK-NEXT:    i32x4.extend_low_i16x8_u $push4=, $1
+; CHECK-NEXT:    i32x4.extend_low_i16x8_s $push3=, $1
+; CHECK-NEXT:    i32x4.mul $push5=, $pop4, $pop3
+; CHECK-NEXT:    v128.store 0($0), $pop5
+; CHECK-NEXT:    return
+  %wide.a = zext <8 x i16> %a to <8 x i32>
+  %wide.b = sext <8 x i16> %a to <8 x i32>
+  %mul = mul <8 x i32> %wide.a, %wide.b
+  ret <8 x i32> %mul
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/130785


More information about the llvm-commits mailing list