[llvm] [InstCombine] Canonicalize reassoc contract fmuladd to fmul + fadd (PR #90434)

via llvm-commits llvm-commits at lists.llvm.org
Sun Apr 28 23:55:34 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-llvm-transforms

Author: Allen (vfdff)

<details>
<summary>Changes</summary>

According https://llvm.org/docs/LangRef.html#fast-math-flags, I think only contract and reassoc is required for this transformation. a) contract:  fusing a multiply followed by an addition into a fused multiply-and-add b) reassoc: reassociation transformations when there may be other fp operations.

Fix https://github.com/llvm/llvm-project/issues/90379

---
Full diff: https://github.com/llvm/llvm-project/pull/90434.diff


2 Files Affected:

- (modified) llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp (+7-7) 
- (modified) llvm/test/Transforms/InstCombine/fma.ll (+28) 


``````````diff
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index e5652458f150b5..bf5472f4fd1e4f 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -2389,12 +2389,13 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
     break;
   }
   case Intrinsic::fmuladd: {
-    // Canonicalize fast fmuladd to the separate fmul + fadd.
-    if (II->isFast()) {
+    // Canonicalize reassoc contract fmuladd to the separate fmul + fadd.
+    FastMathFlags FMF = II->getFastMathFlags();
+    if (FMF.allowReassoc() && FMF.allowContract()) {
       BuilderTy::FastMathFlagGuard Guard(Builder);
-      Builder.setFastMathFlags(II->getFastMathFlags());
-      Value *Mul = Builder.CreateFMul(II->getArgOperand(0),
-                                      II->getArgOperand(1));
+      Builder.setFastMathFlags(FMF);
+      Value *Mul =
+          Builder.CreateFMul(II->getArgOperand(0), II->getArgOperand(1));
       Value *Add = Builder.CreateFAdd(Mul, II->getArgOperand(2));
       Add->takeName(II);
       return replaceInstUsesWith(*II, Add);
@@ -2402,8 +2403,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
 
     // Try to simplify the underlying FMul.
     if (Value *V = simplifyFMulInst(II->getArgOperand(0), II->getArgOperand(1),
-                                    II->getFastMathFlags(),
-                                    SQ.getWithInstruction(II))) {
+                                    FMF, SQ.getWithInstruction(II))) {
       auto *FAdd = BinaryOperator::CreateFAdd(V, II->getArgOperand(2));
       FAdd->copyFastMathFlags(II);
       return FAdd;
diff --git a/llvm/test/Transforms/InstCombine/fma.ll b/llvm/test/Transforms/InstCombine/fma.ll
index cf3d7f3c525a5f..42e584e9ef9a1b 100644
--- a/llvm/test/Transforms/InstCombine/fma.ll
+++ b/llvm/test/Transforms/InstCombine/fma.ll
@@ -204,6 +204,34 @@ define float @fmuladd_fneg_x_fneg_y_fast(float %x, float %y, float %z) {
   ret float %fmuladd
 }
 
+define float @fmuladd_unfold(float %x, float %y, float %z) {
+; CHECK-LABEL: @fmuladd_unfold(
+; CHECK-NEXT:    [[TMP1:%.*]] = fmul reassoc contract float [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT:    [[FMULADD:%.*]] = fadd reassoc contract float [[TMP1]], [[Z:%.*]]
+; CHECK-NEXT:    ret float [[FMULADD]]
+;
+  %fmuladd = call reassoc contract float @llvm.fmuladd.f32(float %x, float %y, float %z)
+  ret float %fmuladd
+}
+
+define float @fmuladd_unfold_missing_reassoc(float %x, float %y, float %z) {
+; CHECK-LABEL: @fmuladd_unfold_missing_reassoc(
+; CHECK-NEXT:    [[FMULADD:%.*]] = call contract float @llvm.fmuladd.f32(float [[X:%.*]], float [[Y:%.*]], float [[Z:%.*]])
+; CHECK-NEXT:    ret float [[FMULADD]]
+;
+  %fmuladd = call contract float @llvm.fmuladd.f32(float %x, float %y, float %z)
+  ret float %fmuladd
+}
+
+define float @fmuladd_unfold_missing_contract(float %x, float %y, float %z) {
+; CHECK-LABEL: @fmuladd_unfold_missing_contract(
+; CHECK-NEXT:    [[FMULADD:%.*]] = call reassoc float @llvm.fmuladd.f32(float [[X:%.*]], float [[Y:%.*]], float [[Z:%.*]])
+; CHECK-NEXT:    ret float [[FMULADD]]
+;
+  %fmuladd = call reassoc float @llvm.fmuladd.f32(float %x, float %y, float %z)
+  ret float %fmuladd
+}
+
 define float @fmuladd_unary_fneg_x_unary_fneg_y_fast(float %x, float %y, float %z) {
 ; CHECK-LABEL: @fmuladd_unary_fneg_x_unary_fneg_y_fast(
 ; CHECK-NEXT:    [[TMP1:%.*]] = fmul fast float [[X:%.*]], [[Y:%.*]]

``````````

</details>


https://github.com/llvm/llvm-project/pull/90434


More information about the llvm-commits mailing list