[llvm-branch-commits] [llvm] a6f0221 - [SLP] fix fast-math-flag propagation on FP reductions

Sanjay Patel via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Sat Jan 23 08:36:19 PST 2021


Author: Sanjay Patel
Date: 2021-01-23T11:17:20-05:00
New Revision: a6f02212764a76935ec5fb704fe86a1a76f65745

URL: https://github.com/llvm/llvm-project/commit/a6f02212764a76935ec5fb704fe86a1a76f65745
DIFF: https://github.com/llvm/llvm-project/commit/a6f02212764a76935ec5fb704fe86a1a76f65745.diff

LOG: [SLP] fix fast-math-flag propagation on FP reductions

As shown in the test diffs, we could miscompile by
propagating flags that did not exist in the original
code.

The flags required for fmin/fmax reductions will be
fixed in a follow-up patch.

Added: 
    

Modified: 
    llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
    llvm/test/Transforms/SLPVectorizer/X86/horizontal.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 78ce4870588c..6c2b10e5b9fa 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -6820,12 +6820,18 @@ class HorizontalReduction {
     if (NumReducedVals < 4)
       return false;
 
-    // FIXME: Fast-math-flags should be set based on the instructions in the
-    //        reduction (not all of 'fast' are required).
+    // Intersect the fast-math-flags from all reduction operations.
+    FastMathFlags RdxFMF;
+    RdxFMF.set();
+    for (ReductionOpsType &RdxOp : ReductionOps) {
+      for (Value *RdxVal : RdxOp) {
+        if (auto *FPMO = dyn_cast<FPMathOperator>(RdxVal))
+          RdxFMF &= FPMO->getFastMathFlags();
+      }
+    }
+
     IRBuilder<> Builder(cast<Instruction>(ReductionRoot));
-    FastMathFlags Unsafe;
-    Unsafe.setFast();
-    Builder.setFastMathFlags(Unsafe);
+    Builder.setFastMathFlags(RdxFMF);
 
     BoUpSLP::ExtraValueToDebugLocsMap ExternallyUsedValues;
     // The same extra argument may be used several times, so log each attempt
@@ -7071,9 +7077,6 @@ class HorizontalReduction {
     assert(isPowerOf2_32(ReduxWidth) &&
            "We only handle power-of-two reductions for now");
 
-    // FIXME: The builder should use an FMF guard. It should not be hard-coded
-    //        to 'fast'.
-    assert(Builder.getFastMathFlags().isFast() && "Expected 'fast' FMF");
     return createSimpleTargetReduction(Builder, TTI, VectorizedValue, RdxKind,
                                        ReductionOps.back());
   }

diff  --git a/llvm/test/Transforms/SLPVectorizer/X86/horizontal.ll b/llvm/test/Transforms/SLPVectorizer/X86/horizontal.ll
index 38d36c676fa7..03ec04cb8cbe 100644
--- a/llvm/test/Transforms/SLPVectorizer/X86/horizontal.ll
+++ b/llvm/test/Transforms/SLPVectorizer/X86/horizontal.ll
@@ -1766,7 +1766,6 @@ bb.1:
   ret void
 }
 
