[llvm] [WebAssembly] Add fold support for dot (PR #151775)

Jasmine Tang via llvm-commits llvm-commits at lists.llvm.org
Tue Aug 5 11:14:06 PDT 2025


https://github.com/badumbatish updated https://github.com/llvm/llvm-project/pull/151775

>From 4d304c888e4aecac25ee4a17e52ab5e4861e1a6a Mon Sep 17 00:00:00 2001
From: Jasmine Tang <jjasmine at igalia.com>
Date: Fri, 1 Aug 2025 14:03:04 -0700
Subject: [PATCH 1/4] Precommit test

---
 .../WebAssembly/simd-dot-reductions.ll        | 32 +++++++++++++++++++
 1 file changed, 32 insertions(+)
 create mode 100644 llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll

diff --git a/llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll b/llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll
new file mode 100644
index 0000000000000..76c20c404e6f0
--- /dev/null
+++ b/llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll
@@ -0,0 +1,32 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc < %s -mattr=+simd128 | FileCheck %s
+
+target triple = "wasm32-unknown-unknown"
+define <4 x i32> @dot(<8 x i16> %a, <8 x i16> %b) {
+; CHECK-LABEL: dot:
+; CHECK:         .functype dot (v128, v128) -> (v128)
+; CHECK-NEXT:    .local v128
+; CHECK-NEXT:  # %bb.0:
+; CHECK-NEXT:    local.get 0
+; CHECK-NEXT:    local.get 1
+; CHECK-NEXT:    i32x4.extmul_low_i16x8_s
+; CHECK-NEXT:    local.tee 2
+; CHECK-NEXT:    local.get 0
+; CHECK-NEXT:    local.get 1
+; CHECK-NEXT:    i32x4.extmul_high_i16x8_s
+; CHECK-NEXT:    local.tee 1
+; CHECK-NEXT:    i8x16.shuffle 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27
+; CHECK-NEXT:    local.get 2
+; CHECK-NEXT:    local.get 1
+; CHECK-NEXT:    i8x16.shuffle 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
+; CHECK-NEXT:    i32x4.add
+; CHECK-NEXT:    # fallthrough-return
+  %sext1 = sext <8 x i16> %a to <8 x i32>
+  %sext2 = sext <8 x i16> %b to <8 x i32>
+  %mul = mul nsw <8 x i32> %sext1, %sext2
+  %shuffle1 = shufflevector <8 x i32> %mul, <8 x i32> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
+  %shuffle2 = shufflevector <8 x i32> %mul, <8 x i32> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
+  %res = add <4 x i32> %shuffle1, %shuffle2
+  ret <4 x i32> %res
+}
+

>From cb9aac0407cb67fbf705a7c18c2b842bb4623466 Mon Sep 17 00:00:00 2001
From: Jasmine Tang <jjasmine at igalia.com>
Date: Fri, 1 Aug 2025 14:21:51 -0700
Subject: [PATCH 2/4] Added combine support for dot

---
 .../WebAssembly/WebAssemblyISelLowering.cpp   | 51 +++++++++++++++++++
 .../WebAssembly/simd-dot-reductions.ll        | 13 +----
 2 files changed, 52 insertions(+), 12 deletions(-)

diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index cd434f7a331e4..648e3b6b2b440 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -192,6 +192,9 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
     // Combine wide-vector muls, with extend inputs, to extmul_half.
     setTargetDAGCombine(ISD::MUL);
 
+    // Combine add with vector shuffle of muls to dots
+    setTargetDAGCombine(ISD::ADD);
+
     // Combine vector mask reductions into alltrue/anytrue
     setTargetDAGCombine(ISD::SETCC);
 
@@ -3436,6 +3439,52 @@ static SDValue performSETCCCombine(SDNode *N,
   return SDValue();
 }
 
+static SDValue performAddCombine(SDNode *N, SelectionDAG &DAG) {
+  assert(N->getOpcode() == ISD::ADD);
+  EVT VT = N->getValueType(0);
+  SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
+
+  if (VT != MVT::v4i32)
+    return SDValue();
+
+  auto IsShuffleWithMask = [](SDValue V, ArrayRef<int> ShuffleValue) {
+    if (V.getOpcode() != ISD::VECTOR_SHUFFLE)
+      return SDValue();
+    if (cast<ShuffleVectorSDNode>(V)->getMask() != ShuffleValue)
+      return SDValue();
+    return V;
+  };
+  auto ShuffleA = IsShuffleWithMask(N0, {0, 2, 4, 6});
+  auto ShuffleB = IsShuffleWithMask(N1, {1, 3, 5, 7});
+  // two SDValues must be muls
+  if (!ShuffleA || !ShuffleB)
+    return SDValue();
+
+  if (ShuffleA.getOperand(0) != ShuffleB.getOperand(0) ||
+      ShuffleA.getOperand(1) != ShuffleB.getOperand(1))
+    return SDValue();
+
+  auto IsMulExtend =
+      [](SDValue V, WebAssemblyISD::NodeType I) -> std::pair<SDValue, SDValue> {
+    if (V.getOpcode() != ISD::MUL)
+      return {};
+
+    auto V0 = V.getOperand(0), V1 = V.getOperand(1);
+    if (V0.getOpcode() != I || V1.getOpcode() != I)
+      return {};
+    return {V0.getOperand(0), V1.getOperand(0)};
+  };
+
+  auto [LowA, LowB] =
+      IsMulExtend(ShuffleA.getOperand(0), WebAssemblyISD::EXTEND_LOW_S);
+  auto [HighA, HighB] =
+      IsMulExtend(ShuffleA.getOperand(1), WebAssemblyISD::EXTEND_HIGH_S);
+
+  if (!LowA || !LowB || !HighA || !HighB || LowA != HighA || LowB != HighB)
+    return SDValue();
+
+  return DAG.getNode(WebAssemblyISD::DOT, SDLoc(N), MVT::v4i32, LowA, LowB);
+}
 static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG) {
   assert(N->getOpcode() == ISD::MUL);
   EVT VT = N->getValueType(0);
@@ -3558,5 +3607,7 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
   }
   case ISD::MUL:
     return performMulCombine(N, DCI.DAG);
