[llvm] [WebAssembly] Add fold support for dot (PR #151775)
Jasmine Tang via llvm-commits
llvm-commits at lists.llvm.org
Tue Aug 5 11:14:06 PDT 2025
https://github.com/badumbatish updated https://github.com/llvm/llvm-project/pull/151775
>From 4d304c888e4aecac25ee4a17e52ab5e4861e1a6a Mon Sep 17 00:00:00 2001
From: Jasmine Tang <jjasmine at igalia.com>
Date: Fri, 1 Aug 2025 14:03:04 -0700
Subject: [PATCH 1/4] Precommit test
---
.../WebAssembly/simd-dot-reductions.ll | 32 +++++++++++++++++++
1 file changed, 32 insertions(+)
create mode 100644 llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll
diff --git a/llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll b/llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll
new file mode 100644
index 0000000000000..76c20c404e6f0
--- /dev/null
+++ b/llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll
@@ -0,0 +1,32 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -mattr=+simd128 | FileCheck %s
+
+target triple = "wasm32-unknown-unknown"
+define <4 x i32> @dot(<8 x i16> %a, <8 x i16> %b) {
+; CHECK-LABEL: dot:
+; CHECK: .functype dot (v128, v128) -> (v128)
+; CHECK-NEXT: .local v128
+; CHECK-NEXT: # %bb.0:
+; CHECK-NEXT: local.get 0
+; CHECK-NEXT: local.get 1
+; CHECK-NEXT: i32x4.extmul_low_i16x8_s
+; CHECK-NEXT: local.tee 2
+; CHECK-NEXT: local.get 0
+; CHECK-NEXT: local.get 1
+; CHECK-NEXT: i32x4.extmul_high_i16x8_s
+; CHECK-NEXT: local.tee 1
+; CHECK-NEXT: i8x16.shuffle 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27
+; CHECK-NEXT: local.get 2
+; CHECK-NEXT: local.get 1
+; CHECK-NEXT: i8x16.shuffle 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
+; CHECK-NEXT: i32x4.add
+; CHECK-NEXT: # fallthrough-return
+ %sext1 = sext <8 x i16> %a to <8 x i32>
+ %sext2 = sext <8 x i16> %b to <8 x i32>
+ %mul = mul nsw <8 x i32> %sext1, %sext2
+ %shuffle1 = shufflevector <8 x i32> %mul, <8 x i32> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
+ %shuffle2 = shufflevector <8 x i32> %mul, <8 x i32> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
+ %res = add <4 x i32> %shuffle1, %shuffle2
+ ret <4 x i32> %res
+}
+
>From cb9aac0407cb67fbf705a7c18c2b842bb4623466 Mon Sep 17 00:00:00 2001
From: Jasmine Tang <jjasmine at igalia.com>
Date: Fri, 1 Aug 2025 14:21:51 -0700
Subject: [PATCH 2/4] Added combine support for dot
---
.../WebAssembly/WebAssemblyISelLowering.cpp | 51 +++++++++++++++++++
.../WebAssembly/simd-dot-reductions.ll | 13 +----
2 files changed, 52 insertions(+), 12 deletions(-)
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index cd434f7a331e4..648e3b6b2b440 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -192,6 +192,9 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
// Combine wide-vector muls, with extend inputs, to extmul_half.
setTargetDAGCombine(ISD::MUL);
+ // Combine add with vector shuffle of muls to dots
+ setTargetDAGCombine(ISD::ADD);
+
// Combine vector mask reductions into alltrue/anytrue
setTargetDAGCombine(ISD::SETCC);
@@ -3436,6 +3439,52 @@ static SDValue performSETCCCombine(SDNode *N,
return SDValue();
}
+static SDValue performAddCombine(SDNode *N, SelectionDAG &DAG) {
+ assert(N->getOpcode() == ISD::ADD);
+ EVT VT = N->getValueType(0);
+ SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
+
+ if (VT != MVT::v4i32)
+ return SDValue();
+
+ auto IsShuffleWithMask = [](SDValue V, ArrayRef<int> ShuffleValue) {
+ if (V.getOpcode() != ISD::VECTOR_SHUFFLE)
+ return SDValue();
+ if (cast<ShuffleVectorSDNode>(V)->getMask() != ShuffleValue)
+ return SDValue();
+ return V;
+ };
+ auto ShuffleA = IsShuffleWithMask(N0, {0, 2, 4, 6});
+ auto ShuffleB = IsShuffleWithMask(N1, {1, 3, 5, 7});
+ // two SDValues must be muls
+ if (!ShuffleA || !ShuffleB)
+ return SDValue();
+
+ if (ShuffleA.getOperand(0) != ShuffleB.getOperand(0) ||
+ ShuffleA.getOperand(1) != ShuffleB.getOperand(1))
+ return SDValue();
+
+ auto IsMulExtend =
+ [](SDValue V, WebAssemblyISD::NodeType I) -> std::pair<SDValue, SDValue> {
+ if (V.getOpcode() != ISD::MUL)
+ return {};
+
+ auto V0 = V.getOperand(0), V1 = V.getOperand(1);
+ if (V0.getOpcode() != I || V1.getOpcode() != I)
+ return {};
+ return {V0.getOperand(0), V1.getOperand(0)};
+ };
+
+ auto [LowA, LowB] =
+ IsMulExtend(ShuffleA.getOperand(0), WebAssemblyISD::EXTEND_LOW_S);
+ auto [HighA, HighB] =
+ IsMulExtend(ShuffleA.getOperand(1), WebAssemblyISD::EXTEND_HIGH_S);
+
+ if (!LowA || !LowB || !HighA || !HighB || LowA != HighA || LowB != HighB)
+ return SDValue();
+
+ return DAG.getNode(WebAssemblyISD::DOT, SDLoc(N), MVT::v4i32, LowA, LowB);
+}
static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG) {
assert(N->getOpcode() == ISD::MUL);
EVT VT = N->getValueType(0);
@@ -3558,5 +3607,7 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
}
case ISD::MUL:
return performMulCombine(N, DCI.DAG);
+ case ISD::ADD:
+ return performAddCombine(N, DCI.DAG);
}
}
diff --git a/llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll b/llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll
index 76c20c404e6f0..7ac49794491a1 100644
--- a/llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll
+++ b/llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll
@@ -5,21 +5,10 @@ target triple = "wasm32-unknown-unknown"
define <4 x i32> @dot(<8 x i16> %a, <8 x i16> %b) {
; CHECK-LABEL: dot:
; CHECK: .functype dot (v128, v128) -> (v128)
-; CHECK-NEXT: .local v128
; CHECK-NEXT: # %bb.0:
; CHECK-NEXT: local.get 0
; CHECK-NEXT: local.get 1
-; CHECK-NEXT: i32x4.extmul_low_i16x8_s
-; CHECK-NEXT: local.tee 2
-; CHECK-NEXT: local.get 0
-; CHECK-NEXT: local.get 1
-; CHECK-NEXT: i32x4.extmul_high_i16x8_s
-; CHECK-NEXT: local.tee 1
-; CHECK-NEXT: i8x16.shuffle 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27
-; CHECK-NEXT: local.get 2
-; CHECK-NEXT: local.get 1
-; CHECK-NEXT: i8x16.shuffle 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
-; CHECK-NEXT: i32x4.add
+; CHECK-NEXT: i32x4.dot_i16x8_s
; CHECK-NEXT: # fallthrough-return
%sext1 = sext <8 x i16> %a to <8 x i32>
%sext2 = sext <8 x i16> %b to <8 x i32>
>From 86fe99b07c58ebd696bd6bd24ae4e74a728c336c Mon Sep 17 00:00:00 2001
From: Jasmine Tang <jjasmine at igalia.com>
Date: Tue, 5 Aug 2025 10:32:51 -0700
Subject: [PATCH 3/4] Transition to tablegen for pattern
---
.../WebAssembly/WebAssemblyISelLowering.cpp | 52 -------------------
.../WebAssembly/WebAssemblyInstrSIMD.td | 21 ++++++++
2 files changed, 21 insertions(+), 52 deletions(-)
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index 0955e2d2f39b0..3f80b2ab2bd6d 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -192,9 +192,6 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
// Combine wide-vector muls, with extend inputs, to extmul_half.
setTargetDAGCombine(ISD::MUL);
- // Combine add with vector shuffle of muls to dots
- setTargetDAGCombine(ISD::ADD);
-
// Combine vector mask reductions into alltrue/anytrue
setTargetDAGCombine(ISD::SETCC);
@@ -3439,53 +3436,6 @@ static SDValue performSETCCCombine(SDNode *N,
return SDValue();
}
-static SDValue performAddCombine(SDNode *N, SelectionDAG &DAG) {
- assert(N->getOpcode() == ISD::ADD);
- EVT VT = N->getValueType(0);
- SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
-
- if (VT != MVT::v4i32)
- return SDValue();
-
- auto IsShuffleWithMask = [](SDValue V, ArrayRef<int> ShuffleValue) {
- if (V.getOpcode() != ISD::VECTOR_SHUFFLE)
- return SDValue();
- if (cast<ShuffleVectorSDNode>(V)->getMask() != ShuffleValue)
- return SDValue();
- return V;
- };
- auto ShuffleA = IsShuffleWithMask(N0, {0, 2, 4, 6});
- auto ShuffleB = IsShuffleWithMask(N1, {1, 3, 5, 7});
- // two SDValues must be muls
- if (!ShuffleA || !ShuffleB)
- return SDValue();
-
- if (ShuffleA.getOperand(0) != ShuffleB.getOperand(0) ||
- ShuffleA.getOperand(1) != ShuffleB.getOperand(1))
- return SDValue();
-
- auto IsMulExtend =
- [](SDValue V, WebAssemblyISD::NodeType I) -> std::pair<SDValue, SDValue> {
- if (V.getOpcode() != ISD::MUL)
- return {};
-
- auto V0 = V.getOperand(0), V1 = V.getOperand(1);
- if (V0.getOpcode() != I || V1.getOpcode() != I)
- return {};
- return {V0.getOperand(0), V1.getOperand(0)};
- };
-
- auto [LowA, LowB] =
- IsMulExtend(ShuffleA.getOperand(0), WebAssemblyISD::EXTEND_LOW_S);
- auto [HighA, HighB] =
- IsMulExtend(ShuffleA.getOperand(1), WebAssemblyISD::EXTEND_HIGH_S);
-
- if (!LowA || !LowB || !HighA || !HighB || LowA != HighA || LowB != HighB)
- return SDValue();
-
- return DAG.getNode(WebAssemblyISD::DOT, SDLoc(N), MVT::v4i32, LowA, LowB);
-}
-
static SDValue TryWideExtMulCombine(SDNode *N, SelectionDAG &DAG) {
EVT VT = N->getValueType(0);
if (VT != MVT::v8i32 && VT != MVT::v16i32)
@@ -3647,7 +3597,5 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
}
case ISD::MUL:
return performMulCombine(N, DCI);
- case ISD::ADD:
- return performAddCombine(N, DCI.DAG);
}
}
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
index 143298b700928..15da6567af6f4 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
@@ -1210,6 +1210,27 @@ defm EXTMUL_LOW_U :
defm EXTMUL_HIGH_U :
SIMDExtBinary<I64x2, extmul_high_u, "extmul_high_i32x4_u", 0xdf>;
+// Pattern for dot
+def : Pat<
+ (v4i32 (add
+ (wasm_shuffle
+ (v4i32 (extmul_low_s v8i16:$lhs, v8i16:$rhs)),
+ (v4i32 (extmul_high_s v8i16:$lhs, v8i16:$rhs)),
+ (i32 0), (i32 1), (i32 2), (i32 3),
+ (i32 8), (i32 9), (i32 10), (i32 11),
+ (i32 16), (i32 17), (i32 18), (i32 19),
+ (i32 24), (i32 25), (i32 26), (i32 27)),
+ (wasm_shuffle
+ (v4i32 (extmul_low_s v8i16:$lhs, v8i16:$rhs)),
+ (v4i32 (extmul_high_s v8i16:$lhs, v8i16:$rhs)),
+ (i32 4), (i32 5), (i32 6), (i32 7),
+ (i32 12), (i32 13), (i32 14), (i32 15),
+ (i32 20), (i32 21), (i32 22), (i32 23),
+ (i32 28), (i32 29), (i32 30), (i32 31)))
+ ),
+ (v4i32 (DOT v8i16:$lhs, v8i16:$rhs))
+>;
+
//===----------------------------------------------------------------------===//
// Floating-point unary arithmetic
//===----------------------------------------------------------------------===//
>From 34f58f17590368f6c55a6a40b0f023a8ef1ce351 Mon Sep 17 00:00:00 2001
From: Jasmine Tang <jjasmine at igalia.com>
Date: Tue, 5 Aug 2025 11:13:28 -0700
Subject: [PATCH 4/4] Addresses PR reviews
---
.../WebAssembly/simd-dot-reductions.ll | 75 ++++++++++++++++++-
1 file changed, 71 insertions(+), 4 deletions(-)
diff --git a/llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll b/llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll
index 7ac49794491a1..fd50287a231d3 100644
--- a/llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll
+++ b/llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll
@@ -2,9 +2,10 @@
; RUN: llc < %s -mattr=+simd128 | FileCheck %s
target triple = "wasm32-unknown-unknown"
-define <4 x i32> @dot(<8 x i16> %a, <8 x i16> %b) {
-; CHECK-LABEL: dot:
-; CHECK: .functype dot (v128, v128) -> (v128)
+
+define <4 x i32> @dot_sext_1(<8 x i16> %a, <8 x i16> %b) {
+; CHECK-LABEL: dot_sext_1:
+; CHECK: .functype dot_sext_1 (v128, v128) -> (v128)
; CHECK-NEXT: # %bb.0:
; CHECK-NEXT: local.get 0
; CHECK-NEXT: local.get 1
@@ -12,10 +13,76 @@ define <4 x i32> @dot(<8 x i16> %a, <8 x i16> %b) {
; CHECK-NEXT: # fallthrough-return
%sext1 = sext <8 x i16> %a to <8 x i32>
%sext2 = sext <8 x i16> %b to <8 x i32>
- %mul = mul nsw <8 x i32> %sext1, %sext2
+ %mul = mul <8 x i32> %sext1, %sext2
+ %shuffle1 = shufflevector <8 x i32> %mul, <8 x i32> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
+ %shuffle2 = shufflevector <8 x i32> %mul, <8 x i32> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
+ %res = add <4 x i32> %shuffle1, %shuffle2
+ ret <4 x i32> %res
+}
+
+
+define <4 x i32> @dot_sext_2(<8 x i16> %a, <8 x i16> %b) {
+; CHECK-LABEL: dot_sext_2:
+; CHECK: .functype dot_sext_2 (v128, v128) -> (v128)
+; CHECK-NEXT: # %bb.0:
+; CHECK-NEXT: local.get 0
+; CHECK-NEXT: local.get 1
+; CHECK-NEXT: i32x4.dot_i16x8_s
+; CHECK-NEXT: # fallthrough-return
+ %sext1 = sext <8 x i16> %a to <8 x i32>
+ %sext2 = sext <8 x i16> %b to <8 x i32>
+ %mul = mul <8 x i32> %sext1, %sext2
+ %shuffle1 = shufflevector <8 x i32> %mul, <8 x i32> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
+ %shuffle2 = shufflevector <8 x i32> %mul, <8 x i32> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
+ %res = add <4 x i32> %shuffle2, %shuffle1
+ ret <4 x i32> %res
+}
+
+define <4 x i32> @dot_zext(<8 x i16> %a, <8 x i16> %b) {
+; CHECK-LABEL: dot_zext:
+; CHECK: .functype dot_zext (v128, v128) -> (v128)
+; CHECK-NEXT: .local v128
+; CHECK-NEXT: # %bb.0:
+; CHECK-NEXT: local.get 0
+; CHECK-NEXT: local.get 1
+; CHECK-NEXT: i32x4.extmul_low_i16x8_u
+; CHECK-NEXT: local.tee 2
+; CHECK-NEXT: local.get 0
+; CHECK-NEXT: local.get 1
+; CHECK-NEXT: i32x4.extmul_high_i16x8_u
+; CHECK-NEXT: local.tee 1
+; CHECK-NEXT: i8x16.shuffle 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27
+; CHECK-NEXT: local.get 2
+; CHECK-NEXT: local.get 1
+; CHECK-NEXT: i8x16.shuffle 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
+; CHECK-NEXT: i32x4.add
+; CHECK-NEXT: # fallthrough-return
+ %zext1 = zext <8 x i16> %a to <8 x i32>
+ %zext2 = zext <8 x i16> %b to <8 x i32>
+ %mul = mul <8 x i32> %zext1, %zext2
%shuffle1 = shufflevector <8 x i32> %mul, <8 x i32> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
%shuffle2 = shufflevector <8 x i32> %mul, <8 x i32> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
%res = add <4 x i32> %shuffle1, %shuffle2
ret <4 x i32> %res
}
+define <4 x i32> @dot_wrong_shuffle(<8 x i16> %a, <8 x i16> %b) {
+; CHECK-LABEL: dot_wrong_shuffle:
+; CHECK: .functype dot_wrong_shuffle (v128, v128) -> (v128)
+; CHECK-NEXT: # %bb.0:
+; CHECK-NEXT: local.get 0
+; CHECK-NEXT: local.get 1
+; CHECK-NEXT: i32x4.extmul_low_i16x8_s
+; CHECK-NEXT: local.get 0
+; CHECK-NEXT: local.get 1
+; CHECK-NEXT: i32x4.extmul_high_i16x8_s
+; CHECK-NEXT: i32x4.add
+; CHECK-NEXT: # fallthrough-return
+ %sext1 = sext <8 x i16> %a to <8 x i32>
+ %sext2 = sext <8 x i16> %b to <8 x i32>
+ %mul = mul <8 x i32> %sext1, %sext2
+ %shuffle1 = shufflevector <8 x i32> %mul, <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+ %shuffle2 = shufflevector <8 x i32> %mul, <8 x i32> poison, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
+ %res = add <4 x i32> %shuffle1, %shuffle2
+ ret <4 x i32> %res
+}
More information about the llvm-commits
mailing list