[llvm] [WebAssembly] Add fold support for dot (PR #151775)
Jasmine Tang via llvm-commits
llvm-commits at lists.llvm.org
Fri Aug 1 14:45:57 PDT 2025
https://github.com/badumbatish created https://github.com/llvm/llvm-project/pull/151775
Fixes https://github.com/llvm/llvm-project/issues/50154
>From 4d304c888e4aecac25ee4a17e52ab5e4861e1a6a Mon Sep 17 00:00:00 2001
From: Jasmine Tang <jjasmine at igalia.com>
Date: Fri, 1 Aug 2025 14:03:04 -0700
Subject: [PATCH 1/2] Precommit test
---
.../WebAssembly/simd-dot-reductions.ll | 32 +++++++++++++++++++
1 file changed, 32 insertions(+)
create mode 100644 llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll
diff --git a/llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll b/llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll
new file mode 100644
index 0000000000000..76c20c404e6f0
--- /dev/null
+++ b/llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll
@@ -0,0 +1,32 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -mattr=+simd128 | FileCheck %s
+
+target triple = "wasm32-unknown-unknown"
+define <4 x i32> @dot(<8 x i16> %a, <8 x i16> %b) {
+; CHECK-LABEL: dot:
+; CHECK: .functype dot (v128, v128) -> (v128)
+; CHECK-NEXT: .local v128
+; CHECK-NEXT: # %bb.0:
+; CHECK-NEXT: local.get 0
+; CHECK-NEXT: local.get 1
+; CHECK-NEXT: i32x4.extmul_low_i16x8_s
+; CHECK-NEXT: local.tee 2
+; CHECK-NEXT: local.get 0
+; CHECK-NEXT: local.get 1
+; CHECK-NEXT: i32x4.extmul_high_i16x8_s
+; CHECK-NEXT: local.tee 1
+; CHECK-NEXT: i8x16.shuffle 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27
+; CHECK-NEXT: local.get 2
+; CHECK-NEXT: local.get 1
+; CHECK-NEXT: i8x16.shuffle 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
+; CHECK-NEXT: i32x4.add
+; CHECK-NEXT: # fallthrough-return
+ %sext1 = sext <8 x i16> %a to <8 x i32>
+ %sext2 = sext <8 x i16> %b to <8 x i32>
+ %mul = mul nsw <8 x i32> %sext1, %sext2
+ %shuffle1 = shufflevector <8 x i32> %mul, <8 x i32> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
+ %shuffle2 = shufflevector <8 x i32> %mul, <8 x i32> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
+ %res = add <4 x i32> %shuffle1, %shuffle2
+ ret <4 x i32> %res
+}
+
>From cb9aac0407cb67fbf705a7c18c2b842bb4623466 Mon Sep 17 00:00:00 2001
From: Jasmine Tang <jjasmine at igalia.com>
Date: Fri, 1 Aug 2025 14:21:51 -0700
Subject: [PATCH 2/2] Added combine support for dot
---
.../WebAssembly/WebAssemblyISelLowering.cpp | 51 +++++++++++++++++++
.../WebAssembly/simd-dot-reductions.ll | 13 +----
2 files changed, 52 insertions(+), 12 deletions(-)
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index cd434f7a331e4..648e3b6b2b440 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -192,6 +192,9 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
// Combine wide-vector muls, with extend inputs, to extmul_half.
setTargetDAGCombine(ISD::MUL);
+ // Combine add with vector shuffle of muls to dots
+ setTargetDAGCombine(ISD::ADD);
+
// Combine vector mask reductions into alltrue/anytrue
setTargetDAGCombine(ISD::SETCC);
@@ -3436,6 +3439,52 @@ static SDValue performSETCCCombine(SDNode *N,
return SDValue();
}
+static SDValue performAddCombine(SDNode *N, SelectionDAG &DAG) {
+ assert(N->getOpcode() == ISD::ADD);
+ EVT VT = N->getValueType(0);
+ SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
+
+ if (VT != MVT::v4i32)
+ return SDValue();
+
+ auto IsShuffleWithMask = [](SDValue V, ArrayRef<int> ShuffleValue) {
+ if (V.getOpcode() != ISD::VECTOR_SHUFFLE)
+ return SDValue();
+ if (cast<ShuffleVectorSDNode>(V)->getMask() != ShuffleValue)
+ return SDValue();
+ return V;
+ };
+ auto ShuffleA = IsShuffleWithMask(N0, {0, 2, 4, 6});
+ auto ShuffleB = IsShuffleWithMask(N1, {1, 3, 5, 7});
+ // two SDValues must be muls
+ if (!ShuffleA || !ShuffleB)
+ return SDValue();
+
+ if (ShuffleA.getOperand(0) != ShuffleB.getOperand(0) ||
+ ShuffleA.getOperand(1) != ShuffleB.getOperand(1))
+ return SDValue();
+
+ auto IsMulExtend =
+ [](SDValue V, WebAssemblyISD::NodeType I) -> std::pair<SDValue, SDValue> {
+ if (V.getOpcode() != ISD::MUL)
+ return {};
+
+ auto V0 = V.getOperand(0), V1 = V.getOperand(1);
+ if (V0.getOpcode() != I || V1.getOpcode() != I)
+ return {};
+ return {V0.getOperand(0), V1.getOperand(0)};
+ };
+
+ auto [LowA, LowB] =
+ IsMulExtend(ShuffleA.getOperand(0), WebAssemblyISD::EXTEND_LOW_S);
+ auto [HighA, HighB] =
+ IsMulExtend(ShuffleA.getOperand(1), WebAssemblyISD::EXTEND_HIGH_S);
+
+ if (!LowA || !LowB || !HighA || !HighB || LowA != HighA || LowB != HighB)
+ return SDValue();
+
+ return DAG.getNode(WebAssemblyISD::DOT, SDLoc(N), MVT::v4i32, LowA, LowB);
+}
static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG) {
assert(N->getOpcode() == ISD::MUL);
EVT VT = N->getValueType(0);
@@ -3558,5 +3607,7 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
}
case ISD::MUL:
return performMulCombine(N, DCI.DAG);
+ case ISD::ADD:
+ return performAddCombine(N, DCI.DAG);
}
}
diff --git a/llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll b/llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll
index 76c20c404e6f0..7ac49794491a1 100644
--- a/llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll
+++ b/llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll
@@ -5,21 +5,10 @@ target triple = "wasm32-unknown-unknown"
define <4 x i32> @dot(<8 x i16> %a, <8 x i16> %b) {
; CHECK-LABEL: dot:
; CHECK: .functype dot (v128, v128) -> (v128)
-; CHECK-NEXT: .local v128
; CHECK-NEXT: # %bb.0:
; CHECK-NEXT: local.get 0
; CHECK-NEXT: local.get 1
-; CHECK-NEXT: i32x4.extmul_low_i16x8_s
-; CHECK-NEXT: local.tee 2
-; CHECK-NEXT: local.get 0
-; CHECK-NEXT: local.get 1
-; CHECK-NEXT: i32x4.extmul_high_i16x8_s
-; CHECK-NEXT: local.tee 1
-; CHECK-NEXT: i8x16.shuffle 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27
-; CHECK-NEXT: local.get 2
-; CHECK-NEXT: local.get 1
-; CHECK-NEXT: i8x16.shuffle 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
-; CHECK-NEXT: i32x4.add
+; CHECK-NEXT: i32x4.dot_i16x8_s
; CHECK-NEXT: # fallthrough-return
%sext1 = sext <8 x i16> %a to <8 x i32>
%sext2 = sext <8 x i16> %b to <8 x i32>
More information about the llvm-commits
mailing list