-; FIXME: This is a miscompile.
 ; The FMF on the reduction should match the incoming insts.
 
 define float @fadd_v4f32_fmf(float* %p) {
@@ -1776,7 +1775,7 @@ define float @fadd_v4f32_fmf(float* %p) {
 ; CHECK-NEXT:    [[P3:%.*]] = getelementptr inbounds float, float* [[P]], i64 3
 ; CHECK-NEXT:    [[TMP1:%.*]] = bitcast float* [[P]] to <4 x float>*
 ; CHECK-NEXT:    [[TMP2:%.*]] = load <4 x float>, <4 x float>* [[TMP1]], align 4
-; CHECK-NEXT:    [[TMP3:%.*]] = call fast float @llvm.vector.reduce.fadd.v4f32(float -0.000000e+00, <4 x float> [[TMP2]])
+; CHECK-NEXT:    [[TMP3:%.*]] = call reassoc nsz float @llvm.vector.reduce.fadd.v4f32(float -0.000000e+00, <4 x float> [[TMP2]])
 ; CHECK-NEXT:    ret float [[TMP3]]
 ;
 ; STORE-LABEL: @fadd_v4f32_fmf(
@@ -1785,7 +1784,7 @@ define float @fadd_v4f32_fmf(float* %p) {
 ; STORE-NEXT:    [[P3:%.*]] = getelementptr inbounds float, float* [[P]], i64 3
 ; STORE-NEXT:    [[TMP1:%.*]] = bitcast float* [[P]] to <4 x float>*
 ; STORE-NEXT:    [[TMP2:%.*]] = load <4 x float>, <4 x float>* [[TMP1]], align 4
-; STORE-NEXT:    [[TMP3:%.*]] = call fast float @llvm.vector.reduce.fadd.v4f32(float -0.000000e+00, <4 x float> [[TMP2]])
+; STORE-NEXT:    [[TMP3:%.*]] = call reassoc nsz float @llvm.vector.reduce.fadd.v4f32(float -0.000000e+00, <4 x float> [[TMP2]])
 ; STORE-NEXT:    ret float [[TMP3]]
 ;
   %p1 = getelementptr inbounds float, float* %p, i64 1
@@ -1801,6 +1800,10 @@ define float @fadd_v4f32_fmf(float* %p) {
   ret float %add3
 }
 
+; The minimal FMF for fadd reduction are "reassoc nsz".
+; Only the common FMF of all operations in the reduction propagate to the result.
+; In this example, "contract nnan arcp" are dropped, but "ninf" transfers with the required flags.
+
 define float @fadd_v4f32_fmf_intersect(float* %p) {
 ; CHECK-LABEL: @fadd_v4f32_fmf_intersect(
 ; CHECK-NEXT:    [[P1:%.*]] = getelementptr inbounds float, float* [[P:%.*]], i64 1
@@ -1808,7 +1811,7 @@ define float @fadd_v4f32_fmf_intersect(float* %p) {
 ; CHECK-NEXT:    [[P3:%.*]] = getelementptr inbounds float, float* [[P]], i64 3
 ; CHECK-NEXT:    [[TMP1:%.*]] = bitcast float* [[P]] to <4 x float>*
 ; CHECK-NEXT:    [[TMP2:%.*]] = load <4 x float>, <4 x float>* [[TMP1]], align 4
-; CHECK-NEXT:    [[TMP3:%.*]] = call fast float @llvm.vector.reduce.fadd.v4f32(float -0.000000e+00, <4 x float> [[TMP2]])
+; CHECK-NEXT:    [[TMP3:%.*]] = call reassoc ninf nsz float @llvm.vector.reduce.fadd.v4f32(float -0.000000e+00, <4 x float> [[TMP2]])
 ; CHECK-NEXT:    ret float [[TMP3]]
 ;
 ; STORE-LABEL: @fadd_v4f32_fmf_intersect(
@@ -1817,7 +1820,7 @@ define float @fadd_v4f32_fmf_intersect(float* %p) {
 ; STORE-NEXT:    [[P3:%.*]] = getelementptr inbounds float, float* [[P]], i64 3
 ; STORE-NEXT:    [[TMP1:%.*]] = bitcast float* [[P]] to <4 x float>*
 ; STORE-NEXT:    [[TMP2:%.*]] = load <4 x float>, <4 x float>* [[TMP1]], align 4
-; STORE-NEXT:    [[TMP3:%.*]] = call fast float @llvm.vector.reduce.fadd.v4f32(float -0.000000e+00, <4 x float> [[TMP2]])
+; STORE-NEXT:    [[TMP3:%.*]] = call reassoc ninf nsz float @llvm.vector.reduce.fadd.v4f32(float -0.000000e+00, <4 x float> [[TMP2]])
 ; STORE-NEXT:    ret float [[TMP3]]
 ;
   %p1 = getelementptr inbounds float, float* %p, i64 1


        


More information about the llvm-branch-commits mailing list