[llvm] [WebAssembly] Lower wide SIMD i8 muls (PR #130785)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Mar 11 08:51:59 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-webassembly
Author: Sam Parker (sparker-arm)
<details>
<summary>Changes</summary>
Currently, 'wide' i32 simd multiplication, with extended i8 elements, will perform the multiplication with i32 So, for IR like the following:
```
%wide.a = sext <8 x i8> %a to <8 x i32>
%wide.b = sext <8 x i8> %a to <8 x i32>
%mul = mul <8 x i32> %wide.a, %wide.b
ret <8 x i32> %mul
```
We would generate the following sequence:
```
i16x8.extend_low_i8x16_s $push6=, $1
local.tee $push5=, $3=, $pop6
i32x4.extmul_low_i16x8_s $push0=, $pop5, $3
v128.store 0($0), $pop0
i8x16.shuffle $push1=, $1, $1, 4, 5, 6, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
i16x8.extend_low_i8x16_s $push4=, $pop1
local.tee $push3=, $1=, $pop4
i32x4.extmul_low_i16x8_s $push2=, $pop3, $1
v128.store 16($0), $pop2
return
```
But now we perform the multiplication with i16, resulting in:
```
i16x8.extmul_low_i8x16_s $push3=, $1, $1
local.tee $push2=, $1=, $pop3
i32x4.extend_high_i16x8_s $push0=, $pop2
v128.store 16($0), $pop0
i32x4.extend_low_i16x8_s $push1=, $1
v128.store 0($0), $pop1
return
```
---
Full diff: https://github.com/llvm/llvm-project/pull/130785.diff
2 Files Affected:
- (modified) llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp (+93-2)
- (added) llvm/test/CodeGen/WebAssembly/wide-simd-mul.ll (+197)
``````````diff
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index b24a45c2d8898..9ae46e709d823 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -183,6 +183,9 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
// Combine partial.reduce.add before legalization gets confused.
setTargetDAGCombine(ISD::INTRINSIC_WO_CHAIN);
+ // Combine wide-vector muls, with extend inputs, to extmul_half.
+ setTargetDAGCombine(ISD::MUL);
+
// Combine vector mask reductions into alltrue/anytrue
setTargetDAGCombine(ISD::SETCC);
@@ -1461,8 +1464,7 @@ WebAssemblyTargetLowering::LowerCall(CallLoweringInfo &CLI,
bool WebAssemblyTargetLowering::CanLowerReturn(
CallingConv::ID /*CallConv*/, MachineFunction & /*MF*/, bool /*IsVarArg*/,
- const SmallVectorImpl<ISD::OutputArg> &Outs,
- LLVMContext & /*Context*/,
+ const SmallVectorImpl<ISD::OutputArg> &Outs, LLVMContext & /*Context*/,
const Type *RetTy) const {
// WebAssembly can only handle returning tuples with multivalue enabled
return WebAssembly::canLowerReturn(Outs.size(), Subtarget);
@@ -3254,6 +3256,93 @@ static SDValue performSETCCCombine(SDNode *N,
return SDValue();
}
+static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG) {
+ assert(N->getOpcode() == ISD::MUL);
+ EVT VT = N->getValueType(0);
+ if (VT != MVT::v8i32 && VT != MVT::v16i32)
+ return SDValue();
+
+ // Mul with extending inputs.
+ SDValue LHS = N->getOperand(0);
+ SDValue RHS = N->getOperand(1);
+ if (LHS.getOpcode() != RHS.getOpcode())
+ return SDValue();
+
+ if (LHS.getOpcode() != ISD::SIGN_EXTEND &&
+ LHS.getOpcode() != ISD::ZERO_EXTEND)
+ return SDValue();
+
+ if (LHS->getOperand(0).getValueType() != RHS->getOperand(0).getValueType())
+ return SDValue();
+
+ EVT FromVT = LHS->getOperand(0).getValueType();
+ EVT EltTy = FromVT.getVectorElementType();
+ if (EltTy != MVT::i8)
+ return SDValue();
+
+ // For an input DAG that looks like this
+ // %a = input_type
+ // %b = input_type
+ // %lhs = extend %a to output_type
+ // %rhs = extend %b to output_type
+ // %mul = mul %lhs, %rhs
+
+ // input_type | output_type | instructions
+ // v16i8 | v16i32 | %low = i16x8.extmul_low_i8x16_ %a, %b
+ // | | %high = i16x8.extmul_high_i8x16_, %a, %b
+ // | | %low_low = i32x4.ext_low_i16x8_ %low
+ // | | %low_high = i32x4.ext_high_i16x8_ %low
+ // | | %high_low = i32x4.ext_low_i16x8_ %high
+ // | | %high_high = i32x4.ext_high_i16x8_ %high
+ // | | %res = concat_vector(...)
+ // v8i8 | v8i32 | %low = i16x8.extmul_low_i8x16_ %a, %b
+ // | | %low_low = i32x4.ext_low_i16x8_ %low
+ // | | %low_high = i32x4.ext_high_i16x8_ %low
+ // | | %res = concat_vector(%low_low, %low_high)
+
+ SDLoc DL(N);
+ unsigned NumElts = VT.getVectorNumElements();
+ SDValue ExtendInLHS = LHS->getOperand(0);
+ SDValue ExtendInRHS = RHS->getOperand(0);
+ bool IsSigned = LHS->getOpcode() == ISD::SIGN_EXTEND;
+ unsigned ExtendLowOpc =
+ IsSigned ? WebAssemblyISD::EXTEND_LOW_S : WebAssemblyISD::EXTEND_LOW_U;
+ unsigned ExtendHighOpc =
+ IsSigned ? WebAssemblyISD::EXTEND_HIGH_S : WebAssemblyISD::EXTEND_HIGH_U;
+
+ auto GetExtendLow = [&DAG, &DL, &ExtendLowOpc](EVT VT, SDValue Op) {
+ return DAG.getNode(ExtendLowOpc, DL, VT, Op);
+ };
+ auto GetExtendHigh = [&DAG, &DL, &ExtendHighOpc](EVT VT, SDValue Op) {
+ return DAG.getNode(ExtendHighOpc, DL, VT, Op);
+ };
+
+ if (NumElts == 16) {
+ SDValue LowLHS = GetExtendLow(MVT::v8i16, ExtendInLHS);
+ SDValue LowRHS = GetExtendLow(MVT::v8i16, ExtendInRHS);
+ SDValue MulLow = DAG.getNode(ISD::MUL, DL, MVT::v8i16, LowLHS, LowRHS);
+ SDValue HighLHS = GetExtendHigh(MVT::v8i16, ExtendInLHS);
+ SDValue HighRHS = GetExtendHigh(MVT::v8i16, ExtendInRHS);
+ SDValue MulHigh = DAG.getNode(ISD::MUL, DL, MVT::v8i16, HighLHS, HighRHS);
+ SDValue SubVectors[] = {
+ GetExtendLow(MVT::v4i32, MulLow),
+ GetExtendHigh(MVT::v4i32, MulLow),
+ GetExtendLow(MVT::v4i32, MulHigh),
+ GetExtendHigh(MVT::v4i32, MulHigh),
+ };
+ return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, SubVectors);
+ } else {
+ assert(NumElts == 8);
+ SDValue LowLHS = DAG.getNode(LHS->getOpcode(), DL, MVT::v8i16, ExtendInLHS);
+ SDValue LowRHS = DAG.getNode(RHS->getOpcode(), DL, MVT::v8i16, ExtendInRHS);
+ SDValue MulLow = DAG.getNode(ISD::MUL, DL, MVT::v8i16, LowLHS, LowRHS);
+ SDValue Lo = GetExtendLow(MVT::v4i32, MulLow);
+ SDValue Hi = GetExtendHigh(MVT::v4i32, MulLow);
+ return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Lo, Hi);
+ }
+ return SDValue();
+}
+
SDValue
WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
@@ -3281,5 +3370,7 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
return performTruncateCombine(N, DCI);
case ISD::INTRINSIC_WO_CHAIN:
return performLowerPartialReduction(N, DCI.DAG);
+ case ISD::MUL:
+ return performMulCombine(N, DCI.DAG);
}
}
diff --git a/llvm/test/CodeGen/WebAssembly/wide-simd-mul.ll b/llvm/test/CodeGen/WebAssembly/wide-simd-mul.ll
new file mode 100644
index 0000000000000..94aa197bfd564
--- /dev/null
+++ b/llvm/test/CodeGen/WebAssembly/wide-simd-mul.ll
@@ -0,0 +1,197 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -mtriple=wasm32 -verify-machineinstrs -disable-wasm-fallthrough-return-opt -wasm-disable-explicit-locals -wasm-keep-registers -mattr=+simd128 | FileCheck %s
+
+define <8 x i32> @sext_mul_v8i8(<8 x i8> %a, <8 x i8> %b) {
+; CHECK-LABEL: sext_mul_v8i8:
+; CHECK: .functype sext_mul_v8i8 (i32, v128, v128) -> ()
+; CHECK-NEXT: # %bb.0:
+; CHECK-NEXT: i16x8.extmul_low_i8x16_s $push3=, $1, $1
+; CHECK-NEXT: local.tee $push2=, $1=, $pop3
+; CHECK-NEXT: i32x4.extend_high_i16x8_s $push0=, $pop2
+; CHECK-NEXT: v128.store 16($0), $pop0
+; CHECK-NEXT: i32x4.extend_low_i16x8_s $push1=, $1
+; CHECK-NEXT: v128.store 0($0), $pop1
+; CHECK-NEXT: return
+ %wide.a = sext <8 x i8> %a to <8 x i32>
+ %wide.b = sext <8 x i8> %a to <8 x i32>
+ %mul = mul <8 x i32> %wide.a, %wide.b
+ ret <8 x i32> %mul
+}
+
+define <16 x i32> @sext_mul_v16i8(<16 x i8> %a, <16 x i8> %b) {
+; CHECK-LABEL: sext_mul_v16i8:
+; CHECK: .functype sext_mul_v16i8 (i32, v128, v128) -> ()
+; CHECK-NEXT: # %bb.0:
+; CHECK-NEXT: i16x8.extmul_high_i8x16_s $push7=, $1, $1
+; CHECK-NEXT: local.tee $push6=, $3=, $pop7
+; CHECK-NEXT: i32x4.extend_high_i16x8_s $push0=, $pop6
+; CHECK-NEXT: v128.store 48($0), $pop0
+; CHECK-NEXT: i32x4.extend_low_i16x8_s $push1=, $3
+; CHECK-NEXT: v128.store 32($0), $pop1
+; CHECK-NEXT: i16x8.extmul_low_i8x16_s $push5=, $1, $1
+; CHECK-NEXT: local.tee $push4=, $1=, $pop5
+; CHECK-NEXT: i32x4.extend_high_i16x8_s $push2=, $pop4
+; CHECK-NEXT: v128.store 16($0), $pop2
+; CHECK-NEXT: i32x4.extend_low_i16x8_s $push3=, $1
+; CHECK-NEXT: v128.store 0($0), $pop3
+; CHECK-NEXT: return
+ %wide.a = sext <16 x i8> %a to <16 x i32>
+ %wide.b = sext <16 x i8> %a to <16 x i32>
+ %mul = mul <16 x i32> %wide.a, %wide.b
+ ret <16 x i32> %mul
+}
+
+define <8 x i32> @sext_mul_v8i16(<8 x i16> %a, <8 x i16> %b) {
+; CHECK-LABEL: sext_mul_v8i16:
+; CHECK: .functype sext_mul_v8i16 (i32, v128, v128) -> ()
+; CHECK-NEXT: # %bb.0:
+; CHECK-NEXT: i32x4.extmul_high_i16x8_s $push0=, $1, $1
+; CHECK-NEXT: v128.store 16($0), $pop0
+; CHECK-NEXT: i32x4.extmul_low_i16x8_s $push1=, $1, $1
+; CHECK-NEXT: v128.store 0($0), $pop1
+; CHECK-NEXT: return
+ %wide.a = sext <8 x i16> %a to <8 x i32>
+ %wide.b = sext <8 x i16> %a to <8 x i32>
+ %mul = mul <8 x i32> %wide.a, %wide.b
+ ret <8 x i32> %mul
+}
+
+define <8 x i32> @zext_mul_v8i8(<8 x i8> %a, <8 x i8> %b) {
+; CHECK-LABEL: zext_mul_v8i8:
+; CHECK: .functype zext_mul_v8i8 (i32, v128, v128) -> ()
+; CHECK-NEXT: # %bb.0:
+; CHECK-NEXT: i16x8.extmul_low_i8x16_u $push3=, $1, $1
+; CHECK-NEXT: local.tee $push2=, $1=, $pop3
+; CHECK-NEXT: i32x4.extend_high_i16x8_u $push0=, $pop2
+; CHECK-NEXT: v128.store 16($0), $pop0
+; CHECK-NEXT: i32x4.extend_low_i16x8_u $push1=, $1
+; CHECK-NEXT: v128.store 0($0), $pop1
+; CHECK-NEXT: return
+ %wide.a = zext <8 x i8> %a to <8 x i32>
+ %wide.b = zext <8 x i8> %a to <8 x i32>
+ %mul = mul <8 x i32> %wide.a, %wide.b
+ ret <8 x i32> %mul
+}
+
+define <16 x i32> @zext_mul_v16i8(<16 x i8> %a, <16 x i8> %b) {
+; CHECK-LABEL: zext_mul_v16i8:
+; CHECK: .functype zext_mul_v16i8 (i32, v128, v128) -> ()
+; CHECK-NEXT: # %bb.0:
+; CHECK-NEXT: i16x8.extmul_high_i8x16_u $push7=, $1, $1
+; CHECK-NEXT: local.tee $push6=, $3=, $pop7
+; CHECK-NEXT: i32x4.extend_high_i16x8_u $push0=, $pop6
+; CHECK-NEXT: v128.store 48($0), $pop0
+; CHECK-NEXT: i32x4.extend_low_i16x8_u $push1=, $3
+; CHECK-NEXT: v128.store 32($0), $pop1
+; CHECK-NEXT: i16x8.extmul_low_i8x16_u $push5=, $1, $1
+; CHECK-NEXT: local.tee $push4=, $1=, $pop5
+; CHECK-NEXT: i32x4.extend_high_i16x8_u $push2=, $pop4
+; CHECK-NEXT: v128.store 16($0), $pop2
+; CHECK-NEXT: i32x4.extend_low_i16x8_u $push3=, $1
+; CHECK-NEXT: v128.store 0($0), $pop3
+; CHECK-NEXT: return
+ %wide.a = zext <16 x i8> %a to <16 x i32>
+ %wide.b = zext <16 x i8> %a to <16 x i32>
+ %mul = mul <16 x i32> %wide.a, %wide.b
+ ret <16 x i32> %mul
+}
+
+define <8 x i32> @zext_mul_v8i16(<8 x i16> %a, <8 x i16> %b) {
+; CHECK-LABEL: zext_mul_v8i16:
+; CHECK: .functype zext_mul_v8i16 (i32, v128, v128) -> ()
+; CHECK-NEXT: # %bb.0:
+; CHECK-NEXT: i32x4.extmul_high_i16x8_u $push0=, $1, $1
+; CHECK-NEXT: v128.store 16($0), $pop0
+; CHECK-NEXT: i32x4.extmul_low_i16x8_u $push1=, $1, $1
+; CHECK-NEXT: v128.store 0($0), $pop1
+; CHECK-NEXT: return
+ %wide.a = zext <8 x i16> %a to <8 x i32>
+ %wide.b = zext <8 x i16> %a to <8 x i32>
+ %mul = mul <8 x i32> %wide.a, %wide.b
+ ret <8 x i32> %mul
+}
+
+define <8 x i32> @sext_zext_mul_v8i8(<8 x i8> %a, <8 x i8> %b) {
+; CHECK-LABEL: sext_zext_mul_v8i8:
+; CHECK: .functype sext_zext_mul_v8i8 (i32, v128, v128) -> ()
+; CHECK-NEXT: # %bb.0:
+; CHECK-NEXT: i16x8.extend_low_i8x16_s $push2=, $1
+; CHECK-NEXT: i32x4.extend_low_i16x8_s $push3=, $pop2
+; CHECK-NEXT: i16x8.extend_low_i8x16_u $push0=, $1
+; CHECK-NEXT: i32x4.extend_low_i16x8_u $push1=, $pop0
+; CHECK-NEXT: i32x4.mul $push4=, $pop3, $pop1
+; CHECK-NEXT: v128.store 0($0), $pop4
+; CHECK-NEXT: i8x16.shuffle $push11=, $1, $1, 4, 5, 6, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
+; CHECK-NEXT: local.tee $push10=, $1=, $pop11
+; CHECK-NEXT: i16x8.extend_low_i8x16_s $push7=, $pop10
+; CHECK-NEXT: i32x4.extend_low_i16x8_s $push8=, $pop7
+; CHECK-NEXT: i16x8.extend_low_i8x16_u $push5=, $1
+; CHECK-NEXT: i32x4.extend_low_i16x8_u $push6=, $pop5
+; CHECK-NEXT: i32x4.mul $push9=, $pop8, $pop6
+; CHECK-NEXT: v128.store 16($0), $pop9
+; CHECK-NEXT: return
+ %wide.a = sext <8 x i8> %a to <8 x i32>
+ %wide.b = zext <8 x i8> %a to <8 x i32>
+ %mul = mul <8 x i32> %wide.a, %wide.b
+ ret <8 x i32> %mul
+}
+
+define <16 x i32> @sext_zext_mul_v16i8(<16 x i8> %a, <16 x i8> %b) {
+; CHECK-LABEL: sext_zext_mul_v16i8:
+; CHECK: .functype sext_zext_mul_v16i8 (i32, v128, v128) -> ()
+; CHECK-NEXT: # %bb.0:
+; CHECK-NEXT: i16x8.extend_low_i8x16_s $push2=, $1
+; CHECK-NEXT: i32x4.extend_low_i16x8_s $push3=, $pop2
+; CHECK-NEXT: i16x8.extend_low_i8x16_u $push0=, $1
+; CHECK-NEXT: i32x4.extend_low_i16x8_u $push1=, $pop0
+; CHECK-NEXT: i32x4.mul $push4=, $pop3, $pop1
+; CHECK-NEXT: v128.store 0($0), $pop4
+; CHECK-NEXT: i8x16.shuffle $push25=, $1, $1, 12, 13, 14, 15, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
+; CHECK-NEXT: local.tee $push24=, $3=, $pop25
+; CHECK-NEXT: i16x8.extend_low_i8x16_s $push7=, $pop24
+; CHECK-NEXT: i32x4.extend_low_i16x8_s $push8=, $pop7
+; CHECK-NEXT: i16x8.extend_low_i8x16_u $push5=, $3
+; CHECK-NEXT: i32x4.extend_low_i16x8_u $push6=, $pop5
+; CHECK-NEXT: i32x4.mul $push9=, $pop8, $pop6
+; CHECK-NEXT: v128.store 48($0), $pop9
+; CHECK-NEXT: i8x16.shuffle $push23=, $1, $1, 8, 9, 10, 11, 12, 13, 14, 15, 0, 0, 0, 0, 0, 0, 0, 0
+; CHECK-NEXT: local.tee $push22=, $3=, $pop23
+; CHECK-NEXT: i16x8.extend_low_i8x16_s $push12=, $pop22
+; CHECK-NEXT: i32x4.extend_low_i16x8_s $push13=, $pop12
+; CHECK-NEXT: i16x8.extend_low_i8x16_u $push10=, $3
+; CHECK-NEXT: i32x4.extend_low_i16x8_u $push11=, $pop10
+; CHECK-NEXT: i32x4.mul $push14=, $pop13, $pop11
+; CHECK-NEXT: v128.store 32($0), $pop14
+; CHECK-NEXT: i8x16.shuffle $push21=, $1, $1, 4, 5, 6, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
+; CHECK-NEXT: local.tee $push20=, $1=, $pop21
+; CHECK-NEXT: i16x8.extend_low_i8x16_s $push17=, $pop20
+; CHECK-NEXT: i32x4.extend_low_i16x8_s $push18=, $pop17
+; CHECK-NEXT: i16x8.extend_low_i8x16_u $push15=, $1
+; CHECK-NEXT: i32x4.extend_low_i16x8_u $push16=, $pop15
+; CHECK-NEXT: i32x4.mul $push19=, $pop18, $pop16
+; CHECK-NEXT: v128.store 16($0), $pop19
+; CHECK-NEXT: return
+ %wide.a = sext <16 x i8> %a to <16 x i32>
+ %wide.b = zext <16 x i8> %a to <16 x i32>
+ %mul = mul <16 x i32> %wide.a, %wide.b
+ ret <16 x i32> %mul
+}
+
+define <8 x i32> @zext_sext_mul_v8i16(<8 x i16> %a, <8 x i16> %b) {
+; CHECK-LABEL: zext_sext_mul_v8i16:
+; CHECK: .functype zext_sext_mul_v8i16 (i32, v128, v128) -> ()
+; CHECK-NEXT: # %bb.0:
+; CHECK-NEXT: i32x4.extend_high_i16x8_u $push1=, $1
+; CHECK-NEXT: i32x4.extend_high_i16x8_s $push0=, $1
+; CHECK-NEXT: i32x4.mul $push2=, $pop1, $pop0
+; CHECK-NEXT: v128.store 16($0), $pop2
+; CHECK-NEXT: i32x4.extend_low_i16x8_u $push4=, $1
+; CHECK-NEXT: i32x4.extend_low_i16x8_s $push3=, $1
+; CHECK-NEXT: i32x4.mul $push5=, $pop4, $pop3
+; CHECK-NEXT: v128.store 0($0), $pop5
+; CHECK-NEXT: return
+ %wide.a = zext <8 x i16> %a to <8 x i32>
+ %wide.b = sext <8 x i16> %a to <8 x i32>
+ %mul = mul <8 x i32> %wide.a, %wide.b
+ ret <8 x i32> %mul
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/130785
More information about the llvm-commits
mailing list