+  case ISD::ADD:
+    return performAddCombine(N, DCI.DAG);
   }
 }
diff --git a/llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll b/llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll
index 76c20c404e6f0..7ac49794491a1 100644
--- a/llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll
+++ b/llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll
@@ -5,21 +5,10 @@ target triple = "wasm32-unknown-unknown"
 define <4 x i32> @dot(<8 x i16> %a, <8 x i16> %b) {
 ; CHECK-LABEL: dot:
 ; CHECK:         .functype dot (v128, v128) -> (v128)
-; CHECK-NEXT:    .local v128
 ; CHECK-NEXT:  # %bb.0:
 ; CHECK-NEXT:    local.get 0
 ; CHECK-NEXT:    local.get 1
-; CHECK-NEXT:    i32x4.extmul_low_i16x8_s
-; CHECK-NEXT:    local.tee 2
-; CHECK-NEXT:    local.get 0
-; CHECK-NEXT:    local.get 1
-; CHECK-NEXT:    i32x4.extmul_high_i16x8_s
-; CHECK-NEXT:    local.tee 1
-; CHECK-NEXT:    i8x16.shuffle 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27
-; CHECK-NEXT:    local.get 2
-; CHECK-NEXT:    local.get 1
-; CHECK-NEXT:    i8x16.shuffle 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
-; CHECK-NEXT:    i32x4.add
+; CHECK-NEXT:    i32x4.dot_i16x8_s
 ; CHECK-NEXT:    # fallthrough-return
   %sext1 = sext <8 x i16> %a to <8 x i32>
   %sext2 = sext <8 x i16> %b to <8 x i32>

>From 86fe99b07c58ebd696bd6bd24ae4e74a728c336c Mon Sep 17 00:00:00 2001
From: Jasmine Tang <jjasmine at igalia.com>
Date: Tue, 5 Aug 2025 10:32:51 -0700
Subject: [PATCH 3/4] Transition to tablegen for pattern

---
 .../WebAssembly/WebAssemblyISelLowering.cpp   | 52 -------------------
 .../WebAssembly/WebAssemblyInstrSIMD.td       | 21 ++++++++
 2 files changed, 21 insertions(+), 52 deletions(-)

diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index 0955e2d2f39b0..3f80b2ab2bd6d 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -192,9 +192,6 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
     // Combine wide-vector muls, with extend inputs, to extmul_half.
     setTargetDAGCombine(ISD::MUL);
 
-    // Combine add with vector shuffle of muls to dots
-    setTargetDAGCombine(ISD::ADD);
-
     // Combine vector mask reductions into alltrue/anytrue
     setTargetDAGCombine(ISD::SETCC);
 
@@ -3439,53 +3436,6 @@ static SDValue performSETCCCombine(SDNode *N,
   return SDValue();
 }
 
-static SDValue performAddCombine(SDNode *N, SelectionDAG &DAG) {
-  assert(N->getOpcode() == ISD::ADD);
-  EVT VT = N->getValueType(0);
-  SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
-
-  if (VT != MVT::v4i32)
-    return SDValue();
-
-  auto IsShuffleWithMask = [](SDValue V, ArrayRef<int> ShuffleValue) {
-    if (V.getOpcode() != ISD::VECTOR_SHUFFLE)
-      return SDValue();
-    if (cast<ShuffleVectorSDNode>(V)->getMask() != ShuffleValue)
-      return SDValue();
-    return V;
-  };
-  auto ShuffleA = IsShuffleWithMask(N0, {0, 2, 4, 6});
-  auto ShuffleB = IsShuffleWithMask(N1, {1, 3, 5, 7});
-  // two SDValues must be muls
-  if (!ShuffleA || !ShuffleB)
-    return SDValue();
-
-  if (ShuffleA.getOperand(0) != ShuffleB.getOperand(0) ||
-      ShuffleA.getOperand(1) != ShuffleB.getOperand(1))
-    return SDValue();
-
-  auto IsMulExtend =
-      [](SDValue V, WebAssemblyISD::NodeType I) -> std::pair<SDValue, SDValue> {
-    if (V.getOpcode() != ISD::MUL)
-      return {};
-
-    auto V0 = V.getOperand(0), V1 = V.getOperand(1);
-    if (V0.getOpcode() != I || V1.getOpcode() != I)
-      return {};
-    return {V0.getOperand(0), V1.getOperand(0)};
-  };
-
-  auto [LowA, LowB] =
-      IsMulExtend(ShuffleA.getOperand(0), WebAssemblyISD::EXTEND_LOW_S);
-  auto [HighA, HighB] =
-      IsMulExtend(ShuffleA.getOperand(1), WebAssemblyISD::EXTEND_HIGH_S);
-
-  if (!LowA || !LowB || !HighA || !HighB || LowA != HighA || LowB != HighB)
-    return SDValue();
-
-  return DAG.getNode(WebAssemblyISD::DOT, SDLoc(N), MVT::v4i32, LowA, LowB);
-}
-
 static SDValue TryWideExtMulCombine(SDNode *N, SelectionDAG &DAG) {
   EVT VT = N->getValueType(0);
   if (VT != MVT::v8i32 && VT != MVT::v16i32)
@@ -3647,7 +3597,5 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
   }
   case ISD::MUL:
     return performMulCombine(N, DCI);
-  case ISD::ADD:
-    return performAddCombine(N, DCI.DAG);
   }
 }
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
index 143298b700928..15da6567af6f4 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrSIMD.td
@@ -1210,6 +1210,27 @@ defm EXTMUL_LOW_U :
 defm EXTMUL_HIGH_U :
   SIMDExtBinary<I64x2, extmul_high_u, "extmul_high_i32x4_u", 0xdf>;
 
