[llvm] 2fe2a6d - [DAGCombiner] Use generalized pattern matcher in visitFMA to support vp.fma.

Yeting Kuo via llvm-commits llvm-commits at lists.llvm.org
Wed Jun 7 18:40:28 PDT 2023


Author: Yeting Kuo
Date: 2023-06-08T09:40:21+08:00
New Revision: 2fe2a6d4b8a4647e49d69a5ff7161946aeb7cee1

URL: https://github.com/llvm/llvm-project/commit/2fe2a6d4b8a4647e49d69a5ff7161946aeb7cee1
DIFF: https://github.com/llvm/llvm-project/commit/2fe2a6d4b8a4647e49d69a5ff7161946aeb7cee1.diff

LOG: [DAGCombiner] Use generalized pattern matcher in visitFMA to support vp.fma.

Note: Some patterns in visitFMA are needed refined to support splat of constant.

Reviewed By: luke

Differential Revision: https://reviews.llvm.org/D152260

Added: 
    llvm/test/CodeGen/RISCV/rvv/vfma-vp-combine.ll

Modified: 
    llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
    llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 02da5508656e6..8ffe00f802c31 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -491,7 +491,7 @@ namespace {
     SDValue visitSTRICT_FADD(SDNode *N);
     SDValue visitFSUB(SDNode *N);
     SDValue visitFMUL(SDNode *N);
-    SDValue visitFMA(SDNode *N);
+    template <class MatchContextClass> SDValue visitFMA(SDNode *N);
     SDValue visitFDIV(SDNode *N);
     SDValue visitFREM(SDNode *N);
     SDValue visitFSQRT(SDNode *N);
@@ -1961,7 +1961,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
   case ISD::STRICT_FADD:        return visitSTRICT_FADD(N);
   case ISD::FSUB:               return visitFSUB(N);
   case ISD::FMUL:               return visitFMUL(N);
-  case ISD::FMA:                return visitFMA(N);
+  case ISD::FMA:                return visitFMA<EmptyMatchContext>(N);
   case ISD::FDIV:               return visitFDIV(N);
   case ISD::FREM:               return visitFREM(N);
   case ISD::FSQRT:              return visitFSQRT(N);
@@ -16320,7 +16320,7 @@ SDValue DAGCombiner::visitFMUL(SDNode *N) {
   return SDValue();
 }
 
-SDValue DAGCombiner::visitFMA(SDNode *N) {
+template <class MatchContextClass> SDValue DAGCombiner::visitFMA(SDNode *N) {
   SDValue N0 = N->getOperand(0);
   SDValue N1 = N->getOperand(1);
   SDValue N2 = N->getOperand(2);
@@ -16331,6 +16331,7 @@ SDValue DAGCombiner::visitFMA(SDNode *N) {
   const TargetOptions &Options = DAG.getTarget().Options;
   // FMA nodes have flags that propagate to the created nodes.
   SelectionDAG::FlagInserter FlagsInserter(DAG, N);
+  MatchContextClass matcher(DAG, TLI, N);
 
   bool CanReassociate =
       Options.UnsafeFPMath || N->getFlags().hasAllowReassociation();
@@ -16339,7 +16340,7 @@ SDValue DAGCombiner::visitFMA(SDNode *N) {
   if (isa<ConstantFPSDNode>(N0) &&
       isa<ConstantFPSDNode>(N1) &&
       isa<ConstantFPSDNode>(N2)) {
-    return DAG.getNode(ISD::FMA, DL, VT, N0, N1, N2);
+    return matcher.getNode(ISD::FMA, DL, VT, N0, N1, N2);
   }
 
   // (-N0 * -N1) + N2 --> (N0 * N1) + N2
@@ -16355,7 +16356,7 @@ SDValue DAGCombiner::visitFMA(SDNode *N) {
         TLI.getNegatedExpression(N1, DAG, LegalOperations, ForCodeSize, CostN1);
     if (NegN1 && (CostN0 == TargetLowering::NegatibleCost::Cheaper ||
                   CostN1 == TargetLowering::NegatibleCost::Cheaper))
-      return DAG.getNode(ISD::FMA, DL, VT, NegN0, NegN1, N2);
+      return matcher.getNode(ISD::FMA, DL, VT, NegN0, NegN1, N2);
   }
 
   // FIXME: use fast math flags instead of Options.UnsafeFPMath
@@ -16366,70 +16367,74 @@ SDValue DAGCombiner::visitFMA(SDNode *N) {
       return N2;
   }
 
+  // FIXME: Support splat of constant.
   if (N0CFP && N0CFP->isExactlyValue(1.0))
-    return DAG.getNode(ISD::FADD, SDLoc(N), VT, N1, N2);
+    return matcher.getNode(ISD::FADD, SDLoc(N), VT, N1, N2);
   if (N1CFP && N1CFP->isExactlyValue(1.0))
-    return DAG.getNode(ISD::FADD, SDLoc(N), VT, N0, N2);
+    return matcher.getNode(ISD::FADD, SDLoc(N), VT, N0, N2);
 
   // Canonicalize (fma c, x, y) -> (fma x, c, y)
   if (DAG.isConstantFPBuildVectorOrConstantFP(N0) &&
      !DAG.isConstantFPBuildVectorOrConstantFP(N1))
-    return DAG.getNode(ISD::FMA, SDLoc(N), VT, N1, N0, N2);
+    return matcher.getNode(ISD::FMA, SDLoc(N), VT, N1, N0, N2);
 
   if (CanReassociate) {
     // (fma x, c1, (fmul x, c2)) -> (fmul x, c1+c2)
-    if (N2.getOpcode() == ISD::FMUL && N0 == N2.getOperand(0) &&
+    if (matcher.match(N2, ISD::FMUL) && N0 == N2.getOperand(0) &&
         DAG.isConstantFPBuildVectorOrConstantFP(N1) &&
         DAG.isConstantFPBuildVectorOrConstantFP(N2.getOperand(1))) {
-      return DAG.getNode(ISD::FMUL, DL, VT, N0,
-                         DAG.getNode(ISD::FADD, DL, VT, N1, N2.getOperand(1)));
+      return matcher.getNode(
+          ISD::FMUL, DL, VT, N0,
+          matcher.getNode(ISD::FADD, DL, VT, N1, N2.getOperand(1)));
     }
 
     // (fma (fmul x, c1), c2, y) -> (fma x, c1*c2, y)
-    if (N0.getOpcode() == ISD::FMUL &&
+    if (matcher.match(N0, ISD::FMUL) &&
         DAG.isConstantFPBuildVectorOrConstantFP(N1) &&
         DAG.isConstantFPBuildVectorOrConstantFP(N0.getOperand(1))) {
-      return DAG.getNode(ISD::FMA, DL, VT, N0.getOperand(0),
-                         DAG.getNode(ISD::FMUL, DL, VT, N1, N0.getOperand(1)),
-                         N2);
+      return matcher.getNode(
+          ISD::FMA, DL, VT, N0.getOperand(0),
+          matcher.getNode(ISD::FMUL, DL, VT, N1, N0.getOperand(1)), N2);
     }
   }
 
   // (fma x, -1, y) -> (fadd (fneg x), y)
+  // FIXME: Support splat of constant.
   if (N1CFP) {
     if (N1CFP->isExactlyValue(1.0))
-      return DAG.getNode(ISD::FADD, DL, VT, N0, N2);
+      return matcher.getNode(ISD::FADD, DL, VT, N0, N2);
 
     if (N1CFP->isExactlyValue(-1.0) &&
         (!LegalOperations || TLI.isOperationLegal(ISD::FNEG, VT))) {
-      SDValue RHSNeg = DAG.getNode(ISD::FNEG, DL, VT, N0);
+      SDValue RHSNeg = matcher.getNode(ISD::FNEG, DL, VT, N0);
       AddToWorklist(RHSNeg.getNode());
-      return DAG.getNode(ISD::FADD, DL, VT, N2, RHSNeg);
+      return matcher.getNode(ISD::FADD, DL, VT, N2, RHSNeg);
     }
 
     // fma (fneg x), K, y -> fma x -K, y
-    if (N0.getOpcode() == ISD::FNEG &&
+    if (matcher.match(N0, ISD::FNEG) &&
         (TLI.isOperationLegal(ISD::ConstantFP, VT) ||
-         (N1.hasOneUse() && !TLI.isFPImmLegal(N1CFP->getValueAPF(), VT,
-                                              ForCodeSize)))) {
-      return DAG.getNode(ISD::FMA, DL, VT, N0.getOperand(0),
-                         DAG.getNode(ISD::FNEG, DL, VT, N1), N2);
+         (N1.hasOneUse() &&
+          !TLI.isFPImmLegal(N1CFP->getValueAPF(), VT, ForCodeSize)))) {
+      return matcher.getNode(ISD::FMA, DL, VT, N0.getOperand(0),
+                             matcher.getNode(ISD::FNEG, DL, VT, N1), N2);
     }
   }
 
+  // FIXME: Support splat of constant.
   if (CanReassociate) {
     // (fma x, c, x) -> (fmul x, (c+1))
     if (N1CFP && N0 == N2) {
-      return DAG.getNode(
-          ISD::FMUL, DL, VT, N0,
-          DAG.getNode(ISD::FADD, DL, VT, N1, DAG.getConstantFP(1.0, DL, VT)));
+      return matcher.getNode(ISD::FMUL, DL, VT, N0,
+                             matcher.getNode(ISD::FADD, DL, VT, N1,
+                                             DAG.getConstantFP(1.0, DL, VT)));
     }
 
     // (fma x, c, (fneg x)) -> (fmul x, (c-1))
-    if (N1CFP && N2.getOpcode() == ISD::FNEG && N2.getOperand(0) == N0) {
-      return DAG.getNode(
-          ISD::FMUL, DL, VT, N0,
-          DAG.getNode(ISD::FADD, DL, VT, N1, DAG.getConstantFP(-1.0, DL, VT)));
+    if (N1CFP && matcher.match(N2, ISD::FNEG) && N2.getOperand(0) == N0) {
+      return matcher.getNode(ISD::FMUL, DL, VT, N0,
+                             matcher.getNode(ISD::FADD, DL, VT, N1,
+                                             DAG.getConstantFP(-1.0, DL, VT)));
     }
   }
 
@@ -16438,7 +16443,7 @@ SDValue DAGCombiner::visitFMA(SDNode *N) {
   if (!TLI.isFNegFree(VT))
     if (SDValue Neg = TLI.getCheaperNegatedExpression(
             SDValue(N, 0), DAG, LegalOperations, ForCodeSize))
-      return DAG.getNode(ISD::FNEG, DL, VT, Neg);
+      return matcher.getNode(ISD::FNEG, DL, VT, Neg);
   return SDValue();
 }
 
@@ -25695,6 +25700,8 @@ SDValue DAGCombiner::visitVPOp(SDNode *N) {
       return visitVP_FADD(N);
     case ISD::VP_FSUB:
       return visitVP_FSUB(N);
+    case ISD::VP_FMA:
+      return visitFMA<VPMatchContext>(N);
     }
     return SDValue();
   }

diff  --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index 4c443600f9413..32be369d08cf7 100644
--- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -6816,7 +6816,7 @@ SDValue TargetLowering::getNegatedExpression(SDValue Op, SelectionDAG &DAG,
                                              NegatibleCost &Cost,
                                              unsigned Depth) const {
   // fneg is removable even if it has multiple uses.
-  if (Op.getOpcode() == ISD::FNEG) {
+  if (Op.getOpcode() == ISD::FNEG || Op.getOpcode() == ISD::VP_FNEG) {
     Cost = NegatibleCost::Cheaper;
     return Op.getOperand(0);
   }

diff  --git a/llvm/test/CodeGen/RISCV/rvv/vfma-vp-combine.ll b/llvm/test/CodeGen/RISCV/rvv/vfma-vp-combine.ll
new file mode 100644
index 0000000000000..1f64a72230c3b
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/rvv/vfma-vp-combine.ll
@@ -0,0 +1,70 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc -mtriple=riscv32 -mattr=+d,+zfh,+experimental-zvfh,+v,+m -target-abi=ilp32d \
+; RUN:     -verify-machineinstrs < %s | FileCheck %s
+; RUN: llc -mtriple=riscv64 -mattr=+d,+zfh,+experimental-zvfh,+v,+m -target-abi=lp64d \
+; RUN:     -verify-machineinstrs < %s | FileCheck %s
+
+declare <vscale x 1 x double> @llvm.vp.fma.nxv1f64(<vscale x 1 x double>, <vscale x 1 x double>, <vscale x 1 x double>, <vscale x 1 x i1>, i32)
+declare <vscale x 1 x double> @llvm.vp.fneg.nxv1f64(<vscale x 1 x double>, <vscale x 1 x i1>, i32)
+declare <vscale x 1 x double> @llvm.vp.fmul.nxv1f64(<vscale x 1 x double>, <vscale x 1 x double>, <vscale x 1 x i1>, i32)
+
+; (-N0 * -N1) + N2 --> (N0 * N1) + N2
+define <vscale x 1 x double> @test1(<vscale x 1 x double> %a, <vscale x 1 x double> %b, <vscale x 1 x double> %c, <vscale x 1 x i1> %m, i32 zeroext %evl) {
+; CHECK-LABEL: test1:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli zero, a0, e64, m1, ta, ma
+; CHECK-NEXT:    vfmadd.vv v9, v8, v10, v0.t
+; CHECK-NEXT:    vmv.v.v v8, v9
+; CHECK-NEXT:    ret
+  %nega = call <vscale x 1 x double> @llvm.vp.fneg.nxv1f64(<vscale x 1 x double> %a, <vscale x 1 x i1> %m, i32 %evl)
+  %negb = call <vscale x 1 x double> @llvm.vp.fneg.nxv1f64(<vscale x 1 x double> %b, <vscale x 1 x i1> %m, i32 %evl)
+  %v = call <vscale x 1 x double> @llvm.vp.fma.nxv1f64(<vscale x 1 x double> %nega, <vscale x 1 x double> %negb, <vscale x 1 x double> %c, <vscale x 1 x i1> %m, i32 %evl)
+  ret <vscale x 1 x double> %v
+}
+
+; (fma x, c1, (fmul x, c2)) -> (fmul x, c1+c2)
+define <vscale x 1 x double> @test2(<vscale x 1 x double> %a, <vscale x 1 x i1> %m, i32 zeroext %evl) {
+; CHECK-LABEL: test2:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    lui a1, %hi(.LCPI1_0)
+; CHECK-NEXT:    addi a1, a1, %lo(.LCPI1_0)
+; CHECK-NEXT:    vsetvli a2, zero, e64, m1, ta, ma
+; CHECK-NEXT:    vlse64.v v9, (a1), zero
+; CHECK-NEXT:    lui a1, %hi(.LCPI1_1)
+; CHECK-NEXT:    fld fa5, %lo(.LCPI1_1)(a1)
+; CHECK-NEXT:    vsetvli zero, a0, e64, m1, ta, ma
+; CHECK-NEXT:    vfadd.vf v9, v9, fa5, v0.t
+; CHECK-NEXT:    vfmul.vv v8, v8, v9, v0.t
+; CHECK-NEXT:    ret
+  %elt.head1 = insertelement <vscale x 1 x double> poison, double 2.0, i32 0
+  %c1 = shufflevector <vscale x 1 x double> %elt.head1, <vscale x 1 x double> poison, <vscale x 1 x i32> zeroinitializer
+  %t = call <vscale x 1 x double> @llvm.vp.fmul.nxv1f64(<vscale x 1 x double> %a, <vscale x 1 x double> %c1, <vscale x 1 x i1> %m, i32 %evl)
+  %elt.head2 = insertelement <vscale x 1 x double> poison, double 4.0, i32 0
+  %c2 = shufflevector <vscale x 1 x double> %elt.head2, <vscale x 1 x double> poison, <vscale x 1 x i32> zeroinitializer
+  %v = call fast <vscale x 1 x double> @llvm.vp.fma.nxv1f64(<vscale x 1 x double> %a, <vscale x 1 x double> %c2, <vscale x 1 x double> %t, <vscale x 1 x i1> %m, i32 %evl)
+  ret <vscale x 1 x double> %v
+}
+
+; (fma (fmul x, c1), c2, y) -> (fma x, c1*c2, y)
+define <vscale x 1 x double> @test3(<vscale x 1 x double> %a, <vscale x 1 x double> %b, <vscale x 1 x i1> %m, i32 zeroext %evl) {
+; CHECK-LABEL: test3:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    lui a1, %hi(.LCPI2_0)
+; CHECK-NEXT:    addi a1, a1, %lo(.LCPI2_0)
+; CHECK-NEXT:    vsetvli a2, zero, e64, m1, ta, ma
+; CHECK-NEXT:    vlse64.v v10, (a1), zero
+; CHECK-NEXT:    lui a1, %hi(.LCPI2_1)
+; CHECK-NEXT:    fld fa5, %lo(.LCPI2_1)(a1)
+; CHECK-NEXT:    vsetvli zero, a0, e64, m1, ta, ma
+; CHECK-NEXT:    vfmul.vf v10, v10, fa5, v0.t
+; CHECK-NEXT:    vfmadd.vv v10, v8, v9, v0.t
+; CHECK-NEXT:    vmv.v.v v8, v10
+; CHECK-NEXT:    ret
+  %elt.head1 = insertelement <vscale x 1 x double> poison, double 2.0, i32 0
+  %c1 = shufflevector <vscale x 1 x double> %elt.head1, <vscale x 1 x double> poison, <vscale x 1 x i32> zeroinitializer
+  %t = call <vscale x 1 x double> @llvm.vp.fmul.nxv1f64(<vscale x 1 x double> %a, <vscale x 1 x double> %c1, <vscale x 1 x i1> %m, i32 %evl)
+  %elt.head2 = insertelement <vscale x 1 x double> poison, double 4.0, i32 0
+  %c2 = shufflevector <vscale x 1 x double> %elt.head2, <vscale x 1 x double> poison, <vscale x 1 x i32> zeroinitializer
+  %v = call fast <vscale x 1 x double> @llvm.vp.fma.nxv1f64(<vscale x 1 x double> %t, <vscale x 1 x double> %c2, <vscale x 1 x double> %b, <vscale x 1 x i1> %m, i32 %evl)
+  ret <vscale x 1 x double> %v
+}


        


More information about the llvm-commits mailing list