[llvm] 76656ec - [X86][FP16] Combine the FADD(A, FMA(B, C, 0)) to FMA(B, C, A)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Sep 23 00:37:22 PDT 2021
Author: Liu, Chen3
Date: 2021-09-23T15:37:08+08:00
New Revision: 76656ec8ec535bd17afc998112a83c7c55ad7719
URL: https://github.com/llvm/llvm-project/commit/76656ec8ec535bd17afc998112a83c7c55ad7719
DIFF: https://github.com/llvm/llvm-project/commit/76656ec8ec535bd17afc998112a83c7c55ad7719.diff
LOG: [X86][FP16] Combine the FADD(A, FMA(B, C, 0)) to FMA(B, C, A)
This patch is to support transform something like
_mm512_add_ph(acc, _mm512_fmadd_pch(a, b, _mm512_setzero_ph()))
to _mm512_fmadd_pch(a, b, acc).
Differential Revision: https://reviews.llvm.org/D109953
Added:
llvm/test/CodeGen/X86/avx512fp16-combine-vfmac-fadd.ll
Modified:
llvm/lib/Target/X86/X86ISelLowering.cpp
Removed:
################################################################################
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 5de3a4c9d8dbb..c79ff5b1590a4 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -47662,22 +47662,37 @@ static SDValue combineFMulcFCMulc(SDNode *N, SelectionDAG &DAG,
return Res;
}
-// Try to combine the following nodes
-// t21: v16f32 = X86ISD::VFMULC/VFCMULC t7, t8
-// t15: v32f16 = bitcast t21
-// t16: v32f16 = fadd nnan ninf nsz arcp contract afn reassoc t15, t2
-// into X86ISD::VFMADDC/VFCMADDC if possible:
-// t22: v16f32 = bitcast t2
-// t23: v16f32 = nnan ninf nsz arcp contract afn reassoc
-// X86ISD::VFMADDC/VFCMADDC t7, t8, t22
-// t24: v32f16 = bitcast t23
+// Try to combine the following nodes:
+// FADD(A, FMA(B, C, 0)) and FADD(A, FMUL(B, C)) to FMA(B, C, A)
static SDValue combineFaddCFmul(SDNode *N, SelectionDAG &DAG,
const X86Subtarget &Subtarget) {
- auto AllowContract = [&DAG](SDNode *N) {
+ auto AllowContract = [&DAG](const SDNodeFlags &Flags) {
return DAG.getTarget().Options.AllowFPOpFusion == FPOpFusion::Fast ||
- N->getFlags().hasAllowContract();
+ Flags.hasAllowContract();
};
- if (N->getOpcode() != ISD::FADD || !Subtarget.hasFP16() || !AllowContract(N))
+
+ auto HasNoSignedZero = [&DAG](const SDNodeFlags &Flags) {
+ return DAG.getTarget().Options.NoSignedZerosFPMath ||
+ Flags.hasNoSignedZeros();
+ };
+ auto IsVectorAllNegativeZero = [](const SDNode *N) {
+ if (N->getOpcode() != X86ISD::VBROADCAST_LOAD)
+ return false;
+ assert(N->getSimpleValueType(0).getScalarType() == MVT::f32 &&
+ "Unexpected vector type!");
+ if (ConstantPoolSDNode *CP =
+ dyn_cast<ConstantPoolSDNode>(N->getOperand(1)->getOperand(0))) {
+ APInt AI = APInt(32, 0x80008000, true);
+ if (const auto *CI = dyn_cast<ConstantInt>(CP->getConstVal()))
+ return CI->getValue() == AI;
+ if (const auto *CF = dyn_cast<ConstantFP>(CP->getConstVal()))
+ return CF->getValue() == APFloat(APFloat::IEEEsingle(), AI);
+ }
+ return false;
+ };
+
+ if (N->getOpcode() != ISD::FADD || !Subtarget.hasFP16() ||
+ !AllowContract(N->getFlags()))
return SDValue();
EVT VT = N->getValueType(0);
@@ -47686,16 +47701,33 @@ static SDValue combineFaddCFmul(SDNode *N, SelectionDAG &DAG,
SDValue LHS = N->getOperand(0);
SDValue RHS = N->getOperand(1);
- SDValue CFmul, FAddOp1;
- auto GetCFmulFrom = [&CFmul, &AllowContract](SDValue N) -> bool {
+ bool IsConj;
+ SDValue FAddOp1, MulOp0, MulOp1;
+ auto GetCFmulFrom = [&MulOp0, &MulOp1, &IsConj, &AllowContract,
+ &IsVectorAllNegativeZero,
+ &HasNoSignedZero](SDValue N) -> bool {
if (!N.hasOneUse() || N.getOpcode() != ISD::BITCAST)
- return false;
+ return false;
SDValue Op0 = N.getOperand(0);
unsigned Opcode = Op0.getOpcode();
- if (Op0.hasOneUse() && AllowContract(Op0.getNode()) &&
- (Opcode == X86ISD::VFMULC || Opcode == X86ISD::VFCMULC))
- CFmul = Op0;
- return !!CFmul;
+ if (Op0.hasOneUse() && AllowContract(Op0->getFlags())) {
+ if ((Opcode == X86ISD::VFMULC || Opcode == X86ISD::VFCMULC)) {
+ MulOp0 = Op0.getOperand(0);
+ MulOp1 = Op0.getOperand(1);
+ IsConj = Opcode == X86ISD::VFCMULC;
+ return true;
+ }
+ if ((Opcode == X86ISD::VFMADDC || Opcode == X86ISD::VFCMADDC) &&
+ ((ISD::isBuildVectorAllZeros(Op0->getOperand(2).getNode()) &&
+ HasNoSignedZero(Op0->getFlags())) ||
+ IsVectorAllNegativeZero(Op0->getOperand(2).getNode()))) {
+ MulOp0 = Op0.getOperand(0);
+ MulOp1 = Op0.getOperand(1);
+ IsConj = Opcode == X86ISD::VFCMADDC;
+ return true;
+ }
+ }
+ return false;
};
if (GetCFmulFrom(LHS))
@@ -47706,14 +47738,12 @@ static SDValue combineFaddCFmul(SDNode *N, SelectionDAG &DAG,
return SDValue();
MVT CVT = MVT::getVectorVT(MVT::f32, VT.getVectorNumElements() / 2);
- assert(CFmul->getValueType(0) == CVT && "Complex type mismatch");
FAddOp1 = DAG.getBitcast(CVT, FAddOp1);
- unsigned newOp = CFmul.getOpcode() == X86ISD::VFMULC ? X86ISD::VFMADDC
- : X86ISD::VFCMADDC;
+ unsigned NewOp = IsConj ? X86ISD::VFCMADDC : X86ISD::VFMADDC;
// FIXME: How do we handle when fast math flags of FADD are
diff erent from
// CFMUL's?
- CFmul = DAG.getNode(newOp, SDLoc(N), CVT, FAddOp1, CFmul.getOperand(0),
- CFmul.getOperand(1), N->getFlags());
+ SDValue CFmul =
+ DAG.getNode(NewOp, SDLoc(N), CVT, FAddOp1, MulOp0, MulOp1, N->getFlags());
return DAG.getBitcast(VT, CFmul);
}
diff --git a/llvm/test/CodeGen/X86/avx512fp16-combine-vfmac-fadd.ll b/llvm/test/CodeGen/X86/avx512fp16-combine-vfmac-fadd.ll
new file mode 100644
index 0000000000000..e5d93a1dfdbc7
--- /dev/null
+++ b/llvm/test/CodeGen/X86/avx512fp16-combine-vfmac-fadd.ll
@@ -0,0 +1,234 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc < %s -mtriple=x86_64-unknown-unknown --fp-contract=fast --enable-no-signed-zeros-fp-math -mattr=avx512fp16 | FileCheck %s --check-prefixes=CHECK,NO-SZ
+; RUN: llc < %s -mtriple=x86_64-unknown-unknown --fp-contract=fast -mattr=avx512fp16 | FileCheck %s --check-prefixes=CHECK,HAS-SZ
+
+; FADD(acc, FMA(a, b, +0.0)) can be combined to FMA(a, b, acc) if the nsz flag set.
+define dso_local <32 x half> @test1(<32 x half> %acc, <32 x half> %a, <32 x half> %b) {
+; NO-SZ-LABEL: test1:
+; NO-SZ: # %bb.0: # %entry
+; NO-SZ-NEXT: vfcmaddcph %zmm1, %zmm0, %zmm2
+; NO-SZ-NEXT: vmovaps %zmm2, %zmm0
+; NO-SZ-NEXT: retq
+;
+; HAS-SZ-LABEL: test1:
+; HAS-SZ: # %bb.0: # %entry
+; HAS-SZ-NEXT: vxorps %xmm3, %xmm3, %xmm3
+; HAS-SZ-NEXT: vfcmaddcph %zmm2, %zmm1, %zmm3
+; HAS-SZ-NEXT: vaddph %zmm0, %zmm3, %zmm0
+; HAS-SZ-NEXT: retq
+entry:
+ %0 = bitcast <32 x half> %a to <16 x float>
+ %1 = bitcast <32 x half> %b to <16 x float>
+ %2 = tail call <16 x float> @llvm.x86.avx512fp16.mask.vfcmadd.cph.512(<16 x float> %0, <16 x float> %1, <16 x float> zeroinitializer, i16 -1, i32 4)
+ %3 = bitcast <16 x float> %2 to <32 x half>
+ %add.i = fadd <32 x half> %3, %acc
+ ret <32 x half> %add.i
+}
+
+define dso_local <32 x half> @test2(<32 x half> %acc, <32 x half> %a, <32 x half> %b) {
+; NO-SZ-LABEL: test2:
+; NO-SZ: # %bb.0: # %entry
+; NO-SZ-NEXT: vfmaddcph %zmm1, %zmm0, %zmm2
+; NO-SZ-NEXT: vmovaps %zmm2, %zmm0
+; NO-SZ-NEXT: retq
+;
+; HAS-SZ-LABEL: test2:
+; HAS-SZ: # %bb.0: # %entry
+; HAS-SZ-NEXT: vxorps %xmm3, %xmm3, %xmm3
+; HAS-SZ-NEXT: vfmaddcph %zmm2, %zmm1, %zmm3
+; HAS-SZ-NEXT: vaddph %zmm0, %zmm3, %zmm0
+; HAS-SZ-NEXT: retq
+entry:
+ %0 = bitcast <32 x half> %a to <16 x float>
+ %1 = bitcast <32 x half> %b to <16 x float>
+ %2 = tail call <16 x float> @llvm.x86.avx512fp16.mask.vfmadd.cph.512(<16 x float> %0, <16 x float> %1, <16 x float> zeroinitializer, i16 -1, i32 4)
+ %3 = bitcast <16 x float> %2 to <32 x half>
+ %add.i = fadd <32 x half> %3, %acc
+ ret <32 x half> %add.i
+}
+
+define dso_local <16 x half> @test3(<16 x half> %acc, <16 x half> %a, <16 x half> %b) {
+; NO-SZ-LABEL: test3:
+; NO-SZ: # %bb.0: # %entry
+; NO-SZ-NEXT: vfcmaddcph %ymm1, %ymm0, %ymm2
+; NO-SZ-NEXT: vmovaps %ymm2, %ymm0
+; NO-SZ-NEXT: retq
+;
+; HAS-SZ-LABEL: test3:
+; HAS-SZ: # %bb.0: # %entry
+; HAS-SZ-NEXT: vxorps %xmm3, %xmm3, %xmm3
+; HAS-SZ-NEXT: vfcmaddcph %ymm2, %ymm1, %ymm3
+; HAS-SZ-NEXT: vaddph %ymm0, %ymm3, %ymm0
+; HAS-SZ-NEXT: retq
+entry:
+ %0 = bitcast <16 x half> %a to <8 x float>
+ %1 = bitcast <16 x half> %b to <8 x float>
+ %2 = tail call <8 x float> @llvm.x86.avx512fp16.mask.vfcmadd.cph.256(<8 x float> %0, <8 x float> %1, <8 x float> zeroinitializer, i8 -1)
+ %3 = bitcast <8 x float> %2 to <16 x half>
+ %add.i = fadd <16 x half> %3, %acc
+ ret <16 x half> %add.i
+}
+
+define dso_local <16 x half> @test4(<16 x half> %acc, <16 x half> %a, <16 x half> %b) {
+; NO-SZ-LABEL: test4:
+; NO-SZ: # %bb.0: # %entry
+; NO-SZ-NEXT: vfmaddcph %ymm1, %ymm0, %ymm2
+; NO-SZ-NEXT: vmovaps %ymm2, %ymm0
+; NO-SZ-NEXT: retq
+;
+; HAS-SZ-LABEL: test4:
+; HAS-SZ: # %bb.0: # %entry
+; HAS-SZ-NEXT: vxorps %xmm3, %xmm3, %xmm3
+; HAS-SZ-NEXT: vfmaddcph %ymm2, %ymm1, %ymm3
+; HAS-SZ-NEXT: vaddph %ymm0, %ymm3, %ymm0
+; HAS-SZ-NEXT: retq
+entry:
+ %0 = bitcast <16 x half> %a to <8 x float>
+ %1 = bitcast <16 x half> %b to <8 x float>
+ %2 = tail call <8 x float> @llvm.x86.avx512fp16.mask.vfmadd.cph.256(<8 x float> %0, <8 x float> %1, <8 x float> zeroinitializer, i8 -1)
+ %3 = bitcast <8 x float> %2 to <16 x half>
+ %add.i = fadd <16 x half> %3, %acc
+ ret <16 x half> %add.i
+}
+
+define dso_local <8 x half> @test5(<8 x half> %acc, <8 x half> %a, <8 x half> %b) {
+; NO-SZ-LABEL: test5:
+; NO-SZ: # %bb.0: # %entry
+; NO-SZ-NEXT: vfcmaddcph %xmm1, %xmm0, %xmm2
+; NO-SZ-NEXT: vmovaps %xmm2, %xmm0
+; NO-SZ-NEXT: retq
+;
+; HAS-SZ-LABEL: test5:
+; HAS-SZ: # %bb.0: # %entry
+; HAS-SZ-NEXT: vxorps %xmm3, %xmm3, %xmm3
+; HAS-SZ-NEXT: vfcmaddcph %xmm2, %xmm1, %xmm3
+; HAS-SZ-NEXT: vaddph %xmm0, %xmm3, %xmm0
+; HAS-SZ-NEXT: retq
+entry:
+ %0 = bitcast <8 x half> %a to <4 x float>
+ %1 = bitcast <8 x half> %b to <4 x float>
+ %2 = tail call <4 x float> @llvm.x86.avx512fp16.mask.vfcmadd.cph.128(<4 x float> %0, <4 x float> %1, <4 x float> zeroinitializer, i8 -1)
+ %3 = bitcast <4 x float> %2 to <8 x half>
+ %add.i = fadd <8 x half> %3, %acc
+ ret <8 x half> %add.i
+}
+
+define dso_local <8 x half> @test6(<8 x half> %acc, <8 x half> %a, <8 x half> %b) {
+; NO-SZ-LABEL: test6:
+; NO-SZ: # %bb.0: # %entry
+; NO-SZ-NEXT: vfmaddcph %xmm1, %xmm0, %xmm2
+; NO-SZ-NEXT: vmovaps %xmm2, %xmm0
+; NO-SZ-NEXT: retq
+;
+; HAS-SZ-LABEL: test6:
+; HAS-SZ: # %bb.0: # %entry
+; HAS-SZ-NEXT: vxorps %xmm3, %xmm3, %xmm3
+; HAS-SZ-NEXT: vfmaddcph %xmm2, %xmm1, %xmm3
+; HAS-SZ-NEXT: vaddph %xmm0, %xmm3, %xmm0
+; HAS-SZ-NEXT: retq
+entry:
+ %0 = bitcast <8 x half> %a to <4 x float>
+ %1 = bitcast <8 x half> %b to <4 x float>
+ %2 = tail call <4 x float> @llvm.x86.avx512fp16.mask.vfmadd.cph.128(<4 x float> %0, <4 x float> %1, <4 x float> zeroinitializer, i8 -1)
+ %3 = bitcast <4 x float> %2 to <8 x half>
+ %add.i = fadd <8 x half> %3, %acc
+ ret <8 x half> %add.i
+}
+
+; FADD(acc, FMA(a, b, -0.0)) can be combined to FMA(a, b, acc) no matter if the nsz flag set.
+define dso_local <32 x half> @test13(<32 x half> %acc, <32 x half> %a, <32 x half> %b) {
+; CHECK-LABEL: test13:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: vfcmaddcph %zmm1, %zmm0, %zmm2
+; CHECK-NEXT: vmovaps %zmm2, %zmm0
+; CHECK-NEXT: retq
+entry:
+ %0 = bitcast <32 x half> %a to <16 x float>
+ %1 = bitcast <32 x half> %b to <16 x float>
+ %2 = tail call <16 x float> @llvm.x86.avx512fp16.mask.vfcmadd.cph.512(<16 x float> %0, <16 x float> %1, <16 x float> <float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000>, i16 -1, i32 4)
+ %3 = bitcast <16 x float> %2 to <32 x half>
+ %add.i = fadd <32 x half> %3, %acc
+ ret <32 x half> %add.i
+}
+
+define dso_local <32 x half> @test14(<32 x half> %acc, <32 x half> %a, <32 x half> %b) {
+; CHECK-LABEL: test14:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: vfmaddcph %zmm1, %zmm0, %zmm2
+; CHECK-NEXT: vmovaps %zmm2, %zmm0
+; CHECK-NEXT: retq
+entry:
+ %0 = bitcast <32 x half> %a to <16 x float>
+ %1 = bitcast <32 x half> %b to <16 x float>
+ %2 = tail call <16 x float> @llvm.x86.avx512fp16.mask.vfmadd.cph.512(<16 x float> %0, <16 x float> %1, <16 x float> <float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000>, i16 -1, i32 4)
+ %3 = bitcast <16 x float> %2 to <32 x half>
+ %add.i = fadd <32 x half> %3, %acc
+ ret <32 x half> %add.i
+}
+
+define dso_local <16 x half> @test15(<16 x half> %acc, <16 x half> %a, <16 x half> %b) {
+; CHECK-LABEL: test15:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: vfcmaddcph %ymm1, %ymm0, %ymm2
+; CHECK-NEXT: vmovaps %ymm2, %ymm0
+; CHECK-NEXT: retq
+entry:
+ %0 = bitcast <16 x half> %a to <8 x float>
+ %1 = bitcast <16 x half> %b to <8 x float>
+ %2 = tail call <8 x float> @llvm.x86.avx512fp16.mask.vfcmadd.cph.256(<8 x float> %0, <8 x float> %1, <8 x float> <float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000>, i8 -1)
+ %3 = bitcast <8 x float> %2 to <16 x half>
+ %add.i = fadd <16 x half> %3, %acc
+ ret <16 x half> %add.i
+}
+
+define dso_local <16 x half> @test16(<16 x half> %acc, <16 x half> %a, <16 x half> %b) {
+; CHECK-LABEL: test16:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: vfmaddcph %ymm1, %ymm0, %ymm2
+; CHECK-NEXT: vmovaps %ymm2, %ymm0
+; CHECK-NEXT: retq
+entry:
+ %0 = bitcast <16 x half> %a to <8 x float>
+ %1 = bitcast <16 x half> %b to <8 x float>
+ %2 = tail call <8 x float> @llvm.x86.avx512fp16.mask.vfmadd.cph.256(<8 x float> %0, <8 x float> %1, <8 x float> <float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000>, i8 -1)
+ %3 = bitcast <8 x float> %2 to <16 x half>
+ %add.i = fadd <16 x half> %3, %acc
+ ret <16 x half> %add.i
+}
+
+define dso_local <8 x half> @test17(<8 x half> %acc, <8 x half> %a, <8 x half> %b) {
+; CHECK-LABEL: test17:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: vfcmaddcph %xmm1, %xmm0, %xmm2
+; CHECK-NEXT: vmovaps %xmm2, %xmm0
+; CHECK-NEXT: retq
+entry:
+ %0 = bitcast <8 x half> %a to <4 x float>
+ %1 = bitcast <8 x half> %b to <4 x float>
+ %2 = tail call <4 x float> @llvm.x86.avx512fp16.mask.vfcmadd.cph.128(<4 x float> %0, <4 x float> %1, <4 x float> <float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000>, i8 -1)
+ %3 = bitcast <4 x float> %2 to <8 x half>
+ %add.i = fadd <8 x half> %3, %acc
+ ret <8 x half> %add.i
+}
+
+define dso_local <8 x half> @test18(<8 x half> %acc, <8 x half> %a, <8 x half> %b) {
+; CHECK-LABEL: test18:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: vfmaddcph %xmm1, %xmm0, %xmm2
+; CHECK-NEXT: vmovaps %xmm2, %xmm0
+; CHECK-NEXT: retq
+entry:
+ %0 = bitcast <8 x half> %a to <4 x float>
+ %1 = bitcast <8 x half> %b to <4 x float>
+ %2 = tail call <4 x float> @llvm.x86.avx512fp16.mask.vfmadd.cph.128(<4 x float> %0, <4 x float> %1, <4 x float> <float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000, float 0xB790000000000000>, i8 -1)
+ %3 = bitcast <4 x float> %2 to <8 x half>
+ %add.i = fadd <8 x half> %3, %acc
+ ret <8 x half> %add.i
+}
+
+declare <16 x float> @llvm.x86.avx512fp16.mask.vfcmadd.cph.512(<16 x float>, <16 x float>, <16 x float>, i16, i32 immarg)
+declare <16 x float> @llvm.x86.avx512fp16.mask.vfmadd.cph.512(<16 x float>, <16 x float>, <16 x float>, i16, i32 immarg)
+declare <8 x float> @llvm.x86.avx512fp16.mask.vfcmadd.cph.256(<8 x float>, <8 x float>, <8 x float>, i8)
+declare <8 x float> @llvm.x86.avx512fp16.mask.vfmadd.cph.256(<8 x float>, <8 x float>, <8 x float>, i8)
+declare <4 x float> @llvm.x86.avx512fp16.mask.vfcmadd.cph.128(<4 x float>, <4 x float>, <4 x float>, i8)
+declare <4 x float> @llvm.x86.avx512fp16.mask.vfmadd.cph.128(<4 x float>, <4 x float>, <4 x float>, i8)
More information about the llvm-commits
mailing list