+// Pattern for dot
+def : Pat<
+  (v4i32 (add
+    (wasm_shuffle
+      (v4i32 (extmul_low_s v8i16:$lhs, v8i16:$rhs)),
+      (v4i32 (extmul_high_s v8i16:$lhs, v8i16:$rhs)),
+      (i32 0), (i32 1), (i32 2), (i32 3),
+      (i32 8), (i32 9), (i32 10), (i32 11),
+      (i32 16), (i32 17), (i32 18), (i32 19),
+      (i32 24), (i32 25), (i32 26), (i32 27)),
+    (wasm_shuffle
+      (v4i32 (extmul_low_s v8i16:$lhs, v8i16:$rhs)),
+      (v4i32 (extmul_high_s v8i16:$lhs, v8i16:$rhs)),
+      (i32 4), (i32 5), (i32 6), (i32 7),
+      (i32 12), (i32 13), (i32 14), (i32 15),
+      (i32 20), (i32 21), (i32 22), (i32 23),
+      (i32 28), (i32 29), (i32 30), (i32 31)))
+  ),
+  (v4i32 (DOT v8i16:$lhs, v8i16:$rhs))
+>;
+
 //===----------------------------------------------------------------------===//
 // Floating-point unary arithmetic
 //===----------------------------------------------------------------------===//

