[llvm] [RISCV][MachineCombiner] Combine `fadd X, (fneg Y)` to `fsub X, Y` (PR #107803)

Yingwei Zheng via llvm-commits llvm-commits at lists.llvm.org
Sun Sep 8 20:57:29 PDT 2024


https://github.com/dtcxzyw created https://github.com/llvm/llvm-project/pull/107803

This patch adds transformation `fadd X, (fneg Y)` to `fsub X, Y` to eliminate unnecessary fneg instructions introduced by materializing fpimm -0.5 with zfa. I don't see the value of combining the commuted version because it should be handled by SDAG.

Closes https://github.com/llvm/llvm-project/issues/107772.


>From 4bfedc827dbf841761d336928a2900377a8ee5a5 Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Mon, 9 Sep 2024 11:14:32 +0800
Subject: [PATCH 1/2] [RISCV][MachineCombiner] Add pre-commit tests. NFC.

---
 llvm/test/CodeGen/RISCV/float-zfa.ll | 12 ++++++++++++
 1 file changed, 12 insertions(+)

diff --git a/llvm/test/CodeGen/RISCV/float-zfa.ll b/llvm/test/CodeGen/RISCV/float-zfa.ll
index e5196ead1f8819..bfaf6f9d6bf9e3 100644
--- a/llvm/test/CodeGen/RISCV/float-zfa.ll
+++ b/llvm/test/CodeGen/RISCV/float-zfa.ll
@@ -269,3 +269,15 @@ define void @fli_remat() {
   tail call void @foo(float 1.000000e+00, float 1.000000e+00)
   ret void
 }
+
+define float @add_negimm(float %x) {
+; CHECK-LABEL: add_negimm:
+; CHECK:       # %bb.0: # %entry
+; CHECK-NEXT:    fli.s fa5, 0.5
+; CHECK-NEXT:    fneg.s fa5, fa5
+; CHECK-NEXT:    fadd.s fa0, fa0, fa5
+; CHECK-NEXT:    ret
+entry:
+  %sub = fadd float %x, -5.000000e-01
+  ret float %sub
+}

>From 25b25d4107a4c8f264ea165e8dcf96ab2b5768d3 Mon Sep 17 00:00:00 2001
From: Yingwei Zheng <dtcxzyw2333 at gmail.com>
Date: Mon, 9 Sep 2024 11:32:08 +0800
Subject: [PATCH 2/2] [RISCV][MachineCombiner] Combine `fadd X, (fneg Y)` to
 `fsub X, Y`

---
 llvm/lib/Target/RISCV/RISCVInstrInfo.cpp | 118 +++++++++++++++++++----
 llvm/lib/Target/RISCV/RISCVInstrInfo.h   |   1 +
 llvm/test/CodeGen/RISCV/float-zfa.ll     |   3 +-
 3 files changed, 99 insertions(+), 23 deletions(-)

diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
index 0a64a8e1440084..1549ef0cf4937c 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
@@ -2019,6 +2019,55 @@ RISCVInstrInfo::getInverseOpcode(unsigned Opcode) const {
 #undef RVV_OPC_LMUL_CASE
 }
 
+/// Utility routine that checks if \param MO is defined by an
+/// \param CombineOpc instruction in the basic block \param MBB
+static const MachineInstr *canCombine(const MachineBasicBlock &MBB,
+                                      const MachineOperand &MO,
+                                      unsigned CombineOpc) {
+  const MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo();
+  const MachineInstr *MI = nullptr;
+
+  if (MO.isReg() && MO.getReg().isVirtual())
+    MI = MRI.getUniqueVRegDef(MO.getReg());
+  // And it needs to be in the trace (otherwise, it won't have a depth).
+  if (!MI || MI->getParent() != &MBB || MI->getOpcode() != CombineOpc)
+    return nullptr;
+  // Must only used by the user we combine with.
+  if (!MRI.hasOneNonDBGUse(MI->getOperand(0).getReg()))
+    return nullptr;
+
+  return MI;
+}
+
+/// Fold (fadd X, fneg Y) -> fsub X, Y
+static bool getFPAddNegCombinePatterns(MachineInstr &Root,
+                                       SmallVectorImpl<unsigned> &Patterns) {
+  unsigned Opc = Root.getOpcode();
+  unsigned NegOpc;
+  bool Added = false;
+  switch (Opc) {
+  case RISCV::FADD_H:
+    NegOpc = RISCV::FSGNJN_H;
+    break;
+  case RISCV::FADD_S:
+    NegOpc = RISCV::FSGNJN_S;
+    break;
+  case RISCV::FADD_D:
+    NegOpc = RISCV::FSGNJN_D;
+    break;
+  default:
+    return false;
+  }
+  const MachineBasicBlock &MBB = *Root.getParent();
+  const MachineInstr *NegMI = canCombine(MBB, Root.getOperand(2), NegOpc);
+  if (NegMI && NegMI->getOperand(1).isReg() && NegMI->getOperand(2).isReg() &&
+      NegMI->getOperand(1).getReg() == NegMI->getOperand(2).getReg()) {
+    Patterns.push_back(RISCVMachineCombinerPattern::FADD_NEGRHS);
+    Added = true;
+  }
+  return Added;
+}
+
 static bool canCombineFPFusedMultiply(const MachineInstr &Root,
                                       const MachineOperand &MO,
                                       bool DoRegPressureReduce) {
@@ -2072,27 +2121,8 @@ static bool getFPFusedMultiplyPatterns(MachineInstr &Root,
 static bool getFPPatterns(MachineInstr &Root,
                           SmallVectorImpl<unsigned> &Patterns,
                           bool DoRegPressureReduce) {
-  return getFPFusedMultiplyPatterns(Root, Patterns, DoRegPressureReduce);
-}
-
-/// Utility routine that checks if \param MO is defined by an
-/// \param CombineOpc instruction in the basic block \param MBB
-static const MachineInstr *canCombine(const MachineBasicBlock &MBB,
-                                      const MachineOperand &MO,
-                                      unsigned CombineOpc) {
-  const MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo();
-  const MachineInstr *MI = nullptr;
-
-  if (MO.isReg() && MO.getReg().isVirtual())
-    MI = MRI.getUniqueVRegDef(MO.getReg());
-  // And it needs to be in the trace (otherwise, it won't have a depth).
-  if (!MI || MI->getParent() != &MBB || MI->getOpcode() != CombineOpc)
-    return nullptr;
-  // Must only used by the user we combine with.
-  if (!MRI.hasOneNonDBGUse(MI->getOperand(0).getReg()))
-    return nullptr;
-
-  return MI;
+  return getFPAddNegCombinePatterns(Root, Patterns) ||
+         getFPFusedMultiplyPatterns(Root, Patterns, DoRegPressureReduce);
 }
 
 /// Utility routine that checks if \param MO is defined by a SLLI in \param
@@ -2319,6 +2349,49 @@ genShXAddAddShift(MachineInstr &Root, unsigned AddOpIdx,
   DelInstrs.push_back(&Root);
 }
 
+// Combine (fadd X, (fneg Y)) -> fsub X, Y
+static void combineFPAddNeg(MachineInstr &Root,
+                            SmallVectorImpl<MachineInstr *> &InsInstrs,
+                            SmallVectorImpl<MachineInstr *> &DelInstrs,
+                            DenseMap<unsigned, unsigned> &InstrIdxForVirtReg) {
+  MachineFunction *MF = Root.getMF();
+  MachineRegisterInfo &MRI = MF->getRegInfo();
+  const TargetInstrInfo *TII = MF->getSubtarget().getInstrInfo();
+
+  unsigned Opc = Root.getOpcode();
+  unsigned SubOpc;
+  switch (Opc) {
+  case RISCV::FADD_H:
+    SubOpc = RISCV::FSUB_H;
+    break;
+  case RISCV::FADD_S:
+    SubOpc = RISCV::FSUB_S;
+    break;
+  case RISCV::FADD_D:
+    SubOpc = RISCV::FSUB_D;
+    break;
+  default:
+    llvm_unreachable("Unexpected opcode");
+  }
+
+  const MachineOperand &LHS = Root.getOperand(1);
+  const MachineOperand &RHS = Root.getOperand(2);
+  MachineInstr *NegMI = MRI.getUniqueVRegDef(RHS.getReg());
+  const MachineOperand &NegRHS = NegMI->getOperand(1);
+
+  MachineInstrBuilder MIB =
+      BuildMI(*MF, MIMetadata(Root), TII->get(SubOpc),
+              Root.getOperand(0).getReg())
+          .addReg(LHS.getReg(), getKillRegState(LHS.isKill()))
+          .addReg(NegRHS.getReg(), getKillRegState(NegRHS.isKill()))
+          .addImm(Root.getOperand(3).getImm())
+          .copyImplicitOps(Root);
+
+  InsInstrs.push_back(MIB);
+  DelInstrs.push_back(NegMI);
+  DelInstrs.push_back(&Root);
+}
+
 void RISCVInstrInfo::genAlternativeCodeSequence(
     MachineInstr &Root, unsigned Pattern,
     SmallVectorImpl<MachineInstr *> &InsInstrs,
@@ -2348,6 +2421,9 @@ void RISCVInstrInfo::genAlternativeCodeSequence(
   case RISCVMachineCombinerPattern::SHXADD_ADD_SLLI_OP2:
     genShXAddAddShift(Root, 2, InsInstrs, DelInstrs, InstrIdxForVirtReg);
     return;
+  case RISCVMachineCombinerPattern::FADD_NEGRHS:
+    combineFPAddNeg(Root, InsInstrs, DelInstrs, InstrIdxForVirtReg);
+    return;
   }
 }
 
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.h b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
index 457db9b9860d00..64ff3e04b30d2c 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
@@ -57,6 +57,7 @@ enum RISCVMachineCombinerPattern : unsigned {
   FNMSUB,
   SHXADD_ADD_SLLI_OP1,
   SHXADD_ADD_SLLI_OP2,
+  FADD_NEGRHS,
 };
 
 class RISCVInstrInfo : public RISCVGenInstrInfo {
diff --git a/llvm/test/CodeGen/RISCV/float-zfa.ll b/llvm/test/CodeGen/RISCV/float-zfa.ll
index bfaf6f9d6bf9e3..1806fdfb39384a 100644
--- a/llvm/test/CodeGen/RISCV/float-zfa.ll
+++ b/llvm/test/CodeGen/RISCV/float-zfa.ll
@@ -274,8 +274,7 @@ define float @add_negimm(float %x) {
 ; CHECK-LABEL: add_negimm:
 ; CHECK:       # %bb.0: # %entry
 ; CHECK-NEXT:    fli.s fa5, 0.5
-; CHECK-NEXT:    fneg.s fa5, fa5
-; CHECK-NEXT:    fadd.s fa0, fa0, fa5
+; CHECK-NEXT:    fsub.s fa0, fa0, fa5
 ; CHECK-NEXT:    ret
 entry:
   %sub = fadd float %x, -5.000000e-01



More information about the llvm-commits mailing list