[llvm] [WebAssembly] Constant fold wasm.dot (PR #149619)

Jasmine Tang via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 22 13:46:42 PDT 2025


https://github.com/badumbatish updated https://github.com/llvm/llvm-project/pull/149619

>From 5563f46b0d4b559e0d29d508a3acbbf270a173c2 Mon Sep 17 00:00:00 2001
From: Jasmine Tang <jjasmine at igalia.com>
Date: Fri, 18 Jul 2025 14:11:05 -0700
Subject: [PATCH 1/3] [WebAssembly] Precommit test for constant folding dot

---
 .../InstSimplify/ConstProp/WebAssembly/dot.ll | 37 +++++++++++++++++++
 1 file changed, 37 insertions(+)
 create mode 100644 llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll

diff --git a/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll b/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll
new file mode 100644
index 0000000000000..75a500c6278ad
--- /dev/null
+++ b/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll
@@ -0,0 +1,37 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+
+; RUN: opt -passes=instsimplify -S < %s | FileCheck %s
+
+; Test that intrinsics wasm dot call are constant folded
+
+target triple = "wasm32-unknown-unknown"
+
+
+define <4 x i32> @dot_zero() {
+; CHECK-LABEL: define <4 x i32> @dot_zero() {
+; CHECK-NEXT:    [[RES:%.*]] = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> zeroinitializer, <8 x i16> zeroinitializer)
+; CHECK-NEXT:    ret <4 x i32> [[RES]]
+;
+  %res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> zeroinitializer, <8 x i16> zeroinitializer)
+  ret <4 x i32> %res
+}
+
+define <4 x i32> @dot_nonzero() {
+; CHECK-LABEL: define <4 x i32> @dot_nonzero() {
+; CHECK-NEXT:    [[RES:%.*]] = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> <i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 8>, <8 x i16> <i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 8>)
+; CHECK-NEXT:    ret <4 x i32> [[RES]]
+;
+  %res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> <i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 8>, <8 x i16> <i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 8>)
+  ret <4 x i32> %res
+}
+
+define <4 x i32> @dot_doubly_negative() {
+; CHECK-LABEL: define <4 x i32> @dot_doubly_negative() {
+; CHECK-NEXT:    [[RES:%.*]] = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> splat (i16 -1), <8 x i16> splat (i16 -1))
+; CHECK-NEXT:    ret <4 x i32> [[RES]]
+;
+  %res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> <i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1>, <8 x i16> <i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1>)
+  ret <4 x i32> %res
+}
+
+

>From 50ca839dbee8ea3d505dee87c3e21c458ee28b7f Mon Sep 17 00:00:00 2001
From: Jasmine Tang <jjasmine at igalia.com>
Date: Fri, 18 Jul 2025 16:34:41 -0700
Subject: [PATCH 2/3] [WebAssembly] Constant fold dot operation

---
 llvm/lib/Analysis/ConstantFolding.cpp         | 31 +++++++++++++++++++
 .../InstSimplify/ConstProp/WebAssembly/dot.ll | 14 +++++----
 2 files changed, 39 insertions(+), 6 deletions(-)

diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp
index 9c1c2c6e60f02..2304c58b3f95f 100644
--- a/llvm/lib/Analysis/ConstantFolding.cpp
+++ b/llvm/lib/Analysis/ConstantFolding.cpp
@@ -1657,6 +1657,7 @@ bool llvm::canConstantFoldCallTo(const CallBase *Call, const Function *F) {
   case Intrinsic::aarch64_sve_convert_from_svbool:
   case Intrinsic::wasm_alltrue:
   case Intrinsic::wasm_anytrue:
+  case Intrinsic::wasm_dot:
   // WebAssembly float semantics are always known
   case Intrinsic::wasm_trunc_signed:
   case Intrinsic::wasm_trunc_unsigned:
@@ -3826,6 +3827,36 @@ static Constant *ConstantFoldFixedVectorCall(
     }
     return ConstantVector::get(Result);
   }
