[llvm] 6592bce - [x86] invert a vector select IR canonicalization with a binop identity constant

Sanjay Patel via llvm-commits llvm-commits at lists.llvm.org
Wed Feb 2 05:17:59 PST 2022


Author: Sanjay Patel
Date: 2022-02-02T08:17:53-05:00
New Revision: 6592bcecd4ffc03f72d23f81bcb8d51f8ebeb07d

URL: https://github.com/llvm/llvm-project/commit/6592bcecd4ffc03f72d23f81bcb8d51f8ebeb07d
DIFF: https://github.com/llvm/llvm-project/commit/6592bcecd4ffc03f72d23f81bcb8d51f8ebeb07d.diff

LOG: [x86] invert a vector select IR canonicalization with a binop identity constant

This is an intentionally limited/different form of D90113.
That patch bravely tries to generalize folds where we pull
a binop into the arms of a select:
N0 + (Cond ? 0 : FVal) --> Cond ? N0 : (N0 + FVal)
...but it is not universally profitable.

This is the inverse of IR canonicalization as discussed in
D113442.

We know that this transform is not entirely profitable even
within x86, so we only handle x86 vector fadd/fsub as a 1st
step. The intent is to prevent AVX512 regressions as mentioned
in D113442.

The plan is to port this to DAGCombiner (so it will eventually
look more like D90113) and add more types/cases in pieces with
many more tests to verify that we are seeing improvements.

Differential Revision: https://reviews.llvm.org/D118644

Added: 
    

Modified: 
    llvm/lib/Target/X86/X86ISelLowering.cpp
    llvm/test/CodeGen/X86/avx512fp16-arith-intrinsics.ll
    llvm/test/CodeGen/X86/vector-bo-select.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index ea52cc45be12d..1d927d59d1aec 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -48942,6 +48942,83 @@ static SDValue combineFaddCFmul(SDNode *N, SelectionDAG &DAG,
   return DAG.getBitcast(VT, CFmul);
 }
 
