[llvm] [RISCV][MachineCombiner] Combine `fadd X, (fneg Y)` to `fsub X, Y` (PR #107803)
via llvm-commits
llvm-commits at lists.llvm.org
Sun Sep 8 20:57:59 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-backend-risc-v
Author: Yingwei Zheng (dtcxzyw)
<details>
<summary>Changes</summary>
This patch adds transformation `fadd X, (fneg Y)` to `fsub X, Y` to eliminate unnecessary fneg instructions introduced by materializing fpimm -0.5 with zfa. I don't see the value of combining the commuted version because it should be handled by SDAG.
Closes https://github.com/llvm/llvm-project/issues/107772.
---
Full diff: https://github.com/llvm/llvm-project/pull/107803.diff
3 Files Affected:
- (modified) llvm/lib/Target/RISCV/RISCVInstrInfo.cpp (+97-21)
- (modified) llvm/lib/Target/RISCV/RISCVInstrInfo.h (+1)
- (modified) llvm/test/CodeGen/RISCV/float-zfa.ll (+11)
``````````diff
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
index 0a64a8e1440084..1549ef0cf4937c 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
@@ -2019,6 +2019,55 @@ RISCVInstrInfo::getInverseOpcode(unsigned Opcode) const {
#undef RVV_OPC_LMUL_CASE
}
+/// Utility routine that checks if \param MO is defined by an
+/// \param CombineOpc instruction in the basic block \param MBB
+static const MachineInstr *canCombine(const MachineBasicBlock &MBB,
+ const MachineOperand &MO,
+ unsigned CombineOpc) {
+ const MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo();
+ const MachineInstr *MI = nullptr;
+
+ if (MO.isReg() && MO.getReg().isVirtual())
+ MI = MRI.getUniqueVRegDef(MO.getReg());
+ // And it needs to be in the trace (otherwise, it won't have a depth).
+ if (!MI || MI->getParent() != &MBB || MI->getOpcode() != CombineOpc)
+ return nullptr;
+ // Must only used by the user we combine with.
+ if (!MRI.hasOneNonDBGUse(MI->getOperand(0).getReg()))
+ return nullptr;
+
+ return MI;
+}
+
+/// Fold (fadd X, fneg Y) -> fsub X, Y
+static bool getFPAddNegCombinePatterns(MachineInstr &Root,
+ SmallVectorImpl<unsigned> &Patterns) {
+ unsigned Opc = Root.getOpcode();
+ unsigned NegOpc;
+ bool Added = false;
+ switch (Opc) {
+ case RISCV::FADD_H:
+ NegOpc = RISCV::FSGNJN_H;
+ break;
+ case RISCV::FADD_S:
+ NegOpc = RISCV::FSGNJN_S;
+ break;
+ case RISCV::FADD_D:
+ NegOpc = RISCV::FSGNJN_D;
+ break;
+ default:
+ return false;
+ }
+ const MachineBasicBlock &MBB = *Root.getParent();
+ const MachineInstr *NegMI = canCombine(MBB, Root.getOperand(2), NegOpc);
+ if (NegMI && NegMI->getOperand(1).isReg() && NegMI->getOperand(2).isReg() &&
+ NegMI->getOperand(1).getReg() == NegMI->getOperand(2).getReg()) {
+ Patterns.push_back(RISCVMachineCombinerPattern::FADD_NEGRHS);
+ Added = true;
+ }
+ return Added;
+}
+
static bool canCombineFPFusedMultiply(const MachineInstr &Root,
const MachineOperand &MO,
bool DoRegPressureReduce) {
@@ -2072,27 +2121,8 @@ static bool getFPFusedMultiplyPatterns(MachineInstr &Root,
static bool getFPPatterns(MachineInstr &Root,
SmallVectorImpl<unsigned> &Patterns,
bool DoRegPressureReduce) {
- return getFPFusedMultiplyPatterns(Root, Patterns, DoRegPressureReduce);
-}
-
-/// Utility routine that checks if \param MO is defined by an
-/// \param CombineOpc instruction in the basic block \param MBB
-static const MachineInstr *canCombine(const MachineBasicBlock &MBB,
- const MachineOperand &MO,
- unsigned CombineOpc) {
- const MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo();
- const MachineInstr *MI = nullptr;
-
- if (MO.isReg() && MO.getReg().isVirtual())
- MI = MRI.getUniqueVRegDef(MO.getReg());
- // And it needs to be in the trace (otherwise, it won't have a depth).
- if (!MI || MI->getParent() != &MBB || MI->getOpcode() != CombineOpc)
- return nullptr;
- // Must only used by the user we combine with.
- if (!MRI.hasOneNonDBGUse(MI->getOperand(0).getReg()))
- return nullptr;
-
- return MI;
+ return getFPAddNegCombinePatterns(Root, Patterns) ||
+ getFPFusedMultiplyPatterns(Root, Patterns, DoRegPressureReduce);
}
/// Utility routine that checks if \param MO is defined by a SLLI in \param
@@ -2319,6 +2349,49 @@ genShXAddAddShift(MachineInstr &Root, unsigned AddOpIdx,
DelInstrs.push_back(&Root);
}
+// Combine (fadd X, (fneg Y)) -> fsub X, Y
+static void combineFPAddNeg(MachineInstr &Root,
+ SmallVectorImpl<MachineInstr *> &InsInstrs,
+ SmallVectorImpl<MachineInstr *> &DelInstrs,
+ DenseMap<unsigned, unsigned> &InstrIdxForVirtReg) {
+ MachineFunction *MF = Root.getMF();
+ MachineRegisterInfo &MRI = MF->getRegInfo();
+ const TargetInstrInfo *TII = MF->getSubtarget().getInstrInfo();
+
+ unsigned Opc = Root.getOpcode();
+ unsigned SubOpc;
+ switch (Opc) {
+ case RISCV::FADD_H:
+ SubOpc = RISCV::FSUB_H;
+ break;
+ case RISCV::FADD_S:
+ SubOpc = RISCV::FSUB_S;
+ break;
+ case RISCV::FADD_D:
+ SubOpc = RISCV::FSUB_D;
+ break;
+ default:
+ llvm_unreachable("Unexpected opcode");
+ }
+
+ const MachineOperand &LHS = Root.getOperand(1);
+ const MachineOperand &RHS = Root.getOperand(2);
+ MachineInstr *NegMI = MRI.getUniqueVRegDef(RHS.getReg());
+ const MachineOperand &NegRHS = NegMI->getOperand(1);
+
+ MachineInstrBuilder MIB =
+ BuildMI(*MF, MIMetadata(Root), TII->get(SubOpc),
+ Root.getOperand(0).getReg())
+ .addReg(LHS.getReg(), getKillRegState(LHS.isKill()))
+ .addReg(NegRHS.getReg(), getKillRegState(NegRHS.isKill()))
+ .addImm(Root.getOperand(3).getImm())
+ .copyImplicitOps(Root);
+
+ InsInstrs.push_back(MIB);
+ DelInstrs.push_back(NegMI);
+ DelInstrs.push_back(&Root);
+}
+
void RISCVInstrInfo::genAlternativeCodeSequence(
MachineInstr &Root, unsigned Pattern,
SmallVectorImpl<MachineInstr *> &InsInstrs,
@@ -2348,6 +2421,9 @@ void RISCVInstrInfo::genAlternativeCodeSequence(
case RISCVMachineCombinerPattern::SHXADD_ADD_SLLI_OP2:
genShXAddAddShift(Root, 2, InsInstrs, DelInstrs, InstrIdxForVirtReg);
return;
+ case RISCVMachineCombinerPattern::FADD_NEGRHS:
+ combineFPAddNeg(Root, InsInstrs, DelInstrs, InstrIdxForVirtReg);
+ return;
}
}
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.h b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
index 457db9b9860d00..64ff3e04b30d2c 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.h
@@ -57,6 +57,7 @@ enum RISCVMachineCombinerPattern : unsigned {
FNMSUB,
SHXADD_ADD_SLLI_OP1,
SHXADD_ADD_SLLI_OP2,
+ FADD_NEGRHS,
};
class RISCVInstrInfo : public RISCVGenInstrInfo {
diff --git a/llvm/test/CodeGen/RISCV/float-zfa.ll b/llvm/test/CodeGen/RISCV/float-zfa.ll
index e5196ead1f8819..1806fdfb39384a 100644
--- a/llvm/test/CodeGen/RISCV/float-zfa.ll
+++ b/llvm/test/CodeGen/RISCV/float-zfa.ll
@@ -269,3 +269,14 @@ define void @fli_remat() {
tail call void @foo(float 1.000000e+00, float 1.000000e+00)
ret void
}
+
+define float @add_negimm(float %x) {
+; CHECK-LABEL: add_negimm:
+; CHECK: # %bb.0: # %entry
+; CHECK-NEXT: fli.s fa5, 0.5
+; CHECK-NEXT: fsub.s fa0, fa0, fa5
+; CHECK-NEXT: ret
+entry:
+ %sub = fadd float %x, -5.000000e-01
+ ret float %sub
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/107803
More information about the llvm-commits
mailing list