+  case Intrinsic::wasm_dot: {
+    unsigned NumElements =
+        cast<FixedVectorType>(Operands[0]->getType())->getNumElements();
+
+    assert(NumElements == 8 && NumElements / 2 == Result.size() &&
+           "wasm dot takes i16x8 and produce i32x4");
+    assert(Ty->isIntegerTy());
+    SmallVector<APInt, 8> MulVector;
+
+    for (unsigned I = 0; I < NumElements; ++I) {
+      ConstantInt *Elt0 =
+          cast<ConstantInt>(Operands[0]->getAggregateElement(I));
+      ConstantInt *Elt1 =
+          cast<ConstantInt>(Operands[1]->getAggregateElement(I));
+
+      // sext 32 first, according to specs
+      APInt IMul = Elt0->getValue().sext(32) * Elt1->getValue().sext(32);
+
+      // TODO: imul in specs includes a modulo operation
+      // Is this performed automatically via trunc = true in APInt creation of *
+      MulVector.push_back(IMul);
+    }
+    for (unsigned I = 0; I < Result.size(); ++I) {
+      // Same case as with imul
+      APInt IAdd = MulVector[I] + MulVector[I + Result.size()];
+      Result[I] = ConstantInt::get(Ty, IAdd);
+    }
+
+    return ConstantVector::get(Result);
+  }
   default:
     break;
   }
diff --git a/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll b/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll
index 75a500c6278ad..02c6649becbce 100644
--- a/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll
+++ b/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll
@@ -9,17 +9,20 @@ target triple = "wasm32-unknown-unknown"
 
 define <4 x i32> @dot_zero() {
 ; CHECK-LABEL: define <4 x i32> @dot_zero() {
-; CHECK-NEXT:    [[RES:%.*]] = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> zeroinitializer, <8 x i16> zeroinitializer)
-; CHECK-NEXT:    ret <4 x i32> [[RES]]
+; CHECK-NEXT:    ret <4 x i32> zeroinitializer
 ;
   %res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> zeroinitializer, <8 x i16> zeroinitializer)
   ret <4 x i32> %res
 }
 
+; a               =   1    2    3    4    5    6    7    8
+; b               =   1    2    3    4    5    6    7    8
+; k1|k2 = a * b   =   1    4    9   16   25   36   49   64
+; k1 + k2         =   (1+25) |  (4+36) | (9+49)  | (16+64)
+; result          =    26    |   40    |   58    |   80
 define <4 x i32> @dot_nonzero() {
 ; CHECK-LABEL: define <4 x i32> @dot_nonzero() {
-; CHECK-NEXT:    [[RES:%.*]] = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> <i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 8>, <8 x i16> <i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 8>)
-; CHECK-NEXT:    ret <4 x i32> [[RES]]
+; CHECK-NEXT:    ret <4 x i32> <i32 26, i32 40, i32 58, i32 80>
 ;
   %res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> <i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 8>, <8 x i16> <i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 8>)
   ret <4 x i32> %res
