[llvm] 2fe2a6d - [DAGCombiner] Use generalized pattern matcher in visitFMA to support vp.fma.
Yeting Kuo via llvm-commits
llvm-commits at lists.llvm.org
Wed Jun 7 18:40:28 PDT 2023
Author: Yeting Kuo
Date: 2023-06-08T09:40:21+08:00
New Revision: 2fe2a6d4b8a4647e49d69a5ff7161946aeb7cee1
URL: https://github.com/llvm/llvm-project/commit/2fe2a6d4b8a4647e49d69a5ff7161946aeb7cee1
DIFF: https://github.com/llvm/llvm-project/commit/2fe2a6d4b8a4647e49d69a5ff7161946aeb7cee1.diff
LOG: [DAGCombiner] Use generalized pattern matcher in visitFMA to support vp.fma.
Note: Some patterns in visitFMA are needed refined to support splat of constant.
Reviewed By: luke
Differential Revision: https://reviews.llvm.org/D152260
Added:
llvm/test/CodeGen/RISCV/rvv/vfma-vp-combine.ll
Modified:
llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
Removed:
################################################################################
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 02da5508656e6..8ffe00f802c31 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -491,7 +491,7 @@ namespace {
SDValue visitSTRICT_FADD(SDNode *N);
SDValue visitFSUB(SDNode *N);
SDValue visitFMUL(SDNode *N);
- SDValue visitFMA(SDNode *N);
+ template <class MatchContextClass> SDValue visitFMA(SDNode *N);
SDValue visitFDIV(SDNode *N);
SDValue visitFREM(SDNode *N);
SDValue visitFSQRT(SDNode *N);
@@ -1961,7 +1961,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
case ISD::STRICT_FADD: return visitSTRICT_FADD(N);
case ISD::FSUB: return visitFSUB(N);
case ISD::FMUL: return visitFMUL(N);
- case ISD::FMA: return visitFMA(N);
+ case ISD::FMA: return visitFMA<EmptyMatchContext>(N);
case ISD::FDIV: return visitFDIV(N);
case ISD::FREM: return visitFREM(N);
case ISD::FSQRT: return visitFSQRT(N);
@@ -16320,7 +16320,7 @@ SDValue DAGCombiner::visitFMUL(SDNode *N) {
return SDValue();
}
-SDValue DAGCombiner::visitFMA(SDNode *N) {
+template <class MatchContextClass> SDValue DAGCombiner::visitFMA(SDNode *N) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
SDValue N2 = N->getOperand(2);
@@ -16331,6 +16331,7 @@ SDValue DAGCombiner::visitFMA(SDNode *N) {
const TargetOptions &Options = DAG.getTarget().Options;
// FMA nodes have flags that propagate to the created nodes.
SelectionDAG::FlagInserter FlagsInserter(DAG, N);
+ MatchContextClass matcher(DAG, TLI, N);
bool CanReassociate =
Options.UnsafeFPMath || N->getFlags().hasAllowReassociation();
@@ -16339,7 +16340,7 @@ SDValue DAGCombiner::visitFMA(SDNode *N) {
if (isa<ConstantFPSDNode>(N0) &&
isa<ConstantFPSDNode>(N1) &&
isa<ConstantFPSDNode>(N2)) {
- return DAG.getNode(ISD::FMA, DL, VT, N0, N1, N2);
+ return matcher.getNode(ISD::FMA, DL, VT, N0, N1, N2);
}
// (-N0 * -N1) + N2 --> (N0 * N1) + N2
@@ -16355,7 +16356,7 @@ SDValue DAGCombiner::visitFMA(SDNode *N) {
TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize, CostN1);
if (NegN1 && (CostN0 == TargetLowering::NegatibleCost::Cheaper ||
CostN1 == TargetLowering::NegatibleCost::Cheaper))
- return DAG.getNode(ISD::FMA, DL, VT, NegN0, NegN1, N2);
+ return matcher.getNode(ISD::FMA, DL, VT, NegN0, NegN1, N2);
}
// FIXME: use fast math flags instead of Options.UnsafeFPMath
@@ -16366,70 +16367,74 @@ SDValue DAGCombiner::visitFMA(SDNode *N) {
return N2;
}
+ // FIXME: Support splat of constant.
if (N0CFP && N0CFP->isExactlyValue(1.0))
- return DAG.getNode(ISD::FADD, SDLoc(N), VT, N1, N2);
+ return matcher.getNode(ISD::FADD, SDLoc(N), VT, N1, N2);
if (N1CFP && N1CFP->isExactlyValue(1.0))
- return DAG.getNode(ISD::FADD, SDLoc(N), VT, N0, N2);
+ return matcher.getNode(ISD::FADD, SDLoc(N), VT, N0, N2);
// Canonicalize (fma c, x, y) -> (fma x, c, y)
if (DAG.isConstantFPBuildVectorOrConstantFP(N0) &&
!DAG.isConstantFPBuildVectorOrConstantFP(N1))
- return DAG.getNode(ISD::FMA, SDLoc(N), VT, N1, N0, N2);
+ return matcher.getNode(ISD::FMA, SDLoc(N), VT, N1, N0, N2);
if (CanReassociate) {
// (fma x, c1, (fmul x, c2)) -> (fmul x, c1+c2)
- if (N2.getOpcode() == ISD::FMUL && N0 == N2.getOperand(0) &&
+ if (matcher.match(N2, ISD::FMUL) && N0 == N2.getOperand(0) &&
DAG.isConstantFPBuildVectorOrConstantFP(N1) &&
DAG.isConstantFPBuildVectorOrConstantFP(N2.getOperand(1))) {
- return DAG.getNode(ISD::FMUL, DL, VT, N0,
- DAG.getNode(ISD::FADD, DL, VT, N1, N2.getOperand(1)));
+ return matcher.getNode(
+ ISD::FMUL, DL, VT, N0,
+ matcher.getNode(ISD::FADD, DL, VT, N1, N2.getOperand(1)));
}
// (fma (fmul x, c1), c2, y) -> (fma x, c1*c2, y)
- if (N0.getOpcode() == ISD::FMUL &&
+ if (matcher.match(N0, ISD::FMUL) &&
DAG.isConstantFPBuildVectorOrConstantFP(N1) &&
DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(1))) {
- return DAG.getNode(ISD::FMA, DL, VT, N0.getOperand(0),
- DAG.getNode(ISD::FMUL, DL, VT, N1, N0.getOperand(1)),
- N2);
+ return matcher.getNode(
+ ISD::FMA, DL, VT, N0.getOperand(0),
+ matcher.getNode(ISD::FMUL, DL, VT, N1, N0.getOperand(1)), N2);
}
}
// (fma x, -1, y) -> (fadd (fneg x), y)
+ // FIXME: Support splat of constant.
if (N1CFP) {
if (N1CFP->isExactlyValue(1.0))
- return DAG.getNode(ISD::FADD, DL, VT, N0, N2);
+ return matcher.getNode(ISD::FADD, DL, VT, N0, N2);
if (N1CFP->isExactlyValue(-1.0) &&
(!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT))) {
- SDValue RHSNeg = DAG.getNode(ISD::FNEG, DL, VT, N0);
+ SDValue RHSNeg = matcher.getNode(ISD::FNEG, DL, VT, N0);
AddToWorklist(RHSNeg.getNode());
- return DAG.getNode(ISD::FADD, DL, VT, N2, RHSNeg);
+ return matcher.getNode(ISD::FADD, DL, VT, N2, RHSNeg);
}
// fma (fneg x), K, y -> fma x -K, y
- if (N0.getOpcode() == ISD::FNEG &&
+ if (matcher.match(N0, ISD::FNEG) &&
(TLI.isOperationLegal(ISD::ConstantFP, VT) ||
- (N1.hasOneUse() && !TLI.isFPImmLegal(N1CFP->getValueAPF(), VT,
- ForCodeSize)))) {
- return DAG.getNode(ISD::FMA, DL, VT, N0.getOperand(0),
- DAG.getNode(ISD::FNEG, DL, VT, N1), N2);
+ (N1.hasOneUse() &&
+ !TLI.isFPImmLegal(N1CFP->getValueAPF(), VT, ForCodeSize)))) {
+ return matcher.getNode(ISD::FMA, DL, VT, N0.getOperand(0),
+ matcher.getNode(ISD::FNEG, DL, VT, N1), N2);
}
}
+ // FIXME: Support splat of constant.
if (CanReassociate) {
// (fma x, c, x) -> (fmul x, (c+1))
if (N1CFP && N0 == N2) {
- return DAG.getNode(
- ISD::FMUL, DL, VT, N0,
- DAG.getNode(ISD::FADD, DL, VT, N1, DAG.getConstantFP(1.0, DL, VT)));
+ return matcher.getNode(ISD::FMUL, DL, VT, N0,
+ matcher.getNode(ISD::FADD, DL, VT, N1,
+ DAG.getConstantFP(1.0, DL, VT)));
}
// (fma x, c, (fneg x)) -> (fmul x, (c-1))
- if (N1CFP && N2.getOpcode() == ISD::FNEG && N2.getOperand(0) == N0) {
- return DAG.getNode(
- ISD::FMUL, DL, VT, N0,
- DAG.getNode(ISD::FADD, DL, VT, N1, DAG.getConstantFP(-1.0, DL, VT)));
+ if (N1CFP && matcher.match(N2, ISD::FNEG) && N2.getOperand(0) == N0) {
+ return matcher.getNode(ISD::FMUL, DL, VT, N0,
+ matcher.getNode(ISD::FADD, DL, VT, N1,
+ DAG.getConstantFP(-1.0, DL, VT)));
}
}
@@ -16438,7 +16443,7 @@ SDValue DAGCombiner::visitFMA(SDNode *N) {
if (!TLI.isFNegFree(VT))
if (SDValue Neg = TLI.getCheaperNegatedExpression(
SDValue(N, 0), DAG, LegalOperations, ForCodeSize))
- return DAG.getNode(ISD::FNEG, DL, VT, Neg);
+ return matcher.getNode(ISD::FNEG, DL, VT, Neg);
return SDValue();
}
@@ -25695,6 +25700,8 @@ SDValue DAGCombiner::visitVPOp(SDNode *N) {
return visitVP_FADD(N);
case ISD::VP_FSUB:
return visitVP_FSUB(N);
+ case ISD::VP_FMA:
+ return visitFMA<VPMatchContext>(N);
}
return SDValue();
}
diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 4c443600f9413..32be369d08cf7 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -6816,7 +6816,7 @@ SDValue TargetLowering::getNegatedExpression(SDValue Op, SelectionDAG &DAG,
NegatibleCost &Cost,
unsigned Depth) const {
// fneg is removable even if it has multiple uses.
- if (Op.getOpcode() == ISD::FNEG) {
+ if (Op.getOpcode() == ISD::FNEG || Op.getOpcode() == ISD::VP_FNEG) {
Cost = NegatibleCost::Cheaper;
return Op.getOperand(0);
}
diff --git a/llvm/test/CodeGen/RISCV/rvv/vfma-vp-combine.ll b/llvm/test/CodeGen/RISCV/rvv/vfma-vp-combine.ll
new file mode 100644
index 0000000000000..1f64a72230c3b
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/rvv/vfma-vp-combine.ll
@@ -0,0 +1,70 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+experimental-zvfh,+v,+m -target-abi=ilp32d \
+; RUN: -verify-machineinstrs < %s | FileCheck %s
+; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+experimental-zvfh,+v,+m -target-abi=lp64d \
+; RUN: -verify-machineinstrs < %s | FileCheck %s
+
+declare <vscale x 1 x double> @llvm.vp.fma.nxv1f64(<vscale x 1 x double>, <vscale x 1 x double>, <vscale x 1 x double>, <vscale x 1 x i1>, i32)
+declare <vscale x 1 x double> @llvm.vp.fneg.nxv1f64(<vscale x 1 x double>, <vscale x 1 x i1>, i32)
+declare <vscale x 1 x double> @llvm.vp.fmul.nxv1f64(<vscale x 1 x double>, <vscale x 1 x double>, <vscale x 1 x i1>, i32)
+
+; (-N0 * -N1) + N2 --> (N0 * N1) + N2
+define <vscale x 1 x double> @test1(<vscale x 1 x double> %a, <vscale x 1 x double> %b, <vscale x 1 x double> %c, <vscale x 1 x i1> %m, i32 zeroext %evl) {
+; CHECK-LABEL: test1:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetvli zero, a0, e64, m1, ta, ma
+; CHECK-NEXT: vfmadd.vv v9, v8, v10, v0.t
+; CHECK-NEXT: vmv.v.v v8, v9
+; CHECK-NEXT: ret
+ %nega = call <vscale x 1 x double> @llvm.vp.fneg.nxv1f64(<vscale x 1 x double> %a, <vscale x 1 x i1> %m, i32 %evl)
+ %negb = call <vscale x 1 x double> @llvm.vp.fneg.nxv1f64(<vscale x 1 x double> %b, <vscale x 1 x i1> %m, i32 %evl)
+ %v = call <vscale x 1 x double> @llvm.vp.fma.nxv1f64(<vscale x 1 x double> %nega, <vscale x 1 x double> %negb, <vscale x 1 x double> %c, <vscale x 1 x i1> %m, i32 %evl)
+ ret <vscale x 1 x double> %v
+}
+
+; (fma x, c1, (fmul x, c2)) -> (fmul x, c1+c2)
+define <vscale x 1 x double> @test2(<vscale x 1 x double> %a, <vscale x 1 x i1> %m, i32 zeroext %evl) {
+; CHECK-LABEL: test2:
+; CHECK: # %bb.0:
+; CHECK-NEXT: lui a1, %hi(.LCPI1_0)
+; CHECK-NEXT: addi a1, a1, %lo(.LCPI1_0)
+; CHECK-NEXT: vsetvli a2, zero, e64, m1, ta, ma
+; CHECK-NEXT: vlse64.v v9, (a1), zero
+; CHECK-NEXT: lui a1, %hi(.LCPI1_1)
+; CHECK-NEXT: fld fa5, %lo(.LCPI1_1)(a1)
+; CHECK-NEXT: vsetvli zero, a0, e64, m1, ta, ma
+; CHECK-NEXT: vfadd.vf v9, v9, fa5, v0.t
+; CHECK-NEXT: vfmul.vv v8, v8, v9, v0.t
+; CHECK-NEXT: ret
+ %elt.head1 = insertelement <vscale x 1 x double> poison, double 2.0, i32 0
+ %c1 = shufflevector <vscale x 1 x double> %elt.head1, <vscale x 1 x double> poison, <vscale x 1 x i32> zeroinitializer
+ %t = call <vscale x 1 x double> @llvm.vp.fmul.nxv1f64(<vscale x 1 x double> %a, <vscale x 1 x double> %c1, <vscale x 1 x i1> %m, i32 %evl)
+ %elt.head2 = insertelement <vscale x 1 x double> poison, double 4.0, i32 0
+ %c2 = shufflevector <vscale x 1 x double> %elt.head2, <vscale x 1 x double> poison, <vscale x 1 x i32> zeroinitializer
+ %v = call fast <vscale x 1 x double> @llvm.vp.fma.nxv1f64(<vscale x 1 x double> %a, <vscale x 1 x double> %c2, <vscale x 1 x double> %t, <vscale x 1 x i1> %m, i32 %evl)
+ ret <vscale x 1 x double> %v
+}
+
+; (fma (fmul x, c1), c2, y) -> (fma x, c1*c2, y)
+define <vscale x 1 x double> @test3(<vscale x 1 x double> %a, <vscale x 1 x double> %b, <vscale x 1 x i1> %m, i32 zeroext %evl) {
+; CHECK-LABEL: test3:
+; CHECK: # %bb.0:
+; CHECK-NEXT: lui a1, %hi(.LCPI2_0)
+; CHECK-NEXT: addi a1, a1, %lo(.LCPI2_0)
+; CHECK-NEXT: vsetvli a2, zero, e64, m1, ta, ma
+; CHECK-NEXT: vlse64.v v10, (a1), zero
+; CHECK-NEXT: lui a1, %hi(.LCPI2_1)
+; CHECK-NEXT: fld fa5, %lo(.LCPI2_1)(a1)
+; CHECK-NEXT: vsetvli zero, a0, e64, m1, ta, ma
+; CHECK-NEXT: vfmul.vf v10, v10, fa5, v0.t
+; CHECK-NEXT: vfmadd.vv v10, v8, v9, v0.t
+; CHECK-NEXT: vmv.v.v v8, v10
+; CHECK-NEXT: ret
+ %elt.head1 = insertelement <vscale x 1 x double> poison, double 2.0, i32 0
+ %c1 = shufflevector <vscale x 1 x double> %elt.head1, <vscale x 1 x double> poison, <vscale x 1 x i32> zeroinitializer
+ %t = call <vscale x 1 x double> @llvm.vp.fmul.nxv1f64(<vscale x 1 x double> %a, <vscale x 1 x double> %c1, <vscale x 1 x i1> %m, i32 %evl)
+ %elt.head2 = insertelement <vscale x 1 x double> poison, double 4.0, i32 0
+ %c2 = shufflevector <vscale x 1 x double> %elt.head2, <vscale x 1 x double> poison, <vscale x 1 x i32> zeroinitializer
+ %v = call fast <vscale x 1 x double> @llvm.vp.fma.nxv1f64(<vscale x 1 x double> %t, <vscale x 1 x double> %c2, <vscale x 1 x double> %b, <vscale x 1 x i1> %m, i32 %evl)
+ ret <vscale x 1 x double> %v
+}
More information about the llvm-commits
mailing list