[llvm] relaxed simd fma (PR #147487)
via llvm-commits
llvm-commits at lists.llvm.org
Tue Jul 8 02:03:21 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-llvm-selectiondag
Author: jjasmine (badumbatish)
<details>
<summary>Changes</summary>
- **Precommit test for #<!-- -->121311**
- **[WASM] Optimize fma when relaxed and ffast-math**
---
Full diff: https://github.com/llvm/llvm-project/pull/147487.diff
3 Files Affected:
- (modified) llvm/include/llvm/CodeGen/SelectionDAGNodes.h (+1)
- (modified) llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp (+41)
- (added) llvm/test/CodeGen/WebAssembly/simd-relaxed-fma.ll (+43)
``````````diff
diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
index a3675eecfea3f..ec566b168bc3d 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h
@@ -475,6 +475,7 @@ struct SDNodeFlags {
bool hasAllowReassociation() const { return Flags & AllowReassociation; }
bool hasNoFPExcept() const { return Flags & NoFPExcept; }
bool hasUnpredictable() const { return Flags & Unpredictable; }
+ bool hasFastMath() const { return Flags & FastMathFlags; }
bool operator==(const SDNodeFlags &Other) const {
return Flags == Other.Flags;
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index bf2e04caa0a61..ef0146f28aba1 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -182,6 +182,12 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
// SIMD-specific configuration
if (Subtarget->hasSIMD128()) {
+ // Enable fma optimization for wasm relaxed simd
+ if (Subtarget->hasRelaxedSIMD()) {
+ setTargetDAGCombine(ISD::FADD);
+ setTargetDAGCombine(ISD::FMA);
+ }
+
// Combine partial.reduce.add before legalization gets confused.
setTargetDAGCombine(ISD::INTRINSIC_WO_CHAIN);
@@ -3412,6 +3418,37 @@ static SDValue performSETCCCombine(SDNode *N,
return SDValue();
}
+static SDValue performFAddCombine(SDNode *N, SelectionDAG &DAG) {
+ assert(N->getOpcode() == ISD::FADD);
+ using namespace llvm::SDPatternMatch;
+ if (!N->getFlags().hasFastMath())
+ return SDValue();
+
+ SDLoc DL(N);
+ SDValue A, B, C;
+ EVT VecVT = N->getValueType(0);
+ if (sd_match(N, m_FAdd(m_Value(A), m_FMul(m_Value(B), m_Value(C)))))
+ return DAG.getNode(
+ ISD::INTRINSIC_WO_CHAIN, DL, VecVT,
+ {DAG.getConstant(Intrinsic::wasm_relaxed_madd, DL, MVT::i32), A, B, C});
+
+ return SDValue();
+}
+
+static SDValue performFMACombine(SDNode *N, SelectionDAG &DAG) {
+ assert(N->getOpcode() == ISD::FMA);
+ if (!N->getFlags().hasFastMath())
+ return SDValue();
+
+ SDLoc DL(N);
+ SDValue A = N->getOperand(0), B = N->getOperand(1), C = N->getOperand(2);
+ EVT VecVT = N->getValueType(0);
+
+ return DAG.getNode(
+ ISD::INTRINSIC_WO_CHAIN, DL, VecVT,
+ {DAG.getConstant(Intrinsic::wasm_relaxed_madd, DL, MVT::i32), A, B, C});
+}
+
static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG) {
assert(N->getOpcode() == ISD::MUL);
EVT VT = N->getValueType(0);
@@ -3529,6 +3566,10 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
return AnyAllCombine;
return performLowerPartialReduction(N, DCI.DAG);
}
+ case ISD::FADD:
+ return performFAddCombine(N, DCI.DAG);
+ case ISD::FMA:
+ return performFMACombine(N, DCI.DAG);
case ISD::MUL:
return performMulCombine(N, DCI.DAG);
}
diff --git a/llvm/test/CodeGen/WebAssembly/simd-relaxed-fma.ll b/llvm/test/CodeGen/WebAssembly/simd-relaxed-fma.ll
new file mode 100644
index 0000000000000..fe5e8573f12b4
--- /dev/null
+++ b/llvm/test/CodeGen/WebAssembly/simd-relaxed-fma.ll
@@ -0,0 +1,43 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+
+; RUN: llc < %s -verify-machineinstrs -disable-wasm-fallthrough-return-opt -wasm-disable-explicit-locals -wasm-keep-registers -mattr=+simd128,+relaxed-simd | FileCheck %s
+target triple = "wasm32"
+define void @fma_seperate(ptr %a, ptr %b, ptr %c, ptr %dest) {
+; CHECK-LABEL: fma_seperate:
+; CHECK: .functype fma_seperate (i32, i32, i32, i32) -> ()
+; CHECK-NEXT: # %bb.0: # %entry
+; CHECK-NEXT: v128.load $push2=, 0($2):p2align=0
+; CHECK-NEXT: v128.load $push1=, 0($1):p2align=0
+; CHECK-NEXT: v128.load $push0=, 0($0):p2align=0
+; CHECK-NEXT: f32x4.relaxed_madd $push3=, $pop2, $pop1, $pop0
+; CHECK-NEXT: v128.store 0($3):p2align=0, $pop3
+; CHECK-NEXT: return
+entry:
+ %0 = load <4 x float>, ptr %a, align 1
+ %1 = load <4 x float>, ptr %b, align 1
+ %2 = load <4 x float>, ptr %c, align 1
+ %mul.i = fmul fast <4 x float> %1, %0
+ %add.i = fadd fast <4 x float> %mul.i, %2
+ store <4 x float> %add.i, ptr %dest, align 1
+ ret void
+}
+
+; Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(argmem: readwrite)
+define void @fma_llvm(ptr %a, ptr %b, ptr %c, ptr %dest) {
+; CHECK-LABEL: fma_llvm:
+; CHECK: .functype fma_llvm (i32, i32, i32, i32) -> ()
+; CHECK-NEXT: # %bb.0: # %entry
+; CHECK-NEXT: v128.load $push2=, 0($0):p2align=0
+; CHECK-NEXT: v128.load $push1=, 0($1):p2align=0
+; CHECK-NEXT: v128.load $push0=, 0($2):p2align=0
+; CHECK-NEXT: f32x4.relaxed_madd $push3=, $pop2, $pop1, $pop0
+; CHECK-NEXT: v128.store 0($3):p2align=0, $pop3
+; CHECK-NEXT: return
+entry:
+ %0 = load <4 x float>, ptr %a, align 1
+ %1 = load <4 x float>, ptr %b, align 1
+ %2 = load <4 x float>, ptr %c, align 1
+ %fma = tail call fast <4 x float> @llvm.fma.v4f32(<4 x float> %0, <4 x float> %1, <4 x float> %2)
+ store <4 x float> %fma, ptr %dest, align 1
+ ret void
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/147487
More information about the llvm-commits
mailing list