@@ -27,8 +30,7 @@ define <4 x i32> @dot_nonzero() {
 
 define <4 x i32> @dot_doubly_negative() {
 ; CHECK-LABEL: define <4 x i32> @dot_doubly_negative() {
-; CHECK-NEXT:    [[RES:%.*]] = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> splat (i16 -1), <8 x i16> splat (i16 -1))
-; CHECK-NEXT:    ret <4 x i32> [[RES]]
+; CHECK-NEXT:    ret <4 x i32> splat (i32 2)
 ;
   %res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> <i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1>, <8 x i16> <i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1, i16 -1>)
   ret <4 x i32> %res

>From fa8c096313e2eab8661a5d63b04477f01c86ed78 Mon Sep 17 00:00:00 2001
From: Jasmine Tang <jjasmine at igalia.com>
Date: Tue, 22 Jul 2025 13:40:55 -0700
Subject: [PATCH 3/3] Addresses specs questions and added test to reflect

---
 llvm/lib/Analysis/ConstantFolding.cpp          |  9 ++++-----
 .../InstSimplify/ConstProp/WebAssembly/dot.ll  | 18 +++++++++++++++---
 2 files changed, 19 insertions(+), 8 deletions(-)

diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp
index 2304c58b3f95f..a63be47e21eaa 100644
--- a/llvm/lib/Analysis/ConstantFolding.cpp
+++ b/llvm/lib/Analysis/ConstantFolding.cpp
@@ -3845,13 +3845,12 @@ static Constant *ConstantFoldFixedVectorCall(
       // sext 32 first, according to specs
       APInt IMul = Elt0->getValue().sext(32) * Elt1->getValue().sext(32);
 
-      // TODO: imul in specs includes a modulo operation
-      // Is this performed automatically via trunc = true in APInt creation of *
+      // i16 -> i32 bypasses specs modulo on imul
       MulVector.push_back(IMul);
     }
-    for (unsigned I = 0; I < Result.size(); ++I) {
-      // Same case as with imul
-      APInt IAdd = MulVector[I] + MulVector[I + Result.size()];
+    for (unsigned I = 0; I < Result.size(); I++) {
+      // i16 -> i32 bypasses specs modulo on iadd
+      APInt IAdd = MulVector[I * 2] + MulVector[I * 2 + 1];
       Result[I] = ConstantInt::get(Ty, IAdd);
     }
 
diff --git a/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll b/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll
index 02c6649becbce..b2f23d0f153ef 100644
--- a/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll
+++ b/llvm/test/Transforms/InstSimplify/ConstProp/WebAssembly/dot.ll
@@ -18,11 +18,11 @@ define <4 x i32> @dot_zero() {
 ; a               =   1    2    3    4    5    6    7    8
 ; b               =   1    2    3    4    5    6    7    8
 ; k1|k2 = a * b   =   1    4    9   16   25   36   49   64
-; k1 + k2         =   (1+25) |  (4+36) | (9+49)  | (16+64)
-; result          =    26    |   40    |   58    |   80
+; k1 + k2         =   (1+4) |  (9 + 16) | (25 + 36)  | (49 + 64)
+; result          =    5    |   25    |   61    |   113
 define <4 x i32> @dot_nonzero() {
 ; CHECK-LABEL: define <4 x i32> @dot_nonzero() {
-; CHECK-NEXT:    ret <4 x i32> <i32 26, i32 40, i32 58, i32 80>
+; CHECK-NEXT:    ret <4 x i32> <i32 5, i32 25, i32 61, i32 113>
 ;
   %res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> <i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 8>, <8 x i16> <i16 1, i16 2, i16 3, i16 4, i16 5, i16 6, i16 7, i16 8>)
   ret <4 x i32> %res
@@ -36,4 +36,16 @@ define <4 x i32> @dot_doubly_negative() {
   ret <4 x i32> %res
 }
 
+; This test checks for llvm's compliance on spec's wasm.dot's imul and iadd
+; Since the original number can only be i16::max == 2^15 - 1,
+;   subsequent modulo of 2^32 of imul and iadd
+;   should return the same result
+; 2*(2^15 - 1)^2 % 2^32 == 2*(2^15 - 1)^2
+define <4 x i32> @dot_follow_modulo_spec() {
+; CHECK-LABEL: define <4 x i32> @dot_follow_modulo_spec() {
+; CHECK-NEXT:    ret <4 x i32> <i32 2147352578, i32 0, i32 0, i32 0>
+;
+  %res = tail call <4 x i32> @llvm.wasm.dot(<8 x i16> <i16 32767, i16 32767, i16 0, i16 0, i16 0, i16 0, i16 0, i16 0>, <8 x i16> <i16 32767, i16 32767, i16 0, i16 0, i16 0, i16 0, i16 0, i16 0>)
+  ret <4 x i32> %res
+}
 



More information about the llvm-commits mailing list