[llvm] [X86][StrictFP] Combine fcmp + select to fmin/fmax for some predicates (PR #109512)

Phoebe Wang via llvm-commits llvm-commits at lists.llvm.org
Fri Sep 20 21:29:53 PDT 2024


https://github.com/phoebewang created https://github.com/llvm/llvm-project/pull/109512

X86 maxss/minss etc. instructions won't turn SNaN to QNaN, so we can combine fcmp + select to them for some predicates.

>From b29686db8dbd1c43ba20a24e97d39fb33758cca5 Mon Sep 17 00:00:00 2001
From: "Wang, Phoebe" <phoebe.wang at intel.com>
Date: Sat, 21 Sep 2024 12:24:46 +0800
Subject: [PATCH] [X86][StrictFP] Combine fcmp + select to fmin/fmax for some
 predicates

X86 maxss/minss etc. instructions won't turn SNaN to QNaN, so we can
combine fcmp + select to them for some predicates.
---
 llvm/lib/Target/X86/X86ISelLowering.cpp       |  33 ++--
 llvm/lib/Target/X86/X86ISelLowering.h         |   4 +
 llvm/lib/Target/X86/X86InstrAVX512.td         |   8 +-
 llvm/lib/Target/X86/X86InstrFragmentsSIMD.td  |  12 ++
 llvm/lib/Target/X86/X86InstrSSE.td            |   8 +-
 llvm/test/CodeGen/X86/fp-strict-scalar-cmp.ll | 149 +++++++++++++++++-
 6 files changed, 196 insertions(+), 18 deletions(-)

diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index b9c9e5703849ae..3927211c39b0bf 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -34176,10 +34176,12 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const {
   NODE_NAME_CASE(FMAXS)
   NODE_NAME_CASE(FMAX_SAE)
   NODE_NAME_CASE(FMAXS_SAE)
+  NODE_NAME_CASE(STRICT_FMAX)
   NODE_NAME_CASE(FMIN)
   NODE_NAME_CASE(FMINS)
   NODE_NAME_CASE(FMIN_SAE)
   NODE_NAME_CASE(FMINS_SAE)
+  NODE_NAME_CASE(STRICT_FMIN)
   NODE_NAME_CASE(FMAXC)
   NODE_NAME_CASE(FMINC)
   NODE_NAME_CASE(FRSQRT)
@@ -46494,17 +46496,22 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG,
   // x<=y?x:y, because of how they handle negative zero (which can be
   // ignored in unsafe-math mode).
   // We also try to create v2f32 min/max nodes, which we later widen to v4f32.
-  if (Cond.getOpcode() == ISD::SETCC && VT.isFloatingPoint() &&
-      VT != MVT::f80 && VT != MVT::f128 && !isSoftF16(VT, Subtarget) &&
-      (TLI.isTypeLegal(VT) || VT == MVT::v2f32) &&
+  if ((Cond.getOpcode() == ISD::SETCC ||
+       Cond.getOpcode() == ISD::STRICT_FSETCC ||
+       Cond.getOpcode() == ISD::STRICT_FSETCCS) &&
+      VT.isFloatingPoint() && VT != MVT::f80 && VT != MVT::f128 &&
+      !isSoftF16(VT, Subtarget) && (TLI.isTypeLegal(VT) || VT == MVT::v2f32) &&
       (Subtarget.hasSSE2() ||
        (Subtarget.hasSSE1() && VT.getScalarType() == MVT::f32))) {
-    ISD::CondCode CC = cast<CondCodeSDNode>(Cond.getOperand(2))->get();
+    bool IsStrict = Cond->isStrictFPOpcode();
+    ISD::CondCode CC =
+        cast<CondCodeSDNode>(Cond.getOperand(IsStrict ? 3 : 2))->get();
+    SDValue Op0 = Cond.getOperand(IsStrict ? 1 : 0);
+    SDValue Op1 = Cond.getOperand(IsStrict ? 2 : 1);
 
     unsigned Opcode = 0;
     // Check for x CC y ? x : y.
-    if (DAG.isEqualTo(LHS, Cond.getOperand(0)) &&
-        DAG.isEqualTo(RHS, Cond.getOperand(1))) {
+    if (DAG.isEqualTo(LHS, Op0) && DAG.isEqualTo(RHS, Op1)) {
       switch (CC) {
       default: break;
       case ISD::SETULT:
@@ -46572,8 +46579,7 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG,
         break;
       }
     // Check for x CC y ? y : x -- a min/max with reversed arms.
-    } else if (DAG.isEqualTo(LHS, Cond.getOperand(1)) &&
-               DAG.isEqualTo(RHS, Cond.getOperand(0))) {
+    } else if (DAG.isEqualTo(LHS, Op1) && DAG.isEqualTo(RHS, Op0)) {
       switch (CC) {
       default: break;
       case ISD::SETOGE:
@@ -46638,8 +46644,17 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG,
       }
     }
 
-    if (Opcode)
+    if (Opcode) {
+      if (IsStrict) {
+        SDValue Ret = DAG.getNode(Opcode == X86ISD::FMIN ? X86ISD::STRICT_FMIN
+                                                         : X86ISD::STRICT_FMAX,
+                                  DL, {N->getValueType(0), MVT::Other},
+                                  {Cond.getOperand(0), LHS, RHS});
+        DAG.ReplaceAllUsesOfValueWith(Cond.getValue(1), Ret.getValue(1));
+        return Ret;
+      }
       return DAG.getNode(Opcode, DL, N->getValueType(0), LHS, RHS);
+    }
   }
 
   // Some mask scalar intrinsics rely on checking if only one bit is set
diff --git a/llvm/lib/Target/X86/X86ISelLowering.h b/llvm/lib/Target/X86/X86ISelLowering.h
index 0ab42f032c3ea6..bf1db9f6da366e 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.h
+++ b/llvm/lib/Target/X86/X86ISelLowering.h
@@ -850,6 +850,10 @@ namespace llvm {
     // Perform an FP80 add after changing precision control in FPCW.
     STRICT_FP80_ADD,
 
+    /// Floating point max and min.
+    STRICT_FMAX,
+    STRICT_FMIN,
+
     // WARNING: Only add nodes here if they are strict FP nodes. Non-memory and
     // non-strict FP nodes should be above FIRST_TARGET_STRICTFP_OPCODE.
 
diff --git a/llvm/lib/Target/X86/X86InstrAVX512.td b/llvm/lib/Target/X86/X86InstrAVX512.td
index 928abac46da866..f574bc882dd638 100644
--- a/llvm/lib/Target/X86/X86InstrAVX512.td
+++ b/llvm/lib/Target/X86/X86InstrAVX512.td
@@ -5395,7 +5395,7 @@ multiclass avx512_fp_scalar_round<bits<8> opc, string OpcodeStr,X86VectorVTInfo
                           EVEX_B, EVEX_RC, Sched<[sched]>;
 }
 multiclass avx512_fp_scalar_sae<bits<8> opc, string OpcodeStr,X86VectorVTInfo _,
-                                SDNode OpNode, SDNode VecNode, SDNode SaeNode,
+                                SDPatternOperator OpNode, SDNode VecNode, SDNode SaeNode,
                                 X86FoldableSchedWrite sched, bit IsCommutable> {
   let ExeDomain = _.ExeDomain in {
   defm rr_Int : AVX512_maskable_scalar<opc, MRMSrcReg, _, (outs _.RC:$dst),
@@ -5458,7 +5458,7 @@ multiclass avx512_binop_s_round<bits<8> opc, string OpcodeStr, SDPatternOperator
                                 T_MAP5, XS, EVEX, VVVV, VEX_LIG, EVEX_CD8<16, CD8VT1>;
 }
 
-multiclass avx512_binop_s_sae<bits<8> opc, string OpcodeStr, SDNode OpNode,
+multiclass avx512_binop_s_sae<bits<8> opc, string OpcodeStr, SDPatternOperator OpNode,
                               SDNode VecNode, SDNode SaeNode,
                               X86SchedWriteSizes sched, bit IsCommutable> {
   defm SSZ : avx512_fp_scalar_sae<opc, OpcodeStr#"ss", f32x_info, OpNode,
@@ -5481,9 +5481,9 @@ defm VSUB : avx512_binop_s_round<0x5C, "vsub", any_fsub, X86fsubs, X86fsubRnds,
                                  SchedWriteFAddSizes, 0>;
 defm VDIV : avx512_binop_s_round<0x5E, "vdiv", any_fdiv, X86fdivs, X86fdivRnds,
                                  SchedWriteFDivSizes, 0>;
-defm VMIN : avx512_binop_s_sae<0x5D, "vmin", X86fmin, X86fmins, X86fminSAEs,
+defm VMIN : avx512_binop_s_sae<0x5D, "vmin", X86any_fmin, X86fmins, X86fminSAEs,
                                SchedWriteFCmpSizes, 0>;
-defm VMAX : avx512_binop_s_sae<0x5F, "vmax", X86fmax, X86fmaxs, X86fmaxSAEs,
+defm VMAX : avx512_binop_s_sae<0x5F, "vmax", X86any_fmax, X86fmaxs, X86fmaxSAEs,
                                SchedWriteFCmpSizes, 0>;
 
 // MIN/MAX nodes are commutable under "unsafe-fp-math". In this case we use
diff --git a/llvm/lib/Target/X86/X86InstrFragmentsSIMD.td b/llvm/lib/Target/X86/X86InstrFragmentsSIMD.td
index ed1bff05b7316c..c09522709d2f0d 100644
--- a/llvm/lib/Target/X86/X86InstrFragmentsSIMD.td
+++ b/llvm/lib/Target/X86/X86InstrFragmentsSIMD.td
@@ -46,6 +46,18 @@ def X86fminc    : SDNode<"X86ISD::FMINC", SDTFPBinOp,
 def X86fmaxc    : SDNode<"X86ISD::FMAXC", SDTFPBinOp,
     [SDNPCommutative, SDNPAssociative]>;
 
+def X86strict_fmin : SDNode<"X86ISD::STRICT_FMIN", SDTFPBinOp,
+                            [SDNPHasChain]>;
+def X86strict_fmax : SDNode<"X86ISD::STRICT_FMAX", SDTFPBinOp,
+                            [SDNPHasChain]>;
+
+def X86any_fmin : PatFrags<(ops node:$src1, node:$src2),
+                           [(X86strict_fmin node:$src1, node:$src2),
+                            (X86fmin node:$src1, node:$src2)]>;
+def X86any_fmax : PatFrags<(ops node:$src1, node:$src2),
+                           [(X86strict_fmax node:$src1, node:$src2),
+                            (X86fmax node:$src1, node:$src2)]>;
+
 def X86fand    : SDNode<"X86ISD::FAND",      SDTFPBinOp,
                         [SDNPCommutative, SDNPAssociative]>;
 def X86for     : SDNode<"X86ISD::FOR",       SDTFPBinOp,
diff --git a/llvm/lib/Target/X86/X86InstrSSE.td b/llvm/lib/Target/X86/X86InstrSSE.td
index d51125a209db9d..e77e56aa96c670 100644
--- a/llvm/lib/Target/X86/X86InstrSSE.td
+++ b/llvm/lib/Target/X86/X86InstrSSE.td
@@ -2730,11 +2730,11 @@ let isCommutable = 0 in {
   defm DIV : basic_sse12_fp_binop_p<0x5E, "div", any_fdiv, SchedWriteFDivSizes>,
              basic_sse12_fp_binop_s<0x5E, "div", any_fdiv, SchedWriteFDivSizes>,
              basic_sse12_fp_binop_s_int<0x5E, "div", null_frag, SchedWriteFDivSizes>;
-  defm MAX : basic_sse12_fp_binop_p<0x5F, "max", X86fmax, SchedWriteFCmpSizes>,
-             basic_sse12_fp_binop_s<0x5F, "max", X86fmax, SchedWriteFCmpSizes>,
+  defm MAX : basic_sse12_fp_binop_p<0x5F, "max", X86any_fmax, SchedWriteFCmpSizes>,
+             basic_sse12_fp_binop_s<0x5F, "max", X86any_fmax, SchedWriteFCmpSizes>,
              basic_sse12_fp_binop_s_int<0x5F, "max", X86fmaxs, SchedWriteFCmpSizes>;
-  defm MIN : basic_sse12_fp_binop_p<0x5D, "min", X86fmin, SchedWriteFCmpSizes>,
-             basic_sse12_fp_binop_s<0x5D, "min", X86fmin, SchedWriteFCmpSizes>,
+  defm MIN : basic_sse12_fp_binop_p<0x5D, "min", X86any_fmin, SchedWriteFCmpSizes>,
+             basic_sse12_fp_binop_s<0x5D, "min", X86any_fmin, SchedWriteFCmpSizes>,
              basic_sse12_fp_binop_s_int<0x5D, "min", X86fmins, SchedWriteFCmpSizes>;
 }
 
diff --git a/llvm/test/CodeGen/X86/fp-strict-scalar-cmp.ll b/llvm/test/CodeGen/X86/fp-strict-scalar-cmp.ll
index cb1876fee05aea..272d2b0729136a 100644
--- a/llvm/test/CodeGen/X86/fp-strict-scalar-cmp.ll
+++ b/llvm/test/CodeGen/X86/fp-strict-scalar-cmp.ll
@@ -4202,7 +4202,154 @@ define void @foo(float %0, float %1) #0 {
 }
 declare dso_local void @bar()
 
-attributes #0 = { strictfp }
+define float @fcmp_select_ogt(float %f1, float %f2) #0 {
+; SSE-32-LABEL: fcmp_select_ogt:
+; SSE-32:       # %bb.0:
+; SSE-32-NEXT:    pushl %eax
+; SSE-32-NEXT:    movss {{.*#+}} xmm0 = mem[0],zero,zero,zero
+; SSE-32-NEXT:    maxss {{[0-9]+}}(%esp), %xmm0
+; SSE-32-NEXT:    movss %xmm0, (%esp)
+; SSE-32-NEXT:    flds (%esp)
+; SSE-32-NEXT:    wait
+; SSE-32-NEXT:    popl %eax
+; SSE-32-NEXT:    retl
+;
+; SSE-64-LABEL: fcmp_select_ogt:
+; SSE-64:       # %bb.0:
+; SSE-64-NEXT:    maxss %xmm1, %xmm0
+; SSE-64-NEXT:    retq
+;
+; AVX-32-LABEL: fcmp_select_ogt:
+; AVX-32:       # %bb.0:
+; AVX-32-NEXT:    pushl %eax
+; AVX-32-NEXT:    vmovss {{.*#+}} xmm0 = mem[0],zero,zero,zero
+; AVX-32-NEXT:    vmaxss {{[0-9]+}}(%esp), %xmm0, %xmm0
+; AVX-32-NEXT:    vmovss %xmm0, (%esp)
+; AVX-32-NEXT:    flds (%esp)
+; AVX-32-NEXT:    wait
+; AVX-32-NEXT:    popl %eax
+; AVX-32-NEXT:    retl
+;
+; AVX-64-LABEL: fcmp_select_ogt:
+; AVX-64:       # %bb.0:
+; AVX-64-NEXT:    vmaxss %xmm1, %xmm0, %xmm0
+; AVX-64-NEXT:    retq
+;
+; X87-LABEL: fcmp_select_ogt:
+; X87:       # %bb.0:
+; X87-NEXT:    flds {{[0-9]+}}(%esp)
+; X87-NEXT:    flds {{[0-9]+}}(%esp)
+; X87-NEXT:    fucom %st(1)
+; X87-NEXT:    wait
+; X87-NEXT:    fnstsw %ax
+; X87-NEXT:    # kill: def $ah killed $ah killed $ax
+; X87-NEXT:    sahf
+; X87-NEXT:    ja .LBB57_2
+; X87-NEXT:  # %bb.1:
+; X87-NEXT:    fstp %st(0)
+; X87-NEXT:    fldz
+; X87-NEXT:    fxch %st(1)
+; X87-NEXT:  .LBB57_2:
+; X87-NEXT:    fstp %st(1)
+; X87-NEXT:    wait
+; X87-NEXT:    retl
+;
+; X87-CMOV-LABEL: fcmp_select_ogt:
+; X87-CMOV:       # %bb.0:
+; X87-CMOV-NEXT:    flds {{[0-9]+}}(%esp)
+; X87-CMOV-NEXT:    flds {{[0-9]+}}(%esp)
+; X87-CMOV-NEXT:    fucomi %st(1), %st
+; X87-CMOV-NEXT:    fxch %st(1)
+; X87-CMOV-NEXT:    fcmovnbe %st(1), %st
+; X87-CMOV-NEXT:    fstp %st(1)
+; X87-CMOV-NEXT:    wait
+; X87-CMOV-NEXT:    retl
+  %cond = call i1 @llvm.experimental.constrained.fcmp.f32(
+                                               float %f1, float %f2, metadata !"ogt",
+                                               metadata !"fpexcept.strict")
+  %res = select i1 %cond, float %f1, float %f2
+  ret float %res
+}
+
+define double @fcmp_select_ule(double %f1, double %f2) #0 {
+; SSE-32-LABEL: fcmp_select_ule:
+; SSE-32:       # %bb.0:
+; SSE-32-NEXT:    pushl %ebp
+; SSE-32-NEXT:    movl %esp, %ebp
+; SSE-32-NEXT:    andl $-8, %esp
+; SSE-32-NEXT:    subl $8, %esp
+; SSE-32-NEXT:    movsd {{.*#+}} xmm0 = mem[0],zero
+; SSE-32-NEXT:    minsd 8(%ebp), %xmm0
+; SSE-32-NEXT:    movsd %xmm0, (%esp)
+; SSE-32-NEXT:    fldl (%esp)
+; SSE-32-NEXT:    wait
+; SSE-32-NEXT:    movl %ebp, %esp
+; SSE-32-NEXT:    popl %ebp
+; SSE-32-NEXT:    retl
+;
+; SSE-64-LABEL: fcmp_select_ule:
+; SSE-64:       # %bb.0:
+; SSE-64-NEXT:    minsd %xmm0, %xmm1
+; SSE-64-NEXT:    movapd %xmm1, %xmm0
+; SSE-64-NEXT:    retq
+;
+; AVX-32-LABEL: fcmp_select_ule:
+; AVX-32:       # %bb.0:
+; AVX-32-NEXT:    pushl %ebp
+; AVX-32-NEXT:    movl %esp, %ebp
+; AVX-32-NEXT:    andl $-8, %esp
+; AVX-32-NEXT:    subl $8, %esp
+; AVX-32-NEXT:    vmovsd {{.*#+}} xmm0 = mem[0],zero
+; AVX-32-NEXT:    vminsd 8(%ebp), %xmm0, %xmm0
+; AVX-32-NEXT:    vmovsd %xmm0, (%esp)
+; AVX-32-NEXT:    fldl (%esp)
+; AVX-32-NEXT:    wait
+; AVX-32-NEXT:    movl %ebp, %esp
+; AVX-32-NEXT:    popl %ebp
+; AVX-32-NEXT:    retl
+;
+; AVX-64-LABEL: fcmp_select_ule:
+; AVX-64:       # %bb.0:
+; AVX-64-NEXT:    vminsd %xmm0, %xmm1, %xmm0
+; AVX-64-NEXT:    retq
+;
+; X87-LABEL: fcmp_select_ule:
+; X87:       # %bb.0:
+; X87-NEXT:    fldl {{[0-9]+}}(%esp)
+; X87-NEXT:    fldl {{[0-9]+}}(%esp)
+; X87-NEXT:    fcom %st(1)
+; X87-NEXT:    wait
+; X87-NEXT:    fnstsw %ax
+; X87-NEXT:    # kill: def $ah killed $ah killed $ax
+; X87-NEXT:    sahf
+; X87-NEXT:    jbe .LBB58_2
+; X87-NEXT:  # %bb.1:
+; X87-NEXT:    fstp %st(0)
+; X87-NEXT:    fldz
+; X87-NEXT:    fxch %st(1)
+; X87-NEXT:  .LBB58_2:
+; X87-NEXT:    fstp %st(1)
+; X87-NEXT:    wait
+; X87-NEXT:    retl
+;
+; X87-CMOV-LABEL: fcmp_select_ule:
+; X87-CMOV:       # %bb.0:
+; X87-CMOV-NEXT:    fldl {{[0-9]+}}(%esp)
+; X87-CMOV-NEXT:    fldl {{[0-9]+}}(%esp)
+; X87-CMOV-NEXT:    fcomi %st(1), %st
+; X87-CMOV-NEXT:    fxch %st(1)
+; X87-CMOV-NEXT:    fcmovbe %st(1), %st
+; X87-CMOV-NEXT:    fstp %st(1)
+; X87-CMOV-NEXT:    wait
+; X87-CMOV-NEXT:    retl
+  %cond = call i1 @llvm.experimental.constrained.fcmps.f64(
+                                               double %f1, double %f2, metadata !"ule",
+                                               metadata !"fpexcept.strict")
+  %res = select i1 %cond, double %f1, double %f2
+  ret double %res
+}
+
+attributes #0 = { nounwind strictfp }
 
 declare i1 @llvm.experimental.constrained.fcmp.f32(float, float, metadata, metadata)
 declare i1 @llvm.experimental.constrained.fcmp.f64(double, double, metadata, metadata)



More information about the llvm-commits mailing list