[llvm] b6c7907 - [MachineCombiner][RISCV] Add fmadd/fmsub/fnmsub instructions patterns

Anton Sidorenko via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 17 02:28:37 PST 2022


Author: Anton Sidorenko
Date: 2022-11-17T13:24:04+03:00
New Revision: b6c790736e77f79e819fa761fb13f2311296714a

URL: https://github.com/llvm/llvm-project/commit/b6c790736e77f79e819fa761fb13f2311296714a
DIFF: https://github.com/llvm/llvm-project/commit/b6c790736e77f79e819fa761fb13f2311296714a.diff

LOG: [MachineCombiner][RISCV] Add fmadd/fmsub/fnmsub instructions patterns

This patch adds tranformation of fmul+fadd/fsub chains to fused multiply
instructions:
  * fmul+fadd->fmadd
  * fmul+fsub->fmsub/fnmsub

We also will try to combine these instructions if the fmul has more than one use
and cannot be deleted. However, removing the dependence between fmul and fadd can
still be profitable, and we rely on machine combiner approximations of scheduling.

Differential Revision: https://reviews.llvm.org/D136764

Added: 
    

Modified: 
    llvm/include/llvm/CodeGen/MachineCombinerPattern.h
    llvm/lib/CodeGen/MachineCombiner.cpp
    llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
    llvm/lib/Target/RISCV/RISCVInstrInfo.h
    llvm/test/CodeGen/RISCV/machine-combiner-mir.ll
    llvm/test/CodeGen/RISCV/machine-combiner.ll

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/CodeGen/MachineCombinerPattern.h b/llvm/include/llvm/CodeGen/MachineCombinerPattern.h
index 68c95679d4667..39e70d583710e 100644
--- a/llvm/include/llvm/CodeGen/MachineCombinerPattern.h
+++ b/llvm/include/llvm/CodeGen/MachineCombinerPattern.h
@@ -169,6 +169,12 @@ enum class MachineCombinerPattern {
   FMULv4i32_indexed_OP2,
   FMULv8i16_indexed_OP1,
   FMULv8i16_indexed_OP2,
+
+  // RISCV FMADD, FMSUB, FNMSUB patterns
+  FMADD_AX,
+  FMADD_XA,
+  FMSUB,
+  FNMSUB,
 };
 
 } // end namespace llvm

diff  --git a/llvm/lib/CodeGen/MachineCombiner.cpp b/llvm/lib/CodeGen/MachineCombiner.cpp
index ed0fe7ddf2772..bced97f7d49a0 100644
--- a/llvm/lib/CodeGen/MachineCombiner.cpp
+++ b/llvm/lib/CodeGen/MachineCombiner.cpp
@@ -319,6 +319,10 @@ static CombinerObjective getCombinerObjective(MachineCombinerPattern P) {
   case MachineCombinerPattern::REASSOC_XMM_AMM_BMM:
   case MachineCombinerPattern::SUBADD_OP1:
   case MachineCombinerPattern::SUBADD_OP2:
+  case MachineCombinerPattern::FMADD_AX:
+  case MachineCombinerPattern::FMADD_XA:
+  case MachineCombinerPattern::FMSUB:
+  case MachineCombinerPattern::FNMSUB:
     return CombinerObjective::MustReduceDepth;
   case MachineCombinerPattern::REASSOC_XY_BCA:
   case MachineCombinerPattern::REASSOC_XY_BAC:

diff  --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
index c05527cff95d9..a560c87f3ab25 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
@@ -26,6 +26,7 @@
 #include "llvm/CodeGen/MachineInstrBuilder.h"
 #include "llvm/CodeGen/MachineRegisterInfo.h"
 #include "llvm/CodeGen/RegisterScavenging.h"
+#include "llvm/IR/DebugInfoMetadata.h"
 #include "llvm/MC/MCInstBuilder.h"
 #include "llvm/MC/TargetRegistry.h"
 #include "llvm/Support/ErrorHandling.h"
@@ -1176,6 +1177,17 @@ static bool isFADD(unsigned Opc) {
   }
 }
 