>From 34f58f17590368f6c55a6a40b0f023a8ef1ce351 Mon Sep 17 00:00:00 2001
From: Jasmine Tang <jjasmine at igalia.com>
Date: Tue, 5 Aug 2025 11:13:28 -0700
Subject: [PATCH 4/4] Addresses PR reviews

---
 .../WebAssembly/simd-dot-reductions.ll        | 75 ++++++++++++++++++-
 1 file changed, 71 insertions(+), 4 deletions(-)

diff --git a/llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll b/llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll
index 7ac49794491a1..fd50287a231d3 100644
--- a/llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll
+++ b/llvm/test/CodeGen/WebAssembly/simd-dot-reductions.ll
@@ -2,9 +2,10 @@
 ; RUN: llc < %s -mattr=+simd128 | FileCheck %s
 
 target triple = "wasm32-unknown-unknown"
-define <4 x i32> @dot(<8 x i16> %a, <8 x i16> %b) {
-; CHECK-LABEL: dot:
-; CHECK:         .functype dot (v128, v128) -> (v128)
+
+define <4 x i32> @dot_sext_1(<8 x i16> %a, <8 x i16> %b) {
+; CHECK-LABEL: dot_sext_1:
+; CHECK:         .functype dot_sext_1 (v128, v128) -> (v128)
 ; CHECK-NEXT:  # %bb.0:
 ; CHECK-NEXT:    local.get 0
 ; CHECK-NEXT:    local.get 1
@@ -12,10 +13,76 @@ define <4 x i32> @dot(<8 x i16> %a, <8 x i16> %b) {
 ; CHECK-NEXT:    # fallthrough-return
   %sext1 = sext <8 x i16> %a to <8 x i32>
   %sext2 = sext <8 x i16> %b to <8 x i32>
-  %mul = mul nsw <8 x i32> %sext1, %sext2
+  %mul = mul <8 x i32> %sext1, %sext2
+  %shuffle1 = shufflevector <8 x i32> %mul, <8 x i32> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
+  %shuffle2 = shufflevector <8 x i32> %mul, <8 x i32> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
+  %res = add <4 x i32> %shuffle1, %shuffle2
+  ret <4 x i32> %res
+}
+
+
+define <4 x i32> @dot_sext_2(<8 x i16> %a, <8 x i16> %b) {
+; CHECK-LABEL: dot_sext_2:
+; CHECK:         .functype dot_sext_2 (v128, v128) -> (v128)
+; CHECK-NEXT:  # %bb.0:
+; CHECK-NEXT:    local.get 0
+; CHECK-NEXT:    local.get 1
+; CHECK-NEXT:    i32x4.dot_i16x8_s
+; CHECK-NEXT:    # fallthrough-return
+  %sext1 = sext <8 x i16> %a to <8 x i32>
+  %sext2 = sext <8 x i16> %b to <8 x i32>
+  %mul = mul <8 x i32> %sext1, %sext2
+  %shuffle1 = shufflevector <8 x i32> %mul, <8 x i32> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
+  %shuffle2 = shufflevector <8 x i32> %mul, <8 x i32> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
+  %res = add <4 x i32> %shuffle2, %shuffle1
+  ret <4 x i32> %res
+}
+
+define <4 x i32> @dot_zext(<8 x i16> %a, <8 x i16> %b) {
+; CHECK-LABEL: dot_zext:
+; CHECK:         .functype dot_zext (v128, v128) -> (v128)
+; CHECK-NEXT:    .local v128
+; CHECK-NEXT:  # %bb.0:
+; CHECK-NEXT:    local.get 0
+; CHECK-NEXT:    local.get 1
+; CHECK-NEXT:    i32x4.extmul_low_i16x8_u
+; CHECK-NEXT:    local.tee 2
+; CHECK-NEXT:    local.get 0
+; CHECK-NEXT:    local.get 1
+; CHECK-NEXT:    i32x4.extmul_high_i16x8_u
+; CHECK-NEXT:    local.tee 1
+; CHECK-NEXT:    i8x16.shuffle 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27
+; CHECK-NEXT:    local.get 2
+; CHECK-NEXT:    local.get 1
+; CHECK-NEXT:    i8x16.shuffle 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
+; CHECK-NEXT:    i32x4.add
+; CHECK-NEXT:    # fallthrough-return
+  %zext1 = zext <8 x i16> %a to <8 x i32>
+  %zext2 = zext <8 x i16> %b to <8 x i32>
+  %mul = mul <8 x i32> %zext1, %zext2
   %shuffle1 = shufflevector <8 x i32> %mul, <8 x i32> poison, <4 x i32> <i32 0, i32 2, i32 4, i32 6>
   %shuffle2 = shufflevector <8 x i32> %mul, <8 x i32> poison, <4 x i32> <i32 1, i32 3, i32 5, i32 7>
   %res = add <4 x i32> %shuffle1, %shuffle2
   ret <4 x i32> %res
 }
 
+define <4 x i32> @dot_wrong_shuffle(<8 x i16> %a, <8 x i16> %b) {
+; CHECK-LABEL: dot_wrong_shuffle:
+; CHECK:         .functype dot_wrong_shuffle (v128, v128) -> (v128)
+; CHECK-NEXT:  # %bb.0:
+; CHECK-NEXT:    local.get 0
+; CHECK-NEXT:    local.get 1
+; CHECK-NEXT:    i32x4.extmul_low_i16x8_s
+; CHECK-NEXT:    local.get 0
+; CHECK-NEXT:    local.get 1
+; CHECK-NEXT:    i32x4.extmul_high_i16x8_s
+; CHECK-NEXT:    i32x4.add
+; CHECK-NEXT:    # fallthrough-return
+  %sext1 = sext <8 x i16> %a to <8 x i32>
+  %sext2 = sext <8 x i16> %b to <8 x i32>
+  %mul = mul <8 x i32> %sext1, %sext2
+  %shuffle1 = shufflevector <8 x i32> %mul, <8 x i32> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
+  %shuffle2 = shufflevector <8 x i32> %mul, <8 x i32> poison, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
+  %res = add <4 x i32> %shuffle1, %shuffle2
+  ret <4 x i32> %res
+}



More information about the llvm-commits mailing list