+/// This inverts a canonicalization in IR that replaces a variable select arm
+/// with an identity constant. Codegen improves if we re-use the variable
+/// operand rather than load a constant. This can also be converted into a
+/// masked vector operation if the target supports it.
+static SDValue foldSelectWithIdentityConstant(SDNode *N, SelectionDAG &DAG,
+                                              bool ShouldCommuteOperands) {
+  // Match a select as operand 1. The identity constant that we are looking for
+  // is only valid as operand 1 of a non-commutative binop.
+  SDValue N0 = N->getOperand(0);
+  SDValue N1 = N->getOperand(1);
+  if (ShouldCommuteOperands)
+    std::swap(N0, N1);
+
+  // TODO: Should this apply to scalar select too?
+  if (!N1.hasOneUse() || N1.getOpcode() != ISD::VSELECT)
+    return SDValue();
+
+  unsigned Opcode = N->getOpcode();
+  EVT VT = N->getValueType(0);
+  SDValue Cond = N1.getOperand(0);
+  SDValue TVal = N1.getOperand(1);
+  SDValue FVal = N1.getOperand(2);
+
+  // TODO: This (and possibly the entire function) belongs in a
+  //       target-independent location with target hooks.
+  // TODO: The cases should match with IR's ConstantExpr::getBinOpIdentity().
+  // TODO: With fast-math (NSZ), allow the opposite-sign form of zero?
+  auto isIdentityConstantForOpcode = [](unsigned Opcode, SDValue V) {
+    if (ConstantFPSDNode *C = isConstOrConstSplatFP(V)) {
+      switch (Opcode) {
+      case ISD::FADD: // X + -0.0 --> X
+        return C->isZero() && C->isNegative();
+      case ISD::FSUB: // X - 0.0 --> X
+        return C->isZero() && !C->isNegative();
+      }
+    }
+    return false;
+  };
+
+  // This transform increases uses of N0, so freeze it to be safe.
+  // binop N0, (vselect Cond, IDC, FVal) --> vselect Cond, N0, (binop N0, FVal)
+  if (isIdentityConstantForOpcode(Opcode, TVal)) {
+    SDValue F0 = DAG.getFreeze(N0);
+    SDValue NewBO = DAG.getNode(Opcode, SDLoc(N), VT, F0, FVal, N->getFlags());
+    return DAG.getSelect(SDLoc(N), VT, Cond, F0, NewBO);
+  }
+  // binop N0, (vselect Cond, TVal, IDC) --> vselect Cond, (binop N0, TVal), N0
+  if (isIdentityConstantForOpcode(Opcode, FVal)) {
+    SDValue F0 = DAG.getFreeze(N0);
+    SDValue NewBO = DAG.getNode(Opcode, SDLoc(N), VT, F0, TVal, N->getFlags());
+    return DAG.getSelect(SDLoc(N), VT, Cond, NewBO, F0);
+  }
+
+  return SDValue();
+}
+
+static SDValue combineBinopWithSelect(SDNode *N, SelectionDAG &DAG,
+                                      const X86Subtarget &Subtarget) {
+  // TODO: This is too general. There are cases where pre-AVX512 codegen would
+  //       benefit. The transform may also be profitable for scalar code.
+  if (!Subtarget.hasAVX512())
+    return SDValue();
+
+  if (!Subtarget.hasVLX() && !N->getValueType(0).is512BitVector())
+    return SDValue();
+
+  if (SDValue Sel = foldSelectWithIdentityConstant(N, DAG, false))
+    return Sel;
+
+  const TargetLowering &TLI = DAG.getTargetLoweringInfo();
+  if (TLI.isCommutativeBinOp(N->getOpcode()))
+    if (SDValue Sel = foldSelectWithIdentityConstant(N, DAG, true))
+      return Sel;
+
+  return SDValue();
+}
+
 /// Do target-specific dag combines on floating-point adds/subs.
 static SDValue combineFaddFsub(SDNode *N, SelectionDAG &DAG,
                                const X86Subtarget &Subtarget) {
@@ -48951,6 +49028,9 @@ static SDValue combineFaddFsub(SDNode *N, SelectionDAG &DAG,
   if (SDValue COp = combineFaddCFmul(N, DAG, Subtarget))
     return COp;
 
+  if (SDValue Sel = combineBinopWithSelect(N, DAG, Subtarget))
+    return Sel;
+
   return SDValue();
 }
 

diff  --git a/llvm/test/CodeGen/X86/avx512fp16-arith-intrinsics.ll b/llvm/test/CodeGen/X86/avx512fp16-arith-intrinsics.ll
index 8148585f8d793..d94da4a94e977 100644
--- a/llvm/test/CodeGen/X86/avx512fp16-arith-intrinsics.ll
+++ b/llvm/test/CodeGen/X86/avx512fp16-arith-intrinsics.ll
@@ -83,8 +83,8 @@ define <32 x half> @test_int_x86_avx512fp16_maskz_sub_ph_512(<32 x half> %src, <
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    kmovd %edi, %k1
 ; CHECK-NEXT:    vsubph %zmm2, %zmm1, %zmm0 {%k1} {z}
-; CHECK-NEXT:    vsubph (%rsi), %zmm1, %zmm1 {%k1} {z}
-; CHECK-NEXT:    vsubph %zmm1, %zmm0, %zmm0
+; CHECK-NEXT:    vsubph (%rsi), %zmm1, %zmm1
+; CHECK-NEXT:    vsubph %zmm1, %zmm0, %zmm0 {%k1}
 ; CHECK-NEXT:    retq
   %mask = bitcast i32 %msk to <32 x i1>
   %val = load <32 x half>, <32 x half>* %ptr

diff  --git a/llvm/test/CodeGen/X86/vector-bo-select.ll b/llvm/test/CodeGen/X86/vector-bo-select.ll
index 9cc0761905ef8..93777915b1b88 100644
--- a/llvm/test/CodeGen/X86/vector-bo-select.ll
+++ b/llvm/test/CodeGen/X86/vector-bo-select.ll
@@ -27,9 +27,8 @@ define <4 x float> @fadd_v4f32(<4 x i1> %b, <4 x float> noundef %x, <4 x float>
 ; AVX512VL:       # %bb.0:
 ; AVX512VL-NEXT:    vpslld $31, %xmm0, %xmm0
 ; AVX512VL-NEXT:    vptestmd %xmm0, %xmm0, %k1
-; AVX512VL-NEXT:    vbroadcastss {{.*#+}} xmm0 = [-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0]
-; AVX512VL-NEXT:    vmovaps %xmm2, %xmm0 {%k1}
-; AVX512VL-NEXT:    vaddps %xmm0, %xmm1, %xmm0
+; AVX512VL-NEXT:    vaddps %xmm2, %xmm1, %xmm1 {%k1}
+; AVX512VL-NEXT:    vmovaps %xmm1, %xmm0
 ; AVX512VL-NEXT:    retq
   %s = select <4 x i1> %b, <4 x float> %y, <4 x float> <float -0.0, float -0.0, float -0.0, float -0.0>
   %r = fadd <4 x float> %x, %s
@@ -62,9 +61,8 @@ define <8 x float> @fadd_v8f32_commute(<8 x i1> %b, <8 x float> noundef %x, <8 x
 ; AVX512VL-NEXT:    vpmovsxwd %xmm0, %ymm0
 ; AVX512VL-NEXT:    vpslld $31, %ymm0, %ymm0
 ; AVX512VL-NEXT:    vptestmd %ymm0, %ymm0, %k1
-; AVX512VL-NEXT:    vbroadcastss {{.*#+}} ymm0 = [-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0]
-; AVX512VL-NEXT:    vmovaps %ymm2, %ymm0 {%k1}
-; AVX512VL-NEXT:    vaddps %ymm1, %ymm0, %ymm0
+; AVX512VL-NEXT:    vaddps %ymm2, %ymm1, %ymm1 {%k1}
+; AVX512VL-NEXT:    vmovaps %ymm1, %ymm0
 ; AVX512VL-NEXT:    retq
   %s = select <8 x i1> %b, <8 x float> %y, <8 x float> <float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0>
   %r = fadd <8 x float> %s, %x
@@ -92,8 +90,8 @@ define <16 x float> @fadd_v16f32_swap(<16 x i1> %b, <16 x float> noundef %x, <16
 ; AVX512-NEXT:    vpmovsxbd %xmm0, %zmm0
 ; AVX512-NEXT:    vpslld $31, %zmm0, %zmm0
 ; AVX512-NEXT:    vptestmd %zmm0, %zmm0, %k1
-; AVX512-NEXT:    vbroadcastss {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm2 {%k1}
 ; AVX512-NEXT:    vaddps %zmm2, %zmm1, %zmm0
+; AVX512-NEXT:    vmovaps %zmm1, %zmm0 {%k1}
 ; AVX512-NEXT:    retq
   %s = select <16 x i1> %b, <16 x float> <float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0>, <16 x float> %y
   %r = fadd <16 x float> %x, %s
@@ -121,8 +119,8 @@ define <16 x float> @fadd_v16f32_commute_swap(<16 x i1> %b, <16 x float> noundef
 ; AVX512-NEXT:    vpmovsxbd %xmm0, %zmm0
 ; AVX512-NEXT:    vpslld $31, %zmm0, %zmm0
 ; AVX512-NEXT:    vptestmd %zmm0, %zmm0, %k1
-; AVX512-NEXT:    vbroadcastss {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm2 {%k1}
-; AVX512-NEXT:    vaddps %zmm1, %zmm2, %zmm0
+; AVX512-NEXT:    vaddps %zmm2, %zmm1, %zmm0
+; AVX512-NEXT:    vmovaps %zmm1, %zmm0 {%k1}
 ; AVX512-NEXT:    retq
   %s = select <16 x i1> %b, <16 x float> <float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0>, <16 x float> %y
   %r = fadd <16 x float> %s, %x
@@ -152,14 +150,16 @@ define <4 x float> @fsub_v4f32(<4 x i1> %b, <4 x float> noundef %x, <4 x float>
 ; AVX512VL:       # %bb.0:
 ; AVX512VL-NEXT:    vpslld $31, %xmm0, %xmm0
 ; AVX512VL-NEXT:    vptestmd %xmm0, %xmm0, %k1
-; AVX512VL-NEXT:    vmovaps %xmm2, %xmm0 {%k1} {z}
-; AVX512VL-NEXT:    vsubps %xmm0, %xmm1, %xmm0
+; AVX512VL-NEXT:    vsubps %xmm2, %xmm1, %xmm1 {%k1}
+; AVX512VL-NEXT:    vmovaps %xmm1, %xmm0
 ; AVX512VL-NEXT:    retq
   %s = select <4 x i1> %b, <4 x float> %y, <4 x float> zeroinitializer
   %r = fsub <4 x float> %x, %s
   ret <4 x float> %r
 }
 
+; negative test - fsub is not commutative; there is no identity constant for operand 0
+
 define <8 x float> @fsub_v8f32_commute(<8 x i1> %b, <8 x float> noundef %x, <8 x float> noundef %y) {
 ; AVX2-LABEL: fsub_v8f32_commute:
 ; AVX2:       # %bb.0:
@@ -214,15 +214,17 @@ define <16 x float> @fsub_v16f32_swap(<16 x i1> %b, <16 x float> noundef %x, <16
 ; AVX512:       # %bb.0:
 ; AVX512-NEXT:    vpmovsxbd %xmm0, %zmm0
 ; AVX512-NEXT:    vpslld $31, %zmm0, %zmm0
-; AVX512-NEXT:    vptestnmd %zmm0, %zmm0, %k1
-; AVX512-NEXT:    vmovaps %zmm2, %zmm0 {%k1} {z}
-; AVX512-NEXT:    vsubps %zmm0, %zmm1, %zmm0
+; AVX512-NEXT:    vptestmd %zmm0, %zmm0, %k1
+; AVX512-NEXT:    vsubps %zmm2, %zmm1, %zmm0
+; AVX512-NEXT:    vmovaps %zmm1, %zmm0 {%k1}
 ; AVX512-NEXT:    retq
   %s = select <16 x i1> %b, <16 x float> zeroinitializer, <16 x float> %y
   %r = fsub <16 x float> %x, %s
   ret <16 x float> %r
 }
 
+; negative test - fsub is not commutative; there is no identity constant for operand 0
+
 define <16 x float> @fsub_v16f32_commute_swap(<16 x i1> %b, <16 x float> noundef %x, <16 x float> noundef %y) {
 ; AVX2-LABEL: fsub_v16f32_commute_swap:
 ; AVX2:       # %bb.0:
@@ -570,9 +572,7 @@ define <8 x float> @fadd_v8f32_cast_cond(i8 noundef zeroext %pb, <8 x float> nou
 ; AVX512VL-LABEL: fadd_v8f32_cast_cond:
 ; AVX512VL:       # %bb.0:
 ; AVX512VL-NEXT:    kmovw %edi, %k1
-; AVX512VL-NEXT:    vbroadcastss {{.*#+}} ymm2 = [-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0]
-; AVX512VL-NEXT:    vmovaps %ymm1, %ymm2 {%k1}
-; AVX512VL-NEXT:    vaddps %ymm2, %ymm0, %ymm0
+; AVX512VL-NEXT:    vaddps %ymm1, %ymm0, %ymm0 {%k1}
 ; AVX512VL-NEXT:    retq
   %b = bitcast i8 %pb to <8 x i1>
   %s = select <8 x i1> %b, <8 x float> %y, <8 x float> <float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0>
@@ -636,9 +636,7 @@ define <8 x double> @fadd_v8f64_cast_cond(i8 noundef zeroext %pb, <8 x double> n
 ; AVX512-LABEL: fadd_v8f64_cast_cond:
 ; AVX512:       # %bb.0:
 ; AVX512-NEXT:    kmovw %edi, %k1
-; AVX512-NEXT:    vbroadcastsd {{.*#+}} zmm2 = [-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0]
-; AVX512-NEXT:    vmovapd %zmm1, %zmm2 {%k1}
-; AVX512-NEXT:    vaddpd %zmm2, %zmm0, %zmm0
+; AVX512-NEXT:    vaddpd %zmm1, %zmm0, %zmm0 {%k1}
 ; AVX512-NEXT:    retq
   %b = bitcast i8 %pb to <8 x i1>
   %s = select <8 x i1> %b, <8 x double> %y, <8 x double> <double -0.0, double -0.0, double -0.0, double -0.0, double -0.0, double -0.0, double -0.0, double -0.0>
@@ -709,8 +707,7 @@ define <8 x float> @fsub_v8f32_cast_cond(i8 noundef zeroext %pb, <8 x float> nou
 ; AVX512VL-LABEL: fsub_v8f32_cast_cond:
 ; AVX512VL:       # %bb.0:
 ; AVX512VL-NEXT:    kmovw %edi, %k1
-; AVX512VL-NEXT:    vmovaps %ymm1, %ymm1 {%k1} {z}
-; AVX512VL-NEXT:    vsubps %ymm1, %ymm0, %ymm0
+; AVX512VL-NEXT:    vsubps %ymm1, %ymm0, %ymm0 {%k1}
 ; AVX512VL-NEXT:    retq
   %b = bitcast i8 %pb to <8 x i1>
   %s = select <8 x i1> %b, <8 x float> %y, <8 x float> zeroinitializer
@@ -775,8 +772,7 @@ define <8 x double> @fsub_v8f64_cast_cond(i8 noundef zeroext %pb, <8 x double> n
 ; AVX512-LABEL: fsub_v8f64_cast_cond:
 ; AVX512:       # %bb.0:
 ; AVX512-NEXT:    kmovw %edi, %k1
-; AVX512-NEXT:    vmovapd %zmm1, %zmm1 {%k1} {z}
-; AVX512-NEXT:    vsubpd %zmm1, %zmm0, %zmm0
+; AVX512-NEXT:    vsubpd %zmm1, %zmm0, %zmm0 {%k1}
 ; AVX512-NEXT:    retq
   %b = bitcast i8 %pb to <8 x i1>
   %s = select <8 x i1> %b, <8 x double> %y, <8 x double> zeroinitializer


        


More information about the llvm-commits mailing list