+static bool isFSUB(unsigned Opc) {
+  switch (Opc) {
+  default:
+    return false;
+  case RISCV::FSUB_H:
+  case RISCV::FSUB_S:
+  case RISCV::FSUB_D:
+    return true;
+  }
+}
+
 static bool isFMUL(unsigned Opc) {
   switch (Opc) {
   default:
@@ -1211,6 +1223,33 @@ static bool canReassociate(MachineInstr &Root, MachineOperand &MO) {
   return RISCV::hasEqualFRM(Root, *MI);
 }
 
+static bool canCombineFPFusedMultiply(const MachineInstr &Root,
+                                      const MachineOperand &MO,
+                                      bool DoRegPressureReduce) {
+  if (!MO.isReg() || !Register::isVirtualRegister(MO.getReg()))
+    return false;
+  const MachineRegisterInfo &MRI = Root.getMF()->getRegInfo();
+  MachineInstr *MI = MRI.getVRegDef(MO.getReg());
+  if (!MI || !isFMUL(MI->getOpcode()))
+    return false;
+
+  if (!Root.getFlag(MachineInstr::MIFlag::FmContract) ||
+      !MI->getFlag(MachineInstr::MIFlag::FmContract))
+    return false;
+
+  // Try combining even if fmul has more than one use as it eliminates
+  // dependency between fadd(fsub) and fmul. However, it can extend liveranges
+  // for fmul operands, so reject the transformation in register pressure
+  // reduction mode.
+  if (DoRegPressureReduce && !MRI.hasOneNonDBGUse(MI->getOperand(0).getReg()))
+    return false;
+
+  // Do not combine instructions from 
diff erent basic blocks.
+  if (Root.getParent() != MI->getParent())
+    return false;
+  return RISCV::hasEqualFRM(Root, *MI);
+}
+
 static bool
 getFPReassocPatterns(MachineInstr &Root,
                      SmallVectorImpl<MachineCombinerPattern> &Patterns) {
@@ -1228,25 +1267,148 @@ getFPReassocPatterns(MachineInstr &Root,
   return Added;
 }
 
-static bool getFPPatterns(MachineInstr &Root,
-                          SmallVectorImpl<MachineCombinerPattern> &Patterns) {
+static bool
+getFPFusedMultiplyPatterns(MachineInstr &Root,
+                           SmallVectorImpl<MachineCombinerPattern> &Patterns,
+                           bool DoRegPressureReduce) {
   unsigned Opc = Root.getOpcode();
-  if (isAssociativeAndCommutativeFPOpcode(Opc))
-    return getFPReassocPatterns(Root, Patterns);
-  return false;
+  bool IsFAdd = isFADD(Opc);
+  if (!IsFAdd && !isFSUB(Opc))
+    return false;
+  bool Added = false;
+  if (canCombineFPFusedMultiply(Root, Root.getOperand(1),
+                                DoRegPressureReduce)) {
+    Patterns.push_back(IsFAdd ? MachineCombinerPattern::FMADD_AX
+                              : MachineCombinerPattern::FMSUB);
+    Added = true;
+  }
+  if (canCombineFPFusedMultiply(Root, Root.getOperand(2),
+                                DoRegPressureReduce)) {
+    Patterns.push_back(IsFAdd ? MachineCombinerPattern::FMADD_XA
+                              : MachineCombinerPattern::FNMSUB);
+    Added = true;
+  }
+  return Added;
+}
+
+static bool getFPPatterns(MachineInstr &Root,
+                          SmallVectorImpl<MachineCombinerPattern> &Patterns,
+                          bool DoRegPressureReduce) {
+  bool Added = getFPFusedMultiplyPatterns(Root, Patterns, DoRegPressureReduce);
+  if (isAssociativeAndCommutativeFPOpcode(Root.getOpcode()))
+    Added |= getFPReassocPatterns(Root, Patterns);
+  return Added;
 }
 
 bool RISCVInstrInfo::getMachineCombinerPatterns(
     MachineInstr &Root, SmallVectorImpl<MachineCombinerPattern> &Patterns,
     bool DoRegPressureReduce) const {
 
-  if (getFPPatterns(Root, Patterns))
+  if (getFPPatterns(Root, Patterns, DoRegPressureReduce))
     return true;
 
   return TargetInstrInfo::getMachineCombinerPatterns(Root, Patterns,
                                                      DoRegPressureReduce);
 }
 
+static unsigned getFPFusedMultiplyOpcode(unsigned RootOpc,
+                                         MachineCombinerPattern Pattern) {
+  switch (RootOpc) {
+  default:
+    llvm_unreachable("Unexpected opcode");
+  case RISCV::FADD_H:
+    return RISCV::FMADD_H;
+  case RISCV::FADD_S:
+    return RISCV::FMADD_S;
+  case RISCV::FADD_D:
+    return RISCV::FMADD_D;
+  case RISCV::FSUB_H:
+    return Pattern == MachineCombinerPattern::FMSUB ? RISCV::FMSUB_H
+                                                    : RISCV::FNMSUB_H;
+  case RISCV::FSUB_S:
+    return Pattern == MachineCombinerPattern::FMSUB ? RISCV::FMSUB_S
+                                                    : RISCV::FNMSUB_S;
+  case RISCV::FSUB_D:
+    return Pattern == MachineCombinerPattern::FMSUB ? RISCV::FMSUB_D
+                                                    : RISCV::FNMSUB_D;
+  }
+}
+
+static unsigned getAddendOperandIdx(MachineCombinerPattern Pattern) {
+  switch (Pattern) {
+  default:
+    llvm_unreachable("Unexpected pattern");
+  case MachineCombinerPattern::FMADD_AX:
+  case MachineCombinerPattern::FMSUB:
+    return 2;
+  case MachineCombinerPattern::FMADD_XA:
+  case MachineCombinerPattern::FNMSUB:
+    return 1;
+  }
+}
+
+static void combineFPFusedMultiply(MachineInstr &Root, MachineInstr &Prev,
+                                   MachineCombinerPattern Pattern,
+                                   SmallVectorImpl<MachineInstr *> &InsInstrs,
+                                   SmallVectorImpl<MachineInstr *> &DelInstrs) {
+  MachineFunction *MF = Root.getMF();
+  MachineRegisterInfo &MRI = MF->getRegInfo();
+  const TargetInstrInfo *TII = MF->getSubtarget().getInstrInfo();
+
+  MachineOperand &Mul1 = Prev.getOperand(1);
+  MachineOperand &Mul2 = Prev.getOperand(2);
+  MachineOperand &Dst = Root.getOperand(0);
+  MachineOperand &Addend = Root.getOperand(getAddendOperandIdx(Pattern));
+
+  Register DstReg = Dst.getReg();
+  unsigned FusedOpc = getFPFusedMultiplyOpcode(Root.getOpcode(), Pattern);
+  auto IntersectedFlags = Root.getFlags() & Prev.getFlags();
+  DebugLoc MergedLoc =
+      DILocation::getMergedLocation(Root.getDebugLoc(), Prev.getDebugLoc());
+
+  MachineInstrBuilder MIB =
+      BuildMI(*MF, MergedLoc, TII->get(FusedOpc), DstReg)
+          .addReg(Mul1.getReg(), getKillRegState(Mul1.isKill()))
+          .addReg(Mul2.getReg(), getKillRegState(Mul2.isKill()))
+          .addReg(Addend.getReg(), getKillRegState(Addend.isKill()))
+          .setMIFlags(IntersectedFlags);
+
+  // Mul operands are not killed anymore.
+  Mul1.setIsKill(false);
+  Mul2.setIsKill(false);
+
+  InsInstrs.push_back(MIB);
+  if (MRI.hasOneNonDBGUse(Prev.getOperand(0).getReg()))
+    DelInstrs.push_back(&Prev);
+  DelInstrs.push_back(&Root);
+}
+
+void RISCVInstrInfo::genAlternativeCodeSequence(
+    MachineInstr &Root, MachineCombinerPattern Pattern,
+    SmallVectorImpl<MachineInstr *> &InsInstrs,
+    SmallVectorImpl<MachineInstr *> &DelInstrs,
+    DenseMap<unsigned, unsigned> &InstrIdxForVirtReg) const {
+  MachineRegisterInfo &MRI = Root.getMF()->getRegInfo();
+  switch (Pattern) {
+  default:
+    TargetInstrInfo::genAlternativeCodeSequence(Root, Pattern, InsInstrs,
+                                                DelInstrs, InstrIdxForVirtReg);
+    return;
+  case MachineCombinerPattern::FMADD_AX:
+  case MachineCombinerPattern::FMSUB: {
+    MachineInstr &Prev = *MRI.getVRegDef(Root.getOperand(1).getReg());
+    combineFPFusedMultiply(Root, Prev, Pattern, InsInstrs, DelInstrs);
+    return;
+  }
+  case MachineCombinerPattern::FMADD_XA:
+  case MachineCombinerPattern::FNMSUB: {
+    MachineInstr &Prev = *MRI.getVRegDef(Root.getOperand(2).getReg());
+    combineFPFusedMultiply(Root, Prev, Pattern, InsInstrs, DelInstrs);
+    return;
+  }
+  }
+}
+
 bool RISCVInstrInfo::verifyInstruction(const MachineInstr &MI,
                                        StringRef &ErrInfo) const {
   MCInstrDesc const &Desc = MI.getDesc();

diff  --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.h b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
index fe95bb2fc4d13..3eaca1e781b95 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
@@ -196,6 +196,12 @@ class RISCVInstrInfo : public RISCVGenInstrInfo {
   finalizeInsInstrs(MachineInstr &Root, MachineCombinerPattern &P,
                     SmallVectorImpl<MachineInstr *> &InsInstrs) const override;
 
+  void genAlternativeCodeSequence(
+      MachineInstr &Root, MachineCombinerPattern Pattern,
+      SmallVectorImpl<MachineInstr *> &InsInstrs,
+      SmallVectorImpl<MachineInstr *> &DelInstrs,
+      DenseMap<unsigned, unsigned> &InstrIdxForVirtReg) const override;
+
 protected:
   const RISCVSubtarget &STI;
 };

diff  --git a/llvm/test/CodeGen/RISCV/machine-combiner-mir.ll b/llvm/test/CodeGen/RISCV/machine-combiner-mir.ll
index 42ea49a9e546f..1a78419f56a67 100644
--- a/llvm/test/CodeGen/RISCV/machine-combiner-mir.ll
+++ b/llvm/test/CodeGen/RISCV/machine-combiner-mir.ll
@@ -95,8 +95,8 @@ define double @test_fmadd(double %a0, double %a1, double %a2) {
   ; CHECK-NEXT:   [[COPY1:%[0-9]+]]:fpr64 = COPY $f11_d
   ; CHECK-NEXT:   [[COPY2:%[0-9]+]]:fpr64 = COPY $f10_d
   ; CHECK-NEXT:   [[FMUL_D:%[0-9]+]]:fpr64 = contract nofpexcept FMUL_D [[COPY2]], [[COPY1]], 7, implicit $frm
-  ; CHECK-NEXT:   [[FADD_D:%[0-9]+]]:fpr64 = contract nofpexcept FADD_D [[FMUL_D]], [[COPY]], 7, implicit $frm
-  ; CHECK-NEXT:   [[FDIV_D:%[0-9]+]]:fpr64 = nofpexcept FDIV_D killed [[FADD_D]], [[FMUL_D]], 7, implicit $frm
+  ; CHECK-NEXT:   [[FMADD_D:%[0-9]+]]:fpr64 = contract nofpexcept FMADD_D [[COPY2]], [[COPY1]], [[COPY]], 7, implicit $frm
+  ; CHECK-NEXT:   [[FDIV_D:%[0-9]+]]:fpr64 = nofpexcept FDIV_D killed [[FMADD_D]], [[FMUL_D]], 7, implicit $frm
   ; CHECK-NEXT:   $f10_d = COPY [[FDIV_D]]
   ; CHECK-NEXT:   PseudoRET implicit $f10_d
   %t0 = fmul contract double %a0, %a1

diff  --git a/llvm/test/CodeGen/RISCV/machine-combiner.ll b/llvm/test/CodeGen/RISCV/machine-combiner.ll
index ca19c6bfa8ebe..e77c4abed7c66 100644
--- a/llvm/test/CodeGen/RISCV/machine-combiner.ll
+++ b/llvm/test/CodeGen/RISCV/machine-combiner.ll
@@ -188,10 +188,9 @@ define double @test_reassoc_fadd_flags_2(double %a0, double %a1, double %a2, dou
 define double @test_fmadd1(double %a0, double %a1, double %a2, double %a3) {
 ; CHECK-LABEL: test_fmadd1:
 ; CHECK:       # %bb.0:
-; CHECK-NEXT:    fmul.d ft0, fa0, fa1
-; CHECK-NEXT:    fadd.d ft1, ft0, fa2
-; CHECK-NEXT:    fadd.d ft0, fa3, ft0
-; CHECK-NEXT:    fadd.d fa0, ft1, ft0
+; CHECK-NEXT:    fmadd.d ft0, fa0, fa1, fa2
+; CHECK-NEXT:    fmadd.d ft1, fa0, fa1, fa3
+; CHECK-NEXT:    fadd.d fa0, ft0, ft1
 ; CHECK-NEXT:    ret
   %t0 = fmul contract double %a0, %a1
   %t1 = fadd contract double %t0, %a2
@@ -204,7 +203,7 @@ define double @test_fmadd2(double %a0, double %a1, double %a2) {
 ; CHECK-LABEL: test_fmadd2:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    fmul.d ft0, fa0, fa1
-; CHECK-NEXT:    fadd.d ft1, ft0, fa2
+; CHECK-NEXT:    fmadd.d ft1, fa0, fa1, fa2
 ; CHECK-NEXT:    fdiv.d fa0, ft1, ft0
 ; CHECK-NEXT:    ret
   %t0 = fmul contract double %a0, %a1
@@ -217,7 +216,7 @@ define double @test_fmsub(double %a0, double %a1, double %a2) {
 ; CHECK-LABEL: test_fmsub:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    fmul.d ft0, fa0, fa1
-; CHECK-NEXT:    fsub.d ft1, ft0, fa2
+; CHECK-NEXT:    fmsub.d ft1, fa0, fa1, fa2
 ; CHECK-NEXT:    fdiv.d fa0, ft1, ft0
 ; CHECK-NEXT:    ret
   %t0 = fmul contract double %a0, %a1
@@ -230,7 +229,7 @@ define double @test_fnmsub(double %a0, double %a1, double %a2) {
 ; CHECK-LABEL: test_fnmsub:
 ; CHECK:       # %bb.0:
 ; CHECK-NEXT:    fmul.d ft0, fa0, fa1
-; CHECK-NEXT:    fsub.d ft1, fa2, ft0
+; CHECK-NEXT:    fnmsub.d ft1, fa0, fa1, fa2
 ; CHECK-NEXT:    fdiv.d fa0, ft1, ft0
 ; CHECK-NEXT:    ret
   %t0 = fmul contract double %a0, %a1


        


More information about the llvm-commits mailing list