[llvm] relaxed simd fma (PR #147487)

via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 8 02:02:45 PDT 2025


https://github.com/badumbatish created https://github.com/llvm/llvm-project/pull/147487

- **Precommit test for #121311**
- **[WASM] Optimize fma when relaxed and ffast-math**


>From 683fae7878c6c9250bf7142a2fd16170aa734f71 Mon Sep 17 00:00:00 2001
From: Jasmine Tang <jjasmine at igalia.com>
Date: Tue, 8 Jul 2025 01:05:36 -0700
Subject: [PATCH 1/2] Precommit test for #121311

---
 .../CodeGen/WebAssembly/simd-relaxed-fma.ll   | 66 +++++++++++++++++++
 1 file changed, 66 insertions(+)
 create mode 100644 llvm/test/CodeGen/WebAssembly/simd-relaxed-fma.ll

diff --git a/llvm/test/CodeGen/WebAssembly/simd-relaxed-fma.ll b/llvm/test/CodeGen/WebAssembly/simd-relaxed-fma.ll
new file mode 100644
index 0000000000000..ea3ee2a33cfa4
--- /dev/null
+++ b/llvm/test/CodeGen/WebAssembly/simd-relaxed-fma.ll
@@ -0,0 +1,66 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+
+; RUN: llc < %s -verify-machineinstrs -disable-wasm-fallthrough-return-opt -wasm-disable-explicit-locals -wasm-keep-registers  -mattr=+simd128,+relaxed-simd | FileCheck %s
+target triple = "wasm32"
+define void @fma_seperate(ptr %a, ptr %b, ptr %c, ptr %dest) {
+; CHECK-LABEL: fma_seperate:
+; CHECK:         .functype fma_seperate (i32, i32, i32, i32) -> ()
+; CHECK-NEXT:  # %bb.0: # %entry
+; CHECK-NEXT:    v128.load $push1=, 0($1):p2align=0
+; CHECK-NEXT:    v128.load $push0=, 0($0):p2align=0
+; CHECK-NEXT:    f32x4.mul $push2=, $pop1, $pop0
+; CHECK-NEXT:    v128.load $push3=, 0($2):p2align=0
+; CHECK-NEXT:    f32x4.add $push4=, $pop2, $pop3
+; CHECK-NEXT:    v128.store 0($3):p2align=0, $pop4
+; CHECK-NEXT:    return
+entry:
+  %0 = load <4 x float>, ptr %a, align 1
+  %1 = load <4 x float>, ptr %b, align 1
+  %2 = load <4 x float>, ptr %c, align 1
+  %mul.i = fmul fast <4 x float> %1, %0
+  %add.i = fadd fast <4 x float> %mul.i, %2
+  store <4 x float> %add.i, ptr %dest, align 1
+  ret void
+}
+
+; Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(argmem: readwrite)
+define void @fma_llvm(ptr %a, ptr %b, ptr %c, ptr %dest) {
+; CHECK-LABEL: fma_llvm:
+; CHECK:         .functype fma_llvm (i32, i32, i32, i32) -> ()
+; CHECK-NEXT:  # %bb.0: # %entry
+; CHECK-NEXT:    v128.load $push25=, 0($0):p2align=0
+; CHECK-NEXT:    local.tee $push24=, $6=, $pop25
+; CHECK-NEXT:    f32x4.extract_lane $push2=, $pop24, 0
+; CHECK-NEXT:    v128.load $push23=, 0($1):p2align=0
+; CHECK-NEXT:    local.tee $push22=, $5=, $pop23
+; CHECK-NEXT:    f32x4.extract_lane $push1=, $pop22, 0
+; CHECK-NEXT:    v128.load $push21=, 0($2):p2align=0
+; CHECK-NEXT:    local.tee $push20=, $4=, $pop21
+; CHECK-NEXT:    f32x4.extract_lane $push0=, $pop20, 0
+; CHECK-NEXT:    call $push3=, fmaf, $pop2, $pop1, $pop0
+; CHECK-NEXT:    f32x4.splat $push4=, $pop3
+; CHECK-NEXT:    f32x4.extract_lane $push7=, $6, 1
+; CHECK-NEXT:    f32x4.extract_lane $push6=, $5, 1
+; CHECK-NEXT:    f32x4.extract_lane $push5=, $4, 1
+; CHECK-NEXT:    call $push8=, fmaf, $pop7, $pop6, $pop5
+; CHECK-NEXT:    f32x4.replace_lane $push9=, $pop4, 1, $pop8
+; CHECK-NEXT:    f32x4.extract_lane $push12=, $6, 2
+; CHECK-NEXT:    f32x4.extract_lane $push11=, $5, 2
+; CHECK-NEXT:    f32x4.extract_lane $push10=, $4, 2
+; CHECK-NEXT:    call $push13=, fmaf, $pop12, $pop11, $pop10
+; CHECK-NEXT:    f32x4.replace_lane $push14=, $pop9, 2, $pop13
+; CHECK-NEXT:    f32x4.extract_lane $push17=, $6, 3
+; CHECK-NEXT:    f32x4.extract_lane $push16=, $5, 3
+; CHECK-NEXT:    f32x4.extract_lane $push15=, $4, 3
+; CHECK-NEXT:    call $push18=, fmaf, $pop17, $pop16, $pop15
+; CHECK-NEXT:    f32x4.replace_lane $push19=, $pop14, 3, $pop18
+; CHECK-NEXT:    v128.store 0($3):p2align=0, $pop19
+; CHECK-NEXT:    return
+entry:
+  %0 = load <4 x float>, ptr %a, align 1
+  %1 = load <4 x float>, ptr %b, align 1
+  %2 = load <4 x float>, ptr %c, align 1
+  %fma = tail call fast <4 x float> @llvm.fma.v4f32(<4 x float> %0, <4 x float> %1, <4 x float> %2)
+  store <4 x float> %fma, ptr %dest, align 1
+  ret void
+}

>From b1c4c01dd18259980d8faae6a9e4f71cb30208c6 Mon Sep 17 00:00:00 2001
From: Jasmine Tang <jjasmine at igalia.com>
Date: Tue, 8 Jul 2025 01:49:37 -0700
Subject: [PATCH 2/2] [WASM] Optimize fma when relaxed and ffast-math

Fixes #121311, which folds a series of multiply and add to wasm.fma when
we have -mrelaxed-simd and -ffast-math.

Also attempted to use wasm.fma instead of the built in llvm.fma
---
 llvm/include/llvm/CodeGen/SelectionDAGNodes.h |  1 +
 .../WebAssembly/WebAssemblyISelLowering.cpp   | 41 +++++++++++++++++++
 .../CodeGen/WebAssembly/simd-relaxed-fma.ll   | 39 ++++--------------
 3 files changed, 50 insertions(+), 31 deletions(-)

diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index a3675eecfea3f..ec566b168bc3d 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -475,6 +475,7 @@ struct SDNodeFlags {
   bool hasAllowReassociation() const { return Flags & AllowReassociation; }
   bool hasNoFPExcept() const { return Flags & NoFPExcept; }
   bool hasUnpredictable() const { return Flags & Unpredictable; }
+  bool hasFastMath() const { return Flags & FastMathFlags; }
 
   bool operator==(const SDNodeFlags &Other) const {
     return Flags == Other.Flags;
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index bf2e04caa0a61..ef0146f28aba1 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -182,6 +182,12 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
   // SIMD-specific configuration
   if (Subtarget->hasSIMD128()) {
 
+    // Enable fma optimization for wasm relaxed simd
+    if (Subtarget->hasRelaxedSIMD()) {
+      setTargetDAGCombine(ISD::FADD);
+      setTargetDAGCombine(ISD::FMA);
+    }
+
     // Combine partial.reduce.add before legalization gets confused.
     setTargetDAGCombine(ISD::INTRINSIC_WO_CHAIN);
 
@@ -3412,6 +3418,37 @@ static SDValue performSETCCCombine(SDNode *N,
   return SDValue();
 }
 
+static SDValue performFAddCombine(SDNode *N, SelectionDAG &DAG) {
+  assert(N->getOpcode() == ISD::FADD);
+  using namespace llvm::SDPatternMatch;
+  if (!N->getFlags().hasFastMath())
+    return SDValue();
+
+  SDLoc DL(N);
+  SDValue A, B, C;
+  EVT VecVT = N->getValueType(0);
+  if (sd_match(N, m_FAdd(m_Value(A), m_FMul(m_Value(B), m_Value(C)))))
+    return DAG.getNode(
+        ISD::INTRINSIC_WO_CHAIN, DL, VecVT,
+        {DAG.getConstant(Intrinsic::wasm_relaxed_madd, DL, MVT::i32), A, B, C});
+
+  return SDValue();
+}
+
+static SDValue performFMACombine(SDNode *N, SelectionDAG &DAG) {
+  assert(N->getOpcode() == ISD::FMA);
+  if (!N->getFlags().hasFastMath())
+    return SDValue();
+
+  SDLoc DL(N);
+  SDValue A = N->getOperand(0), B = N->getOperand(1), C = N->getOperand(2);
+  EVT VecVT = N->getValueType(0);
+
+  return DAG.getNode(
+      ISD::INTRINSIC_WO_CHAIN, DL, VecVT,
+      {DAG.getConstant(Intrinsic::wasm_relaxed_madd, DL, MVT::i32), A, B, C});
+}
+
 static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG) {
   assert(N->getOpcode() == ISD::MUL);
   EVT VT = N->getValueType(0);
@@ -3529,6 +3566,10 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
       return AnyAllCombine;
     return performLowerPartialReduction(N, DCI.DAG);
   }
+  case ISD::FADD:
+    return performFAddCombine(N, DCI.DAG);
+  case ISD::FMA:
+    return performFMACombine(N, DCI.DAG);
   case ISD::MUL:
     return performMulCombine(N, DCI.DAG);
   }
diff --git a/llvm/test/CodeGen/WebAssembly/simd-relaxed-fma.ll b/llvm/test/CodeGen/WebAssembly/simd-relaxed-fma.ll
index ea3ee2a33cfa4..fe5e8573f12b4 100644
--- a/llvm/test/CodeGen/WebAssembly/simd-relaxed-fma.ll
+++ b/llvm/test/CodeGen/WebAssembly/simd-relaxed-fma.ll
@@ -6,12 +6,11 @@ define void @fma_seperate(ptr %a, ptr %b, ptr %c, ptr %dest) {
 ; CHECK-LABEL: fma_seperate:
 ; CHECK:         .functype fma_seperate (i32, i32, i32, i32) -> ()
 ; CHECK-NEXT:  # %bb.0: # %entry
+; CHECK-NEXT:    v128.load $push2=, 0($2):p2align=0
 ; CHECK-NEXT:    v128.load $push1=, 0($1):p2align=0
 ; CHECK-NEXT:    v128.load $push0=, 0($0):p2align=0
-; CHECK-NEXT:    f32x4.mul $push2=, $pop1, $pop0
-; CHECK-NEXT:    v128.load $push3=, 0($2):p2align=0
-; CHECK-NEXT:    f32x4.add $push4=, $pop2, $pop3
-; CHECK-NEXT:    v128.store 0($3):p2align=0, $pop4
+; CHECK-NEXT:    f32x4.relaxed_madd $push3=, $pop2, $pop1, $pop0
+; CHECK-NEXT:    v128.store 0($3):p2align=0, $pop3
 ; CHECK-NEXT:    return
 entry:
   %0 = load <4 x float>, ptr %a, align 1
@@ -28,33 +27,11 @@ define void @fma_llvm(ptr %a, ptr %b, ptr %c, ptr %dest) {
 ; CHECK-LABEL: fma_llvm:
 ; CHECK:         .functype fma_llvm (i32, i32, i32, i32) -> ()
 ; CHECK-NEXT:  # %bb.0: # %entry
-; CHECK-NEXT:    v128.load $push25=, 0($0):p2align=0
-; CHECK-NEXT:    local.tee $push24=, $6=, $pop25
-; CHECK-NEXT:    f32x4.extract_lane $push2=, $pop24, 0
-; CHECK-NEXT:    v128.load $push23=, 0($1):p2align=0
-; CHECK-NEXT:    local.tee $push22=, $5=, $pop23
-; CHECK-NEXT:    f32x4.extract_lane $push1=, $pop22, 0
-; CHECK-NEXT:    v128.load $push21=, 0($2):p2align=0
-; CHECK-NEXT:    local.tee $push20=, $4=, $pop21
-; CHECK-NEXT:    f32x4.extract_lane $push0=, $pop20, 0
-; CHECK-NEXT:    call $push3=, fmaf, $pop2, $pop1, $pop0
-; CHECK-NEXT:    f32x4.splat $push4=, $pop3
-; CHECK-NEXT:    f32x4.extract_lane $push7=, $6, 1
-; CHECK-NEXT:    f32x4.extract_lane $push6=, $5, 1
-; CHECK-NEXT:    f32x4.extract_lane $push5=, $4, 1
-; CHECK-NEXT:    call $push8=, fmaf, $pop7, $pop6, $pop5
-; CHECK-NEXT:    f32x4.replace_lane $push9=, $pop4, 1, $pop8
-; CHECK-NEXT:    f32x4.extract_lane $push12=, $6, 2
-; CHECK-NEXT:    f32x4.extract_lane $push11=, $5, 2
-; CHECK-NEXT:    f32x4.extract_lane $push10=, $4, 2
-; CHECK-NEXT:    call $push13=, fmaf, $pop12, $pop11, $pop10
-; CHECK-NEXT:    f32x4.replace_lane $push14=, $pop9, 2, $pop13
-; CHECK-NEXT:    f32x4.extract_lane $push17=, $6, 3
-; CHECK-NEXT:    f32x4.extract_lane $push16=, $5, 3
-; CHECK-NEXT:    f32x4.extract_lane $push15=, $4, 3
-; CHECK-NEXT:    call $push18=, fmaf, $pop17, $pop16, $pop15
-; CHECK-NEXT:    f32x4.replace_lane $push19=, $pop14, 3, $pop18
-; CHECK-NEXT:    v128.store 0($3):p2align=0, $pop19
+; CHECK-NEXT:    v128.load $push2=, 0($0):p2align=0
+; CHECK-NEXT:    v128.load $push1=, 0($1):p2align=0
+; CHECK-NEXT:    v128.load $push0=, 0($2):p2align=0
+; CHECK-NEXT:    f32x4.relaxed_madd $push3=, $pop2, $pop1, $pop0
+; CHECK-NEXT:    v128.store 0($3):p2align=0, $pop3
 ; CHECK-NEXT:    return
 entry:
   %0 = load <4 x float>, ptr %a, align 1



More information about the llvm-commits mailing list